From 51408922885cd3ce31a07b7f3e90981d2b26c0ef Mon Sep 17 00:00:00 2001 From: aj Date: Thu, 12 Sep 2019 08:41:48 +0100 Subject: [PATCH] return type annotations --- spotframework/engine/playlistengine.py | 17 ++--- spotframework/engine/processor/abstract.py | 14 ++-- spotframework/engine/processor/added.py | 5 +- spotframework/engine/processor/deduplicate.py | 4 +- spotframework/engine/processor/popularity.py | 2 +- spotframework/engine/processor/shuffle.py | 4 +- spotframework/engine/processor/sort.py | 4 +- spotframework/model/album.py | 2 +- spotframework/model/playlist.py | 4 +- spotframework/model/track.py | 4 +- spotframework/net/network.py | 68 ++++++++++++------- spotframework/net/parse/parse.py | 12 ++-- spotframework/net/user.py | 7 +- spotframework/util/monthstrings.py | 6 +- 14 files changed, 88 insertions(+), 65 deletions(-) diff --git a/spotframework/engine/playlistengine.py b/spotframework/engine/playlistengine.py index f16c2bf..8338291 100644 --- a/spotframework/engine/playlistengine.py +++ b/spotframework/engine/playlistengine.py @@ -5,12 +5,13 @@ import logging import spotframework.util.monthstrings as monthstrings from spotframework.engine.processor.added import AddedSince -from typing import List +from typing import List, Optional from spotframework.model.track import SpotifyTrack from spotframework.model.playlist import SpotifyPlaylist from spotframework.net.network import Network from spotframework.engine.processor.abstract import AbstractProcessor from datetime import datetime +from requests.models import Response logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ class PlaylistEngine: self.playlists = [] self.net = net - def load_user_playlists(self): + def load_user_playlists(self) -> None: logger.info('loading') playlists = self.net.get_playlists() @@ -30,7 +31,7 @@ class PlaylistEngine: else: logger.error('error getting playlists') - def append_user_playlists(self): + def append_user_playlists(self) -> None: logger.info('loading') playlists = self.net.get_playlists() @@ -40,7 +41,7 @@ class PlaylistEngine: logger.error('error getting playlists') def get_playlist_tracks(self, - playlist: SpotifyPlaylist): + playlist: SpotifyPlaylist) -> None: logger.info(f"pulling tracks for {playlist.name}") tracks = self.net.get_playlist_tracks(playlist.playlist_id) @@ -53,7 +54,7 @@ class PlaylistEngine: playlist_parts: List[str], processors: List[AbstractProcessor] = None, include_recommendations: bool = False, - recommendation_limit: int = 10): + recommendation_limit: int = 10) -> List[SpotifyTrack]: if processors is None: processors = [] @@ -102,7 +103,7 @@ class PlaylistEngine: include_recommendations: bool = False, recommendation_limit: int = 10, add_this_month: bool = False, - add_last_month: bool = False): + add_last_month: bool = False) -> List[SpotifyTrack]: if processors is None: processors = [] @@ -129,7 +130,7 @@ class PlaylistEngine: def execute_playlist(self, tracks: List[SpotifyTrack], - playlist_id: str): + playlist_id: str) -> Optional[Response]: resp = self.net.replace_playlist_tracks(playlist_id, [i.uri for i in tracks]) if resp: @@ -142,7 +143,7 @@ class PlaylistEngine: playlistparts: List[str], playlist_id: str, overwrite: bool = None, - suffix: str = None): + suffix: str = None) -> Optional[Response]: if overwrite: string = overwrite diff --git a/spotframework/engine/processor/abstract.py b/spotframework/engine/processor/abstract.py index 98974e3..820192f 100644 --- a/spotframework/engine/processor/abstract.py +++ b/spotframework/engine/processor/abstract.py @@ -8,24 +8,24 @@ class AbstractProcessor(ABC): def __init__(self, names: List[str] = None): self.playlist_names = names - def has_targets(self): + def has_targets(self) -> bool: if self.playlist_names: return True else: return False @abstractmethod - def process(self, tracks: List[Track]): + def process(self, tracks: List[Track]) -> List[Track]: pass class BatchSingleProcessor(AbstractProcessor, ABC): @staticmethod - def process_single(track: Track): + def process_single(track: Track) -> Track: return track - def process_batch(self, tracks: List[Track]): + def process_batch(self, tracks: List[Track]) -> List[Track]: processed = [] for track in tracks: @@ -34,8 +34,8 @@ class BatchSingleProcessor(AbstractProcessor, ABC): return processed - def process(self, tracks: List[Track]): - return [i for i in self.process_batch(tracks) if i] + def process(self, tracks: List[Track]) -> List[Track]: + return [i for i in self.process_batch(tracks) if i is not None] class BatchSingleTypeAwareProcessor(BatchSingleProcessor, ABC): @@ -48,7 +48,7 @@ class BatchSingleTypeAwareProcessor(BatchSingleProcessor, ABC): self.instance_check = instance_check self.append_malformed = append_malformed - def process(self, tracks: List[Track]): + def process(self, tracks: List[Track]) -> List[Track]: if self.instance_check: return_tracks = [] diff --git a/spotframework/engine/processor/added.py b/spotframework/engine/processor/added.py index 75c46bc..72eb602 100644 --- a/spotframework/engine/processor/added.py +++ b/spotframework/engine/processor/added.py @@ -2,6 +2,7 @@ from .abstract import BatchSingleTypeAwareProcessor import datetime from typing import List from spotframework.model.track import PlaylistTrack +from typing import Optional class Added(BatchSingleTypeAwareProcessor): @@ -17,12 +18,12 @@ class Added(BatchSingleTypeAwareProcessor): class AddedBefore(Added): - def process_single(self, track: PlaylistTrack): + def process_single(self, track: PlaylistTrack) -> Optional[PlaylistTrack]: if track.added_at < self.boundary: return track class AddedSince(Added): - def process_single(self, track: PlaylistTrack): + def process_single(self, track: PlaylistTrack) -> Optional[PlaylistTrack]: if track.added_at > self.boundary: return track diff --git a/spotframework/engine/processor/deduplicate.py b/spotframework/engine/processor/deduplicate.py index 90beeee..b87ab40 100644 --- a/spotframework/engine/processor/deduplicate.py +++ b/spotframework/engine/processor/deduplicate.py @@ -12,7 +12,7 @@ class DeduplicateByID(BatchSingleTypeAwareProcessor): instance_check=SpotifyTrack, append_malformed=append_malformed) - def process_batch(self, tracks: List[SpotifyTrack]): + def process_batch(self, tracks: List[SpotifyTrack]) -> List[SpotifyTrack]: return_tracks = [] for track in tracks: @@ -24,7 +24,7 @@ class DeduplicateByID(BatchSingleTypeAwareProcessor): class DeduplicateByName(BatchSingleProcessor): - def process_batch(self, tracks: List[Track]): + def process_batch(self, tracks: List[Track]) -> List[Track]: return_tracks = [] for to_check in tracks: diff --git a/spotframework/engine/processor/popularity.py b/spotframework/engine/processor/popularity.py index 43aaf6f..310d0ed 100644 --- a/spotframework/engine/processor/popularity.py +++ b/spotframework/engine/processor/popularity.py @@ -14,6 +14,6 @@ class SortPopularity(BatchSingleTypeAwareProcessor): append_malformed=append_malformed) self.reverse = reverse - def process_batch(self, tracks: List[SpotifyTrack]): + def process_batch(self, tracks: List[SpotifyTrack]) -> List[SpotifyTrack]: tracks.sort(key=lambda x: x.popularity, reverse=self.reverse) return tracks diff --git a/spotframework/engine/processor/shuffle.py b/spotframework/engine/processor/shuffle.py index 11f0490..e94c0fc 100644 --- a/spotframework/engine/processor/shuffle.py +++ b/spotframework/engine/processor/shuffle.py @@ -6,7 +6,7 @@ from spotframework.model.track import Track class Shuffle(AbstractProcessor): - def process(self, tracks: List[Track]): + def process(self, tracks: List[Track]) -> List[Track]: random.shuffle(tracks) return tracks @@ -19,5 +19,5 @@ class RandomSample(Shuffle): super().__init__(names) self.sample_size = sample_size - def process(self, tracks: List[Track]): + def process(self, tracks: List[Track]) -> List[Track]: return super().process(tracks)[:self.sample_size] diff --git a/spotframework/engine/processor/sort.py b/spotframework/engine/processor/sort.py index 49ea780..0ec6bf3 100644 --- a/spotframework/engine/processor/sort.py +++ b/spotframework/engine/processor/sort.py @@ -14,13 +14,13 @@ class BasicReversibleSort(AbstractProcessor, ABC): class SortReleaseDate(BasicReversibleSort): - def process(self, tracks: List[Track]): + def process(self, tracks: List[Track]) -> List[Track]: tracks.sort(key=lambda x: x.album.release_date, reverse=self.reverse) return tracks class SortArtistName(BasicReversibleSort): - def process(self, tracks: List[Track]): + def process(self, tracks: List[Track]) -> List[Track]: tracks.sort(key=lambda x: x.artists[0].name, reverse=self.reverse) return tracks diff --git a/spotframework/model/album.py b/spotframework/model/album.py index 898187b..ccd6f68 100644 --- a/spotframework/model/album.py +++ b/spotframework/model/album.py @@ -11,7 +11,7 @@ class Album: self.artists = artists @property - def artists_names(self): + def artists_names(self) -> str: return self._join_strings([i.name for i in self.artists]) @staticmethod diff --git a/spotframework/model/playlist.py b/spotframework/model/playlist.py index 26009bc..0157bbf 100644 --- a/spotframework/model/playlist.py +++ b/spotframework/model/playlist.py @@ -16,7 +16,7 @@ class Playlist: self.name = name self.description = description - def has_tracks(self): + def has_tracks(self) -> bool: if len(self.tracks) > 0: return True else: @@ -26,7 +26,7 @@ class Playlist: return len(self.tracks) @property - def tracks(self): + def tracks(self) -> List[Track]: return self._tracks @tracks.setter diff --git a/spotframework/model/track.py b/spotframework/model/track.py index 5f540d9..a3ac1fd 100644 --- a/spotframework/model/track.py +++ b/spotframework/model/track.py @@ -27,11 +27,11 @@ class Track: self.explicit = excplicit @property - def artists_names(self): + def artists_names(self) -> str: return self._join_strings([i.name for i in self.artists]) @property - def album_artists_names(self): + def album_artists_names(self) -> str: return self.album.artists_names @staticmethod diff --git a/spotframework/net/network.py b/spotframework/net/network.py index 5ad8e09..fda65a4 100644 --- a/spotframework/net/network.py +++ b/spotframework/net/network.py @@ -2,10 +2,12 @@ import requests import random import logging import time -from typing import List +from typing import List, Optional from . import const from spotframework.net.parse import parse from spotframework.model.playlist import SpotifyPlaylist +from spotframework.model.track import Track, PlaylistTrack +from requests.models import Response limit = 50 @@ -17,7 +19,7 @@ class Network: def __init__(self, user): self.user = user - def _make_get_request(self, method, url, params=None, headers={}): + def _make_get_request(self, method, url, params=None, headers={}) -> Optional[dict]: headers['Authorization'] = 'Bearer ' + self.user.accesstoken @@ -44,7 +46,7 @@ class Network: return None - def _make_post_request(self, method, url, params=None, json=None, headers={}): + def _make_post_request(self, method, url, params=None, json=None, headers={}) -> Optional[Response]: headers['Authorization'] = 'Bearer ' + self.user.accesstoken @@ -71,7 +73,7 @@ class Network: return None - def _make_put_request(self, method, url, params=None, json=None, headers={}): + def _make_put_request(self, method, url, params=None, json=None, headers={}) -> Optional[Response]: headers['Authorization'] = 'Bearer ' + self.user.accesstoken @@ -98,7 +100,7 @@ class Network: return None - def get_playlist(self, playlistid: str): + def get_playlist(self, playlistid: str) -> Optional[SpotifyPlaylist]: logger.info(f"{playlistid}") @@ -114,7 +116,12 @@ class Network: logger.error(f"{playlistid} - no tracks returned") return None - def create_playlist(self, username, name='New Playlist', public=True, collaborative=False, description=None): + def create_playlist(self, + username, + name='New Playlist', + public=True, + collaborative=False, + description=None) -> Optional[dict]: json = {"name": name, "public": public, "collaborative": collaborative} @@ -129,7 +136,7 @@ class Network: logger.error('error creating playlist') return None - def get_playlists(self, offset=0): + def get_playlists(self, offset=0) -> Optional[List[SpotifyPlaylist]]: logger.info(f"{offset}") @@ -155,7 +162,7 @@ class Network: logger.error(f'error getting playlists offset={offset}') return None - def get_user_playlists(self): + def get_user_playlists(self) -> Optional[List[SpotifyPlaylist]]: logger.info('retrieved') @@ -167,7 +174,7 @@ class Network: logger.error('no playlists returned to filter') return None - def get_playlist_tracks(self, playlistid, offset=0): + def get_playlist_tracks(self, playlistid, offset=0) -> List[PlaylistTrack]: logger.info(f"{playlistid}{' ' + str(offset) if offset is not 0 else ''}") @@ -192,7 +199,7 @@ class Network: return tracks - def get_available_devices(self): + def get_available_devices(self) -> Optional[dict]: logger.info("retrieving") @@ -203,7 +210,7 @@ class Network: logger.error('no devices returned') return None - def get_player(self): + def get_player(self) -> Optional[dict]: logger.info("retrieved") @@ -214,7 +221,7 @@ class Network: logger.error('no player returned') return None - def get_device_id(self, devicename): + def get_device_id(self, devicename) -> Optional[str]: logger.info(f"{devicename}") @@ -225,7 +232,7 @@ class Network: logger.error('no devices returned') return None - def play(self, uri=None, uris=None, deviceid=None): + def play(self, uri=None, uris=None, deviceid=None) -> Optional[Response]: logger.info(f"{uri}{' ' + deviceid if deviceid is not None else ''}") @@ -247,10 +254,12 @@ class Network: raise Exception('need either context uri or uris') req = self._make_put_request('play', 'me/player/play', params=params, json=payload) - if req is None: + if req: + return req + else: logger.error('error playing') - def pause(self, deviceid=None): + def pause(self, deviceid=None) -> Optional[Response]: logger.info(f"{deviceid if deviceid is not None else ''}") @@ -260,10 +269,12 @@ class Network: params = None req = self._make_put_request('pause', 'me/player/pause', params=params) - if req is None: + if req: + return req + else: logger.error('error pausing') - def next(self, deviceid=None): + def next(self, deviceid=None) -> Optional[Response]: logger.info(f"{deviceid if deviceid is not None else ''}") @@ -273,10 +284,12 @@ class Network: params = None req = self._make_post_request('next', 'me/player/next', params=params) - if req is None: + if req: + return req + else: logger.error('error skipping') - def set_shuffle(self, state, deviceid=None): + def set_shuffle(self, state, deviceid=None) -> Optional[Response]: logger.info(f"{state}{' ' + deviceid if deviceid is not None else ''}") @@ -286,10 +299,12 @@ class Network: params['device_id'] = deviceid req = self._make_put_request('setShuffle', 'me/player/shuffle', params=params) - if req is None: + if req: + return req + else: logger.error(f'error setting shuffle {state}') - def set_volume(self, volume, deviceid=None): + def set_volume(self, volume, deviceid=None) -> Optional[Response]: logger.info(f"{volume}{' ' + deviceid if deviceid is not None else ''}") @@ -330,7 +345,12 @@ class Network: else: logger.error(f'error replacing playlist tracks, total: {len(uris)}') - def change_playlist_details(self, playlistid, name=None, public=None, collaborative=None, description=None): + def change_playlist_details(self, + playlistid, + name=None, + public=None, + collaborative=None, + description=None) -> Optional[Response]: logger.info(f"{playlistid}") @@ -361,7 +381,7 @@ class Network: logger.error('error updating details') return None - def add_playlist_tracks(self, playlistid: str, uris: List[str]): + def add_playlist_tracks(self, playlistid: str, uris: List[str]) -> List[dict]: logger.info(f"{playlistid}") @@ -386,7 +406,7 @@ class Network: logger.error(f'error retrieving tracks {playlistid}, total: {len(uris)}') return [] - def get_recommendations(self, tracks=None, artists=None, response_limit=10): + def get_recommendations(self, tracks=None, artists=None, response_limit=10) -> Optional[List[Track]]: logger.info(f'sample size: {response_limit}') diff --git a/spotframework/net/parse/parse.py b/spotframework/net/parse/parse.py index 8766a13..b89c353 100644 --- a/spotframework/net/parse/parse.py +++ b/spotframework/net/parse/parse.py @@ -1,11 +1,11 @@ -from spotframework.model.artist import Artist, SpotifyArtist -from spotframework.model.album import Album, SpotifyAlbum +from spotframework.model.artist import SpotifyArtist +from spotframework.model.album import SpotifyAlbum from spotframework.model.track import Track, SpotifyTrack, PlaylistTrack from spotframework.model.playlist import SpotifyPlaylist from spotframework.model.user import User -def parse_artist(artist_dict) -> Artist: +def parse_artist(artist_dict) -> SpotifyArtist: name = artist_dict.get('name', None) @@ -28,7 +28,7 @@ def parse_artist(artist_dict) -> Artist: popularity=popularity) -def parse_album(album_dict) -> Album: +def parse_album(album_dict) -> SpotifyAlbum: name = album_dict.get('name', None) if name is None: @@ -139,7 +139,7 @@ def parse_track(track_dict) -> Track: popularity=popularity) -def parse_user(user_dict): +def parse_user(user_dict) -> User: display_name = user_dict.get('display_name', None) spotify_id = user_dict.get('id', None) @@ -152,7 +152,7 @@ def parse_user(user_dict): display_name=display_name) -def parse_playlist(playlist_dict): +def parse_playlist(playlist_dict) -> SpotifyPlaylist: collaborative = playlist_dict.get('collaborative', None) diff --git a/spotframework/net/user.py b/spotframework/net/user.py index baaa8a9..38b025e 100644 --- a/spotframework/net/user.py +++ b/spotframework/net/user.py @@ -3,6 +3,7 @@ from spotframework.model.user import User from base64 import b64encode import logging import time +from typing import Optional logger = logging.getLogger(__name__) @@ -21,7 +22,7 @@ class NetworkUser(User): self.refresh_token() self.refresh_info() - def refresh_token(self): + def refresh_token(self) -> None: if self.refreshtoken is None: raise NameError('no refresh token to query') @@ -58,7 +59,7 @@ class NetworkUser(User): error_text = req.json()['error']['message'] logger.error(f'refresh_token get {req.status_code} {error_text}') - def refresh_info(self): + def refresh_info(self) -> None: info = self.get_info() if info.get('display_name', None): @@ -77,7 +78,7 @@ class NetworkUser(User): if info.get('uri', None): self.uri = info['uri'] - def get_info(self): + def get_info(self) -> Optional[dict]: headers = {'Authorization': 'Bearer %s' % self.accesstoken} diff --git a/spotframework/util/monthstrings.py b/spotframework/util/monthstrings.py index 580ff9e..a4cd74c 100644 --- a/spotframework/util/monthstrings.py +++ b/spotframework/util/monthstrings.py @@ -1,14 +1,14 @@ import datetime -def get_this_month(): +def get_this_month() -> str: return datetime.date.today().strftime('%B %y').lower() -def get_last_month(): +def get_last_month() -> str: month = datetime.date.today().replace(day=1) - datetime.timedelta(days=1) return month.strftime('%B %y').lower() -def get_this_year(): +def get_this_year() -> str: return datetime.date.today().strftime('%y')