diff --git a/admin.py b/admin.py index 6199e85..76788ef 100755 --- a/admin.py +++ b/admin.py @@ -97,7 +97,7 @@ class Admin(Cmd): print('>> user not found') name = input('enter playlist name: ') - playlist = Playlist.collection.parent(user.key).filter('name', '==', name).get() + playlist = user.get_playlist(name) if playlist is None: print('>> playlist not found') diff --git a/music/api/api.py b/music/api/api.py index 08a9acd..041bf40 100644 --- a/music/api/api.py +++ b/music/api/api.py @@ -50,7 +50,7 @@ def all_playlists_route(user=None): @validate_args(('name', str)) def playlist_get_delete_route(user=None): - playlist = Playlist.collection.parent(user.key).filter('name', '==', request.args['name']).get() + playlist = user.get_playlist(request.args['name'], raise_error=False) if playlist is None: return jsonify({'error': f'playlist {request.args["name"]} not found'}), 404 @@ -77,7 +77,7 @@ def playlist_post_put_route(user=None): if request_json['playlist_references'] != -1: for i in request_json['playlist_references']: - playlist = Playlist.collection.parent(user.key).filter('name', '==', i).get() + playlist = user.get_playlist(i, raise_error=False) if playlist is not None: playlist_references.append(db.document(playlist.key)) else: @@ -86,7 +86,7 @@ def playlist_post_put_route(user=None): if len(playlist_references) == 0 and request_json.get('playlist_references', None) != -1: playlist_references = None - searched_playlist = Playlist.collection.parent(user.key).filter('name', '==', playlist_name).get() + searched_playlist = user.get_playlist(playlist_name, raise_error=False) # CREATE if request.method == 'PUT': @@ -129,34 +129,32 @@ def playlist_post_put_route(user=None): if searched_playlist is None: return jsonify({'error': "playlist doesn't exist"}), 400 - playlist = Playlist.collection.parent(user.key).filter('name', '==', playlist_name).get() - # ATTRIBUTES for rec_key, rec_item in request_json.items(): # type and parts require extra validation if rec_key in [k for k in Playlist.mutable_keys if k not in ['type', 'parts', 'playlist_references']]: - setattr(playlist, rec_key, request_json[rec_key]) + setattr(searched_playlist, rec_key, request_json[rec_key]) # COMPONENTS if request_json.get('parts'): if request_json['parts'] == -1: - playlist.parts = [] + searched_playlist.parts = [] else: - playlist.parts = request_json['parts'] + searched_playlist.parts = request_json['parts'] if playlist_references is not None: if playlist_references == -1: - playlist.playlist_references = [] + searched_playlist.playlist_references = [] else: - playlist.playlist_references = playlist_references + searched_playlist.playlist_references = playlist_references # ATTRIBUTE WITH CHECKS if request_json.get('type'): playlist_type = request_json['type'].strip().lower() if playlist_type in ['default', 'recents', 'fmchart']: - playlist.type = playlist_type + searched_playlist.type = playlist_type - playlist.update() + searched_playlist.update() logger.info(f'updated {user.username} / {playlist_name}') return jsonify({"message": 'playlist updated', "status": "success"}), 200 @@ -304,7 +302,7 @@ def run_users(user=None): @validate_args(('name', str)) def image(user=None): - _playlist = Playlist.collection.parent(user.key).filter('name', '==', request.args['name']).get() + _playlist = user.get_playlist(request.args['name'], raise_error=False) if _playlist is None: return jsonify({'error': "playlist not found"}), 404 diff --git a/music/model/user.py b/music/model/user.py index 0a6015d..84b1c65 100644 --- a/music/model/user.py +++ b/music/model/user.py @@ -1,8 +1,14 @@ +import logging + from fireo.models import Model from fireo.fields import TextField, BooleanField, DateTime, NumberField +from music.model.playlist import Playlist + from werkzeug.security import check_password_hash +logger = logging.getLogger(__name__) + class User(Model): class Meta: @@ -40,3 +46,46 @@ class User(Model): to_return.pop('key', None) return to_return + + def get_playlist(self, playlist_name: str, single_return=True, raise_error=True): + """Get a user's playlist by name with smart case sensitivity + + Will return an exact match if possible, otherwise will return the first case-insensitive match + + Args: + playlist_name (str): Subject playlist name + single_return (bool, optional): Return the best match, otherwise return (, ). will be None if not found. Defaults to True. + raise_error (bool, optional): Raise a NameError if nothing found. Defaults to True. + + Raises: + NameError: If no matching playlists found + + Returns: + Optional[Playlist] or (, ): Found user's playlists + """ + + smart_playlists = Playlist.collection.parent(self.key).fetch() + + exact_match = None + matches = list() + for playlist in smart_playlists: + if playlist.name == playlist_name: + exact_match = playlist + if playlist.name.lower() == playlist_name.lower(): + matches.append(playlist) + + if len(matches) == 0: + # NO PLAYLIST FOUND + logger.critical(f'playlist not found {self.username} / {playlist_name}') + if raise_error: + raise NameError(f'Playlist {playlist_name} not found for {self.username}') + else: + return None + + if single_return: + if exact_match: + return exact_match + else: + return matches[0] + else: + return exact_match, matches diff --git a/music/tasks/refresh_lastfm_stats.py b/music/tasks/refresh_lastfm_stats.py index 15fb689..51c9e34 100644 --- a/music/tasks/refresh_lastfm_stats.py +++ b/music/tasks/refresh_lastfm_stats.py @@ -25,7 +25,7 @@ def refresh_lastfm_track_stats(username, playlist_name): spotnet = database.get_authed_spotify_network(user) counter = Counter(fmnet=fmnet, spotnet=spotnet) - playlist = Playlist.collection.parent(user.key).filter('name', '==', playlist_name).get() + playlist = user.get_playlist(playlist_name) if playlist is None: logger.critical(f'playlist {playlist_name} for {username} not found') @@ -71,7 +71,7 @@ def refresh_lastfm_album_stats(username, playlist_name): spotnet = database.get_authed_spotify_network(user) counter = Counter(fmnet=fmnet, spotnet=spotnet) - playlist = Playlist.collection.parent(user.key).filter('name', '==', playlist_name).get() + playlist = user.get_playlist(playlist_name) if playlist is None: logger.critical(f'playlist {playlist_name} for {username} not found') @@ -117,7 +117,7 @@ def refresh_lastfm_artist_stats(username, playlist_name): spotnet = database.get_authed_spotify_network(user) counter = Counter(fmnet=fmnet, spotnet=spotnet) - playlist = Playlist.collection.parent(user.key).filter('name', '==', playlist_name).get() + playlist = user.get_playlist(playlist_name) if playlist is None: logger.critical(f'playlist {playlist_name} for {username} not found') diff --git a/music/tasks/run_user_playlist.py b/music/tasks/run_user_playlist.py index 3c08351..9b48e7a 100644 --- a/music/tasks/run_user_playlist.py +++ b/music/tasks/run_user_playlist.py @@ -56,13 +56,10 @@ def run_user_playlist(user: User, playlist: Playlist, spotnet: SpotNetwork = Non if isinstance(playlist, str): playlist_name = playlist - playlist = Playlist.collection.parent(user.key).filter('name', '==', playlist_name).get() + playlist = user.get_playlist(playlist_name) + else: - playlist_name = playlist.name - - if playlist is None: - logger.critical(f'playlist not found {username} / {playlist_name}') - raise NameError(f'Playlist {playlist_name} not found for {username}') + playlist_name = playlist.name if playlist.uri is None: logger.critical(f'no playlist id to populate {username} / {playlist_name}') diff --git a/tests/test_model.py b/tests/test_model.py index 039ab68..a155b11 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,4 +18,42 @@ class TestUser(unittest.TestCase): for user in users: for key in ['password', 'access_token', 'refresh_token', 'token_expiry', 'id', 'key']: - self.assertNotIn(key, user.to_dict()) \ No newline at end of file + self.assertNotIn(key, user.to_dict()) + + def test_get_playlist(self): + test_user = User.collection.filter('username', '==', "test").get() + + test_playlist = test_user.get_playlist("test_playlist") + self.assertIsNotNone(test_playlist) + + def test_get_playlist_all_returned(self): + test_user = User.collection.filter('username', '==', "test").get() + + exact, matches = test_user.get_playlist("test_playlist", single_return=False) + self.assertIsNotNone(exact) + self.assertEqual(len(matches), 1) + + def test_get_playlist_wrong_case(self): + test_user = User.collection.filter('username', '==', "test").get() + + test_playlist = test_user.get_playlist("TEST_PLAYLIST") + self.assertIsNotNone(test_playlist) + + def test_get_playlist_wrong_case_not_exact(self): + test_user = User.collection.filter('username', '==', "test").get() + + exact, matches = test_user.get_playlist("TEST_PLAYLIST", single_return=False) + self.assertIsNone(exact) + self.assertEqual(len(matches), 1) + + def test_get_playlist_missing_key(self): + test_user = User.collection.filter('username', '==', "test").get() + + with self.assertRaises(NameError): + test_playlist = test_user.get_playlist("test_playlist_missing") + + def test_get_playlist_missing_key_without_error(self): + test_user = User.collection.filter('username', '==', "test").get() + + test_playlist = test_user.get_playlist("test_playlist_missing", raise_error=False) + self.assertIsNone(test_playlist) \ No newline at end of file