2021-05-06 16:19:44 +01:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Playlist Neural Network\n",
|
|
|
|
"\n",
|
|
|
|
"Given a list of playlists, can unknown tracks be correctly classified?"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 3,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
|
|
|
|
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
|
|
|
|
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
|
|
|
|
"playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
|
|
|
|
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
|
|
|
|
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
|
|
|
|
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
|
|
|
|
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
|
|
|
|
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
|
|
|
|
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
|
|
|
|
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
|
|
|
|
"\n",
|
|
|
|
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
|
|
|
|
"headers = float_headers\n",
|
|
|
|
"\n",
|
|
|
|
"BALANCED_WEIGHTS = True"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"Pull and process playlist information.\n",
|
|
|
|
"\n",
|
|
|
|
"1. Get live playlist track information from spotify\n",
|
|
|
|
"2. Filter listening history for these tracks\n",
|
|
|
|
"\n",
|
|
|
|
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 4,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
|
|
|
|
"\n",
|
|
|
|
"# filter playlists by join with playlist track/artist names\n",
|
|
|
|
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
|
|
|
|
"\n",
|
|
|
|
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
|
|
|
|
"# distinct on uri\n",
|
|
|
|
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
|
|
|
|
"# select only descriptor float columns\n",
|
|
|
|
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"Construct the dataset with associated labels before splitting into a train and test set."
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 5,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"dataset = pd.concat(filtered_playlists)\n",
|
|
|
|
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
|
|
|
|
"labels = np.concatenate(labels)\n",
|
|
|
|
"\n",
|
|
|
|
"# stratify: maintains class proportions in test and train set\n",
|
|
|
|
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, \n",
|
|
|
|
" test_size=0.1, \n",
|
|
|
|
"# random_state=70, \n",
|
|
|
|
" stratify=labels\n",
|
|
|
|
" )\n",
|
|
|
|
"\n",
|
|
|
|
"class_weights = class_weight.compute_class_weight('balanced',\n",
|
|
|
|
" classes=np.unique(labels_train),\n",
|
|
|
|
" y=labels_train)\n",
|
|
|
|
"class_weights = {i: j for i, j in zip(range(len(filtered_playlists)), class_weights)}\n",
|
|
|
|
"\n",
|
|
|
|
"labels_train = tf.one_hot(labels_train, len(filtered_playlists))\n",
|
|
|
|
"labels_test = tf.one_hot(labels_test, len(filtered_playlists))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 6,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
|
|
|
|
" return tf.keras.callbacks.TensorBoard(\n",
|
|
|
|
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
|
|
|
|
" )"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 7,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def get_model(hidden_nodes=9,\n",
|
|
|
|
" layers=1,\n",
|
|
|
|
" classes=len(filtered_playlists),\n",
|
|
|
|
" activation=lambda: 'sigmoid', \n",
|
|
|
|
" weight_init=lambda: 'glorot_uniform'):\n",
|
|
|
|
" l = [tf.keras.layers.InputLayer(input_shape=data_train.to_numpy()[0].shape, name='Input')]\n",
|
|
|
|
" \n",
|
|
|
|
" for i in range(layers):\n",
|
|
|
|
" l.append(\n",
|
|
|
|
" tf.keras.layers.Dense(hidden_nodes, \n",
|
|
|
|
" activation=activation(), \n",
|
|
|
|
" kernel_initializer=weight_init(), \n",
|
|
|
|
" name=f'Hidden{i+1}')\n",
|
|
|
|
" )\n",
|
|
|
|
" \n",
|
|
|
|
" l.append(tf.keras.layers.Dense(classes, \n",
|
|
|
|
" activation='softmax', \n",
|
|
|
|
" kernel_initializer=weight_init(), \n",
|
|
|
|
" name='Output'))\n",
|
|
|
|
" \n",
|
|
|
|
" model = tf.keras.models.Sequential(l)\n",
|
|
|
|
" return model"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Single Model"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 8,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-05-07 01:03:08 +01:00
|
|
|
"Model: \"sequential\"\n",
|
2021-05-06 16:19:44 +01:00
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Layer (type) Output Shape Param # \n",
|
|
|
|
"=================================================================\n",
|
|
|
|
"Hidden1 (Dense) (None, 64) 512 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Hidden2 (Dense) (None, 64) 4160 \n",
|
|
|
|
"_________________________________________________________________\n",
|
|
|
|
"Output (Dense) (None, 5) 325 \n",
|
|
|
|
"=================================================================\n",
|
|
|
|
"Total params: 4,997\n",
|
|
|
|
"Trainable params: 4,997\n",
|
|
|
|
"Non-trainable params: 0\n",
|
|
|
|
"_________________________________________________________________\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"model = get_model(hidden_nodes=64, layers=2)\n",
|
|
|
|
"\n",
|
|
|
|
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
|
|
|
|
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
|
|
|
|
" loss='categorical_crossentropy', \n",
|
|
|
|
" metrics=['accuracy'])\n",
|
|
|
|
"model.summary()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Train"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 9,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"if BALANCED_WEIGHTS:\n",
|
|
|
|
" cw = class_weights\n",
|
|
|
|
"else:\n",
|
|
|
|
" cw = None\n",
|
|
|
|
"history = model.fit(data_train.to_numpy(), labels_train, \n",
|
|
|
|
" callbacks=[tensorboard_callback()], \n",
|
|
|
|
" validation_split=0.11,\n",
|
|
|
|
" verbose=0,\n",
|
|
|
|
" class_weight=cw,\n",
|
|
|
|
" epochs=50)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 10,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2021-05-07 01:03:08 +01:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoQAAAG/CAYAAADB4sa8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AABhqUlEQVR4nO3dd5iU1dnH8e+9CyxlqSJVEERFYsOGSlSwR8USg7GgsZCosSAafTXYS2I0xK6JJsYKRiEqGrsR7AEbooJYAEF6X+oCO/f7x5lhZnZnd2d3tswyv891Pdc+c552Zs+Ue057zN0RERERkdyVV98ZEBEREZH6pYBQREREJMcpIBQRERHJcQoIRURERHKcAkIRERGRHKeAUERERCTHKSAUERERyXEKCEVERERynAJCERERkRyngFBEREQkxykgFBEREclxCghFREREclxWB4RmVmhmN5rZq2a2zMzczM6qwvFtzOwhM1tsZmvMbLyZ7VmLWRYRERFpcLI6IATaA9cBfYDPq3KgmeUBLwGnAfcB/wd0ACaY2Q41nE8RERGRBqtRfWegEvOBzu6+wMz2Bj6qwrGDgf7ASe4+FsDMngG+AW4kBIoiIiIiOS+rawjdvdjdF1Tz8MHAQuDZhPMtBp4BjjezghrIooiIiEiDl+01hJnYA/jU3SOl0icB5wI7Al+kOtDMOgBbl0oujB7zJbChZrMqIiIiUqOaAN2At919ZWU7b8kBYWfgnRTp86N/u1BOQAhcAFxfG5kSERERqUPHAy9UttOWHBA2A4pTpK9P2F6eB4AxpdJ2AsaOHj2a7bbbrgayV9a6dev44osv2HXXXWnWrKLsSX1Q+WQvlU12U/lkL5VNdsukfGbMmMFpp50GMCed/bfkgHAdkKqfYNOE7Sm5+yJgUWKamQGw2267sfPOO9dQFpMVFRWxcuVK9txzT1q1alUr15DqU/lkL5VNdlP5ZC+VTXbLpHwKCwtjq2l1c8vqQSUZmk9oNi4tljavDvMiIiIikrW25IBwMrBndD7CRPsCawnTz4iIiIjkvC0iIDSzzma2k5k1TkgeC3QETkzYrz1wEvCiu6fqXygiIiKSc7K+D6GZXQS0IYwKBjjWzLaJrt8bHUp9K3Am0BOYFd02Fvgf8IiZ/QRYQhg9nI9GEIuIiIhslvUBIXA5sG3C4xOJ1/o9CaScW8fdS8zsaODPwDDCqOKPgLPcfXrtZVdERESkYcn6gNDde6Sxz1nAWSnSlwO/ji4iIiJZKRKJsHDhQoqLi4lESt9PofZs3LiRdu3aMX/+fJYsWVJn15X0lC6fvLw8CgoK6NixI3l5NdvrL+sDQhERkS1ZJBJh9uzZrFu3jvz8fPLz8zdPdVbbGjVqxNZbb02jRgoHslFi+bg7GzZsYN26dRQXF9O9e/caDQr1ChAREalHCxcuZN26dbRr144OHTrUWTAIUFJSwqpVq2jZsiX5+fl1dl1JT+nycXcWLVrEsmXLWLhwIZ07p5pdr3q2iFHGIiIiDVVxcTH5+fl1HgxKw2NmdOjQgfz8fIqLa3ayFAWEIiIi9SgSidRpM7E0bGZGfn5+jfc1VUAoIiJSzxQMSlXUxutFAaGIiIhIjlNAKCIiIpLjFBCKiIhIg3XWWWfRo0eP+s5Gg6eAUERERGqcmaW1TJgwob6zKmgeQhEREakFTzzxRNLjxx9/nDfeeKNMep8+fTK6zt///vc6vbvLlkoBoYiIiNS4008/Penx//73P954440y6aWtXbuW5s2bp32dxo0bVyt/kkxNxiIiIlIvBg4cyC677MInn3zCQQcdRPPmzRkxYgQA48aN45hjjqFLly4UFBTQq1cvbr75ZkpKSpLOUboP4axZszAzRo4cyUMPPUSvXr0oKChgn3324aOPPqrLp9egqIZQRERE6s3SpUs56qijOOWUUzj99NPp2LEjAI8++iiFhYVcdtllFBYW8tZbb3HddddRVFTEn//850rPO3r0aFatWsV5552HmXH77bdz4oknMmPGDNUqpqCAUEREJAsNHw6TJ9fuNdzzKCkpJD8/j8rmOu7bF+66q+bzsGDBAv72t79x3nnnJaWPHj2aZs2abX58/vnnc/755/PAAw9wyy23UFBQUOF5Z8+ezbfffkvbtm0B6N27N8cffzyvvfYagwYNqvkn0sApIBQREclCkyfD22/X9lWM+g4FCgoKOPvss8ukJwaDq1atori4mAMPPJAHH3yQr7/+mt13373C85588smbg0GAAw88EIAZM2bUUM63LAoIRUREslDfvrV/DXenpKQkrXsp11Z+unbtSpMmTcqkf/XVV1xzzTW89dZbFBUVJW1buXJlpeft3r170uNYcLh8+fIMcrvlUkAoIiKShWqjeba0kpIIq1atpmXLluTn59f+BVNIrAmMWbFiBQMGDKBVq1bcdNNN9OrVi6ZNm/Lpp59y5ZVXpjXNTHnPx90zzvOWSAGhiIiIZJUJEyawdOlSnn32WQ466KDN6TNnzqzHXG3ZNO2MiIiIZJVY7V5ibd6GDRt44IEH6itLWzzVEIqIiEhW6d+/P23btuXMM89k2LBhmBlPPPGEmntrkWoIRUREJKtstdVW/Oc//6Fz585cc801jBw5ksMPP5zbb7+9vrO2xVINoYiIiNS6++67j/vuuy8pbcKECeXu379/fz788MMy6aVrCR999NGkxz169Ci3JlE1jOVTDaGIiIhIjlNAKCIiIpLjFBCKiIiI5DgFhCIiIiI5TgGhiIiISI5TQCgiIiKS4xQQioiIiOQ4BYQiIiIiOU4BoYiIiEiOU0AoIiIikuMUEIqIiIjkOAWEIiIikvVmzZqFmSXdu/iGG27AzNI63sy44YYbajRPAwcOZODAgTV6zvqigFBERERq3HHHHUfz5s1ZtWpVufsMGTKEJk2asHTp0jrMWdVMnTqVG264gVmzZtV3VmqVAkIRERGpcUOGDGHdunU899xzKbevXbuWcePG8bOf/YytttqqWte45pprWLduXSbZrNTUqVO58cYbUwaEr7/+Oq+//nqtXr+uKCAUERGRGnfcccfRsmVLRo8enXL7uHHjWLNmDUOGDKn2NRo1akTTpk2rfXymmjRpQpMmTert+jVJAaGIiIjUuGbNmnHiiSfy3//+l0WLFpXZPnr0aFq2bMkBBxzA5Zdfzq677kphYSGtWrXiqKOO4vPPP6/0Gqn6EBYXF3PppZey9dZb07JlS4477jh+/PHHMsf+8MMPXHDBBfTu3ZtmzZqx1VZbcdJJJyXVBD766KOcdNJJABx88MGYGWbGhAkTgNR9CBctWsTQoUPp2LEjTZs2Zffdd+exxx5L2ifWH3LkyJE89NBD9OrVi4KCAvbZZx8++uijSp93bWhUL1cVERGRLd6QIUN47LHHeOaZZ7jooos2py9btozXXnuNU089lfnz5/P8889z0kkn0bNnTxYuXMiDDz7IgAEDmDp1Kl26dKnSNX/961/z5JNPctppp9G/f3/eeustjjnmmDL7ffTRR3zwwQeccsopbLPNNsyaNYu//vWvDBw4kKlTp9K8eXMOOugghg0bxj333MOIESPo06cPwOa/pa1bt46BAwfy3XffcdFFF9GzZ0/GjBnDWWedxYoVK7jkkkuS9h89ejSrVq3ivPPOw8y4/fbbOfHEE5kxYwaNGzeu0vPOlAJCERERqRWHHHIInTt3ZvTo0UkB4ZgxY9i4cSNDhgxh11135ZtvviEvL95oecYZZ7DTTjvx8MMPc+2116Z9vc8//5wnn3ySCy64gPvvvx+ACy+8kCFDhjBlypSkfY855hgGDx6clHbsscey//778+9//5szzjiD7bbbjgMPPJB77rmHww8/vNIRxQ899BDTpk3jySef3NwUfv755zNgwACuueYazjnnHFq2bLl5/9mzZ/Ptt9/Stm1bAHr37s3xxx/Pa6+9xqBBg9J+3jVBAaGIiEg2Gj4cJk+u1UvkuVNYUkJefj5UNn1L375w111VOn9+fj6nnHIKd955J7NmzaJHjx5AqBnr2LEjhx56KPn5+Zv3LykpYcWKFRQWFtK7d28
|
2021-05-06 16:19:44 +01:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 720x480 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"history.history\n",
|
|
|
|
"plt.plot(range(len(history.history[\"accuracy\"])), history.history[\"accuracy\"], label=\"Train\", c=(0, 0, 1))\n",
|
|
|
|
"plt.plot(range(len(history.history[\"val_accuracy\"])), history.history[\"val_accuracy\"], label=\"Validation\", c=(1, 0, 0))\n",
|
|
|
|
"\n",
|
|
|
|
"plt.xlabel(\"Epochs\")\n",
|
|
|
|
"plt.ylabel(\"Accuracy\")\n",
|
|
|
|
"plt.ylim(0, 1)\n",
|
|
|
|
"\n",
|
|
|
|
"plt.grid()\n",
|
|
|
|
"plt.legend()\n",
|
|
|
|
"plt.show()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Test\n",
|
|
|
|
"\n",
|
|
|
|
"Single number below from the evaluate function"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 11,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2021-05-07 01:03:08 +01:00
|
|
|
"10/10 [==============================] - 0s 2ms/step - loss: 0.6832 - accuracy: 0.7634\n"
|
2021-05-06 16:19:44 +01:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2021-05-07 01:03:08 +01:00
|
|
|
"[0.6831779479980469, 0.7634069323539734]"
|
2021-05-06 16:19:44 +01:00
|
|
|
]
|
|
|
|
},
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 11,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"model.evaluate(data_test.to_numpy(), labels_test)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"Get raw predictions from test data to generate a confusion matrix"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 12,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2021-05-07 01:03:08 +01:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAGgCAYAAABR4ZjdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AACMAUlEQVR4nOzdd3wURRvA8d+kBxI6IaEjGBCl96KAgFgQpEiVJoqoIEr1FbsiqDRRERBEUYrSm40iSBECSFcJvSYkdALpmfePvVxyubuQ3F0qz9fPfmJmZ3ZnN8fdczOzM0prjRBCCCGESOGW0xUQQgghhMhtJEASQgghhEhDAiQhhBBCiDQkQBJCCCGESEMCJCGEEEKINCRAEkIIIYRIQwIkIYQQQog0JEASQgghhEhDAiQhhBBCiDQkQBJCCCGESEMCJCGEEEKINDxyugIZlch8WTQui1X2/yGnq5DvXYjZm9NVuCskJF7P6Srke1on5HQV8j2t41V2ns9Vn7Pu9M7WemeVPBMgCSGEECLrJCUluuQ47vmkbyqfXIYQQgghhOtIC5IQQgghpNs0DQmQhBBCCIHWruliyy+ki00IIYQQIg1pQRJCCCEESdLFZkECJCGEEELIGKQ0JEASQgghhARIacgYJCGEEEKINKQFSQghhBDoJGlBSk0CJCGEEEKAdLFZkC42IYQQQog0pAVJCCGEEDJIOw0JkIQQQggBSfE5XYNcRbrYhBBCCCHSkABJCCGEEGid4JLNUUopb6XUx0qpC0qpaKXUTqVU2wyWbaOU+kMpdUkpdU0pFaKU6uNwZZAASQghhBAASQmu2Rz3LTAcmA8MAxKBn5VSzdMrpJTqAPwOeAHvAmOBaGCeUuo1RysjY5CEEEII4Wxw4xSlVEOgBzBKaz3RlDYPOAR8AjRNp/gQIAx4WGsdayo7E/gP6A9McaRO0oIkhBBCiJzWFaPFaFZygtY6BpgDNFFKlUunbCHganJwZCqbAFzCaElyiLQgCSGEEMJlE0UqpQKAkjZ2RWqtI+wUqwOEaq1vpEkPMf2sDZy1U3YTMEYp9QHwHaCBXkB9oFvGa25JAiQhhBBCoFzXxfYS8I6N9PcwxgjZEoTRTZZWclrpdM73AVAJY+zRm6a020AXrfXKO1XWHgmQhBBCCOFK04HFNtIj0ynjC8TaSI9Jtd+eWCAUWAIsA9yBQcAPSqm2Wusdd6yxDRIgCSGEEMJlg7RN3Wj2utLsiQa8baT7pNpvzxdAY6Cu1joJQCn1E3AY+AxolMm6ADJIWwghhBCQ04/5h2F0s6WVnHbBViGllBcwEFibHBwBaK3jgV+A+qY8mSYtSEIIIYRA5exabPuAVkqpQmkGajdKtd+W4hixjLuNfZ4YDUG29t2RtCAJIYQQIqctIWXsEGDMrA0MAHZqrc+a0sorpaqlKhcBXAM6pW4pUkr5AU8C/2mtHXrUX1qQhBBCCAFJiTl2aq31TqXUYmC8aZqAY0A/oCJGF1qyeUALQJnKJSqlJgIfAjtMk0u6m8qUBZ5xtE7SgmRHXFwCkz5dT4vmk6lT8yO6Pz2b7duOZ6js9u0n6N/nO5o2+pRG9T+me9fZrFpxwCpf9arv29y+nrXV1ZeTp3h5efL6+8+yM/QH/otYwYqNU2jeqs4dy91zbxneGj+IpesncSRyJadu/kLZ8gHZUOPcwcvLi48+epNTp/dx/cZJtm77mdatH8pQ2dKlA1mwYBYRkUe4dPkoS5d+S6VK5S3y+Pj4MHPWZPbu3UTkpVCuXD3O7j0bGDL0OTw80v+u9dWMicTFh7N8xfcOX19u4eXlxYQJH3Lu3Alu3brCX3/9SZs2D2eobOnSpVm06AeuXAnj2rWLLF/+E5UqVbTKN3jw8/z443xOnQolKSmab76ZZX0wIDAwkPHjP2DDhl+5fj2CpKRoWrR40JnLy1eMv9VHnD9/mtu3b7BjxzbatGmd09XKtVRSgks2J/QFpgJ9gGkYXWTttdZ/pldIaz0O6A3EY0wv8AFwA+iqtZ7vaGWkBcmON15fye+//Uufvo2oULEYK5bvZ/Cghcz9ri/16pe3W27jhiMMfflHatcuy8tDW6CU4tdfDvP6mBVcvXabfv0bW+Rv2uweOnSsaZF2X/XALLmmvGLijOE89lRzvpm+glPHL9C1dxvmLn2fnk+8zu6/DtstV7fhffR/sQNH/zvDsSNnub9W5Wysdc6bM+czOndpz7RpX3Ps2An69u3OqtXzadu2C9u3hdgtV7BgAdatW0qhwoX4eMI04hPieeWVQazfsJwG9dtw5cpVAHx9fahevSq//rqBU6fPkpSURJMmDZg48X0aNqhL374v2Tx+3Xq16Nu3O9HRDk9om6vMnfs1Xbt24rPPvuDo0WP069eHtWtX8PDDj7Jt23a75QoWLMjGjb9SuHAhxo//lPj4eF59dSibNq2jTp1GXLlyxZx39OgR+Pv7ERKym6Ag++8HVavey5gxIwkNPcrBg4dp2rSx3bx3o2+/nUPXrl2YOnUaR48eo3//vvz882patWrLtm3bcrp6Ig3TzNmjTJu9PC3tpC8AFriyPkpr7crjZZlE5mdbRQ8cOE+Pp+cwcnQbnh1oLP8SG5tAh/ZfUbx4QRYsetZu2eee/YFjRyP5fcNQvLyM+DMhIYknHvuSAr5eLF/1gjlv9arv06t3A958+7GsvaAMquz/Q05XgVr1glm56TPGjZ3N19OWAuDt7clvO2dw+dI1urQZYbds4aJ+JMQncisqmudf6cLYcc/R/P5+nDuT2adNs86FmL1Zctz6DeqwffsvjBn9HlOmfAWAt7c3e/dtIjLyEi0eetJu2REjXmb8hLdo0uRR9uzeB0DVqlXYu28TkyZ+yVtvjU/33FOmjuPllwdSrmwNLl60nuZk85+r+e+/o7Rq9SCHD/9Hp6ecWmA7QxISr2fJcRs0qM/OnVsYNep/TJo0FTDu88GDe4iIiKR581Z2y44aNZyPPx5Hw4bN2b17DwBVqwZz8OAePv10MmPHpsyrV758ec6cOQPAjRuRLFmynGefHWR1TD8/Pzw9Pbl69SpdunRi8eIFtGr1CJs3b3HhVdvmzKrt2aFBgwaEhGxn5MjRTJpkLMXl7e3NoUP7iIiIpFmzjLWu5iSt41V2nu/Wqa4u+ZwtWHFJttY7qzjUxaaUGqSU+lcpFaOUOq+UmmIaTJUv/P7rP7i7K7p1r2dO8/b2oEvXOuzbe46wMPtvvlFRsRQq7GMOjgA8PNwoWrQA3j62G+xiYuKJjc3dbzbZ5bGnmpOQkMjCub+Y02Jj4/np+9+o16g6QWVK2C17/WoUt6LyRytFZnXp3J6EhARmz07pwoqNjeXbuQto0qQBZcvan4S2c5f27Nq11xwcARw5coyNG7fQpWuHO5779Clj9v8iRQpb7Xvmmae5//5qvH2HICuv6Nq1EwkJCcyaNcecFhsbyzfffEvTpo0pW7as3bJdunQiJGS3OTgCOHIklA0b/uDpp7tY5E0Oju4kKiqKq1evZvIq7g5du3Y2/a1mm9NiY2OZM2cuTZs2SfdvddfK2cf8c51MB0hKqaeAGRiDnw6YjvGKKS1f+PffcCpULI6fn2XMV6Om8SHz37/hdss2bFiBY0cjmTb1D06fvsKZM1f46ss/OXzoAgOfs16MePnyfdSrPZ46NT+i/ePTWbP6oGsvJo+5v2ZlTh47T9TN2xbp+3aHAlC95t3VbZZRtWo/wNHQE9y8GWWRvmuX0WJVq9b9NssppahR4z727NlvtW/3rr1UqVIJP7+CFumenp4UL16MsmVL07HjY7w2/EVOnTrLsWMnLfL5+RVk3Edv8vGEz2y2LOVFtWvXIjT0KDdv3rRIDwnZbdpf01YxlFLUrPkAe/bssdq3a9duqlSpjJ+fn+srfBerU6c2oaGhNv5WuwDjbyksqaREl2z5hSNjkIYDx4HmWuuLSikP4Hugt1JqmI2F5vKcyMgoSpa0frMqWdIfgIiIKKt9yQa/9BDnzl1j5owtzPjKaOb29fVk6rRutG5T1SJvnTplefSx+ylTtggRETdZuGA3o0cuJ+pmLD1
|
2021-05-06 16:19:44 +01:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 720x480 with 2 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"predictions = model(data_test.to_numpy())\n",
|
|
|
|
"\n",
|
|
|
|
"conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
|
|
|
|
" [tf.math.argmax(i) for i in predictions], \n",
|
|
|
|
" num_classes=len(filtered_playlists))\n",
|
|
|
|
"\n",
|
|
|
|
"normalised_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
|
|
|
|
"for idx, row in enumerate(conf):\n",
|
|
|
|
" normalised_conf[idx, :] = row / np.sum(row)\n",
|
|
|
|
"\n",
|
|
|
|
"sns.heatmap(normalised_conf, \n",
|
|
|
|
" annot=True, \n",
|
|
|
|
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
|
|
|
|
" cmap='inferno')\n",
|
|
|
|
"plt.show()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Imports & Setup"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2021-05-07 01:03:08 +01:00
|
|
|
"execution_count": 1,
|
2021-05-06 16:19:44 +01:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"from datetime import datetime\n",
|
|
|
|
"import os\n",
|
|
|
|
"\n",
|
|
|
|
"from google.cloud import bigquery\n",
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
"import matplotlib as mpl\n",
|
|
|
|
"mpl.rcParams['figure.dpi'] = 120\n",
|
|
|
|
"import seaborn as sns\n",
|
|
|
|
"\n",
|
|
|
|
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
|
|
|
|
"from analysis.query import *\n",
|
|
|
|
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
|
|
|
|
"\n",
|
|
|
|
"import numpy as np\n",
|
|
|
|
"import pandas as pd\n",
|
|
|
|
"\n",
|
|
|
|
"from sklearn import svm\n",
|
|
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
|
|
"from sklearn.metrics import plot_confusion_matrix\n",
|
|
|
|
"from sklearn.utils import class_weight\n",
|
|
|
|
"\n",
|
|
|
|
"import tensorflow as tf\n",
|
|
|
|
"\n",
|
|
|
|
"client = bigquery.Client()\n",
|
|
|
|
"spotnet = get_spotnet()\n",
|
|
|
|
"cache = 'query.csv'\n",
|
|
|
|
"first_day = datetime(year=2017, month=11, day=3)\n",
|
|
|
|
"sig_max, c_max = 0.5, 20"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Read Scrobble Frame"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"scrobbles = get_query(cache=cache)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Write Scrobble Frame"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 6,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.8.9"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 4
|
|
|
|
}
|