2021-03-19 17:21:00 +00:00
{
"cells": [
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 1,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2450,
"status": "ok",
"timestamp": 1615991459232,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "TGIxH9Tmt5zp"
},
2021-03-26 20:01:05 +00:00
"outputs": [],
2021-03-19 17:21:00 +00:00
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
2021-03-29 18:34:04 +01:00
"import tensorflow.keras.optimizers as tf_optim\n",
2021-03-19 17:21:00 +00:00
"tf.get_logger().setLevel('ERROR')\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import seaborn as sns\n",
2021-03-21 09:56:27 +00:00
"import random\n",
"import pickle\n",
"import json\n",
2021-03-22 20:49:29 +00:00
"import math\n",
2021-03-29 18:34:04 +01:00
"import datetime\n",
"import os\n",
2021-05-04 15:24:37 +01:00
"import random\n",
2021-03-19 17:21:00 +00:00
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
2021-05-03 17:48:15 +01:00
"fig_dpi = 200"
2021-03-19 17:21:00 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "fksHv5rXACEX"
},
"source": [
"# Neural Network Training\n"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "l4zqVWyRAM0Z"
},
"source": [
"## Load Dataset\n",
"\n",
"Read CSVs dumped from MatLab and parse into Pandas DataFrames"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 2,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 331
},
"executionInfo": {
"elapsed": 2441,
"status": "ok",
"timestamp": 1615991459234,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "Hj5l_tdZuYh7",
"outputId": "fbfa9406-f662-4ebc-8ba2-67950714627c"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Clump thickness</th>\n",
" <th>Uniformity of cell size</th>\n",
" <th>Uniformity of cell shape</th>\n",
" <th>Marginal adhesion</th>\n",
" <th>Single epithelial cell size</th>\n",
" <th>Bare nuclei</th>\n",
" <th>Bland chomatin</th>\n",
" <th>Normal nucleoli</th>\n",
" <th>Mitoses</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.441774</td>\n",
" <td>0.313448</td>\n",
" <td>0.320744</td>\n",
" <td>0.280687</td>\n",
" <td>0.321602</td>\n",
" <td>0.354363</td>\n",
" <td>0.343777</td>\n",
" <td>0.286695</td>\n",
" <td>0.158941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.281574</td>\n",
" <td>0.305146</td>\n",
" <td>0.297191</td>\n",
" <td>0.285538</td>\n",
" <td>0.221430</td>\n",
" <td>0.360186</td>\n",
" <td>0.243836</td>\n",
" <td>0.305363</td>\n",
" <td>0.171508</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.300000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.600000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.400000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2023-05-27 23:29:17 +01:00
" Clump thickness Uniformity of cell size Uniformity of cell shape \n",
"count 699.000000 699.000000 699.000000 \\\n",
2021-03-19 17:21:00 +00:00
"mean 0.441774 0.313448 0.320744 \n",
"std 0.281574 0.305146 0.297191 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.400000 0.100000 0.100000 \n",
"75% 0.600000 0.500000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
2023-05-27 23:29:17 +01:00
" Marginal adhesion Single epithelial cell size Bare nuclei \n",
"count 699.000000 699.000000 699.000000 \\\n",
2021-03-19 17:21:00 +00:00
"mean 0.280687 0.321602 0.354363 \n",
"std 0.285538 0.221430 0.360186 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.100000 0.200000 0.100000 \n",
"50% 0.100000 0.200000 0.100000 \n",
"75% 0.400000 0.400000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Bland chomatin Normal nucleoli Mitoses \n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.343777 0.286695 0.158941 \n",
"std 0.243836 0.305363 0.171508 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.300000 0.100000 0.100000 \n",
"75% 0.500000 0.400000 0.100000 \n",
"max 1.000000 1.000000 1.000000 "
]
},
2023-05-27 23:29:17 +01:00
"execution_count": 2,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('features.csv', header=None).T\n",
"data.columns = ['Clump thickness', 'Uniformity of cell size', 'Uniformity of cell shape', 'Marginal adhesion', 'Single epithelial cell size', 'Bare nuclei', 'Bland chomatin', 'Normal nucleoli', 'Mitoses']\n",
"labels = pd.read_csv('targets.csv', header=None).T\n",
"labels.columns = ['Benign', 'Malignant']\n",
"data.describe()"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 3,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"executionInfo": {
"elapsed": 2436,
"status": "ok",
"timestamp": 1615991459236,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "qc1Mku6h041u",
"outputId": "94e38c34-0419-4a02-ac8c-17bbc83f777b"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Benign</th>\n",
" <th>Malignant</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Benign Malignant\n",
"0 1 0\n",
"1 1 0\n",
"2 1 0\n",
"3 0 1\n",
"4 1 0"
]
},
2023-05-27 23:29:17 +01:00
"execution_count": 3,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.head()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "h9QsJjWEMbLu"
},
"source": [
"### Explore Dataset\n",
"\n",
"The classes are uneven in their occurences, stratify when splitting later on"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 4,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2430,
"status": "ok",
"timestamp": 1615991459237,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "rjjiSYAZMa4k",
"outputId": "ae0c3772-00be-4f2b-80d2-9cd62a6b6e08"
},
"outputs": [
{
"data": {
"text/plain": [
"Benign 458\n",
"Malignant 241\n",
"dtype: int64"
]
},
2023-05-27 23:29:17 +01:00
"execution_count": 4,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.astype(bool).sum(axis=0)"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "E9lVYI14AUMf"
},
"source": [
"## Split Dataset\n",
"\n",
"Using a 50/50 split"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 5,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2604,
"status": "ok",
"timestamp": 1615991459418,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "L83Ae5l9wM35"
},
"outputs": [],
"source": [
2021-04-29 22:53:26 +01:00
"data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5\n",
"# , stratify=labels\n",
" )"
2021-03-19 17:21:00 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "Qf2U199fNjmJ"
},
"source": [
"## Generate & Retrieve Model\n",
"\n",
"Get a shallow model with a single hidden layer of varying nodes"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 6,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2598,
"status": "ok",
"timestamp": 1615991459419,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "SgoQ-NjWB0T5"
},
"outputs": [],
"source": [
"def get_model(hidden_nodes=9, activation=lambda: 'sigmoid', weight_init=lambda: 'glorot_uniform'):\n",
2021-03-29 18:34:04 +01:00
" layers = [tf.keras.layers.InputLayer(input_shape=(9,), name='Input'), \n",
" tf.keras.layers.Dense(hidden_nodes, activation=activation(), kernel_initializer=weight_init(), name='Hidden'), \n",
" tf.keras.layers.Dense(2, activation='softmax', kernel_initializer=weight_init(), name='Output')]\n",
2021-03-19 17:21:00 +00:00
"\n",
" model = tf.keras.models.Sequential(layers)\n",
" return model"
]
},
2021-03-29 18:34:04 +01:00
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"Get a Keras Tensorboard callback for dumping data for later analysis"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 7,
2021-03-29 18:34:04 +01:00
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
2021-03-30 16:31:10 +01:00
" return tf.keras.callbacks.TensorBoard(\n",
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
" ) "
2021-03-29 18:34:04 +01:00
]
},
2021-03-19 17:21:00 +00:00
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "QT5B9PTUN3pj"
},
"source": [
"# Example Training"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "mQGAUtIPAd6e"
},
"source": [
"## Define Model\n",
"\n",
"Variable number of hidden nodes. All using 9D outputs except the last layer which is 2D for binary classification"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 8,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 7889,
"status": "ok",
"timestamp": 1615991464716,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "fYA34P0Vu_pX",
"outputId": "aded18b8-aa7f-4362-a614-837c8a0f526f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2023-05-27 23:29:17 +01:00
"Model: \"sequential\"\n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
2023-05-27 23:29:17 +01:00
" Layer (type) Output Shape Param # \n",
2021-03-19 17:21:00 +00:00
"=================================================================\n",
2023-05-27 23:29:17 +01:00
" Hidden (Dense) (None, 9) 90 \n",
" \n",
" Output (Dense) (None, 2) 20 \n",
" \n",
2021-03-19 17:21:00 +00:00
"=================================================================\n",
"Total params: 110\n",
"Trainable params: 110\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = get_model(9)\n",
"model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])\n",
"model.summary()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
"id": "KZSwFe-AAs1y"
},
"source": [
"## Train Model\n",
"\n",
"Example 10 epochs"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 9,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11304,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "s8U9Atu3zelS",
"outputId": "8439e8d2-7a5d-495f-a192-a34f01e95bfa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2023-05-27 23:29:17 +01:00
"Epoch 1/5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-05-27 23:23:33.467785: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/11 [==============================] - 0s 1ms/step - loss: 0.7494 - accuracy: 0.6390\n",
2021-03-21 09:56:27 +00:00
"Epoch 2/5\n",
2023-05-27 23:29:17 +01:00
"11/11 [==============================] - 0s 856us/step - loss: 0.7457 - accuracy: 0.6390\n",
2021-03-21 09:56:27 +00:00
"Epoch 3/5\n",
2023-05-27 23:29:17 +01:00
"11/11 [==============================] - 0s 721us/step - loss: 0.7423 - accuracy: 0.6390\n",
2021-03-21 09:56:27 +00:00
"Epoch 4/5\n",
2023-05-27 23:29:17 +01:00
"11/11 [==============================] - 0s 863us/step - loss: 0.7389 - accuracy: 0.6390\n",
2021-03-21 09:56:27 +00:00
"Epoch 5/5\n",
2023-05-27 23:29:17 +01:00
"11/11 [==============================] - 0s 716us/step - loss: 0.7357 - accuracy: 0.6390\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
"text/plain": [
2023-05-27 23:29:17 +01:00
"<keras.callbacks.History at 0x290434bd0>"
2021-03-19 17:21:00 +00:00
]
},
2023-05-27 23:29:17 +01:00
"execution_count": 9,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"model.fit(data_train.to_numpy(), labels_train.to_numpy(), epochs=5)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 10,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11294,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "VnUEJdXovzi-",
"outputId": "02075086-352c-4a23-fac5-ad54d11e0e35"
},
"outputs": [
{
"data": {
"text/plain": [
"['loss', 'accuracy']"
]
},
2023-05-27 23:29:17 +01:00
"execution_count": 10,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics_names"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 11,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11285,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "r0vxP3Ah42ib",
"outputId": "061113ba-52db-4fbe-c7f9-b5d3d85438ed"
},
"outputs": [
{
"data": {
"text/plain": [
2023-05-27 23:29:17 +01:00
"<tf.Tensor: shape=(), dtype=float32, numpy=0.63896847>"
2021-03-19 17:21:00 +00:00
]
},
2023-05-27 23:29:17 +01:00
"execution_count": 11,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics[1].result()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "z7bn8pKTAynt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 1\n",
"\n",
"The below function runs an iteration of layer/epoch investigations.\n",
"Returns the amount of layers/epochs used as well as the results and the model.\n",
"\n",
2021-03-29 19:17:14 +01:00
"Using cancer dataset (as in E2) and 'trainscg' or an optimiser of your choice, vary nodes and epochs (that is using early stopping for epochs) over suitable range, to find optimal choice in terms of classification test error rate of node/epochs for 50/50% random train/test split (no validation set). It is suggested that you initially try epochs = [ 1 2 4 8 16 32 64], nodes = [2 8 32], so there would be 21 node/epoch combinations. \n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-29 19:17:14 +01:00
"(Hint1: from the 'advanced script' in E2, nodes can be changed to xx, with hiddenLayerSize = xx; and epochs changed to xx by addingnet. trainParam.epochs = xx; placed afternet = patternnet(hiddenLayerSize, trainFcn); --see 'trainscg' help documentation for changing epochs). \n",
2021-03-19 17:21:00 +00:00
"\n",
"Repeat each of the 21 node/epoch combinations at least thirty times, with different 50/50 split and take average and report classification error rate and standard deviation (std). Graph classification train and test error rate and std as node-epoch changes, that is plot error rate vs epochs for different number of nodes. Report the optimal value for test error rate and associated node/epoch values. \n",
"\n",
2021-03-29 19:17:14 +01:00
"(Hint2: as epochs increases you can expect the test error rate to reach a minimum and then start increasing, you may need to set the stopping criteria to achieve the desired number of epochs - Hint 3: to find classification error rates for train and test set, you need to check the code from E2, to determine how you may obtain the train and test set patterns)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 18,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 11274,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "mYWhCSW4A57V",
"tags": [
"exp1",
"exp-func"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [],
"source": [
2021-03-26 20:01:05 +00:00
"# hidden_nodes = [2, 8, 16, 24, 32]\n",
"# epochs = [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"hidden_nodes = [2, 8, 16]\n",
"epochs = [1, 2, 4, 8]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_parameters(hidden_nodes=hidden_nodes, \n",
" epochs=epochs, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
2021-04-29 22:53:26 +01:00
" weight_init=lambda: 'glorot_uniform',\n",
2021-03-19 17:21:00 +00:00
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
2021-03-26 20:01:05 +00:00
" run_eagerly=False,\n",
2021-03-29 18:34:04 +01:00
" tboard=True,\n",
2021-03-19 17:21:00 +00:00
" \n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for idx1, hn in enumerate(hidden_nodes):\n",
" for idx2, e in enumerate(epochs):\n",
" if print_params:\n",
" print(f\"Nodes: {hn}, Epochs: {e}\")\n",
"\n",
2021-04-29 22:53:26 +01:00
" model = get_model(hn, weight_init=weight_init)\n",
2021-03-19 17:21:00 +00:00
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
2021-03-26 20:01:05 +00:00
" metrics=metrics,\n",
" run_eagerly=run_eagerly\n",
2021-03-19 17:21:00 +00:00
" )\n",
" \n",
2021-03-29 18:34:04 +01:00
" if tboard:\n",
" if callbacks is not None:\n",
" cb = [i() for i in callbacks] + [tensorboard_callback(prefix=f'exp1-{hn}-{e}-')]\n",
" else:\n",
" cb = [tensorboard_callback(prefix=f'exp1-{hn}-{e}-')]\n",
" \n",
2021-03-19 17:21:00 +00:00
" response = {\"nodes\": hn, \n",
2021-03-21 09:56:27 +00:00
" \"epochs\": e,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TRAIN\n",
" ##############\n",
" \"history\": model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
2021-03-21 09:56:27 +00:00
" epochs=e, \n",
" verbose=verbose,\n",
" \n",
2021-03-29 18:34:04 +01:00
" callbacks=cb,\n",
2021-03-21 09:56:27 +00:00
" validation_split=validation_split).history,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TEST\n",
" ##############\n",
" \"results\": model.evaluate(dtest.to_numpy(), \n",
2021-03-29 18:34:04 +01:00
" ltest.to_numpy(),\n",
" callbacks=cb,\n",
2021-03-19 17:21:00 +00:00
" batch_size=batch_size, \n",
2021-03-21 09:56:27 +00:00
" verbose=verbose),\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"loss\": model.loss,\n",
" \"model_config\": json.loads(model.to_json())\n",
" }\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"model\"] = model\n",
"\n",
" yield response"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "r-63V9qb-i4w",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Single Iteration\n",
"Run a single iteration of epoch/layer investigations"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 19,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 313592,
"status": "ok",
"timestamp": 1615991770468,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "ZmGFkE9y8E4H",
2021-03-26 20:01:05 +00:00
"outputId": "243fb136-bc07-4438-afb7-f2d21758168d",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nodes: 2, Epochs: 1\n",
"Nodes: 2, Epochs: 2\n",
"Nodes: 2, Epochs: 4\n",
"Nodes: 2, Epochs: 8\n",
"Nodes: 8, Epochs: 1\n",
"Nodes: 8, Epochs: 2\n",
"Nodes: 8, Epochs: 4\n",
"Nodes: 8, Epochs: 8\n",
"Nodes: 16, Epochs: 1\n",
"Nodes: 16, Epochs: 2\n",
"Nodes: 16, Epochs: 4\n",
2021-03-26 20:01:05 +00:00
"Nodes: 16, Epochs: 8\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
2021-03-26 20:01:05 +00:00
"# es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience = 5)\n",
"single_results = list(evaluate_parameters(return_model=False, validation_split=0.2\n",
2023-05-27 23:29:17 +01:00
" , optimizer = lambda: tf.keras.optimizers.legacy.SGD(learning_rate=0.5, momentum=0.5)\n",
2021-03-26 20:01:05 +00:00
"# , callbacks=[es]\n",
" ))"
2021-03-19 17:21:00 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "mdWK3-M6SK8_",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
2021-03-21 09:56:27 +00:00
"### Train/Test Curves\n",
"\n",
"For a single test from the set"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 20,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 517
},
"executionInfo": {
"elapsed": 314527,
"status": "ok",
"timestamp": 1615991771417,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "F9Xre1o6SesD",
2021-03-26 20:01:05 +00:00
"outputId": "d6b817aa-58cd-4510-807f-e5e6bcf62f18",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2023-05-27 23:29:17 +01:00
"Nodes: 2, Epochs: 8\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAACXwAAATfCAYAAAC7lK1oAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3hUVf7H8c+k95CEXhOagPQSpIMKigoKFhSUIta14K5t0V0XdVfUXVSs+9OliCLYEFQsiPQaekeKoYeW3svM/f0x5DpDephkkvB+PU+e3Ln33HPPnZyZmznzvd9jMQzDEAAAAAAAAAAAAAAAAACgyvNwdwMAAAAAAAAAAAAAAAAAAKVDwBcAAAAAAAAAAAAAAAAAVBMEfAEAAAAAAAAAAAAAAABANUHAFwAAAAAAAAAAAAAAAABUEwR8AQAAAAAAAAAAAAAAAEA1QcAXAAAAAAAAAAAAAAAAAFQTBHwBAAAAAAAAAAAAAAAAQDVBwBcAAAAAAAAAAAAAAAAAVBMEfAEAAAAAAAAAAAAAAABANUHAFwAAAAAAAAAAAAAAAABUEwR8AQAAAAAAAAAAAAAAAEA1QcAXAAAAAAAAAAAAAAAAAFQTBHwBAAAAAAAAAAAAAAAAQDVBwBcAAAAAAAAAAAAAAAAAVBMEfAEAAAAAAAAAAAAAAABANUHAFwAAAAAAAAAAAAAAAABUEwR8AQAAAAAAAAAAAAAAAEA1QcAXAAAAAAAAAAAAAAAAAFQTBHwBAAAAAAAAAAAAAAAAQDVBwBcA4LKxYsUKWSwWWSwWDRw4sFKPHRkZaR77yJEjlXps1Fyl7dP5ZSwWi8uOPX78eLPO2bNnu6xeVzty5IjZzsjISHc3BwAAAACAKoXxMtQ0jJeVjPEyAABqBgK+AABFcvzg56qfKVOmuPu0gGI99dRTTgMehmGUq56EhAT5+vpWi0Ee1DxTpkxx24A9AAAAANRkjJfhcsR4GWqy//3vf07vyUOHDnV3kwAAKBUCvgAAAByMGzfOXD569KhWrlxZrnrmz5+vnJwcSVJgYKBuu+02l7TvcsbdhwAAAAAAAJWP8bKqi/GyS/fxxx87Pf7ll18UFxfnptYAAFB6Xu5uAACg6goJCdEjjzxSbJmYmBht2rRJktSwYUONGDGi2PLR0dEuax9QETp06KAuXbpo27ZtkqQ5c+aUK0PSnDlzzOVbb71VQUFBrmoiAAAAAABwE8bLcDlivAw11eHDh7VmzRqndVarVZ9++qmefvppN7UKAIDSIeALAFCk8PBwvfvuu8WWmTJlijmA1apVqxLLu9PAgQPLnW78Uh05csQtx0X5jBs3zhzA+vrrr/Xee+/J39+/1PsfOHBAGzdudKrPndzV76uCS5lmAAAAAACAizFe5jqMl1UvjJfVHIyX/cExCNHf31+ZmZmS7Fm/CPgCAFR1TOkIAABwkdGjR8vb21uSlJKSooULF5Zpf8eBgqZNm2rQoEGubB4AAAAAAABQqRgvQ01jGIZTv3z11Vfl4+MjSdqzZ4+2bNnirqYBAFAqBHwBAABcpE6dOho6dKj52PGDf0kMw9Cnn35qPr7nnntksVhc2j4AAAAAAACgMjFehppm1apVZqbBwMBATZw4UTfccIO5/eOPP3ZTywAAKB0CvgAAFW7KlCmyWCyyWCyaMmWKJCkzM1MzZszQkCFD1LRpU/n4+MhisWj79u1O+yYnJ2vevHl68MEH1bNnT9WuXVs+Pj4KCQlRixYtdNddd+mLL76QzWYrsR0rVqww2zFw4MBCyxw5csQsExkZaa7fvHmz7rvvPrVu3VoBAQEKCwtTdHS0XnnlFaWnp5d47MjISLPeotLVDxw40CyzYsUKSVJCQoJee+019ejRQ7Vr15a/v7+aN2+uiRMnavfu3SUe11FMTIwmTpyo5s2by9/fX3Xq1FF0dLRee+01xcfHS5Jmz55ttmH8+PFlqt/R448/btbz4IMPlnq/zz77zNzvyiuvLLTM8ePH9eKLL6p///6qV6+efH195ePjo4iICHXq1EmjR4/WBx98oNOnT5e7/ZJzWvlffvml1PWtXLlSR48eNR+PHTvWabsr+3Rp5T+npR1IW7hwoW6++WY1atRIvr6+aty4sQYPHqxPPvlEeXl5ZTp2ZmamFi5cqMcff1x9+/ZVvXr15OPjo6CgIEVGRmrEiBGaMWOGcnJyiqwjv19GRUWZ644ePep0XkWdY1Gv6eJs2LBBjz76qK688kqFhYXJz89PjRs31vXXX6933323VK/5wt738vLyNGfOHF177bXmc9ugQQPdcsst+v7770vVNnfas2ePnn76aXXp0kW1a9eWr6+vGjZsqIEDBzq9j5TGsmXLNHHiRHXo0EG1atWSl5eXAgIC1LhxY/Xr109PPPGEvv/++2L7RWW9FwAAAABARWC8jPGy0mK8jPGywjBedukcA7pGjhypwMBA3XPPPea6efPmKTc3t8z1rlmzRpMmTVKXLl1Ut25deXt7KyQkRB06dNC4ceM0b948c+rI4pw5c0avv/66Bg8erKZNm8rf31/+/v5q2rSphg4dqtdff73I987SvL86Gj9+vFl+9uzZpS6TlJSk6dOnq3///mrUqJG8vLxksViUlJTktO/Zs2c1a9YsjRs3Tl26dFF4eLi8vb1Vq1YttWnTRhMmTNDPP/9cYjsLU9bnOycnR3Xq1DHPZf369aU+1oABA8z93n777XK1FwBcygAA4BL84x//MCQZkowBAwaUWOYf//iHsXfvXuPKK6801zn+bNu2zdzv66+/Nnx9fQstd/FPp06djN9//73Yti5fvrzEtsbGxpplmjVrZthsNuOFF14wPDw8ijx2VFSUcfjw4WKP3axZM7N8bGxsoWUGDBhgllm+fLmxZs0ao1GjRkUe19PT0/jwww+LPa5hGIbNZjOeeuqpYs+hUaNGxvr1641Zs2aZ68aNG1di3UXZuHGjWU9YWJiRnZ1dqv2GDh1q7vfKK68U2P5///d/hr+/f6n6RJ8+fcrdfsMwjOzsbCM8PNysb9q0aaXab8KECeY+vXr1ctrmjj5tGIZTvcVJTU01brjhhmLb1bdvXyMuLs4YN26cuW7WrFmF1rdhwwYjKCioVOcbGRlpbN26tdB6HPtlaX4cXfyaLk5aWpoxatSoEutv0KCB8cMPPxRb18XveydOnDB69+5dbL0TJkwwrFZrsfWWVmnem0srNzfXeOyxxwxPT89i21+rVi1j9uzZxdaVlpZmDB8+vNR/y48++qjQeirzvQAAAAAAyorxMsbLCsN4GeNlhsF4meP7XmWPlzlKT083goODzeMsWbLEMAx7Hw8LCzPXf/PNN6Wu8/jx48bgwYNL9ffo2bNnkfVYrVbjxRdfNAICAkqsx8PDw9izZ0+BOkrz/uqoNH334jJr1qwxmjRpUmi7EhMTzf2mT59e4rhi/s/VV19tnD9/vsT2GsalPd9PPvmkue2+++4r1fEOHDhg7uPr62vEx8eXaj8AqEheAgCgEsXHx+v666/XsWPH5Ofnp759+6pZs2ZKS0vThg0bnMqePXtW2dnZkqTGjRurXbt2ql+/vgICApSWlqZ9+/Zp69atMgxDO3bsUP/+/bV9+3ZFRES4rL0vvviiXnrpJUlS586d1aFDB3l7e2v79u3aunWrJCk2Nla33HKLtm7dKi8v11xad+/ercmTJystLU1169ZVv379FBERoZMnT2rZsmXKzMyU1WrVQw89pA4dOuiqq64qsq4nn3xSb775pvk4KChIgwYNUv369XXmzBktX75cJ0+e1I033qgnnnjCJe2Pjo5W69atdeDAASUmJuqHH37QLbfcUuw+586d0y+//CLJfofdmDFjnLYvXLjQ6e7HkJAQ9erVS40bN5aXl5eSk5N14MAB7d69u9i730rLx8dHd911l9577z1J9jT1f/nLX4rdJzMzU1999ZX52PGuR6lq9Omi5Obm6sYbb9SqVavMdfXr11f//v0VHBysQ4cOac2aNVqzZo1GjBih5s2bl1hnYmKi0tLSJEl169bVlVdeqcaNGyswMFAZGRk6dOiQYmJ
2021-03-19 17:21:00 +00:00
"text/plain": [
2023-05-27 23:29:17 +01:00
"<Figure size 3000x1400 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
2023-05-27 23:29:17 +01:00
"metadata": {},
2021-03-19 17:21:00 +00:00
"output_type": "display_data"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"single_result = random.choice([i for i in single_results if i[\"epochs\"] > 1])\n",
2021-03-19 17:21:00 +00:00
"single_history = single_result[\"history\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(15,7))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## LOSS\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[0]\n",
"ax.set_title(\"Training vs Validation Loss\")\n",
2021-03-21 09:56:27 +00:00
"ax.plot(single_history['loss'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_loss'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
2021-03-21 09:56:27 +00:00
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## ACCURACY\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[1]\n",
"ax.set_title(\"Training vs Validation Accuracy\")\n",
2021-03-21 09:56:27 +00:00
"ax.plot(single_history['accuracy'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_accuracy'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
2021-04-06 17:29:15 +01:00
"# ax.set_ylim(0, 1)\n",
2021-03-21 09:56:27 +00:00
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
"print(f\"Nodes: {single_result['nodes']}, Epochs: {single_result['epochs']}\")\n",
2021-03-21 09:56:27 +00:00
"# plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "0IQ7HfJCSDud",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Accuracy Surface"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 21,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 705
},
"executionInfo": {
"elapsed": 315450,
"status": "ok",
"timestamp": 1615991772345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "X3MWHLxJElbc",
2021-03-26 20:01:05 +00:00
"outputId": "134671d0-bfd3-4ee6-aa02-1a2a5b23f3ca",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABZoAAAPUCAYAAAAkEozUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU1f4G8He2pBdC76ETem/SUUSsIFxQVLBee7sWbBdRr3qvHbsootj1p2JBURCkKSUh9A4BkhASAgmpW2b2/P6IO84m20t2k7yf58lD2J2dObOZnfLume+RhBACRERERERERERERER+0oW7AURERERERERERERUtzFoJiIiIiIiIiIiIqKAMGgmIiIiIiIiIiIiooAwaCYiIiIiIiIiIiKigDBoJiIiIiIiIiIiIqKAMGgmIiIiIiIiIiIiooAwaCYiIiIiIiIiIiKigDBoJiIiIiIiIiIiIqKAMGgmIiIiIiIiIiIiooAwaCYiIiIiIiIiIiKigDBoJiIiIiIiIiIiIqKAMGgmIiIiIiIiIiIiooAwaCYiIiIiIiIiIiKigDBoJiIiIiIiIiIiIqKAMGgmIiIiIiIiIiIiooAwaCYiIiIiIiIiIiKigDBoJiIiIiIiIiIiIqKAMGgmIiIiIiIiIiIiooAwaCYiIiIiIiIiIiKigDBoDrL58+dDkiRIkoT58+fX2nI7dOigLvfo0aO1tlwioobu6NGj6v63Q4cOLqezTyNJksd5VlRU4KWXXsKYMWPQtGlTGAwGt8eW7OxsPPDAA+jfvz+Sk5Oh0+nU6X///Xf/V46omg8++EDdtq699tpaXbYvnyGqG37//Xf1bzpu3LhwN8elcG73kSJUn7+6dg3jTXvHjRvn0zH4q6++wiWXXII2bdogOjra7WfC1/MDoroiXDkK+W7FihWYMWMGUlNTERsb69V1UENTV85vQsUQypmPGzcOa9ascXhs6dKluOyyy7yexwMPPIAXXnjB4bHHH3+cOx8iIqqXzpw5gzFjxmD37t1eTb9p0yZccMEFKC4uDm3DiIiIKGiEELj66qvx6aefejW9r+cHRETB9vDDD+O///1vuJtBES6kQbMzS5Ys8TpoVhQFn3zySYhbRM5oeysIIcLYksAcPXoUHTt2BACkpqbWiZ4SRNSwzZ07V72INBgMOPfcc5Gamgqj0QgAGDp0qDqtEAKzZ89WQ+ZGjRphwoQJaNGiBXS6qpuW2rRpU7srUA+E+tgxf/58PPHEEwD45TkR1R/aTkarV69ukL24fPHpp586hMxDhw5Fz549ER8fDwDo2rWrw/S+nB+Q9zp06IBjx44BALKystgrk8iFDRs2OITMvXr1woABA5CcnAwAaNKkSbiaRhGm1oPmH3/8EUVFRUhJSfE47YoVK5CXl1cLrSIiIgo/WZbx2Wefqf//7bffMGbMGJfTb9q0CQcOHAAANGvWDHv27EHTpk1D3k4iIiIKzEcffaT+/sQTT2DevHkup/X1/ICIKNi0+6wbbrgB7777LsuZkVO1VqO5Z8+eAACLxYLPP//cq9csWbKkxuuJiIjqIiGE+uPKgQMHUF5eDgDo0qWLx4vIrVu3qr9fdtllDJkp5K699lp1O/7ggw9qddnefIaobhk3bpz6N2U9+cjGz5/3fv/9d/W9cterW3sMv+GGG9zO09fzAyKiYNPus6677jqGzORSrQXNV1xxhXpbjzZAdqWkpARLly4FAPTv3x99+vQJZfOIiIjCrqioSP29VatWQZ+eiIiIIoMvx3Ae74ko3LgfIm/VWtDcrFkzTJ48GQCwceNGHDx40O30X331FSorKwEAc+bMCXn7iIiIws1qtaq/22ssB3N6IiIiigyyLKu/ezqG83hPROHG/RB5q1a3jtmzZ6u/e+rVbH/eYDBg1qxZPi9LCIGvvvoKV155JTp37oyEhAQkJCSgc+fOmDVrFv7v//7P51u/Vq9ejVmzZiE1NRUxMTFo1aoVRo8ejTfffBMVFRU+t9Hut99+wy233IJevXqhcePGiI6ORuvWrTFp0iS8/vrrauAear///jskSapxC4T9seo/7gZHys7OxlNPPYXRo0ejdevWiI6ORuPGjTFgwADcf//9ak1RT6xWKz7++GNcfvnl6NSpExISEmAwGJCYmIguXbpg0qRJmDdvHjZv3uzwug8++ACSJKmDOQHAsWPHXK5LICorK7F06VLcddddGDVqFFq0aIGoqCgkJCSgQ4cOmDp1KhYtWgSLxeLzvI8cOYL58+djzJgxaNOmDWJiYhAXF4dOnTphypQpeO2111BQUOBxPiaTCe+//z5mzJiBzp07IykpCVFRUWjevDlGjx6Nhx56CJs2bXL62muvvVZ9n7y5Tdr+3kuShGuvvdbraRRFweeff47LLrsMnTp1QmxsLCRJUu9ssIu09/uyyy5T1+XZZ5/1elmPP/64+rrLL7/c57Y6U1ZWhldffRWTJk1C27ZtERMTg5SUFPTu3Rt33HGHy7+xXd++fdU2aesAevLPf/5Tfd3tt9/udtq9e/fikUcewdChQ9W/XbNmzTBs2DDMmzcPJ06c8Li8cePGqcuz32qdl5eHZ555BkOHDkXLli2h1+vRqFEjr9fBlby8PDz66KPo27cvkpKSkJSUhF69euHee+/F/v37fZqXq33O0aNH1cfHjx+vPr5mzZoa+6px48Y5fH7sA8oBVfUdq0/vbpC5LVu24N5770X//v3RrFkzREVFoWXLlhg7diz+97//OfRacKVDhw41jgmHDx/Go48+igEDBqBZs2bQ6XTo37+/09eXl5fjrbfewiWXXILU1FTExcUhMTERXbt2xfXXX49Vq1Z5bIOrfc63336LSy65BO3bt0d0dDSaN2+O888/Hx9//LHL43+ojx32bdfT383Z/tPZvri4uBgLFixQ91kGgwGSJKmDQ9oVFBRg8eLFmDNnDgYMGIDGjRvDaDSiUaNGSEtLw3XXXYdffvnFq3XwZh+vPZ/Q3ja+atUqXHHFFejUqRNiYmLQpEkTjBkzBq+//rrDxYsr3rz3zrbJnJwc/Pvf/0a/fv3QqFEjxMfHIy0tDXfeeac6AJS3VqxYgSuuuALt27d3OBd844031Nva58+f79Vn0FvO1mnfvn2455570LNnT3Xf1LdvXzz22GM4efKkT/MXQuDbb7/FnDlz0K1bNyQnJyMmJgbt2rXDlClT8OGHHzoEYs5o92PagbTWr1+PG2+8EWlpaUhOToYkSbjnnnvU511tK+788ssvuP7669GtWzckJSUhNjYWqampmDp1Kj744AOvtiWtrVu34qabblLPPZo1a4ahQ4fiueeew5kzZ3yaF1C1nd9www3o06cPGjVqBIPBgLi4OLRt2xajR4/GPffcgx9//NGv8xS7tLQ09X3bt2+fy+m0+w1JkrBr1y6X0954443qdIsWLarxvLvPn/1x+0CAADB+/Hin+zZvziWD+Zn1ltlsxmuvvYbRo0ejWbNmiI2NVa8bV69e7dO8nJ2n2Gk/z1qurrV8OT9wJRjXZYEcg4DQ7mfS09Nx4403olu3boiLi0NKSgqGDh2KZ555Rt0vu5uXdpvq2LGj07+Fv6V9gnmO4mmerrh637yZZt26dbjmmmvQpUsXxMXFITk5GePGjcOnn37qtJ2rV6/GP/7xD3Tt2hWxsbFo3rw5LrroIvz8888e2+lMeXk53njjDYwePRotW7ZETEwMUlNTcdVVVznsb7wV7uuQ7OxsPPHEExgzZgxatGiB6OhoREVFoUmTJujXrx9mzZqFt956y+fjuCcbN27EHXfcgV69eiElJQUxMTFo27YtLrjgArz++usuPyfV19fTZ8WfwbOD+bl2JZD1d+bgwYO45557kJaWhvj4eDRu3Bj9+/fHvHnzkJOT49O8tIKxfQJVmcDbb7+Niy66CO3bt0dcXByMRiOSk5ORlpaGSy65BM8884zbc4KAiRAaO3asACAAiLfeekuYzWaRkpIiAIgOHToIm83m9HVZWVlCkiQBQFx00UVCCCFmzpypzuvxxx93u9wDBw6IAQMGqNO7+hk0aJA4fPi
2021-03-19 17:21:00 +00:00
"text/plain": [
2023-05-27 23:29:17 +01:00
"<Figure size 1600x1000 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
2023-05-27 23:29:17 +01:00
"metadata": {},
2021-03-19 17:21:00 +00:00
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(epochs, hidden_nodes)\n",
"\n",
"shaped_result = np.reshape([r[\"results\"][1] for r in single_results], \n",
" (len(hidden_nodes), len(epochs)))\n",
"\n",
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
"surf = ax.plot_surface(X, Y, shaped_result, cmap='viridis')\n",
"ax.set_title('Model test accuracy over different training periods with different numbers of nodes')\n",
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
2021-04-06 17:29:15 +01:00
"# ax.set_zlim([0, 1])\n",
2021-03-19 17:21:00 +00:00
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
2021-03-21 09:56:27 +00:00
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "C793_RHvSGai",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Error Rate Curves"
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": 22,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 668
},
"executionInfo": {
"elapsed": 316211,
"status": "ok",
"timestamp": 1615991773109,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "tpClZMptrq-q",
2021-03-26 20:01:05 +00:00
"outputId": "f9fe93f9-7b67-4772-83e4-9e3567fd9318",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABaIAAAOrCAYAAAClQvEMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd1gU1/4/8PfSe1OkixW7goKKFQuW2Ev0aowtlvjVmBiNJSZGUzRer+1qLLGhMTFq7BpjVxRFAREbdrECSgfpML8/+DF3F9gGuyzC+/U8PM+WM2fODDNnZj9z5jMSQRAEEBERERERERERERFpiZ6uG0BERERERERERERElRsD0URERERERERERESkVQxEExEREREREREREZFWMRBNRERERERERERERFrFQDQRERERERERERERaRUD0URERERERERERESkVQxEExEREREREREREZFWMRBNRERERERERERERFrFQDQRERERERERERERaRUD0URERERERERERESkVQxEExEREREREREREZFWMRBNRERERERERERERFrFQDQRERERERERERERaRUD0URERERERERERESkVQxEExEREREREREREZFWMRBNRERERERERERERFrFQDQRERERERERERERaRUD0URERERERERERESkVQxEExEREREREREREZFWMRBNRKSivXv3ol+/fnBxcYGxsTEkEgkkEgn8/Px03bQqbeHCheL/YuHChbpuTpnVqlVLXJ6oqChdN6eYyra+K6qxY8eK6zkgIEDXzXnv5ebmYtOmTfD394eDgwOMjIzE9Tt27Fitzz8qKkqcX61ateSWKywjkUiU1pmeno4VK1agU6dOqF69OgwMDBTumy9evMBXX30FT09PWFtbQ09PTyx//vz50i8cUQWlzv5E5UfV/lBV2jgvOX/+vFbO87VxjsfzhcqD59jao+l+Rx1+fn483yrCQNcN0KWoqCjUrl1bo3V+99137DSIKhlBEDBq1Cj88ccfum4KERGpKSsrC7169apUJ/8JCQno1KkT7ty5o1L5q1evolevXkhKStJuw4iIiIiIFKjSgejK7Pz58+jSpQsAoHPnzpXqxxfpjvSIEkEQdNiS8vXHH3/IBKFbt26Nxo0bw9zcHABQv359XTWNqELiMUjW2LFjsX37dgDAtm3bymUELv3Pf/7zH5ltsHPnzqhXrx5MTEwAAG3bttVRy0pvzpw5YhDawMAA3bp1g7u7OwwNDQEUHKcKCYKA0aNHi0FoGxsbdO3aFQ4ODtDTK7g50sXFpXwXoBKQHtDi7u5eIe9gISIi0hSe35OmVOlAtJWVFaZOnaqwzLVr1xASEgIAcHZ2xqBBgxSWlz7xJ6LK4bfffhNfL1q0CAsWLNBha4iISB3Sffj27dsxevRoHbam7HJzc7Fr1y7x/ZkzZ9CpUye55a9evYoHDx4AAOzt7XH37l1Ur15d6+0kIiIiIiqqSgei7ezssHbtWoVlFi5cKAai69evr7Q8EVU+169fF19/8sknOmwJlWThwoVMiVSOuL7LR0BAAHM9akB6ejru378PADAyMsKoUaN03CLFVLnb6MGDB3j37h0AoF69egqD0IDsMWzAgAEMQhMRKeHn51el7v4kquxq1aqls32aI8eL48MKiYiUSExMFF87OTnpsCVERKQO6f5bOhXF+0zdYxKPYURERERUUbz/Z+NERFqWm5srvq4MQQwioqoiJydHfF1Z+m91l6kyrgMiIiIiej/xbFTDIiMj8fXXX6N169ZwcHCAkZER7O3t0aZNGyxYsACvX79WqZ60tDRs2LABffr0Qc2aNWFmZgZDQ0NYW1ujYcOG6NevHxYvXozbt2/LTLdw4UJIJBIxiTwAXLhwARKJpNhfrVq1NLLMOTk5+O233zBs2DDUqVMHlpaWMDc3R+3atTFixAgcOHBA6W0Q58+fF9vl5+cnfv73339jxIgRqF+/PiwsLCCRSLBq1SoABQ+JKWlZLl26hAkTJqBhw4awtraGRCLBF198UeJ8T5w4gfHjx8PDwwNWVlYwNTWFu7s7Bg0ahICAAJkfb/KMHTtWbEfhbdRJSUlYvXo1OnXqBBcXFxgYGEAikZTqafWlWTeFcnJycOLECcyePRtdunSBs7MzTExMYGpqCldXV/Tu3RurVq1CWlqaSvOXVtI2JZFIFD6s58WLF/jhhx/QsWNHODs7w9jYGHZ2dvDy8sKsWbPEHJbK5OTkYOfOnRg8eDDq1KkDCwsLGBgYwNLSEvXq1UPPnj2xYMECXLt2TaX6SlKrVi2Vl1ueO3fu4KuvvoKXlxeqV68OY2NjODs7w8/PD0uXLkV8fLzSdgQEBIjzKXzAWV5eHv78808MGDAAderUgampKSQSCQ4ePFjq5X327BnWr1+PESNGoGnTprC2toahoSGqVauGZs2aYcqUKQgODi51/Yq8ePECixYtQqdOneDg4ABjY2MYGRmhWrVqaNGiBUaOHIn169cjJiamxOkL+z2JRCI3ZURJ6xEADhw4gH79+qFmzZowNjZGjRo10KNHD+zcuVOt27cSExPx448/wtvbG7a2trCwsECDBg0wYcIEMb0TAJW2G3Vpog9Wh6L1rcljUEhICGbMmAFPT0/Y29vDyMgIjo6O6Ny5M5YuXSozylMe6f24sG96/Pgx5s+fDy8vL9jb20NPTw+enp7Fpo2MjMTKlSsxePBgNGjQAJaWljA0NIS9vT28vb0xY8YM3L17V6X5Fz6oEADGjRtX4vooui5LOrYokpaWhv/+97/o2bMnXF1dYWJiAltbWzRt2hTTpk3D1atXldYBlLyN3r9/H1988QUaNWoECwsLWFlZoUWLFpg3bx7i4uJUqlddgiBg7969GDFiBOrWrQsLCwtYWFigbt26GDlyJP766y+F23XhMhQ+TA4o6Oe0cS4UHR2N+fPno3nz5rCysoKVlRWaNGmCGTNmiGlBVCWvj5A+51G2f/n5+cn0eYsWLRLLL1q0SOm2J03X+yEAvHv3DuvXr0e/fv3g7u4OMzMzWFpaon79+hg/fjzOnj2rtA2aPAYU1qVs29JkXy8IAg4cOIAxY8bAw8MD1tbWMDExgZubGwYOHIjt27fLXDQvibzz5jNnzoj7mampKezt7dGxY0esXbsWWVlZarVTE+c9Rd28eRNz585FmzZt4OjoCCMjI/EYO3z4cGzZsgXJyckq16eJ/iwuLg7/+c9/0L17d/Hc2tDQEDY2NmjSpAmGDh2KFStW4OnTp2ovb0nevHmDbdu2YcyYMfDy8oKdnZ04v4YNG2LcuHE4ceKESnWVdAzPzc3Fjh070L17d7i4uMDY2BhOTk4YOHAgjh49qlZbNdkflkVZl0ne7y95srKysGbNGnTs2BH29vYwNTUVj1fnzp0r9XIcPHgQAwYMEJfB1dUV/v7++O2335Tu8/Josz8JDQ3FhAkT4OHhATMzM9ja2qJ169ZYvHixmFJKE7R9Xl/Wc5CSnDt3DiNHjoS7uztMTEzg5OSEjh07Yt26dUhPT1erLmlnzpzBp59+iiZNmsDOzk7sd3v27Im1a9ciIyOj1HUXKsv5fVl+y2ZkZODgwYOYPn06OnToIMbXLCwsUKtWLQwaNAhbtmxBdna20mWQt80WpY3zYD8/P7FOeWk6SjrvT09Px7p168RlNzY2hpubG0aMGIGgoCCl85V27949TJ8+HQ0aNIC5uTns7Ozg6emJb7/9Fi9evACgfr9XJgIp9N133wkABABC586d5ZbLzMwUJk+eLOjr64vlS/ozNTUV1qxZo3Cely9fFlxcXBTWI/2Xk5NTYnuV/bm7u5d5/Zw7d06oW7eu0nm1bdtWePnypcJ6pNdzUlKSMGjQoBLrWrlypSAIgvD06VOZZcnKyhImT55c4jSff/65zPxiY2OFbt26KW13/fr1hZCQEIXrYMyYMWL5bdu2CZcuXRLc3NxKrC8xMbFU61jddSMIgvD8+XOhWrVqKm0L1apVE06ePKl0/qr8PX36tFgdeXl5wrfffiuYmJgonNbAwED4+uuvhfz8fLnr4/79+0KjRo1Ubs/
2021-03-19 17:21:00 +00:00
"text/plain": [
2023-05-27 23:29:17 +01:00
"<Figure size 1600x1000 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
2023-05-27 23:29:17 +01:00
"metadata": {},
2021-03-19 17:21:00 +00:00
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for layer in hidden_nodes:\n",
" plt.plot(epochs, \n",
" 1 - np.array([i[\"results\"][1] \n",
" for i in single_results \n",
" if i[\"nodes\"] == layer]), \n",
" label=f'{layer} Nodes')\n",
"\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.title(\"Test error rates for a single iteration of different epochs and hidden node training\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
2021-03-29 18:34:04 +01:00
"plt.ylim(0)\n",
2021-03-21 09:56:27 +00:00
"\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "7mJaKjlCxEkt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Multiple Iterations\n",
"\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the epoch/layer investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs |\n",
"| --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|3|0.2|0.05|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|4|0.08|0.04|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|5|0.08|0|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|6|0.06|0|128|1, 2, 3, 4, 5, 6, 7, 8|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|7|0.06|0|35|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
2021-04-29 22:53:26 +01:00
"| test | learning rate | momentum | batch size | hidden nodes | epochs | statified |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100| |\n",
"|3|1|0.3|20|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100| |\n",
"|4|0.6|0.1|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
"|5|0.05|0.01|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
"|6|1.5|0.5|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
2021-04-30 20:51:04 +01:00
"|2-1|0.01|0|35|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| n |\n",
"|2-2|0.1|0|35|2, 16, 32|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-3|0.15|0|35|2, 16, 32|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-4|0.08|0.9|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-5|0.08|0.2|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
2021-05-03 17:48:15 +01:00
"|2-6|0.01|0.1|35|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| n |\n",
"|2-7|0.01|0.9|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-8|0.01|0.5|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-9|0.01|0.3|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-10|0.01|0.7|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-11|0.01|0.0|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-12|0.1|0.0|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| y |\n",
"|2-13|0.5|0.0|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| y |\n",
"|2-14|0.05|0.0|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| y |"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2023-05-27 23:29:17 +01:00
"execution_count": null,
2021-03-19 17:21:00 +00:00
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "-lsKo4BCP3yw",
2023-05-27 23:29:17 +01:00
"tags": []
2021-03-19 17:21:00 +00:00
},
2023-05-27 23:29:17 +01:00
"outputs": [],
2021-03-19 17:21:00 +00:00
"source": [
2021-03-21 09:56:27 +00:00
"multi_param_results = list()\n",
2021-03-19 17:21:00 +00:00
"multi_iterations = 30\n",
"for i in range(multi_iterations):\n",
" print(f\"Iteration {i+1}/{multi_iterations}\")\n",
2021-04-29 22:53:26 +01:00
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5\n",
"# , stratify=labels\n",
" )\n",
2021-03-21 09:56:27 +00:00
" multi_param_results.append(list(evaluate_parameters(dtrain=data_train, \n",
2021-03-19 17:21:00 +00:00
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
2021-04-30 20:51:04 +01:00
" hidden_nodes=[2, 16, 32],\n",
2021-04-29 22:53:26 +01:00
" epochs=[1, 2, 4, 8, 16, 32, 64, 100],\n",
2023-05-27 23:29:17 +01:00
" optimizer=lambda: tf.keras.optimizers.legacy.SGD(learning_rate=0.15, momentum=0.0),\n",
2021-04-29 22:53:26 +01:00
" weight_init=lambda: 'random_uniform',\n",
2021-03-19 17:21:00 +00:00
" return_model=False,\n",
" print_params=False,\n",
2021-04-30 20:51:04 +01:00
" batch_size=35)))"
2021-03-19 17:21:00 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-22 20:49:29 +00:00
"(Iterations x [Test/Train] x Number of nodes x Number of epochs)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 268,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30 Tests\n",
2021-05-03 17:48:15 +01:00
"Nodes: [1, 2, 8, 16, 32, 64]\n",
"Epochs: [1, 2, 4, 8, 16, 32, 64, 100]\n",
2021-03-22 20:49:29 +00:00
"\n",
"Loss: categorical_crossentropy\n",
2021-05-03 17:48:15 +01:00
"LR: 0.05\n",
2021-04-28 21:57:13 +01:00
"Momentum: 0.0\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
2021-03-21 09:56:27 +00:00
"multi_param_epochs = sorted(list({i[\"epochs\"] for i in multi_param_results[0]}))\n",
"multi_param_nodes = sorted(list({i[\"nodes\"] for i in multi_param_results[0]}))\n",
"multi_param_iter = len(multi_param_results)\n",
"\n",
2021-03-22 20:49:29 +00:00
"accuracy_tensor = np.zeros((multi_param_iter, 2, len(multi_param_nodes), len(multi_param_epochs)))\n",
2021-03-21 09:56:27 +00:00
"for iter_idx, iteration in enumerate(multi_param_results):\n",
2021-03-19 17:21:00 +00:00
" for single_test in iteration:\n",
2021-03-26 20:01:05 +00:00
" accuracy_tensor[iter_idx, :,\n",
2021-03-21 09:56:27 +00:00
" multi_param_nodes.index(single_test['nodes']), \n",
2021-03-26 20:01:05 +00:00
" multi_param_epochs.index(single_test['epochs'])] = [single_test[\"results\"][1], \n",
" single_test[\"history\"][\"accuracy\"][-1]]\n",
2021-03-22 20:49:29 +00:00
" \n",
"mean_param_accuracy = np.mean(accuracy_tensor, axis=0)\n",
"std_param_accuracy = np.std(accuracy_tensor, axis=0)\n",
"\n",
"print(f'{multi_param_iter} Tests')\n",
"print(f'Nodes: {multi_param_nodes}')\n",
"print(f'Epochs: {multi_param_epochs}')\n",
"print()\n",
"print(f'Loss: {multi_param_results[0][0][\"loss\"]}')\n",
2021-03-26 20:01:05 +00:00
"print(f'LR: {multi_param_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
2021-05-03 17:48:15 +01:00
"print(f'Momentum: {multi_param_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-04-30 20:51:04 +01:00
"cell_type": "code",
"execution_count": 215,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-04-30 20:51:04 +01:00
"outputs": [],
2021-03-21 09:56:27 +00:00
"source": [
2021-04-30 20:51:04 +01:00
"pickle.dump(multi_param_results, open(\"results/exp1-test2-3.p\", \"wb\"))"
2021-03-21 09:56:27 +00:00
]
},
{
2021-04-06 17:29:15 +01:00
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 267,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-04-06 17:29:15 +01:00
"outputs": [],
2021-03-21 09:56:27 +00:00
"source": [
2021-05-03 17:48:15 +01:00
"exp1_testname = 'exp1-test2-14'\n",
2021-03-22 20:49:29 +00:00
"multi_param_results = pickle.load(open(f\"results/{exp1_testname}.p\", \"rb\"))"
2021-03-21 09:56:27 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"np.savetxt(\"exp1-mean.csv\", mean_param_accuracy, delimiter=',')\n",
"np.savetxt(\"exp1-std.csv\", std_param_accuracy, delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
2021-03-22 20:49:29 +00:00
"mean_param_accuracy = np.loadtxt(\"results/test1-exp1-mean.csv\", delimiter=',')\n",
"std_param_accuracy = np.loadtxt(\"results/test1-exp1-std.csv\", delimiter=',')\n",
2021-03-19 17:21:00 +00:00
"# multi_iterations = 30"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Best Results"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-05-03 17:48:15 +01:00
"execution_count": 166,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-05-03 17:48:15 +01:00
"Nodes: 32, Epochs: 100, 96.1% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_param_accuracy_idx = np.unravel_index(np.argmax(mean_param_accuracy[0, :, :]), mean_param_accuracy.shape)\n",
"best_param_accuracy = mean_param_accuracy[best_param_accuracy_idx]\n",
"best_param_accuracy_nodes = multi_param_nodes[best_param_accuracy_idx[1]]\n",
"best_param_accuracy_epochs = multi_param_epochs[best_param_accuracy_idx[2]]\n",
"\n",
2021-05-03 17:48:15 +01:00
"print(f'Nodes: {best_param_accuracy_nodes}, Epochs: {best_param_accuracy_epochs}, {best_param_accuracy * 100:.3}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-22 20:49:29 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test Accuracy Surface"
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 269,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653358,
"status": "aborted",
"timestamp": 1615994110345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "ZGJVhz6iJU-7",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6QAAAMlCAYAAABkUz6gAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd5xcVf3/8deZvi3Jpm56QnovuylIDQpfQVGwACpSBAQhCH5RELBgRyn+gu0rvViiX0WRbxRBxSTEtN30sum9Z5Nts9Pn/P6YvZO7s1N3Z2e2fJ6Pxzwyu3Pn3jOzk5n7nnPO5yitNUIIIYQQQgghRK5Z8t0AIYQQQgghhBA9kwRSIYQQQgghhBB5IYFUCCGEEEIIIUReSCAVQgghhBBCCJEXEkiFEEIIIYQQQuSFBFIhhBBCCCGEEHkhgVQIIYQQQgghRF5IIBVCCCGEEEIIkRcSSIUQQgghhBBC5IUEUiGEEEIIIYQQeSGBVAghhBBCCCFEXkggFUIIIYQQQgiRFxJIhRBCCCGEEELkhQRSIYQQQgghhBB5IYFUCCGEEEIIIUReSCAVQgghhBBCCJEXEkiFEEIIIYQQQuSFBFIhhBBCCCGEEHkhgVQIIYQQQgghRF5IIBVCCCGEEEIIkRcSSIUQaVNK/UoppU2Xh/LdJtE1KaV6KaXcptdSUCk1JN/tEkIIIURuSSAVQqRFKVUCXBvz65vz0RbRLXwSKDT9bAVuzFNbhBBCCJEnEkiFEOmKDRAAk5RSc/LRGNHlxfsyQ77gEEIIIXoYCaRCiHSZw4Inwe+FSEkpdR5wYfOPYcDffH2yUqoiP60SQgghRD5IIBVCpKSUGg1c1PyjBr5suvlTSilH7lslurCbANV8/V3gTdNt8gWHEEII0YNIIBVCpMMcIJYCzwKnmn/uC3w4H40SXY9SShF5PRlea74Y5AsOIYQQogeRQCqESCpegNBaB4HFpt/F7dVSStmVUqdNlVTPz+C4b5vu95UU285RSv1YKbVBKXVKKeVXSh1XSi1VSj2klCpN43j7Tccb1fy7MUqp7yml1jfvN6yU2hDnvpOUUl9SSr2ulNqhlGpQSgWa71PZ3LbJ6T52034vV0otVkodVEp5lVLHlFLLlVL3KKWKmrd5zNTux9Lc7/uVUv+jlNqqlDqjlPIppY4qpf6ulFqolCrItK0ZuAgY3Xy9Cfgj8Fegpvl3/YAPZbpTpdQgpdSDSql3mp8vT/PloFLqb823jUpjP1al1HVKqVeb/5Znm/+WNUqp1UqpRc3Pn4pz34z+FkqpS03b/zuTbZRSVymlfquU2qWUamy+/f6Y+9qVUv+llPqRUurd5r+xt/l5Odz8vNyvlCpO1dY47cr4+W5+7ozH8ssMjnWr6X7rMm2rEEKITk5rLRe5yEUuCS9EAoRuvniAXs2/n2P6vR8YkOD+Pzdt97M0jzkYCDbfJwQMTbBdKfAH0/4TXc4Cn0hxzP2m7UcBn29+vLH72hBzv9+ncXxNZK7kjwFrGo/fQaTXMNn+tgETgMdMv3ssxX6HExkim6qtR4CLOuj19ILpOL82/f5npt//OYP9WYBvAO40HlcImJzitb4jzb/n43Hun/bfonn7S03b/zudbYDewOsJ2nR/zN/6dJqP5TRweUc/38BU0211QGGax3zPdL+7O+J1KRe5yEUucsnfxYYQQiRn7v18Q2tdD6C1XquUqgYmAnbg08CiOPf/FfCF5uvXKaXu05Ee1mRuILIMCMC7WusjsRsopcqAfwGTTL/eCmwEGoGBRAJGP6AP8Hul1Ge11r9OcWyIVBT+UfP1o8AKIifQQ4gMUTYb0fxvkEhI3AXUEjkZH0gkuA8lMuT5fsAJ3J3i+L8FPmb6+QyRMHKGSNC4hMjjXgL8JY3Hg1JqEvBPImEfIif365rb7Glu48VASfPjfEcpdaXW+t109p9mGwqJPLcG81DdVzn3vFyllOqvtT6dYn9W4H9puRyRH1hJ5AuGAFAGlBN53BYiYT/evm5oboPd9OudwHoif/tewJTmiwVwJWtbB1FE/j99mMjfr5LI309xLuwZioi89iHyhcxW4ACR/xsOIr3U84k8jn7AX5VSl2it/5Pw4O18vrXWW5RSK4HziTyfnyDynCd+wEpNAC5o/tEDpPP/VwghRFeS70QsF7nIpfNegAIiJ+NG78SHYm5/xHTbuiT72WPa7sNpHLfKtP0tcW63EAmjxjargVlxtnMB3yTSO6mJnIyPTnDM/ab9BQAfcAegYrZzxvz8AyIhq1eC/SrgauCkaf8XJnnst9Gyl+nJOMccSGSYqwa8pm0fS7DPIiLBxdjur8CYONv1omWP9lGgdxZfT58x7fsYMb3FRAKgcfsX09jf4zHP1U+Afgm2nQu8AkyJc9ssWvaGrwPmJdhPGZGiXg/Gue2xVH+LmO0vNW3/7zS2CTT/uwmYFmdbp+n6SOCZ5sdtSbDvXs2vL2P/OxJtm63nG7jFdP+laTxHPzJt/2q2XotykYtc5CKXznPJewPkIhe5dN4LkV5P42TwJGCLuX0k58KejneS3Lzdt0zb/DbFMSeatm0CSuJs81nTNiuBghT7NAeFXyTYZn/MyfZnsvxczjPt+3cJtrECh03b/U+S/TmANTFtfizBtl83bfN6stDRvP3Lpu0fyuJz8LZpv0+naGdVin2NJ9ILbWz/1Xa0yzwkdC1Q3Mb9mF9ncf8WMdtfatr+32lsYwT5/ll+bf7CtP8rO/L5JrKWca1pP+OSbGsDjpu2vTibj1sucpGLXOTSOS5S1EgIkczNpuu/1TFDbbXWB4BlCbY3+5Xp+keUUiVJjnmj6fobWuuGONv8t+n6XVprT5xtzB4nchIMkSquqd771uj0hvamTWu9Gtje/OP7E2z2QSJDZyEyR++rSfbnp+XyO3EppezAwuYffUSer3CKuxk93xDp1Ww3pdRQWj7u1+Js9ivTcWcrpaYm2eWXOFeYbxXwwza2ax7nhoRq4GatdWNb9pUj39YphjK3wUum6x9IsE1Wnm+tdRMth91+LsnmHwYGNV/fqbVelmRbIYQQXZTMIRVCxNUcIMwnp/ECBETmgF3SfP0zSqmHtNYh8wZa611KqTVEhvEVEpmDlmju2KdN138Ve6NSajAws/nHbVrrjckeR/Pxvc1z164kUhRmKpFhj4ksTnJbQkqp8UAFMKb5OE7OLZdD8+8A+imlhmutD8Xs4lLT9SVa69pkx9NaL1NKHeTcPNZ4KogM8QX4p9b6ZNIHEdnv0eb5wZOAqUqp3lrrulT3S+GznAs0W7XW6+Mcd59S6j3OrXl7M5CowvIHTdd/qrXWCbZLxbyff2qtt7VxP7nyu0zv0PylxDxgBpEhxyW0/Pw3f0E0M8FusvV8Q2TZKGO+8M1Kqa/Fvmc0u810/YV2HE8IIUQnJoFUCJHIjZwLENVa68oE2/2BSIVUF5GT3f8iMkcx1q+IBFJj360CqVLqAs4tCXIK+Huc/ZiXjilQSv00yWMwG2O6PpzkgbQqzX0CoJT6EPAdInMR09UfiA2kM03XV6e5nzUkD6Tm52tYBs9Xn+Z/FTCMyFzi9jD3nif6csO4zQikn1FKfTU2rCilBhGphGx4tx3tmp+l/eTCPq31mXQ3bl6+5xHgLiKvt3S02i7Lzzda642mL6gGA1cBb8YccwiRL5AgMnf2lfYcUwghROclgVQIkUhaAUJrXa+UegO43nS/eIF0MfA0kfedy5RSZVrr4zHbmIeHLo4dItxsiOn6aOCeRG1LItW6pKfS3VHzepPfbEMb4g1bHmC6HhtWEzmc4nbz8zW9+ZKplOu4JqOUmktkbjBE5hwnGw79v0SK5TiJhJUrgL/FbDPIdN2ntT7ajuaZ97W3HfvJhUxel6VECn/NzPAY8V6X2Xy+Dc9y7guq24gJpETeR4xK2/+ntT6RhWMKIYTohGQOqRCiFaXUHM4tp6JJvdSCObB+RCnVJ3YDrbW5x9MKfCrmmHbgugT7NOud4PeZSPVlXKo5qQAopS6nZRhdSWT90llEeppcWmtlXIClpm3jvf8Wm643pdMGIpWDk8nF85W
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1200x800 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-21 09:56:27 +00:00
"X, Y = np.meshgrid(multi_param_epochs, multi_param_nodes)\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-22 20:49:29 +00:00
"# fig = plt.figure(figsize=(10, 5))\n",
"fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
2021-03-22 20:49:29 +00:00
"surf = ax.plot_surface(X, Y, mean_param_accuracy[0, :, :], cmap='coolwarm')\n",
"ax.set_title(f'Average Accuracy')\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
2021-04-06 17:29:15 +01:00
"# ax.set_zlim([0, 1])\n",
2021-03-19 17:21:00 +00:00
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
2021-03-22 20:49:29 +00:00
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig(f'graphs/{exp1_testname}-acc-surf.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Test Error Rate Curves"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 270,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653349,
"status": "aborted",
"timestamp": 1615994110347,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "Jrn3hKQAlGcc",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9QAAAMMCAYAAACyue/GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdebzN1f7H8dfijIZjnp2iQbkpRKGEMnRFbiRdmpQp1W2iQW7SZCrK1TVcEn4VETKlRGYyJykihMxDOM7hOOes3x/ffbZ9hj2cfUbO+/l47Mf57u93fdda+7u/e5/9+a71XctYaxERERERERGRjCmQ2xUQERERERERuRQpoBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYREREREREJggJqERERERERkSAooBYRERHJAcaY/sYY63r0z6I8m3jkuSQr8nTlu8cj3ypZlOcEjzw7Z0Wekjuy41wWuVQpoBbJh4wxVTz+EWbVo39uvy4RERERkZykgFpELivZ1Voj+ZvnxaPcrouIiIjkHSG5XQERyRWngf/6SXMrcItr+QAw00/6tZmtlIiIiIjIpUQBtUg+ZK09ATzjK42rC3dyQL3DWuszvYiI5Dxr7RLA5HY9RETyK3X5FhEREREREQmCAmoRERERERGRICigFpEsYYypbowZYIxZa4w5bIyJN8YcNcasMca8ZYypGGA+RYwxTxpj5hlj9hpjYo0xF4wxp4wx24wxc4wxrxljaqTar79rwKjFHqsbexmRfE8WveZQY8wjxpipxphdxpgzxpizxpjdxpjJxpi2xhifXTG9DaJmjLnHlccOY0yMa/vzrm1V0nstxpiGxphxruN0yrX9Qy/l3m2MGW+M+c0Yc9oYE2eM+cMYM9MY09kYExrA608zBY4xprgx5jljzDJjzJ/GmATX9uL+j2jmj41HmlDXaxxijFlsjDlgjDnnep37jTHzjTHPG2OKBFJ+qvXeRrqv4iOvaGPM68aY5a66nDfGnDDGbDLGvG+MqRbgMQk1xjxsjJnhOudiXMf4jDFmpzHmW9fn7dZA8guUMaa0MeZVY8xSY8xBV/2Puer/njHmb37qfMzjODXIQLkLPPZ7yU/aW4wxHxhjfjTOd0+8MeaQq86vGGNKBFBemqmijDFXG2Pedb3Wo8aYJGPMj4G+hgDKDDHGPGqMWej6zJx3HeOvjDGtA9g/QwMxGmPCjTH/cp2LR12fid+NMZ8bY+7MxOu4zxgzy+M17DfGfGec78igbjE0jrbGmInG+a465foc73Mdn8f85W28f1/WNc735W/G+T9z0jj/v14zxhQOpr5eyu/sUf4Ej/VtjfP/bK/reB1xne8PG+P7/0aq/I0x5gHjfCf+7vpOiPF4T9tnJD9Xnne69v3DdbwPus6Xp4wxhTKSV6p8mxpjRhtjthrn+++8cb4PvzXGPGOMiQwwn2hjzBvG+T9z2JVPvDHmuDFms6vuPY0x5YOtq0iGWGv10EMPPdI8gP6AdT2W+EgXDowGEjzSp/eIBZ7xU2YDYL+ffDwfIV7q6++xJwuOTxNgZwBlrQYq+cnHfZyBYsAML3k979qniudrAcJc70F6+3yYqryywMIA6v0bUNfPMZjgkb4zcDuw10t+xYM8xhk6Nq79ooFjAZ4Lx4DmAZQfyKNKOnkUAN4C4vzsewF4FzA+jkc14JcM1OeaLPoueAL4y09ZCcAHQEEveYz0SPvfAMutwMXvlUS8fI6AEsCXARyPk0B7P2Xu8Xw/ge5e3rsfgzyW/T3y6A9UAlb6qfd4oECgnxM/5VcHtvkpbxQQmvpY+MizCDDPT57LgfKk+s7wU9ebgE0BvK/bgL/5yKeKR9o9OPebv+k6p7zluQu4Kos+P5098p2A8z02y89rmg9EBpD3tcDGAI7R+kBeD87YSh/7yWsrcB2pzmU/+UbjXOz2V88/gTv85NUd5/eEv7wssCIr3kM99PD30KBkIhI011X8b3ECqWS/AxtwfryWdG2rCEQCI4wxUdbaAenkFe3Kq6hr1QVgHU7QGgsUxvlhVBOISqc6a3FGLq8E3Oda52108uMBvsR0GWMeAD7D+dEJzg/uH3B+rCXhBD4NcH6c1AdWG2NusdYe9pc18CnQmos/gn5xra/hWpeeD4AeruUtwGac41fNVZ/kepfD+fF+tce+vwNrgPPA34B6rvXXAouNMX+31q70U2+Aa4APcX4sngGW4Rz/EkCjAPb3JyPHpjBQyrV8EucH4B9ADM7Fh6o470uEK93XxpjG1tpVqcr8k4uj4T/tsd7bCPmnU1TYmILAF8D9qfJcCxzFCUTq4bwfIcBrQBmcH4wpX7wxRXEuhES7ViXhBBu/ul5XIZxzvyZQ2kv9MswY0xt4z2PVeWApzoWTEsCdOJ/zgsDzwBXGmPbW2tTn6qdAT9dyB2PMc9baBD/F/9OVL8Bia+2f6dSvPPA9TqCYbCvOZyAG5wLSHTjvc3FgqjHmEWvtZ37KBngAGOJaPoDz2TmF831WMoD9/SkCfINz/sbiBJ37cL4D73TVHeBxYDswODOFGWOuBBbhXKhItpWLAdnNrro86apPIHmG4gTTnp/xQzif/zM43wsNXY+ZOIFqIPk2AuZw8bs++f/BDtdyFVeeETjB3SpjTANr7a8BZP8G0M+1/CPOd+YFoBbOMQDnO+IrY8zNAZynGRECTAeaAvHAKpzv4Aic8/QKV7q/A8O4+JlJwxhTHeezWMZj9Rac12SB2sCNrvV1cI5RI2vtbz7qNwno6PH8L5xA+Lirbk1w/k98Dcz2kU/qenqedxbnnPsF539nJZzzpyjOZ+s7Y0xLa+3idPK6Dxjjseo0zkXr/TgX34rh/N+rgfNdL5Izcjui10MPPfLmgwBaqIGJHmm2A03SSVMQ50fBOS62ZDVIJ90HHnktAyp6KTMEaIzzAz1NaxgZaK0J8rjcwMWr40k4wUbxdNJdhfMDObkuX3vJz7O+F1x/fwJuTCdtuOtvFY99klvw9pLOlf3kfVzLX3vsFwP8M530dXF+4FmPfNO8PlfaCenU/SOgSKp0ofhoYfNxrDN8bFzLVwL/wZn6Ld1ycX6ov5/q/PXVCuhu9chA/d/y2O8g0I50WqBxAre/PNJ2SCfNcx7btwLXeSnT4IzOPxKIzuS5fhspe558DZRLfdxxgk7PVqEXveTneV61DqD8DR7pO6ezvQBOMJ2cZg1QO510EThBVJLHuV/VS5l7Up1z54Fuqd83z/Mtg8e0v0f+yd+JE4CSqdIVAj73SHsGKBzA52SJj7I9e6b8ld57ALQETrjSxHukr+Ilz9c90iThXBQqmCpNNS4Geed9vaeu9OWBwx7pJgIV0klXjpQ9Vn5KXbYrXRWPNOdd9dwJ3Orls+j5uh/NzGfIlWfndN7zr0nV4wLnf9t7qY6nt+Me5nFMret4NUsnXQuci3fJ6TYAoV7yfMQjnQVGkKqVHCcoXpTOe9nfS56FSdmr5mvg6nTSRZGyF8sBoFg66Talql8hL+UWcb2XgzL7/umhRyCPXK+AHnrokTcf+Amoca6mJ2/fCZT2k5/nj4r56Wxf77E96K6qZH9Avcgj/xf8pC2ME/wkp6/np77JgZe/Y1kl1T5ngWp+9rkz1T6t/OT/l0fafl7STUiV59gsPtYZPjZBlDHKI/+WPtK56xFgvlW4GIweT+9HpI/35xfSBnCeXZrT/HDOjgdO61dymSuBMB9ph3ukPQUUTSfNmx5pJvsp+3qPtLFe8vMMAFbjp4ssKb/TRnlJsyfVOfdQFh/T/qny/9xH2ghS3kLxoJd0np+TJV7SNPdIkwTc6aPcO7h48SH5USWddMVwvnuS07zhI88yOEGSZ56dvaT17HI83M/xLEjK7+Q0x4i
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-05-03 17:48:15 +01:00
"<Figure size 1000x800 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-05-03 17:48:15 +01:00
"fig = plt.figure(figsize=(5, 4))\n",
2021-04-30 20:51:04 +01:00
"# fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-04-30 20:51:04 +01:00
"for idx, (layer, std) in enumerate(zip(mean_param_accuracy[0, :, :], std_param_accuracy[0, :, :])):\n",
"# plt.errorbar(multi_param_epochs, 1 - layer, yerr=std, capsize=4, label=f'{multi_param_nodes[idx]} Nodes')\n",
2021-03-22 20:49:29 +00:00
" plt.plot(multi_param_epochs, 1 - layer, '-', label=f'{multi_param_nodes[idx]} Nodes', lw=2)\n",
2021-03-19 17:21:00 +00:00
"\n",
"plt.legend()\n",
"plt.grid()\n",
2021-05-03 17:48:15 +01:00
"plt.title(f\"Test error rates over hidden nodes\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
2021-04-30 20:51:04 +01:00
"plt.ylim(0, 0.6)\n",
2021-03-22 20:49:29 +00:00
"\n",
"plt.tight_layout()\n",
2021-05-04 15:24:37 +01:00
"# plt.savefig(f'graphs/{exp1_testname}-error-rate-curves.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-05-03 17:48:15 +01:00
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 271,
2021-05-03 17:48:15 +01:00
"metadata": {},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9QAAAMMCAYAAACyue/GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU1f7H8feBQELvnQgoVUXpRRFQiiJFQMULFlAERbyKYPcnFxuISrso6EURvXhBEJAmiiBdkCqgiIKIgBTpLbSQ8/tjdpdJsrvZbLIhkM/refbZ2ZkzZ87Ozs7ud86Zc4y1FhERERERERFJnWwXuwAiIiIiIiIilyIF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIiIiIiJhUEAtIiIiIiIiEgYF1CIiIpmYMaabMcZ6HuMudnmCMcYsdJW16cUuj1x6InEMGWMGuPIckE55lnfluT098vTku92Vb/n0ylcyns6HWYcCapEAkvxYptdjwMV+XyIiIiIikj4UUItkQcaYpq4gf+HFLo9cHtwXjy52WTJKJGq+RERE5NIRdbELIJKJHQPeSyFNPaCuZ3o3MC2F9CvTWigREREREckcFFCLBGCtPQQ8HiyNp0bKG1BvsdYGTS8iIiIZy1o7ABhwkYshIpcpNfkWERERERERCYMCahEREREREZEwKKAWyUDGmGrGmIHGmJXGmH3GmLPGmP3GmB+MMa8aY0qHmE9eY8yjxpjZxpgdxpg4Y8w5Y8xRY8xmY8xMY8yLxphrk6w3wNNh1ALX7CYBeiTfnk7vOYcx5n5jzCRjzDZjzHFjzEljzB/GmAnGmA7GGJNCHn47UTPG3O7JY4sx5oRneR/PMr9DmhhjGhljPvTsp6Oe5cMDbPdWY8xYY8xvxphjxphTxpg/jTHTjDOUUY4Q3v84Vzm6eeYVNMY8aYxZbIz5yxgT71leMOU9mvZ940qTw/Me3zLGLDDG7DbGnPa8z13GmDnGmD7GmLyhbD/J/EA93ZcPklesMeZlY8wST1nOGGMOGWPWGWPeMcZUTu3+SYlnm//yfBb7PNs8a4w5aIxZb4z5nzGmlzGmZJL1Fnre879cs/8V4D2PC7L9WsaYMZ7vxinjnA9WGmOeNcYUTu/3Gy5jTDZjTFdjzLfGmL2e42S7MWa6MaZ9GvIN+/xgjClunPOeNcacN8aUScV2f3V9PneHW/4keV5jjHnbc7we8BxLuz3HynPGmCJB1q3lKs9RY0xMiNuMMRfOY9YYUzdIWuPZn58Y55x21PM57jTGfOn5fIPeCmjS4bwaDmNMYc8+XOXZt6c8x8tHJsnvXID1U9V5oDGmlDHmDWPMBuOc+48ZY342xgwzxlQJ8z1EG2P+aZzz237Pe/jdc465OZw8Pfnm8ZyjZhrn9ynO8z3aYpzfr1tCyMPv0Hye42Wmcf5nnDHG/G2MmWuMuS/Q9zLM9+DvdzK3MeYxY8xSc+HcvNNzXrgxlfnnNcY8YYz5xji/baeNMYeNMT8ZY941xtRPZX6Z7nzoJ6+6nve21vNe4z3H3B5jzApjzGhjTCdjTJ5wyysu1lo99NAjzAfOPVnW81gYJF008D4Q70rv7xEHPJ7CNhsCu1LIx/2IClDelB7b02H/NAW2hrCt5UCZFPLx7WegADA1QF59POuUd78XIKfnM/C3zvAk2ysOzAuh3L8BdVLYB+Nc6bsBNwI7AuRXMMx9nKp941kvFjgQ4rFwAGgRwvZDeZT3k0c24FXgVArrngPeAEw6fX974nznQin30iTrLkzFex4XYPuvE/ycsBNo4DluguYV4fNcSWBFCu9xKpAvyX5pGunzA/CVK90zIb6feq51jgAxadw/UcC/U/gsLXAY6Bokn02utHeHuO1OrnU2B0l3HbAuhH29Gbg6SD7lXWm3k4rzair2Z6JjCOecGew3Lx7okUKeA1zpB6SQtoPnswq0vdPAw0n3RQp5VvPs22D7fjSQw7NfA54vk+R7N7AnhM91JlAgSD7dXGnH4fyOTE8hzzlArrSeXzzbH+fKtxtwNYm/D/4er4SYd5sQ99FnQO4Q8su050PX+eiDEPLxPl5Pj88wqz/UKZlIhHmu/n2D86fA63dgDc6PdmHPstJALmCkMSa/tXagn7xiPXnl88w6B6zCOQnHAXlwfuSvB/L7Kc5KnJ7LywDtPfMC9U5+MMS36Jdxan0+w/mDAE6wtALnz0ICUBnn4kAUTtCw3BhT11q7L6WsgfE4P5IWWI3zw2uAaz3z/BkGPOKZ3gisx9l/lT3l8Za7BLAMuMq17u/AD8AZnB9679XsSsACY8xt1tplKZQboCIwHOfPynFgMc7+LwQ0DmH9lKRm3+QBvDVmh4GfgT+BEzh/kivgfC4xnnRfGWOaWGu/T7LNv7jQG35v1/xAPeQfS1RgY7IDnwN3JslzJbAfyIuzv6/COVZeBIrhBMNh89QifJCkXMtx/rjH43xGlXH2W04/WUwDfiJxT/+r8N+T/wo/2x8IvOCaFQd8h/PHryRwC1AWJ2AcHtKbigDjtJr4Dicg8PoDZ1+dAa7B2QcdcH2PQsg3vc4P44FWnul7gbdD2Py9rukvrLWnQy13UsaYbMAUoJ1r9iGcP9KHcC5c3YxzDBUExhljClprR/jJbjzOBSNvGSeHUAT3exkfoIyNcQIq72+C93dji2e6PNAI57teBfjeGNPQWvtLCNsP6byaBtcCg3DOA38DS3B+m8rgfEdyAdmB940xG621yb5rqWGMaQ1M4kKnvQk4vwe/ecrQGCgFjAGeCDHPcsB8z3pePwNrcc7JtXDe56M454FQy/oUMATn/A6Jz2HZcb6bdTzL2wALjTE3WmtT2kYUzjHdDDgLfI/zGxgD3ARc4Ul3GzAU6BVqmUNUGueCdimcC15LgL1AUZzPvIAnXX9jzCZr7eeBMjLG3INznsnumXUeWIrznykvzvvxtgzsAlQwxtwS6JxwCZwPwTkHun8f3b+n2XB+z6/G+a5LernYEb0eelzKD0KooQY+caX5FT9XKXFO9r1wrnx7r7g39JNumCuvxUDpANuMAprg/MHK7md505TKncb9cg0Xav4ScE7wBf2kuxLnx9Jblq8C5Ocu7znP8wagup+00Z7n8q51vDVHO4CbAq3jmXbXeJ0A/uEnfR2cPxjWlW+y9+dJO85P2d8F8iZJlwPIFsa+TvW+8UyXw6lVqxdouzh/wN9JcvwGLKMrnU1F+V91rbcH6IifGmicmpgjrrSd0niMrnPlNZIANRM4f7ruBt4MsHyAK58BIW67sed74V1vMlAoSZoCwATP8jOutOPS8r7D2E8fubZ9BnjIT5p6XKhVc5e1aYA80+38AOTGuTjlTXNNCu8nO7AvpTKmYv886z7ucYK/nEnSlMS5EOr+ntb3k1c513FxBiicwrYL4wQ83v1YwU+akkne7ydAKT/pSpC4ZcsG/P92lHelCfm8msp9utC1jdOe7fTF1drKky4WJ4j3pv0uSJ4pfk9xAg33vtoAVEuSJpvnM09IcqxvD7Jtd2unI0AbP2la4VyAsa7P1BKghhon2D3vOlaew885DKiBE7x78xsVIL9uSfa5xfktLJMkXRTO99WbNiFQGVP5mY/zs/03k74nzzE/35X2dwK0WMK5COs+N/wAVPTzefZ17UsL/DtIOTP7+bAIF/4HxANdg+yfUsA/ge5p/fz0sAqo9dAjLQ9SCKhxrn56l28FiqaQn/tHbY6f5atdyyumodxNg5U7HfaL+wfvqRTS5knyg+/vj6a7vBYn8EppX5ZPss5JoHIK69ycZJ3WKeR/xJW2f4B045LkOSad93Wq900Y2xjtyr9VkHS+coSYb3ku/Ck/CFyVis9nU6A/CiFsN68rnx3
2021-05-03 17:48:15 +01:00
"text/plain": [
"<Figure size 1000x800 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(5, 4))\n",
"# fig = plt.figure()\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (layer, std) in enumerate(zip(mean_param_accuracy[0, :, :], std_param_accuracy[0, :, :])):\n",
"# plt.errorbar(multi_param_epochs, 1 - layer, yerr=std, capsize=4, label=f'{multi_param_nodes[idx]} Nodes')\n",
" plt.plot(multi_param_epochs, std, 'x-', label=f'{multi_param_nodes[idx]} Nodes', lw=2)\n",
"\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.title(f\"Test error rate std. dev over hidden nodes\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Standard Deviation\")\n",
"plt.ylim(0, 0.1)\n",
"\n",
"plt.tight_layout()\n",
2021-05-04 15:24:37 +01:00
"# plt.savefig(f'graphs/{exp1_testname}-error-rate-std.png')\n",
2021-05-03 17:48:15 +01:00
"plt.show()"
]
},
2021-03-22 20:49:29 +00:00
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-22 20:49:29 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test/Train Error Over Nodes"
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 272,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJkAAAScCAYAAAAoB16YAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU5drH8e+dQgiEjggCggWP2AsqYCgKir1jPWLvXVQ4NrCDHo+9o2Jvr4hdUVEgKAoodtCAIIqg9BZKkuf9YzbJJGySTXY3s7v5fa5rrszOPPPMvTs7u3fumZ0x5xwiIiIiIiIiIiLRSAs6ABERERERERERSX4qMomIiIiIiIiISNRUZBIRERERERERkaipyCQiIiIiIiIiIlFTkUlERERERERERKKmIpOIiIiIiIiIiERNRSYREREREREREYmaikwiIiIiIiIiIhI1FZlERERERERERCRqKjKJiIiIiIiIiEjUVGQSEREREREREZGoqcgkIiIiIiIiIiJRU5FJRERERERERESipiKTiIiIiIiIiIhETUUmERERERERERGJmopMIiIiIiIiIiISNRWZREREREREREQkaioyiYiIiIiIiIhI1FRkEhERERERERGRqKnIJCIiEgAzcyVD0LGIiIiIJAozm+vLkzoHHY/UjIpMUsrMPvP/01PDYXTQ8ScbMxtdzWu6xswWmNl4M7vVzLYLOub6xMz61nJf+Czo2GPFzLYwsyND778PzGxJhefaOU7rHV5hPUVmtlOEy/q328J4xCci9Y9ypLqlHCmx1fccycwam9lhZnaPmU0ws7/MbH3ofTnPzMaa2Xlm1jgO6z49zOt6SITLdq6wXMNYxycCKjKJJLJGQDtgP+A6YKaZPRTEF4LpaEK9Y2YLgD+BsXjvvwFAy4DCSQNuDmjdIiKSeJQjSSDM7Hngb+Bt4HKgN9AWaID3vtwSOBJ4FJhrZsfUQVi3mJnVwXpEIpIRdACSsKYCX9Wg/ZR4BVJPzAQ+qTAtB+gK7AVYaLgQaGdmxzrn9BObuvVQhO1+jWsUdadd0AFUcLSZdXPOTQs6EBGp95Qj1S3lSImvPuVIxwFZvsfL8D4PFuC9D7sCe4fGWwOvm9kFzrlH4xjTHsAxwOtxXIdIxFRkksq855wbHnQQ9ciXzrmLw80wsx2Al4BdQpOORl8kda6y7ZPi1gEz8P6hmgosAj4MMJ5bgYMCXL+ICChHqmvKkRJcPcyR1uK9754Cpjjniv0zzWxH4Hlgt9CkB81ssnPu+zjGdLOZvVExFpEg6OdyIgnOOfcTcDCwxjf5vIDCkfpjD6CJc66Hc+5S59xzwC8BxDEdKAyNDzCz3ABiEBGRBKQcSQLwALC1c+5s59zn4Yo6zrkfgf2BeaFJ6cC1cYhlFrA8NL4DcHIc1iFSYyoyiSQB59wC4FXfpFz99lriyTn3jXOusPqWcZcPjPY9vi2gOEREJAEpR5K65Jy72jm3KIJ2y4CRvkkRXZy7hpYD//U9Hm5m+qWSBE5FJomrcBdDNLNtzOw2M/vGzP4xs2Izm+Fbxn8Hl76hae3M7Foz+8rMFobuNrW8knV2MrObzWyKmS0ysw2hv1PM7CYz6xhB3P67Znzmm36Imb1kZr+a2erQ/MujeIlqYoZvPBtoUVlDM0szs16h12Gcmf1uZmtDd774K3Q3luvMrHUVfZTegQLo5Jv1WyV3DOlbRV+tzGywmX1kZvPNbJ2ZLTezn0IX6uwWyQtgnqPM7EUzm2VmK0PvhTWh99p4MxtpZvuZWcJ9vplZjpldamYfmtkfoddhmZn9YGYPmtk+EfZT+rr7pu1qZveF+loamj82bk+mbt0CbAiN9zazA2PZeSw+Myr018zM/mNmU0Pbd3Xo/fqEme0ZRZzaj0RSiClHiqUZvnHlSEn42W6pmSNN9o03NbN43EDlPuCf0Pg2wJmx7NzMWpvZUCt/F73Foc+ou8z7yWpN+ssys0vMbFLoM67AzGaH3rf7RRFnYzO7wMzeNu8Of2vNbFXo8+gpM9u/Bn3tb2ZPmtn3oX2xMNTfH6G47zXv7oINahtvynPOadCAcw7gM8CFhuEx6nOur8/OwLlAgW9ayTCjkjj64t2hYWmYZZaHWd91lfTvHwqAIdXE3dfX/jOgGTCmkv4ur+VrM9rXx+gI2p9TYb1bVNIuE/ijmtegZFgN/LuSfjpH2Efptqqkn4vwjrRUtWwx8CTQoIrnvznweQ3i6R/le9f/HnAx2BcOA/6KIO4XgEbV9FUuLmA43k/KKvY1Nhb7cRXvic6x7N+3nuG+dbwcmna/b9qXEW63hRGsKyafGb7+cvHuyldZX0XAjeG2YzX9JuV+pEFDqgwoR/L34/+c/QzlSJEMfSvpJyk/21GOFOnz2rnCOjeLQZ+n+/qbEpp2pW/a70BWhO/bhtWs68wI3p+FwD1AegSxd8W7kH9V/T2Ct5/O9U3rXE2/AyN8/7wNNKuin8bAmzXYj86O93soWQedTid1aSBwZ2h8AV51fwWwBZXfGr0n3pdDJrAEmAgsBtoAu/sbmtmDeF/WJVYDnwIL8W4tuh/e3UgaAiPMrK1z7ooI4ja8i/cdhveBMg34KTR9p9C0urCFb7wI7/UIJx1oHxpfDfwIzAFW4r2OHYDuQFO8D9PnzGyjc+6VCv2spOxuIYOAJqHxZ4FVYdb7Z8UJZnYvcJlv0mLgC7xt0hBvG+6E91qeCWxhZoe6TS+gmA68C/jPAvkhNCwP9dUW2JXEuysaZnYCXmKUHppUBOTh/RQsB+hF2fY9GdjKzPZ3zq2LoO+rgWGhh7Px7nCyFi+R2Bijp5AIbgfOwrs98N5mdqRz7s1oOoz1Z4Z5Zym9H1qmxDTge7xbG3fHO8p4k5ktq0Gc96L9SCTVKUeKjnKkJP1sT/EcaWffeAHeNo6Hh/EKTe2BjsD5eGc41ZqZXQXc5Zu0HpiAV8RqgfeZ0RJvu10ObGlmx7lQtSZMf53w7hLpfw/+CHyN9zmxB977/Xy8bRRpnFcAd+PtJ+Dtm1/gFZPTgR2BbqH5hwGfmdm+zrlw63geOML3OB/4Bq+QnwlshrdNO0caX70VdJVLQ+IMxP8o3Ua8D6hzAKvQLquSODbiHcG5HsisYpnjKV9ZfhpoWqF9U+C5Cu2OqSTuvhVicMB3wM5h2oY9WhDBazPat47REbT3H6GaWkW7Bnh3u+hb8TXzxwxc7Xtuy4CcCLdj5wif35m+ZVYAZ4eLB+9Lyn9U8ZowbY70zV8A7FPFencERgB7R/ne9b8HXBT9bIOXcJb09SWwbYU2aXjJQZGv3f1V9Ol/D2/ESyKPitV7s4r1dq6w7ojeC7VYz3DfOl72TR/pm/4dFT5Hwmy3Ss9kIvafGQ3w/rEqafc70CNMu0F4d+1bH8n7K9n3Iw0aUmVAOZK/bd8KMShHUo5U235SJkeqJJZxvljeiVGfp/v6nOKbfoFv+kLCnPFFhGcy4RWw/Wd/vQdsHmY/ubNCf1dWEffHvnbLgcPCtDmYsjMzN1S3XwH9fO+L9cCQSp73bngFrZL+Hg7TZlff/FXAwVU8l63xzgw9PN7voWQdAg9AQ+IMlE9cvgIerMHQspI+51b48DmlhnE44Lpq2qfhHYUqaf8qYf75DLU1YKyvbT6QFqZd3wox/AW0jvHrPdrX/+hq2v67QjznxCiGIb4+L6iinX87do6g3yZ4SVnJh36lCU+ofVfKTuFfXPELAu+ihiXrr5NTU8O8ByLZDzbZLsAzvj5+perTdK/wtS0CtqqknavQrncdvSadK6y72vdCLdcz3LcOf5GpJV4yXjLvxGq2W9giU5w+M/w/1SgAtq/i+Z1S4XV0lbRL+v1Ig4ZUGVCO5G/Xt0IMypGUI9X7HClMHIdViKXSokUN+z3d16e/yJRZYV8fGmbZzhViqqzINMHXZjJV/1TzPl/bFXh3Jq7Y5gBfm2Jgvyr66xVqU2W+iffZ9ouvzdHVvG5t8YpvDq+A1aHC/It9fd0axHsmlYbAA9CQOAObJi41GTpX0udcX5tKr6NSRRx/AhnVtD/I13490Laa9u0pXx0fEKZ
2021-03-22 20:49:29 +00:00
"text/plain": [
2021-05-03 17:48:15 +01:00
"<Figure size 1200x1200 with 6 Axes>"
2021-03-22 20:49:29 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(6, 6*math.ceil(len(multi_param_nodes) / 2)/3))\n",
2021-03-22 20:49:29 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rates For {nodes} Nodes')\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[0, idx, :], markersize=4, lw=1, label='Test', capsize=4, c=(0, 0, 1), ecolor=(0, 0, 1, 0.5))\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[1, idx, :], markersize=4, lw=1, label='Train', capsize=4, c=(1, 0, 0), ecolor=(1, 0, 0, 0.5))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(1 - mean_param_accuracy + std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-05-04 15:24:37 +01:00
"# fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate.png')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 273,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABMcAAAScCAYAAACP2uhjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd5wT1frH8c+ztF2aIqAoqKCg2BsqKiB2sTcs1957L+hVfmAHu9d+1SuWay+o115AwYoKNlBABAUFQUR62z2/P85kM5tNNslu2m6+79crr51kzpw5yaQ8+8yZc8w5h4iIiIiIiIiISDEqyXcDRERERERERERE8kXJMRERERERERERKVpKjomIiIiIiIiISNFSckxERERERERERIqWkmMiIiIiIiIiIlK0lBwTEREREREREZGipeSYiIiIiIiIiIgULSXHRERERERERESkaCk5JiIiIiIiIiIiRUvJMRERERERERERKVpKjomIiIiIiIiISNFSckxERERERERERIqWkmMiIiIiIiIiIlK0lBwTEREREREREZGipeSYiIiIiIiIiIgULSXHRERERERERESkaCk5JiIiIiIiIiIiRUvJMRERERERERERKVpKjomIiIiIiIiISNFSckxEGgQzO8HMXHAblu/2iEhumdng0HfA4Hy3R0SkUChGEilu+g5IjZJjGWJmI0NvuHRvw/Ld/vrGzIYleU0XmdlvZva+mV1nZhvku831kZltY2a3mdnHZjbLzJaZ2VIzm2Nm48zsOTO73Mx6m1njfLc3k2L+0Y69rTSzP81sipl9GbwfzzGzrfLd7kJWh+/JvvlueyaYWVMz29bMzjSz/5jZt8F7Keu/BXFe02fT2DZ83M7IVhul4VKMlFuKkXJDMZJipExSjGQbmdl5ZvasmY03s7/NbEXwefrCzO7M1nvIzKbGvKafpbFt+Pt2SDbaJ7nToL6oRUKaB7c1gV2Af5rZfcDFzrmluWyImU0F1g3udnHOTc3l/mvDzDoBDwD7JCjSDGgLbAEcFjz2t5m1dc6Vx6lvMDAouHu1c25wRhuce42A1YJbF2Br4HgAMxsH3A88FO+1kOJkZucAtwJN892WwGFmtoVz7ut8N0REck4xUh0oRkpKMZKkzMz2AG4HNklQpG1w2wY4z8yeA85wzs3NYrO2M7MDnHOvZHEfUoCUHMuOMcDnaZT/NFsNKRI/AO/FPNYS2AjYFrDgdhawppkd6pxzuW1i/WFm6wIfAuuEHp6Hf0/PAFbgA57u+Ne4UVBmFfzr3BD9BrwU81grYFWgM7Ap0Z64W+IDvxPM7Fjn3OTcNLHeSed7ckY2G5Ij7SicxBj4z+p1wP75bogUHcVIuaUYKYMUI8WlGCnziilG2oaqiTEHfANMBP4C2gO9gr8A/YGNzWxn59yfWWzXtWb2qr4Pi4uSY9nxegM461OffOacOyfeCjPbGHgK2Dx46GDgEOCFHLWtPvoP0aBvPnA+8F/n3IrYgmbWGtgP+AfQL2ctzL1Jid5jAGbWCv86XAT0CB7uCXxuZj2dcxNz0Mb6pli/J3/FB72RwPdMoj0Lcm2/4P2p5IPkUrF+9vNFMVJmKUaqTjFS5hXj9+Q44EHgmdikl5k1BS4ErscnnDcB7gWOyGJ7Ng/qfzqL+5ACozHHpEFzzo3HBySLQg+fnqfmFDwz2xbYNbjrgP2dc8PiBX0Azrn5zrknnXP7Ad2Aouwi75xb4Jx7CtgO+CfR16EN8D8zWyVvjZNC8QjQwTm3jnPuUOfcEOfc+1T9bsqVcDLsujzsX0QKgGKk9ChGqh3FSJLEROAg59xWzrl74/UGc84td84NBS4NPXy4mXXPQnvCMdJgM2uUsKQ0OEqOSYPnnPsNCA8+3cvMGmrX9rraM7T8mXPuw1Q3dM5NKfaux867Ebgi9HA3/JllKWLOuWnOuVn5bkfgGvylPwC7mdku+WyMiOSPYqS0KEaqA8VIEo9z7kXn3MspFv8X/jLeiETj/tXF7UBkPLMNgeOysA8pUEqOFbCYmTM6B4+tb2bXm9lYM5ttZhXB4JaRbcIznfQNHlvTzP5pZp+b2UwzKzezeQn2ua6ZXWNmn5qfeWd58PdTM7vazNZOod19Q20YGXp8HzN7yswmmdnCYP0FdXiJ0jEutFyGP1sVl5mVmJ9Z6Boze9vMfjGzxeZnIfrd/OxOV5pZuxrq6Bx5DYgONAvwc5xZZmqcacbM2prZxWb2jpn9an4mpHnmZ3K5x8x6JNq2FjqGlqfVtbLI+5HoQLMAgxK8BsNqqGdrM3vQ/MxHS4L3/udmdpmZrVbXdmbBLcBHofvnm1nLZBvV5Vib2SGh1/LHVBtqZp2C74TIDFMdUt22EJhZTzO728y+N7O/gtdsupm9aX52rBYp1BGedWtw8FiZmZ0c+g5YHqzfMtvPKQd+Bh4O3c9o7zHz+gff9z8F3/cLg+Unzewws/T++TazXYJtpwXH+HczG2VmZ5lZ8zq0dTczuz94/8wNvud/M7O3gvdPWYr1rG1mg8zsQ4vOWrfc/IxtXwdtP7O+fb4KmSlGyqRxoWXFSIkpRsoMxUg5Yg0sRgomcQjPJNk5C7uZD9wUuv9/5i/rzBgz28v8rOUTzWx+8LmdZmYvmdkJZtYkzfqy8h1gfob1283Pvjs7OM4zzewDMxtgZgl/K2LqaWdml5jZu+ZjrKXmZyKdF7w3nzezi8ysS23bmjHOOd0ycANG4rtYO2BwhuqcGqqzM3AasCT0WOQ2LkE7+gIH4rPfsdvMi7O/KxPUH74tAQYkaXffUPmR+EFIX0xQ3wW1fG2GheoYlkL5U2P2u1aCck2A6Uleg8htIXBMgno6p1hH5bFKUM/Z+IFea9q2Av+PbtMMvOfuDtX7WYY/F8lucY8j/h/3lTVs9yt+7IoT0nlPpND2weH3cS22PySmnQcnKV+nY42fHeuvUNltU2znZaFt3qrr65bG+2FwHetqgR8HItn76jegXxrHejB+EOXvEtS3ZYZfk2GZfN/WsJ/wc+iO/ycv/H2/T4rH7Ywk++kGfJXCcfkCWC+FdjcO3vM11fU9/uxuleOYpN61gREptHMG0DtJXacBi1OoywGjs/kZK9RbJj/7oTqnhursjGKk8D6GheoYlkJ5xUipva6KkaL7HRx+H9die8VINb8fBtexrgYRIyVozwuh/d2ToTqnhurcGz+b78zQY2fVsO2wULkhSfazOvBuCsdlItAjxbZn/DsAf4Lk+RTa+RdwWJK6DiT+72y82/Rsv3+S3TQgf/3Rn2gW+zf8GZe/gbXws+LEsyP+S6wJ8Cd+dp05+A/mVuGCZnY3/ocnYiH+H4eZQAf8VN8tgVJgiJl1cM5dmEK7DXgCPxCnw/9DND54fNPgsVxYK7Rcjn894mlE9MzgQvw/XVPwZxGaAJ3wXzCt8T88j5vZCufcMzH1zAfuCZaPw8/aA/AYsCDOfqvNNGNmd1C1q/kc4BP8MSnFH8NN8a/lScBaZravc64iwXNLxU+h5R5mtptzLnaWq3S8hP8B3Q4/KxYknoGn2qDgZnYDVbvfLwbeB37Hvy93xR+T14E76tDObHgVWIo/VgC9qT6bE5CZY+2cW2Z+eutTg4eOxr/WyRwdWn48hfJ5Z76n0Pv491XEb8Ao/Oe2K35mo0bAmsArZnaUc+75FKpvC7yJH3B5KTAa30OgJf6z3yA452aY2b34AZIBrjOzN1wQydSGmW0EfEB0RimAb/G9Uhz+fbxZ8Pg2wMdm1sfVPBjzY8BRofvz8L9Nf+KPUV9gY/x3QEpTrgftfA//3iBo21f436Yl+N+APvjv7bWAd8ysn3NuRJy6DgIeCD00H//ZnY4PVlcBNsB/fgtpttKGRjFS3ShGSo1ipMxRjJQlRRAjbRZa/jUbO3DOLTaz6/GXcQJcZWaPOOeW1LZOM1sD/9u0fujhn/A94ZbhY5ntg8e7ASPMbG/n3EckkI3vgKB35Pv4JGjE98DX+PfP6vjPa1v8jLTPmp959r9x6uqBT7JFck5L8N9lU4Pn3Br/emyGT0jmX76zcw3lRvbPiq7Av4lOBSymXLM
2021-03-22 20:49:29 +00:00
"text/plain": [
2021-05-03 17:48:15 +01:00
"<Figure size 1200x1200 with 6 Axes>"
2021-03-22 20:49:29 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(6, 6*math.ceil(len(multi_param_nodes) / 2)/3))\n",
2021-03-22 20:49:29 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rate Std Dev. For {nodes} Nodes')\n",
" ax.plot(multi_param_epochs, std_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, std_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-05-04 15:24:37 +01:00
"# fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate-std.png')"
2021-03-22 20:49:29 +00:00
]
},
2021-03-19 17:21:00 +00:00
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "eUPJuxUtVUc3",
"tags": [
"exp2"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 2\n",
"\n",
"For cancer dataset, choose an appropriate value of node and epochs, based on Exp 1) and use ensemble of individual (base) classifiers with random starting weights and Majority Vote to see if performance improves - repeat the majority vote ensemble at least thirty times with different 50/50 split and average and graph (Each classifier in the ensemble sees the same training patterns). Repeat for a different odd number (prevents tied vote) of individual classifiers between 3 and 25, and comment on the result of individualclassifier accuracy vs ensemble accuracy as number of base classifiers varies. Consider changing the number of nodes/epochs (both less complex and more complex) to see if you obtain better performance, and comment on the result with respect to why the optimal node/epoch combination may be different for an ensemble compared with the base classifier, as in Exp 1). \n",
"\n",
2021-03-29 19:17:14 +01:00
"(Hint4: to implement majority vote you need to determine the predicted class labels -probably easier to implement yourself rather than use the ensemble matlab functions)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 249,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2",
"exp-func"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [],
"source": [
2021-03-27 16:29:31 +00:00
"num_models=[1, 3, 9, 15, 25]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_ensemble_vote(hidden_nodes=16, \n",
" epochs=50, \n",
" batch_size=128,\n",
2021-05-03 17:48:15 +01:00
" learning_rates=None,\n",
2021-05-04 15:24:37 +01:00
" rand_ranges=False,\n",
2021-03-19 17:21:00 +00:00
" optimizer=lambda: 'sgd',\n",
2021-03-21 09:56:27 +00:00
" weight_init=lambda: 'glorot_uniform',\n",
2021-03-19 17:21:00 +00:00
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
2021-03-26 20:01:05 +00:00
" round_predictions=True,\n",
2021-03-19 17:21:00 +00:00
"\n",
" nmodels=num_models,\n",
2021-03-29 18:34:04 +01:00
" tboard=True,\n",
" exp='2',\n",
2021-03-19 17:21:00 +00:00
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
"\n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
2021-05-03 17:48:15 +01:00
" \n",
" for m in nmodels: # iterate over different ensemble sizes\n",
2021-03-19 17:21:00 +00:00
" if print_params:\n",
" print(f\"Models: {m}\")\n",
2021-04-06 17:29:15 +01:00
" \n",
2021-05-03 17:48:15 +01:00
" # response dict object for test stats\n",
2021-04-06 17:29:15 +01:00
" response = {\"epochs\": list(),\n",
" \"num_models\": m}\n",
" \n",
" ###################\n",
" ## GET MODELS\n",
" ###################\n",
" if isinstance(hidden_nodes, tuple): # for range of hidden nodes, calculate value per model\n",
" if m == 1:\n",
2021-05-04 15:24:37 +01:00
" if not rand_ranges:\n",
" # just average provided range\n",
" models = [get_model(int(np.mean(hidden_nodes)), weight_init=weight_init)]\n",
" response[\"nodes\"] = [int(np.mean(hidden_nodes))]\n",
" else:\n",
" # get random val\n",
" node_val = random.randint(*hidden_nodes)\n",
" models = [get_model(node_val, weight_init=weight_init)]\n",
" response[\"nodes\"] = [node_val]\n",
2021-04-06 17:29:15 +01:00
" else:\n",
2021-05-04 15:24:37 +01:00
" if not rand_ranges:\n",
" # use linspace to generate equally spaced nodes throughout range\n",
" models = [get_model(int(i), weight_init=weight_init) \n",
" for i in np.linspace(*hidden_nodes, num=m)]\n",
" response[\"nodes\"] = [int(i) for i in np.linspace(*hidden_nodes, num=m)]\n",
" else:\n",
" # use random to generate nodes throughout range\n",
" node_val = [random.randint(*hidden_nodes) for _ in range(m)]\n",
" models = [get_model(i, weight_init=weight_init) for i in node_val]\n",
" response[\"nodes\"] = node_val\n",
2021-04-29 22:53:26 +01:00
" \n",
" elif hidden_nodes == 'm':\n",
2021-05-03 17:48:15 +01:00
" # incrementing mode, number of nodes ranges from 1 to m\n",
" # more nodes in larger ensembles\n",
2021-04-30 20:51:04 +01:00
" models = [get_model(i+1, weight_init=weight_init) for i in range(m)]\n",
" response[\"nodes\"] = [i+1 for i in range(m)]\n",
2021-05-03 17:48:15 +01:00
" else: \n",
" # not a range of epochs, just set to given value\n",
2021-04-06 17:29:15 +01:00
" models = [get_model(hidden_nodes, weight_init=weight_init) for _ in range(m)]\n",
" response[\"nodes\"] = hidden_nodes\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-05-03 17:48:15 +01:00
" ######################\n",
" ## COMPILE MODELS\n",
" ######################\n",
" if learning_rates is None:\n",
2021-05-04 15:24:37 +01:00
" # default, just load optimiser\n",
2021-05-03 17:48:15 +01:00
" for model in models: \n",
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
" metrics=metrics\n",
" ) \n",
" else:\n",
2021-05-04 15:24:37 +01:00
" for idx, model in enumerate(models):\n",
" optim = optimizer()\n",
" \n",
" # generate learning rate either randomly or linearly\n",
" if isinstance(learning_rates, tuple):\n",
" if not rand_ranges:\n",
" # get equal spaced learning rates\n",
" optim.learning_rate = np.linspace(*learning_rates, num=m)[idx]\n",
" else:\n",
" # get random learning rate\n",
" optim.learning_rate = random.uniform(*learning_rates)\n",
" elif learning_rates == '+':\n",
" # incrementing mode, scale with size of ensemble\n",
" optim.learning_rate = 0.01 * (idx + 1)\n",
2021-05-03 17:48:15 +01:00
" \n",
" model.compile(\n",
2021-05-04 15:24:37 +01:00
" optimizer=optim,\n",
2021-05-03 17:48:15 +01:00
" loss=loss(),\n",
" metrics=metrics\n",
" )\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-29 18:34:04 +01:00
" if tboard:\n",
2021-05-04 15:24:37 +01:00
" # include a tensorboard callback to dump stats for later analysis\n",
2021-03-29 18:34:04 +01:00
" if callbacks is not None:\n",
" cb = [i() for i in callbacks] + [tensorboard_callback(prefix=f'exp{exp}-{m}-')]\n",
" else:\n",
" cb = [tensorboard_callback(prefix=f'exp{exp}-{m}-')]\n",
" \n",
2021-03-26 20:01:05 +00:00
" ###################\n",
" ## TRAIN MODELS\n",
" ###################\n",
2021-03-19 17:21:00 +00:00
" histories = list()\n",
" for idx, model in enumerate(models):\n",
2021-05-04 15:24:37 +01:00
" if isinstance(epochs, tuple): \n",
" # for range of epochs, calculate value per model\n",
" if not rand_ranges:\n",
" if m == 1:\n",
" e = np.mean(epochs) # average, not lower bound if single model\n",
" else:\n",
" e = np.linspace(*epochs, num=m)[idx] \n",
" e = int(e)\n",
2021-03-26 20:01:05 +00:00
" else:\n",
2021-05-04 15:24:37 +01:00
" e = random.randint(*epochs)\n",
" else: \n",
" # not a range of epochs, just set to given value\n",
2021-03-19 17:21:00 +00:00
" e = epochs\n",
2021-03-26 20:01:05 +00:00
" \n",
"# print(m, e) # debug\n",
" \n",
" history = model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
2021-03-19 17:21:00 +00:00
" epochs=e, \n",
" verbose=verbose,\n",
"\n",
2021-03-29 18:34:04 +01:00
" callbacks=cb,\n",
2021-03-19 17:21:00 +00:00
" validation_split=validation_split)\n",
2021-03-21 09:56:27 +00:00
" histories.append(history.history)\n",
2021-03-19 17:21:00 +00:00
" response[\"epochs\"].append(e)\n",
"\n",
2021-05-03 17:48:15 +01:00
" ############################\n",
" ## FEEDFORWARD TEST DATA\n",
" ############################\n",
2021-03-27 16:29:31 +00:00
" # TEST DATA PREDICTIONS\n",
2021-03-19 17:21:00 +00:00
" response[\"predictions\"] = [model(dtest.to_numpy()) for model in models]\n",
2021-03-27 16:29:31 +00:00
" # TEST LABEL TENSOR\n",
" ltest_tensor = tf.constant(ltest.to_numpy())\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## ENSEMBLE ACCURACY\n",
" ########################\n",
" ensem_sum_rounded = sum(tf.math.round(pred) for pred in response[\"predictions\"])\n",
2021-03-27 16:29:31 +00:00
" ensem_sum = sum(response[\"predictions\"])\n",
2021-03-26 20:01:05 +00:00
" # round predictions to onehot vectors and sum over all ensemble models\n",
" # take argmax for ensemble predicted class\n",
" \n",
" correct = 0 # number of correct ensemble predictions\n",
2021-05-03 17:48:15 +01:00
" correct_num_models = 0 # when correctly predicted ensembley, number of models correctly classifying\n",
2021-03-26 20:01:05 +00:00
" individual_accuracy = 0 # proportion of models correctly classifying\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" # pc = predicted class, pcr = rounded predicted class, gt = ground truth\n",
" for pc, pcr, gt in zip(ensem_sum, ensem_sum_rounded, ltest_tensor):\n",
" gt_argmax = tf.math.argmax(gt)\n",
" \n",
" if round_predictions:\n",
" pred_val = pcr\n",
" else:\n",
" pred_val = pc\n",
" \n",
2021-03-27 16:29:31 +00:00
" correct_models = pcr[gt_argmax] / m # use rounded value so will divide nicely\n",
" individual_accuracy += correct_models\n",
2021-03-26 20:01:05 +00:00
" \n",
" if tf.math.argmax(pred_val) == gt_argmax: # ENSEMBLE EVALUATE HERE\n",
2021-03-19 17:21:00 +00:00
" correct += 1\n",
2021-03-27 16:29:31 +00:00
" correct_num_models += correct_models\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), True) # debug\n",
"# else:\n",
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), False)\n",
" \n",
" ########################\n",
" ## RESULTS\n",
" ########################\n",
" response.update({\n",
" \"history\": histories,\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"model_config\": json.loads(model.to_json()),\n",
" \"loss\": model.loss,\n",
" \"round_predictions\": round_predictions,\n",
" \n",
" \"accuracy\": correct / len(ltest), # average number of correct ensemble predictions\n",
" \"agreement\": correct_num_models / correct, # when correctly predicted ensembley, average proportion of models correctly classifying\n",
" \"individual_accuracy\": individual_accuracy / len(ltest) # average proportion of individual models correctly classifying\n",
" })\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"models\"] = models\n",
"\n",
" yield response"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Single Iteration\n",
"Run a single iteration of ensemble model investigations"
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 250,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Models: 1\n",
2021-05-04 15:24:37 +01:00
"[6] [3]\n",
2021-03-19 17:21:00 +00:00
"Models: 3\n",
2021-05-04 15:24:37 +01:00
"[5, 19, 1] [3, 18, 9]\n",
2021-03-29 18:34:04 +01:00
"Models: 9\n",
2021-05-04 15:24:37 +01:00
"[10, 18, 16, 5, 7, 14, 9, 20, 19] [3, 16, 7, 18, 15, 11, 15, 5, 11]\n",
2021-03-29 18:34:04 +01:00
"Models: 15\n",
2021-05-04 15:24:37 +01:00
"[9, 19, 2, 4, 20, 6, 17, 13, 19, 3, 3, 9, 13, 8, 16] [14, 7, 8, 9, 16, 1, 13, 13, 16, 3, 19, 9, 8, 3, 7]\n",
2021-04-06 17:29:15 +01:00
"Models: 25\n",
2021-05-04 15:24:37 +01:00
"[9, 13, 16, 9, 6, 14, 4, 6, 8, 18, 15, 11, 17, 11, 19, 7, 1, 19, 16, 4, 1, 10, 15, 14, 14] [20, 11, 4, 6, 20, 15, 13, 1, 4, 2, 12, 15, 6, 9, 15, 6, 6, 5, 5, 16, 15, 6, 16, 12, 15]\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
"single_ensem_results = list()\n",
2021-04-06 17:29:15 +01:00
"# for test in evaluate_ensemble_vote(epochs=(5, 300), optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
2021-05-04 15:24:37 +01:00
"for test in evaluate_ensemble_vote(hidden_nodes=(1, 20),\n",
" epochs=(1, 20),\n",
" rand_ranges=True,\n",
2021-05-03 17:48:15 +01:00
" learning_rates=(0.01, 0.5),\n",
2021-04-30 20:51:04 +01:00
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
2021-04-06 17:29:15 +01:00
" single_ensem_results.append(test)\n",
" print(test[\"nodes\"], test[\"epochs\"])"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 251,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABXgAAAOcCAYAAAD0DtmZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU1foH8O+bXiCN3kOk9xJAOoINQa6AFRsIV8XrVRHLtVzFn14biF30CgiKKAiocEFFpBcFgiBFpCX0HkIIJKTs+f0xu5uZzbZsdrPZ7PfzPPMwc+bMmXd3Z5bdN2fPEaUUiIiIiIiIiIiIiCjwhPg7ACIiIiIiIiIiIiLyDBO8RERERERERERERAGKCV4iIiIiIiIiIiKiAMUELxEREREREREREVGAYoKXiIiIiIiIiIiIKEAxwUtEREREREREREQUoJjgJSIiIiIiIiIiIgpQTPASERERERERERERBSgmeImIiIiIiIiIiIgCFBO8RERERERERERERAGKCV4iIiIiIiIiIiKiAMUELxEREREREREREVGAYoKXiIiIiIiIiIiIKEAxwUtEREREREREREQUoJjgJSIiIiIiIiIiIgpQTPASERERERERERERBSgmeImIiIiIiIiIiIgCFBO8RERERERERERERAGKCV4iIiIiIiIiIiKiAMUELxEREVEQEJGWIvKRiOwSkQsionRLsr/jI9dEpJ/uNVvp73hIIyIzdK/LSH/HE4hEZKTuOZxRjuddqTtvv/I6L1VMvnqPFZEM/n9LRL7GBC8RUQUkIsk2yRdvLBP8/bjKg4hMKMNzlOHv+AONzZdye8tlETklIpvMycU+/o45GInIEAC/AxgLoCWAKv6NyHtsrsGMUh5r+16b7JsoidxnkzC2LI+Wso1/2Gljho9CJjObhHFplxn+jp+IiAIXE7xERFSu2AMt6EQAqAEgFVpycZWIrBCRBuUdSLD2shORKgBmAIg0Fx0H8A2AjwB8aF6y/RIcURnpE2T+jsXH7ill/Xt9EgURERFVSGH+DoCIiOzKhpZ0caYrgC7m9WMAvnVRf2NZgwpA7jwvemd9FUiQuADgc5uyKACNAPQEEG0u6wdguYhcqZTic+57gwEkmtd3AuiilMr1YzxEVHqdRKS1Umqnq4oi0gLFnw/IfzahdJ+9fvVVIEREVPkxwUtEVAEppTIBPOysjnnIBcsXuL1KKaf1gxSfl/KV6ej5FpFq0HqM3mouagJgAoB/lk9oQa2Tbv0rJncDl1JqJQDxdxxUrnYBaGVevwfA024co+/tqz+eytcSpdQEfwdBRETBgUM0EBERkc+Ze+qOgNajyWKUiIT7KaRgkqhbP+63KIjIEz8COG1ev1NEnH5/M++/y7x5GsAPPoyNiIiIKggmeImIiKhcKKWKAHygK4oF0NlP4QQTfRLd5LcoiMgTBQC+Mq/XAzDARf2rAFjGOJ8NoNBHcREREVEFwgQvEVGQEJGWIvKqiGwUkZMiki8ip0XkNxH5PxGp62Y7VUTkQRFZLCKHROSSiBSIyHkR2S0ii0TkWRFpY3PcBPMkOCt0xX0dzCSd4cWHXmb2JucSkQQReVREVovIUREpNO9PMO+foDtmgrksWkRGi8hS83OXb97fwc45q4jIIyLyk4gcEZE8ETknIjtE5AMR6eZm7CUmIBKR9iLyrrmtTPP+78r8RLlnq8220+vOfN2OE5EFIvKXiFwwX2+nRWSziLwtIk5/fiwiGebHr5906DMH194EJ+2Ei8jdIjJXRA6YY7koIuki8pWIDBURt34+LyJdzK/jFvPrWigiuSJyXER+FZEpInKriMS6056d9q3XnxuPe6SDNirrNVgmIlJdRP4lIqvMr9dlETkjIr+LyERX16O5jWR773ci0ktEpor2XnrevP8d3X6nk1SKyEgH17WrJcO2LZt2G4n2/8SvUvz/x0nz9kvixqSJjmIXkf4i8rX5nsoTkbOiva8+LA56+Ovbsil39PiS7bRR5veWcqQf29zVZGv6/bZjortFRK4TkekiskdEss3vTQdF5FvzNVaqX16ISCcR+dT8Guean+ONIvKUiCR5EqO53S7m12mruc18ETlhvjefFpFE161UfGL+P0x/LYtIfRF5WUS2iUiWaP8X7RaR90WkkZvtevR5zkl7ZX49HDzWJqK9t+4wx5ZrftzPikiMnTaai/Z/1Hbz9Ztlfq/6h4iEuvNYbNoTERkmIgvN90Ge+XEtFZF7xEWvek9IAH3eIKIKRCnFhQsXLlwCcIE2fqkyLyud1IsE8DG0XjzKyXIJwMMuztkdwBEX7eiXMAfxuloyfP28lLLNGbo2R0KbMOyQg9gT7MQxAUBLADscHNPB5nyDof2U3tXz9CWAGBexW+vr4rJ3LXzn4XMzsjSvG4CmNucd4aTuXDevFxOAtwGEOmgnoxTX3gQHbfQDsM+N4zcAqOfkMYUB+KQU8bzihfvA1TLSzvGV9hq0OTbZJoZkF/XvA5Dl4jkpdHY92jlvBoAIaO/T9tp7x+Y6tJSvdPFclGZx+LwBeA5ArovjcwE87eK5M8Rufsz/ddFuGoDqLtpyZ0m2Od4r7y269mbojilxP3lwTevbe91cZvn/IwdAFQfHxUKb7FIB2GEue13X1gwX560JYJkbz8seAKluPpZX4Pzzx2EAV9pcu67iTAQwz404zwG42UVbK3X1+3nhtdO3N6Gs7ZnbzNBfywBugvP3oUsABrlo0+PPcz5+PWwf610ALjppcwuARN3xzwMoclJ/BZz8n4WS71NVAXzn4nGtB1CzNI/LRd1+CKDPG1y4cKk4CydZIyKqxMx/jf8JWkLSYj+0L83nACSZ99UFEA3gfRGJU0q9aqetBua2qpqLCqCNp7oP2peJWGgfxtsDiLMTzkYAH0L7ielN5rJjAL61U/esmw/RH5oAeAdAPLQv0quhPY5EAH0cHFMN2jiKDQHkAVgL4CCAKtC+2FqJyG3QkmaWXiZF5vr7zPV7o7jX6wgAjUWkv1Iqz1XgIvIkgBfNm/uhvSaXoL1uBa6O9xLbHrsnndRtaP63ENpEQXuhfaktgpaI6ALtehIAj0H7Y8ZDdtqZCe01GACghbnsFwC77dQtMeO5iNwC7TWx9FrLhTbbeQa0BFAzaF+Ww6C9nhtEpItSyt5jmwjgft32UfM5T0P7ZVU1aBMiNbdzbGlY7jfA9eP+U78RBNegR0TkCWivn8VlAKug/bEnEdpP45OgPW+PAWgoIjcrpZQbzb8N4AHz+nYA26A9H81QumE1/kTx6+7KQAAp5nW7MYrIBwD+oSvKgZYgOQGgNrTHXAVAFIDXRaS2Umqcm+f/L7Te5SYAv0G7LkOg3UOW678TtF6oN9gcexTFj1Mfn6PHnm2z7a33lvL0BbRkbSyA4dDe12wNh/Z6AKXsvSsitQCsA3CFrng/tNfmMrT3JUuv/aYAVojI9UqpdU7afBXAM7qiSwCWQ/vjUW0A/QHUB7AE2v+r7sRZ29xGS13xTmj3TA601683tPfSBABzReRupdSX7rQfAK6G9segUGjvPRugXd+NoSUGw6B9npsrIm2UUum2DXjh85y+LV++HgOhDesUAu0e3QjtM1Q7FE8y3BHA1wCuE5FnALxsLv/DHEMhgK4AWpvL+wGYDOBBN84PaH9s+Ru098iN0N4vIgH0gPYcAdr//7+ISE+llO17TakE6OcNIqoo/J1h5sKFCxcuni1wo6cqtC+Aljp/wU7vFGhfEsZC+9CsoH0Y7m6n3tu6tlYDqOvgnGEA+gKYBTu9nuCiB1p5PC8etDlD12aB+d8PYNOLCtoH8hA7cViO+QZADZtjQgCEm9evQHHvKwXti3UTO/Ufh7GHyntOYlc2cWQBuMlOvUgPn5uRuvYz3Kj/qq5+PnQ9b+zUfQ3ALQDiHOwXADcCOKVrs5ebr+NINx9fa2hfeBW0L1cTYe6lbVMvBcAaXftL7NSpprsWCqElt8TBeesA+CeA0V6+fp0+7mC4Bm2OTbaJL9lBvR4w9kJcAqCWbfwA3rRp73E3zmtp9xCA3s6eF3jp/RPA9TaP5wE7dW6
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1600x1000 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"ensem_x = [i[\"num_models\"] for i in single_ensem_results]\n",
"\n",
2021-04-06 17:29:15 +01:00
"plt.plot(ensem_x, 1 - np.array([i[\"accuracy\"] for i in single_ensem_results]), 'x-', label='Ensemble Test')\n",
"plt.plot(ensem_x, 1 - np.array([i[\"individual_accuracy\"] for i in single_ensem_results]), 'x-', label='Individual Test')\n",
"plt.plot(ensem_x, 1 - np.array([i[\"agreement\"] for i in single_ensem_results]), 'x-', label='Disagreement')\n",
2021-03-26 20:01:05 +00:00
"\n",
2021-04-06 17:29:15 +01:00
"plt.title(\"Test Error Rates for Horizontal Model Ensembles\")\n",
"plt.ylim(0)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-26 20:01:05 +00:00
"plt.legend()\n",
2021-04-06 17:29:15 +01:00
"plt.ylabel(\"Error Rate\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
"plt.show()"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Multiple Iterations\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the ensemble model investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|1|0.06|0|128|16|50|1, 3, 9, 15, 25|\n",
"|2|0.06|0|35|16|1 - 100|1, 3, 9, 15, 25|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
2021-04-30 20:51:04 +01:00
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models | stratify |\n",
"| --- | --- | --- | --- | --- | --- | --- | --- |\n",
"|3|0.06|0.05|35|16|1 - 300|1, 3, 9, 15, 25| |\n",
"|4|0.06|0.05|35|1 - 50|50|1, 3, 9, 15, 25| |\n",
"|5|0.06|0.05|35|1 - 300|50|1, 3, 9, 15, 25| |\n",
"|6|0.001|0.01|35|1 - 400|50|1, 3, 9, 15, 25| |\n",
"|7|0.01|0.01|35|1 - 400|30 - 150|1, 3, 9, 15, 25| |\n",
"|8|0.03|0.01|35|1 - 400|5 - 100|1, 3, 9, 15, 25| |\n",
"|9|0.1|0.01|35|1 - 400|20|1, 3, 9, 15, 25| |\n",
"|10|0.15|0.01|35|1 - 400|20|1, 3, 9, 15, 25, 35, 45| |\n",
"|11|0.15|0.01|35|1 - 400|10|1, 3, 9, 15, 25, 35, 45| |\n",
"|12|0.02|0.01|35|m|50|1, 3, 9, 15, 25, 35, 45| |\n",
"|13|0.01 exp 0.98, 1|0.01|35|1 - 200|50|1, 3, 9, 15, 25, 35, 45| n |\n",
"|14|0.01|0.01|35|1 - 200|50|1, 3, 9, 15, 25, 35, 45| n |\n",
"|15|0.01|0.9|35|50 - 100|50|1, 3, 5, 7, 9, 15, 25, 35, 45| n |\n",
"|16|0.01|0.1|35|50 - 100|50|1, 3, 5, 7, 9, 15, 25, 35, 45| n |\n",
2021-05-04 15:24:37 +01:00
"|17|0.1|0.1|35|50 - 100|50 - 100|1, 3, 5, 7, 9, 15, 25, 35, 45| n |\n",
"|18 (r)|0.01 - 1|0.0|35|1 - 50|20 - 70|1, 3, 5, 7, 9, 15, 25, 35| n |\n",
"|19 (r)|0.01 - 1|0.0|35|1 - 100|10 - 70|1, 3, 5, 7, 9, 15, 25| n |"
2021-03-19 17:21:00 +00:00
]
},
2021-04-28 12:10:25 +01:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 335,
2021-04-28 12:10:25 +01:00
"metadata": {},
2021-04-30 20:51:04 +01:00
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEGCAYAAABy53LJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqyUlEQVR4nO3deXxU1fnH8c+ThLATIIEECDthX0QQkEUCiCIiuKBgtbVWpQu21KWt9lertdoWrVq1Wou11eKC1qXEiggFgoLsCrJLWATCjsiiItvz+2MuNE1ZZmImk0y+79eLFzNn7r15jsR8c++59xxzd0RERMKVEOsCRESkbFFwiIhIRBQcIiISEQWHiIhERMEhIiIRSYp1ASUhLS3NmzRpUqR9P//8c6pWrVq8BZVy6nP8K2/9BfU5UosWLdrl7nVO9lm5CI4mTZqwcOHCIu2bm5tLdnZ28RZUyqnP8a+89RfU50iZ2Sen+kyXqkREJCIKDhERiUhUg8PMBpnZajPLM7M7TvJ5RTN7Ofh8npk1CdpTzWyGmR0wsz8W2qeLmS0N9nnMzCyafRARkf8WteAws0TgCeAioC1wtZm1LbTZDcAed28BPAKMDdoPAncBt5/k0H8CbgKygj+Dir96ERE5lWiecXQD8tx9nbsfAiYAwwptMwx4Lnj9KjDAzMzdP3f3WYQC5AQzqwfUcPe5Hppk6+/ApdEo/tgx5wcvLGJ2/mE0n5eIyH9E866qBsCmAu83A91PtY27HzGzvUAqsOs0x9xc6JgNTrahmY0CRgGkp6eTm5sbUfGfH3byNh9k0p5jfPiHd7iuXUWqJ5ePq2IHDhyI+L9XWVfe+lze+gvqc3GK29tx3X0cMA6ga9euXpRb0gYNcO58dipvrD3CvQuO8sAVHenXum4xV1r66LbF+Ffe+gvqc3GK5qWqfKBhgfeZQdtJtzGzJCAF2H2GY2ae4ZjFJjHBGNwsmYmje5NaNZnrn13Az99YyudfHYnWlxQRKfWiGRwLgCwza2pmycBIIKfQNjnAdcHr4cB0P82AgrtvBfaZWY/gbqpvAROLv/T/1rZ+DSbe3IvvnteMl+ZvZPBj77Hok0+j/WVFREqlqAWHux8BbgbeAVYCr7j7cjO718yGBps9A6SaWR5wK3Dill0z2wA8DHzbzDYXuCPrB8BfgDxgLfB2tPpQUMWkRO4c3IYJN/XgyFHnyqfm8Nu3V3Lw8NGS+PIiIqVGVMc43H0SMKlQ2y8LvD4IXHmKfZucon0h0L74qoxM92apTP5xH34zaSV/nrmOGat28NCVZ9EhMyVWJYmIlCg9OV4E1StV4LeXd+Rv15/D3i8Pc+mTs3lk6sccPnos1qWJiESdguNr6NeqLlN+3Jehnerz6LQ1XPrEbFZv2x/rskREokrB8TWlVKnAIyPO4qlru7Bt70GGPP4ef5y+hiM6+xCROKXgKCaD2mcw5ZbzuKBdBr+f8jGXPfk+q7bti3VZIiLFTsFRjFKrVeSJb5zNk9eczZbPvuSSx2fx+LQ1GvsQkbii4IiCwR3qMfXWvgxqX4+Hpn7MZU/OZuVWnX2ISHxQcERJ7arJPH51Z5669my27T3IJY/P4uGpH3PoiM4+RKRsU3BE2aD29Zh6S18u6VSfx6atYcjj77F402exLktEpMgUHCWgVtVkHhlxFn/9dlf2HzzC5U/O5v63VvDlIT11LiJlj4KjBPVvnc6UW85jZLdGPP3eegY9+i5z1p5uTkcRkdJHwVHCqleqwG8u68CLN4WWJrn66bnc8dpH7P3icIwrExEJj4IjRno2T2PymPP47nnNeGXhJgY8PJNJS7dqtUERKfUUHDFUOTk0427Ozb1Jr1GRH7zwATf9fRFb934Z69JERE5JwVEKtG+QwsTRvbjzotbMytvJwIffZfycDRw7prMPESl9FBylRFJiAt/t25x3fnweZzWsyV0Tl3PFU5q2RERKHwVHKdM4tSrjb+jGw1d14pPdXzDksVn87u1VunVXREoNBUcpZGZcfnYm027ty+VnN+CpmWu54A8zyV29I9aliYgoOEqzWlWTeWB4JyaM6kFyYgLf/tsCbn7xA3bsOxjr0kSkHFNwlAE9mqUyaUwfbjm/JVNWbGfAQzN5dvZ6jmrwXERiQMFRRlRMSmTM+VmhwfNGNbnnzRUMe2IWSzTvlYiUMAVHGdM0rSp//043/viNzuzY9xWXPjmbu/65jL1f6slzESkZCo4yyMwY0rE+027ry7d7NuGFeZ8w4KFcXlu0WU+ei0jUKTjKsOqVKnD3Je3Iubk3mbWqcNs/lnDVn+do0SgRiSoFRxxo3yCF17/fkweu6MjanZ8z5PFZ/OrN5ew7qMtXIlL8FBxxIiHBuOqchky/rS9Xd2vIs+9voP/vZ/LGh7p8JSLFS8ERZ2pWSea+SzswcXQvGtSqzC0vL+HKp+awLH9vrEsTkTih4IhTHTNr8sb3ezL2ig6s3/U5Q/84i/97Yyl7Pj8U69JEpIxTcMSxhARjxDmNmH57Nt86twkTFmyi30O5jJ/7iR4eFJEiU3CUAymVK3DP0HZM+lEf2mTU4K5/LuOSx2cxb52WrRWRyCk4ypFWGdV58abuPPGNs/nsi0OMGDeX0S98wOY9X8S6NBEpQxQc5YyZcXHHeky7LZtbzm/JtFWhua8enrKaLw4diXV5IlIGKDjKqcrJobmvpt+WzYXtMnhseh79fz+Tf36Yr9t3ReS0FBzlXP2alXns6s7843vnklY9mR+/vJhfzz3Iok/2xLo0ESmlohocZjbIzFabWZ6Z3XGSzyua2cvB5/PMrEmBz+4M2leb2YUF2m8xs+VmtszMXjKzStHsQ3lxTpPa5IzuzQPDO7L7oHPFn97nhy99qPEPEfkfUQsOM0sEngAuAtoCV5tZ20Kb3QDscfcWwCPA2GDftsBIoB0wCHjSzBLNrAHwI6Cru7cHEoPtpBgkJBhXdW3I2D6V+VH/FkxZvo3+D83kgcmrOPCVxj9EJCSaZxzdgDx3X+fuh4AJwLBC2wwDngtevwoMMDML2ie4+1fuvh7IC44HkARUNrMkoAqwJYp9KJcqJRm3XtCKGbdnM7h9Bk/mriX7wVxemr9Rz3+ICBatgVAzGw4Mcvcbg/ffBLq7+80FtlkWbLM5eL8W6A7cA8x19+eD9meAt939VTMbA9wPfAlMcfdrTvH1RwGjANLT07tMmDChSP04cOAA1apVK9K+ZVXhPq/97CgvrTpE3mfHyKxmXNUqmQ5piYQyPj6Ut3/n8tZfUJ8j1a9fv0Xu3vVknyV9rapKmJnVInQ20hT4DPiHmV17PGAKcvdxwDiArl27enZ2dpG+Zm5uLkXdt6wq3Ods4DvuTF62jd9NXsXDi76gd4s07hzcmnb1U2JVZrEqb//O5a2/oD4Xp2heqsoHGhZ4nxm0nXSb4NJTCrD7NPueD6x3953ufhh4HegZlerlv5gZF3Wox9Rb+vLLIW1ZtmUvQx6fxe3/WMLWvV/GujwRKUHRDI4FQJaZNTWzZEKD2DmFtskBrgteDweme+jaWQ4wMrjrqimQBcwHNgI9zKxKMBYyAFgZxT5IIclJCXynd1Nm3t6Pm/o0I2fxFrIfzGXs5FVavlaknIhacLj7EeBm4B1CP9xfcfflZnavmQ0NNnsGSDWzPOBW4I5g3+XAK8AKYDIw2t2Puvs8QoPoHwBLg/rHRasPcmopVSrw88FtmHZbXy5qn8GfctfS98EZ/OW9dXx15GisyxORKIrqGIe7TwImFWr7ZYHXB4ErT7Hv/YQGwQu33w3cXbyVSlE1rF2FP4zszI19mjF28irue2slz76/gdsvaMXQTvVJSIifAXQRCdGT41Is2jdIYfwN3fn7d7pRo1IFfvzyYoY8Povc1Ts0hYlInFFwSLE6r2Ud/vXD3jwyohP7Dh7m239bwNVPz+WDjZrCRCReKDik2CUkGJd1zmT6bdn8amg78nYc4PIn3+e74xeSt2N/rMsTka9JwSFRk5yUwHU9mzDzJ/24dWBLZuft5oJH3uX2fyzRHFgiZZiCQ6KuasUkfjQgi3d/2o/rezUlZ8k
2021-04-30 20:51:04 +01:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-04-28 12:10:25 +01:00
"source": [
"batch_size=35\n",
"test_size=0.5\n",
2021-04-30 20:51:04 +01:00
"epochs=50\n",
"lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.01,\n",
" decay_steps=1,\n",
" decay_rate=0.98)\n",
"\n",
"plt.plot(range(epochs+1), [lr_schedule(i) for i in range(epochs+1)])\n",
"plt.grid()\n",
"plt.ylim(0)\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Learning Rate')\n",
"plt.show()"
2021-04-28 12:10:25 +01:00
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 357,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_results = list()\n",
2021-04-06 17:29:15 +01:00
"multi_ensem_iterations = 30\n",
2021-03-19 17:21:00 +00:00
"for i in range(multi_ensem_iterations):\n",
" print(f\"Iteration {i+1}/{multi_ensem_iterations}\")\n",
2021-04-30 20:51:04 +01:00
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=test_size, \n",
"# stratify=labels\n",
" )\n",
" multi_ensem_results.append(list(evaluate_ensemble_vote(epochs=(50, 100),\n",
" hidden_nodes=(50, 100),\n",
" nmodels=[1, 3, 5, 7, 9, 15, 25, 35, 45],\n",
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.1),\n",
2021-03-19 17:21:00 +00:00
" weight_init=lambda: 'random_uniform',\n",
2021-04-28 12:10:25 +01:00
" batch_size=batch_size,\n",
2021-03-19 17:21:00 +00:00
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-26 20:01:05 +00:00
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
2021-03-19 17:21:00 +00:00
]
},
2021-03-30 16:31:10 +01:00
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 253,
2021-03-30 16:31:10 +01:00
"metadata": {},
"outputs": [],
"source": [
"def test_tensor_data(test):\n",
" return [test[\"accuracy\"], \n",
" np.mean([i[\"accuracy\"][-1] for i in test[\"history\"]]), # avg train acc\n",
" test[\"individual_accuracy\"], \n",
" test[\"agreement\"]]"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 354,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"30 Tests\n",
2021-05-04 15:24:37 +01:00
"Models: [1, 3, 5, 7, 9, 15, 25]\n",
2021-03-26 20:01:05 +00:00
"\n",
"Loss: categorical_crossentropy\n",
2021-05-04 15:24:37 +01:00
"LR: 0.786\n",
"Momentum: 0.0\n"
2021-03-26 20:01:05 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
2021-03-21 09:56:27 +00:00
"multi_ensem_models = sorted(list({i[\"num_models\"] for i in multi_ensem_results[0]}))\n",
"multi_ensem_iter = len(multi_ensem_results)\n",
"\n",
2021-03-26 20:01:05 +00:00
"accuracy_ensem_tensor = np.zeros((multi_ensem_iter, 4, len(multi_ensem_models)))\n",
2021-03-19 17:21:00 +00:00
"for iter_idx, iteration in enumerate(multi_ensem_results):\n",
" for single_test in iteration:\n",
2021-03-22 20:49:29 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" ensem_models_idx = multi_ensem_models.index(single_test['num_models'])\n",
2021-03-30 16:31:10 +01:00
" accuracy_ensem_tensor[iter_idx, :, ensem_models_idx] = test_tensor_data(single_test)\n",
2021-03-19 17:21:00 +00:00
" \n",
"mean_ensem_accuracy = np.mean(accuracy_ensem_tensor, axis=0)\n",
2021-03-22 20:49:29 +00:00
"std_ensem_accuracy = np.std(accuracy_ensem_tensor, axis=0)\n",
"\n",
"print(f'{multi_ensem_iter} Tests')\n",
"print(f'Models: {multi_ensem_models}')\n",
"print()\n",
2021-03-26 20:01:05 +00:00
"print(f'Loss: {multi_ensem_results[0][0][\"loss\"]}')\n",
"print(f'LR: {multi_ensem_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
"print(f'Momentum: {multi_ensem_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-19 17:21:00 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-04-25 15:25:41 +01:00
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 358,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-04-25 15:25:41 +01:00
"outputs": [],
2021-03-21 09:56:27 +00:00
"source": [
2021-04-30 20:51:04 +01:00
"exp2_testname = 'exp2-test17'\n",
"pickle.dump(multi_ensem_results, open(f\"results/{exp2_testname}.p\", \"wb\"))"
2021-03-21 09:56:27 +00:00
]
},
{
2021-03-29 18:34:04 +01:00
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 353,
2021-03-21 09:56:27 +00:00
"metadata": {},
2021-03-29 18:34:04 +01:00
"outputs": [],
2021-03-21 09:56:27 +00:00
"source": [
2021-05-04 15:24:37 +01:00
"exp2_testname = 'exp2-test19'\n",
2021-04-06 17:29:15 +01:00
"multi_ensem_results = pickle.load(open(f\"results/{exp2_testname}.p\", \"rb\"))"
2021-03-21 09:56:27 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
"np.savetxt(\"exp2-mean.csv\", mean_ensem_accuracy, delimiter=',')\n",
"np.savetxt(\"exp2-std.csv\", std_ensem_accuracy, delimiter=',')"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"mean_ensem_accuracy = np.loadtxt(\"results/test1-exp2-mean.csv\", delimiter=',')\n",
2021-03-26 20:01:05 +00:00
"std_ensem_accuracy = np.loadtxt(\"results/test1-exp2-std.csv\", delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
2021-03-22 20:49:29 +00:00
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-22 20:49:29 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Best Results"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 355,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-05-04 15:24:37 +01:00
"Models: 9, 96.5% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_ensem_accuracy_idx = np.unravel_index(np.argmax(mean_ensem_accuracy[0, :]), mean_ensem_accuracy.shape)\n",
"best_ensem_accuracy = mean_ensem_accuracy[best_ensem_accuracy_idx]\n",
"best_ensem_accuracy_models = multi_ensem_models[best_ensem_accuracy_idx[1]]\n",
"\n",
2021-03-27 16:29:31 +00:00
"print(f'Models: {best_ensem_accuracy_models}, {best_ensem_accuracy * 100:.3}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-22 20:49:29 +00:00
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
2021-03-27 16:29:31 +00:00
"### Test/Train Error Over Model Numbers"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 356,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9QAAAMMCAYAAACyue/GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU1f7H8fdJCCT03pGiCCpYaCrSBDuKYhcbig0riB0VsF2vFPWKYvmBoCgXRFERVC5KE1C6UkSaNOm9hZLk/P6Y3c1ssj3ZJJDP63n22dmZM2fO7M7O7nfOmXOMtRYRERERERERiU5CfhdARERERERE5HikgFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEREREREQkBgqoRURERERERGKggFpEpJAwxhQ3xjxujJlujNlhjEkzxljPo29+l08iY4xZ6/rc6uR3eQSMMe1cn8nU/C6P5D9jzFTXMdEul/Ls6spzeG7kKTlnjBnu+ly65uF2vdu0ebVNCUwBtUgQWX4Mo30Mz+/yH2+y/CAFehw0xmwyxvxsjHnFGHNqfpf5eGKMKQ/MAQYCrYEKQGK+FioXZTlW2kW5rvu73jcuBRSJQpYAPZZHnfzeBymcslzw8z7OijKPLwLk0TVORRbJMQXUInK8KA5UAy4EegPLjTHvGmOS87ogx2kN4b+BMzzTacCPwIfAu57HnHwql0iO5FftkIhE7I5IExpjygFXxbEsIrmuSH4XQOQ4MZfoAo5f41WQQmI58FOWeSWB04DmgPE8HgSqGWOus9aqyVMQxpgiwC2uWR2stdPzqzwiErV3o0y/Ly6lEIlNF2PMU9ba9AjS3gQUi3eBRHKTAmqRyEy01vbN70IUIr9Zax8OtMAYczowCjjTM6szcC3wZR6V7Xh0KlDCM71KwfTxzVpbJ7/LIHkr2PlQpIBbBpwOVAUuAb6PYB1vbfZRYD1wSnyKJpJ71ORbRI4r1tplwOXAQdfs+/OpOMeLcq7pzflWChERKUw+dU2HbfZtjKkPnO95ORHYGY9CieQ2BdQictyx1m4CxrhmtTLGmPwqz3EgyTWdkW+lEBGRwmQR8Idn+mpjTOkw6d1B94i4lEgkDhRQi+ShQJ1ZGWNONsa8aoxZaIzZbozJMMYscq2TbegNY0w1Y8xzxpg5xpgtxph0Y8yeINusbYx5yRjzqzFmqzHmqOf5V2NMP2NMrQjKHXBIGGPMFcaYUcaYlcaYA57lPXLwFkVjkWs6Bf9aWD/GmARjTGvP+zDJGLPeGHPIGHPEGLPZ03N4b2NMxRB51PG+B0Bt16K/g/Sy2y5EXhWMMb2MMf8zxmwwxhw2xuwxxizzdLTWLOJ3IbLyTnEtahugrFND5HOpMWaYMWaFMWafMSbVGLPOGDPOOEO4JAVb15VHtk6jjDFljTGPGWcIr39M5hBeZXO463FnHDd4jv3VnmP/gGf6c2PM9ZFc4In1ux3oPJJleaw9Q7cLUdYkY8xdxpivPZ9/qud4+MsYM9QYc3GE712gc2BNY8zLxpjfPd+Dg8aY5caYd4wxtcPlBdzpmv1xkH3rG2D9MsaYW4wxHxhjfjPOUHJHPfu12vP53miMOSH+K5kgQy4ZYzobY8Yb57x4xBizzTjnydsiOY49eTQ3xgw2xiwwxuz2fJ9TjXN+/dUYM8TzXpaIIK8SxpjunjKtM865er9xfmeGGWPax7Kvxvkd6GKM+d44590jxvkt/NIYc36APIoaY243xvxkMs/T640xI4wxp0XyvgTIs5ZxRqn43Rizy3Wsv2mMiUvTZmPMacaY14xzTvH+B9juOeZfMsZUj8d2XT7xPKcAN4QopwFu87zcCUyIdkO5dZ7Kkuc1xphvjPM7dcQYs9E4v923G6d/kqgZR2fPsbTCGLPXc3xt8JT9zljzDrK9isaYJ4wxk40zWsphY8wx45xvlxpjxhpnSM26ubXNQsdaq4ceegR4AFMB63n0zaU817ryrAPcB6S65nkfi4KUox1wNbArwDp7Amyvd5D83Y9U4Okw5W7nSj8VKAN8FSS/HjG+N8NdeQyPIP29WbZbPUi6JGBjmPfA+zgA3BYknzoR5uH7rILk8xCwJ8y6GcBQoGgOjrVoyjs1wPqVgckRrLsCaBbFZ9sVuADn3rhA+ZWNcX/Dvve58V0H6gMLInhf5gH1othuOyL8bpPlPBLmvciNY/ZcYFUE608CKobZZ7+yA9cQ+vtwCOgYQV7hHn2zrHstcDjCdRcBdcPsVztX+mzfpxiOZ3d+Nqf5efLs6spzOM55/Jsw+/49kBIizyLAB1F8Dq+EKeMNOLelhMtnPFAmin2tiNPRZbD8MoC7XOufgnP/b7D0R4Brovx+dyL8sX5fNJ9hmLTFgPdxRnUI9V4eAh7OjWMswPfyMpz7p71lmBZivbau9QZ75v3qmtc1zHZz7Tzlya8kTlAfKq8Znv0bHkU5zwQWRlDO5cDpYfIKe44g+O9KoMfG3DoOCttDnZKJ5J8bgDc805uAmcBeoDpQPsg6LYG+OEHiTmA6sAMn+DnHndAYMxgnePM6gFNTuQXnB+BCnB+MZOB1Y0xVa23PCMptgJHAlTgn4Hk4fzwM0MgzLy+4r6qnE/xeq0Sghmf6ALAUWIPTC24SUBM4DyiN03HXp8aYY9ba0Vny2UdmT7t3AKU8058A+wNs95+sM4wxbwGPuWbtAGbjfCbJOJ9hI5z38m6gujGmo7U2lmba7vLWwAlcwDnWxmVJuzJLOavgHI8nu2avBn7D+SN5Os6fF3ACzCnGmMustTMjKNcpwFs4f+b34xzDm3BaGLSJYP1846mVmgZUcs1ejBNwWZzPr7FnflNgljGmjbV2RQTZR/zdjkCkPULXATq6Xmf77hpj2uAEVcVdaebgfOeL4nx3vMfJxcBMY0wra+32CLZ/Ec4f/kScCyyzcY7bujgBSBGcWq0xxphG1tq/s6w/Amc89Q5AQ8+8n3D+iGaVdZSGymT2JLzRsz9bcIIL74gCTXC+i2cB040xZ1trT5R7OovgdOTYAafzp1k43/FknHHqT/KkuwwYBHQPkk9/nAvDXv/gvNfbcVpBVsA5XzQIVyBjTE9gIM57Ds6xMBvn80nEGfavmWf5lcBUY8wF1tpDEezrV579OozzHV6P8zvbASjryfP/jDErcS4S/gzU8pRhOk6QXwXnmC2Oc+x/bow5I8BxGUgz4FXPejtxgu3dON/Btjjf+xTgA2NMurV2aAR5BuVpCfAjzsVLr9XAfM92y3uWVfds9x1jTGlr7Ws52W4g1totxpj/4RxLrY0xday1awMkvdM1/UmA5UHl9nnKOC2vJuD/m7QF51jYj/M71srzGIfznyLSco7H+b8BcAxnFJmVnuk6njyTcb4zs4wx51tr/4wk/wDbawaMJbMT6lScixRrcX7LS+O8L43JfO8kFvkd0euhR0F9EP8a6mM4J7R7AZMlXbEg5TiGcyX9eSApxDo34n/V8WOgdJb0pXE6DHGnuzZIudtlKYPFuS+qcYC0xSJ9P7KsN9y1jeERpJ/lSj83RLqiwDDPPiQFSVMMeNK1b7uBkhF+jnUi3L+7XevsBe4JVB6cCx3uGvWncuG4c39+UyNIP9GV/gBwc4A0zXD+oHnTrSdI7XKWz9b7Hg/O+h7j/KlMiHEf3cdxuyjXdX/H+oY4jha50m0FLgqQ7hKcYMKbbn6I48693Yi+27EefwG2XQb/GrhpAbZbLsuxuAJoGiCvW3ECUW+6byP87hz2HF+3kf0ceEaWbQ8Lkaf7+Ooa4f5fBTwDnBIiTV3gB1fe/5db37EIyufOz+Y0P0+eXbO89xbnu14jS7oiOIGyN21GoOMMJ1j2fp/TcAIiE2Tb1YBHgG5BlnfAuTB
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-05-04 15:24:37 +01:00
"<Figure size 1000x800 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-05-04 15:24:37 +01:00
"fig = plt.figure(figsize=(5, 4))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-05-04 15:24:37 +01:00
"plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], 'x-', label='Ensemble Test')\n",
"plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], 'x-', label='Individual Test')\n",
"plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], 'x-', label='Individual Train')\n",
"plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], 'x-', label='Disagreement')\n",
2021-03-27 16:29:31 +00:00
"\n",
2021-05-04 15:24:37 +01:00
"# plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], yerr=std_ensem_accuracy[0, :], capsize=4, label='Ensemble Test')\n",
"# plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], yerr=std_ensem_accuracy[2, :], capsize=4, label='Individual Test')\n",
"# plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], yerr=std_ensem_accuracy[1, :], capsize=4, label='Individual Train')\n",
"# plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], yerr=std_ensem_accuracy[3, :], capsize=4, label='Disagreement')\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
"plt.title(f\"Error Rate for Horizontal Ensemble Models\")\n",
2021-05-04 15:24:37 +01:00
"plt.ylim(0, 0.1)\n",
2021-04-30 20:51:04 +01:00
"# plt.ylim(0, np.max(1 - mean_ensem_accuracy + std_ensem_accuracy) + 0.05)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.legend()\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
2021-03-22 20:49:29 +00:00
"plt.ylabel(\"Error Rate\")\n",
2021-04-30 20:51:04 +01:00
"\n",
"plt.tight_layout()\n",
2021-05-04 15:24:37 +01:00
"# plt.savefig(f'graphs/{exp2_testname}-error-rate-curves.png')\n",
2021-04-30 20:51:04 +01:00
"\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-05-04 15:24:37 +01:00
{
"cell_type": "code",
"execution_count": 305,
"metadata": {},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9QAAAMMCAYAAACyue/GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3wUdf7H8dc3AZLQexOkSbEgUsUKCnogggUsYMN+6Nnbqb9TPO/UE1Q8CyqHAooIqAiIIiIgSpGqIogUpUnvBBJCku/vj9ndTJLtyW4S8n4+HvvI7Mx3vt/vTGZn9zPf73zHWGsRERERERERkcgkFHUFREREREREREoiBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiIiIiIiUVBALSIiIiIiIhIFBdQiIiJ5GGMGGmOs5zWqqOsTjDFmjquuXYu6PiISX8aYUa5zwMBCyrOrK885hZGnFL5Y/O/DLNdbpo1XmcWZAmopVowxjd0f0kJ6DS7q7RIRERERkeOPAmqRONHVXomF0niV2BgzWBfMREREpDgoU9QVEMnjIPBGiDSdgI6e6a3ApBDpFxW0UiIiIiIiInkpoJZixVq7F/hbsDSeFilvQL3WWhs0vYiIiIiISCyoy7eIiIiIiIhIFBRQi4iIiIiIiERBAbUc94wxJxtjnjPGLDLG7DDGZBhjdhljfjDG/NMYUz/MfCoaY/5qjJlmjNlkjDlijDlmjDlgjFltjJlqjHnCGHNanvUGewaMmu2a3SXAiOQbCmmbyxpjbjDGTDDG/G6MOWSMOWyM+cMYM84Yc4UxxoTIw+8gasaYSzx5rDXGpHqW3+9Z1tjfthhjzjXG/M+znw54lg8LUO5fjDHvGmPWGGMOGmPSjDEbjTGTPI8yKhvG9ud7jIQxpqox5j5jzFxjzJ/GmEzP8qqh92jB940rTVnPNr5ojJltjNlqjEn3bOcWY8yXxpj7jTEVwyk/z/xAI903DpJXQ2PMP4wx33nqctQYs9cYs9wYM9QY0yLS/ROKp8ynPf+LHZ4yM4wxe4wxPxljPjTGDDLG1M2z3hzPNj/tmv10gG0eFaT8dsaYEZ7PRppxzgeLjDGPGmOqF/b2RssYk2CMuckY87UxZrvnONlgjJlsjLm8APlGfX4wxtQ2znnPGmOyjDEnRFDub67/z1XR1j9PnqcaY4Z4jtfdnmNpq+dYecwYUyPIuu1c9TlgjEkOs8xkk3Mes8aYjkHSGs/+HG2cc9oBz/9xszHmM8//N+jtd6YQzquRKMjx4coj0DnyQmPMR5580z2f+bnGmL+ZMM7tnjxaGef8udDzP8/w5LXTGLPUGPOeZ79WK+JtvdQY86nnM+vd1i+NMZf4ySPBGHOZMeZzT9npxphtxpiJxpjO4ewXP3nW8HwGFhnnHJdmjFlvjHnHGNM2mjzDKDOu3yeu/W5d884wxgw3zvkm1fP6wRhzl7/PmjGmg3F+M/zq+d/vMc5383UR1qWsMeZmz+d6o2d/H/TUY6Qx5qIotu9y45zv//Tsyy3G+T64wd+2hJmnMQU8J0VYXk1jzMPGmJkm5/fOMWPMfmPMSmPMx8aYB40xTQqrzCJhrdVLrxL1AgYD1vOaEyRdEvAWkOlK7+91BPhbiDLPAraEyMf9KhOgvqFeGwph/3QF1oVR1gLghBD5+PYzUAX4NEBe93vWaezeFqCc53/gb51hecqrDcwMo95rgA4h9sEoV/qBwDnApgD5VY1yH0e0bzzrNQR2h3ks7AYuCqP8cF6N/eSRAPwTSAux7jHg34AppM/vHTifuXDq/X2ededEsM2jApT/L4KfEzYDnT3HTdC8YnyeqwssDLGNnwKV8uyXrrE+PwBfuNI9Eub2dHKtsx9ILuD+KQP8N8T/0gL7gJuC5LPKlfaqMMu+2rXO6iDpTgeWh7GvVwOnBMmnsSvtBiI4r0axXwt8fLjy8aad46nzOyHyXArUDFG/wWH8z72vD4poW8sD40LkOdi1fi1gXpC02YT+jTLKlX4gzm+WP4PkmemuQzjbFSJtUX2f+PL2vH80xPExHUjypE0E3gxR33FAYhj1ODPMY2kGIY5xT34VgWkh8voO53si1/8+RL6Fck7yt+8DpLkM2BtGeRbYUhjHRFG9NCiZHJeMMRWAr3ACKa/1OF/Y+4DqnmX1gRTgNWNMZWvtc37yaujJq5Jn1jFgMc7J8whQAecHTxugsp/qLMIZufwE4HLPvECjk+8JcxP9Mk6rz1jAe6U/DedH+QacL+UWOF+0ZXCChgXGmI7W2h2hsgY+AC7FOfEtwfkhaoDTPPP8eQW40zO9AvgJZ/+18NTHW+86OD8omrnWXQ/8ABwFTsH5wgJoDsw2xvSw1s4LUW+Ak4BhOEHvIWAuzv6vBpwfxvqhRLJvKgDeFrN9wEpgI5CK84OzCc7/JdmT7gtjTBdr7fw8Zf5Jzmj4d7vmBxoh/2CuChuTCIwH+ubJcxGwC+fL/Eyc/0cZ4AmcH313BMg/LMZpVX07T70W4FysysT5H7XA2W/l/GQxCfiF3CP9L8b/SP4L/ZT/HPC4a9YRYBawDeeHyYVAA5yAcVhYGxUDxuk1MQs42TX7D5x9dRQ4FWcfXIHrcxRGvoV1fvgA6OmZvg4YEkbx7taej6216eHWOy9jTALwCdDHNXsvTjCzF+fC1QU4x1BVYJQxpqq19lU/2X2A8wPfW8eJYVTBvS0fBKjj+cBUcr4TvN8baz3TjYFzcT7rLYH5xpizrLW/hlF+WOfVSMXw+wOcYPomTz4/4PxgT/Dk09KTph0wBsjXguup333k7p2y21O/bTjn2epAK5zPTWIRbutI4Fqcc9o8nN8K5XHOL3U8aZ42xvwGfIYTZJ0BpON8P23COW674XxPGeC/xpil1toFYZTfCHjZs24qzrlkB87vnQs8dUn01CHBWvtUGHkGVFTfJ37qcSfwH8/bn4EfgSxP2ad45v8F50LcnTjB9B04/+/FwK84x+R5ON/F4PwffwJeCFLu+cCXOPsVnGNxEc7vgHI4x4/3t81FwDxjzLnW2l0B8iuLE0y7f59sxzk2DuH8pjnX85oE/B6obn7qGatzkr/yOgAfkzMAtvszdtRTj2ZAa3L2XclV1BG9XnpF+iKMFmpgtCvNb/hptcH5QhmE8yXmvWJ7lp90r7jymgvUD1BmGaALzg+sfFc0ieBqb5T75VRyWv6ycX7kVvWTrinOlU1vXb4IkJ+7vsc8f38GWvtJ673i29i1jvcq8SbgvEDreKbdLV6pwLV+0nfACbKtK9982+dJO8pP3V8HKuZJVxZIiGJfR7xvPNONcL7MOwUqF+dLZmie4zdgHV3pbAT1/6drvW3AlfhpMQCuwmlN9Ka9uoDHqPvK+GtA+QDpKnrKfiHA8sGufAaHWfb5ns+Fd72JQLU8aaqQ07J01JV2VEG2O4r9NNJV9lHgFj9pOuH8MMlb164B8iy08wPOj59DrjSnhtieRJwf9EHrGMH+edR93APPA+XypKmLcyHU/Tk9009ejVzHxVGgeoiyqwMZrv3YxE+aunm2dzRQz0+6OuTu2fIz/r87GrvShH1ejXCfxvL7w/sduwholSedAe7L8/88309+ZXCCM2+avwNlg/yPbgYeLcJtnQc0zZMuBZjgSrsG5/vAeo6D2nnSVwO+daWfFeT/N8qVzns++ACo7CfPT1xps4Czw9iuOUHKLpLvE0+e7uMm3VN+vvML8JAr3THgAc/0KqBNnrSJ5P7ddwioEKD8auTuvbgGaO8n3XXk7pk1Jcg2/cOVLhvnAkRinjQtcC4Y5D3/DwyQZ6Gek/Lu+wDLJ7nSfEye71tXumSci2hvFfR4KMpXkVdAL70ifREioMa5uuhdvo7QXcgGutJ
2021-05-04 15:24:37 +01:00
"text/plain": [
"<Figure size 1000x800 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(5, 4))\n",
"# fig = plt.figure()\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"plt.plot(multi_ensem_models, std_ensem_accuracy[0, :], 'x-', label='Ensemble Test', lw=2)\n",
"plt.plot(multi_ensem_models, std_ensem_accuracy[1, :], 'x-', label='Individual Train', lw=2)\n",
"plt.plot(multi_ensem_models, std_ensem_accuracy[2, :], 'x-', label='Individual Test', lw=2)\n",
"plt.plot(multi_ensem_models, std_ensem_accuracy[3, :], 'x-', label='Agreement', lw=2)\n",
"\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.title(f\"Test error rate std. dev over ensemble models\")\n",
"plt.xlabel(\"Number of Models\")\n",
"plt.ylabel(\"Standard Deviation\")\n",
"plt.ylim(0, 0.08)\n",
"\n",
"plt.tight_layout()\n",
"# plt.savefig(f'graphs/{exp2_testname}-error-rate-std.png')\n",
"plt.show()"
]
},
2021-03-26 20:01:05 +00:00
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-26 20:01:05 +00:00
"cell_type": "markdown",
"metadata": {
"id": "FSZq1mNiVZq_",
"tags": [
"ex3"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 3\n",
"\n",
2021-03-29 19:17:14 +01:00
"Repeat Exp 2) for cancer dataset with two different optimisers of your choice e.g. 'trainlm' and 'trainrp'. Comment and discuss the result and decide which is more appropriate training algorithm for the problem. In your discussion, include in your description a detailed account of how the training algorithms (optimisations) work."
2021-03-19 17:21:00 +00:00
]
2021-03-29 18:34:04 +01:00
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 127,
2021-03-29 18:34:04 +01:00
"metadata": {},
"outputs": [],
"source": [
"def evaluate_optimisers(optimizers=[(lambda: 'sgd', 'sgd'), \n",
" (lambda: 'adam', 'adam'), \n",
" (lambda: 'rmsprop', 'rmsprop')],\n",
" weight_init=lambda: 'glorot_uniform',\n",
" print_params=True,\n",
" **kwargs\n",
" ):\n",
" for o in optimizers:\n",
" \n",
" if print_params:\n",
" print(f'Optimiser: {o[1]}')\n",
" \n",
" yield list(evaluate_ensemble_vote(optimizer=o[0],\n",
" weight_init=weight_init,\n",
" exp=f'3-{o[1]}',\n",
" print_params=print_params,\n",
" **kwargs\n",
" ))"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single Iteration"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimiser: sgd\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n",
"Optimiser: adam\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n",
"Optimiser: rmsprop\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n"
]
}
],
"source": [
"single_optim_results = list()\n",
"for test in evaluate_optimisers(epochs=(5, 300), nmodels=[1, 3, 5]):\n",
" single_optim_results.append(test)"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multiple Iterations\n",
"\n",
"### Pickle Results\n",
"\n",
2021-04-29 22:53:26 +01:00
"| test | optim1 | optim2 | optim3 | lr | momentum | epsilon | batch size | hidden nodes | epochs | models | stratified |\n",
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n",
"| 1 | SGD | Adam | RMSprop | 0.1 | 0.0 | 1e7 | 35 | 16 | 1 - 100 | 1, 3, 9, 15, 25 | y |\n",
"| 2 | SGD | Adam | RMSprop | 0.05 | 0.01 | 1e7 | 35 | 16 | 1 - 100 | 1, 3, 9, 15, 25 | y |\n",
"| 3 | SGD | Adam | RMSprop | 0.1 | 0.01 | 1e7 | 35 | 1 - 400 | 20 | 1, 3, 9, 15, 25, 35, 45 | y |\n",
"| 4 | SGD | Adam | RMSprop | 0.075 | 0.01 | 1e7 | 35 | 1 - 400 | 20 | 1, 3, 9, 15, 25, 35, 45 | y |\n",
2021-04-30 20:51:04 +01:00
"| 5 | SGD | Adam | RMSprop | 0.05 | 0.01 | 1e7 | 35 | 1 - 400 | 20 | 1, 3, 9, 15, 25, 35, 45 | n |\n",
"| 6 | SGD | Adam | RMSprop | 0.02 | 0.01 | 1e7 | 35 | m | 50 | 1, 3, 9, 15, 25, 35, 45 | n |\n",
"| 7 | SGD | Adam | RMSprop | 0.1 | 0.9 | 1e-8 | 35 | 1 - 400 | 50 - 100 | 1, 3, 5, 7, 9, 15, 25 | n |\n",
2021-05-04 15:24:37 +01:00
"| 8 | SGD | Adam | RMSprop | 0.05 | 0.9 | 1e-8 | 35 | 1 - 400 | 50 - 100 | 1, 3, 5, 7, 9, 15, 25 | n |\n",
"| 9 (r) | SGD | Adam | RMSprop | 0.01 - 1 | 0.0 | 1e-7 | 35 | 1 - 100 | 10 - 70 | 1, 5, 9, 15, 25 | n |\n",
"| 10 (r) | SGD | Adam | RMSprop | 0.01 - 1 | 0.0 | 1e-7 | 35 | 1 - 100 | 1 - 70 | 1, 5, 9, 15, 25 | n |"
2021-03-29 18:34:04 +01:00
]
},
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 27,
2021-03-29 18:34:04 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 16:31:10 +01:00
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
2021-03-29 18:34:04 +01:00
]
}
],
"source": [
"multi_optim_results = list()\n",
2021-03-30 16:31:10 +01:00
"multi_optim_iterations = 30\n",
2021-03-29 18:34:04 +01:00
"\n",
2021-04-29 22:53:26 +01:00
"multi_optim_lr = 0.05\n",
2021-04-09 12:42:18 +01:00
"multi_optim_mom = 0.01\n",
2021-03-29 18:34:04 +01:00
"multi_optim_eps = 1e-07\n",
"multi_optims = [(lambda: tf_optim.SGD(learning_rate=multi_optim_lr, \n",
" momentum=multi_optim_mom), 'sgd'), \n",
" (lambda: tf_optim.Adam(learning_rate=multi_optim_lr, \n",
" epsilon=multi_optim_eps), 'adam'), \n",
" (lambda: tf_optim.RMSprop(learning_rate=multi_optim_lr, \n",
" momentum=multi_optim_mom, \n",
" epsilon=multi_optim_eps), 'rmsprop')]\n",
"\n",
"for i in range(multi_optim_iterations):\n",
" print(f\"Iteration {i+1}/{multi_optim_iterations}\")\n",
2021-04-29 22:53:26 +01:00
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, \n",
"# stratify=labels\n",
" )\n",
2021-04-30 20:51:04 +01:00
" multi_optim_results.append(list(evaluate_optimisers(epochs=(50, 100),\n",
2021-04-28 12:10:25 +01:00
" hidden_nodes=(1, 400),\n",
2021-04-30 20:51:04 +01:00
" nmodels=[1, 3, 9, 15, 25],\n",
2021-03-29 18:34:04 +01:00
" optimizers=multi_optims,\n",
" weight_init=lambda: 'random_uniform',\n",
" batch_size=35,\n",
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 339,
2021-03-29 18:34:04 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 16:31:10 +01:00
"30 Tests\n",
2021-03-29 18:34:04 +01:00
"Optimisers: ['SGD', 'Adam', 'RMSprop']\n",
2021-05-04 15:24:37 +01:00
"Models: [1, 5, 9, 15, 25]\n",
2021-03-29 18:34:04 +01:00
"\n",
"Loss: categorical_crossentropy\n"
]
}
],
"source": [
"multi_optim_results_dict = dict() # indexed by optimiser name\n",
"multi_optim_iter = len(multi_optim_results) # number of iterations (30)\n",
"\n",
"#####################################\n",
"## INDIVIDUAL RESULTS TO DICTIONARY\n",
"#####################################\n",
"for iter_idx, iteration in enumerate(multi_optim_results): # of 30 iterations\n",
" for model_idx, model_test in enumerate(iteration): # of 3 optimisers\n",
" for single_optim_test in model_test: # single tests for each optimisers\n",
" \n",
" single_optim_name = single_optim_test[\"optimizer\"][\"name\"]\n",
" if single_optim_name not in multi_optim_results_dict:\n",
" multi_optim_results_dict[single_optim_name] = list(list() for _ in range(multi_optim_iter))\n",
"\n",
" multi_optim_results_dict[single_optim_name][iter_idx].append(single_optim_test)\n",
"\n",
"# list of numbers of models used in test\n",
"multi_optim_models = sorted(list({i[\"num_models\"] for i in multi_optim_results[0][0]}))\n",
"\n",
"##################################\n",
"## DICTIONARY TO RESULTS TENSORS\n",
"##################################\n",
"optim_tensors = dict()\n",
"for optim, optim_results in multi_optim_results_dict.items():\n",
" \n",
" accuracy_optim_tensor = np.zeros((multi_optim_iter, 4, len(multi_optim_models)))\n",
" for iter_idx, iteration in enumerate(optim_results):\n",
" for single_test in iteration:\n",
"\n",
" optim_models_idx = multi_optim_models.index(single_test['num_models'])\n",
2021-03-30 16:31:10 +01:00
" accuracy_optim_tensor[iter_idx, :, optim_models_idx] = test_tensor_data(single_test)\n",
" \n",
2021-03-29 18:34:04 +01:00
" optim_tensors[optim] = {\n",
" \"accuracy\": accuracy_optim_tensor,\n",
" \"mean\": np.mean(accuracy_optim_tensor, axis=0),\n",
" \"std\": np.std(accuracy_optim_tensor, axis=0)\n",
" }\n",
"\n",
"print(f'{multi_optim_iter} Tests')\n",
"print(f'Optimisers: {list(multi_optim_results_dict.keys())}')\n",
"print(f'Models: {multi_optim_models}')\n",
"print()\n",
"print(f'Loss: {multi_optim_results[0][0][0][\"loss\"]}')"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-03-30 16:31:10 +01:00
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 28,
2021-03-29 18:34:04 +01:00
"metadata": {},
2021-03-30 16:31:10 +01:00
"outputs": [],
2021-03-29 18:34:04 +01:00
"source": [
2021-04-29 22:53:26 +01:00
"pickle.dump(multi_optim_results, open(\"results/exp3-test5.p\", \"wb\"))"
2021-03-29 18:34:04 +01:00
]
},
{
2021-04-06 17:29:15 +01:00
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 338,
2021-03-29 18:34:04 +01:00
"metadata": {},
2021-04-06 17:29:15 +01:00
"outputs": [],
2021-03-29 18:34:04 +01:00
"source": [
2021-05-04 15:24:37 +01:00
"exp3_testname = 'exp3-test10'\n",
2021-04-06 17:29:15 +01:00
"multi_optim_results = pickle.load(open(f\"results/{exp3_testname}.p\", \"rb\"))"
2021-03-29 18:34:04 +01:00
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"### Best Results"
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 340,
2021-03-29 18:34:04 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-05-04 15:24:37 +01:00
"SGD: 25 Models, 96.5% Accurate\n",
"Adam: 15 Models, 96.5% Accurate\n",
"RMSprop: 25 Models, 96.6% Accurate\n"
2021-03-29 18:34:04 +01:00
]
}
],
"source": [
"for optim, optim_results in optim_tensors.items():\n",
" best_optim_accuracy_idx = np.unravel_index(np.argmax(optim_results[\"mean\"][0, :]), optim_results[\"mean\"].shape)\n",
" best_optim_accuracy = optim_results[\"mean\"][best_optim_accuracy_idx]\n",
" best_optim_accuracy_models = multi_optim_models[best_optim_accuracy_idx[1]]\n",
"\n",
" print(f'{optim}: {best_optim_accuracy_models} Models, {best_optim_accuracy * 100:.3}% Accurate')"
]
},
{
2023-05-27 23:29:17 +01:00
"attachments": {},
2021-03-29 18:34:04 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optimiser Error Rates"
]
},
{
"cell_type": "code",
2021-05-04 15:24:37 +01:00
"execution_count": 343,
2021-03-29 18:34:04 +01:00
"metadata": {},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAACXUAAAJECAYAAABT8FUYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd5gV1f3H8feXDgKCiNhFFDtWwFhZS4yKYjeWqBi7xm40iUnEmGhU7EY0+aFgVAIiFqyEyIoiSlODGhVUVEQp0qWz5/fHmbs79+4tc9vu3buf1/PcZ+/OnDlz5k473zlnZsw5h4iIiIiIiIiIiIiIiIiIiIiIiJSGJvVdABEREREREREREREREREREREREamhTl0iIiIiIiIiIiIiIiIiIiIiIiIlRJ26RERERERERERERERERERERERESog6dYmIiIiIiIiIiIiIiIiIiIiIiJQQdeoSEREREREREREREREREREREREpIerUJSIiIiIiIiIiIiIiIiIiIiIiUkLUqUtERERERERERERERERERERERKSEqFOXiIiIiIiIiIiIiIiIiIiIiIhICVGnLhERERERERERERERERERERERkRKiTl0iIiIiIiIiIiIiIiIiIiIiIiIlRJ26RERERERERERERERERERERERESog6dYmIiIiIiIiIiIiIiIiIiIiIiJQQdeoSEREREREREREREREREREREREpIerUJSIiIiIiIiIiIiIiIiIiIiIiUkLUqUtERERERERERERERERERERERKSEqFOXiIiIiIiIiIiIiIiIiIiIiIhICVGnLhERERERERERERERERERERERkRKiTl0iIiIiIiIiIiIiIiIiIiIiIiIlRJ26RERERERERERERERERERERERESog6dYmISFkys95mNtTMZprZj2bmYp/6Lps0bsXYFs1sSCjf/oXKV/JjZrNC66VrHc2zIjTPyrqYp4ikZmYDQvvkgPouj4gUn5k1M7MLzOzfZjbXzNaEjgND6rt80niZWf9ibIuKtUtPfcUEiktFREQaHjPrZGY3mdm7ZrbIzNbrfC6loBh1S107L031df20PtpvcqVOXVIWgoPwQ2Y2xczmBxdNV5rZvGDYU2Z2tZn1NDPLMm8L8v+zmb1hZp+b2eJgHgvM7DMzG2VmvzezvSPmGT4RJX7WBMsw08zeMbNHzOxCM9sxt18nUnkq0pQnyqdrscpWjsysa4bfc52ZLTSzj83sSTM7zcxa1He5GxIzuwSYCJwNbAe0qd8SFU5C5Sbbz6z6Lr80TimOe/PMrFkWeTQ1s+90DhKRfJjZEwnHkBvqu0yNXYbYKNOnsr7L39BEqEuuNLPvzewtM7vLzPaq7zI3JGbWEvg38HfgcGAToHm9FqqALP6CZ7afAfVdfmmcUhz37skyj746B4mUDjOrzHDOWWZmX5vZq2b2OzPbIou86/WYYUVs55DCUH2obhVzfxcws27AB8AAoDfQgTLpO2CZ2+EyfSrqexmkcUpx3Dsuyzzu1DmovERuSBMpRWa2M/Ao8JMko5sDrYDOwD7A6cHwj4DdIuZ/KvBHYNcUSToFn+7ACcAtZvYFcA/wD+fc6mhLUqvcGwef7YB9Q+V5E/ibc254DvlKw9EU6Bh8dgbOAGaZ2dnOuTfrsiDme78/Fvw71DnXvy7nnwsz2wa4n5rg4wvgXWBhvRVKRJLpDBwFjI6Y/mfApsUrjoiUOzNrh6+zh50D3F4PxREpVa2CTxfgAOAaM3sauMg5t6guCxI0fvYJ/j3EOVdZl/PP0XVARej/N4CZwKrg/3fqukAiktTpZvZr59y6iOnPKWppRKTQ2gafrfDXEgaY2V+APznncnmqYNGPGcVu5xApY4Xe33NmVvPUUudcQ+l0+QgQ6wi3EhgLfAusD4b9rz4KJSK1nA08HyWhmTUFzixucaSuqVOXNFjm7xh+Hd9zPGYuMAX4HnD4Dle7AdsDsUpUOH2qvFsDg6kJkGJWAJOD/JcEeW2CD6baBWm6AQ8APwWi9Jz9BPhPePZA+yDvHfAdxmJlPwg4yMzOBn7pnJsbIf9c/C3L9EuLUorG43FgWej/ZtQ0onQOhnUFXjOzQ51zaghI73Rqzm9jgL5ZXHRpaOYAz2aR/odiFUQkR2cTvVPX2cUsiIg0CqdQ++mdO5tZL+fc5PookNSSGBtlMqNYBWkkktUl2+BvLtqPmidMnQJsGcQiq5B0zgp9P8c593i9laT4/oPfZ6OaVKyCiOSgC77h96VMCc2sA3BssQskIjmbTO1zzIbAHkCP4P/m+KfQdACuzmEeRT1mFLOdQ4pO9aG6VRf7e6NhZpvhny4MsBrYwzlXzjF2YjtcJt8WqyAiOTjGzDpGvNnup8BmxS6Q1C116pIGycyaA09RE7jMAS4DXnDOVSVJ3xnfweosfKerdHm3wL8u4YDQ4EnAn4B/O+fWJJmmGf4umvPwT1VqAWwQcXHedc79Kk15NsJfRL8aiL2C8WjgHTPr7ZybH3E+kaUrjxTFTc65WYkDg23xauA2fLDeGnjEzPas67tMGpjwa1AfL+MOXQAztL9KA/UxsAtwrJl1cM4tTpfYzDakpqN0bFoRkWyF75hfia9bxYarU1dpSBsbScGlrEua2Vb4i94VwaD98DH3XXVTtIbHzNpQE7OvAZ6ox+LUhSecc0PquxAiWQrHEmcToYMGcCr+CTmJ04tIaXjZOTcg2Qgz2x8YBmwdDLrKzJ50zk2JmHfRjxnFbOeQOqH6UN0q5v7eGO0V+v5mmXfoghTtcCIlLlaXaAGcBgyKME345njFL2WiLN6LK43S8cBOwfeV+FcxPJcs0AFwzs13zv2fc64P8a9CSOZ+4jt0/cU5t69z7qVkHbqC/Nc5595yzp0LbAuMymJZ0nLOLXTOPYK/0+CB0KiuwLNBhzIpQ865Nc6524F7Q4N3xzeoSGodQ9+/q7dSiEg6/wz+tgR+HiF9+KJoOT/xQkSKxMy2xT/1Fvyd7teFRp8edKYXkYBz7hv8Uya+CQ2+qJ6K01CE45C5qa5PiEi9mg58EHzvF9w8kkmsU/hafGOxiDQQzrm38R2gwjfHXphFFnVxzDie4rVziDQaBdjfGyO1o4iUvn/h6xQQ4U0mZtYeX7cAeB9fl5EyoE5d0lAdEfr+vHPus6gTOuc+TzXOzPoQf6H6Pufc77MpmHNujnPuJOD6bKaLkO9a59wVwEOhwQfgnwwm5e2+hP8PrpdSNBzNQ9/VkCJSmp4CYk/Ri/JaxViatcG0IiLZOpua15S8AfwdiD3xdiPgmPoolEgpc84tB/4vNKi7mW1aX+VpABSHiDQMQ4O/rfA3j6RkZtsB+wf/vgwsKGK5RKQInHPvA5WhQdleVy32MaMo7RwijVEB9vfGRvGLSOmbD7wSfP+JmXXPkP4Uat5MMDRdQmlY1KlLGqotQt+/KmC+vwt9/xL4Ta4ZOeem5V+cpK7Bly3mt2ZWcvuymfU3Mxd8hgTDmprZaWb2vJl9YWYrg/HHB+MrQtNUhvI62syGmdkMM1sejL8qyTzNzE4J0n4epF0efH/KzE42M0ucLkk+laFyVATDNjOz35nZJDP73szWm9niAvxUGTnnvgLC70nePF16M9vEzM41s6Fm9p6ZLTSztWa22Mw+MbPHzOxnGfIYYmYOeCw0+JzQ7xL+VGbI6zAze9jMPgrKstrM5pjZa2b2KzNrnW76KGLlDcrcJzRqXJLyVqTIY2Mz+42ZvWFm3wXlXBD8hneaWcZHlJpZ19B8ZoWGH2hm/xf8/kuC8ffmt9SFE/79zKx/MKyNmV1qZm+Z2dzg9/gm2L8OyJBlLF8zs+OD/e9TM1sa7Ds/mtksM3vdzG43s0MswnHMzLYysz+Y2ZvBNrQ62KbeM7OBZrZDjsvawcyuNbN3zGyema0xf4x6yPyrhxLz6BRsK5PMbL6ZrTCz/5nZX82sY62ZRvutegXbyGfB77MwyP+35u+uKLi62DfTmAe8Gnzf3/yFz1Tl3JaaJ2i+Sk0njMgKsX8n5NfSzC4PtsX55s9nsXPNIdmWL5TvBmZ2iZmNNrO
2021-03-29 18:34:04 +01:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 2400x600 with 3 Axes>"
2021-03-29 18:34:04 +01:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig, axes = plt.subplots(1, 3, figsize=(12, 3))\n",
2021-03-29 18:34:04 +01:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, ((optimiser_name, tensors_dict), ax) in enumerate(zip(optim_tensors.items(), axes.flatten())):\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][0, :], 'x-', label='Ensemble Test')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][2, :], 'x-', label='Individual Test')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][1, :], 'x-', label='Individual Train')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][3, :], 'x-', label='Disagreement')\n",
"\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][0, :], yerr=tensors_dict[\"std\"][0, :], capsize=4, label='Ensemble Test')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][2, :], yerr=tensors_dict[\"std\"][2, :], capsize=4, label='Individual Test')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][1, :], yerr=tensors_dict[\"std\"][1, :], capsize=4, label='Individual Train')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][3, :], yerr=tensors_dict[\"std\"][3, :], capsize=4, label='Disagreement')\n",
"\n",
" ax.set_title(f\"{optimiser_name} Error Rate for Ensemble Models\")\n",
2021-05-04 15:24:37 +01:00
" ax.set_ylim(0, 0.15)\n",
2021-04-30 20:51:04 +01:00
"# ax.set_ylim(0, np.max([np.max(1 - i[\"mean\"] + i[\"std\"]) for i in optim_tensors.values()]) + 0.03)\n",
2021-03-29 18:34:04 +01:00
" ax.grid()\n",
2021-04-30 20:51:04 +01:00
"# if idx > 0:\n",
2021-05-04 15:24:37 +01:00
"# ax.legend()\n",
2021-03-29 18:34:04 +01:00
" ax.set_xlabel(\"Number of Models\")\n",
" ax.set_ylabel(\"Error Rate\")\n",
"\n",
2021-04-30 20:51:04 +01:00
"# axes[0].set_ylim(0, 0.4)\n",
"axes[1].legend()\n",
"axes[2].legend()\n",
"\n",
"plt.tight_layout()\n",
2021-05-04 15:24:37 +01:00
"# plt.savefig(f'graphs/{exp3_testname}-error-rate-curves.png')\n",
2021-04-30 20:51:04 +01:00
"\n",
2021-03-29 18:34:04 +01:00
"plt.show()"
]
},
2021-05-04 15:24:37 +01:00
{
"cell_type": "code",
"execution_count": 345,
"metadata": {},
"outputs": [
{
"data": {
2023-05-27 23:29:17 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAACV0AAAJECAYAAAAPehY8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd7gU1f3H8feX3osIqKAgihUUlRIrWGIJBruxRMVY0dgwahJ/iRgTNbHETqJRwSgoIorEhoVrRQErggVUVESa9F7u+f1xZu+d3bu93vJ5Pc88d3fmzJkzO3N3z3fmzDnmnENERERERERERERERERERERERETSU6/UBRAREREREREREREREREREREREalJ1OhKREREREREREREREREREREREQkA2p0JSIiIiIiIiIiIiIiIiIiIiIikgE1uhIREREREREREREREREREREREcmAGl2JiIiIiIiIiIiIiIiIiIiIiIhkQI2uREREREREREREREREREREREREMqBGVyIiIiIiIiIiIiIiIiIiIiIiIhlQoysREREREREREREREREREREREZEMqNGViIiIiIiIiIiIiIiIiIiIiIhIBtToSkREREREREREREREREREREREJANqdCUiIiIiIiIiIiIiIiIiIiIiIpIBNboSERERERERERERERERERERERHJgBpdiYiIiIiIiIiIiIiIiIiIiIiIZECNrkRERERERERERERERERERERERDKgRlciIiIiIiIiIiIiIiIiIiIiIiIZUKMrERERERERERERERERERERERGRDKjRlYiIiIiIiIiIiIiIiIiIiIiISAbU6EpERERERERERERERERERERERCQDanQlIiIiIiIiIiIiIiIiIiIiIiKSATW6EhGRWs/MBpuZC6YRecw3kqfLV54iyZjZiNB5N7jU5UlG/x8iIpKKmQ0L/V4MK3V5RPLNzMpC5/iAPOVZkNhGJJmaUrfX/4eIiIhIbgpR76tJ17SldjCzAaFzrqzU5UlG/x+1gxpdidQRwQ/MfWY2zcwWmdkGM1trZguDeaPM7Aoz621mlmHeFuT/VzN73cy+MrNlwTYWm9mXZjbOzP7PzPZOM8/wj0zstCHYh9lm9q6Z/dvMzjeznbP7dERERERExMwejal3X1PqMomIiIiISGZiGn7Hm1aa2Xdm9qKZ/dHMOmWQ97A4+f0zw/INjJNHWZrrFuw+h4iIiEg21OhKpJYzs13NbDIwCRgC7ANsCTQEmgDtg3mnArcDU4HpGeR/cpB+EnAtcBDQDWgdbKMd0B04DrgBeD9olPVbM2uc5W41DPZhB6AfcD7wb+BzM3vDzH6VZb4itUptaiFvBeglIN9qQhlFREQSMbOW+Dp72FmlKIuI1GyFeDq+FMysa2hf5pS6PPHUhDKKiEi11ALYFjgC+BvwjZldl0MjpVPNrEEG6TOOMwp9n0NE6i6rQb0ipWI1oOfXmlBGkUxlUgkSkRrGzPYCXgPahGYvAKYB8wGHbxTVA9gRiARV4fSJ8m4KPIgPYsLW4AOa+cDyIK8O+ICnZZCmG3A38HPgmDR25XPg1fDmgVZB3jvhG3VFyn4gcKCZnQn8xjm3II38RURERETqupOAZjHzdjWzPs65qaUokIiIiIiI5GwqMCVmXmtgT6Bn8L4hMAx/vf2KLLbREd+A67lUCc2sDfDLTDIv5H0OERERkVyp0ZVILWVmDYFRVAYW84CLgWedc+Vx0rfHN4A6A98oKlnejYCXgf1Ds6cAfwFeds5tiLNOA+BnwDnAaUAjoHmau/Oec+63ScqzBf4m0RVAZIjBXwDvmllf59yiNLcjIiIiIlJXhZ82Xws0Dc1XoysRERERkZrpeefcsHgLzGw/YDSwXTDrcjN7zDk3Lc28ZwK7Ba/PJI1GV8DJ+J6pYtePq5D3OURERETyQcMLitRexwK7BK/XAgc7556JF4gAOOcWOef+45zrDwxIkfddRDe4+ptzrp9z7rl4Da6C/Dc5595yzp0NbA+My2BfknLOLXHO/Rv/ZM7doUVdgacz7NpYRERERKROMbPt8T3Ggn9K/HehxacGD12IiIiIiEgt4px7B99AKTwk8PkZZDEd+Dh4PcjMWqexTuRhj434Bl+pHEvh7nOIiIiI5EyNrkRqr8NDr8c7575Md0Xn3FeJlplZf+CC0Kw7nXP/l0nBnHPznHMnAFdnsl4a+W50zl0K3BeavT++Zy0REREREYnvTCqH4HgduB+I9Ba7BXB0KQolIiIiIiKF5Zz7CCgLzToowyxGBn+b4HuxSsjMdgD2C94+DyxOI/+C3OcQERERyRc1uhKpvTqFXn+bx3z/GHr9DfD7bDNyzn2Qe3HiGoovW8QfzCyv33dmdqiZ/cvMZpjZEjNbb2bzzOwlM/utmTVNIw8XmULzdjazO8zsMzNbZWYrzOxjM7vJzLZMs2xbmtnvzOyVoEzrzGyjmS0LyjvWzIYGPRqUcl97mdlwM/si2NdVZvaemV0Ur3cyM+ttZiOCz2a1mf1kZpPM7PR09iNOfg3N7Cwze9nM5gb7NdfMnjGzY7LJM41tNjezIWY2wcy+NbM1ZrbSzGaZ2UNmdkietjMn+KzDwwQ9HD4OoWlYknwamtkZZjbGzL4OyrrazL4xs9FmdpyZWaL1Y/LqY2b3mNkHZrbUzDaZ2Voz+9HM3g3OhZPNrHnMepHzpn9o9qQE+zI4/U+p+pXRzI41s/Fm9kPofHw5OAbVpsc+M2ttZn8ws6nB57Qq+D9+wMz2ySHfdmZ2ZbDP35v/7lpmZjPN7F4z651k3eNDn/EXGWyzs5ltDtbbZGZbZVt+EZGaKvgtPzM067/OuU3A46F5Z5EhMzvYzEYFdZ51we/pm+bres0yzKuhmR1hZv8I6n+ROu7a4PfyBTO73MxapJFX19BvxpzQ/APN7L9mNjuooy03szIzOy1efSfYvyfN1+PWmtlCM3vOzI7KZN8yYXmoS5rZ4ND+jwjNPy7I97ugHrLQzCaa2a/j7X+CvLOqSxVzX82sXnBMXwjqG+vNbIGZPWVm+8bJo5H5utirofrJd2Y20sx2TedziZPntmb2V/Nx3hLz9evPzeyfZrZjNnmmsc1dzexGM5sS7O8GM1tkPv76i5ltk4dtDLCYmC+YH69O7Mysa5K8tjWzP5n/zpgXHKclZvahmd1qZjulWaaGwTk8znw8syo4L1cG/+svBfvfN2a9wcF+hK8rdEm0Lxl8TNWujFagun0hmNneQbm+Dr5bFgXn9NVmtkUO+fYJ/v8+CvLcYGbzzex1M7vGzNomWffZ0Of8hwy2+cfQeukMxSUiUmgfhV5nWi8YBWwKXp+ZLGHM8kfSzL8g9znM1/Uj38UDgnk51dPMbFgoz2HBvKZmdo75uvV3we+MM7NecdZvYWaXBr//c83XPZea2afm69n90ty3eNfi+5jZf8zsy2C/lgS/o38ws1bpfWqZy0e9zvw9gahrq2bWzHxs+Zb5+u168/X10Wa2f4osI/ma+euxo4L6zwrz1wlXm7++/pqZ/d187Jfy/lIB97WN+Wum75qP0zYE9aH7zGzbOHm0M7PfB8d3kflY6jMzuzlZvSZFuYp+/lge7kulsY1hwf/JpNDs/ha/Tj0nRV5Z1ylj8snq/l7k3AEeDs0+K8G+lKX3CVXPMppZd/P3UT8PnY8fmY+bOueyb/lmBbjnY95x5q9LfGn++tU689+Bz5i/7xk3b/Px5+LQ51zlOkiS7U4MrXdVNmWv1ZxzmjRpqoUT8D98t8AOeCJPeW4fytMBVxWw/CNC2xmRxfpDY8q6V57KtS2+AuZSTD8AB6bIqyJ98P5CYF2SPBcDvVPkeQywJI3yOWBuCff1anwwnijPF4HGQdr6+N7LkpVhNFA/yfYHh88nYGvg7RR5Pgs0z2S/UqQ9Cfgxjc9zAtA6x/N0TprngAOGJchjADA7jfUnA52SlKUB8O8MyvPXRJ9xGtPgLD+vkpYRaAE8l2K9N4GtiP5uzGp/czy3DsD/zycq52bgz1n8f1wMLEvxGZQDDwKN4qzfGFgaStsnzf25OrTOS8X
2021-05-04 15:24:37 +01:00
"text/plain": [
"<Figure size 2400x600 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# fig = plt.figure(figsize=(5, 4))\n",
"# fig = plt.figure()\n",
"# fig.set_dpi(fig_dpi)\n",
"\n",
"fig, axes = plt.subplots(1, 3, figsize=(12, 3))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, ((optimiser_name, tensors_dict), ax) in enumerate(zip(optim_tensors.items(), axes.flatten())):\n",
" ax.plot(multi_optim_models, tensors_dict[\"std\"][0, :], 'x-', label='Ensemble Test', lw=2)\n",
" ax.plot(multi_optim_models, tensors_dict[\"std\"][1, :], 'x-', label='Individual Train', lw=2)\n",
" ax.plot(multi_optim_models, tensors_dict[\"std\"][2, :], 'x-', label='Individual Test', lw=2)\n",
" ax.plot(multi_optim_models, tensors_dict[\"std\"][3, :], 'x-', label='Agreement', lw=2)\n",
"\n",
" ax.legend()\n",
" ax.grid()\n",
" ax.set_title(f\"{optimiser_name} ensemble test std. dev\")\n",
" ax.set_xlabel(\"Number of Models\")\n",
" ax.set_ylabel(\"Standard Deviation\")\n",
" ax.set_ylim(0, 0.15)\n",
"\n",
"plt.tight_layout()\n",
"# plt.savefig(f'graphs/{exp3_testname}-errors-rate-std.png')\n",
"plt.show()"
]
},
2021-03-29 18:34:04 +01:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2021-03-19 17:21:00 +00:00
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyNAMGLKzaoWaq1wvQ+w0w8h",
"collapsed_sections": [],
"name": "nncw.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-05-27 23:29:17 +01:00
"version": "3.11.3"
2021-03-21 09:56:27 +00:00
},
"toc-autonumbering": false,
2021-03-27 16:29:31 +00:00
"toc-showcode": false,
2021-03-22 20:49:29 +00:00
"toc-showmarkdowntxt": false,
2021-03-27 16:29:31 +00:00
"toc-showtags": false
2021-03-19 17:21:00 +00:00
},
"nbformat": 4,
"nbformat_minor": 4
}