{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Playlist Neural Network\n", "\n", "Given a list of playlists, can unknown tracks be correctly classified?" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n", "# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\", \"POP\"] # super-genres\n", "# playlist_names = [\"RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n", "# playlist_names = [\"ALL RAP\", \"EDM\", \"ROCK\", \"METAL\", \"JAZZ\"] # super-genres without POP\n", "playlist_names = [\"ALL RAP\", \"DNB\", \"4/4\", \"cRock\", \"METAL\", \"cJazz\"] # super-genres with decomposed EDM\n", "# playlist_names = [\"DNB\", \"HOUSE\", \"TECHNO\", \"GARAGE\", \"DUBSTEP\", \"BASS\"] # EDM playlists\n", "# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\"] # rap decades\n", "# playlist_names = [\"UK RAP\", \"US RAP\"] # UK/US split\n", "# playlist_names = [\"uk rap\", \"grime\", \"drill\", \"afro bash\"] # british rap playlists\n", "# playlist_names = [\"20s rap\", \"10s rap\", \"00s rap\", \"90s rap\", \"80s rap\", \"trap\", \"gangsta rap\", \"industrial rap\", \"weird rap\", \"jazz rap\", \"boom bap\", \"trap metal\"] # american rap playlists\n", "# playlist_names = [\"rock\", \"indie\", \"punk\", \"pop rock\", \"bluesy rock\", \"hard rock\", \"chilled rock\", \"emo\", \"pop punk\", \"stoner rock/metal\", \"post-hardcore\", \"melodic hardcore\", \"art rock\", \"post-rock\", \"classic pop punk\", \"90s rock & grunge\", \"90s indie & britpop\", \"psych\"] # rock playlists\n", "# playlist_names = [\"metal\", \"metalcore\", \"mathcore\", \"hardcore\", \"black metal\", \"death metal\", \"doom metal\", \"sludge metal\", \"classic metal\", \"industrial\", \"nu metal\", \"calm metal\", \"thrash metal\"] # metal playlists\n", "\n", "# headers = float_headers + [\"duration_ms\", \"mode\", \"loudness\", \"tempo\"]\n", "headers = float_headers + [\"mode\", \"loudness\", \"tempo\"]\n", "# headers = float_headers\n", "\n", "BALANCED_WEIGHTS = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pull and process playlist information.\n", "\n", "1. Get live playlist track information from spotify\n", "2. Filter listening history for these tracks\n", "\n", "Filter out tracks without features and drop duplicates before taking only the descriptor parameters" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "playlists = [get_playlist(i, spotnet) for i in playlist_names] # 1)\n", "\n", "# filter playlists by join with playlist track/artist names\n", "filtered_playlists = [pd.merge(track_frame(i.tracks), scrobbles, on=['track', 'artist']) for i in playlists] # 2)\n", "\n", "filtered_playlists = [i[pd.notnull(i[\"uri\"])] for i in filtered_playlists]\n", "# distinct on uri\n", "filtered_playlists = [i.drop_duplicates(['uri']) for i in filtered_playlists]\n", "# select only descriptor float columns\n", "filtered_playlists = [i[headers] for i in filtered_playlists]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Construct the dataset with associated labels before splitting into a train and test set." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "dataset = pd.concat(filtered_playlists)\n", "labels = [np.full(len(plst), idx) for idx, plst in enumerate(filtered_playlists)]\n", "labels = np.concatenate(labels)\n", "\n", "# stratify: maintains class proportions in test and train set\n", "data_train, data_test, labels_train, labels_test = train_test_split(dataset, labels, \n", " test_size=0.1, \n", "# random_state=70, \n", " stratify=labels\n", " )\n", "\n", "class_weights = class_weight.compute_class_weight('balanced',\n", " classes=np.unique(labels_train),\n", " y=labels_train)\n", "class_weights = {i: j for i, j in zip(range(len(filtered_playlists)), class_weights)}\n", "\n", "labels_train = tf.one_hot(labels_train, len(filtered_playlists))\n", "labels_test = tf.one_hot(labels_test, len(filtered_playlists))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def tensorboard_callback(path='tensorboard-logs', prefix=''):\n", " return tf.keras.callbacks.TensorBoard(\n", " log_dir=os.path.normpath(os.path.join(path, prefix + datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n", " )" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def get_model(hidden_nodes=128,\n", " layers=2,\n", " classes=len(filtered_playlists),\n", " activation=lambda: 'sigmoid', \n", " weight_init=lambda: 'glorot_uniform'):\n", " l = [tf.keras.layers.InputLayer(input_shape=data_train.to_numpy()[0].shape, name='Input')]\n", " \n", " for i in range(layers):\n", " l.append(\n", " tf.keras.layers.Dense(hidden_nodes, \n", " activation=activation(), \n", " kernel_initializer=weight_init(), \n", " name=f'Hidden{i+1}')\n", " )\n", " \n", " l.append(tf.keras.layers.Dense(classes, \n", " activation='softmax', \n", " kernel_initializer=weight_init(), \n", " name='Output'))\n", " \n", " model = tf.keras.models.Sequential(l)\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Single Model" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_10\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "Hidden1 (Dense) (None, 64) 704 \n", "_________________________________________________________________\n", "Output (Dense) (None, 6) 390 \n", "=================================================================\n", "Total params: 1,094\n", "Trainable params: 1,094\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model = get_model(hidden_nodes=64, layers=1)\n", "\n", "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n", "# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n", " loss='categorical_crossentropy', \n", " metrics=['accuracy'])\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "if BALANCED_WEIGHTS:\n", " cw = class_weights\n", "else:\n", " cw = None\n", "history = model.fit(data_train.to_numpy(), labels_train, \n", " callbacks=[tensorboard_callback()], \n", " validation_split=0.11,\n", " verbose=0,\n", " class_weight=cw,\n", " epochs=50)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-05-09T23:59:37.477771\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "history.history\n", "plt.plot(range(len(history.history[\"accuracy\"])), history.history[\"accuracy\"], label=\"Train\", c=(0, 0, 1))\n", "plt.plot(range(len(history.history[\"val_accuracy\"])), history.history[\"val_accuracy\"], label=\"Validation\", c=(1, 0, 0))\n", "\n", "plt.xlabel(\"Epochs\")\n", "plt.ylabel(\"Accuracy\")\n", "plt.ylim(0, 1)\n", "\n", "plt.grid()\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test\n", "\n", "Single number below from the evaluate function" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10/10 [==============================] - 0s 2ms/step - loss: 0.7529 - accuracy: 0.7609\n" ] }, { "data": { "text/plain": [ "[0.7529351115226746, 0.7609427571296692]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(data_test.to_numpy(), labels_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get raw predictions from test data to generate a confusion matrix" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-05-09T23:59:37.976238\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "predictions = model(data_test.to_numpy())\n", "\n", "conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n", " [tf.math.argmax(i) for i in predictions], \n", " num_classes=len(filtered_playlists))\n", "\n", "normalised_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n", "for idx, row in enumerate(conf):\n", " normalised_conf[idx, :] = row / np.sum(row)\n", "\n", "sns.heatmap(normalised_conf, \n", " annot=True, \n", " xticklabels=playlist_names, yticklabels=playlist_names, \n", " cmap='inferno')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ensemble Model" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "models = [get_model(hidden_nodes=random.randint(16, 128), \n", " layers=random.randint(1, 2)) \n", " for _ in range(9)]\n", "\n", "for m in models:\n", " m.compile(\n", " optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), \n", "# optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),\n", " loss='categorical_crossentropy', \n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train the models of the meta-classifier. Get a random number of epochs from a reasonable range to introduce variation between models. *Weird?*: Randomly decide whether to regularise by class weights. Class weights change the penalisation methods such that smaller classes are treated more importantly when computing loss. I'm not sure whether it's going to help. " ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "training model 1\n", "training model 2\n", "training model 3\n", "training model 4\n", "training model 5\n", "training model 6\n", "training model 7\n", "training model 8\n", "training model 9\n" ] } ], "source": [ "if BALANCED_WEIGHTS:\n", " cw = class_weights\n", "else:\n", " cw = None\n", "\n", "ensem_histories = list()\n", "for idx, m in enumerate(models):\n", " print(f'training model {idx+1}')\n", " h = m.fit(data_train.to_numpy(), labels_train, \n", " callbacks=[tensorboard_callback()], \n", " validation_split=0.11,\n", " verbose=0,\n", "# class_weight=cw,\n", " class_weight=random.choice([*([class_weights]*3), None]),\n", " epochs=random.randint(20, 100))\n", " ensem_histories.append(h)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "79.5% Accuracy, 82.7% Agreement, 67.6% Ind. Accuracy\n" ] } ], "source": [ "ensem_results = ensem_classify(models, data_test, labels_test)\n", "print(f\"{ensem_results[2]*100:.3}% Accuracy, {ensem_results[3]*100:.3}% Agreement, {ensem_results[4]*100:.3}% Ind. Accuracy\")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-05-10T00:00:22.760673\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ensem_conf = tf.math.confusion_matrix([tf.math.argmax(i) for i in labels_test], \n", " ensem_results[0], \n", " num_classes=len(filtered_playlists))\n", "\n", "normalised_ensem_conf = np.ndarray((len(filtered_playlists), len(filtered_playlists)))\n", "for idx, row in enumerate(ensem_conf):\n", " normalised_ensem_conf[idx, :] = row / np.sum(row)\n", "\n", "sns.heatmap(normalised_ensem_conf, \n", " annot=True, \n", " xticklabels=playlist_names, yticklabels=playlist_names, \n", " cmap='inferno')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Imports & Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from datetime import datetime\n", "import os\n", "import random\n", "\n", "from google.cloud import bigquery\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "mpl.rcParams['figure.dpi'] = 120\n", "import seaborn as sns\n", "\n", "from analysis.nn import ensem_classify\n", "from analysis.net import get_spotnet, get_playlist, track_frame\n", "from analysis.query import *\n", "from analysis import spotify_descriptor_headers, float_headers, days_since\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from sklearn import svm\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import plot_confusion_matrix\n", "from sklearn.utils import class_weight\n", "\n", "import tensorflow as tf\n", "\n", "client = bigquery.Client()\n", "spotnet = get_spotnet()\n", "cache = 'query.csv'\n", "first_day = datetime(year=2017, month=11, day=3)\n", "sig_max, c_max = 0.5, 20" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read Scrobble Frame" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "scrobbles = get_query(cache=cache)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write Scrobble Frame" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "scrobbles.reset_index().to_csv(cache, sep='\\t')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.9" }, "metadata": { "interpreter": { "hash": "bce1a3677099e73bf385a0de8ef462673e03f7df0abce93e57e7ca76e8c504a2" } } }, "nbformat": 4, "nbformat_minor": 4 }