listening-analysis/playlist-nn.ipynb

533 lines
285 KiB
Plaintext
Raw Normal View History

2021-05-06 16:19:44 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Playlist Neural Network\n",
"\n",
"Given a list of playlists, can unknown tracks be correctly classified?"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 17,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
2021-05-10 00:18:57 +01:00
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
"playlist_names = [\"ALL RAP\", \"DNB\", \"4/4\", \"cRock\", \"METAL\", \"cJazz\"] # super-genres with decomposed EDM\n",
2021-05-06 16:19:44 +01:00
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
"\n",
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
2021-05-10 00:18:57 +01:00
"headers = float_headers + [\"mode\", \"loudness\", \"tempo\"]\n",
"# headers = float_headers\n",
2021-05-06 16:19:44 +01:00
"\n",
"BALANCED_WEIGHTS = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pull and process playlist information.\n",
"\n",
"1. Get live playlist track information from spotify\n",
"2. Filter listening history for these tracks\n",
"\n",
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 18,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
"\n",
"# filter playlists by join with playlist track/artist names\n",
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
"\n",
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
"# distinct on uri\n",
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
"# select only descriptor float columns\n",
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Construct the dataset with associated labels before splitting into a train and test set."
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 19,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.concat(filtered_playlists)\n",
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
"labels = np.concatenate(labels)\n",
"\n",
"# stratify: maintains class proportions in test and train set\n",
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, \n",
" test_size=0.1, \n",
"# random_state=70, \n",
" stratify=labels\n",
" )\n",
"\n",
"class_weights = class_weight.compute_class_weight('balanced',\n",
" classes=np.unique(labels_train),\n",
" y=labels_train)\n",
"class_weights = {i: j for i, j in zip(range(len(filtered_playlists)), class_weights)}\n",
"\n",
"labels_train = tf.one_hot(labels_train, len(filtered_playlists))\n",
"labels_test = tf.one_hot(labels_test, len(filtered_playlists))"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 20,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
" return tf.keras.callbacks.TensorBoard(\n",
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
" )"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 21,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-05-10 00:18:57 +01:00
"def get_model(hidden_nodes=128,\n",
" layers=2,\n",
2021-05-06 16:19:44 +01:00
" classes=len(filtered_playlists),\n",
" activation=lambda: 'sigmoid', \n",
" weight_init=lambda: 'glorot_uniform'):\n",
" l = [tf.keras.layers.InputLayer(input_shape=data_train.to_numpy()[0].shape, name='Input')]\n",
" \n",
" for i in range(layers):\n",
" l.append(\n",
" tf.keras.layers.Dense(hidden_nodes, \n",
" activation=activation(), \n",
" kernel_initializer=weight_init(), \n",
" name=f'Hidden{i+1}')\n",
" )\n",
" \n",
" l.append(tf.keras.layers.Dense(classes, \n",
" activation='softmax', \n",
" kernel_initializer=weight_init(), \n",
" name='Output'))\n",
" \n",
" model = tf.keras.models.Sequential(l)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Single Model"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 22,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"output_type": "stream",
2021-05-10 00:18:57 +01:00
"name": "stdout",
2021-05-06 16:19:44 +01:00
"text": [
2021-05-10 00:18:57 +01:00
"Model: \"sequential_10\"\n_________________________________________________________________\nLayer (type) Output Shape Param # \n=================================================================\nHidden1 (Dense) (None, 64) 704 \n_________________________________________________________________\nOutput (Dense) (None, 6) 390 \n=================================================================\nTotal params: 1,094\nTrainable params: 1,094\nNon-trainable params: 0\n_________________________________________________________________\n"
2021-05-06 16:19:44 +01:00
]
}
],
"source": [
2021-05-10 00:18:57 +01:00
"model = get_model(hidden_nodes=64, layers=1)\n",
2021-05-06 16:19:44 +01:00
"\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
" loss='categorical_crossentropy', \n",
" metrics=['accuracy'])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 23,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"if BALANCED_WEIGHTS:\n",
" cw = class_weights\n",
"else:\n",
" cw = None\n",
"history = model.fit(data_train.to_numpy(), labels_train, \n",
" callbacks=[tensorboard_callback()], \n",
" validation_split=0.11,\n",
" verbose=0,\n",
" class_weight=cw,\n",
" epochs=50)"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 24,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
2021-05-10 00:18:57 +01:00
"output_type": "display_data",
2021-05-06 16:19:44 +01:00
"data": {
2021-05-10 00:18:57 +01:00
"text/plain": "<Figure size 720x480 with 1 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg height=\"265.995469pt\" version=\"1.1\" viewBox=\"0 0 385.78125 265.995469\" width=\"385.78125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <metadata>\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2021-05-09T23:59:37.477771</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.4.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 265.995469 \nL 385.78125 265.995469 \nL 385.78125 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 43.78125 228.439219 \nL 378.58125 228.439219 \nL 378.58125 10.999219 \nL 43.78125 10.999219 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"matplotlib.axis_1\">\n <g id=\"xtick_1\">\n <g id=\"line2d_1\">\n <path clip-path=\"url(#pfdea193794)\" d=\"M 58.999432 228.439219 \nL 58.999432 10.999219 \n\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n </g>\n <g id=\"line2d_2\">\n <defs>\n <path d=\"M 0 0 \nL 0 3.5 \n\" id=\"ma9c7a78e69\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n </defs>\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"58.999432\" xlink:href=\"#ma9c7a78e69\" y=\"228.439219\"/>\n </g>\n </g>\n <g id=\"text_1\">\n <!-- 0 -->\n <g transform=\"translate(55.818182 243.037656)scale(0.1 -0.1)\">\n <defs>\n <path d=\"M 2034 4250 \nQ 1547 4250 1301 3770 \nQ 1056 3291 1056 2328 \nQ 1056 1369 1301 889 \nQ 1547 409 2034 409 \nQ 2525 409 2770 889 \nQ 3016 1369 3016 2328 \nQ 3016 3291 2770 3770 \nQ 2525 4250 2034 4250 \nz\nM 2034 4750 \nQ 2819 4750 3233 4129 \nQ 3647 3509 3647 2328 \nQ 3647 1150 3233 529 \nQ 2819 -91 2034 -91 \nQ 1250 -91 836 529 \nQ 422 1150 422 2328 \nQ 422 3509 836 4129 \nQ 1250 4750 2034 4750 \nz\n\" id=\"DejaVuSans-30\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_2\">\n <g id=\"line2d_3\">\n <path clip-path=\"url(#pfdea193794)\" d=\"M 121.11446 228.439219 \nL 121.11446 10.999219 \n\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n </g>\n <g id=\"line2d_4\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"121.11446\" xlink:href=\"#ma9c7a78e69\" y=\"228.439219\"/>\n </g>\n </g>\n <g id=\"text_2\">\n <!-- 10 -->\n <g transform=\"translate(114.75196 243.037656)scale(0.1 -0.1)\">\n <defs>\n <path d=\"M 794 531 \nL 1825 531 \nL 1825 4091 \nL 703 3866 \nL 703 4441 \nL 1819 4666 \nL 2450 4666 \nL 2450 531 \nL 3481 531 \nL 3481 0 \nL 794 0 \nL 794 531 \nz\n\" id=\"DejaVuSans-31\" transform=\"scale(0.015625)\"/>\n </defs>\n <use xlink:href=\"#DejaVuSans-31\"/>\n <use x=\"63.623047\" xlink:href=\"#DejaVuSans-30\"/>\n </g>\n </g>\n </g>\n <g id=\"xtick_3\">\n <g id=\"line2d_5\">\n <path clip-path=\"url(#pfdea193794)\" d=\"M 183.229487 228.439219 \nL 183.229487 10.999219 \n\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n </g>\n <g id=\"line2d_6\">\n <g>\n <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"183.229487\" xlink:href=\"#ma9c7a78e69\" y=\"228.439219\"/>\n </g>\n </g>\n <g id=\"text_3\">\n <!-- 20 -->\n <g transform=\"translate(176.8
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoQAAAG/CAYAAADB4sa8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AAB72klEQVR4nO3dd3hUVf7H8fdJgFBCE+mCINJEmigqAoINC+qu4lqw66o/166ra13b6trWLmvvKIptrVhBFAQEAUGKSJUqHUIIKef3x5mbO0kmydTMJPN5Pc88uXPrmZxk5jvnnO+5xlqLiIiIiKSvjGQXQERERESSSwGhiIiISJpTQCgiIiKS5hQQioiIiKQ5BYQiIiIiaU4BoYiIiEiaU0AoIiIikuYUEIqIiIikOQWEIiIiImlOAaGIiIhImlNAKCIiIpLmFBCKiIiIpLmUDgiNMdnGmDuMMZ8ZYzYaY6wx5twIjm9ijHnGGPOHMSbHGPONMWa/BBZZREREpNpJ6YAQ2B24DegOzIrkQGNMBvAxcAbwBHA90AIYb4zpHOdyioiIiFRbtZJdgEqsBlpba9cYY/YHpkVw7AhgAHCKtXYsgDHmLWAhcAcuUBQRERFJeyndQmitzbPWrony8BHAWuDdoPP9AbwFnGiMyYpDEUVERESqvVRvIYxFX2CGtbao1PqpwEVAF+DnUAcaY1oAzUutzg4cMwfYFd+iioiIiMRVHaAdMMFau6WynWtyQNga+DbE+tWBn20oJyAELgX+mYhCiYiIiFShE4H/VbZTTQ4I6wF5IdbvDNpenqeAt0ut6waMHT16NHvttVccildWbm4uP//8Mz179qRevYqKJ8mg+kldqpvUpvpJXaqb1BZL/SxevJgzzjgDYEU4+9fkgDAXCDVOsG7Q9pCsteuAdcHrjDEA9OrVix49esSpiCVt3bqVLVu2sN9++9GoUaOEXEOip/pJXaqb1Kb6SV2qm9QWS/1kZ2d7i2ENc0vppJIYrcZ1G5fmrVtVhWURERERSVk1OSCcCewXmI8w2IHADtz0MyIiIiJpr0YEhMaY1saYbsaY2kGrxwItgZOC9tsdOAX40FobanyhiIiISNpJ+TGExpjLgCa4rGCA440xewSWHw+kUt8LnAN0BJYGto0FfgBeNMbsA6zHZQ9nogxiERERkWIpHxAC1wF7Bj0/Cb/V7zUg5Nw61tpCY8yxwAPAFbis4mnAudbaBYkrroiIiEj1kvIBobW2Qxj7nAucG2L9JuDCwENERCQlFRUVsXbtWvLy8igqKn0/hcTJz89nt912Y/Xq1axfv77KrivhKV0/GRkZZGVl0bJlSzIy4jvqL+UDQhERkZqsqKiI5cuXk5ubS2ZmJpmZmcVTnSVarVq1aN68ObVqKRxIRcH1Y61l165d5ObmkpeXR/v27eMaFOovQEREJInWrl1Lbm4uu+22Gy1atKiyYBCgsLCQbdu20bBhQzIzM6vsuhKe0vVjrWXdunVs3LiRtWvX0rp1qNn1olMjsoxFRESqq7y8PDIzM6s8GJTqxxhDixYtyMzMJC8vvpOlKCAUERFJoqKioirtJpbqzRhDZmZm3MeaKiAUERFJMgWDEolE/L0oIBQRERFJcwoIRURERNKcAkIRERGpts4991w6dOiQ7GJUewoIRUREJO6MMWE9xo8fn+yiCpqHUERERBLg1VdfLfH8lVde4Ysvviizvnv37jFd59lnn63Su7vUVAoIRUREJO7OPPPMEs9/+OEHvvjiizLrS9uxYwf169cP+zq1a9eOqnxSkrqMRUREJCmGDBnCvvvuy/Tp0xk8eDD169fnpptuAuCDDz7guOOOo02bNmRlZdGpUyfuuusuCgsLS5yj9BjCpUuXYozhwQcf5JlnnqFTp05kZWVxwAEHMG3atKp8edWKWghFREQkaTZs2MAxxxzDaaedxplnnknLli0BeOmll8jOzuaaa64hOzubr7/+mttuu42tW7fywAMPVHre0aNHs23bNi6++GKMMdx///2cdNJJLF68WK2KISggFBERSUFXXQUzZyb2GtZmUFiYTWZmBpXNddynDzzySPzLsGbNGv773/9y8cUXl1g/evRo6tWrV/z8kksu4ZJLLuGpp57i7rvvJisrq8LzLl++nF9//ZWmTZsC0LVrV0488UTGjRvH8OHD4/9CqjkFhCIiIilo5kyYMCHRVzEkOxTIysrivPPOK7M+OBjctm0beXl5DBo0iKeffpr58+fTu3fvCs976qmnFgeDAIMGDQJg8eLFcSp5zaKAUEREJAX16ZP4a1hrKSwsDOteyokqT9u2balTp06Z9XPnzuWWW27h66+/ZuvWrSW2bdmypdLztm/fvsRzLzjctGlTDKWtuRQQioiIpKBEdM+WVlhYxLZt22nYsCGZmZmJv2AIwS2Bns2bN3PooYfSqFEj7rzzTjp16kTdunWZMWMGN9xwQ1jTzJT3eqy1MZe5JlJAKCIiIill/PjxbNiwgXfffZfBgwcXr1+yZEkSS1WzadoZERERSSle615wa96uXbt46qmnklWkGk8thCIiIpJSBgwYQNOmTTnnnHO44oorMMbw6quvqrs3gdRCKCIiIimlWbNmfPTRR7Ru3ZpbbrmFBx98kCOPPJL7778/2UWrsdRCKCIiIgn3xBNP8MQTT5RYN378+HL3HzBgAJMnTy6zvnQr4UsvvVTieYcOHcptSVQLY/nUQigiIiKS5hQQioiIiKQ5BYQiIiIiaU4BoYiIiEiaU0AoIiIikuYUEIqIiIikOQWEIiIiImlOAaGIiIhImlNAKCIiIpLmFBCKiIiIpDkFhCIiIiJpTgGhiIiIpLylS5dijClx7+Lbb78dY0xYxxtjuP322+NapiFDhjBkyJC4njNZFBCKiIhI3J1wwgnUr1+fbdu2lbvPyJEjqVOnDhs2bKjCkkXml19+4fbbb2fp0qXJLkpCKSAUERGRuBs5ciS5ubm89957Ibfv2LGDDz74gKOPPppmzZpFdY1bbrmF3NzcWIpZqV9++YU77rgjZED4+eef8/nnnyf0+lVFAaGIiIjE3QknnEDDhg0ZPXp0yO0ffPABOTk5jBw5Mupr1KpVi7p160Z9fKzq1KlDnTp1knb9eFJAKCIiInFXr149TjrpJL766ivWrVtXZvvo0aNp2LAhAwcO5LrrrqNnz55kZ2fTqFEjjjnmGGbNmlXpNUKNIczLy+Pqq6+mefPmNGzYkBNOOIHff/+9zLHLli3j0ksvpWvXrtSrV49mzZpxyimnlGgJfOmllzjllFMAGDp0KMYYjDGMHz8eCD2GcN26dVxwwQW0bNmSunXr0rt3b15++eUS+3jjIR988EGeeeYZOnXqRFZWFgcccADTpk2r9HUnQq2kXFVERERqvJEjR/Lyyy/z1ltvcdlllxWv37hxI+PGjeP0009n9erVvP/++5xyyil07NiRtWvX8vTTT3PooYfyyy+/0KZNm4iueeGFF/Laa69xxhlnMGDAAL7++muOO+64MvtNmzaNSZMmcdppp7HHHnuwdOlSRo0axZAhQ/jll1+oX78+gwcP5oorruCxxx7jpptuonv37gDFP0vLzc1lyJAhLFq0iMsuu4yOHTvy9ttvc+6557J582auvPLKEvuPHj2abdu2cfHFF2OM4f777+ekk05i8eLF1K5dO6LXHSsFhCIiIpIQhx12GK1bt2b06NElAsK3336b/Px8Ro4cSc+ePVm4cCEZGX6n5VlnnUW3bt14/vnnufXWW8O+3qxZs3jttde49NJLefLJJwH429/+xsiRI5k9e3aJfY877jhGjBhRYt3xxx/PwQcfzDvvvMNZZ53FXnvtxaBBg3jsscc48sgjK80ofuaZZ5g3bx6vvfZacVf4JZdcwqGHHsott9zC+eefT8OGDYv3X758Ob/++itNmzYFoGvXrpx44omMGzeO4cOHh/2640EBoYiISCq66iqYOTOhl8iwluzCQjIyM6Gy6Vv69IFHHono/JmZmZx22mk8/PDDLF26lA4dOgCuZaxly5YcfvjhZGZmFu9fWFjI5s2byc7OpmvXrsyYMSO
2021-05-06 16:19:44 +01:00
},
"metadata": {
"needs_background": "light"
2021-05-10 00:18:57 +01:00
}
2021-05-06 16:19:44 +01:00
}
],
"source": [
"history.history\n",
"plt.plot(range(len(history.history[\"accuracy\"])), history.history[\"accuracy\"], label=\"Train\", c=(0, 0, 1))\n",
"plt.plot(range(len(history.history[\"val_accuracy\"])), history.history[\"val_accuracy\"], label=\"Validation\", c=(1, 0, 0))\n",
"\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.ylim(0, 1)\n",
"\n",
"plt.grid()\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test\n",
"\n",
"Single number below from the evaluate function"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 25,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"output_type": "stream",
2021-05-10 00:18:57 +01:00
"name": "stdout",
2021-05-06 16:19:44 +01:00
"text": [
2021-05-10 00:18:57 +01:00
"10/10 [==============================] - 0s 2ms/step - loss: 0.7529 - accuracy: 0.7609\n"
2021-05-06 16:19:44 +01:00
]
},
{
2021-05-10 00:18:57 +01:00
"output_type": "execute_result",
2021-05-06 16:19:44 +01:00
"data": {
"text/plain": [
2021-05-10 00:18:57 +01:00
"[0.7529351115226746, 0.7609427571296692]"
2021-05-06 16:19:44 +01:00
]
},
"metadata": {},
2021-05-10 00:18:57 +01:00
"execution_count": 25
2021-05-06 16:19:44 +01:00
}
],
"source": [
"model.evaluate(data_test.to_numpy(), labels_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get raw predictions from test data to generate a confusion matrix"
]
},
{
"cell_type": "code",
2021-05-10 00:18:57 +01:00
"execution_count": 26,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
2021-05-10 00:18:57 +01:00
"output_type": "display_data",
2021-05-06 16:19:44 +01:00
"data": {
2021-05-10 00:18:57 +01:00
"text/plain": "<Figure size 720x480 with 2 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 380.336375 248.518125\" width=\"380.336375pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <metadata>\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2021-05-09T23:59:37.976238</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.4.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 248.518125 \nL 380.336375 248.518125 \nL 380.336375 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 54.78125 224.64 \nL 322.62125 224.64 \nL 322.62125 7.2 \nL 54.78125 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"QuadMesh_1\">\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 54.78125 7.2 \nL 99.42125 7.2 \nL 99.42125 43.44 \nL 54.78125 43.44 \nL 54.78125 7.2 \n\" style=\"fill:#fbbe23;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 99.42125 7.2 \nL 144.06125 7.2 \nL 144.06125 43.44 \nL 99.42125 43.44 \nL 99.42125 7.2 \n\" style=\"fill:#040314;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 144.06125 7.2 \nL 188.70125 7.2 \nL 188.70125 43.44 \nL 144.06125 43.44 \nL 144.06125 7.2 \n\" style=\"fill:#320a5e;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 188.70125 7.2 \nL 233.34125 7.2 \nL 233.34125 43.44 \nL 188.70125 43.44 \nL 188.70125 7.2 \n\" style=\"fill:#0b0724;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 233.34125 7.2 \nL 277.98125 7.2 \nL 277.98125 43.44 \nL 233.34125 43.44 \nL 233.34125 7.2 \n\" style=\"fill:#000004;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 277.98125 7.2 \nL 322.62125 7.2 \nL 322.62125 43.44 \nL 277.98125 43.44 \nL 277.98125 7.2 \n\" style=\"fill:#010108;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 54.78125 43.44 \nL 99.42125 43.44 \nL 99.42125 79.68 \nL 54.78125 79.68 \nL 54.78125 43.44 \n\" style=\"fill:#040312;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 99.42125 43.44 \nL 144.06125 43.44 \nL 144.06125 79.68 \nL 99.42125 79.68 \nL 99.42125 43.44 \n\" style=\"fill:#fa9008;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 144.06125 43.44 \nL 188.70125 43.44 \nL 188.70125 79.68 \nL 144.06125 79.68 \nL 144.06125 43.44 \n\" style=\"fill:#340a5f;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 188.70125 43.44 \nL 233.34125 43.44 \nL 233.34125 79.68 \nL 188.70125 79.68 \nL 188.70125 43.44 \n\" style=\"fill:#040312;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 233.34125 43.44 \nL 277.98125 43.44 \nL 277.98125 79.68 \nL 233.34125 79.68 \nL 233.34125 43.44 \n\" style=\"fill:#260c51;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 277.98125 43.44 \nL 322.62125 43.44 \nL 322.62125 79.68 \nL 277.98125 79.68 \nL 277.98125 43.44 \n\" style=\"fill:#000004;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 54.78125 79.68 \nL 99.42125 79.68 \nL 99.42125 115.92 \nL 54.78125 115.92 \nL 54.78125 79.68 \n\" style=\"fill:#040312;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 99.42125 79.68 \nL 144.06125 79.68 \nL 144.06125 115.92 \nL 99.42125 115.92 \nL 99.42125 79.68 \n\" style=\"fill:#040312;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 144.06125 79.68 \nL 188.70125 79.68 \nL 188.70125 115.92 \nL 144.06125 115.92 \nL 144.06125 79.68 \n\" style=\"fill:#fcffa4;\"/>\n <path clip-path=\"url(#pe16421a0a6)\" d=\"M 188.70125 79.68 \nL 233.34125 79.68 \nL 233.34125 115.92 \nL 188.70125 11
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnkAAAGgCAYAAADW0HHbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AAChZklEQVR4nOzdd3wURRvA8d+kxyRAqAkdpSud0JFuo0kTkC6IiGABREWxgAL60u10UBGkgwJKlU7oTSWEXhISOoH0zPvHXi7l7pKQHCQcz9fPfmJmd3Znlrvcc9NWaa0RQgghhBCOxSm7CyCEEEIIIexPgjwhhBBCCAckQZ4QQgghhAOSIE8IIYQQwgFJkCeEEEII4YAkyBNCCCGEcEAS5AkhhBBCOCAJ8oQQQgghHJAEeUIIIYQQDkiCPCGEEEIIByRBnhBCCCGEA3LJ7gI8SuI2lZYHBWeQZ4ub2V2Eh0Jl9xeyuwgPhePxO7K7CA+NO9FnsrsIDwWl5OMzoxISItWDvF48v9jls9aZbg+03PeDvEqFEEII4TASEuLtch5nB+jrdIAqCCGEEEKI1KQlTwghhBAOQ+u47C5CjiFBnhBCCCEchtb26a51BNJdK4QQQgjhgKQlTwghhBAOI0G6a80kyBNCCCGEw5AxeUkkyBNCCCGEw5AgL4mMyRNCCCGEcEDSkieEEEIIh6ETpCUvkQR5QgghhHAc0l1rJt21QgghhBAOSFryhBBCCOEwZOJFEgnyhBBCCOE4EmKzuwQ5hnTXCiGEEEI4IGnJE0IIIYTDkO7aJBLkCSGEEMJxyBIqZhLkCSGEEMJxSJBnJmPyhBBCCCEckLTkCSGEEMJxyJg8MwnyhBBCCOEwlHTXmkl37UMuJhYmLFU0fs+J6oOd6DLOiR3/pJ+vxQgnnhzgbHV7fmTKl4Wt46avVfepVlnj5ubGmLEjOXvuCLdun2P7jrU0a94oQ3kLF/Zj/q8zCL8SzNVrp1iydB6lSpVIcYyHhwfTpk3mwMEtXLl6kus3zrBv3yYGD+6Pi0va35t++GEisXHhLF/xS6brd7+5urkwePRLrA2ezPYr05m7+WNqN30y3Xwlyvgx5MuXmbXhI3Zcnc6+O3PxL57f6rGr/hnPvjtzLbYPpvSyd3Xsws3NjVGfv8uJU9sIv36ETVsW06RZ/Qzl9S9ciHk/T+FC6D4uhR1gwaLvKVmqmMVxEVEnrG5DhvVPcdyx45tsHnvw6Dq71DcncHNzY9y4MVy8eJa7d2+xa9d2mjdvlt3Fshujfp9z4cIp7ty5xs6dW2jevGmG8hYuXJgFC37m2rUQbty4zLJlv1GqVEmL4wYMeJWFC3/hzJkgEhIimTVrmtXz+fn5MXbsaDZsWMvNm2EkJETSqFHDrFRP5BDSkveQGzFXsW6/okczTfGCsGKn4vVvnJg1JIEapW3ne/+lBO5GpwzSLl2FqSudqFdRWxxfr4KmTZ2U6RWKWR6XE8yc9TUdOrRm6tQfCT5xip69urBq1a+0aN6O7dt328zn5eXFuvXLyZ3bh3HjJhMXG8ubbw1gw8YV1KzRhGvXrgPg6elBxSfLsXbtes6eOU9CQgJ16wYwfsJoAmpVp2ePAVbPX6NGFXr26kJkZOR9qbe9fDrtVZq/WJP53/7FueDLtO7egKlLh/Da8+M4uPOEzXyVa5emy+stOP3fRU4fD6F8lRI2jwX479BZfp66NkXauROhdqmDvf0440tebPcs3349l5Mnz9Cte3uWLp/OC8/2YOeOfTbzeXk9xuo/fyJ3Lh/Gf/UDsbGxDHqzD2vX/UK9Wm24du1GiuM3rN/G/F+WpUg7fDDlt7b3hn2Bl/djKdKKFy/CJ58NYeP6bVmraA4yZ85MOnbswOTJUzlxIpjevXuyevUqmjRpwfbt27O7eFk2e/Z0OnZsx5Qp33DiRDC9evXgjz+W07Tpc2zfvsNmPi8vLzZuXEvu3LkYO/Z/xMbG8vbbg9m8eR3VqtXm2rVr5mOHDx+Kj483gYF78ff3s3nOcuXK8N57wwgKOsGRI8eoV6+OXev6wElLnlmWgzyl1EDgWyBQa13bxjEa+FZrPSiN82wG8mutn7rH6zcGNiVLSgCuAluAkVrrf23kewH4AwgBimqtE6wccwZI/kkVDhwHJmqtl6U+/kE7fBrW7HViWPsE+jxjBFxt62jajnJi4lInfhluUSWzZlUBUgZpP6w2gr5WtSyDtxIFNa1r58ygLrmAgGp06dKe4cM/YdLE7wD46affOHhoK2PHfczTDVvazDvg9T6ULfsEdeu0YO/egwCsXbuBg4e28s6QgYz86AsArl+/QYP6z6fIO23aXG7evM0bg/rx7rCPuXw5zOL8kyaP4eeffqNJ05z7DfnJGo/zXKc6TB6xgJ+mrAHgj/nb+W3PF7z5eWdeafa5zbx//3GADYVf525EFD3eej7dIC/80nXWLLD9YZZT1KhZmU4vtWLE++OYOnkmAPN/Xkbg/tWM/mI4zZt0tpn31de6UaZMKZ6u3579+44AsO7PLQTu/4PBb7/CZx9PTHF88InTLPx1ZZrl+X3Veou04e8PBGDhgrTzPiwCAgLo2rULw4YNZ8KESQDMm/cTR48e5KuvxlK//tPZXMKsCQioSdeuL/Huux8wYcJkAObN+4UjR/bx5Zdf0KBBE5t5Bw58jbJly1CrVgP27jW+YKxZ8ydHjuxj6NC3+PDDT8zHNm78DOfOnQPg1q1wm+fct+8A+fIV5vr163To0E6CPAdij+7absAZoJZSKo22o/tuKtAD6Af8ArQEtiqlbH19SSy3P5BWG/lB03l7AOOBwsBSpZT15poH6K/9CmcnTaeGScGXuyt0qK85eEoRci2NzFb8Eagoml9T7Qnr+6NiIDqHPy2mfYfWxMXFMWP6PHNadHQ0s2f/Qt26tShatLDNvB06tGbPnv3mAA/g+PFgNm7cSseObdO99pmzxh/TPHlyWezr3v0lnnyyAiNHjrmH2jx4zdrVJC4unqWzkr43xUTHsmLeFqrUKUOhInlt5r11/Q53I6Lu6Xours54POaW6fI+CC+2e464uDhmz1xoTouOjmHenEXUqVudIkVtt5C82O5Z9u45ZA7wAIKCTrF5007ad3jBah4PD3fc3e/tnrzUuTWnT59n964D95Qvp+rYsT1xcXFMmzbDnBYdHc3MmbOpV68uRYsWzcbSZV3Hju1M9ZtpTouOjmbWrDnUq1cnzfp16NCOwMC95gAP4PjxIDZs2ESnTh1SHJsY4KUnIiKC69ev32Mtci6l4+yyOYIsBXlKqVJAPWAIRitXN3sUKpO2aq1/1lrP1lq/A7wD5AN6pj5QKeUFtAUmAgdIu9wXTef9WWv9FVAfuGM6f7b677yiREHw9kyZXqmkEfT9dyHj5/r3HJwKVbwQYL21bvkuRc23nKg+2JnWnzrxe2DOHI9XtWolgoJOcvt2RIr0PXv2A1ClqvWGYqUUlSpVZF+yAC953tKlS+Ht7ZUi3dXVlXz58lK0aGHatn2BIUPe4MyZcwQHn05xnLe3F2PGfsy4cZOttvDlJOWqlODciVDu3E4ZrB3dewqAspWL2+1aAY0qsP3KdLaHT2fVP+PpOrCF3c5tT1WqViT4xBmL19S+vYcBqFy5otV8SimeqlSeA/uPWuzbt/cwTzxRwuI11a1He8KuHebqzWPsPbCGTp1bp1u+ylUqUr5CaRYtXJXRKuV41apVJSgoiNu3b6dIDwzcA0DVqlWyo1h2U7VqFYKCTlip317T/spW8ymlqFz5KfbtsxwisGfPXkqXfgJvb2/7F1g8tLLaXdsNuI7R7bnY9PtnWS2UnWw1/bTWLtUO8AQWmX5+qJR6XWudbjOE1jpUKfUvkO1/ZcJvQYHclun5TQ1J4TcUqbtkbUkM2qx11VZ9XPNcDU2R/Jrwm4pfNyvem+VERGQCXRrlrC5cP79ChIZetkgPDTHSCtsYl5I3ry8eHh6EpJW3sB9BQSfN6e3ateSX+dPNv+/dc4BXX32L+Pj4FPk/GjmMyMhIpkz+4d4r9IDl98vDlcs3LNKvhBppBfzz2OU6J45e4OD
2021-05-06 16:19:44 +01:00
},
"metadata": {
"needs_background": "light"
2021-05-10 00:18:57 +01:00
}
2021-05-06 16:19:44 +01:00
}
],
"source": [
"predictions = model(data_test.to_numpy())\n",
"\n",
"conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
" [tf.math.argmax(i) for i in predictions], \n",
" num_classes=len(filtered_playlists))\n",
"\n",
"normalised_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
"for idx, row in enumerate(conf):\n",
" normalised_conf[idx, :] = row / np.sum(row)\n",
"\n",
"sns.heatmap(normalised_conf, \n",
" annot=True, \n",
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
" cmap='inferno')\n",
"plt.show()"
]
},
2021-05-10 00:18:57 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ensemble Model"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"models = [get_model(hidden_nodes=random.randint(16, 128), \n",
" layers=random.randint(1, 2)) \n",
" for _ in range(9)]\n",
"\n",
"for m in models:\n",
" m.compile(\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
" loss='categorical_crossentropy', \n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the models of the meta-classifier. Get a random number of epochs from a reasonable range to introduce variation between models. *Weird?*: Randomly decide whether to regularise by class weights. Class weights change the penalisation methods such that smaller classes are treated more importantly when computing loss. I'm not sure whether it's going to help. "
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"training model 1\n",
"training model 2\n",
"training model 3\n",
"training model 4\n",
"training model 5\n",
"training model 6\n",
"training model 7\n",
"training model 8\n",
"training model 9\n"
]
}
],
"source": [
"if BALANCED_WEIGHTS:\n",
" cw = class_weights\n",
"else:\n",
" cw = None\n",
"\n",
"ensem_histories = list()\n",
"for idx, m in enumerate(models):\n",
" print(f'training model {idx+1}')\n",
" h = m.fit(data_train.to_numpy(), labels_train, \n",
" callbacks=[tensorboard_callback()], \n",
" validation_split=0.11,\n",
" verbose=0,\n",
"# class_weight=cw,\n",
" class_weight=random.choice([*([class_weights]*3), None]),\n",
" epochs=random.randint(20, 100))\n",
" ensem_histories.append(h)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"79.5% Accuracy, 82.7% Agreement, 67.6% Ind. Accuracy\n"
]
}
],
"source": [
"ensem_results = ensem_classify(models, data_test, labels_test)\n",
"print(f\"{ensem_results[2]*100:.3}% Accuracy, {ensem_results[3]*100:.3}% Agreement, {ensem_results[4]*100:.3}% Ind. Accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 720x480 with 2 Axes>",
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg height=\"248.518125pt\" version=\"1.1\" viewBox=\"0 0 380.336375 248.518125\" width=\"380.336375pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n <metadata>\n <rdf:RDF xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2021-05-10T00:00:22.760673</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.4.1, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linecap:butt;stroke-linejoin:round;}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 248.518125 \nL 380.336375 248.518125 \nL 380.336375 0 \nL 0 0 \nz\n\" style=\"fill:none;\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 54.78125 224.64 \nL 322.62125 224.64 \nL 322.62125 7.2 \nL 54.78125 7.2 \nz\n\" style=\"fill:#ffffff;\"/>\n </g>\n <g id=\"QuadMesh_1\">\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 54.78125 7.2 \nL 99.42125 7.2 \nL 99.42125 43.44 \nL 54.78125 43.44 \nL 54.78125 7.2 \n\" style=\"fill:#fcffa4;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 99.42125 7.2 \nL 144.06125 7.2 \nL 144.06125 43.44 \nL 99.42125 43.44 \nL 99.42125 7.2 \n\" style=\"fill:#010108;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 144.06125 7.2 \nL 188.70125 7.2 \nL 188.70125 43.44 \nL 144.06125 43.44 \nL 144.06125 7.2 \n\" style=\"fill:#08051d;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 188.70125 7.2 \nL 233.34125 7.2 \nL 233.34125 43.44 \nL 188.70125 43.44 \nL 188.70125 7.2 \n\" style=\"fill:#02020e;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 233.34125 7.2 \nL 277.98125 7.2 \nL 277.98125 43.44 \nL 233.34125 43.44 \nL 233.34125 7.2 \n\" style=\"fill:#000004;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 277.98125 7.2 \nL 322.62125 7.2 \nL 322.62125 43.44 \nL 277.98125 43.44 \nL 277.98125 7.2 \n\" style=\"fill:#040314;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 54.78125 43.44 \nL 99.42125 43.44 \nL 99.42125 79.68 \nL 54.78125 79.68 \nL 54.78125 43.44 \n\" style=\"fill:#260c51;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 99.42125 43.44 \nL 144.06125 43.44 \nL 144.06125 79.68 \nL 99.42125 79.68 \nL 99.42125 43.44 \n\" style=\"fill:#ea632a;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 144.06125 43.44 \nL 188.70125 43.44 \nL 188.70125 79.68 \nL 144.06125 79.68 \nL 144.06125 43.44 \n\" style=\"fill:#260c51;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 188.70125 43.44 \nL 233.34125 43.44 \nL 233.34125 79.68 \nL 188.70125 79.68 \nL 188.70125 43.44 \n\" style=\"fill:#180c3c;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 233.34125 43.44 \nL 277.98125 43.44 \nL 277.98125 79.68 \nL 233.34125 79.68 \nL 233.34125 43.44 \n\" style=\"fill:#180c3c;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 277.98125 43.44 \nL 322.62125 43.44 \nL 322.62125 79.68 \nL 277.98125 79.68 \nL 277.98125 43.44 \n\" style=\"fill:#000004;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 54.78125 79.68 \nL 99.42125 79.68 \nL 99.42125 115.92 \nL 54.78125 115.92 \nL 54.78125 79.68 \n\" style=\"fill:#160b39;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 99.42125 79.68 \nL 144.06125 79.68 \nL 144.06125 115.92 \nL 99.42125 115.92 \nL 99.42125 79.68 \n\" style=\"fill:#040312;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 144.06125 79.68 \nL 188.70125 79.68 \nL 188.70125 115.92 \nL 144.06125 115.92 \nL 144.06125 79.68 \n\" style=\"fill:#f2f482;\"/>\n <path clip-path=\"url(#p19e3caf8b9)\" d=\"M 188.70125 79.68 \nL 233.34125 79.68 \nL 233.34125 115.92 \nL 188.70125 11
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnkAAAGgCAYAAADW0HHbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AACfdUlEQVR4nOzdd3gU1RrA4d9J2SSQAKEmEJogICA9IE1AQFSQjoB0UVRERUWxgCKgiBcRwUoXkaJUkaL0DqF3Cb2HhE4gPef+MZtN2d0QkiVl+d77zBM5c87smbmzu9+eNkprjRBCCCGEcC4uWV0BIYQQQgjheBLkCSGEEEI4IQnyhBBCCCGckAR5QgghhBBOSII8IYQQQggnJEGeEEIIIYQTkiBPCCGEEMIJSZAnhBBCCOGEJMgTQgghhHBCEuQJIYQQQjghCfKEEEIIIZyQW1ZX4GESx+/yoOA0cnd5OaurkCN4uBXK6irkCNFxN7O6CjlGfPzdrK5CjuDmmi+rq5BjxMSGqcx8PUd917rSLVPr/SBIkCeEEEIIpxEfH+eQ47g6QV+nE5yCEEIIIYRISVryhBBCCOE0tI7N6ipkGxLkCSGEEMJpaO2Y7lpnIN21QgghhBBOSFryhBBCCOE04qW71kKCPCGEEEI4DRmTl0iCPCGEEEI4DQnyEsmYPCGEEEIIJyQteUIIIYRwGjpeWvISSJAnhBBCCOch3bUW0l0rhBBCCOGEpCVPCCGEEE5DJl4kkiBPCCGEEM4jPiara5BtSHetEEIIIYQTkpY8IYQQQjgN6a5NJEGeEEIIIZyHLKFiIUGeEEIIIZyHBHkWMiZPCCGEEMIJSUueEEIIIZyHjMmzkCBPCCGEEE5DSXethXTX5nDR0bF8879VNGowlupVvqRzp8ls2XwiTWW3bDlJ7x6/Uq/O/6hTazSdO07mr0X7rfJduRLOxx8tpkHdMVSv8iUd2k1kxfLDjj4VhzGZTHz11UjOnz/JnTvX2Lp1A82aPZWmskWLFmXOnJlcu3aJGzcus3DhH5QuXcoq32uvvcLcub9z+nQw8fERTJ060ebx/Pz8GDVqBKtXr+DmzVDi4yNo1KhhRk7PoUwmEyO//IiTp3dw7WYwGzYt5qmmaatf0aJFmDnrRy6FHuDylUP8MX8ypUqXSLVMvXqBRESfJSL6LAUK+Cbb98nQdyz7km7XbwWn+/wcxWQyMWrU55w79x/h4SFs2bKaZs2apKls0aL+zJkznatXz3D9+jkWLpxldU8FBBRj6NDBbN26hitXznD58klWr/6bpk0bWx2vYcN6LFo0m9OnD3HnzmUuXAhm2bL51KtXxwFnmn0Y7+MvuXDhDHfv3mLbts00a9Y0q6vlMCaTiS9HDeXM2QPcun2WzVtW0LRZozSVLVrUj1mzJxN25ThXr51k/oIZlC5dMlkeT09PJk4cx569G7hy9QTXb5xm1661vPlmP9zcUm/f+fnnscTEhrFo8e/pPj+RPUhLXg738YeL+fefI/ToWYeSpfKzaOE+Xus3m2m/9qRmLftfuGtWH+XNN+ZSrVoAb7zZCKUUK5Yf4sPBi7h+4y69ej8BQHh4FN1fnM7VK+H06FmHgoW8WbH8EO8OnEdsbDtaPf94Zp1qmk2bNomOHdvx3Xffc+zYcXr16sHSpYt46qln2Lx5i91yuXPnZs2aFeTNm4dRo/5HTEwMAwe+ybp1K6levQ7Xrl2z5P3gg/fw8fEmKGgn/v5+do9ZvvyjDB48iODgYxw4cIh69Z5w6Llm1KQp39Cu/XN8P34Kx4+fpkfPjiz6azrPNO/Cli077JbLnTsXK1bOJU8eH/43+gdiYmJ4862XWbnqD+oEPsO1azesyiil+Gbc54SH38HbO7fdY7/5xseE37lj+Xd8XFyGztERpk37iQ4d2vDddz9x/PgJevZ8kb///pOmTVuxefM2u+Vy587N6tV/m++pseZ7qj9r1y6lRo0GXLt2HYDWrZ/jgw8GsnjxUmbMmIWbmxs9enTl338X07dvf6ZPT/yyLVeuLPHx8fzyyzQuX75Mvnz56NbtBdatW87zz3fin39WP/DrkRmmT59Cx44dGDduPMeOHad3754sW7aEJk2as3nz5qyuXoZNmTqBDh2eZ/z4Xzh+7CQ9e3VhyZLZNG/Wjs2bt9stlzt3blauWkTevD589dU4YmNieOvt11i9ZjG1ajax3FNeXp5UrFSeFStWceb0OeLj46lbN5Ax34wgsHYNevZ4zebxa9asSs9eXYiIiHgg550ppCXPQmmtM3YApfoDPwBBWmubPyWVUhr4QWs9IJXjrAMKaq0r3+frNwbWJkmKB64CG4ChWusjdso9BywFLgEBWut4G3lOA0l/HoUBR4GxWuuF91NPgDh+z9jFTmH//gt06TSFQR8046W+9QCIioqldaufKFAgN7PmvGS37MsvzeT4sTD+Xf0mJpMR68fGxtPy2R/I5WVi4V+vAjBl8ha++d8qpk7vwRN1SwMQH6/p+sIULoXcYtWatzGZXB15WgC4u7ycrnKBgbXYvn0j77//Ed98Mw4ADw8PDhzYRWhoGA0a2G99ef/9dxk9+gtq127Azp27AChfvhwHDuzif/8byyeffGbJW6JECc6ePQvArVthzJu3kJde6md1TG9vb9zd3bl+/TodOrTjzz9n0aTJ06xfvzFd55eSh1uhdJetVasqG7cs4aPBIxn3rdES6eHhwa49KwkLu0KTRu3tln33vdf4YtTHNKjbil27jNbfcuXLsGvPSsZ+8zOfDf3aqszLr3Tns88HMWfWQga81ZcA/6pcvXrdsv+Toe8wZOg7VumOEB13M91lAwNrsG3bWt5/fwhjx04AjOu0f/82QkPDaNjwabtlBw16m9Gjh1OnThN27twNGIH//v3b+N//vmPIkOEAVKxYgcuXQ7l6NfGHhMlkYvfuTXh756ZUqUqp1tHLy4vjx/exb98BnnuuQ7rPFSA+/m6GyjtCYGAgQUFbGDToA7755lvAuOYHD+4lNDSM+vWfzOIagptrvnSXDQyszpat//LBB5/x7dgfAeP89u7bSFhYGE82bGm37HuDBvDVV59R94nm7Ny5F4Dy5cuyd99Gxoz5nqFDvkj1tceNG8UbA14moFglLl8Otdq/YeNS/jtyjCZPNeTQof9o26Zbus8zQUxsmMrwQe7D7UsvOeS71sd/aqbW+0FwRHdtN+A0UFspVdYBx0uv8UAP4GXgd6AlsFEpZa+ZJaHe/kBqfXl7zcftAYwBigILlFK2fwZlon9XHMbVVfFC55qWNA8PNzp0rM7ePee5dMn+F1t4eBR58npaAjwANzcXfH1z4eGZmLZ751ny589lCfAAXFwULZ6tyJWwcHbuOO3Yk8qgjh3bERsby8SJUyxpUVFRTJ06nXr1niAgIMBu2Q4d2hEUtNMS4AEcPRrM6tVr6dQp+RdnQoB3L+Hh4Vy/7tiAxVHadWhJbGwsUybPsqRFRUUxffpcnqhbi4AAf/tl2z/Hzh17LQEeQPDRE6xds5kOHVpZ5ff1zctnnw9ixOffcOPmrVTrpZTCx8c7HWf0YHTo0JbY2FgmTZpuSTPuqd+oV68OAQHFUinbhqCgXZYAD+Do0WOsWbOeTp3aWdIOH/4vWYAHEB0dzfLl/1K8eADe3qlfj4iICMLCrpA3b977PLvsqWPH9ub38WRLWlRUFFOmTKNevbqpvo9zgvYdnic2NpbJk2ZY0qKiopg27Xfq1q1NQEBRu2U7dHieHTt2WwI8gKNHj7NmzUY6dmxzz9c+fcb47MqXL4/Vvu7dX6BSpccYOvTL+zib7EfpWIdsziBDQZ5SqjRQD3gXo5Ur4yF/+m3UWs/UWk/TWr8DvAMUAHqmzKiUyg20AcYCe0i93hfMx52ptf4aqA/cMR8/Sx05EkLJUgXw9vZIlv54FeMD4r8jIXbL1q5dkuPHwhg/bi1nzlzj7Nlr/PTDBg4dvEjfl+tZ8kXHxOLh6W5V3sucdujQJUecisNUq1aV4OBj3L59O1l6UNBO8/4qNssppahSpTK7du2y2rdjx07Kli1zzy/anKZq1UocO3a
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"ensem_conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
" ensem_results[0], \n",
" num_classes=len(filtered_playlists))\n",
"\n",
"normalised_ensem_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
"for idx, row in enumerate(ensem_conf):\n",
" normalised_ensem_conf[idx, :] = row / np.sum(row)\n",
"\n",
"sns.heatmap(normalised_ensem_conf, \n",
" annot=True, \n",
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
" cmap='inferno')\n",
"plt.show()"
]
},
2021-05-06 16:19:44 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports & Setup"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 1,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import os\n",
2021-05-10 00:18:57 +01:00
"import random\n",
2021-05-06 16:19:44 +01:00
"\n",
"from google.cloud import bigquery\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.dpi'] = 120\n",
"import seaborn as sns\n",
"\n",
2021-05-10 00:18:57 +01:00
"from analysis.nn import ensem_classify\n",
2021-05-06 16:19:44 +01:00
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
"from analysis.query import *\n",
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import svm\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import plot_confusion_matrix\n",
"from sklearn.utils import class_weight\n",
"\n",
"import tensorflow as tf\n",
"\n",
"client = bigquery.Client()\n",
"spotnet = get_spotnet()\n",
"cache = 'query.csv'\n",
"first_day = datetime(year=2017, month=11, day=3)\n",
"sig_max, c_max = 0.5, 20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read Scrobble Frame"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"scrobbles = get_query(cache=cache)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Write Scrobble Frame"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
]
}
],
"metadata": {
"kernelspec": {
2021-05-10 00:18:57 +01:00
"name": "pythonjvsc74a57bd0bce1a3677099e73bf385a0de8ef462673e03f7df0abce93e57e7ca76e8c504a2",
"display_name": "Python 3.8.9 ('.venv': poetry)"
2021-05-06 16:19:44 +01:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
2021-05-10 00:18:57 +01:00
},
"metadata": {
"interpreter": {
"hash": "bce1a3677099e73bf385a0de8ef462673e03f7df0abce93e57e7ca76e8c504a2"
}
2021-05-06 16:19:44 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 4
2021-05-10 00:18:57 +01:00
}