listening-analysis/playlist-nn.ipynb

418 lines
93 KiB
Plaintext
Raw Normal View History

2021-05-06 16:19:44 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Playlist Neural Network\n",
"\n",
"Given a list of playlists, can unknown tracks be correctly classified?"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 3,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
"playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
"\n",
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
"headers = float_headers\n",
"\n",
"BALANCED_WEIGHTS = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pull and process playlist information.\n",
"\n",
"1. Get live playlist track information from spotify\n",
"2. Filter listening history for these tracks\n",
"\n",
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 4,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
"\n",
"# filter playlists by join with playlist track/artist names\n",
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
"\n",
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
"# distinct on uri\n",
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
"# select only descriptor float columns\n",
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Construct the dataset with associated labels before splitting into a train and test set."
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 5,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.concat(filtered_playlists)\n",
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
"labels = np.concatenate(labels)\n",
"\n",
"# stratify: maintains class proportions in test and train set\n",
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, \n",
" test_size=0.1, \n",
"# random_state=70, \n",
" stratify=labels\n",
" )\n",
"\n",
"class_weights = class_weight.compute_class_weight('balanced',\n",
" classes=np.unique(labels_train),\n",
" y=labels_train)\n",
"class_weights = {i: j for i, j in zip(range(len(filtered_playlists)), class_weights)}\n",
"\n",
"labels_train = tf.one_hot(labels_train, len(filtered_playlists))\n",
"labels_test = tf.one_hot(labels_test, len(filtered_playlists))"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 6,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
" return tf.keras.callbacks.TensorBoard(\n",
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
" )"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 7,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"def get_model(hidden_nodes=9,\n",
" layers=1,\n",
" classes=len(filtered_playlists),\n",
" activation=lambda: 'sigmoid', \n",
" weight_init=lambda: 'glorot_uniform'):\n",
" l = [tf.keras.layers.InputLayer(input_shape=data_train.to_numpy()[0].shape, name='Input')]\n",
" \n",
" for i in range(layers):\n",
" l.append(\n",
" tf.keras.layers.Dense(hidden_nodes, \n",
" activation=activation(), \n",
" kernel_initializer=weight_init(), \n",
" name=f'Hidden{i+1}')\n",
" )\n",
" \n",
" l.append(tf.keras.layers.Dense(classes, \n",
" activation='softmax', \n",
" kernel_initializer=weight_init(), \n",
" name='Output'))\n",
" \n",
" model = tf.keras.models.Sequential(l)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Single Model"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 8,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-05-07 01:03:08 +01:00
"Model: \"sequential\"\n",
2021-05-06 16:19:44 +01:00
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"Hidden1 (Dense) (None, 64) 512 \n",
"_________________________________________________________________\n",
"Hidden2 (Dense) (None, 64) 4160 \n",
"_________________________________________________________________\n",
"Output (Dense) (None, 5) 325 \n",
"=================================================================\n",
"Total params: 4,997\n",
"Trainable params: 4,997\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = get_model(hidden_nodes=64, layers=2)\n",
"\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
" loss='categorical_crossentropy', \n",
" metrics=['accuracy'])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 9,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"if BALANCED_WEIGHTS:\n",
" cw = class_weights\n",
"else:\n",
" cw = None\n",
"history = model.fit(data_train.to_numpy(), labels_train, \n",
" callbacks=[tensorboard_callback()], \n",
" validation_split=0.11,\n",
" verbose=0,\n",
" class_weight=cw,\n",
" epochs=50)"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 10,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"data": {
2021-05-07 01:03:08 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoQAAAG/CAYAAADB4sa8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AABhqUlEQVR4nO3dd5iU1dnH8e+9CyxlqSJVEERFYsOGSlSwR8USg7GgsZCosSAafTXYS2I0xK6JJsYKRiEqGrsR7AEbooJYAEF6X+oCO/f7x5lhZnZnd2d3tswyv891Pdc+c552Zs+Ue057zN0RERERkdyVV98ZEBEREZH6pYBQREREJMcpIBQRERHJcQoIRURERHKcAkIRERGRHKeAUERERCTHKSAUERERyXEKCEVERERynAJCERERkRyngFBEREQkxykgFBEREclxCghFREREclxWB4RmVmhmN5rZq2a2zMzczM6qwvFtzOwhM1tsZmvMbLyZ7VmLWRYRERFpcLI6IATaA9cBfYDPq3KgmeUBLwGnAfcB/wd0ACaY2Q41nE8RERGRBqtRfWegEvOBzu6+wMz2Bj6qwrGDgf7ASe4+FsDMngG+AW4kBIoiIiIiOS+rawjdvdjdF1Tz8MHAQuDZhPMtBp4BjjezghrIooiIiEiDl+01hJnYA/jU3SOl0icB5wI7Al+kOtDMOgBbl0oujB7zJbChZrMqIiIiUqOaAN2At919ZWU7b8kBYWfgnRTp86N/u1BOQAhcAFxfG5kSERERqUPHAy9UttOWHBA2A4pTpK9P2F6eB4AxpdJ2AsaOHj2a7bbbrgayV9a6dev44osv2HXXXWnWrKLsSX1Q+WQvlU12U/lkL5VNdsukfGbMmMFpp50GMCed/bfkgHAdkKqfYNOE7Sm5+yJgUWKamQGw2267sfPOO9dQFpMVFRWxcuVK9txzT1q1alUr15DqU/lkL5VNdlP5ZC+VTXbLpHwKCwtjq2l1c8vqQSUZmk9oNi4tljavDvMiIiIikrW25IBwMrBndD7CRPsCawnTz4iIiIjkvC0iIDSzzma2k5k1TkgeC3QETkzYrz1wEvCiu6fqXygiIiKSc7K+D6GZXQS0IYwKBjjWzLaJrt8bHUp9K3Am0BOYFd02Fvgf8IiZ/QRYQhg9nI9GEIuIiIhslvUBIXA5sG3C4xOJ1/o9CaScW8fdS8zsaODPwDDCqOKPgLPcfXrtZVdERESkYcn6gNDde6Sxz1nAWSnSlwO/ji4iIiJZKRKJsHDhQoqLi4lESt9PofZs3LiRdu3aMX/+fJYsWVJn15X0lC6fvLw8CgoK6NixI3l5NdvrL+sDQhERkS1ZJBJh9uzZrFu3jvz8fPLz8zdPdVbbGjVqxNZbb02jRgoHslFi+bg7GzZsYN26dRQXF9O9e/caDQr1ChAREalHCxcuZN26dbRr144OHTrUWTAIUFJSwqpVq2jZsiX5+fl1dl1JT+nycXcWLVrEsmXLWLhwIZ07p5pdr3q2iFHGIiIiDVVxcTH5+fl1HgxKw2NmdOjQgfz8fIqLa3ayFAWEIiIi9SgSidRpM7E0bGZGfn5+jfc1VUAoIiJSzxQMSlXUxutFAaGIiIhIjlNAKCIiIpLjFBCKiIhIg3XWWWfRo0eP+s5Gg6eAUERERGqcmaW1TJgwob6zKmgeQhEREakFTzzxRNLjxx9/nDfeeKNMep8+fTK6zt///vc6vbvLlkoBoYiIiNS4008/Penx//73P954440y6aWtXbuW5s2bp32dxo0bVyt/kkxNxiIiIlIvBg4cyC677MInn3zCQQcdRPPmzRkxYgQA48aN45hjjqFLly4UFBTQq1cvbr75ZkpKSpLOUboP4axZszAzRo4cyUMPPUSvXr0oKChgn3324aOPPqrLp9egqIZQRERE6s3SpUs56qijOOWUUzj99NPp2LEjAI8++iiFhYVcdtllFBYW8tZbb3HddddRVFTEn//850rPO3r0aFatWsV5552HmXH77bdz4oknMmPGDNUqpqCAUEREJAsNHw6TJ9fuNdzzKCkpJD8/j8rmOu7bF+66q+bzsGDBAv72t79x3nnnJaWPHj2aZs2abX58/vnnc/755/PAAw9wyy23UFBQUOF5Z8+ezbfffkvbtm0B6N27N8cffzyvvfYagwYNqvkn0sApIBQREclCkyfD22/X9lWM+g4FCgoKOPvss8ukJwaDq1atori4mAMPPJAHH3yQr7/+mt13373C85588smbg0GAAw88EIAZM2bUUM63LAoIRUREslDfvrV/DXenpKQkrXsp11Z+unbtSpMmTcqkf/XVV1xzzTW89dZbFBUVJW1buXJlpeft3r170uNYcLh8+fIMcrvlUkAoIiKShWqjeba0kpIIq1atpmXLluTn59f+BVNIrAmMWbFiBQMGDKBVq1bcdNNN9OrVi6ZNm/Lpp59y5ZVXpjXNTHnPx90zzvOWSAGhiIiIZJUJEyawdOlSnn32WQ466KDN6TNnzqzHXG3ZNO2MiIiIZJVY7V5ibd6GDRt44IEH6itLWzzVEIqIiEhW6d+/P23btuXMM89k2LBhmBlPPPGEmntrkWoIRUREJKtstdVW/Oc//6Fz585cc801jBw5ksMPP5zbb7+9vrO2xVINoYiIiNS6++67j/vuuy8pbcKECeXu379/fz788MMy6aVrCR999NGkxz169Ci3JlE1jOVTDaGIiIhIjlNAKCIiIpLjFBCKiIiI5DgFhCIiIiI5TgGhiIiISI5TQCgiIiKS4xQQioiIiOQ4BYQiIiIiOU4BoYiIiEiOU0AoIiIikuMUEIqIiIjkOAWEIiIikvVmzZqFmSXdu/iGG27AzNI63sy44YYbajRPAwcOZODAgTV6zvqigFBERERq3HHHHUfz5s1ZtWpVufsMGTKEJk2asHTp0jrMWdVMnTqVG264gVmzZtV3VmqVAkIRERGpcUOGDGHdunU899xzKbevXbuWcePG8bOf/YytttqqWte45pprWLduXSbZrNTUqVO58cYbUwaEr7/+Oq+//nqtXr+uKCAUERGRGnfcccfRsmVLRo8enXL7uHHjWLNmDUOGDKn2NRo1akTTpk2rfXymmjRpQpMmTert+jVJAaGIiIjUuGbNmnHiiSfy3//+l0WLFpXZPnr0aFq2bMkBBxzA5Zdfzq677kphYSGtWrXiqKOO4vPPP6/0Gqn6EBYXF3PppZey9dZb07JlS4477jh+/PHHMsf+8MMPXHDBBfTu3ZtmzZqx1VZbcdJJJyXVBD766KOcdNJJABx88MGYGWbGhAkTgNR9CBctWsTQoUPp2LEjTZs2Zffdd+exxx5L2ifWH3LkyJE89NBD9OrVi4KCAvbZZx8++uijSp93bWhUL1cVERGRLd6QIUN47LHHeOaZZ7jooos2py9btozXXnuNU089lfnz5/P8889z0kkn0bNnTxYuXMiDDz7IgAEDmDp1Kl26dKnSNX/961/z5JNPctppp9G/f3/eeustjjnmmDL7ffTRR3zwwQeccsopbLPNNsyaNYu//vWvDBw4kKlTp9K8eXMOOugghg0bxj333MOIESPo06cPwOa/pa1bt46BAwfy3XffcdFFF9GzZ0/GjBnDWWedxYoVK7jkkkuS9h89ejSrVq3ivPPOw8y4/fbbOfHEE5kxYwaNGzeu0vPOlAJCERERqRWHHHIInTt3ZvTo0UkB4ZgxY9i4cSNDhgxh11135ZtvviEvL95oecYZZ7DTTjvx8MMPc+2116Z9vc8//5wnn3ySCy64gPvvvx+ACy+8kCFDhjBlypSkfY855hgGDx6clHbsscey//778+9//5szzjiD7bbbjgMPPJB77rmHww8/vNIRxQ899BDTpk3jySef3NwUfv755zNgwACuueYazjnnHFq2bLl5/9mzZ/Ptt9/Stm1bAHr37s3xxx/Pa6+9xqBBg9J+3jVBAaGIiEg2Gj4cJk+u1UvkuVNYUkJefj5UNn1L375w111VOn9+fj6nnHIKd955J7NmzaJHjx5AqBnr2LEjhx56KPn5+Zv3LykpYcWKFRQWFtK7d28
2021-05-06 16:19:44 +01:00
"text/plain": [
"<Figure size 720x480 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"history.history\n",
"plt.plot(range(len(history.history[\"accuracy\"])), history.history[\"accuracy\"], label=\"Train\", c=(0, 0, 1))\n",
"plt.plot(range(len(history.history[\"val_accuracy\"])), history.history[\"val_accuracy\"], label=\"Validation\", c=(1, 0, 0))\n",
"\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.ylim(0, 1)\n",
"\n",
"plt.grid()\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test\n",
"\n",
"Single number below from the evaluate function"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 11,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-05-07 01:03:08 +01:00
"10/10 [==============================] - 0s 2ms/step - loss: 0.6832 - accuracy: 0.7634\n"
2021-05-06 16:19:44 +01:00
]
},
{
"data": {
"text/plain": [
2021-05-07 01:03:08 +01:00
"[0.6831779479980469, 0.7634069323539734]"
2021-05-06 16:19:44 +01:00
]
},
2021-05-07 01:03:08 +01:00
"execution_count": 11,
2021-05-06 16:19:44 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(data_test.to_numpy(), labels_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get raw predictions from test data to generate a confusion matrix"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 12,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [
{
"data": {
2021-05-07 01:03:08 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAGgCAYAAABR4ZjdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AACMAUlEQVR4nOzdd3wURRvA8d+kBxI6IaEjGBCl96KAgFgQpEiVJoqoIEr1FbsiqDRRERBEUYrSm40iSBECSFcJvSYkdALpmfePvVxyubuQ3F0qz9fPfmJmZ3ZnN8fdczOzM0prjRBCCCGESOGW0xUQQgghhMhtJEASQgghhEhDAiQhhBBCiDQkQBJCCCGESEMCJCGEEEKINCRAEkIIIYRIQwIkIYQQQog0JEASQgghhEhDAiQhhBBCiDQkQBJCCCGESEMCJCGEEEKINDxyugIZlch8WTQui1X2/yGnq5DvXYjZm9NVuCskJF7P6Srke1on5HQV8j2t41V2ns9Vn7Pu9M7WemeVPBMgCSGEECLrJCUluuQ47vmkbyqfXIYQQgghhOtIC5IQQgghpNs0DQmQhBBCCIHWruliyy+ki00IIYQQIg1pQRJCCCEESdLFZkECJCGEEELIGKQ0JEASQgghhARIacgYJCGEEEKINKQFSQghhBDoJGlBSk0CJCGEEEKAdLFZkC42IYQQQog0pAVJCCGEEDJIOw0JkIQQQggBSfE5XYNcRbrYhBBCCCHSkABJCCGEEGid4JLNUUopb6XUx0qpC0qpaKXUTqVU2wyWbaOU+kMpdUkpdU0pFaKU6uNwZZAASQghhBAASQmu2Rz3LTAcmA8MAxKBn5VSzdMrpJTqAPwOeAHvAmOBaGCeUuo1RysjY5CEEEII4Wxw4xSlVEOgBzBKaz3RlDYPOAR8AjRNp/gQIAx4WGsdayo7E/gP6A9McaRO0oIkhBBCiJzWFaPFaFZygtY6BpgDNFFKlUunbCHganJwZCqbAFzCaElyiLQgCSGEEMJlE0UqpQKAkjZ2RWqtI+wUqwOEaq1vpEkPMf2sDZy1U3YTMEYp9QHwHaCBXkB9oFvGa25JAiQhhBBCoFzXxfYS8I6N9PcwxgjZEoTRTZZWclrpdM73AVAJY+zRm6a020AXrfXKO1XWHgmQhBBCCOFK04HFNtIj0ynjC8TaSI9Jtd+eWCAUWAIsA9yBQcAPSqm2Wusdd6yxDRIgCSGEEMJlg7RN3Wj2utLsiQa8baT7pNpvzxdAY6Cu1joJQCn1E3AY+AxolMm6ADJIWwghhBCQ04/5h2F0s6WVnHbBViGllBcwEFibHBwBaK3jgV+A+qY8mSYtSEIIIYRA5exabPuAVkqpQmkGajdKtd+W4hixjLuNfZ4YDUG29t2RtCAJIYQQIqctIWXsEGDMrA0MAHZqrc+a0sorpaqlKhcBXAM6pW4pUkr5AU8C/2mtHXrUX1qQhBBCCAFJiTl2aq31TqXUYmC8aZqAY0A/oCJGF1qyeUALQJnKJSqlJgIfAjtMk0u6m8qUBZ5xtE7SgmRHXFwCkz5dT4vmk6lT8yO6Pz2b7duOZ6js9u0n6N/nO5o2+pRG9T+me9fZrFpxwCpf9arv29y+nrXV1ZeTp3h5efL6+8+yM/QH/otYwYqNU2jeqs4dy91zbxneGj+IpesncSRyJadu/kLZ8gHZUOPcwcvLi48+epNTp/dx/cZJtm77mdatH8pQ2dKlA1mwYBYRkUe4dPkoS5d+S6VK5S3y+Pj4MHPWZPbu3UTkpVCuXD3O7j0bGDL0OTw80v+u9dWMicTFh7N8xfcOX19u4eXlxYQJH3Lu3Alu3brCX3/9SZs2D2eobOnSpVm06AeuXAnj2rWLLF/+E5UqVbTKN3jw8/z443xOnQolKSmab76ZZX0wIDAwkPHjP2DDhl+5fj2CpKRoWrR40JnLy1eMv9VHnD9/mtu3b7BjxzbatGmd09XKtVRSgks2J/QFpgJ9gGkYXWTttdZ/pldIaz0O6A3EY0wv8AFwA+iqtZ7vaGWkBcmON15fye+//Uufvo2oULEYK5bvZ/Cghcz9ri/16pe3W27jhiMMfflHatcuy8tDW6CU4tdfDvP6mBVcvXabfv0bW+Rv2uweOnSsaZF2X/XALLmmvGLijOE89lRzvpm+glPHL9C1dxvmLn2fnk+8zu6/DtstV7fhffR/sQNH/zvDsSNnub9W5Wysdc6bM+czOndpz7RpX3Ps2An69u3OqtXzadu2C9u3hdgtV7BgAdatW0qhwoX4eMI04hPieeWVQazfsJwG9dtw5cpVAHx9fahevSq//rqBU6fPkpSURJMmDZg48X0aNqhL374v2Tx+3Xq16Nu3O9HRDk9om6vMnfs1Xbt24rPPvuDo0WP069eHtWtX8PDDj7Jt23a75QoWLMjGjb9SuHAhxo//lPj4eF59dSibNq2jTp1GXLlyxZx39OgR+Pv7ERKym6Ag++8HVavey5gxIwkNPcrBg4dp2rSx3bx3o2+/nUPXrl2YOnUaR48eo3//vvz882patWrLtm3bcrp6Ig3TzNmjTJu9PC3tpC8AFriyPkpr7crjZZlE5mdbRQ8cOE+Pp+cwcnQbnh1oLP8SG5tAh/ZfUbx4QRYsetZu2eee/YFjRyP5fcNQvLyM+DMhIYknHvuSAr5eLF/1gjlv9arv06t3A958+7GsvaAMquz/Q05XgVr1glm56TPGjZ3N19OWAuDt7clvO2dw+dI1urQZYbds4aJ+JMQncisqmudf6cLYcc/R/P5+nDuT2adNs86FmL1Zctz6DeqwffsvjBn9HlOmfAWAt7c3e/dtIjLyEi0eetJu2REjXmb8hLdo0uRR9uzeB0DVqlXYu28TkyZ+yVtvjU/33FOmjuPllwdSrmwNLl60nuZk85+r+e+/o7Rq9SCHD/9Hp6ecWmA7QxISr2fJcRs0qM/OnVsYNep/TJo0FTDu88GDe4iIiKR581Z2y44aNZyPPx5Hw4bN2b17DwBVqwZz8OAePv10MmPHpsyrV758ec6cOQPAjRuRLFmynGefHWR1TD8/Pzw9Pbl69SpdunRi8eIFtGr1CJs3b3HhVdvmzKrt2aFBgwaEhGxn5MjRTJpkLMXl7e3NoUP7iIiIpFmzjLWu5iSt41V2nu/Wqa4u+ZwtWHFJttY7qzjUxaaUGqSU+lcpFaOUOq+UmmIaTJUv/P7rP7i7K7p1r2dO8/b2oEvXOuzbe46wMPtvvlFRsRQq7GMOjgA8PNwoWrQA3j62G+xiYuKJjc3dbzbZ5bGnmpOQkMjCub+Y02Jj4/np+9+o16g6QWVK2C17/WoUt6LyRytFZnXp3J6EhARmz07pwoqNjeXbuQto0qQBZcvan4S2c5f27Nq11xwcARw5coyNG7fQpWuHO5779Clj9v8iRQpb7Xvmmae5//5qvH2HICuv6Nq1EwkJCcyaNcecFhsbyzfffEvTpo0pW7as3bJdunQiJGS3OTgCOHIklA0b/uDpp7tY5E0Oju4kKiqKq1evZvIq7g5du3Y2/a1mm9NiY2OZM2cuTZs2SfdvddfK2cf8c51MB0hKqaeAGRiDnw6YjvGKKS1f+PffcCpULI6fn2XMV6Om8SHz37/hdss2bFiBY0cjmTb1D06fvsKZM1f46ss/OXzoAgOfs16MePnyfdSrPZ46NT+i/ePTWbP6oGsvJo+5v2ZlTh47T9TN2xbp+3aHAlC95t3VbZZRtWo/wNHQE9y8GWWRvmuX0WJVq9b9NssppahR4z727NlvtW/3rr1UqVIJP7+CFumenp4UL16MsmVL07HjY7w2/EVOnTrLsWMnLfL5+RVk3Edv8vGEz2y2LOVFtWvXIjT0KDdv3rRIDwnZbdpf01YxlFLUrPkAe/bssdq3a9duqlSpjJ+fn+srfBerU6c2oaGhNv5WuwDjbyksqaREl2z5hSNjkIYDx4HmWuuLSikP4Hugt1JqmI2F5vKcyMgoSpa0frMqWdIfgIiIKKt9yQa/9BDnzl1j5owtzPjKaOb29fVk6rRutG5T1SJvnTplefSx+ylTtggRETdZuGA3o0cuJ+pmLD1
2021-05-06 16:19:44 +01:00
"text/plain": [
"<Figure size 720x480 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"predictions = model(data_test.to_numpy())\n",
"\n",
"conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
" [tf.math.argmax(i) for i in predictions], \n",
" num_classes=len(filtered_playlists))\n",
"\n",
"normalised_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
"for idx, row in enumerate(conf):\n",
" normalised_conf[idx, :] = row / np.sum(row)\n",
"\n",
"sns.heatmap(normalised_conf, \n",
" annot=True, \n",
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
" cmap='inferno')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports & Setup"
]
},
{
"cell_type": "code",
2021-05-07 01:03:08 +01:00
"execution_count": 1,
2021-05-06 16:19:44 +01:00
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import os\n",
"\n",
"from google.cloud import bigquery\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.dpi'] = 120\n",
"import seaborn as sns\n",
"\n",
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
"from analysis.query import *\n",
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn import svm\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import plot_confusion_matrix\n",
"from sklearn.utils import class_weight\n",
"\n",
"import tensorflow as tf\n",
"\n",
"client = bigquery.Client()\n",
"spotnet = get_spotnet()\n",
"cache = 'query.csv'\n",
"first_day = datetime(year=2017, month=11, day=3)\n",
"sig_max, c_max = 0.5, 20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read Scrobble Frame"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"scrobbles = get_query(cache=cache)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Write Scrobble Frame"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}