Update search_markets method to apply the total parameter to all types, add tests (#901)

* Update search_markets method to apply the total parameter to all types, fixes #534

* Add integration tests for searching multiple types in multiple markets

* Update search_markets method to apply the total parameter to all types, add tests

---------

Co-authored-by: Stéphane Bruckert <stephane.bruckert@gmail.com>
This commit is contained in:
Richard Ngo-Lam 2023-03-15 16:46:08 -07:00 committed by GitHub
parent f2d23e2219
commit fe438c0432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 13 deletions

View File

@ -7,12 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased ## Unreleased
- Added optional `encoder_cls` argument to `CacheFileHandler`, which overwrite default encoder for token before writing to disk
### Added ### Added
- Added optional `encoder_cls` argument to `CacheFileHandler`, which overwrite default encoder for token before writing to disk
- Integration tests for searching multiple types in multiple markets (non-user endpoints)
### Fixed ### Fixed
- Fixed the regex for matching playlist URIs with the format spotify:user:USERNAME:playlist:PLAYLISTID. - Fixed the regex for matching playlist URIs with the format spotify:user:USERNAME:playlist:PLAYLISTID.
- `search_markets` now factors the counts of all types in the `total` rather than just the first type ([#534](https://github.com/spotipy-dev/spotipy/issues/534))
### Removed ### Removed

View File

@ -15,6 +15,8 @@ import urllib3
from spotipy.exceptions import SpotifyException from spotipy.exceptions import SpotifyException
from collections import defaultdict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -594,12 +596,12 @@ class Spotify(object):
official documentation https://developer.spotify.com/documentation/web-api/reference/search/) # noqa official documentation https://developer.spotify.com/documentation/web-api/reference/search/) # noqa
- limit - the number of items to return (min = 1, default = 10, max = 50). If a search is to be done on multiple - limit - the number of items to return (min = 1, default = 10, max = 50). If a search is to be done on multiple
markets, then this limit is applied to each market. (e.g. search US, CA, MX each with a limit of 10). markets, then this limit is applied to each market. (e.g. search US, CA, MX each with a limit of 10).
If multiple types are specified, this applies to each type.
- offset - the index of the first item to return - offset - the index of the first item to return
- type - the types of items to return. One or more of 'artist', 'album', - type - the types of items to return. One or more of 'artist', 'album',
'track', 'playlist', 'show', or 'episode'. If multiple types are desired, pass in a comma separated string. 'track', 'playlist', 'show', or 'episode'. If multiple types are desired, pass in a comma separated string.
- markets - A list of ISO 3166-1 alpha-2 country codes. Search all country markets by default. - markets - A list of ISO 3166-1 alpha-2 country codes. Search all country markets by default.
- total - the total number of results to return if multiple markets are supplied in the search. - total - the total number of results to return across multiple markets and types.
If multiple types are specified, this only applies to the first type.
""" """
warnings.warn( warnings.warn(
"Searching multiple markets is an experimental feature. " "Searching multiple markets is an experimental feature. "
@ -2005,22 +2007,29 @@ class Spotify(object):
UserWarning, UserWarning,
) )
results = {} results = defaultdict(dict)
first_type = type.split(",")[0] + 's' item_types = [item_type + "s" for item_type in type.split(",")]
count = 0 count = 0
for country in markets: for country in markets:
result = self._get( result = self._get(
"search", q=q, limit=limit, offset=offset, type=type, market=country "search", q=q, limit=limit, offset=offset, type=type, market=country
) )
results[country] = result for item_type in item_types:
results[country][item_type] = result[item_type]
count += len(result[first_type]['items']) # Truncate the items list to the current limit
if total and count >= total: if len(results[country][item_type]['items']) > limit:
break results[country][item_type]['items'] = \
results[country][item_type]['items'][:limit]
count += len(results[country][item_type]['items'])
if total and limit > total - count: if total and limit > total - count:
# when approaching `total` results, adjust `limit` to not request more # when approaching `total` results, adjust `limit` to not request more
# items than needed # items than needed
limit = total - count limit = total - count
if total and count >= total:
return results
return results return results

View File

@ -221,6 +221,87 @@ class AuthTestSpotipy(unittest.TestCase):
total_limited_results += len(results_limited[country]['artists']['items']) total_limited_results += len(results_limited[country]['artists']['items'])
self.assertTrue(total_limited_results <= total) self.assertTrue(total_limited_results <= total)
def test_multiple_types_search_with_multiple_markets(self):
total = 14
countries_list = ['GB', 'US', 'AU']
countries_tuple = ('GB', 'US', 'AU')
results_multiple = self.spotify.search_markets(q='abba', type='artist,track',
markets=countries_list)
results_all = self.spotify.search_markets(q='abba', type='artist,track')
results_tuple = self.spotify.search_markets(q='abba', type='artist,track',
markets=countries_tuple)
results_limited = self.spotify.search_markets(q='abba', limit=3, type='artist,track',
markets=countries_list, total=total)
# Asserts 'artists' property is present in all responses
self.assertTrue(
all('artists' in results_multiple[country] for country in results_multiple))
self.assertTrue(all('artists' in results_all[country] for country in results_all))
self.assertTrue(all('artists' in results_tuple[country] for country in results_tuple))
self.assertTrue(all('artists' in results_limited[country] for country in results_limited))
# Asserts 'tracks' property is present in all responses
self.assertTrue(
all('tracks' in results_multiple[country] for country in results_multiple))
self.assertTrue(all('tracks' in results_all[country] for country in results_all))
self.assertTrue(all('tracks' in results_tuple[country] for country in results_tuple))
self.assertTrue(all('tracks' in results_limited[country] for country in results_limited))
# Asserts 'artists' list is nonempty in unlimited searches
self.assertTrue(
all(len(results_multiple[country]['artists']['items']) > 0 for country in
results_multiple))
self.assertTrue(all(len(results_all[country]['artists']
['items']) > 0 for country in results_all))
self.assertTrue(
all(len(results_tuple[country]['artists']['items']) > 0 for country in results_tuple))
# Asserts 'tracks' list is nonempty in unlimited searches
self.assertTrue(
all(len(results_multiple[country]['tracks']['items']) > 0 for country in
results_multiple))
self.assertTrue(all(len(results_all[country]['tracks']
['items']) > 0 for country in results_all))
self.assertTrue(all(len(results_tuple[country]['tracks']
['items']) > 0 for country in results_tuple))
# Asserts artist name is the first artist result in all searches
self.assertTrue(all(results_multiple[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_multiple))
self.assertTrue(all(results_all[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_all))
self.assertTrue(all(results_tuple[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_tuple))
self.assertTrue(all(results_limited[country]['artists']['items']
[0]['name'] == 'ABBA' for country in results_limited))
# Asserts track name is present in responses from specified markets
self.assertTrue(all('Dancing Queen' in
[item['name'] for item in results_multiple[country]['tracks']['items']]
for country in results_multiple))
self.assertTrue(all('Dancing Queen' in
[item['name'] for item in results_tuple[country]['tracks']['items']]
for country in results_tuple))
# Asserts expected number of items are returned based on the total
# 3 artists + 3 tracks = 6 items returned from first market
# 3 artists + 3 tracks = 6 items returned from second market
# 2 artists + 0 tracks = 2 items returned from third market
# 14 items returned total
self.assertEqual(len(results_limited['GB']['artists']['items']), 3)
self.assertEqual(len(results_limited['GB']['tracks']['items']), 3)
self.assertEqual(len(results_limited['US']['artists']['items']), 3)
self.assertEqual(len(results_limited['US']['tracks']['items']), 3)
self.assertEqual(len(results_limited['AU']['artists']['items']), 2)
self.assertEqual(len(results_limited['AU']['tracks']['items']), 0)
item_count = sum([len(market_result['artists']['items']) + len(market_result['tracks']
['items']) for market_result in results_limited.values()])
self.assertEqual(item_count, total)
def test_artist_albums(self): def test_artist_albums(self):
results = self.spotify.artist_albums(self.weezer_urn) results = self.spotify.artist_albums(self.weezer_urn)
self.assertTrue('items' in results) self.assertTrue('items' in results)