418 lines
84 KiB
Plaintext
418 lines
84 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Playlist Neural Network\n",
|
||
|
"\n",
|
||
|
"Given a list of playlists, can unknown tracks be correctly classified?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 216,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
|
||
|
"# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n",
|
||
|
"# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
|
||
|
"playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n",
|
||
|
"# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n",
|
||
|
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n",
|
||
|
"# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n",
|
||
|
"# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n",
|
||
|
"# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n",
|
||
|
"# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n",
|
||
|
"# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n",
|
||
|
"\n",
|
||
|
"# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n",
|
||
|
"headers = float_headers\n",
|
||
|
"\n",
|
||
|
"BALANCED_WEIGHTS = True"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Pull and process playlist information.\n",
|
||
|
"\n",
|
||
|
"1. Get live playlist track information from spotify\n",
|
||
|
"2. Filter listening history for these tracks\n",
|
||
|
"\n",
|
||
|
"Filter out tracks without features and drop duplicates before taking only the descriptor parameters"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 217,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n",
|
||
|
"\n",
|
||
|
"# filter playlists by join with playlist track/artist names\n",
|
||
|
"filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n",
|
||
|
"\n",
|
||
|
"filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n",
|
||
|
"# distinct on uri\n",
|
||
|
"filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n",
|
||
|
"# select only descriptor float columns\n",
|
||
|
"filtered_playlists = [i.loc[:, headers] for i in filtered_playlists]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Construct the dataset with associated labels before splitting into a train and test set."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 218,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"dataset = pd.concat(filtered_playlists)\n",
|
||
|
"labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n",
|
||
|
"labels = np.concatenate(labels)\n",
|
||
|
"\n",
|
||
|
"# stratify: maintains class proportions in test and train set\n",
|
||
|
"data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, \n",
|
||
|
" test_size=0.1, \n",
|
||
|
"# random_state=70, \n",
|
||
|
" stratify=labels\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"class_weights = class_weight.compute_class_weight('balanced',\n",
|
||
|
" classes=np.unique(labels_train),\n",
|
||
|
" y=labels_train)\n",
|
||
|
"class_weights = {i: j for i, j in zip(range(len(filtered_playlists)), class_weights)}\n",
|
||
|
"\n",
|
||
|
"labels_train = tf.one_hot(labels_train, len(filtered_playlists))\n",
|
||
|
"labels_test = tf.one_hot(labels_test, len(filtered_playlists))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 219,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
|
||
|
" return tf.keras.callbacks.TensorBoard(\n",
|
||
|
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
|
||
|
" )"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 220,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_model(hidden_nodes=9,\n",
|
||
|
" layers=1,\n",
|
||
|
" classes=len(filtered_playlists),\n",
|
||
|
" activation=lambda: 'sigmoid', \n",
|
||
|
" weight_init=lambda: 'glorot_uniform'):\n",
|
||
|
" l = [tf.keras.layers.InputLayer(input_shape=data_train.to_numpy()[0].shape, name='Input')]\n",
|
||
|
" \n",
|
||
|
" for i in range(layers):\n",
|
||
|
" l.append(\n",
|
||
|
" tf.keras.layers.Dense(hidden_nodes, \n",
|
||
|
" activation=activation(), \n",
|
||
|
" kernel_initializer=weight_init(), \n",
|
||
|
" name=f'Hidden{i+1}')\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" l.append(tf.keras.layers.Dense(classes, \n",
|
||
|
" activation='softmax', \n",
|
||
|
" kernel_initializer=weight_init(), \n",
|
||
|
" name='Output'))\n",
|
||
|
" \n",
|
||
|
" model = tf.keras.models.Sequential(l)\n",
|
||
|
" return model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Single Model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 226,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model: \"sequential_27\"\n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Layer (type) Output Shape Param # \n",
|
||
|
"=================================================================\n",
|
||
|
"Hidden1 (Dense) (None, 64) 512 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Hidden2 (Dense) (None, 64) 4160 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Output (Dense) (None, 5) 325 \n",
|
||
|
"=================================================================\n",
|
||
|
"Total params: 4,997\n",
|
||
|
"Trainable params: 4,997\n",
|
||
|
"Non-trainable params: 0\n",
|
||
|
"_________________________________________________________________\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model = get_model(hidden_nodes=64, layers=2)\n",
|
||
|
"\n",
|
||
|
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n",
|
||
|
"# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n",
|
||
|
" loss='categorical_crossentropy', \n",
|
||
|
" metrics=['accuracy'])\n",
|
||
|
"model.summary()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Train"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 227,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"if BALANCED_WEIGHTS:\n",
|
||
|
" cw = class_weights\n",
|
||
|
"else:\n",
|
||
|
" cw = None\n",
|
||
|
"history = model.fit(data_train.to_numpy(), labels_train, \n",
|
||
|
" callbacks=[tensorboard_callback()], \n",
|
||
|
" validation_split=0.11,\n",
|
||
|
" verbose=0,\n",
|
||
|
" class_weight=cw,\n",
|
||
|
" epochs=50)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 228,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoQAAAG/CAYAAADB4sa8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AABaA0lEQVR4nO3deXxU1f3/8dcnAcISNtlXQVSkiuKGgiKL2lZFbRXccLd1q+JS/dqitlat/rRad6u21h2r4L60KgruGjdEVNwAAdnXQAgBks/vjzOTTIbJOplkwryfj8d9zMxdz+Qkmfecc++55u6IiIiISObKaugCiIiIiEjDUiAUERERyXAKhCIiIiIZToFQREREJMMpEIqIiIhkOAVCERERkQynQCgiIiKS4RQIRURERDKcAqGIiIhIhlMgFBEREclwCoQiIiIiGU6BUERERCTDpXUgNLNcM/uLmf3PzFaamZvZqTXYvp2Z3Wdmy8yswMymmtkeKSyyiIiISKOT1oEQ6Aj8CRgAfF6TDc0sC3gJOAG4E/g/oDMwzcx2qONyioiIiDRaTRq6AFVYBHRz98VmthfwUQ22HQMMBca6+2QAM3sS+Bb4CyEoioiIiGS8tG4hdPcid19cy83HAEuAp2P2twx4EjjSzHLqoIgiIiIijV66txAmY3fgU3cviZufB5wJ7Ah8kWhDM+sMdIqbnRvZZiawsW6LKiIiIlKnmgG9gDfdfU1VK2/NgbAb8FaC+Ysij92pIBAC5wJ/TkWhREREROrRkcDzVa20NQfCFkBRgvkbYpZX5G5gUty8nYDJEydOZLvttquD4m2psLCQL774goEDB9KiRWXFk4ag+klfqpv0pvpJX6qb9JZM/cyePZsTTjgBYH511t+aA2EhkOg8weYxyxNy96XA0th5ZgbArrvuys4771xHRSwvPz+fNWvWsMcee9CmTZuUHENqT/WTvlQ36U31k75UN+ktmfrJzc2NPq3WaW5pfVFJkhYRuo3jRectrMeyiIiIiKStrTkQTgf2iIxHGGsfYD1h+BkRERGRjLdVBEIz62ZmO5lZ05jZk4EuwFEx63UExgIvuHui8wtFREREMk7an0NoZucB7QhXBQMcbmY9I8/viFxKfT1wCtAXmBtZNhn4AHjAzH4GLCdcPZyNriAWERERKZX2gRC4BNg25vVRlLX6PQokHFvH3YvN7FDgb8B4wlXFHwGnuvs3qSuuiIiISOOS9oHQ3ftUY51TgVMTzF8F/CYyiYiIpKWSkhKWLFlCUVERJSXx91NInU2bNrHNNtuwaNEili9fXm/HleqJr5+srCxycnLo0qULWVl1e9Zf2gdCERGRrVlJSQnz5s2jsLCQ7OxssrOzS4c6S7UmTZrQqVMnmjRRHEhHsfXj7mzcuJHCwkKKioro3bt3nYZC/QaIiIg0oCVLllBYWMg222xD586d6y0MAhQXF7N27Vpat25NdnZ2vR1Xqie+ftydpUuXsnLlSpYsWUK3bolG16udreIqYxERkcaqqKiI7Ozseg+D0viYGZ07dyY7O5uiorodLEWBUEREpAGVlJTUazexNG5mRnZ2dp2fa6pAKCIi0sAUBqUmUvH7okAoIiIikuEUCEVEREQynAKhiIiINFqnnnoqffr0aehiNHoKhCIiIlLnzKxa07Rp0xq6qILGIRQREZEUeOSRR8q9fvjhh3nttde2mD9gwICkjvPPf/6zXu/usrVSIBQREZE6d+KJJ5Z7/cEHH/Daa69tMT/e+vXradmyZbWP07Rp01qVT8pTl7GIiIg0iBEjRrDLLrvwySefcMABB9CyZUsmTJgAwHPPPcdhhx1G9+7dycnJoV+/flxzzTUUFxeX20f8OYRz587FzLjpppu477776NevHzk5Oey999589NFH9fn2GhW1EIqIiEiDWbFiBYcccgjHHXccJ554Il26dAHgwQcfJDc3l4svvpjc3FzeeOMN/vSnP5Gfn8/f/va3Kvc7ceJE1q5dy1lnnYWZceONN3LUUUcxe/ZstSomoEAoIiKShi68EKZPT+0x3LMoLs4lOzuLqsY6HjQIbr217suwePFi7rnnHs4666xy8ydOnEiLFi1KX5999tmcffbZ3H333Vx77bXk5ORUut958+bx3Xff0b59ewD69+/PkUceySuvvMLo0aPr/o00cgqEIiIiaWj6dHjzzVQfxWjoKJCTk8Npp522xfzYMLh27VqKiooYNmwY9957L7NmzWK33XardL/HHntsaRgEGDZsGACzZ8+uo5JvXRQIRURE0tCgQak/hrtTXFxcrXspp6o8PXr0oFmzZlvM//LLL7niiit44403yM/PL7dszZo1Ve63d+/e5V5Hw+GqVauSKO3WS4FQREQkDaWiezZecXEJa9euo3Xr1mRnZ6f+gAnEtgRGrV69muHDh9OmTRuuvvpq+vXrR/Pmzfn000+57LLLqjXMTEXvx92TLvPWSIFQRERE0sq0adNYsWIFTz/9NAcccEDp/Dlz5jRgqbZuGnZGRERE0kq0dS+2NW/jxo3cfffdDVWkrZ5aCEVERCStDB06lPbt23PKKacwfvx4zIxHHnlE3b0ppBZCERERSSsdOnTgxRdfpFu3blxxxRXcdNNNHHzwwdx4440NXbStlloIRUREJOXuvPNO7rzzznLzpk2bVuH6Q4cO5f33399ifnwr4YMPPljudZ8+fSpsSVQLY8XUQigiIiKS4RQIRURERDKcAqGIiIhIhlMgFBEREclwCoQiIiIiGU6BUERERCTDKRCKiIiIZDgFQhEREZEMp0AoIiIikuEUCEVEREQynAKhiIiISIZTIBQREZG0N3fuXMys3L2Lr7rqKsysWtubGVdddVWdlmnEiBGMGDGiTvfZUBQIRUREpM4dccQRtGzZkrVr11a4zrhx42jWrBkrVqyox5LVzFdffcVVV13F3LlzG7ooKaVAKCIiInVu3LhxFBYW8swzzyRcvn79ep577jl++ctf0qFDh1od44orrqCwsDCZYlbpq6++4i9/+UvCQPjqq6/y6quvpvT49UWBUEREROrcEUccQevWrZk4cWLC5c899xwFBQWMGzeu1sdo0qQJzZs3r/X2yWrWrBnNmjVrsOPXJQVCERERqXMtWrTgqKOO4vXXX2fp0qVbLJ84cSKtW7dm//3355JLLmHgwIHk5ubSpk0bDjnkED7//PMqj5HoHMKioiIuuugiOnXqROvWrTniiCNYsGDBFtv++OOPnHvuufTv358WLVrQoUMHxo4dW64l8MEHH2Ts2LEAjBw5EjPDzJg2bRqQ+BzCpUuXcsYZZ9ClSxeaN2/ObrvtxkMPPVRunej5kDfddBP33Xcf/fr1Iycnh7333puPPvqoyvedCk0a5KgiIiKy1Rs3bhwPPfQQTz75JOedd17p/JUrV/LKK69w/PHHs2jRIp599lnGjh1L3759WbJkCffeey/Dhw/nq6++onv37jU65m9+8xseffRRTjjhBIYOHcobb7zBYYcdtsV6H330Ee+99x7HHXccPXv2ZO7cufzjH/9gxIgRfPXVV7Rs2ZIDDjiA8ePHc/vttzNhwgQGDBgAUPoYr7CwkBEjRvD9999z3nnn0bdvXyZNmsSpp57K6tWrueCCC8qtP3HiRNauXctZZ52FmXHjjTdy1FFHMXv2bJo2bVqj950sBUIRERFJiVGjRtGtWzcmTpxYLhBOmjSJTZs2MW7cOAYOHMi3335LVlZZp+VJJ53ETjvtxP3338+VV15Z7eN9/vnnPProo5x77rncddddAPzud79j3LhxzJgxo9y6hx12GGPGjCk37/DDD2fIkCE89dRTnHTSSWy33XYMGzaM22+/nYMPPrjKK4rvu+8+vv76ax599NHSrvCzzz6b4cOHc8UVV3D66afTunXr0vXnzZvHd999R/v27QHo378/Rx55JK+88gqjR4+u9vuuCwqEIiIi6ejCC2H69JQeIsud3OJisrKzoarhWwYNgltvrdH+s7OzOe6447jllluYO3cuffr0AULLWJcuXTjwwAPJzs4uXb+4uJjVq1eTm5tL//79+fTTT2t
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"history.history\n",
|
||
|
"plt.plot(range(len(history.history[\"accuracy\"])), history.history[\"accuracy\"], label=\"Train\", c=(0, 0, 1))\n",
|
||
|
"plt.plot(range(len(history.history[\"val_accuracy\"])), history.history[\"val_accuracy\"], label=\"Validation\", c=(1, 0, 0))\n",
|
||
|
"\n",
|
||
|
"plt.xlabel(\"Epochs\")\n",
|
||
|
"plt.ylabel(\"Accuracy\")\n",
|
||
|
"plt.ylim(0, 1)\n",
|
||
|
"\n",
|
||
|
"plt.grid()\n",
|
||
|
"plt.legend()\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Test\n",
|
||
|
"\n",
|
||
|
"Single number below from the evaluate function"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 229,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"10/10 [==============================] - 0s 857us/step - loss: 0.6952 - accuracy: 0.7792\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[0.6951839327812195, 0.7791798114776611]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 229,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.evaluate(data_test.to_numpy(), labels_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Get raw predictions from test data to generate a confusion matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 230,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAGgCAYAAABR4ZjdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAABJ0AAASdAHeZh94AAB6MklEQVR4nO3dd3xTVRvA8d/TzSh7la0sRfYUUBHFjcpUEFkOVPR1IIpbFBUHCKKioiCgDGUPcYCCLKGAbEX2bmnZFNrScd4/blqaNCltkjZteL587if03HvuPblNkydnijEGpZRSSil1UYCvC6CUUkopld9ogKSUUkop5UADJKWUUkopBxogKaWUUko50ABJKaWUUsqBBkhKKaWUUg40QFJKKaWUcqABklJKKaWUAw2QlFJKKaUcaICklFJKKeVAAySllFJKKQdBvi5AdqUwWReNy2XBAY/4ugh+z5hkXxdBKa+oUPRaXxfB70WdXS55eT1vfc4G0jNPy51bCkyApJRSSqnck5qa4pXzBPpJ25SfPA2llFJKKe/RGiSllFJKaRcABxogKaWUUgpjvNPE5i+0iU0ppZRSyoHWICmllFKKVG1is6MBklJKKaW0D5IDDZCUUkoppQGSA+2DpJRSSinlQGuQlFJKKYVJ1RqkjDRAUkoppRRoE5sdbWJTSimllHKgNUhKKaWU0k7aDjRAUkoppRSkJvm6BPmKNrEppZRSSjnQGiSllFJKaRObAw2QlFJKKQU6zN+OBkhKKaWU0gDJgfZBUkoppZRyoDVISimllNKJIh1ogKSUUkopRJvY7GgTm1JKKaWUA61BUkoppZR20nagAZJSSimlNEByoAGSUkoppRDtpG1H+yAppZRSSjnQGiSllFJKQWqKr0uQr2gNUg5cuJDMiI8W0/a6j2nc4D3u7/YNq1buzlbehT9tpUunsTSq/y5trh3Oa6/M4+SJ87lc4vwjJCSE999/h0OH9nDu3An++msZ7dvflK28FStWZNq07zlxIopTp44ye/aPXHFF9UzHPf74o/zww2T27dtBamo848ePdXq+ChUqMGzYUH7//RdOn44hNTWetm2v9+Tp+R3r9/Uehw/v5/z5M6xevZL27W/2dbH8it5j10JCgnn17cfZsGM2e2IW89MfX3FDu2bZylshogxfTXyL7QcXsuPwL3w77T2qVo/IdFyZsiUZ+cXLbNkzjz0xi/lt+Tg6dLwx03F33nMDX04YwurNP7Dn6CKW/z2ZN997kmLFi3r6NPMdSU32yuYvNEDKgVdemsvECavpcHd9Xn71NgIDA3i8/1TWrzuQZb5pU9YxaOAsihcvxOCXbqXrfY35eeE2Hur7HYmJ/vNiysq3337Nc889zZQp03j22UGkpKTw009zaNOmdZb5ihQpwh9//ELbttcxbNhHDBnyDo0bN2Lp0kWUKlXK7tgXX3yem25qy7Zt/5CUlOTynHXq1GLw4EFUqlSRLVu2eeX5+ZsJE8YxcOCzTJ48lWeeGUhKSgoLF86nTZs2vi6a39B77NqoL1/hsafuZ9aPi3hj8CekpKbw/cyPaNGqfpb5ChcpxIyFo2l1XSNGj/ie4e+Np36D2sz6+VNKliqWflzR8MLM/e1z7rqnLd+Nn8vbr35OXNx5vv5uKJ26tbc750ejX6BWnerM/OE3XnvxE5YuWkO//p1Z8PsXhIWF5MrzV/mDGGN8XYZsSWGyTwu6efNhuncbx6AX2/PQw9aHemJiMvd0+ILSpYswZdpDTvNduJDCDW1GULtOOSZ+1wcRAWDpkh0MeHwar7x2Ow/2apFnzyMrwQGP5Mp5mzdvxpo1y3nhhZcZMWIUAKGhoWzZsp6YmFiuu66dy7wvvDCQDz54lxYtrmPduvUA1KlTmy1b1vPRRx/z6qtvph9btWpVDhywgtUzZ2KZMWM2Dz3UP9M5ixYtSnBwMCdPnqRLl05Mnz6Fdu1u5c8/l3vxWTtXEFbLbt68OZGRqxg06EVGjBgJWL+vrVs3EhMTS5s2N/i4hAWfP9zjCkWvzZXzNmp6NT8vHctbr37Ol6OnARAaGsKSNRM5duwk97Qf4DLvgGcf4PWhT3B720fZ9Pd2AGrWrsqSNRMZM2oqw96yapWfeKYHb7wzgK53PcPKZX8DICL89MeXVKxcnuZ1u5KUZP2ttrquEX+t2Gh3nW49bmP02Nd4/qkPmDJxgbdvQbqos8sl107uxLl9Xb3yOVuk+ow8LXducasGSUT6i8i/IpIgIodFZKSIhHq7cPnJb7/8Q2CgcN/9TdPTQkOD6NK1MRs3HCIq6rTTfLt2xnDmTAJ33HFNenAEcGO72hQuHMLPP23N9bL7WteunUhOTmbs2HHpaYmJiYwfP4HWra+lcuXKLvN26dKJyMh16cERwH//7eD335fQrVsXu2PTgqNLiYuL4+TJkzl8FpePrl07235f36SnJSYmMm7ct7Ru3SrL35fKHr3HrnXoeCPJycl8/+289LTExAtM/e4nmresT8VK5VznvbctG9b9kx4cAezacYAVS//m7k4Xv4i1bN2AY7En04MjAGMM82YvoXyF0rS6rlF6umNwBLBw/jIAatWp5s5TzL9Sk72z+YkcB0gi0hH4EqgMbLad42lbmt/6999oqlUvTdGi9nFg/QYVAdj+b7TTfBcuWC+W0LDgTPvCwoL4999oUlMLRi2euxo1asiOHTs5e/asXXpk5Drb/gZO84kIDRrUY/369Zn2rV27jpo1a1C0qP/1A/C1xo0bsWPHDie/r7WA9ftUntF77Fq9BrXYs+sQcWft+2huWPcvANc0qOk0n4hwdb0abNrwX6Z9G9b/wxU1KlOkaCEAQkNCSIhPzHRc/PkEABo0rpNlGcuVLw3AiePOvxgXVJKa4pXNX7hTgzQQ2A3UNMa0AKoAPwI9RaRYljkLsNjYOMqWzfxhXLZsOAAxMXFO81WrVhoR2PD3Qbv0vXuOceLEeRISkjlzOt77Bc5HIiIqEBWVOYBMS6tYMXMHSoBSpUoRFhbmVl7lvkv/virmdZH8jt5j18pXKM3R6OOZ0mNsaRUqlHGar2SpYoSFhaYfl1Ha+SpEWHl37TxARKWyVK5S3u64lq0b2o4rm2UZn3yuJ8nJySyYszTrJ6NyRERCReQDETkiIvEiskZEbslm3vYiskREjonIKRGJFJFenpTHnQCpDvCVMeYogLE6VQzDmjLgak8Kk58lJiQTEpJ5VoTQ0CDbfuedgkuWKsztd1zD3Dmb+Hb8Xxw8eJJ16/Yz8LmZBAVbtz/BzztqFypUiMTEzN/WEhIS0ve7ygdW9XpO8yr3ufv7Utmn99i1sLBQLjj7m7fdr7BCzntzhIVZ6YkXMr8Xp72HpB0zZeICUlJS+Wri2zRrWY9qV1Tkf88/yB13X5/lNQA6dWtPzz4d+PLTH9i7+1AOnlkBkJrinc19E7AqYSYDzwApwEIRuS6rTCJyD/AbEAIMAV4F4oFJIvKcu4VxZx6kssARh7TDtsfC7hYkjYiUs13DzpGjn1OuXHFPT++20LCg9OayjNJGoTlrQksz5O27SEhI4qMPFvHRB4sAuPue+lStWpJFv22ncGH/HgkRHx9PaGjmN5ywsLD0/a7ygdVBM6d5lfvc/X2p7NN77FpCQiIhzv7mbffLWdNYWj6A0JDM78Vp7yFpx/y7bTcDHnqbDz95nvmLvwCsWqY3B3/KB58M4nyc8ylYWrZuwIjPX2LJojW8/9bXOXxm+Z8vm8dEpAXQHXjBGDPcljYJ2Ap8CGQ15PkpIAq4yRiTaMv7FbAd6AuMdKdM7k4UmZudZgYAbzomfjFmMW8O6eLk8LxRtmxRjh49myk9NtZKK1fOdV+Y8PAwPv+iO0eOnObI4VNEVCxOpUoleKD7eEqVKkyxYmG5Vu78ICoqmkqVMjcZRERUAODIkSin+U6cOEFCQkL6cTnJq9x36d+X4/cjlVN6j107Gn2ciIqZm7jKVbD6/URHH3Oa7+SJMyQkJKYfl1H5tLxRF/P+NHcpvy1cQd36NQkMDGDLxh20vr4xALt3Hcx0jrr1ajBh2vv8988eHun1Oikp/tP
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x480 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"predictions = model(data_test.to_numpy())\n",
|
||
|
"\n",
|
||
|
"conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n",
|
||
|
" [tf.math.argmax(i) for i in predictions], \n",
|
||
|
" num_classes=len(filtered_playlists))\n",
|
||
|
"\n",
|
||
|
"normalised_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n",
|
||
|
"for idx, row in enumerate(conf):\n",
|
||
|
" normalised_conf[idx, :] = row / np.sum(row)\n",
|
||
|
"\n",
|
||
|
"sns.heatmap(normalised_conf, \n",
|
||
|
" annot=True, \n",
|
||
|
" xticklabels=playlist_names, yticklabels=playlist_names, \n",
|
||
|
" cmap='inferno')\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Imports & Setup"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 60,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from datetime import datetime\n",
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"from google.cloud import bigquery\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import matplotlib as mpl\n",
|
||
|
"mpl.rcParams['figure.dpi'] = 120\n",
|
||
|
"import seaborn as sns\n",
|
||
|
"\n",
|
||
|
"from analysis.net import get_spotnet, get_playlist, track_frame\n",
|
||
|
"from analysis.query import *\n",
|
||
|
"from analysis import spotify_descriptor_headers, float_headers, days_since\n",
|
||
|
"\n",
|
||
|
"import numpy as np\n",
|
||
|
"import pandas as pd\n",
|
||
|
"\n",
|
||
|
"from sklearn import svm\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.metrics import plot_confusion_matrix\n",
|
||
|
"from sklearn.utils import class_weight\n",
|
||
|
"\n",
|
||
|
"import tensorflow as tf\n",
|
||
|
"\n",
|
||
|
"client = bigquery.Client()\n",
|
||
|
"spotnet = get_spotnet()\n",
|
||
|
"cache = 'query.csv'\n",
|
||
|
"first_day = datetime(year=2017, month=11, day=3)\n",
|
||
|
"sig_max, c_max = 0.5, 20"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Read Scrobble Frame"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"scrobbles = get_query(cache=cache)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Write Scrobble Frame"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"scrobbles.reset_index().to_csv(cache, sep='\\t')"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|