diff --git a/music/api/api.py b/music/api/api.py index dfc9740..0de9dc8 100644 --- a/music/api/api.py +++ b/music/api/api.py @@ -38,7 +38,7 @@ def playlist(username=None): user_playlists = database.get_user_playlists(username) - user_ref = database.get_user_doc_ref(username) + user_ref = database.get_user(username).db_ref playlists = user_ref.collection(u'playlists') if request.method == 'GET' or request.method == 'DELETE': @@ -387,7 +387,8 @@ def run_playlist_task(): @login_or_basic_auth def run_user(username=None): - if database.get_user_doc_ref(username).get().to_dict()['type'] == 'admin': + db_user = database.get_user(username) + if db_user.type == db_user.Type.admin: user_name = request.args.get('username', username) else: user_name = username diff --git a/music/api/spotfm.py b/music/api/spotfm.py index 82cb8f8..3a90e57 100644 --- a/music/api/spotfm.py +++ b/music/api/spotfm.py @@ -159,7 +159,8 @@ def run_users_task(): @login_or_basic_auth def run_user(username=None): - if database.get_user_doc_ref(username).get().to_dict()['type'] == 'admin': + db_user = database.get_user(username) + if db_user.type == db_user.Type.admin: user_name = request.args.get('username', username) else: user_name = username diff --git a/music/auth/auth.py b/music/auth/auth.py index 5c1c983..5e10df5 100644 --- a/music/auth/auth.py +++ b/music/auth/auth.py @@ -32,31 +32,19 @@ def login(): return redirect(url_for('index')) username = username.lower() + user = database.get_user(username) - users = database.get_user_query_stream(username) - - if len(users) == 0: + if user is None: flash('user not found') return redirect(url_for('index')) - if len(users) > 1: - flash('multiple users found') - return redirect(url_for('index')) - - doc = users[0].to_dict() - if doc is None: - flash('username not found') - return redirect(url_for('index')) - - if check_password_hash(doc['password'], password): - - if doc['locked']: + if user.check_password(password): + if user.locked: logger.warning(f'locked account attempt {username}') flash('account locked') return redirect(url_for('index')) - user_reference = db.collection(u'spotify_users').document(u'{}'.format(users[0].id)) - user_reference.update({'last_login': datetime.datetime.utcnow()}) + user.last_login = datetime.datetime.utcnow() logger.info(f'success {username}') session['username'] = username @@ -108,18 +96,7 @@ def register(): flash('username already registered') return redirect('authapi.register') - db.collection(u'spotify_users').add({ - 'access_token': None, - 'email': None, - 'last_login': datetime.datetime.utcnow(), - 'locked': False, - 'password': generate_password_hash(password), - 'refresh_token': None, - 'spotify_linked': False, - 'type': 'user', - 'username': username, - 'validated': True - }) + database.create_user(username=username, password=password) logger.info(f'new user {username}') session['username'] = username @@ -173,9 +150,9 @@ def token(): resp = req.json() - user_reference = database.get_user_doc_ref(session['username']) + user = database.get_user(session['username']) - user_reference.update({ + user.update_database({ 'access_token': resp['access_token'], 'refresh_token': resp['refresh_token'], 'last_refreshed': datetime.datetime.now(datetime.timezone.utc), @@ -197,9 +174,9 @@ def deauth(): if 'username' in session: - user_reference = database.get_user_doc_ref(session['username']) + user = database.get_user(session['username']) - user_reference.update({ + user.update_database({ 'access_token': None, 'refresh_token': None, 'last_refreshed': datetime.datetime.now(datetime.timezone.utc), diff --git a/music/db/database.py b/music/db/database.py index 3ee4a7a..8191ff9 100644 --- a/music/db/database.py +++ b/music/db/database.py @@ -2,7 +2,7 @@ from google.cloud import firestore import logging from datetime import timedelta, datetime, timezone from typing import List, Optional -from werkzeug.security import check_password_hash +from werkzeug.security import generate_password_hash from spotframework.net.network import Network as SpotifyNetwork from fmframework.net.network import Network as FmNetwork @@ -17,9 +17,9 @@ logger = logging.getLogger(__name__) def refresh_token_database_callback(user): if isinstance(user, DatabaseUser): - user_ref = get_user_doc_ref(user.user_id) + user = get_user(user.user_id) - user_ref.update({ + user.update_database({ 'access_token': user.access_token, 'refresh_token': user.refresh_token, 'last_refreshed': user.last_refreshed, @@ -32,21 +32,19 @@ def refresh_token_database_callback(user): def get_authed_spotify_network(username): - user = get_user_doc_ref(username) - if user: - user_dict = user.get().to_dict() - - if user_dict.get('spotify_linked', None): + user = get_user(username) + if user is not None: + if user.spotify_linked: spotify_keys = db.document('key/spotify').get().to_dict() user_obj = DatabaseUser(client_id=spotify_keys['clientid'], client_secret=spotify_keys['clientsecret'], - refresh_token=user_dict['refresh_token'], + refresh_token=user.refresh_token, user_id=username, - access_token=user_dict['access_token']) + access_token=user.access_token) user_obj.on_refresh.append(refresh_token_database_callback) - if user_dict['last_refreshed'] + timedelta(seconds=user_dict['token_expiry'] - 1) \ + if user.last_refreshed + timedelta(seconds=user.token_expiry - 1) \ < datetime.now(timezone.utc): user_obj.refresh_access_token() @@ -60,116 +58,17 @@ def get_authed_spotify_network(username): def get_authed_lastfm_network(username): - user = get_user_doc_ref(username) + user = get_user(username) if user: - user_dict = user.get().to_dict() - - if user_dict.get('lastfm_username', None): + if user.lastfm_username: fm_keys = db.document('key/fm').get().to_dict() - - return FmNetwork(username=user_dict['lastfm_username'], api_key=fm_keys['clientid']) + return FmNetwork(username=user.lastfm_username, api_key=fm_keys['clientid']) else: logger.error(f'{username} has no last.fm username') else: logger.error(f'user {username} not found') -def check_user_password(username, password): - - user = get_user_doc_ref(user=username) - if user: - user_dict = user.get().to_dict() - - if check_password_hash(user_dict['password'], password): - return True - else: - logger.error(f'password mismatch {username}') - else: - logger.error(f'user {username} not found') - - return False - - -def get_user_query_stream(user: str) -> List[firestore.DocumentSnapshot]: - - users = [i for i in db.collection(u'spotify_users').where(u'username', u'==', user).stream()] - - if len(users) > 0: - return users - else: - logger.warning(f'{user} not found') - return [] - - -def get_user_doc_ref(user: str) -> Optional[firestore.DocumentReference]: - - users = get_user_query_stream(user) - - if len(users) > 0: - if len(users) == 1: - return users[0].reference - - else: - logger.error(f"multiple {user}'s found") - return None - - else: - logger.error(f'{user} not found') - return None - - -def get_user_playlists_collection(user_id: str) -> firestore.CollectionReference: - - playlists = db.document(u'spotify_users/{}'.format(user_id)).collection(u'playlists') - - return playlists - - -def get_user_playlist_ref_by_username(user: str, playlist: str) -> Optional[firestore.DocumentReference]: - - user_ref = get_user_doc_ref(user) - - if user_ref: - - return get_user_playlist_ref_by_user_ref(user_ref, playlist) - - else: - logger.error(f'{user} not found, looking up {playlist}') - return None - - -def get_user_playlist_ref_by_user_ref(user_ref: firestore.DocumentReference, - playlist: str) -> Optional[firestore.DocumentReference]: - - playlist_collection = get_user_playlists_collection(user_ref.id) - - username = user_ref.get().to_dict()['username'] - - if playlist_collection: - query = [i for i in playlist_collection.where(u'name', u'==', playlist).stream()] - - if len(query) > 0: - if len(query) == 1: - if query[0].exists: - return query[0].reference - - else: - logger.error(f'{playlist} for {username} does not exist') - return query[0] - - else: - logger.error(f'{username} multiple response playlists found for {playlist}') - return query[0] - - else: - logger.error(f'{username} no playlist found for {playlist}') - return None - - else: - logger.error(f'{username} playlist collection not found, looking up {playlist}') - return None - - def get_users() -> List[User]: logger.info('retrieving users') return [parse_user_reference(user_snapshot=i) for i in db.collection(u'spotify_users').stream()] @@ -235,6 +134,22 @@ def update_user(username: str, updates: dict) -> None: user.update(updates) +def create_user(username: str, password: str): + db.collection(u'spotify_users').add({ + 'access_token': None, + 'email': None, + 'last_login': datetime.utcnow(), + 'last_refreshed': None, + 'locked': False, + 'password': generate_password_hash(password), + 'refresh_token': None, + 'spotify_linked': False, + 'type': 'user', + 'username': username, + 'validated': True + }) + + def get_user_playlists(username: str) -> List[Playlist]: logger.info(f'getting playlists for {username}')