From cceba0bd78c384c700ae8861dac57d5de7850ab2 Mon Sep 17 00:00:00 2001 From: aj Date: Fri, 7 Aug 2020 10:59:46 +0100 Subject: [PATCH] filter json by dataclass keys, multi-type uris on network methods, introducing playlists --- backup.py | 2 +- spotframework/engine/playlistengine.py | 13 +- spotframework/model/__init__.py | 29 ++++ spotframework/model/album.py | 44 ++---- spotframework/model/artist.py | 13 +- spotframework/model/playlist.py | 9 +- spotframework/model/podcast.py | 110 +++++++++++++++ spotframework/model/track.py | 60 +++++--- spotframework/model/uri.py | 2 + spotframework/model/user.py | 5 +- spotframework/net/network.py | 185 +++++++++---------------- spotframework/util/__init__.py | 10 ++ spotframework/util/decorators.py | 30 ++++ 13 files changed, 322 insertions(+), 190 deletions(-) create mode 100644 spotframework/model/podcast.py diff --git a/backup.py b/backup.py index a0aa7f3..7412ea2 100644 --- a/backup.py +++ b/backup.py @@ -35,7 +35,7 @@ if __name__ == '__main__': for playlist in playlists: try: - playlist.tracks = network.get_playlist_tracks(playlist.uri) + playlist.tracks = network.get_playlist_tracks(uri=playlist.uri) csvwrite.export_playlist(playlist, totalpath) except SpotifyNetworkException: logger.exception(f'error occured during {playlist.name} track retrieval') diff --git a/spotframework/engine/playlistengine.py b/spotframework/engine/playlistengine.py index d1a497f..877a666 100644 --- a/spotframework/engine/playlistengine.py +++ b/spotframework/engine/playlistengine.py @@ -145,9 +145,10 @@ class PlaylistEngine: counter_track = track if counter_track != tracks_to_sort[0]: - self.net.reorder_playlist_tracks(playlist.uri, - i + tracks_to_sort.index(counter_track), - 1, i) + self.net.reorder_playlist_tracks(uri=playlist.uri, + range_start=i + tracks_to_sort.index(counter_track), + range_length=1, + insert_before=i) tracks_to_sort.remove(counter_track) def execute_playlist(self, @@ -179,7 +180,7 @@ class PlaylistEngine: logger.error('no string generated') return None - resp = self.net.change_playlist_details(uri, description=string) + resp = self.net.change_playlist_details(uri=uri, description=string) if resp: return resp else: @@ -230,7 +231,7 @@ class PlaylistSource(TrackSource): playlist: FullPlaylist) -> None: logger.info(f"pulling tracks for {playlist.name}") - tracks = self.net.get_playlist_tracks(playlist.uri) + tracks = self.net.get_playlist_tracks(uri=playlist.uri) if tracks and len(tracks) > 0: playlist.tracks = tracks else: @@ -263,7 +264,7 @@ class PlaylistSource(TrackSource): if playlist: playlists.append(playlist) else: - playlist = self.net.get_playlist(uri) + playlist = self.net.get_playlist(uri=uri) if playlist: playlists.append(playlist) self.playlists.append(playlist) diff --git a/spotframework/model/__init__.py b/spotframework/model/__init__.py index e69de29..a1ef949 100644 --- a/spotframework/model/__init__.py +++ b/spotframework/model/__init__.py @@ -0,0 +1,29 @@ +import logging + +logger = logging.getLogger(__name__) + +def init_with_key_filter(class_type: type, dict_obj: dict = None, merge_unrecognised_keys: bool = True, **kwargs): + + if '__dataclass_fields__' not in class_type.__dict__: + logger.error(f'{class_type} not a dataclass') + return + + if dict_obj is None: + dict_obj = dict() + + filtered_dict = dict() + unrecognised_keys = dict() + for i, j in {**dict_obj, **kwargs}.items(): + if i in class_type.__dict__['__dataclass_fields__'].keys(): + filtered_dict[i] = j + else: + unrecognised_keys[i] = j + logger.warning(f'unrecognised key found for {class_type}: {i} {type(j)}') + + obj = class_type(**filtered_dict) + + if merge_unrecognised_keys: + for i, j in unrecognised_keys.items(): + setattr(obj, i, j) + + return obj \ No newline at end of file diff --git a/spotframework/model/album.py b/spotframework/model/album.py index 8987c8d..f6f190a 100644 --- a/spotframework/model/album.py +++ b/spotframework/model/album.py @@ -3,11 +3,15 @@ from enum import Enum from dataclasses import dataclass from datetime import datetime from typing import List, Union +import logging from spotframework.model.uri import Uri import spotframework.model.artist import spotframework.model.service import spotframework.model.track +from spotframework.model import init_with_key_filter + +logger = logging.getLogger(__name__) @dataclass class SimplifiedAlbum: @@ -39,14 +43,14 @@ class SimplifiedAlbum: self.uri = Uri(self.uri) if self.uri: - if self.uri.object_type != Uri.ObjectType.album: + if self.uri.object_type not in [Uri.ObjectType.album, Uri.ObjectType.show]: raise TypeError('provided uri not for an album') if all((isinstance(i, dict) for i in self.artists)): - self.artists = [spotframework.model.artist.SimplifiedArtist(**i) for i in self.artists] + self.artists = [init_with_key_filter(spotframework.model.artist.SimplifiedArtist, i) for i in self.artists] if all((isinstance(i, dict) for i in self.images)): - self.images = [spotframework.model.service.Image(**i) for i in self.images] + self.images = [init_with_key_filter(spotframework.model.service.Image, i) for i in self.images] if isinstance(self.release_date, str): if self.release_date_precision == 'year': @@ -55,6 +59,11 @@ class SimplifiedAlbum: self.release_date = datetime.strptime(self.release_date, '%Y-%m') elif self.release_date_precision == 'day': self.release_date = datetime.strptime(self.release_date, '%Y-%m-%d') + else: + logger.error(f'invalid release date type {self.release_date_precision} - {self.release_date}') + + elif self.release_date is None and self.release_date_precision is None: # for podcasts + self.release_date = datetime(year=1900, month=1, day=1) @property def artists_names(self) -> str: @@ -82,33 +91,10 @@ class AlbumFull(SimplifiedAlbum): tracks: List[spotframework.model.track.SimplifiedTrack] = None def __post_init__(self): - - if isinstance(self.album_type, str): - self.album_type = SimplifiedAlbum.Type[self.album_type] - - if isinstance(self.uri, str): - self.uri = Uri(self.uri) - - if self.uri: - if self.uri.object_type != Uri.ObjectType.album: - raise TypeError('provided uri not for an album') - - if all((isinstance(i, dict) for i in self.artists)): - self.artists = [spotframework.model.artist.SimplifiedArtist(**i) for i in self.artists] - - if all((isinstance(i, dict) for i in self.images)): - self.images = [spotframework.model.service.Image(**i) for i in self.images] + super().__post_init__() if all((isinstance(i, dict) for i in self.tracks)): - self.tracks = [spotframework.model.track.SimplifiedTrack(**i) for i in self.tracks] - - if isinstance(self.release_date, str): - if self.release_date_precision == 'year': - self.release_date = datetime.strptime(self.release_date, '%Y') - elif self.release_date_precision == 'month': - self.release_date = datetime.strptime(self.release_date, '%Y-%m') - elif self.release_date_precision == 'day': - self.release_date = datetime.strptime(self.release_date, '%Y-%m-%d') + self.tracks = [init_with_key_filter(spotframework.model.track.SimplifiedTrack, i) for i in self.tracks] @dataclass @@ -118,7 +104,7 @@ class LibraryAlbum: def __post_init__(self): if isinstance(self.album, dict): - self.album = AlbumFull(**self.album) + self.album = init_with_key_filter(AlbumFull, self.album) if isinstance(self.added_at, str): self.added_at = datetime.strptime(self.added_at, '%Y-%m-%dT%H:%M:%S%z') diff --git a/spotframework/model/artist.py b/spotframework/model/artist.py index 4cd5c02..175a464 100644 --- a/spotframework/model/artist.py +++ b/spotframework/model/artist.py @@ -3,6 +3,8 @@ from typing import List, Union from spotframework.model.uri import Uri from spotframework.model.service import Image +from spotframework.model import init_with_key_filter + @dataclass class SimplifiedArtist: @@ -18,7 +20,7 @@ class SimplifiedArtist: self.uri = Uri(self.uri) if self.uri: - if self.uri.object_type != Uri.ObjectType.artist: + if self.uri.object_type not in [Uri.ObjectType.artist, Uri.ObjectType.show]: raise TypeError('provided uri not for an artist') def __str__(self): @@ -32,12 +34,7 @@ class ArtistFull(SimplifiedArtist): popularity: int def __post_init__(self): - if isinstance(self.uri, str): - self.uri = Uri(self.uri) - - if self.uri: - if self.uri.object_type != Uri.ObjectType.artist: - raise TypeError('provided uri not for an artist') + super().__post_init__() if all((isinstance(i, dict) for i in self.images)): - self.images = [Image(**i) for i in self.images] + self.images = [init_with_key_filter(Image, i) for i in self.images] diff --git a/spotframework/model/playlist.py b/spotframework/model/playlist.py index 138bcbf..3509c14 100644 --- a/spotframework/model/playlist.py +++ b/spotframework/model/playlist.py @@ -3,6 +3,7 @@ from spotframework.model.user import PublicUser from spotframework.model.track import TrackFull, PlaylistTrack from spotframework.model.uri import Uri from spotframework.model.service import Image +from spotframework.model import init_with_key_filter from tabulate import tabulate from typing import List, Union import logging @@ -39,10 +40,10 @@ class SimplifiedPlaylist: raise TypeError('provided uri not for a playlist') if all((isinstance(i, dict) for i in self.images)): - self.images = [Image(**i) for i in self.images] + self.images = [init_with_key_filter(Image, i) for i in self.images] if isinstance(self.owner, dict): - self.owner = PublicUser(**self.owner) + self.owner = init_with_key_filter(PublicUser, self.owner) def has_tracks(self) -> bool: return bool(len(self.tracks) > 0) @@ -131,10 +132,10 @@ class FullPlaylist(SimplifiedPlaylist): raise TypeError('provided uri not for a playlist') if all((isinstance(i, dict) for i in self.images)): - self.images = [Image(**i) for i in self.images] + self.images = [init_with_key_filter(Image, i) for i in self.images] if isinstance(self.owner, dict): - self.owner = PublicUser(**self.owner) + self.owner = init_with_key_filter(PublicUser, self.owner) def __str__(self): prefix = f'\n==={self.name}===\n\n' if self.name is not None else '' diff --git a/spotframework/model/podcast.py b/spotframework/model/podcast.py new file mode 100644 index 0000000..fe3d617 --- /dev/null +++ b/spotframework/model/podcast.py @@ -0,0 +1,110 @@ +from typing import List, Union +from dataclasses import dataclass +from datetime import datetime + +from spotframework.model import init_with_key_filter + +from spotframework.model.service import Image +from spotframework.model.uri import Uri + +@dataclass +class ResumePoint: + fully_played: bool + resume_position_ms: int + +@dataclass +class SimplifiedEpisode: + audio_preview_url: str + description: str + duration_ms: int + explicit: bool + external_urls: dict + href: str + id: str + images: List[Image] + is_externally_hosted: bool + is_playable: bool + languages: List[str] + name: str + release_date: datetime + release_date_precision: str + resume_point: ResumePoint + type: str + uri: Union[str, Uri] + + def __post_init__(self): + + if isinstance(self.uri, str): + self.uri = Uri(self.uri) + + if self.uri: + if self.uri.object_type != Uri.ObjectType.episode: + raise TypeError('provided uri not for an episode') + + if isinstance(self.resume_point, ResumePoint): + self.resume_point = init_with_key_filter(ResumePoint, self.resume_point) + + if all((isinstance(i, dict) for i in self.images)): + self.images = [init_with_key_filter(Image, i) for i in self.images] + + if isinstance(self.release_date, str): + if self.release_date_precision == 'year': + self.release_date = datetime.strptime(self.release_date, '%Y') + elif self.release_date_precision == 'month': + self.release_date = datetime.strptime(self.release_date, '%Y-%m') + elif self.release_date_precision == 'day': + self.release_date = datetime.strptime(self.release_date, '%Y-%m-%d') + +@dataclass +class SimplifiedShow: + available_markets: List[str] + copyrights: List[dict] + description: str + explicit: bool + external_urls: dict + href: str + id: str + images: List[Image] + is_externally_hosted: bool + languages: List[str] + media_type: str + name: str + publisher: str + type: str + uri: Union[str, Uri] + + def __post_init__(self): + + if isinstance(self.uri, str): + self.uri = Uri(self.uri) + + if self.uri: + if self.uri.object_type != Uri.ObjectType.episode: + raise TypeError('provided uri not for an episode') + + if all((isinstance(i, dict) for i in self.images)): + self.images = [init_with_key_filter(Image, i) for i in self.images] + +@dataclass +class EpisodeFull(SimplifiedEpisode): + show: SimplifiedShow + + def __post_init__(self): + super().__post_init__() + + if isinstance(self.show, SimplifiedShow): + self.show = init_with_key_filter(SimplifiedShow, self.show) + +@dataclass +class ShowFull(SimplifiedShow): + episodes: List[SimplifiedEpisode] + +@dataclass +class SavedShow: + added_at: datetime + show: ShowFull + + def __post_init__(self): + + if isinstance(self.show, ShowFull): + self.show = init_with_key_filter(ShowFull, self.show) diff --git a/spotframework/model/track.py b/spotframework/model/track.py index 30fef89..11374c2 100644 --- a/spotframework/model/track.py +++ b/spotframework/model/track.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Union, List from datetime import datetime from dataclasses import dataclass, field +import logging import spotframework.model from spotframework.model.uri import Uri @@ -10,6 +11,11 @@ import spotframework.model.album import spotframework.model.artist import spotframework.model.service import spotframework.model.user +from spotframework.model.podcast import EpisodeFull + +from spotframework.model import init_with_key_filter + +logger = logging.getLogger(__name__) @dataclass @@ -37,11 +43,11 @@ class SimplifiedTrack: self.uri = Uri(self.uri) if self.uri: - if self.uri.object_type != Uri.ObjectType.track: + if self.uri.object_type not in [Uri.ObjectType.track, Uri.ObjectType.episode]: raise TypeError('provided uri not for a track') if all((isinstance(i, dict) for i in self.artists)): - self.artists = [spotframework.model.artist.SimplifiedArtist(**i) for i in self.artists] + self.artists = [init_with_key_filter(spotframework.model.artist.SimplifiedArtist, i) for i in self.artists] @property def artists_names(self) -> str: @@ -71,18 +77,10 @@ class TrackFull(SimplifiedTrack): return self.album.artists_names def __post_init__(self): - if isinstance(self.uri, str): - self.uri = Uri(self.uri) - - if self.uri: - if self.uri.object_type != Uri.ObjectType.track: - raise TypeError('provided uri not for a track') - - if all((isinstance(i, dict) for i in self.artists)): - self.artists = [spotframework.model.artist.SimplifiedArtist(**i) for i in self.artists] + super().__post_init__() if isinstance(self.album, dict): - self.album = spotframework.model.album.SimplifiedAlbum(**self.album) + self.album = init_with_key_filter(spotframework.model.album.SimplifiedAlbum, self.album) def __eq__(self, other): return isinstance(other, TrackFull) and other.uri == self.uri @@ -95,7 +93,7 @@ class LibraryTrack: def __post_init__(self): if isinstance(self.track, dict): - self.track = TrackFull(**self.track) + self.track = init_with_key_filter(TrackFull, self.track) if isinstance(self.added_at, str): self.added_at = datetime.strptime(self.added_at, '%Y-%m-%dT%H:%M:%S%z') @@ -107,15 +105,31 @@ class PlaylistTrack: added_by: spotframework.model.user.PublicUser is_local: bool primary_color: str - track: TrackFull + track: Union[TrackFull, EpisodeFull] video_thumbnail: dict def __post_init__(self): if isinstance(self.track, dict): - self.track = TrackFull(**self.track) + + # below seems more intuitive, currently parsing episode to track/album/artist structure for + # serialising over api, below could be implemented + + # obj_type = None + # if self.track['type'] == 'track': + # obj_type = TrackFull + # + # if self.track['type'] == 'episode': + # obj_type = EpisodeFull + # + # if obj_type is None: + # raise TypeError(f'unkown obj type found {self.track["type"]}') + + obj_type = TrackFull + + self.track = init_with_key_filter(obj_type, self.track) if isinstance(self.added_by, dict): - self.added_by = spotframework.model.user.PublicUser(**self.added_by) + self.added_by = init_with_key_filter(spotframework.model.user.PublicUser, self.added_by) if isinstance(self.added_at, str): self.added_at = datetime.strptime(self.added_at, '%Y-%m-%dT%H:%M:%S%z') @@ -129,9 +143,9 @@ class PlayedTrack: def __post_init__(self): if isinstance(self.context, dict): - self.context = Context(**self.context) + self.context = init_with_key_filter(Context, self.context) if isinstance(self.track, dict): - self.track = TrackFull(**self.track) + self.track = init_with_key_filter(TrackFull, self.track) if isinstance(self.played_at, str): self.played_at = datetime.strptime(self.played_at, '%Y-%m-%dT%H:%M:%S%z') @@ -345,13 +359,13 @@ class CurrentlyPlaying: def __post_init__(self): if isinstance(self.context, Context): - self.context = Context(**self.context) + self.context = init_with_key_filter(Context, self.context) if isinstance(self.item, spotframework.model.track.SimplifiedTrack): - self.item = spotframework.model.track.SimplifiedTrack(**self.item) + self.item = init_with_key_filter(spotframework.model.track.SimplifiedTrack, self.item) if isinstance(self.device, Device): - self.device = Device(**self.device) + self.device = init_with_key_filter(Device, self.device) def __eq__(self, other): return isinstance(other, CurrentlyPlaying) and other.item == self.item and other.context == self.context @@ -389,7 +403,7 @@ class Recommendations: def __post_init__(self): if all((isinstance(i, dict) for i in self.seeds)): - self.seeds = [RecommendationsSeed(**i) for i in self.seeds] + self.seeds = [init_with_key_filter(RecommendationsSeed, i) for i in self.seeds] if all((isinstance(i, dict) for i in self.tracks)): - self.tracks = [spotframework.model.track.TrackFull(**i) for i in self.tracks] \ No newline at end of file + self.tracks = [init_with_key_filter(spotframework.model.track.TrackFull, i) for i in self.tracks] \ No newline at end of file diff --git a/spotframework/model/uri.py b/spotframework/model/uri.py index 2698d01..0bd0dd9 100644 --- a/spotframework/model/uri.py +++ b/spotframework/model/uri.py @@ -9,6 +9,8 @@ class Uri: artist = 3 user = 4 playlist = 5 + episode = 6 + show = 7 def __init__(self, input_string: str): self.object_type = None diff --git a/spotframework/model/user.py b/spotframework/model/user.py index 845fed1..45c3ee3 100644 --- a/spotframework/model/user.py +++ b/spotframework/model/user.py @@ -2,6 +2,7 @@ from typing import Union, List from dataclasses import dataclass, field from spotframework.model.uri import Uri from spotframework.model.service import Image +from spotframework.model import init_with_key_filter @dataclass @@ -26,7 +27,7 @@ class PublicUser: raise TypeError('provided uri not for a user') if all((isinstance(i, dict) for i in self.images)): - self.images = [Image(**i) for i in self.images] + self.images = [init_with_key_filter(Image, i) for i in self.images] def __str__(self): return f'{self.id}' @@ -47,5 +48,5 @@ class PrivateUser(PublicUser): raise TypeError('provided uri not for a user') if all((isinstance(i, dict) for i in self.images)): - self.images = [Image(**i) for i in self.images] + self.images = [init_with_key_filter(Image, i) for i in self.images] diff --git a/spotframework/net/network.py b/spotframework/net/network.py index 3efe8e0..0691fe5 100644 --- a/spotframework/net/network.py +++ b/spotframework/net/network.py @@ -8,14 +8,19 @@ from typing import List, Optional, Union import datetime from json import JSONDecodeError -from spotframework.model.artist import ArtistFull -from spotframework.model.user import PublicUser from spotframework.net.user import NetworkUser + +from spotframework.model import init_with_key_filter + +from spotframework.model.user import PublicUser from spotframework.model.playlist import SimplifiedPlaylist, FullPlaylist +from spotframework.model.artist import ArtistFull +from spotframework.model.album import AlbumFull, LibraryAlbum, SimplifiedAlbum from spotframework.model.track import SimplifiedTrack, TrackFull, PlaylistTrack, PlayedTrack, LibraryTrack, \ AudioFeatures, Device, CurrentlyPlaying, Recommendations -from spotframework.model.album import AlbumFull, LibraryAlbum, SimplifiedAlbum +from spotframework.model.podcast import SimplifiedEpisode, EpisodeFull from spotframework.model.uri import Uri +from spotframework.util.decorators import inject_uri limit = 50 @@ -245,9 +250,9 @@ class Network: def refresh_user_info(self): self.user.user = self.get_current_user() + @inject_uri(uris=False) def get_playlist(self, - uri: Uri = None, - uri_string: str = None, + uri: Uri, tracks: bool = True) -> FullPlaylist: """get playlist object with tracks for uri @@ -257,16 +262,10 @@ class Network: :return: playlist object """ - if uri is None and uri_string is None: - raise NameError('no uri provided') - - if uri_string is not None: - uri = Uri(uri_string) - logger.info(f"retrieving {uri}") resp = self.get_request(f'playlists/{uri.object_id}') - playlist = FullPlaylist(**resp) + playlist = init_with_key_filter(FullPlaylist, resp) if resp.get('tracks') and tracks: if 'next' in resp['tracks']: @@ -275,10 +274,10 @@ class Network: track_pager = PageCollection(net=self, page=resp['tracks']) track_pager.continue_iteration() - playlist.tracks = [PlaylistTrack(**i) for i in track_pager.items] + playlist.tracks = [init_with_key_filter(PlaylistTrack, i) for i in track_pager.items] else: logger.debug(f'parsing {len(resp.get("tracks"))} tracks for {uri}') - playlist.tracks = [PlaylistTrack(**i) for i in resp.get('tracks', [])] + playlist.tracks = [init_with_key_filter(PlaylistTrack, i) for i in resp.get('tracks', [])] return playlist @@ -309,7 +308,7 @@ class Network: public=public, collaborative=collaborative, description=description) - return FullPlaylist(**req) + return init_with_key_filter(FullPlaylist, req) def get_playlists(self, response_limit: int = None) -> Optional[List[SimplifiedPlaylist]]: """get current users playlists @@ -325,7 +324,7 @@ class Network: pager.total_limit = response_limit pager.iterate() - return_items = [SimplifiedPlaylist(**i) for i in pager.items] + return_items = [init_with_key_filter(SimplifiedPlaylist, i) for i in pager.items] if len(return_items) == 0: logger.error('no playlists returned') @@ -346,7 +345,7 @@ class Network: pager.total_limit = response_limit pager.iterate() - return_items = [LibraryAlbum(**i) for i in pager.items] + return_items = [init_with_key_filter(LibraryAlbum, i) for i in pager.items] if len(return_items) == 0: logger.error('no albums returned') @@ -367,7 +366,7 @@ class Network: pager.total_limit = response_limit pager.iterate() - return_items = [LibraryTrack(**i) for i in pager.items] + return_items = [init_with_key_filter(LibraryTrack, i) for i in pager.items] if len(return_items) == 0: logger.error('no tracks returned') @@ -394,9 +393,9 @@ class Network: else: logger.error('no playlists returned to filter') + @inject_uri(uris=False) def get_playlist_tracks(self, - uri: Uri = None, - uri_string: str = None, + uri: Uri, response_limit: int = None) -> List[PlaylistTrack]: """get list of playlists tracks for uri @@ -406,12 +405,6 @@ class Network: :return: list of playlist tracks if available """ - if uri is None and uri_string is None: - raise NameError('no uri provided') - - if uri_string is not None: - uri = Uri(uri_string) - logger.info(f"paging tracks for {uri}") pager = PageCollection(net=self, url=f'playlists/{uri.object_id}/tracks', name='getPlaylistTracks') @@ -419,7 +412,7 @@ class Network: pager.total_limit = response_limit pager.iterate() - return_items = [PlaylistTrack(**i) for i in pager.items] + return_items = [init_with_key_filter(PlaylistTrack, i) for i in pager.items] if len(return_items) == 0: logger.error('no tracks returned') @@ -435,7 +428,7 @@ class Network: if len(resp['devices']) == 0: logger.error('no devices returned') - return [Device(**i) for i in resp['devices']] + return [init_with_key_filter(Device, i) for i in resp['devices']] def get_recently_played_tracks(self, response_limit: int = None, @@ -468,7 +461,7 @@ class Network: pager.total_limit = 20 pager.continue_iteration() - return [PlayedTrack(**i) for i in pager.items] + return [init_with_key_filter(PlayedTrack, i) for i in pager.items] def get_player(self) -> CurrentlyPlaying: """get currently playing snapshot (player)""" @@ -476,7 +469,7 @@ class Network: logger.info("polling player") resp = self.get_request('me/player') - return CurrentlyPlaying(**resp) + return init_with_key_filter(CurrentlyPlaying, resp) def get_device_id(self, device_name: str) -> Optional[str]: """return device id of device as searched for by name @@ -498,27 +491,20 @@ class Network: logger.info(f"getting current user") resp = self.get_request('me') - return PublicUser(**resp) + return init_with_key_filter(PublicUser, resp) def change_playback_device(self, device_id: str): """migrate playback to different device""" logger.info(f'shifting playback to {device_id}') self.put_request('me/player', device_ids=[device_id], play=True) + @inject_uri(uri_optional=True, uris_optional=True) def play(self, uri: Uri = None, - uri_string: str = None, uris: List[Uri] = None, - uri_strings: List[str] = None, deviceid: str = None): """begin playback""" - if uri_string is not None: - uri = Uri(uri_string) - - if uri_strings is not None: - uris = [Uri(i) for i in uri_strings] - logger.info(f"{uri}{' ' + deviceid if deviceid is not None else ''}") if deviceid is not None: @@ -600,25 +586,19 @@ class Network: else: logger.error(f"{volume} not accepted value") + @inject_uri def replace_playlist_tracks(self, - uri: Uri = None, - uri_string: str = None, - uris: List[Uri] = None, - uri_strings: List[str] = None) -> Optional[List[str]]: - - if uri_string is not None: - uri = Uri(uri_string) - - if uri_strings is not None: - uris = [Uri(i) for i in uri_strings] + uri: Uri, + uris: List[Uri]) -> Optional[List[str]]: logger.info(f"replacing {uri} with {'0' if uris is None else len(uris)} tracks") self.put_request(f'playlists/{uri.object_id}/tracks', uris=[str(i) for i in uris[:100]]) if len(uris) > 100: - return self.add_playlist_tracks(uri, uris[100:]) + return self.add_playlist_tracks(uri=uri, uris=uris[100:]) + @inject_uri(uris=False) def change_playlist_details(self, uri: Uri, name: str = None, @@ -638,6 +618,7 @@ class Network: collaborative=collaborative, description=description) + @inject_uri def add_playlist_tracks(self, uri: Uri, uris: List[Uri]) -> List[str]: logger.info(f"adding {len(uris)} tracks to {uri}") @@ -648,7 +629,7 @@ class Network: ] if len(uris) > 100: - snapshot_ids += self.add_playlist_tracks(uri, uris[100:]) + snapshot_ids += self.add_playlist_tracks(uri=uri, uris=uris[100:]) return snapshot_ids @@ -673,7 +654,7 @@ class Network: if len(params) == 1: logger.warning('update dictionairy length 0') else: - return Recommendations(**self.get_request('recommendations', params=params)) + return init_with_key_filter(Recommendations, self.get_request('recommendations', params=params)) def write_playlist_object(self, playlist: FullPlaylist, @@ -686,22 +667,23 @@ class Network: self.replace_playlist_tracks(uri=playlist.uri, uris=[]) elif playlist.tracks: if append_tracks: - self.add_playlist_tracks(playlist.uri, [i.uri for i in playlist.tracks if - isinstance(i, SimplifiedTrack)]) + self.add_playlist_tracks(uri=playlist.uri, uris=[i.uri for i in playlist.tracks if + isinstance(i, SimplifiedTrack)]) else: self.replace_playlist_tracks(uri=playlist.uri, uris=[i.uri for i in playlist.tracks if isinstance(i, SimplifiedTrack)]) if playlist.name or playlist.collaborative or playlist.public or playlist.description: - self.change_playlist_details(playlist.uri, - playlist.name, - playlist.public, - playlist.collaborative, - playlist.description) + self.change_playlist_details(uri=playlist.uri, + name=playlist.name, + public=playlist.public, + collaborative=playlist.collaborative, + description=playlist.description) else: logger.error('playlist has no id') + @inject_uri(uris=False) def reorder_playlist_tracks(self, uri: Uri, range_start: int, @@ -725,6 +707,7 @@ class Network: range_length=range_length, insert_before=insert_before) + @inject_uri(uri=False) def get_track_audio_features(self, uris: List[Uri]) -> Optional[List[AudioFeatures]]: logger.info(f'getting {len(uris)} features') @@ -734,7 +717,7 @@ class Network: resp = self.get_request(url='audio-features', ids=','.join(i.object_id for i in chunk)) if resp.get('audio_features', None): - return [AudioFeatures(**i) for i in resp['audio_features']] + return [init_with_key_filter(AudioFeatures, i) for i in resp['audio_features']] else: logger.error('no audio features included') @@ -747,7 +730,7 @@ class Network: logger.info(f'populating {len(tracks)} features') if isinstance(tracks, SimplifiedTrack): - audio_features = self.get_track_audio_features([tracks.uri]) + audio_features = self.get_track_audio_features(uris=[tracks.uri]) if audio_features: if len(audio_features) == 1: @@ -760,7 +743,7 @@ class Network: elif isinstance(tracks, List): if all(isinstance(i, SimplifiedTrack) for i in tracks): - audio_features = self.get_track_audio_features([i.uri for i in tracks]) + audio_features = self.get_track_audio_features(uris=[i.uri for i in tracks]) if audio_features: if len(audio_features) != len(tracks): @@ -775,15 +758,8 @@ class Network: else: raise TypeError('must provide either single or list of spotify tracks') - def get_tracks(self, - uris: List[Uri] = None, - uri_strings: List[str] = None) -> List[TrackFull]: - - if uris is None and uri_strings is None: - raise NameError('no uris provided') - - if uri_strings is not None: - uris = [Uri(i) for i in uri_strings] + @inject_uri(uri=False) + def get_tracks(self, uris: List[Uri]) -> List[TrackFull]: logger.info(f'getting {len(uris)} tracks') @@ -795,31 +771,21 @@ class Network: for chunk in chunked_uris: resp = self.get_request(url='tracks', ids=','.join([i.object_id for i in chunk])) if resp: - tracks += [TrackFull(**i) for i in resp.get('tracks', [])] + tracks += [init_with_key_filter(TrackFull, i) for i in resp.get('tracks', [])] return tracks - def get_track(self, uri: Uri = None, uri_string: str = None) -> Optional[TrackFull]: + @inject_uri(uris=False) + def get_track(self, uri) -> Optional[TrackFull]: - if uri is None and uri_string is None: - raise NameError('no uri provided') - - if uri_string is not None: - uri = Uri(uri_string) - - track = self.get_tracks([uri]) + track = self.get_tracks(uris=[uri]) if len(track) == 1: return track[0] else: return None - def get_albums(self, uris: List[Uri] = None, uri_strings: List[str] = None) -> List[AlbumFull]: - - if uris is None and uri_strings is None: - raise NameError('no uris provided') - - if uri_strings is not None: - uris = [Uri(i) for i in uri_strings] + @inject_uri(uri=False) + def get_albums(self, uris: List[Uri]) -> List[AlbumFull]: logger.info(f'getting {len(uris)} albums') @@ -831,31 +797,21 @@ class Network: for chunk in chunked_uris: resp = self.get_request(url='albums', ids=','.join([i.object_id for i in chunk])) if resp: - albums += [AlbumFull(**i) for i in resp.get('albums', [])] + albums += [init_with_key_filter(AlbumFull, i) for i in resp.get('albums', [])] return albums - def get_album(self, uri: Uri = None, uri_string: str = None) -> Optional[AlbumFull]: + @inject_uri(uris=False) + def get_album(self, uri: Uri) -> Optional[AlbumFull]: - if uri is None and uri_string is None: - raise NameError('no uri provided') - - if uri_string is not None: - uri = Uri(uri_string) - - album = self.get_albums([uri]) + album = self.get_albums(uris=[uri]) if len(album) == 1: return album[0] else: return None - def get_artists(self, uris: List[Uri] = None, uri_strings: List[str] = None) -> List[ArtistFull]: - - if uris is None and uri_strings is None: - raise NameError('no uris provided') - - if uri_strings is not None: - uris = [Uri(i) for i in uri_strings] + @inject_uri(uri=False) + def get_artists(self, uris) -> List[ArtistFull]: logger.info(f'getting {len(uris)} artists') @@ -867,19 +823,14 @@ class Network: for chunk in chunked_uris: resp = self.get_request(url='artists', ids=','.join([i.object_id for i in chunk])) if resp: - artists += [ArtistFull(**i) for i in resp.get('artists', [])] + artists += [init_with_key_filter(ArtistFull, i) for i in resp.get('artists', [])] return artists - def get_artist(self, uri: Uri = None, uri_string: str = None) -> Optional[ArtistFull]: + @inject_uri(uris=False) + def get_artist(self, uri) -> Optional[ArtistFull]: - if uri is None and uri_string is None: - raise NameError('no uri provided') - - if uri_string is not None: - uri = Uri(uri_string) - - artist = self.get_artists([uri]) + artist = self.get_artists(uris=[uri]) if len(artist) == 1: return artist[0] else: @@ -914,10 +865,10 @@ class Network: type=','.join([i.name for i in query_types]), limit=response_limit) - albums = [SimplifiedAlbum(**i) for i in resp.get('albums', {}).get('items', [])] - artists = [ArtistFull(**i) for i in resp.get('artists', {}).get('items', [])] - tracks = [TrackFull(**i) for i in resp.get('tracks', {}).get('items', [])] - playlists = [SimplifiedPlaylist(**i) for i in resp.get('playlists', {}).get('items', [])] + albums = [init_with_key_filter(SimplifiedAlbum, i) for i in resp.get('albums', {}).get('items', [])] + artists = [init_with_key_filter(ArtistFull, i) for i in resp.get('artists', {}).get('items', [])] + tracks = [init_with_key_filter(TrackFull, i) for i in resp.get('tracks', {}).get('items', [])] + playlists = [init_with_key_filter(SimplifiedPlaylist, i) for i in resp.get('playlists', {}).get('items', [])] return SearchResponse(tracks=tracks, albums=albums, artists=artists, playlists=playlists) @@ -996,7 +947,7 @@ class PageCollection: self.iterate(page.next) def add_page(self, page_dict): - page = Page(**page_dict) + page = init_with_key_filter(Page, page_dict) self.pages.append(page) return page diff --git a/spotframework/util/__init__.py b/spotframework/util/__init__.py index 48e9586..9c2e4aa 100644 --- a/spotframework/util/__init__.py +++ b/spotframework/util/__init__.py @@ -15,3 +15,13 @@ def validate_uri_string(uri_string: str): return uri except ValueError: return False + +def get_uri(uri_in): + + if isinstance(uri_in, str): + return Uri(input_string=uri_in) + + if isinstance(uri_in, Uri): + return uri_in + + raise TypeError(f'invalid uri type provided - {type(uri_in)}') diff --git a/spotframework/util/decorators.py b/spotframework/util/decorators.py index 3494c96..5bb73fc 100644 --- a/spotframework/util/decorators.py +++ b/spotframework/util/decorators.py @@ -1,5 +1,8 @@ import functools import logging + +from spotframework.util import get_uri + logger = logging.getLogger(__name__) @@ -14,3 +17,30 @@ def debug(func): print(f"{func.__name__!r} -> {value!r}") return value return wrapper_debug + +def inject_uri(_func=None, *, uri=True, uris=True, uri_optional=False, uris_optional=False): + + def decorator_inject_uri(func): + @functools.wraps(func) + def inject_uri_wrapper(*args, **kwargs): + if uri: + if uri_optional: + kwargs['uri'] = get_uri(kwargs['uri']) if kwargs.get('uri') else None + else: + kwargs['uri'] = get_uri(kwargs['uri']) + + if uris: + if uris_optional: + kwargs['uris'] = [get_uri(i) for i in kwargs['uris']] if kwargs.get('uris') else None + else: + kwargs['uris'] = [get_uri(i) for i in kwargs['uris']] + + return func(*args, **kwargs) + + return inject_uri_wrapper + + if _func is None: + return decorator_inject_uri + else: + return decorator_inject_uri(_func) +