listening-analysis/playlist-classifier.ipynb

403 lines
579 KiB
Plaintext
Raw Normal View History

2021-02-04 13:34:25 +00:00
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-02-13 11:53:15 +00:00
"version": "3.8.4-final"
2021-02-04 13:34:25 +00:00
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# Playlist Classifier\n",
"\n",
"Given a list of playlists, can unknown tracks be correctly classified?"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 95,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
2021-02-13 11:53:15 +00:00
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
2021-02-04 13:34:25 +00:00
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
"\n",
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
"headers = float_headers"
]
},
{
"source": [
"Pull and process playlist information.\n",
"\n",
"1. Get live playlist track information from spotify\n",
"2. Filter listening history for these tracks\n",
"\n",
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 96,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
"\n",
"# filter playlists by join with playlist track/artist names\n",
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
"\n",
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
"# distinct on uri\n",
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
"# select only descriptor float columns\n",
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
]
},
{
"source": [
"Construct the dataset with associated labels before splitting into a train and test set."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 97,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.concat(filtered_playlists)\n",
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
"labels = np.concatenate(labels)\n",
"\n",
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, test_size=0.33, random_state=70)"
]
},
{
"source": [
"# SVM\n",
"Support Vector Machine"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 102,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [
{
2021-02-13 11:53:15 +00:00
"output_type": "execute_result",
2021-02-04 13:34:25 +00:00
"data": {
2021-02-13 11:53:15 +00:00
"text/plain": [
" uw-rbf w-rbf linear poly sigmoid\n",
"accuracy % 72.95 69.95 67.05 69.95 30.98"
],
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>uw-rbf</th>\n <th>w-rbf</th>\n <th>linear</th>\n <th>poly</th>\n <th>sigmoid</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>accuracy %</th>\n <td>72.95</td>\n <td>69.95</td>\n <td>67.05</td>\n <td>69.95</td>\n <td>30.98</td>\n </tr>\n </tbody>\n</table>\n</div>"
2021-02-04 13:34:25 +00:00
},
2021-02-13 11:53:15 +00:00
"metadata": {},
"execution_count": 102
2021-02-04 13:34:25 +00:00
}
],
"source": [
"### TRAIN ###\n",
"clf = svm.SVC()\n",
"clf.fit(data_train, labels_train)\n",
"\n",
2021-02-13 11:53:15 +00:00
"wclf = svm.SVC(class_weight='balanced') # weight classes based on prevalence\n",
"wclf.fit(data_train, labels_train)\n",
"\n",
"lclf = svm.SVC(kernel='linear', class_weight='balanced')\n",
"lclf.fit(data_train, labels_train)\n",
"\n",
"pclf = svm.SVC(kernel='poly', degree=3, class_weight='balanced')\n",
"pclf.fit(data_train, labels_train)\n",
"\n",
"sclf = svm.SVC(kernel='sigmoid', class_weight='balanced')\n",
"sclf.fit(data_train, labels_train)\n",
2021-02-04 13:34:25 +00:00
"\n",
"### EVALUATE ###\n",
2021-02-13 11:53:15 +00:00
"models = {'uw-rbf': clf, 'w-rbf': wclf, 'linear': lclf, 'poly': pclf, 'sigmoid': sclf}\n",
"accuracy = {i: j.score(data_test, labels_test) for i, j in models.items()}\n",
"\n",
"(pd.DataFrame(accuracy, index=['accuracy %']) * 100).round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 1200x240 with 5 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"125.725022pt\" version=\"1.1\" viewBox=\"0 0 578.921541 125.725022\" width=\"578.921541pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n <cc:Work>\r\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n <dc:date>2021-02-13T11:41:39.494748</dc:date>\r\n <dc:format>image/svg+xml</dc:format>\r\n <dc:creator>\r\n <cc:Agent>\r\n <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n </cc:Agent>\r\n </dc:creator>\r\n </cc:Work>\r\n </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 125.725022 \r\nL 578.921541 125.725022 \r\nL 578.921541 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 18.186235 75.356587 \r\nC 19.135179 82.695261 22.180991 89.606796 26.956925 95.258972 \r\nC 31.732859 100.911148 38.039207 105.067618 45.116622 107.227902 \r\nC 52.194038 109.388186 59.746576 109.461951 66.864836 107.440315 \r\nC 73.983096 105.418678 80.369425 101.386176 85.254845 95.828361 \r\nC 90.140266 90.270547 93.320493 83.419821 94.412594 76.101082 \r\nC 95.504695 68.782342 94.463005 61.301623 91.412953 54.559678 \r\nC 88.362901 47.817734 83.432026 42.096479 77.214055 38.084813 \r\nC 70.996084 34.073146 63.751022 31.938815 56.35125 31.938815 \r\nL 56.35125 70.421573 \r\nL 18.186235 75.356587 \r\nz\r\n\" style=\"fill:#008000;\"/>\r\n </g>\r\n <g id=\"patch_3\">\r\n <path d=\"M 56.35125 31.938815 \r\nC 50.881072 31.938815 45.473243 33.105162 40.489355 35.35987 \r\nC 35.505467 37.614579 31.059133 40.906249 27.447655 45.014785 \r\nC 23.836176 49.123322 21.141881 53.955066 19.544921 59.186946 \r\nC 17.947961 64.418826 17.484742 69.931575 18.186235 75.356587 \r\nL 56.35125 70.421573 \r\nL 56.35125 31.938815 \r\nz\r\n\" style=\"fill:#ff0000;\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\"/>\r\n <g id=\"matplotlib.axis_2\"/>\r\n <g id=\"text_1\">\r\n <!-- Uw-rbf Accuracy -->\r\n <g transform=\"translate(7.2 16.318125)scale(0.12 -0.12)\">\r\n <defs>\r\n <path d=\"M 8.6875 72.90625 \r\nL 18.609375 72.90625 \r\nL 18.609375 28.609375 \r\nQ 18.609375 16.890625 22.84375 11.734375 \r\nQ 27.09375 6.59375 36.625 6.59375 \r\nQ 46.09375 6.59375 50.34375 11.734375 \r\nQ 54.59375 16.890625 54.59375 28.609375 \r\nL 54.59375 72.90625 \r\nL 64.5 72.90625 \r\nL 64.5 27.390625 \r\nQ 64.5 13.140625 57.4375 5.859375 \r\nQ 50.390625 -1.421875 36.625 -1.421875 \r\nQ 22.796875 -1.421875 15.734375 5.859375 \r\nQ 8.6875 13.140625 8.6875 27.390625 \r\nz\r\n\" id=\"DejaVuSans-85\"/>\r\n <path d=\"M 4.203125 54.6875 \r\nL 13.1875 54.6875 \r\nL 24.421875 12.015625 \r\nL 35.59375 54.6875 \r\nL 46.1875 54.6875 \r\nL 57.421875 12.015625 \r\nL 68.609375 54.6875 \r\nL 77.59375 54.6875 \r\nL 63.28125 0 \r\nL 52.6875 0 \r\nL 40.921875 44.828125 \r\nL 29.109375 0 \r\nL 18.5 0 \r\nz\r\n\" id=\"DejaVuSans-119\"/>\r\n <path d=\"M 4.890625 31.390625 \r\nL 31.203125 31.390625 \r\nL 31.203125 23.390625 \r\nL 4.890625 23.390625 \r\nz\r\n\" id=\"DejaVuSans-45\"/>\r\n <path d=\"M 41.109375 46.296875 \r\nQ 39.59375 47.171875 37.8125 47.578125 \r\nQ 36.03125 48 33.890625 48 \r\nQ 26.265625 48 22.1875 43.046875 \r\nQ 18.109375 38.09375 18.109375 28.8125 \r\nL 18.109375 0 \r\nL 9.078125 0 \r\nL 9.078125 54.6875 \r\nL 18.109375 54.6875 \r\nL 18.109375 46.1875 \r\nQ 20.953125 51.171875 25.484375 53.578125 \r\nQ 30.03125 56 36.53125 56 \r\nQ 37.453125 56 38.578125 55.875 \r\nQ 39.703125
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8UAAADRCAYAAADlnRB8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AABOfElEQVR4nO3dd7wcVfnH8c+Tm0YKCQlptIQeOiGUEBCC9CagdBSD/gBBEJUiRVzGgghSRRBQiCJVOkgvCb2T0JEWQg0llfRyfn+cWe5ms7fsvbt7Zna+79drX5vMzO4+d3dmzjxzmjnnEBEREREREcmiDqEDEBEREREREQlFSbGIiIiIiIhklpJiERERERERySwlxSIiIiIiIpJZSopFREREREQks5QUi4iIiIiISGYpKRYREREREZHMUlIsIiIiIiIimaWkWERERERERDJLSbGIiIiIiIhklpJiERERERERyazMJcVmdoaZOTMbVebrNjWzB8zsy/j146sSoEgFaD+XJIr3qbGh4xCpN2Y2Jj6+hoSORbLJzEbH++Do0LG0h5mNiv+OM8p4TZuuuSRZ2pQUF+wwY5vZZki8zcS2BpcUZrYs8F9gc+B6IAL+VuZ7XBF/H7PNrHflo5SQzGx4/Ps+08T6g+L1zsxWLbF+GTObG+8fXaofcckYtZ9L2fL7deg4kkrHhOQVlAH5x6L4BuTDZnZw6Phay8y2Kvgbjggdj1SHmTWY2eFmNs7MppjZAjP73MxeNrO/m9l3QsdYD8ysm5lNi4+na0PHk2UdQweQEpsD/YHTnHNnlvtiM+sJHAg4YBng+8DFFY1QQnsJmAoMN7NlnXMzitZvj//9Dfg28I+i9VsBXYAHnHPzqh1sE7SfS7WtA8wOHUSt6JiQJkTxcydgKLAXsJ2Zbeqc+2W4sFotnwi7+N+XB4xFqsDMGoC7gF2Aafgb5h8BnYH1gIPx++4dBS+7FXga+LSWsVbBs/iy6ssafd4BQC/88fRdM+vrnPuqRp8tBTLXfLqNVoifP2nj6w8GegDnA/OBwysRlCSHc24xMBZoALYtscm34/Vfxf8utR7goSqE11raz6WqnHNvOucmhY6jhnRMyFKcc2fEj9Occ98DdsZfEP886c2f49YO+wFvA7fgbwQPCxqUVMNB+IR4AjDEOfd959zJzrlfOud2BpYHTi98gXNuenyOnx4g3opxzs2O/45aJcVHAIuBc/CVIz+s0edKkZomxWZ2ZNw84PCi5YcVNC/rUrTumbhZ6TKt/IyJ8WNZMzsv/veCUn0DzOyHZvaSmc2Jm4RcaWYDC9YPiZsF/jNedFVBk6HRZfzph+N3+AuAO4ENzWyLZv6GoXEsE81sXhzbY2Z2VFu2LWjKPqaJzxtb3PyxsE+FmW1uZv+Nm89802fJzLYzs8vN7HUzmxF/j6+aWc7MujbxWQ1m9hMze8LMpseveSduirNmvM0f488peWKwxqbKdzX1HQaST2iXSHrj72vVeP04YLsSry0rKdZ+rv08jaxEtxsr6ItlZvua2bPmy4IpZna9ma3YxHv1ib/DN+Lvd7qZPWRmO5XYtpeZnWi+mepHZjbfzL4wszvMbMvmYjWzgfHv9rH55q6jy/iTdUw0vq+OiSY45x4C3sS3JNosvzz+e2+Of9t5ZvaBmV1iZoNaes94/3Bm9kgz27xivtxo8f0KfB/f6mFM/IDGmuNSn9HNzH5lZs+b2Uwz+zo+Zi8yswFt2bbUfluwrmS/VmuhzDSzFczsN/H++Vl8jvjEzK41s3Wb+fs2N7Mb4vPDPDP71MzuN7P94/XV+h2qbWT8PKZUkhsnjkv8TU199/G6nePvdlZ8Prkt/m6W6gtfeN4ys9XN7CYz+yreJ+43s/Xj7frF56FPzecJz5lZqeurfBnwRzN7K952qpndZ2Y7lNj2m3NgiXXDzezeOJYZZvagNVGGtEb8t4zAX/v9CX/z9P9aeM0B5su6KfHfMtHMrjOzTduyrTXTH9qaKEMKfrfVzOxY803q51hcvptZZzM7xszuNn/emhfH8KCZ7drM37ZSfLy/Hb/fFPPXBKfH6xvM7MP4u+/RxHv8JY5t3+a+x1JqXVOcv+Dfvmh5/v/LAN/sXGbWCxgOPOWcm1PG53QGHgb2Bu4HLgTeL9rmF/j+khPwFyxvAYcBT5pZv3ibafhmTrfH/789/n8EjG9NIObvoA4HHnLOfUgLhYiZ7Q68iL9T9BpwHnAzvgbypLZu2w5bAo8BXYEr8YnT/Hjdr4Cd8N/FZcDf43VnAPeYb35TGG9n4B7gUmBl4FrgIuAFYB98E2Li91pM0wXtkfFzWf1da+Dh+Lmp/fth4BFgUGEha74v76b45tcvlvF52s+1n9eTo4F/AxOBvwKv4puVPWhL3ywdjP8+Twa+wH9HN+CbvN1rRTde4+V/wH/f/8XvQw/gb0Y9ama7NBFTH3xzwBH4WrGLgcmt+WN0TCwRr46Jlln87ADMbA/gSWBP4EH8b/wWcBTwvJUYm6KQc+5NfHkzyszWWurDzEYC6wO3O+fKae6av9HzL+Be4DPgYDPrXuIzlov/hrPwLSauxO8Db+DLoXXasm07NFdmboM/n0zDH0vn44/9fYFnzWyjEn/f4XHMe8fP5+LPL/3x57Nq/g7Vlm++u1TM5TKzA/HH/zDgP/jjfDngKWBIMy8dAjwDDMCfP+8HdgDGmr+R9jT+JtINwI3ARvjzzypFn98b//ucDEzHXwvdjD/H3W9mR9IK8W/1WBzDPfjyYD6+FWCTNztbkD/XjXHOTcHfPF3HzL5V4vMtTk6vBzbEl0nnxzF9C9ijLdu204XA74BX4n8/ES/vE/+/J76sPQ/f1H4YcLeZLZX4x4n6BOBYfIvFi4BrgJn4sgXn3CLgivh9DyrxHvluSp/ReE3bes65sh/AKPyJe2wz2wyJt5lYtPwD4HPACpZ9gk+YFwG/K1i+V/wep5cR28T4NQ8C3UusPyNePx8YVrTu/HjdP4qWj46Xj27Dd/W3+LUHxf/viO9v8TWwbNG2y+MP2PnAtiXea6U2bpv/LcY0EeNYvyuU/I0dcGQTr1ut8HcsWP67+HUHFC0/M15+B9ClaF0XoF/B/++Kt12/aLue+ANkEtDQlv23mo94X15c9LfkD+qO+L44DjimYP2e8bJbtJ9rP0/Dfl4Uqyv+XlvYdmwT++oMYIOiddfG6/Yv8VsuBg4sWt4bn6jNAQYULO8FLF9q/8Efs2809XfhL/47tuF70TGR0WOihf3flVi+Q7w/LwYG45PCr/DXRN8q2vZX8fvcX7R8TLx8SMGyfeNlfy7xmfntdywj/hHxa+4rWPbneNmPS2yfP34vBToUresB9GrjtkvttwXrRlOiHKPlMrM/0LPE8o3wx+w9RcvXBRYAU4D1Sryu8Dis6O9Qo311GP4csxi4GvguMLiF1yz13cfH7VRgHrBR0fZn0XiuKdxvhxQsP63oNafHy6fgz7EdCtb9IF53ftFrLouXX8aSucea+HPpvKLPHxVvf0bBMsO35nDAXkXvf1xBvKPK+I67xn/HNGCZeNke8ftcXWL7I+J1zxYeD/G6BmBQG7c9o6nYaaIMKdhvPwZWLfG6LoXHQMHyXvgb3lPyf3O8vDP+BpUDDm7heBqEP/aeb2Yf/EOb9vs2Hiz5HWZsM9vkv8iJRcuvipdvGP9/3fj/RwHPAU8WbHtRvG5kGbFNjF+zURPr8z/+P0qs6xXvnHMoKLhpY7IAdMdf6E0DuhYszxciRxVtf3y8/MJWvHc525bcqQvWj6XpC6OX2rB/9Ilfe2XBsob4e5gNrNCK99g9fo+/FC0/Ml7+m7bsu9V+4AsPR8FFPP6i++6C/0+mIAGmMUn9aRmfo/186W21nwd4ULmk+Pcltt+OootJ/EWqA/7TxGfsFa8/upUx5cuZVUrEOg/o34bvRMdE47LMHRPN/G0ufpwRP/4A3AQsjJefF293SPz/a0u8R0caLx5XKVg+hqWTi4748udLljzX945/j3cocXOjmfj/QcGNnnjZ+vGyZ4q27Y9P6j+hRBLa1m2b2m8L1o2m+aR4ozb8bnc
},
"metadata": {}
}
],
"source": [
"fig, ax = plt.subplots(nrows=1, ncols=len(models))\n",
"fig.set_figwidth(2 * len(models))\n",
"fig.set_figheight(2)\n",
2021-02-04 13:34:25 +00:00
"\n",
2021-02-13 11:53:15 +00:00
"for (name, acc), ax in zip(accuracy.items(), ax):\n",
" ax.pie([acc, 1 - acc], colors=['g', 'r'], startangle=90, counterclock=False)\n",
" ax.set_title(f\"{name.capitalize()} Accuracy\")\n",
"\n",
"fig.show()"
2021-02-04 13:34:25 +00:00
]
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 104,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
2021-02-13 11:53:15 +00:00
"text/plain": "<Figure size 1080x480 with 2 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"278.674375pt\" version=\"1.1\" viewBox=\"0 0 602.462131 278.674375\" width=\"602.462131pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n <cc:Work>\r\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n <dc:date>2021-02-13T11:41:40.828768</dc:date>\r\n <dc:format>image/svg+xml</dc:format>\r\n <dc:creator>\r\n <cc:Agent>\r\n <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n </cc:Agent>\r\n </dc:creator>\r\n </cc:Work>\r\n </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 278.674375 \r\nL 602.462131 278.674375 \r\nL 602.462131 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 60.570313 241.118125 \r\nL 279.370313 241.118125 \r\nL 279.370313 22.318125 \r\nL 60.570313 22.318125 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g clip-path=\"url(#p4b8a98466a)\">\r\n <image height=\"219\" id=\"image100b087ff4\" transform=\"scale(1 -1)translate(0 -219)\" width=\"219\" x=\"60.570313\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAAW0AAAFtCAYAAADMATsiAAAGKElEQVR4nO3WMYoVVhiG4bnjNShBglWIoo2BNGIIDAixTDWQDQjuwNoFWFkE0ti5i0AgjY0gdhYWdgFxIIQUKSQZMcGZmy2kOj8v8zwr+A4HXv7Nrfs/7vbOmPdfnLkn71199u/0hKXOP305PWG53bdfT09YbvPi1fSE5fanBwDw/4k2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAyPbKL79Nb1ju75ufT09Y7uLrs/XPp5cuTU9Y7vIPR9MTlvvzzvSC9VzaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINECLaACGiDRAi2gAhog0QItoAIaINELL9+Obt9IblPr3wyfSE5f66fX16wlLfPXw+PWG5F99cnJ7AAi5tgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCNluttvpDct9uPbZ9ITl7j36eXrCUj8dHkxPWG5z/o/pCcvt/jmZnrCcSxsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBEtAFCRBsgRLQBQkQbIES0AUJEGyBku//VjekNyz1+8nh6wnIPDr6fnrDWx3fTC5bbbDbTE5bbTQ8Y4NIGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AENEGCBFtgBDRBggRbYAQ0QYIEW2AkM3J71/upkesdnh4d3rCer8eTS9Y6vT4eHrCevvnphesd3oyvWA5lzZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOEiDZAiGgDhIg2QIhoA4SINkCIaAOE/AdU+zTq63xgWQAAAABJRU5ErkJggg==\" y=\"-22.118125\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M 0 0 \r\nL 0 3.5 \r\n\" id=
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+0AAAHUCAYAAABPrclfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAEAAElEQVR4nOzddXgUV9vA4d/ZZDfuBiRogOBSb5GipUihFOql1F3eUu9Xl7eub41SpS5QL6W0QJFCDZcEiBB3183uzvfHbGSzuxFIQlqe+7pytZyZMzuzMud5zpw5ozRNQwghhBBCCCGEEF2P4UjvgBBCCCGEEEIIIVyTpF0IIYQQQgghhOiiJGkXQgghhBBCCCG6KEnahRBCCCGEEEKILkqSdiGEEEIIIYQQoouSpF0IIYQQQgghhOiiJGkXQgghhBBCCCG6KEnahRBCCCGEEEKILkqSdiGEEEIIIYQQoouSpF0IIYQQQgghhOiiJGkXQgghhBBCCCG6KEnahRAopTSl1Np22M5apZTWDrvUIZRSD9qPdcKR3hchhBCiK1BK9bG3je+2w7ZSlFIph79XHUMp9a79WPsc6X0Roi0kaReiBUqpCS0ltY0avJTO27N/h/bqMOgoSqk4pdQSpdQBpVS1UqpCKZWslPpJKXW/UirKvt6V9mP5rBXbvMe+7ov2f9d9fzSlVLlSKsBNPaWUSmy07oT2PFYhhBCdTyl1rP2c/rub5ec3Ou/3dbHcx94+VSqlvDp+j9tPe3YYdBSl1PFKqQ+VUgeVUjVKqVJ7W/ytUuoOpZSffb3H7MfyVCu2+YZ93Vvs/57Q6DNOVkopN/X87a9ft26fdj1Y0WVJ0i6EABgMXHykd6KrUUpNArYBVwC5wJvAi8CvQB/gIWC4ffWPgXJgjlIqvJltKuBy+z/faLLYAvgB57upPhnoZ19PCCHEv8NWoAg4VikV6GL5ZKBuFNskF8vHAF7ABk3Tag7h9TPQ44C7D6Huv5pS6iJgM3q7vB94DXgF+B0YCTwJdLev/ib653SxUsrYzDb9gPOAGuC9Jost6PHFVDfVzwMCkDjgqCNJuxACTdPiNU1LPdL70QUtBryBSzRNO0XTtBs0TbtH07RLNE0biN5gJwBomlaOnribaL4DZBJ64v2bpmm7myz7G8gGrnRT90r0Rn7VoR6QEEKIrkXTNBuwFvAATnWxyiT78gJcJ+11Zb8c4uvX2uOArEOp/2+llPJFT9A1YKqmaVM0TfuPpml3a5p2gaZpvdA7TPIBNE1LBn4GooAzmtl0XeK9TNO0wibLfkZv55uLA7LQ4wVxFJGkXYgO1PgeaqXUfKXUH/bha4VKqU+UUtFN1v/Yvv6AJuXv2ct/aVIeoJSqVUqtc/Ha5yul1iiliu3D5vYqpe51NXTO3RB1pVR3pdQ7SqlcpVSVUmqbUmpho2FcD7o5bk/7EPD99qFkaUqpJ5VSpkbrXKIa7n8/tdFQL6ftKqVOVEp9oZTKVkqZ7dtbrJTq4eb1j1VK/aiUKrMPI/tZKXWyq3XdUUpFAv2BEk3TmvaEA6Bp2g5N09IaFdVdOb+imU3XNcRNr7KD3nP+DnCcUmpkk/0JB84ElgFNG3khhBD/bHXtu0NSbh/+3Ne+/Fdgoou6Tkm7vR2+Tim12d4OViqltiqlblBKOcT/zQ1RV0oNVEotU0oVKf32sN+UUjPr2nCl1CWuDkYp5aeUeloplWqPAw4ope5sPOzb3tYn2/+5sEkccEmT7U1TSv2glMq3by/Rvv1gN68/RSm13r7PhUqpr5RSg1yt24xhQCCwS9M0lx0imqb9pmlacaOiurbdXdLdeJmrOKAAWI4+ai+i8QKl1AjgBPQ4Qa60H2UkaReic1wHfACkoPfa7gLOBX5ukkTXNQqTm9Sv+/cpSinvRuWnAp406V1XSr0NfISedC6zv2Yh8Ajwo1LKs6Udtietm4BLgL3AC+hD+F4Fbm6h+kfAjcB69KFkVcAd6Feu62xDH14OcND+/3V/axvtx2XARmA6sMa+H3+hJ8Z/KaV6NdnvU+yvOwVYAbwMmO3bPLGl426kBL1R9FdKdW9pZQBN0/6yH9dgpdSYpsuVUmHoiXcJ4O7e97rhdU0b/IXoV/GXtGZfhBBC/KOstv/XXfu/Gr0N7K6UGlK3UOnD6Y9DH16/xV5mBL5Db/uD0dvkN9Dj/v/hPCTbJXuSuxk4C70dfhFIBb5Eb8vcMQIrgXno7fCbgA/wBHB/o/XW2rcJsB3HOGBbo/14APgRvQ3/HngJOADcBmxUTW4pUErNt7/+ccDn6LFHGHpM4zQnQDMK7P/toez3rbfC1+i3053WND6x79sw+3Hs0zTtVzfbWIL+Hi5sUn4lenzwViv3RfybaJomf/Inf838ARPQT5Jrm1mnj32dlCblD9rLS4HhTZZ9ZF92TqOyfvayzxuVxdnLfrL/d3KjZc/by8Y1KrvEXrYc8HGzPzc3KXc6PvRGQQOebFI+En3olgY82GTZWnv530Boo3I/9AbWCnRr6bUbLRuInnAfAKKbLJts396XjcoUEG/f5pwm699sL9eACa387L+wr5+IHhycCPi2UOdae513XSy7xb7sZTffnw32f/+MHoD5NFpnL3ojD3oHUKuPQ/7kT/7kT/66/h+QCdiAiEZlHwJl6B30Q+3n/hsaLT+jrs1vVFbX1v8P8GhU7tGobZ/TqLyuDXq3yf78Yi+/tkn59Ebt6SVNlqXYy39o0oZFAsX2P2NLr91o+UT78t+A4CbLLrEve75RmT96sl0LHNdk/bqYSQP6tOLzUMAf9vW3AdcDowFTC/WexEWMZF/2on3ZbU3KJ9jLP7C/7n4gvtFyH3tcsMr+7w2tPQ75+3f8yZV2ITrHS5qm7WxSVnfF9IS6Ak3TktAbvImNhpDV9bLfj56kNu6FnwxUoPeE17kZ/QrxZZqmVTV5zUfQG7MLm9tZ+zD289GvCD/aeJmmaduBpc3VB+7UGt2npWlaBXrgYUDv+W6ta9F7m2/WNC2jyX78AnwDnKEaZls/Bb2TY52maV832dbL6Ml3W1yJ3vnRF3ga/X0uU0ptV0o9quwzxzfxIfpncnbT3n8ahs27GhLX2BL0qyNnAyilxgGD0K9WCCGE+Hf6BT1hazwEfiKwXtM0i6bPg5KL4xB6h6Hx9qHvN6LPj3KLpmnWuhXt/38rerLXUhzQ077tAziOkkPTtBXoncvNualxDKJpWi76Vegg9Ha6tW6y//dKzXEYOpqmvYueTDc+ljlAKPCRpo9+a+xB9LimVTRN04D56BckRqLHEVuAcqXU7/bh/q4mDlyC/h5f2vhWBPvIyovQL0a828LrvgnEKaXG24vno8cFMtruKNXiEFkhRLto2nAA1N0LHdKkfDVwGTAKfTj6JCBL07TNSqm/sSft9nudhgE/aZpWay/zRW9Y8oH/KNdPDKlBnyW2OXHovbp/aZpW5mL5Bpq/b7stx9ucuvvQT1VKHe9ieST6lYOB6Ff3j7GXOw050zTNqpTaAMS29sU1TSsC5tnvKZyG3uFwPDDC/netUup0TdP+bFSnVCn1KfpneCH67QF1w/aHAH9omrajhZf+Ev0zvBK9g+Qq9KsG77Z234UQQvzjrEZP6iYBnymlBqPPTP58o3XWAlOVUgZNn8Cu6f3sA9GT1v3AvW7igCpajgNG2f+7yf46TW1Avw3NlRJN0w64KD/UOKAWvSP8bBfLTUCEUipM07QCmo8DSpRS23A92Z9Lmj5J70T7ZzEVPQ44odHfdUqpCZo+CV1dnQNKqTXon8009FsEQL9dIBT4TNO0/BZe+l30Cy1XAuvQ44B84KvW7rv4d5GkXYiW1TVWzY1MqVvmqmEDfThYU3WTiHg0Kf8FPeGbrJTajt7L/kOjZXcopYLQGwOF4/3sIfayCOCBZva3JUH2/+a4We6uHICmveF27o63OWH2/97ewnr+9v+2tN/ZbXjtepqmpaBfaVgMoJSKQb+3/wz0Xu9RTaosQf8Mr8CetNP6q+xommZWSi0FFtkn0JsPfGO/UiGEEOLfqem8No3vZ6+zFjgHGK2USkV/7GiGpmn
2021-02-04 13:34:25 +00:00
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
2021-02-13 11:53:15 +00:00
"fig, ax = plt.subplots(nrows=1, ncols=2)\n",
"fig.set_figwidth(9)\n",
"# fig.set_figheight(15)\n",
"\n",
"args = {\n",
" 'normalize': 'true',\n",
" 'colorbar': False,\n",
" 'display_labels': playlist_names\n",
"}\n",
"\n",
"plot_confusion_matrix(clf, data_test, labels_test, ax=ax[0], **args)\n",
"ax[0].set_title('Unweighted SVM')\n",
"\n",
"plot_confusion_matrix(wclf, data_test, labels_test, ax=ax[1], **args)\n",
"ax[1].set_title('Weighted SVM')\n",
"\n",
"fig.tight_layout()\n",
"fig.show()"
2021-02-04 13:34:25 +00:00
]
},
{
"source": [
2021-02-13 11:53:15 +00:00
"## Unweighted Classes\n",
"\n",
"From the above unweighted scenario, it is clear that the Pop playlists was not effective for classifying similar tracks. This is likely primarily due to the larger size of the Rap (~800), EDM (~1,300) and Rock (~700) playlists compared to Pop (~125). Additionally, there is overlap with other genres such as Rap and EDM where much of the confusion occured. Also not helping is that one of the sub-playlists is shared across EDM and Pop, electropop. as EDM is already such a larger playlist it is unsurpising that this performance was poor. The overlap with Rock is understandable as Pop contains an Indie Pop sub-playlist which could have cause some confusion. Quite surprising was the confusion for Jazz as I wouldn't have thought there would be much overlap here.\n",
"\n",
"The other major confusion was with Rock and Metal, specifically classing Metal tracks as Rock. This could be expected due to the similarity in tone.\n",
2021-02-04 13:34:25 +00:00
"\n",
2021-02-13 11:53:15 +00:00
"## Weighted Classes\n",
"\n",
"When weighting the classes by prevalence in the dateset, the model is generally better at classification. The clearest difference is the ability to classify Pop songs. Without weighting, no songs were correctly classified as Pop but were instead mis-identified as Rap, EDM, Rock and Jazz. When re-weighting, the Pop playlist was now correctly classified almost 60% of the time. Mis-identification as Rap, EDM and Rock dropped from a combined 85% to just 20%. Interestingly, the mis-classification of Pop as Jazz increased from 15% to 21%.\n",
"\n",
"The improved accuracy of the Pop model reduced the accuracy of some others. The accuracy of Rap, EDM and Rock decreased as some tracks were instead classified as Pop. EDM and Rock were worse affected than Rap with around 15% Pop error rate compared to Rap's 9%. As discussed previously, this could be due to the overlap in aural tone. The overall of Rap was not significantly affected by this Pop error rate as, to compensate, the EDM error rate dropped from 12% to just 3%."
2021-02-04 13:34:25 +00:00
],
"cell_type": "markdown",
"metadata": {}
},
2021-02-13 11:53:15 +00:00
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 1080x960 with 4 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n<svg height=\"566.674375pt\" version=\"1.1\" viewBox=\"0 0 603.346449 566.674375\" width=\"603.346449pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n <metadata>\r\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\r\n <cc:Work>\r\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\r\n <dc:date>2021-02-13T11:41:44.865763</dc:date>\r\n <dc:format>image/svg+xml</dc:format>\r\n <dc:creator>\r\n <cc:Agent>\r\n <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\r\n </cc:Agent>\r\n </dc:creator>\r\n </cc:Work>\r\n </rdf:RDF>\r\n </metadata>\r\n <defs>\r\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\r\n </defs>\r\n <g id=\"figure_1\">\r\n <g id=\"patch_1\">\r\n <path d=\"M 0 566.674375 \r\nL 603.346449 566.674375 \r\nL 603.346449 0 \r\nL 0 0 \r\nz\r\n\" style=\"fill:none;\"/>\r\n </g>\r\n <g id=\"axes_1\">\r\n <g id=\"patch_2\">\r\n <path d=\"M 60.570313 246.518125 \r\nL 284.770312 246.518125 \r\nL 284.770312 22.318125 \r\nL 60.570313 22.318125 \r\nz\r\n\" style=\"fill:#ffffff;\"/>\r\n </g>\r\n <g clip-path=\"url(#p694abe5739)\">\r\n <image height=\"224.4\" id=\"image2109f66ef9\" transform=\"scale(1 -1)translate(0 -224.4)\" width=\"224.4\" x=\"60.570313\" xlink:href=\"data:image/png;base64,\r\niVBORw0KGgoAAAANSUhEUgAAAXYAAAF2CAYAAAB6XrNlAAAGOUlEQVR4nO3WPYodBBiF4dw7g1NMpXEELWVAEIS0imAlIWDhBlJlB4KENDZioY2WIli5K0EQm5A0o1j5M1w3IRx453lWcKr3+w4P33l6uncHnZ6/XE+YObz5xnrCxO3Pv6wnTPz++P31hJmrJ7+uJ0wc1wMA+H8JO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8Sc3/vjz/WGieP9V9cTdv7+Z71g4uz+a+sJE8+++Gk9YeaHRx+vJ0z42AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0g5vz2xcv1honj5eV6wszNp++tJ0z8+NW36wkTn7394XrC0G/rARM+doAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoAYYQeIEXaAGGEHiBF2gBhhB4gRdoCY8/WAmePdvWnffPn9esLE5w8frydMHC9frCfM/Pvgej1h4u7WDSBK2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSDm/Pjg3fWGibPvbtYTZr7+6JP1hInD8a/1hI2LV9YLZk5nh/WECR87QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8QIO0CMsAPECDtAjLADxAg7QIywA8Qcbp9fn9YjFh5df7CeMHO4uFhPmLi9uVlPmDi7ulpPmDm99fp6woSPHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSBG2AFihB0gRtgBYoQdIEbYAWKEHSDmP6qsLgEb2KFPAAAAAElFTkSuQmCC\" y=\"-22.118125\"/>\r\n </g>\r\n <g id=\"matplotlib.axis_1\">\r\n <g id=\"xtick_1\">\r\n <g id=\"line2d_1\">\r\n <defs>\r\n <path d=\"M
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+8AAAOzCAYAAADJNjepAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAEAAElEQVR4nOzdd3hUVfrA8e+ZJJPeC70mEHqxF0C6NEUEXFxFEMW1/nCxr+6Kuva+69qwgb2AXQRRREBQkN4hhBZIT0jPJDPn98edSWYyMymQkBHez/PME7ht7pk7c8976lVaa4QQQgghhBBCCOG7TM19AkIIIYQQQgghhKidFN6FEEIIIYQQQggfJ4V3IYQQQgghhBDCx0nhXQghhBBCCCGE8HFSeBdCCCGEEEIIIXycFN6FEEIIIYQQQggfJ4V3IYQQQgghhBDCx0nhXQghhBBCCCGE8HFSeBdCCCGEEEIIIXycFN6FEEIIIYQQQggfJ4V3IYQQQgghhBDCx0nhXQgnSimtlPq5uc/jVKWUGmz/jOc0wrF8+loppX5WSunmPg8hhBC+SSn1jj0v69jc59LYlFL7lVL7G+E4Pv0ZKaWm289venOfizg9SOFdnPLsN1UpRAFKqYn2z+NjL+vvs68vU0oFeVjfyb5+X9OfbeNqzIqDpqKUGqGU+lwpdUQpZVFK5SmldiulPlVK/Z9SStm3e9+elpvrccwl9m0n2P/vCDS0UuqXWvbrqJSyye9HCHE6c7pfOl5WpVS2UuonpdRfm/v8PFFK+Sml8pVSFUqpcA/rWzml51ovx1huXz+o6c+4cTVWxUFTUUq1Vko9r5TarpQqUUqVKqUO2j/zR5VSifbtRtivwe/1OOZf7dt+6bTMcY1tjmN62XeZ07bTGyWRoslI4V0IV92Ba5r7JJrQMsAGDHYUBGsYBmggELjQy3qApcf5/r9jfMYvHef+pyyl1D+AJcA44A/gP8DrwGbgIuBFwM+++Vz73+vrOGZHYDhwFPi6xupKYKBSKtnL7tcDyr6dEEKc7h6yv54AlgODgPeVUs8161l5oLW2Aj8D/hjnWZMjL9fA0JorlVIhwHlAMbD6OE9jmNP7CDulVC9gC3A7Rv46D3gOWAxEAf8Ahtg3XwqkAmcrpXrXceiZ9r+v11heiZGXX+flfLoAg5G8/k9DCu9CONFa79RaH2zu82gqWutcYBOQAPRyXqeUCgQuAD7HKOC7ZehOy348zvcvsX/G2cez/6lKKdUBeBgoAM7QWl+itb5Ta32P1noS0BIYBVgBtNY/A7uB/kqpM2o59HUYmfbbWuuaGfM39r9uFQBKKT/gWmAtkHHcCRNCiFOE1nqO/XW/1noicDFG4fd2H+3S/ZP9r7e8vBT4juqCorMBgBlYobWuOJ4311qnaK1TjmffU9wLQAwwR2vdR2t9k/07NVNr3RdIBFYCaK018IZ9v5kejwYopZIwKvkPAYtqrM4A1gHXKqX8PezuiAFqVvALHyWFdyGceBpHrZSaY18+WCk1SSn1u72bU65S6iOlVBsvx4pRSj2ulNph7xJ1TCn1o1JqpIdtI5VSd9m74R22d5nOUkp9pZQ6v7ZzVUq1VEq9oZRKs3fnm15HMh0F75oZ+vlAMLAQ2OhhPRiZvKY6KEApFaKM7vYblVLFSqkipdRqpdSVHs7Za9d1pdTZ9i7ehUqpAqXUUqXU+c6fv5fPIU4p9bpS6qhSqlwpta1mN0Cl1DsYvQ4AHqzRBXJwjW2vtHchy1fG8IEdSqkH7JUbnt5/ilLqD/s1zlRKvauUau1p21qci9GqvkxrvaXmSq21TWu92J6ROzha3z1m6E4FcOfM39k2jBaVaUqpgBrrxgKtnd5DCCGEE631j8BOjArSsx3LlVJnKqUW2PODcqXUAaXUy0qpVnUdUynVzZ4vLatlmy3K6A5f1/G85fWOZaswWnvbeOiB5bGiXil1sVLqO2UMGyhXSqUopZ5WSkV5OE+PXdft8c4L9linTCm1Uyk1WynV2Z72d7wlSCn1N3v6y5RSGfa8P9Jp/WBlDPPqAHSokde/U+NY3ZQxnv6QPebKUEp94OGzcGyfpIwhbHn2WOdXpdRYb+daiwvsf1/0tFJrvU9rvdNp0VsYreJXKw/DGe0cPeXe1FrbPKyfi9EIMM55oT3vnw78CmyvbwJE85LCuxD1dzPwHrAf+B+wFfgLsLRmwU4ZLal/APcCWcCrwMcYXca/V0rVLHB1Bx7FaPH+FqML1Q8YGegvSqlRXs4pBliD0b1tIUZ39LpaSh0F75rd2YY5rV8GnKWcxsoppXpi3Py3aK2z7MuiMGqIH8NoFX4LowtYPPCBUurfdZyL49iDgBUYlQPf2dNRaj+Pc2rZNQojADkf+Mz+3q2Bt5RS05y2+8K+Dozujg85vfY7ncdbwAdAErAA4zrnAo9gXDeXWmul1N+BD4HOwHzgbaA3RkYYXZ+02+XY/3a2F7rrYx5gAa5URhfHmkYDbYClWutUL8eYi3GtxtdYPhMowkibEEIIzxzDzzSAUmocxv3/Eowuz88Bu4CbgHVKqU61HcxeaFuGMbStq9ubKXUBRq+5L7XWR+s41jaMeKCvUirW6RiJGIVbR14P3uOBH532exD4HqOy+VuMoV17gTuBVUqpiNrOx36MIPv7zgIyMQqwPwP3A8/WsftT9tcmjLw5DSOv+txpm/0Y+fox+8s5r//C6TxGAeuBqzB6mL1gT+vlwO+qRo82ZXQtXwNMwqj0fhE4bD/m5XWluwZHfu92fT3RWqdj9JSLBibWXG+PS6ZRHYN58iHGEIiaPe0uxeiJKRX1fyZaa3nJ65R+YWSqugHb/lxj2Rz78gKgd411H9jXXVFj+c8YBfEpNZZHYbRqlwItnJZHAnEezqctcATY4S1dGIVG/wZ8HqEYhb58wM9p+Spgp/3fY+3HHue0/jb7sueclr1jX3Z3jfcIwsjkbUA/p+WD7dvPcVpmAvbYl4+ucZwbndI52Ev636iRjh4YtdTba2zv9t411k+3r18IBHv5DsxyWtbR/jnmAh1rpGdBA793oRhBhwZ+AWYAPZ3T5WW/j+37TPew7kv7ukle0vlv+/seAxY7rW9j//zm2v9/uL7pkJe85CWvU+3l7V6OMZ+Izf7qAIRhFMyswMAa295jP86SGssdeWhHp2WT7Mue8fCeju1H1PPcHTHKJKdlM+3LzsOofMgEPnNaH2nPA7IBZV/m6HX3KxBV4z0cecrzNZbvB/bXWPZP+7YfOo5tX94Oo6FDA+94SfNBoL3Tcn97fqmBc+p6b6d10UCePX09aqzrhVFxvb7G8iXUiAHsy8dTHYtMr+c1eca+fTrwIMacBBF17DMaD/Gpfd0E+7pvvHx3D9v//Yb9urZ1Wv89RgwQghET1Dsd8mq+l7S8C1F//9HuXZodtZVVrcNKqb4YY48WaK0/ct5Ya52PcbMOwqkGVWt9THsYB661PozRotxNKdXewzlZgDu1+3hmr7TWxcBvGBn0mfZzDsPo9udolV+BEYA4d7dz6UZnr8m/GlintX6qxnuUYQQrCqhrNt4LMFq6l2mta47Veh1jbLc3JcBsbUzO43jv7RgVEd3t6aqvWRgZ2wytdWmNdY9gBGVXOS27CggA/qu13u/0/jbgLoyArl7s1+RSjIqdgcCbGD07CpUx++zNXrrtOyamcalNt3enHIMRlH1Zc6ca7/sBMEJVj9mcgdGFX2rihRDCzj6Ea44yZgP/DKPgo4AXtNYHMApyMcDHWusVNXZ/FqNAOcJLXu7sC4xJRqc73/ftPd2uAFKo/6SxnrrODwUKMfJujdHY4DyJ7WCqh3Fp+7L/s/+daY9jqmit38HIu5zzR2+mYeSN9zkdG631IYzW79o8rJ3mJLLHPW/b/1tbD72arsFoSHnQHi9U0Vpvxcj7+iulegAopdoCIzAmjnupxvZfYvTma4j77e8Ri9EwsBzItw8feEEp1dnDPouBA8BF9l4Azhz5f82J6mqai3FdZ0BVD9ERwPta65IGpkE0I08TFwghPFvnYdkh+1/nLtKOMeqRyvNjyeLtf7s7L1RKXYh
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"fig, ax = plt.subplots(nrows=2, ncols=2)\n",
"fig.set_figwidth(9)\n",
"fig.set_figheight(8)\n",
"\n",
"plot_confusion_matrix(lclf, data_test, labels_test, ax=ax[0][0], **args)\n",
"ax[0][0].set_title('Linear Weighted SVM')\n",
"\n",
"plot_confusion_matrix(pclf, data_test, labels_test, ax=ax[0][1], **args)\n",
"ax[0][1].set_title('Poly Weighted SVM')\n",
"\n",
"plot_confusion_matrix(wclf, data_test, labels_test, ax=ax[1][0], **args)\n",
"ax[1][0].set_title('RBF Weighted SVM')\n",
"\n",
"plot_confusion_matrix(sclf, data_test, labels_test, ax=ax[1][1], **args)\n",
"ax[1][1].set_title('Sigmoid Weighted SVM')\n",
"\n",
"fig.tight_layout()\n",
"fig.show()"
]
},
2021-02-04 13:34:25 +00:00
{
"source": [
"## Other Tests\n",
"\n",
"Take a handful of other tracks which I don't listen to and aren't in any playlists to see if they can also be classified"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 34,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Top Of The World (Five Finger Death Punch) could be ROCK ✓\nAston Martin Music (Rick Ross) could be RAP ✓\nOn The Sunny Side Of The Street (Dizzy Gillespie) could be JAZZ ✓\nVibez (ZAYN) could be EDM ✗\nShot In The Dark (AC/DC) could be ROCK ✓\nTo Hell and Back (Sabaton) could be ROCK ✗\nWithstand The Fall Of Time (Immortal) could be METAL ✓\nAlone Together - Rudy Van Gelder Remaster (Kenny Dorham) could be JAZZ ✓\nFeel No Ways (Drake) could be RAP ✓\nBO$$ (Fifth Harmony) could be EDM ✗\n\n70.00% Accurate\n"
]
}
],
"source": [
"### PREPARE ###\n",
"test_uris = [\"spotify:track:53yqxU2EKKzbuQZEUEVtxc\",\n",
" \"spotify:track:5W7xC99N2Zzfh69r7I7zWK\",\n",
" \"spotify:track:38R2EViAkYOFG8ZkG3GLtW\",\n",
" \"spotify:track:6T6D9CIrHkALcHPafDFA6L\",\n",
" \"spotify:track:0sfdiwck2xr4PteGOdyOfz\",\n",
" \"spotify:track:1BrgjqSg9du0lj3TUMLluL\",\n",
" \"spotify:track:5nCnSnLtotQ8eB4E189U91\",\n",
" \"spotify:track:3GOZbK2epuHzCt5YvvVFHO\",\n",
" \"spotify:track:3cjF2OFRmip8spwZYQRKxP\",\n",
" \"spotify:track:1COvXs6jaykXC73h9OSBVM\"]\n",
"# inferring what playlists these would go in\n",
"test_labels = [\"ROCK\", \"RAP\", \"JAZZ\", \"POP\", \"ROCK\", \"METAL\", \"METAL\", \"JAZZ\", \"RAP\", \"POP\"] \n",
"\n",
"test_tracks = spotnet.tracks(uris=test_uris)\n",
"spotnet.populate_track_audio_features(tracks=test_tracks)\n",
"\n",
"test_features = [ {j: k for j, k in i.audio_features.to_dict().items() \n",
" if j in headers} \n",
" for i in test_tracks] # filter down to descriptor columns\n",
"\n",
"### PREDICT ###\n",
"predictable_frame = pd.DataFrame(test_features)\n",
"\n",
"predicted_labels = clf.predict(predictable_frame)\n",
2021-02-13 11:53:15 +00:00
"# predicted_labels = wclf.predict(predictable_frame)\n",
2021-02-04 13:34:25 +00:00
"labels_correct = [i == playlist_names[predicted_labels[idx]] for idx, i in enumerate(test_labels)]\n",
"\n",
"### EVALUATE ###\n",
"for track, label, correct in zip(test_tracks, predicted_labels, labels_correct):\n",
" print(f'{track.name} ({track.artists[0].name}) could be {playlist_names[label]} {\"✓\" if correct else \"✗\"}')\n",
"\n",
"correct = sum(labels_correct) / len(labels_correct)\n",
"print(f'\\n{correct*100:.2f}% Accurate')"
]
},
{
"source": [
"# Imports & Setup"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
2021-02-13 11:53:15 +00:00
"execution_count": 2,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"\n",
"from google.cloud import bigquery\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.dpi'] = 120\n",
"\n",
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
"from analysis.query import *\n",
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import svm\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import plot_confusion_matrix\n",
"\n",
"client = bigquery.Client()\n",
"spotnet = get_spotnet()\n",
"cache = 'query.csv'\n",
"first_day = datetime(year=2017, month=11, day=3)"
]
},
{
"source": [
"## Read Scrobble Frame"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"scrobbles = get_query(cache=cache)"
]
},
{
"source": [
"## Write Scrobble Frame"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
],
"cell_type": "code",
"metadata": {},
"execution_count": 6,
"outputs": []
}
]
}