listening-analysis/playlist-classifier.ipynb

311 lines
159 KiB
Plaintext
Raw Normal View History

2021-02-04 13:34:25 +00:00
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# Playlist Classifier\n",
"\n",
"Given a list of playlists, can unknown tracks be correctly classified?"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
"\n",
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
"headers = float_headers"
]
},
{
"source": [
"Pull and process playlist information.\n",
"\n",
"1. Get live playlist track information from spotify\n",
"2. Filter listening history for these tracks\n",
"\n",
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
"\n",
"# filter playlists by join with playlist track/artist names\n",
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
"\n",
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
"# distinct on uri\n",
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
"# select only descriptor float columns\n",
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
]
},
{
"source": [
"Construct the dataset with associated labels before splitting into a train and test set."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.concat(filtered_playlists)\n",
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
"labels = np.concatenate(labels)\n",
"\n",
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, test_size=0.33, random_state=70)"
]
},
{
"source": [
"# SVM\n",
"Support Vector Machine"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"73.47% Accurate\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 720x480 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"246.958125pt\" version=\"1.1\" viewBox=\"0 0 238.239215 246.958125\" width=\"238.239215pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <metadata>\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2021-02-04T13:23:58.773655</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 246.958125 \nL 238.239215 246.958125 \nL 238.239215 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 115.92 44.062125 \nC 99.075171 44.062125 82.587642 48.955686 68.470345 58.145363 \nC 54.353047 67.335039 43.204871 80.430983 36.386877 95.834333 \nC 29.568883 111.237684 27.370307 128.294994 30.05957 144.923768 \nC 32.748832 161.552542 40.211847 177.047346 51.537469 189.516449 \nC 62.863092 201.985552 77.570861 210.899983 93.865132 215.171397 \nC 110.159404 219.44281 127.348934 218.890002 143.335086 213.580463 \nC 159.321238 208.270925 173.425839 198.4299 183.926885 185.25886 \nC 194.427931 172.087821 200.879942 156.145515 202.495204 139.37831 \nL 115.92 131.038125 \nL 115.92 44.062125 \nz\n\" style=\"fill:#008000;\"/>\n </g>\n <g id=\"patch_3\">\n <path d=\"M 202.495204 139.37831 \nC 203.657854 127.309428 202.284095 115.129721 198.462337 103.623035 \nC 194.640579 92.116349 188.454605 81.534941 180.302525 72.559795 \nC 172.150445 63.58465 162.210972 56.412525 151.123781 51.504997 \nC 140.036589 46.597469 128.044737 44.062123 115.919984 44.062125 \nL 115.92 131.038125 \nL 202.495204 139.37831 \nz\n\" style=\"fill:#ff0000;\"/>\n </g>\n <g id=\"matplotlib.axis_1\"/>\n <g id=\"matplotlib.axis_2\"/>\n <g id=\"text_1\">\n <!-- Correct -->\n <g transform=\"translate(8.603904 198.123656)scale(0.1 -0.1)\">\n <defs>\n <path d=\"M 64.40625 67.28125 \nL 64.40625 56.890625 \nQ 59.421875 61.53125 53.78125 63.8125 \nQ 48.140625 66.109375 41.796875 66.109375 \nQ 29.296875 66.109375 22.65625 58.46875 \nQ 16.015625 50.828125 16.015625 36.375 \nQ 16.015625 21.96875 22.65625 14.328125 \nQ 29.296875 6.6875 41.796875 6.6875 \nQ 48.140625 6.6875 53.78125 8.984375 \nQ 59.421875 11.28125 64.40625 15.921875 \nL 64.40625 5.609375 \nQ 59.234375 2.09375 53.4375 0.328125 \nQ 47.65625 -1.421875 41.21875 -1.421875 \nQ 24.65625 -1.421875 15.125 8.703125 \nQ 5.609375 18.84375 5.609375 36.375 \nQ 5.609375 53.953125 15.125 64.078125 \nQ 24.65625 74.21875 41.21875 74.21875 \nQ 47.75 74.21875 53.53125 72.484375 \nQ 59.328125 70.75 64.40625 67.28125 \nz\n\" id=\"DejaVuSans-67\"/>\n <path d=\"M 30.609375 48.390625 \nQ 23.390625 48.390625 19.1875 42.75 \nQ 14.984375 37.109375 14.984375 27.296875 \nQ 14.984375 17.484375 19.15625 11.84375 \nQ 23.34375 6.203125 30.609375 6.203125 \nQ 37.796875 6.203125 41.984375 11.859375 \nQ 46.1875 17.53125 46.1875 27.296875 \nQ 46.1875 37.015625 41.984375 42.703125 \nQ 37.796875 48.390625 30.609375 48.390625 \nz\nM 30.609375 56 \nQ 42.328125 56 49.015625 48.375 \nQ 55.71875 40.765625 55.71875 27.296875 \nQ 55.71875 13.875 49.015625 6.21875 \nQ 42.328125 -1.421875 30.609375 -1.421875 \nQ 18.84375 -1.421875 12.171875 6.21875 \nQ 5.515625 13.875 5.515625 27.296875 \nQ 5.515625 40.765625 12.171875 48.375 \nQ 18.84375 56 30.609375 56 \nz\n\" id=\"DejaVuSans-111\"/>\n <path d=\"M 41.109375 46.296875 \nQ 39.59375
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAGbCAYAAADX6qdpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAAtuUlEQVR4nO3debxcdX3G8c+THbISwhqWQBKqgoqCsogaUMAFq62oKCpQWdRaoK61FYcRS21RUKsii4BaF1woVVS0KKBFFkFQQUBZghBkzU42knz7x+9ELpPJzbm5985v5pzn/XrNa5KZMzPP3JvMM+f3O4siAjMzs40ZkTuAmZn1BheGmZmV4sIwM7NSXBhmZlaKC8PMzEpxYZiZWSkuDDMzK8WFYWZmpbgwzMysFBeGmZmV4sIwM7NSXBgVIGmGpJB0Ue4srSSdWmSb0+a+N0u6WdKSYplPF7fPlTS3s0nNbGNcGF1K0jMk/aekWyUtkrRK0oOSfiDpHZLG5s44GJL2A74GTATOBprA5VlDbYCk84pCWyZpSu48ZrmMyh3A1ifpo0CDVOjXAl8GlgLbAHOA84F3AXtnijgQnwO+Cfyp5fZXAwLeHhG/bLnvZZ0IVoakicARQACbAW8lvSez2nFhdBlJ/0z6tn0/8IaIuL7NMocB7+t0tk0REY8Bj7W5a/vi+sE2j7l7WEMNzFuACcCZwHuA43BhWE15SKqLSJoBnAo8CbyqXVkARMRlwCtKPN9ukj4h6UZJj0paKek+SedK2qHN8pJ0lKRfFsuvkHS/pB9LelPLss+R9I1ivmFlsfyvJX1a0ug+yz1tDkPS0ZICOKZY5N7i/ijef79zGMW8x5WSFhb5bpf0kXZDdMVzXiVpW0nnS5onaY2kozf2s+vjOGAt8Gng+8BzJO2zoYWLocQL+vxcHpH0C0nv2pRlNzY/Vby/aLltTvGYUyW9sBjGnN/yMz6w+Hfwe0mLJS0vhj8bksZt4LVGSnqnpGuKYdLlku4qfrazi2X+rXidozbwHHsV91+2oZ+hdS+vYXSXY4DRwDcj4tb+FoyIlSWe72+BdwJXAr8EVgG7A8cCr5G0d0TM67P8vwIfBu4FvgUsArYDXgC8AbgYUlkA15OGab5XLD8JmAW8G/gIqfTauYW0BvU64LnAZ4CFxX0L2z1gHUkXkH5GDwDfLZbfFzgNeJmkgyNidcvDpgLXkYb0LiF9+D/c3+v0eb3nAXsB/xsR9xcf2q8Hjie9/9blXw18GxhLmo/5BjCleJ8fJM3VDHjZQdiP9Pv8P+ACYBrp3wDAh4BnkP5d/AAYB7yI9IVljqSXR8SaPnnHAJcBB5PWfr8OLAZmAH9TvMYfgXOK/MeThlJbnVBcf3EI3p91WkT40iUX4KekD+FjB/i4GcXjLmq5fTowts3yhwBrgLNbbn+c9GG8eZvHTOvz508Vr/faNsttAYzo8/dTi2XntCx3UXH7jDbPMReY23Lb0cXylwCbtdy37jVOark9istXgFGb8Pv4YvH4Nxd/HwX8mVQ+k1p/PqSCXQW8tM1z7bCJy7b93fa5/6r03/hpt83p895P2MDjdgXU5vbTise9qeX203nqC8LYlvvGAlv1+ftlxbJ7tCw3EVhCms8aOVT/b3zp3MVDUt1lu+L6gaF4soiYF23WRCLiJ8BtwKFtHvYkqUxaH9NuHmJ5m+UWRMTaTYi7MScBq4G/i4jW1z2NVHZHtnncKuD9sf6aR78kjSfNXywC/hugeI6vAePbvNZRpLWssyPi6tbni4gHNnHZwbglIs5pd0dE3BPFp3iLs4rrv/zbkDSStOa4HHhn67+piFgZEY/2uWnd2tEJPN26+aDzo8/ai/UOD0lVmCSRPtiOJg11bAGM7LPIqpaHfA34B+D3kr4FXA1cGxGLWpa7mPQBfqmk7wBXANfEME1WS9q8yP8YcHJ6W+tZCTyzze1zI+KRTXjZI0jfiM+JiBV9br+ItMHBcTx92Gjf4vpHJZ57IMsOxg0buqMoxJNIw0m7kd5r3x/s9D5/fgYwGbg+ItbbSKGNH5GGKd8m6UMRsay4/XhS6Z9f+h1YV3FhdJc/kz70pm9swZLOBE4unvfHwDyeWis4Gti5Zfl/BO4hzRP8U3FZLemHwPsi4i6AiLhB0ouBfwEOB94GIOlOoBkR3xii/OtsQfow24q0ufFAPLSJr3l8cX1R3xsj4lZJNwF7FXNANxZ3TSmu+84JbchAlh2Mtu+92CjhZ8ALgVtJXwAe5al5pwZpmGmdKcV1qbwRsVbSOcAngDcBF0raC3g+cGnJ0rEu5MLoLv8HHETaD+FLg3kiSVsDJ5I+EPaPiCUt97+59THFMMGngU8Xjz+A9E37DcDuknZfNxwREdcChxVbJ+1F2mrrH4CvS3o0Iq4YTP4W69Zwbo6I5w/wse2GXfpVTOq/sPjrtRtYo4FUKusKY2FxPR343UZeYiDLrhve29D/1Sn9PHZD7/21pPd3UUQc0/cOSduxfikvLK4H8kXmAtLGDScAF/LU8FTbITLrDZ7D6C4Xkr7lvV7Ss/pbsN1mpC12Jf1+f9KmLHYo7t+giHgkIi6JiDeSvo3OBPZos9zKiPhlRHyUVFCQPpCGTEQsJc257C5p6lA+9wasW7u4ilTc7S7LgTdLmlAse11x/coSzz+QZRcU1zu23iFpEmk4aaBmFdeXtLnvpW1uu4NUGs+RtH2b+9dTzGl8B9hH0ouAN5OGqX4y4LTWNVwYXSQi5pK2+BkD/EBS2z25Jb2CjY9/zy2uDygmLdc9dgJwHi3fWCWNLf5jt77WaNKmqQDLitv2l7RZm9fcpu9yQ+xM0s/lArU5PIekLSQNdO1jPcX7OpI08X9kRBzb7kLarHcC6YMQ0iaki4F3SXpJm+ftu99L6WWLsr8DeFHfLxHF7/RM0t7nAzW3uJ7T8rq7Av/eunCx5vmF4rW+2PplRdIYSVu1eZ11czwXk35W5w3TBhHWIR6S6jIRcbqkUaRhgV9J+iVp2GPdoUFeAszmqaGQDT3PQ5K+SRpSukXST0gTlwcDK0j7Q+zZ5yGbAf8n6S7gJuA+0rb5B5PmVb4XEbcXy34QOEjSL0jfGpeS9u94Jekb8bmD+BFs6P1cUIyDvxu4W9KPSZtnTgV2If1cLiTtdzIYbyIN83x/I2Pt55MOE3I86YPwMUlvIX2rvlLSj4DfkraGeg5pDWGX4r2UXrZwBmmt5hpJ3yb9/g4k7bPzG9IGAQPxfeAu4L2Sng3cDOwEHEbaJ2OnNo9pAvsArwH+UOx4t6TIegjwAdaf77lG0rp8T5KGqayX5d6u15f2F9KH9H+S5iAWk7Zo+jNpzeId9NkWng3vh7E5aWe8u0gfMvcDnwe2pGX7fdKHzweL5/9TsfyjpOGTdwJj+ix7COnD+fek+YUngDuBzwI7t2Q4lSHYD6PPfYeRtvN/pPiZPETaGujjwDNalg3gqgH+3K8pHvfXJZa9s1h2zz637U7a72Neke9h0tZmx7d5/ECWfQdpWG5l8Z7Pafd7LJadU+Q6tZ/sO5K2ilu3IcRtxe9/1IZ+bsV97yl+3kuL3/sfSV8QZm3gdU4qnu/buf9P+TL4i4pfqpnZkCv2jj8KeHlE/DRzHBskF4aZDQtJO5LWQO4Bdg9/2PQ8z2GY2ZAq5md2I82fjQVOcVlUg9cwzGxISbqKtBHC/cBZEfHprIFsyLgwzMysFO+HYWZmpbgwzMysFBeGmZmV4sIwM7NSXBhmZlaKC8PMzEpxYZiZWSkuDDMzK8WFYWZmpbgwzMysFBeGmZmV4sIwM7NSXBhmZlaKC8PMzEpxYZiZWSkuDDMzK8WFYWZmpbgwzMysFBeGmZmV4sIwM7NSXBhmZlaKC8PMzEpxYZiZWSkuDDMzK8WFYWZmpbgwzMysFBeGmZmV4sIwM7NSXBhmZlaKC8PMzEpxYZiZWSkuDDMzK8WFYWZmpbgwzMysFBeGmZmV4sIwM7NSXBhmZlaKC8PMzEpxYZiZWSkuDDMzK2VU7gBmnaCmBEwBNgdG9r385gvwnEcAWAusAVYAS4ClRESOvGbdyIVhPUtNTQB
},
"metadata": {}
}
],
"source": [
"### TRAIN ###\n",
"clf = svm.SVC()\n",
"clf.fit(data_train, labels_train)\n",
"\n",
"### TEST ###\n",
"pred_test_labels = clf.predict(data_test)\n",
"\n",
"### EVALUATE ###\n",
"correct = [i == pred_test_labels[idx] for idx, i in enumerate(labels_test)]\n",
"correct = sum(correct) / len(correct)\n",
"print(f'{correct*100:.2f}% Accurate')\n",
"\n",
"plt.pie([correct, 1 - correct], labels=[\"Correct\", \"Incorrect\"], colors=['g', 'r'], startangle=90)\n",
"plt.title(\"Classifier Accuracy\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 720x480 with 2 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Created with matplotlib (https://matplotlib.org/) -->\n<svg height=\"277.314375pt\" version=\"1.1\" viewBox=\"0 0 335.725437 277.314375\" width=\"335.725437pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <metadata>\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2021-02-04T13:23:59.434830</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.3.4, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 277.314375 \nL 335.725437 277.314375 \nL 335.725437 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 60.570312 239.758125 \nL 278.010312 239.758125 \nL 278.010312 22.318125 \nL 60.570312 22.318125 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g clip-path=\"url(#pca4b808932)\">\n <image height=\"217.8\" id=\"imageb83a10a26a\" transform=\"scale(1 -1)translate(0 -217.8)\" width=\"217.8\" x=\"60.570312\" xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAWsAAAFrCAYAAAAXRqh4AAAGCUlEQVR4nO3WoWoWYBiGYTf/waYGizgMFgWbDBYMGgU12D0Dww5CjBaD1WKymS07AmHR6jBMMQji0pz/5iFY9Pu42XUdwQMv3Lwrt3denp47Q35dnr1gvKsfjmZPGGptd2/2hOGO72/PnjDUWbzx6uwBAPydWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwSINUCAWAMEiDVAgFgDBIg1QIBYAwQsrr0/mL1hqJ9bm7MnDLfx8Wzd+GR9ffaE4Taff5o9Yajvu7MXjOezBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCFj83v88e8NQly5uzJ4w3OGd67MnDHXv2ZfZE4bb2z6ePYH/zGcNECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAYuVxWL2hqGOr1yYPWG4py/ezZ4w1NsHd2dPGG5l7dvsCUOdHi1nTxjOZw0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRCwWL11Y/aGoV6/eTV7wnA7W49nTxhr+WP2AvjnfNYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QINYAAWINECDWAAFiDRAg1gABYg0QsLL8evN09oiRHj18MnvCePsHsxcMdXJ4OHvCeKvnZy8Y62Q5e8FwPmuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4AAsQYIEGuAALEGCBBrgACxBggQa4CAP3+8NGg/ulWcAAAAAElFTkSuQmCC\" y=\"-21.958125\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <defs>\n <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"mebdb0f4e4d\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n </defs>\n <g>\n <use style=\"stroke:#000000;stroke-width:0.
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAHRCAYAAACbw+jrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAC6L0lEQVR4nOzdd3wU1drA8d9J2fReaKH3XlRABZVmAURBVCyA3au+drFcvdZrw36vV0UsiF2xK4gVAaVY6B1CCwmk97bZPe8fs5tksyVl0zY+Xz/7iZyZM3Nmy8wzp43SWiOEEEII4Wv8WroAQgghhBANIUGMEEIIIXySBDFCCCGE8EkSxAghhBDCJ0kQI4QQQgifJEGMEEIIIXySBDFCCCGE8EkSxAghhBDCJ0kQI4QQQgifJEGMEEIIIXySBDFCCCGE8EkSxAghaqWUukgptUEpVaCU0kqp55thnyuUUvJwt0aglDqglDrQ0uUQorFJECN8mlLKXyl1tVLqF6VUtlLKrJRKV0ptVkq9ppSaZlvPTyl1yHYBHlDLNkOVUrlKqXKlVKItbZEtr1ZKPewh79xq661owPF0Vko9oZT6UymVU+14flBK3ayUiqrvNr2llDoReBeIAF4GHgK+be5ytAb2wMr2usLDeg9UW2+Rl/vs1hjbEaItCmjpAgjRUEopf+Br4EwgF/gGSAFMwEDgYqAf8KXW2qqUegN4ALgKuM3Dps8HooAlWuv0GssqgMuVUg9prS0u8l5tW6fevy2l1FXAi0AQsAl4H8gB4oAxwPPAv4D4+m7bS1MABczRWv/WjPudA4Q24/7qowLje/RGzQVKKT/gChr4PWgiE1q6AEI0hdbyAxOiIS7CCGA2AadqrfOqL1RKhQKjqiW9DtwHzFZK3a21Lnez3atsf191sexr4Fzbfr+psb/+wMnAZ8D0+hyIUuoSYCFG0HKe1vobF+ucDPyvPtttJB1tf1Obc6da60PNub96+ho4Vyk1UGu9rcayM4AuNOB70FS01vtaugxCNAVpThK+7CTb30U1AxgArXWx1vrnav8+jNEMEo+bi4tSqh9GrUcy8IOLVd4FSjBqXGqyp71W1wOw7TMC+I/tn7NcBTAAWutfcQzK7PknKKW+tTWnlSmldtuapJyanqo1hwQopf6plNpjy3NYKfWkUspUbd3LbH1SLrcl7a/WRNLNto7bZrNqTXDdaqRPU0r9qJRKs+071dYceL2rsrrYrp9S6h9Kqd+VUoVKqSLb/19nqwWpub62bSteKfVqtf1uU0pdXnP9OrJ/xu6+B8UY3xUnSqmOSqn7lVK/KqWO2potU5VS79Vs6lRKPQjst/2zelOlVkpdZlvnNNu/H1RKjVRKfWP7LlT/nBz6xCilYmxpZUqp42rs008p9bMt/+z6vjFCNCepiRG+LMv2t0898izEaB65CvjQxXJ7LczrWmtXnUpzgY+Bi5VS7bXWRwGUUkEYzR+/ALvrUR6AmUAssFZr/Z2nFbXWZdX/rZS6FqOfSpGtXOnAacBdwNlKqZO11rkuNvUeMBZYBuQDk4E7gUSqgpaNGP1fzgWGAi9gHD/V/taLUuoaYAFwFPgKyLTtc4htvy/VYTNvYzQVHsYIJjRGUPoSRgB6iYs80cCvQDmwBKPJ7nzgDaWUVWv9Vj0PZRewErhUKXWX/XNRSrUHzsYIYJwCa5tTgLuBn4FPgEKgN8b3YJrtM9tkW3eFrew3Y9Q4fl5tOxtrbPdE4B5gNUYzV7zteJ1orXOUUhfZjuFDpdRwrXWBbfEDGN+hRVrrtz28B0K0PK21vOTlky9gOMZJ2opxYZsBdK0ljz9wxJane41lJowgwAy0r7FsEcbFciLGhVID91RbPsuWdgnQy/b/K+p4HK/b1v93PY+/K1CGEYT0q7HsJds2X62RvsKW/icQWy09DNgLWDwcezcXZXB7nK7y2fZbBiS6WD/eVVlrpF1k2+ZfQHiN8v9hW3axizJqjIDHv1r6AIx+K9vr8Z7b379ewKW2/7+o2vK7bWkn274rGiMYqL6NRCDCxbaHYgQ0y2qkd3O1nWrLT6t2jNe6WecAcMBF+p22fO/b/j3O9h3YDoR68/uUl7ya4yXNScJnaa03YFxIjtn+fgIcUEplKaU+U0qd7SKPBeMuVQFX1lh8DpAAfKVtNSxu9rsa2AlcpZRStuSrMfqzfNKAQ+lg+5tSz3yXYgReL2qtd9ZYdi9QgNH/J8hF3ru01tn2f2itizBqD/yA4+tZjvqqwAgUHWitM+uQ1z4i6G6tdWG1vEUYtU9QVZtWXTFwm67WGVtrvR2jdqa/Uiq8jmWvbgnGZ341gO27cBWwQxtNfy5prdN1Va1H9fRNwE/AOKVUYAPKs1FrvaCeeZ7CaGKdpZS6B+M7UAZcqLUubkAZhGhWEsQIn6a1/gijE+UZwCMYHS79MJpAvlRKvVUt0LB7DaMm5nLbCCc7e/+GhXXY9UKgBzBeKdUL4w72ba11aUOPpQFG2P7+VHOB1joH2AAEY4zQqukPF2mHbX9jGqV0rr2LMeJou1LqOaXUuUqphHrkH4Hx2a1wsewXjFqE4S6W7dFa57tIb/Ax2z7rd4DTbN+B8UBP6vD9UUpNUUp9ZeufY7b3c8FoigqiYSPQ1tc3g9ZaYzSDpgKPYQTUt2ittzRg/0I0OwlihM/TWpu11t9pre/XWp+NcQG4EKOfyByMGpbq6x8EvscYdTMZjLk4MKr/DwLL67DbxRh3rFfZXoq6BT+upNn+dqpnPnvH3TQ3y+3p0TUXaNf9ZCpsf/1dLGsUWutngbkY7/NNGCN4jtk6ktalBigKyNYuRpZprSsw+ti4mksn1832vD3mhVTV6l2N8Z1Y7CmDUupmjGD7ZIw+Kc8DD2P0P7L3hXFVe1Ybt7WHnmitM2zlAKOfmfSDET5DghjR5mitLbYamudsSeNdrGYfPm2vfbkS42L0utbaWod9ZFI1hPYKYI3WemsDi7za9re+c3nYO462d7O8Q431moLG/QCBaJcZtF6stR6NMf/NFIw+QacAy+tQK5MHxLpqblFKBWAEsK5qXJqErcZiLcb3ZzrwqdY6y936tjI+iBFwDNRaX6i1nqe1fkBr/SBG02iDi9OQTEqpWRh9ujIxPpP/eM4hROshQYxoy+z9Dmo2JwF8iXEhmayU6owxMsbeX6auFmLcMSfQ8FoYMPpWZAMnKqUmelqxRv+WDba/p7lYLxoYBpQCO7woW21ygM4u9u9v279bWutcrfVSrfXVGJ2AYzGCGU82YJy3XK13CkaNyl+1lrpxLcT4Dpio/XsQjxHc/aa1dqhBs/XLGeEij70fT6PXkNmawV4FMjCa4VZi9PWa1dj7EqIpSBAjfJYynuczyc3cIO2pqmVZWXO5relhEcaF4V2MppylWusj9SjCzxhNVdOBD+pXeoeyFGA0rYAx3PUMV+sppUYDa6olvYPRQfZG28WoukeASOAdXWNYdiNbD3RRSp1eI/0+jNFTDpRS41z0UQJjxA4YHXA9sQeZjytjMkP7dkOBJ2z/fL3WUjeuDzC+A+fguq9OdekYx3hc9c7EtpqlF3DdFyYHo5alS2MUtto+TRhlDwfmaq1TMIauZwELlFI9G3N/QjQFmSdG+LJRGPNnHFVKraZqUrDuGM0UIcAXGDUdrizEGNEy1vZvVzP0umXrFPllPcvsblvvKqVCMB478K1SaiPwG1WPHTgRYwhuZrU8B5RSt2DM4vuXUuojjDvqU23r76RqxE5TeRqjU/UXSqkPMWqUTsL4DFbgXEv0GVColFqLMexXYbz/J2AMv3Y1wWAlrfV7SqlzgAuAbUqpzzEu8Ofa9vmh1trlJHNNxTaK5/M6rmtVSv0HYyj2FqXUFxg1OOMwaqJ+tv1/9TyFSql1wFil1LsY8xBZMB6nsdmLos8HjgOe1Vovs+3riG0Sva8wAuqTXPU/EqK1kJoY4cueAf4Po0/CEOAfwC0Y87isAGYDM2zBhhOtdTLwo+2fKRgTv7UYrfVrGJOezadqzpm7MCZBywFuxRj9Uj3PSxhBxFrgPIxnQiViDJ09sfow6iYq848
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"plot_confusion_matrix(clf, data_test, labels_test, display_labels=playlist_names, normalize='true')\n",
"plt.title('SVM Confusion Matrix')\n",
"plt.show()"
]
},
{
"source": [
"From the above it is clear that the Pop playlists was not effective for classifying similar tracks. This is likely because of the overlap with other genres such as Rap and EDM where much of the confusion occured. Also not helping is that one of the sub-playlists is shared across EDM and Pop, electropop. as EDM is already such a larger playlist it is unsurpising that this performance was poor. Additionally the overlap with Rock is understandable as Pop contains an Indie Pop sub-playlist which could have cause some confusion. Quite surprising was the confusion for Jazz as I wouldn't have thought there would be much overlap here.\n",
"\n",
"The other major confusion was with Rock and Metal, specifically classing Metal tracks as Rock. This could be expected due to the similarity in tone."
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"## Other Tests\n",
"\n",
"Take a handful of other tracks which I don't listen to and aren't in any playlists to see if they can also be classified"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Top Of The World (Five Finger Death Punch) could be ROCK ✓\nAston Martin Music (Rick Ross) could be RAP ✓\nOn The Sunny Side Of The Street (Dizzy Gillespie) could be JAZZ ✓\nVibez (ZAYN) could be EDM ✗\nShot In The Dark (AC/DC) could be ROCK ✓\nTo Hell and Back (Sabaton) could be ROCK ✗\nWithstand The Fall Of Time (Immortal) could be METAL ✓\nAlone Together - Rudy Van Gelder Remaster (Kenny Dorham) could be JAZZ ✓\nFeel No Ways (Drake) could be RAP ✓\nBO$$ (Fifth Harmony) could be EDM ✗\n\n70.00% Accurate\n"
]
}
],
"source": [
"### PREPARE ###\n",
"test_uris = [\"spotify:track:53yqxU2EKKzbuQZEUEVtxc\",\n",
" \"spotify:track:5W7xC99N2Zzfh69r7I7zWK\",\n",
" \"spotify:track:38R2EViAkYOFG8ZkG3GLtW\",\n",
" \"spotify:track:6T6D9CIrHkALcHPafDFA6L\",\n",
" \"spotify:track:0sfdiwck2xr4PteGOdyOfz\",\n",
" \"spotify:track:1BrgjqSg9du0lj3TUMLluL\",\n",
" \"spotify:track:5nCnSnLtotQ8eB4E189U91\",\n",
" \"spotify:track:3GOZbK2epuHzCt5YvvVFHO\",\n",
" \"spotify:track:3cjF2OFRmip8spwZYQRKxP\",\n",
" \"spotify:track:1COvXs6jaykXC73h9OSBVM\"]\n",
"# inferring what playlists these would go in\n",
"test_labels = [\"ROCK\", \"RAP\", \"JAZZ\", \"POP\", \"ROCK\", \"METAL\", \"METAL\", \"JAZZ\", \"RAP\", \"POP\"] \n",
"\n",
"test_tracks = spotnet.tracks(uris=test_uris)\n",
"spotnet.populate_track_audio_features(tracks=test_tracks)\n",
"\n",
"test_features = [ {j: k for j, k in i.audio_features.to_dict().items() \n",
" if j in headers} \n",
" for i in test_tracks] # filter down to descriptor columns\n",
"\n",
"### PREDICT ###\n",
"predictable_frame = pd.DataFrame(test_features)\n",
"\n",
"predicted_labels = clf.predict(predictable_frame)\n",
"labels_correct = [i == playlist_names[predicted_labels[idx]] for idx, i in enumerate(test_labels)]\n",
"\n",
"### EVALUATE ###\n",
"for track, label, correct in zip(test_tracks, predicted_labels, labels_correct):\n",
" print(f'{track.name} ({track.artists[0].name}) could be {playlist_names[label]} {\"✓\" if correct else \"✗\"}')\n",
"\n",
"correct = sum(labels_correct) / len(labels_correct)\n",
"print(f'\\n{correct*100:.2f}% Accurate')"
]
},
{
"source": [
"# Imports & Setup"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"\n",
"from google.cloud import bigquery\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.dpi'] = 120\n",
"\n",
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
"from analysis.query import *\n",
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import svm\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import plot_confusion_matrix\n",
"\n",
"client = bigquery.Client()\n",
"spotnet = get_spotnet()\n",
"cache = 'query.csv'\n",
"first_day = datetime(year=2017, month=11, day=3)"
]
},
{
"source": [
"## Read Scrobble Frame"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"scrobbles = get_query(cache=cache)"
]
},
{
"source": [
"## Write Scrobble Frame"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
],
"cell_type": "code",
"metadata": {},
"execution_count": 6,
"outputs": []
}
]
}