listening-analysis/playlist-classifier.ipynb

530 lines
458 KiB
Plaintext
Raw Normal View History

2021-02-04 13:34:25 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"# Playlist Classifier\n",
"\n",
"Given a list of playlists, can unknown tracks be correctly classified?"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 5,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
2021-02-13 11:53:15 +00:00
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
2021-02-04 13:34:25 +00:00
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
"\n",
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
"headers = float_headers"
]
},
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"Pull and process playlist information.\n",
"\n",
"1. Get live playlist track information from spotify\n",
"2. Filter listening history for these tracks\n",
"\n",
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 6,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
"\n",
"# filter playlists by join with playlist track/artist names\n",
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
"\n",
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
"# distinct on uri\n",
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
"# select only descriptor float columns\n",
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
]
},
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"Construct the dataset with associated labels before splitting into a train and test set."
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 7,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.concat(filtered_playlists)\n",
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
"labels = np.concatenate(labels)\n",
"\n",
"# stratify: maintains class proportions in test and train set\n",
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, test_size=0.25, random_state=70, stratify=labels)"
2021-02-04 13:34:25 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
2021-02-23 18:22:32 +00:00
"# SVM Kernels\n",
2021-02-04 13:34:25 +00:00
"Support Vector Machine"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 12,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>uw-rbf</th>\n",
" <th>w-rbf</th>\n",
" <th>linear</th>\n",
" <th>poly</th>\n",
" <th>sigmoid</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>accuracy %</th>\n",
2021-03-10 20:01:13 +00:00
" <td>75.19</td>\n",
" <td>72.42</td>\n",
" <td>68.26</td>\n",
" <td>71.54</td>\n",
" <td>32.49</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
2021-02-13 11:53:15 +00:00
"text/plain": [
" uw-rbf w-rbf linear poly sigmoid\n",
2021-03-10 20:01:13 +00:00
"accuracy % 75.19 72.42 68.26 71.54 32.49"
]
2021-02-04 13:34:25 +00:00
},
2021-03-10 20:01:13 +00:00
"execution_count": 12,
2021-02-13 11:53:15 +00:00
"metadata": {},
"output_type": "execute_result"
2021-02-04 13:34:25 +00:00
}
],
"source": [
"### TRAIN ###\n",
2021-03-10 20:01:13 +00:00
"clf = svm.SVC(kernel='rbf')\n",
2021-02-04 13:34:25 +00:00
"clf.fit(data_train, labels_train)\n",
"\n",
2021-03-10 20:01:13 +00:00
"wclf = svm.SVC(kernel='rbf', gamma = 1/(2*(sig_max**2)), C=c_max, class_weight='balanced') # weight classes based on prevalence\n",
2021-02-13 11:53:15 +00:00
"wclf.fit(data_train, labels_train)\n",
"\n",
"lclf = svm.SVC(kernel='linear', class_weight='balanced')\n",
"lclf.fit(data_train, labels_train)\n",
"\n",
"pclf = svm.SVC(kernel='poly', degree=3, class_weight='balanced')\n",
"pclf.fit(data_train, labels_train)\n",
"\n",
"sclf = svm.SVC(kernel='sigmoid', class_weight='balanced')\n",
"sclf.fit(data_train, labels_train)\n",
2021-02-04 13:34:25 +00:00
"\n",
"### EVALUATE ###\n",
2021-02-13 11:53:15 +00:00
"models = {'uw-rbf': clf, 'w-rbf': wclf, 'linear': lclf, 'poly': pclf, 'sigmoid': sclf}\n",
"accuracy = {i: j.score(data_test, labels_test) for i, j in models.items()}\n",
"\n",
"(pd.DataFrame(accuracy, index=['accuracy %']) * 100).round(decimals=2)"
]
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 13,
2021-02-13 11:53:15 +00:00
"metadata": {},
"outputs": [
{
"data": {
2021-03-10 20:01:13 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8UAAADRCAYAAADlnRB8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AABOZElEQVR4nO3dd5gb1dnG4d+7uy64G2NjwI1ueu/FpiehJqGX4BQglBRISAiBiPlCTUIJCSGUgIGEltB7t+m992p6MS644rJ7vj/OiJVlaXe1K+mMNM99XbrWHs1Kz0rT3pkz55hzDhEREREREZE0aggdQERERERERCQUFcUiIiIiIiKSWiqKRUREREREJLVUFIuIiIiIiEhqqSgWERERERGR1FJRLCIiIiIiIqmlolhERERERERSS0WxiIiIiIiIpJaKYhEREREREUktFcUiIiIiIiKSWiqKRUREREREJLVSVxSb2Ulm5sxsbIm/t6GZ3WNmX8a//3xFAoqUgZZzSaJ4mZoQOodIvTGz8fH6NSp0FkknMxsXL4PjQmfpCjMbG/8dJ5XwO5065pJk6VRRnLPATGhjnlHxPJM6Gy4pzKwfcBuwMXA1EAH/LPE1Loo/jzlmNqD8KSUkM9sg/n6fKPL8fvHzzsyWL/D8Emb2dbx89Kh84oIZtZxLybLLdegcSaV1QrJy9gHZR3N8AvJ+M9s/dL6OMrMtcv6GQ0Pnkcows0YzO8TMJprZVDNbYGZfmNmLZnaxme0WOmM9MLNeZjY9Xp+uDJ0nzZpCB6gRGwNDgN87504t9ZfNrC+wL+CAJYADgb+XNaGE9hwwDdjAzPo552bkPb8d/vs3YFvgX3nPbwH0AO5xzs2rdNgitJxLpa0GzAkdolq0TkgRUfyzGzAa2B3Yxsw2dM4dEy5Wh2ULYRf/+8KAWaQCzKwRuBX4FjAdf8L8I6A7sAawP37ZvTnn124AHgc+rWbWCngSv6/6skrvtw/QH78+fc/MBjnnplTpvSVH6ppPd9Ky8c9POvn7+wN9gLOB+cAh5QglyeGcawEmAI3AmAKzbBs/PyX+d6HnAe6rQLyO0nIuFeWce90590HoHFWkdUIW45w7KX783jn3fWAn/AHxL5Pe/Dlu7bAX8BZwPf5E8HpBQ0kl7IcviF8ARjnnDnTOHeecO8Y5txOwFHBi7i84576Kt/FfBchbNs65OfHfUa2i+FCgBfgz/uLIwVV6X8lT1aLYzA6Lmwcckjf9hznNy3rkPfdE3Kx0iQ6+x6T40c/Mzor/vaDQvQFmdrCZPWdmc+MmIZeY2dCc50fFzQIviyddmtNkaFwJf/oh+AX+HOAWYG0z26SNv2F0nGWSmc2Lsz1kZod3Zt6cpuzji7zfhPzmj7n3VJjZxmZ2W9x85pt7lsxsGzO70MxeNbMZ8ef4spllzKxnkfdqNLOfmtkjZvZV/Dtvx01xVo7nOS1+n4IbBmttqnxrsc8wkGxBu0jRG39ey8fPTwS2KfC7JRXFWs61nNciK3DbjeXci2Vme5rZk+b3BVPN7GozW67Iay0Zf4avxZ/vV2Z2n5ntWGDe/mZ2rPlmqh+Z2Xwzm2xmN5vZZm1lNbOh8ff2sfnmruNK+JO1TrS+rtaJIpxz9wGv41sSbZSdHv+918Xf7Twze9/M/mFmy7T3mvHy4czsgTbmecn8fqPd18txIL7Vw/j4Aa1Xjgu9Ry8z+62ZPW1mM81sVrzOnmtmS3dm3kLLbc5zBe9rtXb2mWa2rJn9IV4+P4u3EZ+Y2ZVmtnobf9/GZnZNvH2YZ2afmtndZrZ3/HylvodK2zz+Ob5QkRsXjov8TcU++/i5neLPdna8Pbkx/mwWuxc+d7tlZiua2f/MbEq8TNxtZmvG8w2Ot0Ofmq8TnjKzQsdX2X3AaWb2RjzvNDO7y8y2LzDvN9vAAs9tYGZ3xllmmNm9VmQf0hHx37Ip/tjvDPzJ05+08zv7mN/XTY3/lklmdpWZbdiZea2N+6GtyD4k53tbwcx+Zr5J/VyL9+9m1t3MjjKz281vt+bFGe41s2+38bcNi9f3t+LXm2r+mODE+PlGM/sw/uz7FHmNv8XZ9mzrcyyk2leKswf82+VNz/5/CeCbhcvM+gMbAI855+aW8D7dgfuBPYC7gb8C7+XNczT+fskX8AcsbwA/BB41s8HxPNPxzZxuiv9/U/z/CHi+I0HMn0HdALjPOfch7exEzGxn4Fn8maJXgLOA6/BXIH/T2Xm7YDPgIaAncAm+cJofP/dbYEf8Z3EBcHH83EnAHeab3+Tm7Q7cAZwPDAeuBM4FngG+i29CTPxaLRTf0R4W/yzpftcquD/+WWz5vh94AFgmdydr/l7eDfHNr58t4f20nGs5rydHAP8GJgHnAS/jm5Xda4ufLB2J/zyPAybjP6Nr8E3e7rS8E6/x9FPwn/dt+GXoHvzJqAfN7FtFMi2Jbw64Kf6q2N+Bzzvyx2idWCSv1on2WfzTAZjZLsCjwK7Avfjv+A3gcOBpK9A3RS7n3Ov4/c1YM1tlsTcz2xxYE7jJOVdKc9fsiZ7LgTuBz4D9zax3gfcYGP8Np+NbTFyCXwZew++HVuvMvF3Q1j5za/z2ZDp+XTobv+7vCTxpZusU+PsOiTPvEf88E799GYLfnlXye6i0bPPdxTKXysz2xa//6wH/xa/nA4HHgFFt/Ooo4Algafz2825ge2CC+RNpj+NPIl0DXAusg9/+jMh7/wH47+c44Cv8sdB1+G3c3WZ2GB0Qf1cPxRnuwO8P5uNbARY92dmO7LZuvHNuKv7k6WpmtlWB97e4OL0aWBu/Tzo7zrQVsEtn5u2ivwJ/BF6K//1IPH3J+P998fvas/BN7dcDbjezxQr/uFB/AfgZvsXiucB/gJn4fQvOuWbgovh19yvwGtnblD6j9Zi245xzJT+AsfgN94Q25hkVzzMpb/r7wBeA5Uz7BF8wNwN/zJm+e/waJ5aQbVL8O/cCvQs8f1L8/Hxgvbznzo6f+1fe9HHx9HGd+Kz+Gf/ufvH/m/D3W8wC+uXNuxR+hZ0PjCnwWsM6OW/2uxhfJOMEvygU/I4dcFiR31sh93vMmf7H+Pf2yZt+ajz9ZqBH3nM9gME5/781nnfNvPn64leQD4DGziy/lXzEy3JL3t+SXamb8PfiOOConOd3jaddr+Vcy3ktLOd5WV3+59rOvBOKLKszgLXynrsyfm7vAt9lC7Bv3vQB+EJtLrB0zvT+wFKFlh/8Ovtasb8Lf/Df1InPRetESteJdpZ/V2D69vHy3AKMxBeFU/DHRFvlzfvb+HXuzps+Pp4+KmfanvG0vxR4z+z8O5SQf9P4d+7KmfaXeNqPC8yfXX/PBxrynusD9O/kvIsttznPjaPAfoz295lDgL4Fpq+DX2fvyJu+OrAAmAqsUeD3ctfDsn4PVVpW18NvY1qAK4DvASPb+Z3FPvt4vZ0GzAPWyZv/dFq3NbnL7aic6b/P+50T4+lT8dvYhpznDoqfOzvvdy6Ip1/AorXHyvht6by89x8bz39SzjTDt+ZwwO55r/+LnLxjS/iMe8Z/x3RgiXjaLvHrXFFg/kPj557MXR/i5xqBZTo570nFslNkH5Kz3H4MLF/g93rkrgM50/vjT3hPzf7N8fTu+BNUDti/nfVpGfy693Qby+ApnVruO7myZBeYCW3Mk/0gJ+VNvzSevnb8/9Xj/x8OPAU8mjPvufFzm5eQbVL8O+sUeT775f+rwHP944VzLjk7bjpZLAC98Qd604GeOdOzO5HD8+b/VTz9rx147VLmLbhQ5zw/geIHRs91YvlYMv7dS3KmNcafwxxg2Q68xs7xa/wtb/ph8fQ/dGbZrfQDv/Nw5BzE4w+6b8/5/+fkFMC0FqlHlvA+Ws4Xn1fLeYAH5SuKTy4w/zbkHUziD1Id8N8i77F7/PwRHcyU3c+MKJB1HjCkE5+J1onWaalbJ9r421z8OCl+nAL8D1gYTz8rnu+A+P9XFniNJloPHkfkTB/P4sVFE37/8yWLbusHxN/H2xQ4udFG/n+Rc6InnrZmPO2JvHmH4Iv6TyhQhHZ23mLLbc5z42i7KF6nE9/
"text/plain": [
"<Figure size 1200x240 with 5 Axes>"
]
2021-02-13 11:53:15 +00:00
},
"metadata": {},
"output_type": "display_data"
2021-02-13 11:53:15 +00:00
}
],
"source": [
"fig, ax = plt.subplots(nrows=1, ncols=len(models))\n",
"fig.set_figwidth(2 * len(models))\n",
"fig.set_figheight(2)\n",
2021-02-04 13:34:25 +00:00
"\n",
2021-02-13 11:53:15 +00:00
"for (name, acc), ax in zip(accuracy.items(), ax):\n",
" ax.pie([acc, 1 - acc], colors=['g', 'r'], startangle=90, counterclock=False)\n",
" ax.set_title(f\"{name.capitalize()} Accuracy\")"
2021-02-04 13:34:25 +00:00
]
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 14,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [
{
"data": {
2021-03-10 20:01:13 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/sAAAHUCAYAAABlHTjwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAEAAElEQVR4nOzdd3gU1frA8e/ZZNN7pQQEAoRe9FpBRZpSFAsqFgR7uZar2Hu599qv5dqxwfVnB7sIiiJFsNJLKAkljfRetp3fH7ObZLO7KZAC4f08Tx7IzJzZmZ3JnPfUUVprhBBCCCGEEEII0XmYOvoAhBBCCCGEEEII0bqksC+EEEIIIYQQQnQyUtgXQgghhBBCCCE6GSnsCyGEEEIIIYQQnYwU9oUQQgghhBBCiE5GCvtCCCGEEEIIIUQnI4V9IYQQQgghhBCik5HCvhBCCCGEEEII0clIYV8IIYQQQgghhOhkpLAvhBBCCCGEEEJ0MlLYF0IIIYQQQgghOhkp7AshDohSSiullrXCfpYppXQrHFKbUEo97DzXMR19LEIIIcShQCnVy5k3vtsK+9qtlNp98EfVNpRS7zrPtVdHH4sQLSWFfSHagFJqTFOF4XoZ5e72O7LOobUqGtqKUipFKTVXKbVTKVWtlKpQSqUrpZYopR5USiU6t7vaeS4fN2Of9zq3fcH5u+v+0UqpcqVUuI90Sim1q962Y1rzXIUQQrQ/pdQxzmf6rz7WX1Tvud/by/pgZ/5UqZQKbPsjbj2tWdHQVpRSxyql/k8ptUcpVaOUKnXmxV8ppe5USoU6t/uX81yeasY+33Bue6vz9zH1rnG6Ukr5SBfm/HzXtr1a9WTFIU0K+0KIAzUQuKyjD+JQo5QaC6wDrgJygTeBF4CfgV7AI8BQ5+YfAOXANKVUXCP7VMCVzl/faLDaBoQCF/lIPg7o49xOCCFE57AWKAKOUUpFeFk/DnD1mhvrZf0oIBBYqbWuOYDPz8SIA+45gLSdmlLqUmANRr68A3gVeBn4FRgOPAl0dW7+JsZ1ukwpZW5kn6HADKAGmNdgtQ0jvpjgI/kMIByJA45IUtgXQhwQrfU2rfXejj6OQ9DrQBAwW2t9ktb6Rq31vVrr2Vrr/hgZfSqA1roco8AfQOMVJ2MxCuy/aK03N1j3J5ADXO0j7dUYwcH3B3pCQgghDi1aawewDPADTvWyyVjn+gK8F/Zdy5Ye4OdbnXFA9oGk76yUUiEYBXsNTNBaj9da/0NrfY/W+mKtdU+MipZ8AK11OvADkAic2ciuXQX2BVrrwgbrfsDI5xuLA7Ix4gVxhJHCvhCHmPpjxJVS05VSvzm72RUqpT5USnVvsP0Hzu37NVg+z7l8aYPl4Uopq1JquZfPvkgp9ZNSqtjZvW+rUup+b138fHWlV0p1VUq9o5TKVUpVKaXWKaVm1etu9rCP8/Z3dlXf4ezytk8p9aRSKqDeNrNV3fj+U+t1SfPYr1LqeKXUp0qpHKWUxbm/15VS3Xx8/jFKqe+UUmXO7m4/KKVO9LatL0qpBKAvUKK1bljzDoDWeoPWel+9Ra6W+qsa2bUrA2/Yqg9GTf07wN+UUsMbHE8ccDawAGgYHAghhDi8ufJ3t8K8s5t2b+f6n4HTvKT1KOw78+EblFJrnPlgpVJqrVLqRqWUW5mhsa70Sqn+SqkFSqkiZQxj+0UpNcWVhyulZns7GaVUqFLqaaXUXmccsFMpdVf97unOvD7d+eusBnHA7Ab7O10p9a1SKt+5v13O/Uf5+PzxSqkVzmMuVEp9rpQa4G3bRgwBIoBNWmuvFSla61+01sX1Frnydl+F9frrvMUBBcBCjF6C8fVXKKWGAcdhxAnSsn8EksK+EIeuG4D3gN0YtcSbgAuBHxoUvl2ZybgG6V2/n6SUCqq3/FTAnwa1+Uqpt4H3MQqrC5yfWQg8BnynlPJv6oCdhd3VwGxgK/A8RlfDV4Bbmkj+PnATsAKjy1sVcCdGS7nLOoxu8AB7nP93/SyrdxxXAKuAScBPzuP4A6NA/YdSqmeD4z7J+bnjgUXAS4DFuc/jmzrvekowMtMwpVTXpjYG0Fr/4TyvgUqpUQ3XK6ViMQrsJYCvsf2uboANA4VZGL0G5jbnWIQQQhxWfnT+6yv//xEjD+yqlBrkWqmMbv9/wxgG8JdzmRn4GiPvj8LIk9/AKCv8F8+u4145C8drgHMx8uEXgL3AZxh5mS9mYDFwHkY+/CYQDDwBPFhvu2XOfQKsxz0OWFfvOB4CvsPIw78BXgR2ArcDq1SDoQ9KqenOz/8b8AlG7BGLEdN4zHnQiALnv92Uc1x+M3yBMexvYsP4xHlsQ5znsV1r/bOPfczF+A5nNVh+NUZ88FYzj0V0Nlpr+ZEf+WnlH2AMxsN1WSPb9HJus7vB8oedy0uBoQ3Wve9cd0G9ZX2cyz6ptyzFuWyJ899x9dY951x2cr1ls53LFgLBPo7nlgbLPc4PIzPRwJMNlg/H6GKmgYcbrFvmXP4nEFNveShGxmwHujT12fXW9ccoqO8EujdYN865v8/qLVPANuc+pzXY/hbncg2Maea1/9S5/S6MoOJ4IKSJNNc707zrZd2tznUv+bh/Vjp//wEjcAuut81WjOAAjIqjZp+H/MiP/MiP/Bz6P0AW4ADi6y37P6AMo2J/sPPZf2O99We68vx6y1x5/X8Bv3rL/erl7dPqLXflQe82OJ6lzuXXN1g+qV5+OrvBut3O5d82yMMSgGLnj7mpz663/jTn+l+AqAbrZjvXPVdvWRhGId0K/K3B9q6YSQO9mnE9FPCbc/t1wN+BkUBAE+mexEuM5Fz3gnPd7Q2Wj3Euf8/5uTuAbfXWBzvjgu+dv69s7nnIT+f5kZZ9IQ5dL2qtNzZY5mqhPc61QGudhpFRnlavq5urVv9BjMJt/Vr/cUAFRs27yy0YLdJXaK2rGnzmYxiZ4CWNHayzu/1FGC3Q/6y/Tmu9HpjfWHrgLl1vHJrWugIjYDFh1LQ31/UYtdu3aK0zGxzHUuBL4ExVN3v9SRiVI8u11l802NdLGIX2lrgao9KkN/A0xvdcppRar5T6p3LOxN/A/2Fck/MbtjZQ173fW9e9+uZitMacD6CUOhkYgNE6IoQQonNailHQq99V/zRghdbapo15XnJx7+rv1oXf2UX/Joz5X27VWttdGzr/PwejkNhUHNDDue+duPfKQ2u9CKNSujE3149BtNa5GK3ekRj5dHPd7Pz3au3eXR6t9bsYhfD65zINiAHe10Zvu/oexohrmkVrrYHpGA0ZwzHiiL+AcqXUr85hCd4mVJyL8R1fXn/IhLMn56UYjRjvNvG5bwIpSqlTnIunY8QF0rvvCNZkt1whRIdpmOEAuMZ6RzdY/iNwBTACo9v8WCBba71GKfUnzsK+cyzXEGCJ1trqXBaCkSHlA/9Q3t/cUoMx625jUjBqkf/QWpd5Wb+Sxselt+R8G+MaZ3+qUupYL+sTMFoq+mP0Jjjaudyja5zW2q6UWgkkN/fDtdZFwHnOMZOnY1RUHAsMc/5cr5Q6Q2v9e700pUqpjzCu4SUYwxhcwwsGAb9prTc08dGfYVzDqzEqVq7BaKV4t7nHLoQQ4rDzI0ZhcCzwsVJqIMZM78/V22YZMEEpZdLGxH4Nx+v3xyjs7gDu9xEHVNF0HDDC+e9q5+c0tBJjuJw3JVrrnV6WH2gcYMWoQD/fy/oAIF4pFau1LqDxOKBEKbUO75MgeqWNyYtPc16LCRhxwHH1fm5QSo3RxuR8rjQ7lVI/YVyb0zGGMoAxrCEG+Fhrnd/ER7+L0UBzNbAcIw7IBz5v7rGLzkcK+0K0DVcm11jvGdc6bxkiGN3WGnJNruLXYPlSjILiOKXUeoxa/W/rrbtTKRWJkYko3MfrRzuXxQMPNXK8TYl0/rvfx3pfywFoWPvu5Ot8GxPr/PeOJrYLc/7b1HHntOCza2mtd2O0bLwOoJRKwpi74EyMWvYRDZLMxbiGV+Es7NP8Vn201hal1HzgNufEgtOBL50tI0IIITqnhvP21B+v77IMuAAYqZTai/H610yt9Tbnele+2Y/G44CwRtb
"text/plain": [
"<Figure size 1080x480 with 2 Axes>"
]
2021-02-04 13:34:25 +00:00
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-02-04 13:34:25 +00:00
}
],
"source": [
2021-02-13 11:53:15 +00:00
"fig, ax = plt.subplots(nrows=1, ncols=2)\n",
"fig.set_figwidth(9)\n",
"# fig.set_figheight(15)\n",
"\n",
"args = {\n",
" 'normalize': 'true',\n",
" 'colorbar': False,\n",
" 'display_labels': playlist_names\n",
"}\n",
"\n",
"plot_confusion_matrix(clf, data_test, labels_test, ax=ax[0], **args)\n",
"ax[0].set_title('Unweighted SVM')\n",
"\n",
"plot_confusion_matrix(wclf, data_test, labels_test, ax=ax[1], **args)\n",
"ax[1].set_title('Weighted SVM')\n",
"\n",
"fig.tight_layout()"
2021-02-04 13:34:25 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
2021-02-13 11:53:15 +00:00
"## Unweighted Classes\n",
"\n",
"From the above unweighted scenario, it is clear that the Pop playlists was not effective for classifying similar tracks. This is likely primarily due to the larger size of the Rap (\\~800), EDM (\\~1,300) and Rock (\\~700) playlists compared to Pop (\\~125). Additionally, there is overlap with other genres such as Rap and EDM where much of the confusion occured. Also not helping is that one of the sub-playlists is shared across EDM and Pop, electropop. as EDM is already such a larger playlist it is unsurpising that this performance was poor. The overlap with Rock is understandable as Pop contains an Indie Pop sub-playlist which could have cause some confusion. Quite surprising was the confusion for Jazz as I wouldn't have thought there would be much overlap here.\n",
2021-02-13 11:53:15 +00:00
"\n",
"The other major confusion was with Rock and Metal, specifically classing Metal tracks as Rock. This could be expected due to the similarity in tone.\n",
2021-02-04 13:34:25 +00:00
"\n",
2021-02-13 11:53:15 +00:00
"## Weighted Classes\n",
"\n",
"When weighting the classes by prevalence in the dateset, the model is generally better at classification. The clearest difference is the ability to classify Pop songs. Without weighting, no songs were correctly classified as Pop but were instead mis-identified as Rap, EDM, Rock and Jazz. When re-weighting, the Pop playlist was now correctly classified almost 60% of the time. Mis-identification as Rap, EDM and Rock dropped from a combined 85% to just 20%. Interestingly, the mis-classification of Pop as Jazz increased from 15% to 21%.\n",
"\n",
"The improved accuracy of the Pop model reduced the accuracy of some others. The accuracy of Rap, EDM and Rock decreased as some tracks were instead classified as Pop. EDM and Rock were worse affected than Rap with around 15% Pop error rate compared to Rap's 9%. As discussed previously, this could be due to the overlap in aural tone. The overall of Rap was not significantly affected by this Pop error rate as, to compensate, the EDM error rate dropped from 12% to just 3%."
]
2021-02-04 13:34:25 +00:00
},
2021-02-13 11:53:15 +00:00
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 15,
2021-02-13 11:53:15 +00:00
"metadata": {},
"outputs": [
{
"data": {
2021-03-10 20:01:13 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA/sAAAOzCAYAAADncxY7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAEAAElEQVR4nOzdd3gU1frA8e/ZZDe9N3qH0ItdKYIUQVAsqFgQRMXeu957Rb3+7PXasWGviKIgCKICgorSe0kI6b23ze75/TG7STa7m4IJgfB+nidPYGbO7JzdzZz31FFaa4QQQgghhBBCCNF2mFr7AoQQQgghhBBCCNG8pLIvhBBCCCGEEEK0MVLZF0IIIYQQQggh2hip7AshhBBCCCGEEG2MVPaFEEIIIYQQQog2Rir7QgghhBBCCCFEGyOVfSGEEEIIIYQQoo2Ryr4QQgghhBBCCNHGSGVfCCGEEEIIIYRoY6SyL4QQQgghhBBCtDFS2RdCCCGEEEIIIdoYqewL8Q8ppbRS6ufWvo62Sik12vEez22Gcx3Rn5VS6mellG7t6xBCCHFkUkq95yjLurX2tTQ3pVSiUiqxGc5zRL9HSqlZjuub1drXIto+qewL4YHjJiyVLkApdYHj/fjMy/77HfvLlVL+HvZ3d+zf3/JX27yas6GhpSilxiulvlZKpSqlKpVSeUqp3UqpL5RStyillOO4jxx5uaER51zmOPY8x/+dgYlWSv1aT7puSim7/P0IIY5lte6Xzh+bUipbKfWTUurS1r4+T5RSPkqpfKWUVSkV4mF/+1r5udLLOX5x7B/V8lfcvJqroaGlKKU6KKWeV0ptV0qVKqXKlFJJjvf8MaVUT8dx4x2fwR+NOOeljmO/qbXN+Rnbnef0knZlrWNnNUsmRYuQyr4Q/1w/4IrWvogWtBKwA6OdFcc6xgIa8AOGe9kPsPwQX/8PjPf45UNM32YppR4AlgFTgL+Al4A3gc3A6cCLgI/j8HmO31c3cM5uwDggDVhUZ3cVMFIpFe8l+dWAchwnhBDHuocdP08AvwCjgI+UUs+16lV5oLW2AT8DvhjXWZezLNfAGXV3KqUCgVOAEmDtIV7G2FqvIxyUUgOBLcBtGOXrfOA5YCkQDjwAjHEcvhxIAE5USg1q4NTXOH6/WWd7FUZZfpWX6+kNjEbK+qOCVPaF+Ie01ju11kmtfR0tRWudC2wCYoGBtfcppfyA04CvMRoE3AKAWttWHOLrlzre4+xDSd9WKaW6Ao8AhcBxWuuztdZ3aa3v1VpPA9oBEwEbgNb6Z2A3MEwpdVw9p74Ko5B/V2tdtyD/zvHbrcFAKeUDXAn8CWQccsaEEKKN0FrPdfw8qLW+ADgTo7J82xE6xPwnx29vZXkZsJiaimVtIwALsEprbT2UF9da79Na7zuUtG3cC0AkMFdrPVhrfb3jO3WN1noI0BNYDaC11sBbjnTXeDwboJTqhdEpcBBYUmd3BrAeuFIp5eshuTMGqNshII5AUtkX4h/yNA9cKTXXsX20UmqaUuoPx7CrXKXUp0qpjl7OFamUelwptcMxRKtAKbVCKTXBw7FhSqm7HcMCkx1DuLOUUt8qpU6t71qVUu2UUm8ppVIcwwtnNZBNZ0W9bgBwKhAALAA2etgPRlCgqQkiUEoFKmP4/0alVIlSqlgptVYpdYmHa/Y6lF4pdaJjyHmRUqpQKbVcKXVq7fffy/sQrZR6UymVppSqUEptqzssUSn1HsaoBoCH6gzJHF3n2EscQ9rylTGdYYdS6l+OxhBPrz9dKfWX4zPOVEp9oJTq4OnYepyM0Wu/Umu9pe5OrbVda73UUfA7OXv3PQYAtSrstYOF2rZh9NjMVEqZ6+ybDHSo9RpCCCFq0VqvAHZiNKie6NyulDpeKfWVozyoUEodUEq9qpRq39A5lVJ9HeXSynqO2aKM4fkNnc9bWe/ctgajN7mjhxFeHhv2lVJnKqUWK2MaQ4VSap9S6mmlVLiH6/Q4lN4R77zgiHXKlVI7lVJ3KKV6OPL+nrcMKaWudeS/XCmV4Sj7w2rtH62MaWddga51yvr36pyrrzLWAzjoiLkylFIfe3gvnMf3UsaUujxHrPObUmqyt2utx2mO3y962qm13q+13llr0zsYve6XKw/TKx2cI/He1lrbPeyfh9FpMKX2RkfZPwv4Ddje2AyI1iOVfSFa1g3Ah0Ai8AqwFbgYWF63IqiMntq/gPuALOB14DOMIew/KKXqVtD6AY9h9Kh/jzGk60eMAvdXpdREL9cUCazDGG63AGN4fEM9sc6Ket3hdWNr7V8JnKBqzfVTSg3AKCy2aK2zHNvCMVqg/w+j1/kdjCFpMcDHSqn/NnAtznOPAlZhNCYsduSjzHEdJ9WTNBwjYDkV+NLx2h2Ad5RSM2sdt9CxD4zhlw/X+kmsdR3vAB8DvYCvMD7nXOBRjM/NpVVcKXU78AnQA3gfeBcYhFFwRjQm7w45jt89HJX0xpgPVAKXKGPIZV2TgI7Acq11gpdzzMP4rKbW2X4NUIyRNyGEEJ45p8NpAKXUFIz7/9kYQ7CfA3YB1wPrlVLd6zuZo5K3EmOqXR+3F1PqNIxRed9ordMaONc2jHhgiFIqqtY5emJUhp1lPXiPB1bUSvcQ8ANG4/T3GFPN9gJ3AWuUUqH1XY/jHP6O170VyMSo8P4MPAg820Dypxw/mzDK5hSMsurrWsckYpTrBY6f2mX9wlrXMRH4G7gMYwTbC468ng/8oeqMmFPGUPd1wDSMRvIXgWTHOc9vKN91OMt7t8/XE611OsZIvAjggrr7HXHJTGpiME8+wZiSUXck3zkYIz2lYf9oobWWH/mRnzo/GIWwbsKxP9fZNtexvRAYVGffx459F9XZ/jNGxX16ne3hGL3mZUBcre1hQLSH6+kEpAI7vOULo5Lp24T3IwijkpgP+NTavgbY6fj3ZMe5p9Taf7Nj23O1tr3n2HZPndfwxwgK7MDQWttHO46fW2ubCdjj2D6pznmuq5XP0V7y/1adfPTHaAXfXud4t9eus3+WY/8CIMDLd+DWWtu6Od7HXKBbnfx81cTvXRBGkKKBX4HZwIDa+fKS7jNHmlke9n3j2DfNSz7/63jdAmBprf0dHe/fPMf/kxubD/mRH/mRn7b24+1ejrEeit3x0xUIxqjI2YCRdY6913GeZXW2O8vQbrW2TXNse8bDazqPH9/Ia3fGKNNqbbvGse0UjMaKTODLWvvDHGVANqAc25yj+n4Dwuu8hrNMeb7O9kQgsc62fzuO/cR5bsf2zhgdIxp4z0uek4Autbb7OspLDZzU0GvX2hcB5Dny17/OvoEYDd1/19m+jDoxgGP7VGpikVmN/EyecRyfDjyEsaZCaANpJuEhPnXsO8+x7zsv391kx7/fcnyunWrt/wEjBgjEiAkanQ/5aZ0f6dkXomW9pN2HWDtbQ6t7n5VSQzDmTn2ltf609sFa63yMm7s/tVpotdYF2sM8dq11MkaPdV+lVBcP11QJ3KXd52N7pbUuAX7HKNCPd1xzMMYwRGev/yqMgKX28D+XYX2OnoLLgfVa66fqvEY5RnCjgIZWKz4Noyd9pda67lyzNzHmpntTCtyhjcWInK+9HaPhop8jX411K0ZBOFtrXVZn36MYQdxltbZdBpiB/2mtE2u9vh24GyMAbBTHZ3IORkPQSOBtjJEjRcpYnfcGL9MInAvxuLTWO4Z3noURxH1TN1Gd1/0YGK9q5pzOxphSIC39Qgjh4JhSNlcZq6V/iVFRUsALWusDGBW/SOAzrfWqOsmfxaiAjvdSlte2EGNR1Vm17/uOkXQXAfto/CK5nobynwEUYZTdGqNzovaivaOpmVamHdtucfy+xhHHVNNav4dRdtUuH72ZiVE23l/r3GitD2L0rtfnEV1rTSVH3POu47/1jQCs6wqMjpeHHPFCNa31Voyyb5hSqj+AUqoTMB5jobyX6xz/DcZowaZ40PEaURgdCb8A+Y7pDC8opXp4SLMUOACc7hhlUJuz/K+7MF9d8zA+19lQPQJ1PPCR1rq0iXkQrcTTogtCiOaz3sO2g47ftYdsO+fYhynPj3m
"text/plain": [
"<Figure size 1080x960 with 4 Axes>"
]
2021-02-13 11:53:15 +00:00
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-02-13 11:53:15 +00:00
}
],
"source": [
"fig, ax = plt.subplots(nrows=2, ncols=2)\n",
"fig.set_figwidth(9)\n",
"fig.set_figheight(8)\n",
"\n",
"plot_confusion_matrix(lclf, data_test, labels_test, ax=ax[0][0], **args)\n",
"ax[0][0].set_title('Linear Weighted SVM')\n",
"\n",
"plot_confusion_matrix(pclf, data_test, labels_test, ax=ax[0][1], **args)\n",
"ax[0][1].set_title('Poly Weighted SVM')\n",
"\n",
"plot_confusion_matrix(wclf, data_test, labels_test, ax=ax[1][0], **args)\n",
"ax[1][0].set_title('RBF Weighted SVM')\n",
"\n",
"plot_confusion_matrix(sclf, data_test, labels_test, ax=ax[1][1], **args)\n",
"ax[1][1].set_title('Sigmoid Weighted SVM')\n",
"\n",
"fig.tight_layout()"
2021-02-13 11:53:15 +00:00
]
},
2021-02-04 13:34:25 +00:00
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"## Other Tests\n",
"\n",
"Take a handful of other tracks which I don't listen to and aren't in any playlists to see if they can also be classified"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 16,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2021-02-04 13:34:25 +00:00
"text": [
"Top Of The World (Five Finger Death Punch) could be ROCK ✓\n",
"Aston Martin Music (Rick Ross) could be ALL RAP ✗\n",
"On The Sunny Side Of The Street (Dizzy Gillespie) could be JAZZ ✓\n",
"Vibez (ZAYN) could be ALL RAP ✗\n",
"Shot In The Dark (AC/DC) could be ROCK ✓\n",
"To Hell and Back (Sabaton) could be ROCK ✗\n",
"Withstand The Fall Of Time (Immortal) could be METAL ✓\n",
"Alone Together - Rudy Van Gelder Remaster (Kenny Dorham) could be JAZZ ✓\n",
"Feel No Ways (Drake) could be ALL RAP ✗\n",
"BO$$ (Fifth Harmony) could be EDM ✗\n",
"\n",
"50.00% Accurate\n"
2021-02-04 13:34:25 +00:00
]
}
],
"source": [
"### PREPARE ###\n",
"test_uris = [\"spotify:track:53yqxU2EKKzbuQZEUEVtxc\",\n",
" \"spotify:track:5W7xC99N2Zzfh69r7I7zWK\",\n",
" \"spotify:track:38R2EViAkYOFG8ZkG3GLtW\",\n",
" \"spotify:track:6T6D9CIrHkALcHPafDFA6L\",\n",
" \"spotify:track:0sfdiwck2xr4PteGOdyOfz\",\n",
" \"spotify:track:1BrgjqSg9du0lj3TUMLluL\",\n",
" \"spotify:track:5nCnSnLtotQ8eB4E189U91\",\n",
" \"spotify:track:3GOZbK2epuHzCt5YvvVFHO\",\n",
" \"spotify:track:3cjF2OFRmip8spwZYQRKxP\",\n",
" \"spotify:track:1COvXs6jaykXC73h9OSBVM\"]\n",
"# inferring what playlists these would go in\n",
"test_labels = [\"ROCK\", \"RAP\", \"JAZZ\", \"POP\", \"ROCK\", \"METAL\", \"METAL\", \"JAZZ\", \"RAP\", \"POP\"] \n",
"\n",
"test_tracks = spotnet.tracks(uris=test_uris)\n",
"spotnet.populate_track_audio_features(tracks=test_tracks)\n",
"\n",
"test_features = [ {j: k for j, k in i.audio_features.to_dict().items() \n",
" if j in headers} \n",
" for i in test_tracks] # filter down to descriptor columns\n",
"\n",
"### PREDICT ###\n",
"predictable_frame = pd.DataFrame(test_features)\n",
"\n",
"predicted_labels = clf.predict(predictable_frame)\n",
2021-02-13 11:53:15 +00:00
"# predicted_labels = wclf.predict(predictable_frame)\n",
2021-02-04 13:34:25 +00:00
"labels_correct = [i == playlist_names[predicted_labels[idx]] for idx, i in enumerate(test_labels)]\n",
"\n",
"### EVALUATE ###\n",
"for track, label, correct in zip(test_tracks, predicted_labels, labels_correct):\n",
" print(f'{track.name} ({track.artists[0].name}) could be {playlist_names[label]} {\"✓\" if correct else \"✗\"}')\n",
"\n",
"correct = sum(labels_correct) / len(labels_correct)\n",
"print(f'\\n{correct*100:.2f}% Accurate')"
]
},
2021-02-23 18:22:32 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SVM Cross-Validation"
]
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 11,
2021-02-23 18:22:32 +00:00
"metadata": {},
"outputs": [
{
"data": {
2021-03-10 20:01:13 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAd4AAAG+CAYAAAA9amWZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABJ0AAASdAHeZh94AAEAAElEQVR4nOz9d5wk613fi7+f6jg557yTZ3Zmd2d2T5Z0JCEhgQQYdI0xGDBwne+1zTXXvr72BX4m2IZrYxv7GjAYMMJE2xhEkBCS2D3n6Ohsmpxzzrlz1fP7o7p6e3q6ezrOhlPv12tfc0539VPV1VX1eb7f5xuElBITExMTExOTq0F52gdgYmJiYmLyfsIUXhMTExMTkyvEFF4TExMTE5MrxBReExMTExOTK8QUXhMTExMTkyvEFF4TExMTE5MrxBReExMTExOTK8QUXhMTExMTkyvEFF4TExMTE5MrxBReExMTExOTK8QUXhMTExMTkyvEFN5nDCHEm0IIKYT4kSvYlxRCfDnb+0kUIcSiEGLxaR+HiYnJE2I9k4QQXxZCJFXs/yqeOVf5DE2VF0Z4gyc6/J8qhNgPXhzfK4QQT/sY3++kcqM+Cwid7xVCfE0IcSiEOBJCDAshfloIUZnm2HVCiH8mhHhHCLErhPAHx38khPgPQog3M/Mt3l8IIbqEEP9OCDEaPJ8+IcS6EOJzQojvF0I4nvYxmqSGEKI5+Iz/5ad9LKlifdoHkAV+NPjXBrQBfwH4EHAb+DtP66CeUboB19M+iDA++rQPIAb/DPi/gSPgt4BjoBP4m8AfANupDCqE+GvAvwGcwCTw34Nj5QXH/27gbwohflpK+UNpfof3DUKI/wf4YXTD4h3gV4BToAp4E/hP6L/d7ad0iC8K3w3kPu2DiMLX0J9tu0/7QGLxwgmvlPJHwv9fCPE68OfA3xJC/L9SyoWncmDPIFLKyad9DOFIKeee9jHE4G8BEngl/JwJIfIASyoDCiG+G/g54AD4S1LK34uyTSHw14HGVPbxfkQI8Y/RJ98rwP8ipXw3yjafAv6Pqz62Fw0p5fLTPoZoSCld6BPZZxcp5QvxD/3BKGO8NxZ8/zNR3nsZ+B1gE/Ch37A/B9TGGOsO8HngBN3y+VPgVeBHgvt4M2zb5uBrvxxjrC9HHjP6jFwCPxLx+iC6dTQE7AMeYAb4f4GSKGN/b3Cc7wU+EdzXUfj+gu9/OeJzBcA/BUaD3+8EmAN+ExiMso/fBeYBd3D7t4DvitjOOA/R/n05bLtFYDHKd3EA/wgYQbfQj4G7wF+Msm3onAf/+zfQZ74e4D7wqRSure3gubNk6FotCv6GEvhYAttbI/6/Fvh/gufauG7XgV8Hei45J63o1/te8Lf9PHA9uF0F8PPARvB8vQd8OMp4PxIc703gO4AHwd9lHfhXgCO43UeC190x+gTjvwBlUcb7cHC/48Ft3cHr74cBZxLntTl4LnzGd4qzrSPG+elAv9a3AY3g/YxuPf+N4Dk5Bc6C//03ASXK+B8Afh9YBbzB3+mrwA9HbFcF/DQwFRzzMPjfvwxcS+A7/8fgsX9zjPdfDr7/O2GvdQD/PHg/7ASPbyn4G9RHGeNNoj+TvkyUZy5gR3+GzAXHXgB+DP0+jvbMSfh6Drv2ov373njHG3yvHfhVYC1sP78KtF9ynX8G3ZJ2od+7vwHUpfoMeOEs3kvwh/+PEOL70C82L/A/0UW3HfgB4NNCiFdk2KxOCPFB9AeVBfhv6BdWH/Al4M+yfOz/K7rb/CvoYq+gi/EPAp8UQrwspTyJ8rnPoAvvH6HfpE2xdhBcB/9j4DV0F91/AgJAPfrD8S76Q9bg/0Of1Pw5+sO6DPgG4L8IITqllP80uN0huhXyvcH9/2jYGIvxvrQQwg78CfpywSTw79HdW58BflMIcVNK+Y+jfLQJ/UaZR3/glwLfDvyeEOLrpJRfirffCH4MfdLzD4GfSOJzsfgMUAK8LaX8wmUbSykDES99EH0i8iX0ic8p+nX7GeCbhBCvSymHogzVDLwLTPBkYvIXgC8LIV5F/+2P0YWnFPhLwB8JITpkdOvmfwM+CfwP9Ifwx4G/D5QKIX4P/eH0OfR77DXgu4Dy4GfC+YdAF/B2cHsn8Dr6g+/N4O+lxjxBT/ir6EtMvyGlHI23oZTSG+XlVvTzMw18FshBPx+gX0N/Gf0Z8Z/QH8h/AfgPwBvAdxqDCCE+Efwex+jPlTX089mN7j350eB2uehi0wp8AV2oBfq1+83oE6T5S77zr6B7Rb4buOA1Ab4n+PeXw177VvRJxJfQz7kP6OXJc++2lHLtkv1GJfgM+a3g8c8BP4suxN+H/qyMRjLX85eBYuDvohsh/yNsnMeXHNsd9GdnAfrvMo5+3X0X8M3B6+y9KB/9W8A3BT/zFfTJzLcDN4LPn2jXUnxSVexn7R8xLF70H1VFF9eaiFmfD5glYuaCvtaoAv897DUF3cKUwCcjtv8bPJl1vRn2ejOZs3ibiGJxAd8f3P4fRrz+vcHXNeATcc7Zl8P+vy/42n+Psq1ChGUNtEbZzg58EX2SE3leL3zfiPcXibB4gf8reEx/SJjlB1QGt5fAa1HOueSidfH1xlhJXlv/OGzMf5jMZ2OM90vBsf5Zip+vBAqivH4D/aH1RxGvh5+T/zvivX8afH0ffWKmhL33V4Lv/euIz/xI8PUjoDvsdQf6RExFt6g/FHH9fCH4uZsR410DRJTv88+C2397gufli8HtfyDJ8xl+fn4iyvvfEXzvIZAf9noeutUogb8c9vrvBl+7EWWs8rD//nS08yuf3EcXfuMYxz+F/nwrjXjdEfxdtyLunTrCLP6w1z8e/O3+v4jX3yRBixd9ciLRJ+7OsNdL0YU4msWb6vX8yzHOx4XjRZ/QTARf/86I7b89+PpkxPVvXOfHQF/EZ349+N4Fr1si/16YqGYDIcSPBP/9uBDiN9FnOAL4B1LKjbBN/yb67PjvyojZnZTyi+izm08LIQqCL7+GHqz1JSnlH0Xs9ufRZ8lZQ0q5JKPP+n8J/cL4+hgf/T0p5R8nuTt3lP1rUsqDiNcurMlKKX3oVqmVzARLfR/6Bf6DMszyk1Juoz+YQZ+pR7KEbqmGH9ufAMvAS4nuXAjxT4AfR3eh/nvgnwshfioySl4I8Y+DkZbfmsCw1cG/F6wKIURx2DUc+hfxPbZlFO+G1K2CPwM+LISwRdnvIrqLMZxfCf51AD8kpdTC3vt1dI/HzRjf499KKSfC9u9Ft5YV4HNSyq+EvacBvxb83xsRxz0vg0+zCP518G+sazuSmuDf1QS3j2SL894Yg+8L/v1HUspT40Up5Rm6tQ7Rr8Fo91G0gJ9o2/mi/cYx+BV0of6OiNc/je5Z+WzEvbMmo1hpUsrPo0+cEj3f0firwb//WErpCRt7nyf3a+R+U72ek+E1dOv2HSnlZyP285vAPfSAxjeifPbfSilHIl77heDfhJ8l4byIruYfjvh/CXy/lPI/R7z+avDvh4IuiEgq0V3KHeju1VvB1+9Fbiil1IQQbwe3zQrBC++vo7v/etDXCcMnTnUxPvq1JHYzju6u+Q4hRBO66+oecD8oqJHH1Ij+4PkoegBQTsQmsY4pIYKTnjZgTUYPBDPc+7eivPc4xkRlhSe//WX770B/EL+NPnGTQcH9B0C5EOIHwvbRHvx7P5Gx41DMxWsY9Nl3+LF9I7qn5Ta6+zbyXi5Hd/+HE+2crAf/Tkc+/KSUqhBiC32pIRrRvqsx3oMo7xkTjXPjBYPU/i6667YD3RUYPrFJ6zpKgqFoggQMoHuOvhzlva+gW4nh1+Bn0d257wYn/18C3pJSRk4IvoJ+Tv6REGIA3avzFhG/kxCiGPh7Ufb9M1LKQ/Q1yn+G7lb+92HvR3MzG+7g70T3it1AF+fwIMEL93oSGOfqwnOS6OfPOKZUrudkjwtiLwn+Gbro3kJfOgsn2nW+EvxbksrBvHDCK6UUELqZXwV+EfiPQoglKWX4SS8L/r0sTSM/+Lco+HcrxnaxXs8Uv4n+YJpHF8RNdPcS6DdlrLzEzUR3EHzQfgQ90OE
2021-02-23 18:22:32 +00:00
"text/plain": [
"<Figure size 720x480 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-03-10 20:01:13 +00:00
},
{
"data": {
"text/plain": [
"(0.5, 20)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
2021-02-23 18:22:32 +00:00
}
],
"source": [
"# %matplotlib widget\n",
"\n",
2021-03-10 20:01:13 +00:00
"sigma = [0.01, 0.05, 0.1, 0.5, 1, 5, 10, 15, 20, 30, 40, 50, 60] # for gamma paramater, for non-linear hyperplanes. Higher gammas (lower sigmas)\n",
"C = [0.01, 0.03, 0.1, 0.25, 0.5, 0.75, 1, 2, 4, 6, 8, 10, 20, 30] # for C. Penalty for error term. Balances smooth decision boundary with correctly classifying points (wiggle)\n",
2021-02-23 18:22:32 +00:00
"# sigma = [0.01, 0.03]\n",
"# C = [0.01, 0.03, 2]\n",
"\n",
"scores = np.zeros((len(sigma), len(C)))\n",
"param_pairs = list()\n",
"\n",
"for x, s in enumerate(sigma):\n",
" for y, c in enumerate(C):\n",
" clf = svm.SVC(kernel='rbf', gamma = 1/(2*(s**2)), C=c, class_weight='balanced')\n",
" clf.fit(data_train, labels_train)\n",
" scores[x, y] = clf.score(data_test, labels_test)\n",
" param_pairs.append((s, c))\n",
"# print(scores[x, y], s, c)\n",
" \n",
"# print(scores)\n",
" \n",
"index = np.argmax(scores)\n",
"sig_max, c_max = param_pairs[index]\n",
"\n",
"X, Y = np.meshgrid(C, sigma)\n",
"\n",
"fig = plt.figure()\n",
"ax = plt.axes(projection='3d')\n",
"surf = ax.plot_surface(X, Y, scores, cmap=mpl.cm.coolwarm)\n",
"ax.set_title('Regularisation & Gamma Cross-validation')\n",
2021-03-10 20:01:13 +00:00
"# ax.set_xscale('log')\n",
"ax.set_xlabel('Regularisation (C)')\n",
2021-02-23 18:22:32 +00:00
"ax.set_ylabel('Sigma')\n",
"ax.set_zlabel('Score')\n",
"ax.view_init(50, -140)\n",
"\n",
"# fig.colorbar(surf, shrink=0.3, aspect=6)\n",
2021-03-10 20:01:13 +00:00
"plt.show()\n",
"sig_max, c_max"
2021-02-23 18:22:32 +00:00
]
},
2021-02-04 13:34:25 +00:00
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"# Imports & Setup"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 2,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"\n",
"from google.cloud import bigquery\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.dpi'] = 120\n",
"\n",
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
"from analysis.query import *\n",
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import svm\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import plot_confusion_matrix\n",
"\n",
"client = bigquery.Client()\n",
"spotnet = get_spotnet()\n",
"cache = 'query.csv'\n",
2021-03-10 20:01:13 +00:00
"first_day = datetime(year=2017, month=11, day=3)\n",
"sig_max, c_max = 0.5, 20"
2021-02-04 13:34:25 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"## Read Scrobble Frame"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
2021-03-10 20:01:13 +00:00
"execution_count": 3,
2021-02-04 13:34:25 +00:00
"metadata": {},
"outputs": [],
"source": [
"scrobbles = get_query(cache=cache)"
]
},
{
"cell_type": "markdown",
"metadata": {},
2021-02-04 13:34:25 +00:00
"source": [
"## Write Scrobble Frame"
]
2021-02-04 13:34:25 +00:00
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.4"
2021-02-04 13:34:25 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}