feat: add types to args

This commit is contained in:
bvandercar-vt 2025-12-02 15:04:44 -07:00
parent 9119b6a070
commit efc3474e27
7 changed files with 419 additions and 284 deletions

View File

@ -11,6 +11,8 @@ import errno
import json
import logging
import os
from json import JSONEncoder
from typing import Optional
from redis import RedisError
@ -49,10 +51,12 @@ class CacheFileHandler(CacheHandler):
as json files on disk.
"""
def __init__(self,
cache_path=None,
username=None,
encoder_cls=None):
def __init__(
self,
cache_path: Optional[str] = None,
username: Optional[str] = None,
encoder_cls: Optional[JSONEncoder] = None,
):
"""
Parameters:
* cache_path: May be supplied, will otherwise be generated
@ -185,7 +189,7 @@ class RedisCacheHandler(CacheHandler):
A cache handler that stores the token info in the Redis.
"""
def __init__(self, redis, key=None):
def __init__(self, redis, key: Optional[str] = None):
"""
Parameters:
* redis: Redis object provided by redis-py library
@ -194,7 +198,7 @@ class RedisCacheHandler(CacheHandler):
(takes precedence over `token_info`)
"""
self.redis = redis
self.key = key if key else 'token_info'
self.key: str = key if key else 'token_info'
def get_cached_token(self):
token_info = None
@ -218,7 +222,7 @@ class MemcacheCacheHandler(CacheHandler):
"""A Cache handler that stores the token info in Memcache using the pymemcache client
"""
def __init__(self, memcache, key=None) -> None:
def __init__(self, memcache, key: Optional[str] = None):
"""
Parameters:
* memcache: memcache client object provided by pymemcache
@ -227,7 +231,7 @@ class MemcacheCacheHandler(CacheHandler):
(takes precedence over `token_info`)
"""
self.memcache = memcache
self.key = key if key else 'token_info'
self.key: str = key if key else 'token_info'
def get_cached_token(self):
from pymemcache import MemcacheError

File diff suppressed because it is too large Load Diff

View File

@ -34,8 +34,16 @@ class SpotifyOauthError(SpotifyBaseException):
class SpotifyStateError(SpotifyOauthError):
""" The state sent and state received were different """
def __init__(self, local_state=None, remote_state=None, message=None,
error=None, error_description=None, *args, **kwargs):
def __init__(
self,
local_state=None,
remote_state=None,
message=None,
error=None,
error_description=None,
*args,
**kwargs,
):
if not message:
message = ("Expected " + local_state + " but received "
+ remote_state)

View File

@ -16,6 +16,7 @@ import urllib.parse as urllibparse
import warnings
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional, Union, Any
from urllib.parse import parse_qsl, urlparse
import requests
@ -23,19 +24,21 @@ import requests
from spotipy.cache_handler import CacheFileHandler, CacheHandler
from spotipy.exceptions import SpotifyOauthError, SpotifyStateError
from spotipy.util import (CLIENT_CREDS_ENV_VARS, REQUESTS_SESSION,
get_host_port, normalize_scope)
get_host_port, normalize_scope, ScopeArgType)
logger = logging.getLogger(__name__)
# TODO: improve the types that are "Any"
def _make_authorization_headers(client_id, client_secret):
def _make_authorization_headers(client_id: str, client_secret: str):
auth_header = base64.b64encode(
str(client_id + ":" + client_secret).encode("ascii")
)
return {"Authorization": f"Basic {auth_header.decode('ascii')}"}
def _ensure_value(value, env_key):
def _ensure_value(value: Optional[str], env_key: str) -> str:
env_val = CLIENT_CREDS_ENV_VARS[env_key]
_val = value or os.getenv(env_val)
if _val is None:
@ -45,7 +48,7 @@ def _ensure_value(value, env_key):
class SpotifyAuthBase:
def __init__(self, requests_session):
def __init__(self, requests_session: Optional[Union[requests.Session, bool]] = None):
if isinstance(requests_session, requests.Session):
self._session = requests_session
else:
@ -55,7 +58,7 @@ class SpotifyAuthBase:
from requests import api
self._session = api
def _normalize_scope(self, scope):
def _normalize_scope(self, scope: Optional[ScopeArgType]):
return normalize_scope(scope)
@property
@ -63,7 +66,7 @@ class SpotifyAuthBase:
return self._client_id
@client_id.setter
def client_id(self, val):
def client_id(self, val: Optional[str]):
self._client_id = _ensure_value(val, "client_id")
@property
@ -71,7 +74,7 @@ class SpotifyAuthBase:
return self._client_secret
@client_secret.setter
def client_secret(self, val):
def client_secret(self, val: Optional[str]):
self._client_secret = _ensure_value(val, "client_secret")
@property
@ -79,30 +82,30 @@ class SpotifyAuthBase:
return self._redirect_uri
@redirect_uri.setter
def redirect_uri(self, val):
def redirect_uri(self, val: Optional[str]):
self._redirect_uri = _ensure_value(val, "redirect_uri")
@staticmethod
def _get_user_input(prompt):
def _get_user_input(prompt) -> str:
try:
return raw_input(prompt)
except NameError:
return input(prompt)
@staticmethod
def is_token_expired(token_info):
def is_token_expired(token_info) -> bool:
now = int(time.time())
return token_info["expires_at"] - now < 60
@staticmethod
def _is_scope_subset(needle_scope, haystack_scope):
def _is_scope_subset(
needle_scope: Optional[str], haystack_scope: Optional[str]
) -> bool:
needle_scope = set(needle_scope.split()) if needle_scope else set()
haystack_scope = (
set(haystack_scope.split()) if haystack_scope else set()
)
haystack_scope = set(haystack_scope.split()) if haystack_scope else set()
return needle_scope <= haystack_scope
def _handle_oauth_error(self, http_error):
def _handle_oauth_error(self, http_error: requests.exceptions.HTTPError):
response = http_error.response
try:
error_payload = response.json()
@ -133,12 +136,12 @@ class SpotifyClientCredentials(SpotifyAuthBase):
def __init__(
self,
client_id=None,
client_secret=None,
proxies=None,
requests_session=True,
requests_timeout=None,
cache_handler=None
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
proxies: Optional[Any] = None,
requests_session: Union[requests.Session, bool] = True,
requests_timeout: Optional[int] = None,
cache_handler: Optional[CacheHandler] = None,
):
"""
Creates a Client Credentials Flow Manager.
@ -181,7 +184,8 @@ class SpotifyClientCredentials(SpotifyAuthBase):
else:
self.cache_handler = CacheFileHandler()
def get_access_token(self, as_dict=True, check_cache=True):
# TODO: better return type based oninput type (overrides)
def get_access_token(self, as_dict: bool = True, check_cache: bool = True):
"""
If a valid access token is in memory, returns it
Else fetches a new token and returns it
@ -255,19 +259,19 @@ class SpotifyOAuth(SpotifyAuthBase):
def __init__(
self,
client_id=None,
client_secret=None,
redirect_uri=None,
state=None,
scope=None,
cache_path=None,
username=None,
proxies=None,
show_dialog=False,
requests_session=True,
requests_timeout=None,
open_browser=True,
cache_handler=None
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
redirect_uri: Optional[str] = None,
state: Optional[Any] = None,
scope: Optional[ScopeArgType] = None,
cache_path: Optional[str] = None,
username: Optional[str] = None,
proxies: Optional[Any] = None,
show_dialog: bool = False,
requests_session: Union[requests.Session, bool] = True,
requests_timeout: Optional[int] = None,
open_browser: bool = True,
cache_handler: Optional[CacheHandler] = None,
):
"""
Creates a SpotifyOAuth object
@ -352,7 +356,7 @@ class SpotifyOAuth(SpotifyAuthBase):
return token_info
def get_authorize_url(self, state=None):
def get_authorize_url(self, state: Optional[Any] = None) -> str:
""" Gets the URL to use to authorize this app
"""
payload = {
@ -405,7 +409,7 @@ class SpotifyOAuth(SpotifyAuthBase):
except webbrowser.Error:
logger.error(f"Please navigate here: {auth_url}")
def _get_auth_response_interactive(self, open_browser=False):
def _get_auth_response_interactive(self, open_browser: bool = False):
if open_browser:
self._open_auth_url()
prompt = "Enter the URL you were redirected to: "
@ -435,7 +439,7 @@ class SpotifyOAuth(SpotifyAuthBase):
else:
raise SpotifyOauthError("Server listening on localhost has not been accessed")
def get_auth_response(self, open_browser=None):
def get_auth_response(self, open_browser: Optional[bool] = None):
logger.info('User authentication requires interaction with your '
'web browser. Once you enter your credentials and '
'give authorization, you will be redirected to '
@ -476,12 +480,14 @@ class SpotifyOAuth(SpotifyAuthBase):
return self._get_auth_response_interactive(open_browser=open_browser)
def get_authorization_code(self, response=None):
def get_authorization_code(self, response: Optional[Any] = None):
if response:
return self.parse_response_code(response)
return self.get_auth_response()
def get_access_token(self, code=None, as_dict=True, check_cache=True):
def get_access_token(
self, code: Optional[Any] = None, as_dict: bool = True, check_cache: bool = True
):
""" Gets the access token for the app given the code
Parameters:
@ -540,7 +546,7 @@ class SpotifyOAuth(SpotifyAuthBase):
except requests.exceptions.HTTPError as http_error:
self._handle_oauth_error(http_error)
def refresh_access_token(self, refresh_token):
def refresh_access_token(self, refresh_token: str):
payload = {
"refresh_token": refresh_token,
"grant_type": "refresh_token",
@ -619,18 +625,20 @@ class SpotifyPKCE(SpotifyAuthBase):
OAUTH_AUTHORIZE_URL = "https://accounts.spotify.com/authorize"
OAUTH_TOKEN_URL = "https://accounts.spotify.com/api/token"
def __init__(self,
client_id=None,
redirect_uri=None,
state=None,
scope=None,
cache_path=None,
username=None,
proxies=None,
requests_timeout=None,
requests_session=True,
open_browser=True,
cache_handler=None):
def __init__(
self,
client_id: Optional[str] = None,
redirect_uri: Optional[str] = None,
state: Optional[Any] = None,
scope: Optional[ScopeArgType] = None,
cache_path: Optional[str] = None,
username: Optional[str] = None,
proxies: Optional[Any] = None,
requests_timeout: Optional[int] = None,
requests_session: Union[requests.Session, bool] = True,
open_browser: bool = True,
cache_handler: Optional[CacheHandler] = None,
):
"""
Creates Auth Manager with the PKCE Auth flow.
@ -695,7 +703,7 @@ class SpotifyPKCE(SpotifyAuthBase):
self.authorization_code = None
self.open_browser = open_browser
def _get_code_verifier(self):
def _get_code_verifier(self) -> str:
""" Spotify PCKE code verifier - See step 1 of the reference guide below
Reference:
https://developer.spotify.com/documentation/general/guides/authorization-guide/#authorization-code-flow-with-proof-key-for-code-exchange-pkce
@ -709,7 +717,7 @@ class SpotifyPKCE(SpotifyAuthBase):
import secrets
return secrets.token_urlsafe(length)
def _get_code_challenge(self):
def _get_code_challenge(self) -> str:
""" Spotify PCKE code challenge - See step 1 of the reference guide below
Reference:
https://developer.spotify.com/documentation/general/guides/authorization-guide/#authorization-code-flow-with-proof-key-for-code-exchange-pkce
@ -720,7 +728,7 @@ class SpotifyPKCE(SpotifyAuthBase):
code_challenge = base64.urlsafe_b64encode(code_challenge_digest).decode('utf-8')
return code_challenge.replace('=', '')
def get_authorize_url(self, state=None):
def get_authorize_url(self, state: Optional[Any] = None) -> str:
""" Gets the URL to use to authorize this app """
if not self.code_challenge:
self.get_pkce_handshake_parameters()
@ -740,7 +748,7 @@ class SpotifyPKCE(SpotifyAuthBase):
urlparams = urllibparse.urlencode(payload)
return f"{self.OAUTH_AUTHORIZE_URL}?{urlparams}"
def _open_auth_url(self, state=None):
def _open_auth_url(self, state: Optional[Any] = None):
auth_url = self.get_authorize_url(state)
try:
webbrowser.open(auth_url)
@ -748,7 +756,7 @@ class SpotifyPKCE(SpotifyAuthBase):
except webbrowser.Error:
logger.error(f"Please navigate here: {auth_url}")
def _get_auth_response(self, open_browser=None):
def _get_auth_response(self, open_browser: Optional[bool] = None):
logger.info('User authentication requires interaction with your '
'web browser. Once you enter your credentials and '
'give authorization, you will be redirected to '
@ -803,7 +811,7 @@ class SpotifyPKCE(SpotifyAuthBase):
else:
raise SpotifyOauthError("Server listening on localhost has not been accessed")
def _get_auth_response_interactive(self, open_browser=False):
def _get_auth_response_interactive(self, open_browser: bool = False):
if open_browser or self.open_browser:
self._open_auth_url()
prompt = "Enter the URL you were redirected to: "
@ -817,7 +825,7 @@ class SpotifyPKCE(SpotifyAuthBase):
raise SpotifyStateError(self.state, state)
return code
def get_authorization_code(self, response=None):
def get_authorization_code(self, response: Optional[Any] = None):
if response:
return self.parse_response_code(response)
return self._get_auth_response()
@ -851,7 +859,7 @@ class SpotifyPKCE(SpotifyAuthBase):
self.code_verifier = self._get_code_verifier()
self.code_challenge = self._get_code_challenge()
def get_access_token(self, code=None, check_cache=True):
def get_access_token(self, code: Optional[Any] = None, check_cache: bool = True):
""" Gets the access token for the app
If the code is not given and no cached token is used, an
@ -906,7 +914,7 @@ class SpotifyPKCE(SpotifyAuthBase):
except requests.exceptions.HTTPError as http_error:
self._handle_oauth_error(http_error)
def refresh_access_token(self, refresh_token):
def refresh_access_token(self, refresh_token: str):
payload = {
"refresh_token": refresh_token,
"grant_type": "refresh_token",
@ -1008,15 +1016,17 @@ class SpotifyImplicitGrant(SpotifyAuthBase):
"""
OAUTH_AUTHORIZE_URL = "https://accounts.spotify.com/authorize"
def __init__(self,
client_id=None,
redirect_uri=None,
state=None,
scope=None,
cache_path=None,
username=None,
show_dialog=False,
cache_handler=None):
def __init__(
self,
client_id: Optional[str] = None,
redirect_uri: Optional[str] = None,
state: Optional[Any] = None,
scope: Optional[ScopeArgType] = None,
cache_path: Optional[str] = None,
username: Optional[str] = None,
show_dialog: bool = False,
cache_handler: Optional[CacheHandler] = None,
):
""" Creates Auth Manager using the Implicit Grant flow
**See help(SpotifyImplicitGrant) for full Security Warning**
@ -1092,10 +1102,12 @@ class SpotifyImplicitGrant(SpotifyAuthBase):
return token_info
def get_access_token(self,
state=None,
response=None,
check_cache=True):
def get_access_token(
self,
state: Optional[Any] = None,
response: Optional[Any] = None,
check_cache: bool = True,
):
""" Gets Auth Token from cache (preferred) or user interaction
Parameters
@ -1118,7 +1130,7 @@ class SpotifyImplicitGrant(SpotifyAuthBase):
return token_info["access_token"]
def get_authorize_url(self, state=None):
def get_authorize_url(self, state: Optional[Any] = None) -> str:
""" Gets the URL to use to authorize this app """
payload = {
"client_id": self.client_id,
@ -1138,7 +1150,7 @@ class SpotifyImplicitGrant(SpotifyAuthBase):
return f"{self.OAUTH_AUTHORIZE_URL}?{urlparams}"
def parse_response_token(self, url, state=None):
def parse_response_token(self, url, state: Optional[Any] = None):
""" Parse the response code in the given response url """
remote_state, token, t_type, exp_in = self.parse_auth_response_url(url)
if state is None:
@ -1163,7 +1175,7 @@ class SpotifyImplicitGrant(SpotifyAuthBase):
return tuple(form.get(param) for param in ["state", "access_token",
"token_type", "expires_in"])
def _open_auth_url(self, state=None):
def _open_auth_url(self, state: Optional[Any] = None):
auth_url = self.get_authorize_url(state)
try:
webbrowser.open(auth_url)
@ -1171,7 +1183,7 @@ class SpotifyImplicitGrant(SpotifyAuthBase):
except webbrowser.Error:
logger.error(f"Please navigate here: {auth_url}")
def get_auth_response(self, state=None):
def get_auth_response(self, state: Optional[Any] = None):
""" Gets a new auth **token** with user interaction """
logger.info('User authentication requires interaction with your '
'web browser. Once you enter your credentials and '
@ -1274,7 +1286,7 @@ Close Window
</body>
</html>""")
def _write(self, text):
def _write(self, text: str):
return self.wfile.write(text.encode("utf-8"))
def log_message(self, format, *args):

View File

@ -8,6 +8,7 @@ import logging
import os
import warnings
from types import TracebackType
from typing import Optional, Union, Tuple, List
import requests
import urllib3
@ -26,17 +27,19 @@ CLIENT_CREDS_ENV_VARS = {
# workaround for garbage collection
REQUESTS_SESSION = requests.Session
StrListOrTuple = Union[List[str], Tuple[str, ...]]
def prompt_for_user_token(
username=None,
scope=None,
client_id=None,
client_secret=None,
redirect_uri=None,
cache_path=None,
oauth_manager=None,
show_dialog=False
):
username: Optional[str] = None,
scope: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
redirect_uri: Optional[str] = None,
cache_path: Optional[str] = None,
oauth_manager: Optional[spotipy.SpotifyOAuth] = None,
show_dialog: bool = False,
) -> Union[str, None]:
""" Prompt the user to login if necessary and returns a user token
suitable for use with the spotipy.Spotify constructor.
@ -116,7 +119,7 @@ def prompt_for_user_token(
return None
def get_host_port(netloc):
def get_host_port(netloc: str):
""" Split the network location string into host and port and returns a tuple
where the host is a string and the the port is an integer.
@ -132,8 +135,9 @@ def get_host_port(netloc):
return host, port
ScopeArgType = Union[str, StrListOrTuple]
def normalize_scope(scope):
def normalize_scope(scope: Optional[ScopeArgType]) -> Union[str, None]:
"""Normalize the scope to verify that it is a list or tuple. A string
input will split the string by commas to create a list of scopes.
A list or tuple input is used directly.
@ -164,12 +168,12 @@ class Retry(urllib3.Retry):
def increment(
self,
method: str | None = None,
url: str | None = None,
response: urllib3.BaseHTTPResponse | None = None,
error: Exception | None = None,
_pool: urllib3.connectionpool.ConnectionPool | None = None,
_stacktrace: TracebackType | None = None,
method: Optional[str] = None,
url: Optional[str] = None,
response: Optional[urllib3.BaseHTTPResponse] = None,
error: Optional[Exception] = None,
_pool: Optional[urllib3.connectionpool.ConnectionPool] = None,
_stacktrace: Optional[TracebackType] = None,
) -> urllib3.Retry:
if response:
retry_header = response.headers.get("Retry-After")

View File

@ -1,9 +1,12 @@
import base64
from typing import Union
import requests
from spotipy import Spotify
def get_spotify_playlist(spotify_object, playlist_name, username):
def get_spotify_playlist(spotify_object: Spotify, playlist_name: str, username: str):
playlists = spotify_object.user_playlists(username)
while playlists:
for item in playlists['items']:
@ -12,5 +15,5 @@ def get_spotify_playlist(spotify_object, playlist_name, username):
playlists = spotify_object.next(playlists)
def get_as_base64(url):
def get_as_base64(url: Union[str, bytes]) -> str:
return base64.b64encode(requests.get(url).content).decode("utf-8")

View File

@ -13,7 +13,7 @@ patch = mock.patch
DEFAULT = mock.DEFAULT
def _make_fake_token(expires_at, expires_in, scope):
def _make_fake_token(expires_at: int, expires_in: int, scope: str):
return dict(
expires_at=expires_at,
expires_in=expires_in,