listening-analysis/playlist-nn.ipynb

549 lines
165 KiB
Plaintext
Raw Normal View History

2021-05-06 16:19:44 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Playlist Neural Network\n",
"\n",
"Given a list of playlists, can unknown tracks be correctly classified?"
]
},
{
"cell_type": "code",
"execution_count": 3,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
2021-05-10 00:18:57 +01:00
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
"playlist_names = [\"ALL RAP\", \"DNB\", \"4/4\", \"cRock\", \"METAL\", \"cJazz\"] # super-genres with decomposed EDM\n",
2021-05-06 16:19:44 +01:00
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
"\n",
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
2021-05-10 00:18:57 +01:00
"headers = float_headers + [\"mode\", \"loudness\", \"tempo\"]\n",
"# headers = float_headers\n",
2021-05-06 16:19:44 +01:00
"\n",
"BALANCED_WEIGHTS = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pull and process playlist information.\n",
"\n",
"1. Get live playlist track information from spotify\n",
"2. Filter listening history for these tracks\n",
"\n",
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
]
},
{
"cell_type": "code",
"execution_count": 4,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
"\n",
"# filter playlists by join with playlist track/artist names\n",
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
"\n",
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
"# distinct on uri\n",
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
"# select only descriptor float columns\n",
"filtered_playlists = [i[headers] for i in filtered_playlists]"
2021-05-06 16:19:44 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Construct the dataset with associated labels before splitting into a train and test set."
]
},
{
"cell_type": "code",
"execution_count": 5,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.concat(filtered_playlists)\n",
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
"labels = np.concatenate(labels)\n",
"\n",
"# stratify: maintains class proportions in test and train set\n",
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, \n",
" test_size=0.1, \n",
"# random_state=70, \n",
" stratify=labels\n",
" )\n",
"\n",
"class_weights = class_weight.compute_class_weight('balanced',\n",
" classes=np.unique(labels_train),\n",
" y=labels_train)\n",
"class_weights = {i: j for i, j in zip(range(len(filtered_playlists)), class_weights)}\n",
"\n",
"labels_train = tf.one_hot(labels_train, len(filtered_playlists))\n",
"labels_test = tf.one_hot(labels_test, len(filtered_playlists))"
]
},
{
"cell_type": "code",
"execution_count": 6,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
" return tf.keras.callbacks.TensorBoard(\n",
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-05-10 00:18:57 +01:00
"def get_model(hidden_nodes=128,\n",
" layers=2,\n",
2021-05-06 16:19:44 +01:00
" classes=len(filtered_playlists),\n",
" activation=lambda: 'sigmoid', \n",
" weight_init=lambda: 'glorot_uniform'):\n",
" l = [tf.keras.layers.InputLayer(input_shape=data_train.to_numpy()[0].shape, name='Input')]\n",
" \n",
" for i in range(layers):\n",
" l.append(\n",
" tf.keras.layers.Dense(hidden_nodes, \n",
" activation=activation(), \n",
" kernel_initializer=weight_init(), \n",
" name=f'Hidden{i+1}')\n",
" )\n",
" \n",
" l.append(tf.keras.layers.Dense(classes, \n",
" activation='softmax', \n",
" kernel_initializer=weight_init(), \n",
" name='Output'))\n",
" \n",
" model = tf.keras.models.Sequential(l)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Single Model"
]
},
{
"cell_type": "code",
"execution_count": 8,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
2021-05-10 00:18:57 +01:00
"name": "stdout",
"output_type": "stream",
2021-05-06 16:19:44 +01:00
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"Hidden1 (Dense) (None, 64) 704 \n",
"_________________________________________________________________\n",
"Output (Dense) (None, 6) 390 \n",
"=================================================================\n",
"Total params: 1,094\n",
"Trainable params: 1,094\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
2021-05-06 16:19:44 +01:00
]
}
],
"source": [
2021-05-10 00:18:57 +01:00
"model = get_model(hidden_nodes=64, layers=1)\n",
2021-05-06 16:19:44 +01:00
"\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
" loss='categorical_crossentropy', \n",
" metrics=['accuracy'])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": 9,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"if BALANCED_WEIGHTS:\n",
" cw = class_weights\n",
"else:\n",
" cw = None\n",
"history = model.fit(data_train.to_numpy(), labels_train, \n",
" callbacks=[tensorboard_callback()], \n",
" validation_split=0.11,\n",
" verbose=0,\n",
" class_weight=cw,\n",
" epochs=50)"
]
},
{
"cell_type": "code",
"execution_count": 10,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoQAAAG/CAYAAADB4sa8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AAB6/klEQVR4nO3dd3iUVdrH8e9JQkIJVbqC9KKAqAiCotgFseuKooJiWyu2taxdV9e2dnd1XwuKKIKL2CuCIs2G0hWp0qWFQGjJef848/BMJpNkJjOTTDK/z3XNlWeePnOSzD33acZai4iIiIikrrSKvgERERERqVgKCEVERERSnAJCERERkRSngFBEREQkxSkgFBEREUlxCghFREREUpwCQhEREZEUp4BQREREJMUpIBQRERFJcQoIRURERFKcAkIRERGRFKeAUERERCTFJXVAaIzJNsbca4z5xBizwRhjjTFDozi+njHmRWPMOmPMVmPMV8aYgxJ4yyIiIiKVTlIHhEBD4C6gM/BzNAcaY9KAD4HzgGeBvwGNgYnGmPZxvk8RERGRSiujom+gFKuAZtba1caYHsB3URx7FtAHONtaOxbAGPM28CtwLy5QFBEREUl5SZ0htNbusNauLuPhZwFrgP8FnW8d8DZwqjEmKw63KCIiIlLpJXuGMBYHAj9aawtC1s8ALgM6ALPCHWiMaQw0ClmdHThmNrAzvrcqIiIiEleZQAtgkrV2c2k7V+WAsBnwdZj1qwI/m1NMQAhcCdydiJsSERERKUenAu+VtlNVDghrADvCrN8etL04zwNjQtZ1AsaOGjWKNm3axOH2isrLy2PWrFl07dqVGjVKuj2pCCqf5KWySW4qn+SlsklusZTPokWLOO+88wCWR7J/VQ4I84Bw7QSrB20Py1q7FlgbvM4YA0C3bt3Yf//943SLheXk5LB582YOOugg6tSpk5BrSNmpfJKXyia5qXySl8omucVSPtnZ2d5iRM3ckrpTSYxW4aqNQ3nrVpbjvYiIiIgkraocEM4EDgqMRxisF7ANN/yMiIiISMqrEgGhMaaZMaaTMaZa0OqxQBPgjKD9GgJnA+9ba8O1LxQRERFJOUnfhtAYczVQD9crGOBkY8w+geVnAl2pHwKGAK2BJYFtY4FpwCvGmP2AP3G9h9NRD2IRERGRPZI+IARuAvYNen4GftZvJBB2bB1rbb4xZgDwKHAtrlfxd8BQa+2CxN2uiIiISOWS9AGhtbZVBPsMBYaGWb8RuCTwEBERSUoFBQWsWbOGHTt2UFAQOp9C4uzatYsGDRqwatUq/vzzz3K7rkQmtHzS0tLIysqiSZMmpKXFt9Vf0geEIiIiVVlBQQHLli0jLy+P9PR00tPT9wx1lmgZGRk0atSIjAyFA8kouHystezcuZO8vDx27NhBy5Yt4xoU6jdARESkAq1Zs4a8vDwaNGhA48aNyy0YBMjPz2fLli3Url2b9PT0cruuRCa0fKy1rF27lg0bNrBmzRqaNQs3ul7ZVIlexiIiIpXVjh07SE9PL/dgUCofYwyNGzcmPT2dHTviO1iKAkIREZEKVFBQUK7VxFK5GWNIT0+Pe1tTBYQiIiIVTMGgRCMRvy8KCEVERERSnAJCERERkRSngFBEREQqraFDh9KqVauKvo1KTwGhiIiIxJ0xJqLHxIkTK/pWBY1DKCIiIgnw+uuvF3r+2muv8fnnnxdZ37lz55iu89///rdcZ3epqhQQioiISNydf/75hZ5PmzaNzz//vMj6UNu2baNmzZoRX6datWpluj8pTFXGIiIiUiH69etHly5d+OGHHzjiiCOoWbMmt99+OwDjx4/npJNOonnz5mRlZdG2bVvuv/9+8vPzC50jtA3hkiVLMMbw2GOP8eKLL9K2bVuysrI45JBD+O6778rz5VUqyhCKiIhIhVm/fj39+/dn0KBBnH/++TRp0gSAV199lezsbG644Qays7OZMGECd911Fzk5OTz66KOlnnfUqFFs2bKFyy+/HGMMjzzyCGeccQaLFi1SVjEMBYQiIiJJaPhwmDkzsdewNo38/GzS09Mobazj7t3hySfjfw+rV6/mP//5D5dffnmh9aNGjaJGjRp7nl9xxRVcccUVPP/88zzwwANkZWWVeN5ly5bx22+/Ub9+fQA6duzIqaeeyqeffsrAgQPj/0IqOQWEIiIiSWjmTJg0KdFXMVR0KJCVlcVFF11UZH1wMLhlyxZ27NhB3759eeGFF5g/fz4HHHBAiec955xz9gSDAH379gVg0aJFcbrzqkUBoYiISBLq3j3x17DWkp+fH9Fcyom6n7333pvMzMwi6+fMmcMdd9zBhAkTyMnJKbRt8+bNpZ63ZcuWhZ57weHGjRtjuNuqSwGhiIhIEkpE9Wyo/PwCtmzJpXbt2qSnpyf+gmEEZwI9mzZt4sgjj6ROnTrcd999tG3blurVq/Pjjz9yyy23RDTMTHGvx1ob8z1XRQoIRUREJKlMnDiR9evX87///Y8jjjhiz/rFixdX4F1VbRp2RkRERJKKl90Lzubt3LmT559/vqJuqcpThlBERESSSp8+fahfvz5Dhgzh2muvxRjD66+/rureBFKGUERERJLKXnvtxQcffECzZs244447eOyxxzjuuON45JFHKvrWqixlCEVERCThnn32WZ599tlC6yZOnFjs/n369GHq1KlF1odmCV999dVCz1u1alVsJlEZxuIpQygiIiKS4hQQioiIiKQ4BYQiIiIiKU4BoYiIiEiKU0AoIiIikuIUEIqIiIikOAWEIiIiIilOAaGIiIhIilNAKCIiIpLiFBCKiIiIpDgFhCIiIiIpTgGhiIiIJL0lS5ZgjCk0d/E999yDMSai440x3HPPPXG9p379+tGvX7+4nrOiKCAUERGRuDvllFOoWbMmW7ZsKXafwYMHk5mZyfr168vxzqIzd+5c7rnnHpYsWVLRt5JQCghFREQk7gYPHkxeXh7jxo0Lu33btm2MHz+eE088kb322qtM17jjjjvIy8uL5TZLNXfuXO69996wAeFnn33GZ599ltDrlxcFhCIiIhJ3p5xyCrVr12bUqFFht48fP56tW7cyePDgMl8jIyOD6tWrl/n4WGVmZpKZmVlh148nBYQiIiISdzVq1OCMM87gyy+/ZO3atUW2jxo1itq1a3P44Ydz00030bVrV7Kzs6lTpw79+/fn559/LvUa4doQ7tixg+uvv55GjRpRu3ZtTjnlFP74448ixy5dupQrr7ySjh07UqNGDfbaay/OPvvsQpnAV199lbPPPhuAo446CmMMxhgmTpwIhG9DuHbtWoYNG0aTJk2oXr06BxxwACNGjCi0j9ce8rHHHuPFF1+kbdu2ZGVlccghh/Ddd9+V+roTIaNCrioiIiJV3uDBgxkxYgRvv/02V1999Z71GzZs4NNPP+Xcc89l1apVvPvuu5x99tm0bt2aNWvW8MILL3DkkUcyd+5cmjdvHtU1L7nkEkaOHMl5551Hnz59mDBhAieddFKR/b777jumTJnCoEGD2GeffViyZAn//ve/6devH3PnzqVmzZocccQRXHvttTz99NPcfvvtdO7cGWDPz1B5eXn069ePhQsXcvXVV9O6dWvGjBnD0KFD2bRpE9ddd12h/UeNGsWWLVu4/PLLMcbwyCOPcMYZZ7Bo0SKqVasW1euOlQJCERERSYijjz6aZs2aMWrUqEIB4ZgxY9i1axeDBw+ma9eu/Prrr6Sl+ZWWF1xwAZ06deKll17izjvvjPh6P//8MyNHjuTKK6/kueeeA+Cqq65i8ODB/PLLL4X2PemkkzjrrLMKrTv55JPp3bs377zzDhdccAFt2rShb9++PP300xx33HGl9ih+8cUXmTdvHiNHjtxTFX7FFVdw5JFHcscdd3DxxRdTu3btPfsvW7aM3377jfr16wPQsWNHTj31VD799FMGDhwY8euOBwWEIiIiyWj4cJg5M6GXSLOW7Px80tLTobThW7p3hyefjOr86enpDBo0iCeeeIIlS5bQqlUrwGXGmjRpwjHHHEN6evqe/fPz89m0aRPZ2dl07NiRH3/8Mar
"text/plain": [
"<Figure size 720x480 with 1 Axes>"
]
2021-05-06 16:19:44 +01:00
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-05-06 16:19:44 +01:00
}
],
"source": [
"history.history\n",
"plt.plot(range(len(history.history[\"accuracy\"])), history.history[\"accuracy\"], label=\"Train\", c=(0, 0, 1))\n",
"plt.plot(range(len(history.history[\"val_accuracy\"])), history.history[\"val_accuracy\"], label=\"Validation\", c=(1, 0, 0))\n",
"\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.ylim(0, 1)\n",
"\n",
"plt.grid()\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test\n",
"\n",
"Single number below from the evaluate function"
]
},
{
"cell_type": "code",
"execution_count": 11,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
2021-05-10 00:18:57 +01:00
"name": "stdout",
"output_type": "stream",
2021-05-06 16:19:44 +01:00
"text": [
"10/10 [==============================] - 0s 971us/step - loss: 0.7446 - accuracy: 0.7384\n"
2021-05-06 16:19:44 +01:00
]
},
{
"data": {
"text/plain": [
"[0.7446056008338928, 0.7384105920791626]"
2021-05-06 16:19:44 +01:00
]
},
"execution_count": 11,
2021-05-06 16:19:44 +01:00
"metadata": {},
"output_type": "execute_result"
2021-05-06 16:19:44 +01:00
}
],
"source": [
"model.evaluate(data_test.to_numpy(), labels_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get raw predictions from test data to generate a confusion matrix"
]
},
{
"cell_type": "code",
"execution_count": 12,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnkAAAGgCAYAAADW0HHbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AACdjElEQVR4nOzdd1xV5R/A8c/DuICCihMUU0vR1Nzb3FqZmrsst6UNy0pNzdKsLKtf5azMXLmy3Jqj4d6Ke6WCW0FAXCib5/fHgYt470WEq8D1+3697ot8zvme+5zTHd/7rKO01gghhBBCCMfilNUVEEIIIYQQ9idJnhBCCCGEA5IkTwghhBDCAUmSJ4QQQgjhgCTJE0IIIYRwQJLkCSGEEEI4IEnyhBBCCCEckCR5QgghhBAOSJI8IYQQQggHJEmeEEIIIYQDkiRPCCGEEMIBuWR1BR4lcVsryI2C08m94YWsrkKOUMDjyayuQo4QdmtPVlchx3ByypXVVcgR3F0LZnUVcoxb0UHqYT5fAnPt8l3rTNeHWu8HQZI8IYQQQjiMxMQEuxzH2QH6Oh3gFIQQQgghxN2kJU8IIYQQDkPr+KyuQrYhSZ4QQgghHIbW9umudQTSXSuEEEII4YCkJU8IIYQQDiNRumvNJMkTQgghhMOQMXkpJMkTQgghhMOQJC+FjMkTQgghhHBA0pInhBBCCIehE6UlL5kkeUIIIYRwHNJdaybdtUIIIYQQDkha8oQQQgjhMGTiRQpJ8oQQQgjhOBLjsroG2YZ01wohhBBCOCBpyRNCCCGEw5Du2hSS5AkhhBDCccgSKmaS5AkhhBDCcUiSZyZj8oQQQgghHJC05AkhhBDCcciYPDNJ8oQQQgjhMJR015pJd20OFxun+X5BAk3ej6f66/G8/Hk8244kpit2+5FEen+TwNMD4qnbP54un8ezfJtl7Pz1iQz8MYHmg+Op2Ceej6Yl2Ps07MpkMjFmzKecP/8fkZEhbNu2lubNm6QrtmhRX+bPn8mVK2e5evU8S5bMo1Spkqn28fMrxogRQ9m+fR3h4We5fPkUa9f+SbNmjS2O17RpI6ZOncSxY3u4eTOYkycPMGXKRHx8itjhTDPPZHJlxGf9OXhyBWfDNrB6/TQaNamVrlgf30L8Mms0Jy/8Q9Cltfw6/xtKlCxqsZ9XntyM/PxtduxfwNmwDew5uoSxPwynmJ/lNWjXqTn/bvmVc+EbOXpmNWN/GE7+AnkzfZ7Ziclk4quvvuTixbPcvn2DHTu20rx5s6yull3Iey/9TCYTn48eQuCpbYRfPcKGTYto2qx+umJ9ixZh1pwJXAzZR3Dofn5fMJmSpYpb7HcrOsjqY9Dg160et2OnVqzbsIDQK4e4GLKPtesX0Khx3Uydp8haSmud1XV4ZMRtrWD3i/3B5AT+2aPp1kJRorBi6dZEjpyB6R84U81f2Yxbvy+RAZMSqfwEPF/bCQX8tTuRgBMwpIsTPZ5Jyf+f+SCeW9HwVCnFjmOaVnUUX7zqbO9TScW94YUMx86dO42OHdsyfvxPBAYG0aPHK9SsWY1mzVqzdesOm3G5c+cmIGATefPm4fvvJxEXF8d7772FUopq1Z4mIuIqAG+91Zevv/6MZctWsnXrDlxcXOje/WWqV6/Cq6++xcyZc83H3LlzA/nze7Nw4VICA4MoVaok/fv35fbtKKpVe5rLl0MzfJ4ABTyezFT85Bmf0aZdU6b8MJ9TQefp0rUVVaqXp8Pz/dm5/YDNuNy5Pfh3y6/kyevJTxPmERcXz+tvd0EpRdN63bkacQMApRSr10+lbLlSzPhlMUGB5yj1uB+9+3bk5s1b1K/ehVuRtwHo9VoHvhk3hE3rd7Ny+QZ8ixai71svcebUBZ5r/CoxMbEZPs+wW3syHGtv8+bNplOnjowbN4GTJwPp1asHNWvWoEmTFmzdujWrq4eTU64Mxz5K7z1314KZip85axzt2j/HDxNnEhh0hm7dOlK9xlO0fLYr27fZfr3mzp2LrTuWkzePFxPGTyMuLo63B/RBKUXdWq2JiLhm3vdWdBBr/93MvLlLUh3jwP6jHDt2MlXZ8I8H8OHwd1iyeDUbNmzH1cWF8hX82bF9D7/NW5qpc70VHWT7y+gBuHm5n12+a72KTHmo9X4QMp3kKaXeAn4Admmta9vYRwM/aK3fTuM4G4CCWuuK9/n8jYH1dxQlAleATcAIrfUxG3HPAyuBYMBPa23RhKWUOgOUuKMoDDgOfK+1XnL3/vdi7yTv0CnNy6MTGPSiE72fM5KymDhNuxEJ5PeCuR/Z7o3v+10CQRc1a752xuRqvI7jEzRtPkrAwwSLP0uJvRSu8S1gfGHXfDOeZ2pk3ySvZs1q7Nixng8++Jjvv58IgJubGwcP7iA0NIwGDZ6xGTt48Lt8/fVn1K7dhICAvQCULVuGgwd38L//jefjjz8DoHz5cly+HMqVKxHmWJPJxN69W/D0zE3JkhXM5Q0a1GPLlu3c+T5r0KAeGzas5osv/sfIkaMzdJ7JMpPkVa1enr82TmfU8An8OGEeAG5uJjbtmkt42FVaNe9nM/bt97oxcvTbPNOwN/v3Gm+x0v4l2LRrLpPGzuHLTycDULP2U6xc+wvDBv6P6VMWmeO7dGvFhMkj6PXyUFat2IirqwtHTq3i6JFA2j33lnm/Fs/VZ+7C7/hw8HdMm7wgw+eaXZK8mjVrsmvXNgYPHsJ3340FjNfn4cP7CQ0No379hllcw4wneY/aey8zSV71GpXYtGUJw4eNYfy4qYDx3tu9dw1hoVdo1qSzzdj3B/Zj9JdDaVC/HXv3HALA3/9xdu9dzdjvpzBq5HfmfW9FBzH5p1kMev/TNOtTs1YV1m1YwIdDv2TSxBkZPi9bHnqSF9zHPkme7/Qcn+TZo7u2K3AGqKWUKm2H42XUBKA78BowF2gFbFZK+djYP7nevkDTNI67P+m43YFvgaLAYqXUG3apdSb8HZCIsxN0bpTyOnRzVXRo4MSBIAiOsP06vxWlyZMbc4IH4OKs8PYEd1PqfYsWVCiVM17rHTu2Iz4+nl9+mWkui4mJYfr02dSrVxs/v2JpxLZl16495i8ZgOPHT7Ju3UY6d25vLjt69L9UXzIAsbGxrF79N8WL++Hp6Wku37x5G3f/kNq8eRtXrkTw5JNlM3qadtGmXVPi4+OZNWOpuSwmJpa5s1ZQs04lihYrbDO2dbsm7A04Yk7wAAJPnGXzhgDadkjpevT0yg1AaGjq63U55AoAUVExAJQr/wT5vPOwbNHaVPv9s2YrkTdv0b5ji4ydZDbTqVMH4uPjmTJlqrksJiaGadNmUK9eXfz8/LKwdpkj7730a9++JfHx8UyfNt9cFhMTy6yZf1CnbjWK+fnajG3X/jkCdh8wJ3gAJ06cYsP6bXTo2MpqjLu7G25uJqvbAPq/3ZvLIWH8MGkmYLQW5mRKx9vl4QgyleQppUoB9YCBGK1cXe1RqQzarLWeo7WeobV+H3gfKAD0uHtHpVRuoC3wPbCPtOt9Mem4c7TW3wD1gVtJx89Sx85BiSLg6ZE6AXuqlPH3+DnbSV7NsorAizBxcQLnLmvOhWomLze6enu3zLlDNatWrcSJE4HcvHkzVfnu3UZLTpUqT1mNU0pRqVIF9uzZZ7Ft1649lC79eKovEGt8fIpw69Ytbt++neZ+uXPnxtMzN+HhV9Lc70F7qrI/QYHnibyZur779hwFoGIlf6txSinKVyzNgX3/WWzbu+copZ4oTm5P40viwL5j3Iq8zbARr/N0o+r4+Bai7tNV+WT02+wNOMKm9bsBcHNzBSAqKtrimFHRMVSs7J9jfmikpWrVKpw4ccLi9blrl3EdqlSpnBXVsgt576Vf5SrlOXnyNDdvRqYqDwg4CEClStZb6JVSVHyqHHv3HrLYFhBwkCeeKIGnZ+5U5d26dyQs4jAR148RsG8NL77UxiK2cZO67NlzkLf69+Tshd2EXjlE0OntvP5G94yeosgmMju7titwFaPbc2HSv9NuF354Nif9fcLKtvaAB7Ag6e9HSqk3tdaW3zB30VqHKKWOAVn+aRx+XVM
"text/plain": [
"<Figure size 720x480 with 2 Axes>"
]
2021-05-06 16:19:44 +01:00
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-05-06 16:19:44 +01:00
}
],
"source": [
"predictions = model(data_test.to_numpy())\n",
"\n",
"conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
" [tf.math.argmax(i) for i in predictions], \n",
" num_classes=len(filtered_playlists))\n",
"\n",
"normalised_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
"for idx, row in enumerate(conf):\n",
" normalised_conf[idx, :] = row / np.sum(row)\n",
"\n",
"sns.heatmap(normalised_conf, \n",
" annot=True, \n",
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
" cmap='inferno')\n",
"plt.show()"
]
},
2021-05-10 00:18:57 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ensemble Model"
]
},
{
"cell_type": "code",
"execution_count": 13,
2021-05-10 00:18:57 +01:00
"metadata": {},
"outputs": [],
"source": [
"models = [get_model(hidden_nodes=random.randint(16, 128), \n",
" layers=random.randint(1, 2)) \n",
" for _ in range(9)]\n",
"\n",
"for m in models:\n",
" m.compile(\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
" loss='categorical_crossentropy', \n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the models of the meta-classifier. Get a random number of epochs from a reasonable range to introduce variation between models. *Weird?*: Randomly decide whether to regularise by class weights. Class weights change the penalisation methods such that smaller classes are treated more importantly when computing loss. I'm not sure whether it's going to help. "
]
},
{
"cell_type": "code",
"execution_count": 14,
2021-05-10 00:18:57 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2021-05-10 00:18:57 +01:00
"text": [
"training model 1\n",
"training model 2\n",
"training model 3\n",
"training model 4\n",
"training model 5\n",
"training model 6\n",
"training model 7\n",
"training model 8\n",
"training model 9\n"
]
}
],
"source": [
"if BALANCED_WEIGHTS:\n",
" cw = class_weights\n",
"else:\n",
" cw = None\n",
"\n",
"ensem_histories = list()\n",
"for idx, m in enumerate(models):\n",
" print(f'training model {idx+1}')\n",
" h = m.fit(data_train.to_numpy(), labels_train, \n",
" callbacks=[tensorboard_callback()], \n",
" validation_split=0.11,\n",
" verbose=0,\n",
"# class_weight=cw,\n",
" class_weight=random.choice([*([class_weights]*3), None]),\n",
" epochs=random.randint(20, 100))\n",
" ensem_histories.append(h)"
]
},
{
"cell_type": "code",
"execution_count": 15,
2021-05-10 00:18:57 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
2021-05-10 00:18:57 +01:00
"text": [
"77.2% Accuracy, 77.3% Agreement, 62.2% Ind. Accuracy\n"
2021-05-10 00:18:57 +01:00
]
}
],
"source": [
"ensem_results = ensem_classify(models, data_test, labels_test)\n",
"print(f\"{ensem_results[2]*100:.3}% Accuracy, {ensem_results[3]*100:.3}% Agreement, {ensem_results[4]*100:.3}% Ind. Accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
2021-05-10 00:18:57 +01:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnkAAAGnCAYAAADL1UFjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AACjCElEQVR4nOzdd3gU1RrA4d9J2SSQAKEmITQVUEB6twCCgoL0Jh1BVK4iXcUGiCJIEytNEBFQOkhRqvTee5OekNAJpOfcP2azyWZ3QxIWEtbvvc8+uZyZb/bMuOXb00ZprRFCCCGEEK7FLbMrIIQQQgghnE+SPCGEEEIIFyRJnhBCCCGEC5IkTwghhBDCBUmSJ4QQQgjhgiTJE0IIIYRwQZLkCSGEEEK4IEnyhBBCCCFckCR5QgghhBAuSJI8IYQQQggXJEmeEEIIIYQLkiRPCCGEEMIFSZInhBBCCOEkSikvpdQIpdQlpVSkUmqbUurFNMa2VUrtVkpFKaXClVJTlFJ5M1oXSfKEEEIIIZxnGtAX+A14D4gHlimlnk0tSCn1NjALuGaOnwS0BVYrpbwzUhGltc5InBBCCCGESEYpVRXYBgzQWo8yl3kDB4EwrXVNB3Em4DKwH6itzcmZUqoRsATopbX+Nr31kZY8IYQQQgjnaInRcjcxsUBrHQVMAWoopQo5iCsD5AJ+18la37TWfwIRGC166SZJnhBCCCGEc1QAjmutb6Uo327+W95BnJf5b6SdbZFABaVUunM2j/QGCCGEEEK4OqVUfiCfnU3hWuswB2GBQIid8sSyIAdxJwANPANMTVaHksnq4A9cvUe1rUiS9xDF85sMgEwjT7fumV2FR4LJI09mV+GREB17ObOr8MhQSr4W0sLdzTezq/DIiI0LVw/z+Zz4XTsE+MxB+WAHMT5AtJ3yqGTbbWitryil/gA6K6WOAAuAgsC3QCzg6Sg2NfJuFkIIIYTLSEiId9ahfgDm2CkPTyUmkqSu1+S8k2135E2MRG6U+QEwAzgFNMcYm5cukuQJIYQQQqRg7pJ11C3rSAhGC1xKgea/l1J5vptAE6VUYaAocFZrfVYptRmji/hGOusiSZ4QQgghXIfWcZn59HuBOkqpHCkmX1RLtj1VWutzwDkApVQuoBIwLyOVkdm1QgghhHAZWsc75ZFBcwF3oEdigVLKC+gKbNNanzeXFVZKPZmG4w3HaJAbm5HKSEueEEIIIYQTaK23KaXmAMPNs3NPAp0xul+7Jdt1OlALsExKUUp9gLFe3jYgDmgKvAR8rLXekZH6SJInhBBCCJeRkLndtQCdgM+BjhjLnuwHGmmt198j7gDQDGiM0Rq4H2ittbY3+SNNJMkTQgghhMvI5DF5iXe4GGB+ONqntp2ypcBSZ9ZFkjwhhBBCuIzMTvKyEpl4IYQQQgjhgqQlTwghhBAuQydIS14iSfKEEEII4Tqku9ZCumuFEEIIIVyQtOQJIYQQwmXIxIskkuQJIYQQwnUkxGZ2DbIM6a4VQgghhHBB0pInhBBCCJch3bVJJMkTQgghhOuQJVQsJMkTQgghhOuQJM9CxuQJIYQQQrggackTQgghhOuQMXkWkuQJIYQQwmUo6a61kO5aFxQTE8for1dR69kxVCj7JW1aTWbzplNpil229CAtmk2k/NNf8Ez1UXw8aDHXr919wDV2LpPJxFdfDePChdPcuXONLVvWU6/eC2mKDQoKYvbsGVy7FsKNG5dZsOAPihUrarPfW2+9we+//8aZM8dJSIjk558n2j1eQEAAw4d/zurVK7h5M4yEhEhq1Xrufk7PqUwmE8O+HMTpMzu5fvMk6zcuoW7dtNUvKCiAGTN/JDTsEGFXjjBn3hSKFSucakzNmlWIirlAVMwF8uTxT3XfpctmEhVzgbHjhqX5fB4FxuvzSy5ePMvdu7fYunUT9erVzexqOYW899LOZDLx5fBPOHvuALdun2PT5hXUrVcrTbFBQQHMnDWZ8CsnuXrtNPPmT6dYsSJW+3h7ezNx4jj27F3PlaunuH7jDLt2reXdd3vg4WHdvvPsczWYv+BXTv+7l9sR5zl/4RB/Lv2dmjWrOu18ReaQJM8FDfpgEb9M20qjV5/mw4/q4+7uxls9ZrFr57lU42bP3En/vvPJmdOH9z94iZatK7B82SFe7/Ir0dGPzi+jqVMn0adPL2bOnE3v3v2Jj49n6dKFPPNMzVTjsmfPzpo1K6hV61mGD/+awYOHUaFCedatW0nu3Lmt9h04sB8vvFCLQ4cOExvreOHNkiWL8/77/SlYMIgDBw455fycafKUMbz33hvMnrWQ/n0/Iz4+noWLp1OzZpVU47Jnz8ZfK//gueeqM3LEd3w+dDTly5Vh5aq55M6dy26MUoox4z4nIuLOPevVpOnLVKteKSOnlOVNmzaFvn1789tvs3jvvb7Ex8ezbNkSnnnmmcyu2n2T917aTfn5W3r3fptZs+bSt89HxMfHs2TJLJ55plqqcdmzZ2flqoU8/3wNvvpqHEOHjKB8+adZvWYRuXMn/XDy8fGmVOmSrFixio8/Gsb7Az9j//5DjBr9OT9P/c7qmCWKP05CQgITJ/5Cr3c/YMyY7wkokJ81axfzUv20JelZSkKccx4uQGmt7+8ASvUEvge2a63tvjqVUhr4Xmv9TirHWQfk1VqXSefz1wbWJitKAK4C64FPtNZHHMS9AiwFQoBgrXWCnX3OAMl/HoUDx4AxWusF6aknQDy/3d/FToP9+y/SttUU+g+sx+vdjA/W6Og4Gjf6kTx5sjNz9ut242Ji4nn+mdGUKJmfX37tjFIKgHVrj9PzrdkM+rgBHTo+vF91nm7dMxRXpUpltm3bwIABHzJ69DgAvLy8OHBgF2Fh4Tz7bB2HsQMG9GXEiC+oWvVZdu7cBUDJkiU4cGAXX389ho8++syyb+HChTl3zkiab90KZ+7cBbz+eg+bY/r6+uLp6cn169dp0aIZc+bMpE6dl/jnnw0ZOr+UTB55MhxbuXJ5Nm7+kw/e/5xxYycAxrXavWc1YeFXqFOrqcPYvv3e5svhH/FMjYbs2rUPgBIlH2f3ntWMGf0jn34ywiam+xsdGDxkILNmzufdXt0pGPg0V69et9nPy8uLffvX8ssvv/PZ4AH8+MM0+vT+OMPnCRAde/m+4p2lSpUqbN++mf79BzJ69FjAON+DB/cSFhbOM888n8k1BKUyNornv/bec3fzzXBslSoV2LzlbwYO/IyxY34AjGu1d98GwsPDef65hg5j+/V/h6+++owa1V9k5869AJQs+QR7921g1Kjv+OTjL1J97nHjhvO/d7oTXLA0ly+HOdzPx8eH4yd2sm/fQRo1bJP+k0wmNi5c3dcB0ul2yOtO+a71C/z5odb7QXBGS1574AxQVSn1hBOOl1HjgY5Ad+A3oCGwQSkV4GD/xHoHAqn9VNlrPm5HYBQQBMxXSr3llFo72d8rDuPurmjdJqkVxMvLgxYtK7B3zwVCQm7ajTt5Ioxbt6J4+eXSlgQPoHadEmTLZmL50oMPvO7O0LJlM+Li4pg4cYqlLDo6mp9/nkbNmtUJDg52GNuiRTO2b99p+ZIBOHbsOKtXr6VVqxZW+yZ+ydxLREQE16/bJjJZQbMWDYmLi2PK5N8sZdHR0UybNosaNSoTHBzoOLZ5Q3bs2GtJ8ACOHzvF2jUbadHiVZv9/f1zMXjIQIYOGcXNm7dSrVe//m/j5ubG2DETMnBWWVvLls3Nr8/JlrLo6GimTJlKzZo1Un19ZnXy3ku75i1eJS4ujsmTplvKoqOjmTr1N2rUqEpwcJDD2BYtXmXHjt2WBA/g2LGTrFmzgZYtm9zzuc+cNa5frlw5Ut0vMjKSK+FX7rlfVqR0nFMeruC+kjylVDGgJtAXo5WrvTMqlUEbtNYztNZTtdZ9gD5AHqBTyh2VUtmBJsAYYA+p1/ui+bgztNYjgWeAO+bjZzlHjoRSpGgefH29rMqfLmt8aBw9Emo3LibGeEF7eXvabPP29uDIkVASEh54Q+R9K1++HMePn+D27dtW5du37zRvL2s3TilF2bJl2LVrl822HTt28sQ
"text/plain": [
"<Figure size 720x480 with 2 Axes>"
]
2021-05-10 00:18:57 +01:00
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-05-10 00:18:57 +01:00
}
],
"source": [
"ensem_conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
" ensem_results[0], \n",
" num_classes=len(filtered_playlists))\n",
"\n",
"normalised_ensem_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
"for idx, row in enumerate(ensem_conf):\n",
" normalised_ensem_conf[idx, :] = row / np.sum(row)\n",
"\n",
"sns.heatmap(normalised_ensem_conf, \n",
" annot=True, \n",
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
" cmap='inferno')\n",
"plt.show()"
]
},
2021-05-06 16:19:44 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports & Setup"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 1,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import os\n",
2021-05-10 00:18:57 +01:00
"import random\n",
2021-05-06 16:19:44 +01:00
"\n",
"from google.cloud import bigquery\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.dpi'] = 120\n",
"import seaborn as sns\n",
"\n",
2021-05-10 00:18:57 +01:00
"from analysis.nn import ensem_classify\n",
2021-05-06 16:19:44 +01:00
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
"from analysis.query import *\n",
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import svm\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import plot_confusion_matrix\n",
"from sklearn.utils import class_weight\n",
"\n",
"import tensorflow as tf\n",
"\n",
"client = bigquery.Client()\n",
"spotnet = get_spotnet()\n",
"cache = 'query.csv'\n",
"first_day = datetime(year=2017, month=11, day=3)\n",
"sig_max, c_max = 0.5, 20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read Scrobble Frame"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"scrobbles = get_query(cache=cache)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Write Scrobble Frame"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
2021-05-06 16:19:44 +01:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
2021-05-10 00:18:57 +01:00
},
"metadata": {
"interpreter": {
"hash": "bce1a3677099e73bf385a0de8ef462673e03f7df0abce93e57e7ca76e8c504a2"
}
2021-05-06 16:19:44 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}