shallow-training/nncw.ipynb

2678 lines
1.5 MiB
Plaintext
Raw Normal View History

2021-03-19 17:21:00 +00:00
{
"cells": [
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 2,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2450,
"status": "ok",
"timestamp": 1615991459232,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "TGIxH9Tmt5zp"
},
2021-03-26 20:01:05 +00:00
"outputs": [],
2021-03-19 17:21:00 +00:00
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import tensorflow.keras.optimizers as tf_optim\n",
2021-03-19 17:21:00 +00:00
"tf.get_logger().setLevel('ERROR')\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import seaborn as sns\n",
"import random\n",
"import pickle\n",
"import json\n",
2021-03-22 20:49:29 +00:00
"import math\n",
"import datetime\n",
"import os\n",
2021-03-19 17:21:00 +00:00
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
2021-03-26 20:01:05 +00:00
"fig_dpi = 70"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fksHv5rXACEX"
},
"source": [
"# Neural Network Training\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l4zqVWyRAM0Z"
},
"source": [
"## Load Dataset\n",
"\n",
"Read CSVs dumped from MatLab and parse into Pandas DataFrames"
]
},
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 3,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 331
},
"executionInfo": {
"elapsed": 2441,
"status": "ok",
"timestamp": 1615991459234,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "Hj5l_tdZuYh7",
"outputId": "fbfa9406-f662-4ebc-8ba2-67950714627c"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Clump thickness</th>\n",
" <th>Uniformity of cell size</th>\n",
" <th>Uniformity of cell shape</th>\n",
" <th>Marginal adhesion</th>\n",
" <th>Single epithelial cell size</th>\n",
" <th>Bare nuclei</th>\n",
" <th>Bland chomatin</th>\n",
" <th>Normal nucleoli</th>\n",
" <th>Mitoses</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.441774</td>\n",
" <td>0.313448</td>\n",
" <td>0.320744</td>\n",
" <td>0.280687</td>\n",
" <td>0.321602</td>\n",
" <td>0.354363</td>\n",
" <td>0.343777</td>\n",
" <td>0.286695</td>\n",
" <td>0.158941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.281574</td>\n",
" <td>0.305146</td>\n",
" <td>0.297191</td>\n",
" <td>0.285538</td>\n",
" <td>0.221430</td>\n",
" <td>0.360186</td>\n",
" <td>0.243836</td>\n",
" <td>0.305363</td>\n",
" <td>0.171508</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.300000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.600000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.400000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Clump thickness Uniformity of cell size Uniformity of cell shape \\\n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.441774 0.313448 0.320744 \n",
"std 0.281574 0.305146 0.297191 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.400000 0.100000 0.100000 \n",
"75% 0.600000 0.500000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Marginal adhesion Single epithelial cell size Bare nuclei \\\n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.280687 0.321602 0.354363 \n",
"std 0.285538 0.221430 0.360186 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.100000 0.200000 0.100000 \n",
"50% 0.100000 0.200000 0.100000 \n",
"75% 0.400000 0.400000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Bland chomatin Normal nucleoli Mitoses \n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.343777 0.286695 0.158941 \n",
"std 0.243836 0.305363 0.171508 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.300000 0.100000 0.100000 \n",
"75% 0.500000 0.400000 0.100000 \n",
"max 1.000000 1.000000 1.000000 "
]
},
2021-04-29 22:53:26 +01:00
"execution_count": 3,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('features.csv', header=None).T\n",
"data.columns = ['Clump thickness', 'Uniformity of cell size', 'Uniformity of cell shape', 'Marginal adhesion', 'Single epithelial cell size', 'Bare nuclei', 'Bland chomatin', 'Normal nucleoli', 'Mitoses']\n",
"labels = pd.read_csv('targets.csv', header=None).T\n",
"labels.columns = ['Benign', 'Malignant']\n",
"data.describe()"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 31,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"executionInfo": {
"elapsed": 2436,
"status": "ok",
"timestamp": 1615991459236,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "qc1Mku6h041u",
"outputId": "94e38c34-0419-4a02-ac8c-17bbc83f777b"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Benign</th>\n",
" <th>Malignant</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Benign Malignant\n",
"0 1 0\n",
"1 1 0\n",
"2 1 0\n",
"3 0 1\n",
"4 1 0"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 31,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h9QsJjWEMbLu"
},
"source": [
"### Explore Dataset\n",
"\n",
"The classes are uneven in their occurences, stratify when splitting later on"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2430,
"status": "ok",
"timestamp": 1615991459237,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "rjjiSYAZMa4k",
"outputId": "ae0c3772-00be-4f2b-80d2-9cd62a6b6e08"
},
"outputs": [
{
"data": {
"text/plain": [
"Benign 458\n",
"Malignant 241\n",
"dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.astype(bool).sum(axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E9lVYI14AUMf"
},
"source": [
"## Split Dataset\n",
"\n",
"Using a 50/50 split"
]
},
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 4,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2604,
"status": "ok",
"timestamp": 1615991459418,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "L83Ae5l9wM35"
},
"outputs": [],
"source": [
2021-04-29 22:53:26 +01:00
"data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5\n",
"# , stratify=labels\n",
" )"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qf2U199fNjmJ"
},
"source": [
"## Generate & Retrieve Model\n",
"\n",
"Get a shallow model with a single hidden layer of varying nodes"
]
},
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 5,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2598,
"status": "ok",
"timestamp": 1615991459419,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "SgoQ-NjWB0T5"
},
"outputs": [],
"source": [
"def get_model(hidden_nodes=9, activation=lambda: 'sigmoid', weight_init=lambda: 'glorot_uniform'):\n",
" layers = [tf.keras.layers.InputLayer(input_shape=(9,), name='Input'), \n",
" tf.keras.layers.Dense(hidden_nodes, activation=activation(), kernel_initializer=weight_init(), name='Hidden'), \n",
" tf.keras.layers.Dense(2, activation='softmax', kernel_initializer=weight_init(), name='Output')]\n",
2021-03-19 17:21:00 +00:00
"\n",
" model = tf.keras.models.Sequential(layers)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get a Keras Tensorboard callback for dumping data for later analysis"
]
},
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
2021-03-30 16:31:10 +01:00
" return tf.keras.callbacks.TensorBoard(\n",
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
" ) "
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "markdown",
"metadata": {
"id": "QT5B9PTUN3pj"
},
"source": [
"# Example Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQGAUtIPAd6e"
},
"source": [
"## Define Model\n",
"\n",
"Variable number of hidden nodes. All using 9D outputs except the last layer which is 2D for binary classification"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 60,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 7889,
"status": "ok",
"timestamp": 1615991464716,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "fYA34P0Vu_pX",
"outputId": "aded18b8-aa7f-4362-a614-837c8a0f526f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Model: \"sequential_1\"\n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
2021-03-26 20:01:05 +00:00
"dense_2 (Dense) (None, 9) 90 \n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
2021-03-26 20:01:05 +00:00
"dense_3 (Dense) (None, 2) 20 \n",
2021-03-19 17:21:00 +00:00
"=================================================================\n",
"Total params: 110\n",
"Trainable params: 110\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = get_model(9)\n",
"model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KZSwFe-AAs1y"
},
"source": [
"## Train Model\n",
"\n",
"Example 10 epochs"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 61,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11304,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "s8U9Atu3zelS",
"outputId": "8439e8d2-7a5d-495f-a192-a34f01e95bfa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 1s 2ms/step - loss: 0.6257 - accuracy: 0.6607\n",
"Epoch 2/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 3ms/step - loss: 0.6226 - accuracy: 0.6651\n",
"Epoch 3/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 2ms/step - loss: 0.6326 - accuracy: 0.6424\n",
"Epoch 4/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 3ms/step - loss: 0.6158 - accuracy: 0.6696\n",
"Epoch 5/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 2ms/step - loss: 0.6228 - accuracy: 0.6534\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
"text/plain": [
2021-03-26 20:01:05 +00:00
"<tensorflow.python.keras.callbacks.History at 0x2cd249f3400>"
2021-03-19 17:21:00 +00:00
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 61,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"model.fit(data_train.to_numpy(), labels_train.to_numpy(), epochs=5)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 62,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11294,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "VnUEJdXovzi-",
"outputId": "02075086-352c-4a23-fac5-ad54d11e0e35"
},
"outputs": [
{
"data": {
"text/plain": [
"['loss', 'accuracy']"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 62,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics_names"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 63,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11285,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "r0vxP3Ah42ib",
"outputId": "061113ba-52db-4fbe-c7f9-b5d3d85438ed"
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=float32, numpy=0.6561605>"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 63,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics[1].result()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "z7bn8pKTAynt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 1\n",
"\n",
"The below function runs an iteration of layer/epoch investigations.\n",
"Returns the amount of layers/epochs used as well as the results and the model.\n",
"\n",
"Using cancer dataset (as in E2) and 'trainscg' or an optimiser of your choice, vary nodes and epochs (that is using early stopping for epochs) over suitable range, to find optimal choice in terms of classification test error rate of node/epochs for 50/50% random train/test split (no validation set). It is suggested that you initially try epochs = [ 1 2 4 8 16 32 64], nodes = [2 8 32], so there would be 21 node/epoch combinations. \n",
2021-03-19 17:21:00 +00:00
"\n",
"(Hint1: from the 'advanced script' in E2, nodes can be changed to xx, with hiddenLayerSize = xx; and epochs changed to xx by addingnet. trainParam.epochs = xx; placed afternet = patternnet(hiddenLayerSize, trainFcn); --see 'trainscg' help documentation for changing epochs). \n",
2021-03-19 17:21:00 +00:00
"\n",
"Repeat each of the 21 node/epoch combinations at least thirty times, with different 50/50 split and take average and report classification error rate and standard deviation (std). Graph classification train and test error rate and std as node-epoch changes, that is plot error rate vs epochs for different number of nodes. Report the optimal value for test error rate and associated node/epoch values. \n",
"\n",
"(Hint2: as epochs increases you can expect the test error rate to reach a minimum and then start increasing, you may need to set the stopping criteria to achieve the desired number of epochs - Hint 3: to find classification error rates for train and test set, you need to check the code from E2, to determine how you may obtain the train and test set patterns)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 194,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 11274,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "mYWhCSW4A57V",
"tags": [
"exp1",
"exp-func"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [],
"source": [
2021-03-26 20:01:05 +00:00
"# hidden_nodes = [2, 8, 16, 24, 32]\n",
"# epochs = [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"hidden_nodes = [2, 8, 16]\n",
"epochs = [1, 2, 4, 8]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_parameters(hidden_nodes=hidden_nodes, \n",
" epochs=epochs, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
2021-04-29 22:53:26 +01:00
" weight_init=lambda: 'glorot_uniform',\n",
2021-03-19 17:21:00 +00:00
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
2021-03-26 20:01:05 +00:00
" run_eagerly=False,\n",
" tboard=True,\n",
2021-03-19 17:21:00 +00:00
" \n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for idx1, hn in enumerate(hidden_nodes):\n",
" for idx2, e in enumerate(epochs):\n",
" if print_params:\n",
" print(f\"Nodes: {hn}, Epochs: {e}\")\n",
"\n",
2021-04-29 22:53:26 +01:00
" model = get_model(hn, weight_init=weight_init)\n",
2021-03-19 17:21:00 +00:00
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
2021-03-26 20:01:05 +00:00
" metrics=metrics,\n",
" run_eagerly=run_eagerly\n",
2021-03-19 17:21:00 +00:00
" )\n",
" \n",
" if tboard:\n",
" if callbacks is not None:\n",
" cb = [i() for i in callbacks] + [tensorboard_callback(prefix=f'exp1-{hn}-{e}-')]\n",
" else:\n",
" cb = [tensorboard_callback(prefix=f'exp1-{hn}-{e}-')]\n",
" \n",
2021-03-19 17:21:00 +00:00
" response = {\"nodes\": hn, \n",
" \"epochs\": e,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TRAIN\n",
" ##############\n",
" \"history\": model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
" epochs=e, \n",
" verbose=verbose,\n",
" \n",
" callbacks=cb,\n",
" validation_split=validation_split).history,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TEST\n",
" ##############\n",
" \"results\": model.evaluate(dtest.to_numpy(), \n",
" ltest.to_numpy(),\n",
" callbacks=cb,\n",
2021-03-19 17:21:00 +00:00
" batch_size=batch_size, \n",
" verbose=verbose),\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"loss\": model.loss,\n",
" \"model_config\": json.loads(model.to_json())\n",
" }\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"model\"] = model\n",
"\n",
" yield response"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "r-63V9qb-i4w",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Single Iteration\n",
"Run a single iteration of epoch/layer investigations"
]
},
{
"cell_type": "code",
"execution_count": 17,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 313592,
"status": "ok",
"timestamp": 1615991770468,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "ZmGFkE9y8E4H",
2021-03-26 20:01:05 +00:00
"outputId": "243fb136-bc07-4438-afb7-f2d21758168d",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nodes: 2, Epochs: 1\n",
"Nodes: 2, Epochs: 2\n",
"Nodes: 2, Epochs: 4\n",
"Nodes: 2, Epochs: 8\n",
"Nodes: 8, Epochs: 1\n",
"Nodes: 8, Epochs: 2\n",
"Nodes: 8, Epochs: 4\n",
"Nodes: 8, Epochs: 8\n",
"Nodes: 16, Epochs: 1\n",
"Nodes: 16, Epochs: 2\n",
"Nodes: 16, Epochs: 4\n",
2021-03-26 20:01:05 +00:00
"Nodes: 16, Epochs: 8\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
2021-03-26 20:01:05 +00:00
"# es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience = 5)\n",
"single_results = list(evaluate_parameters(return_model=False, validation_split=0.2\n",
" , optimizer = lambda: tf.keras.optimizers.SGD(learning_rate=0.5, momentum=0.5)\n",
"# , callbacks=[es]\n",
" ))"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "mdWK3-M6SK8_",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Train/Test Curves\n",
"\n",
"For a single test from the set"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 68,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 517
},
"executionInfo": {
"elapsed": 314527,
"status": "ok",
"timestamp": 1615991771417,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "F9Xre1o6SesD",
2021-03-26 20:01:05 +00:00
"outputId": "d6b817aa-58cd-4510-807f-e5e6bcf62f18",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Nodes: 2, Epochs: 4\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAGuCAYAAACNy6eFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAACAaUlEQVR4nOzddZxU1f/H8dfZJpalY5eU7u62SJVUsQMVMb4WtvwwsMUEJFQMREJQJERBlqW7uxtZuhd29/z+uAMsuMDCxp2ZfT8fj3kwc2s+Z2aYu585536OsdYiIiIiIiIiqRfgdgAiIiIiIiL+QgmWiIiIiIhIGlGCJSIiIiIikkaUYImIiIiIiKQRJVgiIiIiIiJpRAmWiIiIiIhIGlGCJV7BGPO1MebFFG670hhTP71jyijGmOLGmPgkjycaY+64xLbNjDEbrvF5Ghtjll5rnCIicp7OWzpviVyK0TxYkhrGmGNJHmYDTgBnP1QVrLXbMj6qjGeMaQBMAApaa08lWZ4d+Beob61ddol9iwMbrLVBKXieZsBga22pFGy7BbjHWjsjBU24ZsaYITjxv5OezyMikhZ03nJk5vNWkucbCDwAFLLW7s+I55TMQT1YkirW2uxnb0AcUDHJsm0AxuHXnzVr7SxgP9DmolXtgY2XOkmJiEjG0nnLkdnPW8aYUKATcAxItvctnZ73ikmp+D6//vIQ9xhjhhhjvjLG/IPz62BJY8xDxph1xpijxphlnl+1km7/uuf+A8aYf4wx/Y0xR4wxq4wxNZJsu8UY0yjJfl8YY6Z4jvuXMSZ3km0fMcbsMMbs8dy3xpjCycQ72Bjzfxct22SMaWSMyecZ/nDIGLPPGDPsEs0eCtx10bK7gZ+MMbmMMX969o81xgz0fLkn99pFG2Pu8dwPNMZ8bozZb4xZC9S7aNsvjTG7PLH9ZYwperY9QFHgL2PMMWPM3RcP0zDGVDTGTPfsu9AY0/Ci1/h5Y8xqz/qvLtHmSzLGhBlj+npe+23GmJ5n/2AxxtQzxiz2vL87jTHPXm65iEh603nrnMxy3moDHAfeB+65KMYSxpjxnjbsNsY87VkeZIx5yxiz1Rhz2BgT7Vn+n2GQSd83T2wvGmNWAxsu9zpc6vmNMYU9n61sSbZ70Bjz1xXaKS5QgiXp6U6gBxAObAH2ADcAOYEvgV8u9WUNNAZigFzAaODTyzzP7cCzQD4gEPgfgDGmMvAR0A4oATS4zDGGe46DZ986QDAwE3ge2AzkBaI8sSfnJ6C1MSbCc4wCwPXAzzj/1/p69q8C1AIev0w8Zz0GNAcqev7tctH6mUB5oBCwA/gCwFrbFdgG3Oz5VXZo0p2MMSHAH8BInNftQ+APY0yuJJu1w3kfKgG3G2OapyDepN7wxF0eaIRzArvPs+4z4GNrbQ7P8aOvsFxEJCPovJV5zlv3ACOAX4C6xpjrPM8TBIwH5nvaXtYTM8BLQEvPc+QGel7x1TivA9DM03a4xOtwqee31u4AFgC3JjlmF5z3SryMEixJT79aaxdaa+OttWestROstduttQnW2kE4Y95LX2LfNdbaYdbaBJwvj6qXeZ6R1tplnjHkvybZtiMw2lq7wFp7ErjcNUL/APmMMRU9j2/3HNcCZ3C+AItYa+M8wyr+w1q7DljueV5wTtQzrLU7rLX7rbV/ePbfDQzASTqupDPQx1q7x1q7i4tOktbaX6y1hz3t+yCFxwSoCwRYa7/wvDfDgbU4J46zPrPW7vN8qUdz+fcgOXcCb1prD3qG3XzC+RPtGaCUMSa3Z/3iKywXEckIOm9lgvOWMSYn0BoY7jk/zeF8L1ZdnAT7LWvtKWvtEWvtQs+6B4DXrLXbPJ+JmBTGDvC5tfZfT7sv9zpc7vl/wnMe9STDDXCSefEySrAkPe1I+sAY084Ys8jTHX4IyA/kucS+/ya5fwLIfpnnudS2BS+K4YJ4kvKcEEcBdxhjDM4JYrhn9Uc4v6pNM8asMcY8fJlYfuL8cIu7PY8xxoQbY37wDPs4AvTh0m1PqhCwPcnjpPcxxrxmjNngOea8FB4TIPLiYwFbPcvPupr34FLPkfRi8aTH74rz6+YGY8wMc7661qWWi4hkBJ23Msd563Zgl7V2nufxLzhtBygMbLXWJiazX2GcnsFrcfFn61Kvw+WefxTQ1NNr1xn401p75BrjkXSkBEvS07kSlZ4hFcOA14A81tqcwF7ApOPz78HpXj/rP2PYL3J2uEU9INFaOxfA8+vR/6y1RXF+vfry7FCCZPwCNDLGNMUZojDKs/w5nCEN1TzD354jZW3fDRRJ8vjcfc9zdMf5FS4CqHPRvpcrEbrrouOCM/Z9VwpiSqldnmP+5/jW2rXW2ttx/lj5BeezccnlIiIZROetzHHeugeINM51bnuAXkAZzzDL7UAxT9J6se1A8WSWHweynH3g6V26WNLP1uVeh0s+vyeZmoTT49gF5xo68UJKsCSjhAIhOCcnjDH/w/niTk9jgI7GmBrGmDDg1StsPx2nW/4dnHHZABhj2hhjrvN82R3G+ZJMSO4A1tq9OMM2fgDGJfllKRzn17TDxphiOF+sKTEKeNYYU8AYUwh4Msm6cJxhIPtwSg2/ftG+e0n+RAAw19O2Jz0X7XbGGQv+ZwrjuliQcYpanL0F4Zz43zDOhdJFcE7Ov3ie925jTB5rbTxwFM/reanlIiIu0HnLD89bnrY0xLk+rJrnVhHnuqd7cHqTjuKcv8KMMTmMMTU9uw8B3jHGFDFOMY8mnuXrgFzGmKaexPyNK4Rxudfhcs8PTg/js562T7iatkvGUYIlGcLzhd0D55eXPThd4dc08eBVPOdS4GWci2K3AGfHMMddYvtEnItnryfJiQooA0zF+cIbDzxjrd16maf+CedXtZ+SLPscZ9jEQZzx9mNS2IwBOBdNr8YZT/5LknV/4lwkuxVnDP3FY+w/AN73DG25oEqUtfY0zoWyXXDK9L4C3GqtPZjCuC72f8DJJLevgbdxxsevAWZ7Yv/es31rYK0x5ijwNOeLX1xquYhIhtJ5y2/PW3cD0621sz3Xie2x1u4BvuJ8ufa2ONc37cY5j50drv4RMMUT936cni+stYdxCpWMwBlCOP8KMVzydfD8wHip5weYCBQAxlhrk/1ciPs00bBkGsaYssAyIMzqgy8iIl5O5y1JjjFmBfA/a+0Ut2OR5KkHS/yaMaatp4s9AngPGKuTlIiIeCudt+RyjDE3AVlxeijFSynBEn93B05VoS04n/enXI1GRHyOMWaMMeagMWbUJdbXMcas9FQEu5p5cUSSo/OWJMsYMxxnyOXTl6gyKF5CQwRFREQuwxjTDOei9PuttZ2SWT8feBhYiXNdxSPW2uUZGaOIiHgP9WCJiIhchrU2GqdYwH8YYyKBIM+ksQk4vy63zcDwRETEywS5HUBSBQoUsCVKlEjVMY4dO0b27Fc7H6rvUPt8m9rn2/y9fZD6Ns6dO/dfa23BNAzJ20UCO5M83gk0TW5DY0xXnMm0yZo1a93ixYun6okTEhIIDAxM1TG8mdrn29Q+36b2XdmqVasueb7zqgSrRIkSzJkzJ1XHiImJoUmTJlfe0Eepfb5N7fNt/t4+SH0bjTFb0i4a/2KtHQwMBqhXr57V+e7y1D7fpvb5NrXvyi53vtMQQRERkWu3C4hK8jjKs0xERDIpJVgiIiLXyFq7C0gwxlQxxgQCd+JMEisiIpmUVw0RFBER8TbGmMlAVSCbMWYH0Bl4A+jqSbCeBIYBYcCPqiAoIuIF4uPh+HE4duw/t3zz50OJElCkSLo8tRIsEclU4uPj2bFjB6dOnbrqfcPDw1mzZk06ROU9rqaNhQoVIiIiIp0jcp+19sZkFrdOsn4OUDG1z3O1n01//zxea/vCwsIoXLgwQUH6E0fEZ5w5899E6OjRZJOjFN9Onrzk05UHqFhRCZaISFrYsWMH4eHhFCtWDGPMVe179OhRwsPD0yky75DSNp46dYodO3ZkigQro1ztZ9PfP4/X0j5rLQcOHGDHjh2ktkqjiCTDWoiLS3mSk9Ik6fTptI0zMBAiIiB7ducWHn7
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 1050x490 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"single_result = random.choice([i for i in single_results if i[\"epochs\"] > 1])\n",
2021-03-19 17:21:00 +00:00
"single_history = single_result[\"history\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(15,7))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## LOSS\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[0]\n",
"ax.set_title(\"Training vs Validation Loss\")\n",
"ax.plot(single_history['loss'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_loss'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## ACCURACY\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[1]\n",
"ax.set_title(\"Training vs Validation Accuracy\")\n",
"ax.plot(single_history['accuracy'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_accuracy'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
2021-04-06 17:29:15 +01:00
"# ax.set_ylim(0, 1)\n",
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
"print(f\"Nodes: {single_result['nodes']}, Epochs: {single_result['epochs']}\")\n",
"# plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "0IQ7HfJCSDud",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Accuracy Surface"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 69,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 705
},
"executionInfo": {
"elapsed": 315450,
"status": "ok",
"timestamp": 1615991772345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "X3MWHLxJElbc",
2021-03-26 20:01:05 +00:00
"outputId": "134671d0-bfd3-4ee6-aa02-1a2a5b23f3ca",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAFWCAYAAABwwARRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAADBx0lEQVR4nOydd3hcV5n/P2eapJlR771asiz3RkISkpCErA0YQkJICLBLCyVh+UEgECCh7bILIbALoSxkIUBCQm8hgVASlpLYVrcl925JttXbSKMp5/fHzB2PRjOjO9KMms/neeaxNbede+fe+z3ve97zvkJKiUKhUCgUiuWPYbEboFAoFAqFIj4oUVcoFAqFYoWgRF2hUCgUihWCEnWFQqFQKFYIStQVCoVCoVghKFFXKBQKhWKFsCiiLoR4VAjxCR3rXSOEOLoQbVIsLkKICiGEO+jvDiHE5f7/CyHED4QQQ0KIX/i/+7wQYkAI0bRYbU4EQoirhBBtOtf9mBDi4US3aT4IIe4QQvxmjts+L4R4UwLaFPUaCyH+RQjxx3nsXwohSvz/f0YI8YagZdPuWyHE+4QQF4QQfXM93nJGrxYsJAv1bhFCnBRCXBnv/c4q6v4DjwshbEHfWYUQo0KIk/FukF5CRWAe+5nXA6xIDFLKBinlC/4/rwKuAAqklDcJIcqA9wI1UsotC9UmPffcfIVISvlXKeUGnet+Tkp591yPtRBIKR+XUr56sdsRTOg1TtTL1X+sHVLKH/mPM+2+FUJYgC8AV0opcxJx/EgEdzwUF1msd0s80WupdwGvDfr7NUBP3FujmIYQwrTYbYgHcTiPMuC4lHIy6O/zUsqBRWjLvFjs4y8kl9K56iT0vs0DzFLKw7HuSF3b6cTxesz53bJkkFJG/QAngU8DTwd991vgk8DJoO8agL8CQ0ATcEXQsmrg78Ao8DPgR8AngpbfBRwB+oDvATb/99cARyO06zAggTH/pwww+tt6CjgPPASY/OtfBrQAI/g6KR8AqoBJwO3fR0eEY30V6Paf27NAWdCySv/16MfX0flX//cm4DP+tgwDz0c6J/95lARd73uBA9r1jfX4QIn/PG1B670VeDbC+YX97YA3Ac+FrPtd7bfzX3Pt2AeAfwq5b6adR5jjfsz/O530t9sdsv2VwB0hv9EngQnA6//7If/6twAdwADwayAv+Hr774s+4N+BFOBh/zU9C3w06LiPAl8B/oTvfn0WyIp0z4Wcz/2Ax9/eMf/5hTt+NfB//uvdDXwuaB/T7g//8d4DnPBvf1/Qsk8Bj4Sc5wP+a3ASuDFo3VXAC0R4BkPO43ngs/iel0F8z2RK0PJYrvW/AH8M2vYq/36HgL8A9UHLtgHt+O7db/qv0Zv8y14FHPK3/yRwW5h23wD8I+Qd8R3//22AA7AGX2PgEXz3ksP/m93hb/OfgW/429IJbI7yjnwbcAY4B9zJ9Of5eXzP0VVMv2+/C4xz8X76iX/9q/E9g0P+bav931fgewbeje/99RjR33efAh4HfuK/ZruBSv+yZ/3HHfcf+6oI98CngUb/NfgRkORfFvqbVjD92ZX43ukn/efxLuCl/us4CNwf8rx9Fd+9MAL8Bv/zNofrUQv8zb+f88CDEX6vZOBr/t/rNL5nxhDmN3oozLaPEuH94F/+OnzvvEHgKaA4aNkOfM/HAH79xOelgejvpFnv/WltjLYw6OV6Nb6bNtf/OYPPHaqJjgU4ju/FbAbe4G94pn/5HuBz/vVeC7i4KAyvB/YB5f4T+yHwxUgCGOlG8n/3YXwPYy6QATwH3O1f9iJwh///mcCmcDdohGPdBqT72/cd4Jf+7034btRP+W+UNGCLf9nH/eetdTZeFumcmCnqLwL5+F+kczz+n4Hbg47xLPAvYc4t4m8HpOK7ufOD1h3EJxAGoM2/nQm4HOgNWnfGeYQcdye+h7HW/1v9kTCiHuElMu0aAtvxPQjr/OfwBeCnQeu68T24Zv81/Bq++8wOFOETqFcFPbTngPX+a/on4NOR7rkIL8M3hbQ19PjV+J4pk/9angZeG+HcJL4Xsx1Yi6/DoL3YPsV0UXcDH/Hv907gVNB+GvG9pM3Aq4Epoov6qaDf5jngs3O81oHfDsjGd/+8zr/8w/g68yZ899YZfB0YM/A+/740UT/Hxc5mAbAmTLvt+MQ5Bd99dwI45F92HfBihGt8Ev+9FnS/uYDb8T27/wb8JcK1WotPRF7iP+73CSPqEY5bwfR7vhTfM3SV/7jvA/YGrSvxdXaS/ceK9r77lP9aXOu/vt8HvhfunRPlHjiA772cge8ZeWuE5zH0PCTwJL4O1LX4hPJnQBaw2v93VdDzNhR0/X4I/GCO1+NJ4D5A4OvEbY9wbv/uP79MfO/nw/jfjaG/UZhtHyXy+2E1PgPuSiAJn/g/51+Wi0+UX4XvXv8Cvvtbe8dFeyfNeu9Pa2O0hcE3PPBl4G7/57/wWb6aqF9FiDWGzyq43X9TTALJQcv+xkVR/x3wxpCHRNtvxAsceiP5vzsIvDTo71dx0UL+K77eUVbINv/CLKIesn4d0Of//xX4XkSGMOsdAW4I8/2Mc2KmqN8eh+O/Dfi1///5+MQ5Lcx6EX87//9/DtwVdD2b/f+/DDgcst1PufhwzHYe3wU+FfT39cxd1L8JfCzo71R8L2WTf91xLlowAt/Lriho/buBR4Me2q8GLXsvFztRM+65MOf1PDNFPXD8CNv8BxE6sv57Y0vQ33u42AH4FNNFfVi7F/C9UCW+F3IFvhdpUtB+/kp0UQ/9bQ7Heq1DfzvgzQSJI76OYRe+DuHVwImgZQLfva0J4hngHYB9luu/F5+Q3ILP8tmP7/7/VJRrfJKZor4v6O81wFCE431Su3f8f9cwd1H/KPCtkP33+ter8O+3MGhZtPfdp4DfBC3bCbSG3Fezifo9QX9/AfivCM9j6HlIgjwb+Kzm1wX9vZuL9/CjYa7fpP/3j/V6/ADf/VkY6bz86x0Drg36+13A78P9RmG2fZTI74f7md5xsuN7NgqBf9Z+m6Dncwqfts72TtJ172ufWKLfHwfeiM899XjIsiL/gYM55f++EOiVF8dDCVm3DPgff2TzED7Bz42hXcGUAc8E7etxfONW4LsoDcBRIcTftMhqPQghPi6EOCqEGMH3Us32LyrBZw15w2xWgs9SmAtn43D8nwJXCyEy8XlDfielHAmzXrTfDny9Xy169w343HDgu9aV2rX2X+9/wvd7hz2PEApDjhvahlgoAz4e1I4z+HrBBf7l56SUWoBbLr5efWfQ+p/D9+LXOB/0fwe+h3M+BB8fIUSxEOIXQohzQohh4P9x8TcNh9729Gr3gpTS4f/Oju869EopnUHrRvttYOZvo/2usVzrUIrweSXwt9Hr3157T5wNWiZD2ngLcBNwVgjxOyFEfYRj/BVfR/UqfO+S4L//GuV8Q9F7zeN9H7855JmyAcX+5V4pZU/I+pHedzD/+3g+218I+v9EmL+D9xV6/ZLwWfWxXo978VnBrUKIFiFEpADNafch0993eoh0XULv7zF8Q5Pa/X0maJnDvwxmfyfpvfeBGKa0SSkb8V3oTCnl3pDF3fhcJcGU+b/vAXKEEMlBy4LX7QL+WUqZEfSxMTsyzHdd+Hpg2n7SpZRr/O0/JKW8Fd9N/yTwRJT9BBBCXI2vN7YTnwt8e9DiM0C5EEKE2fQMvt5kKOP4fkBt//lh1pFBy+d0fL+A/x64GZ/HJLQjphHttwPfuNBGIUQVPrftj/3fdwEHQn43u5TyP8KdRxh6Qo4b2oZY6MI3ThfclhQppSYKwe3oA5z43H/aumlSyh06jhP1XomyTuh3/4bPDV0rpUzH5/kKdw/Fi3NArj/aWmO2yOfQ30Z7ecZyrUPpxndvAb6piv59a++J0DYF/pZS7pZSvhLfi64N33h3OP4KvAyfBfRX/+dafM/N3yNso+d3jUS87+Nvh1xbq5RSa3doOyO+7xLMtHcY0zvEcyH0+jnxDQHGdD2klD1Syrfh62B+CvhxiO5oTLsPmf6+mw+h97cNX2ddu79Lg5alcLEjH/W
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(epochs, hidden_nodes)\n",
"\n",
"shaped_result = np.reshape([r[\"results\"][1] for r in single_results], \n",
" (len(hidden_nodes), len(epochs)))\n",
"\n",
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
"surf = ax.plot_surface(X, Y, shaped_result, cmap='viridis')\n",
"ax.set_title('Model test accuracy over different training periods with different numbers of nodes')\n",
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
2021-04-06 17:29:15 +01:00
"# ax.set_zlim([0, 1])\n",
2021-03-19 17:21:00 +00:00
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "C793_RHvSGai",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Error Rate Curves"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 70,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 668
},
"executionInfo": {
"elapsed": 316211,
"status": "ok",
"timestamp": 1615991773109,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "tpClZMptrq-q",
2021-03-26 20:01:05 +00:00
"outputId": "f9fe93f9-7b67-4772-83e4-9e3567fd9318",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAf4AAAFECAYAAADGPlw2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABd10lEQVR4nO3dd3wc1bXA8d/Rqjd3reQq44otuWjdcaMZbExLKAmmOHl2Qgu9B/IgAfLoJZTQguPQITQbGwy4CBuMG+4FbFxwlauKJVntvj9mJK+kXWkl72ql1fl+PvuRdubOzLmzM3tm7szeEWMMSimllGoewoIdgFJKKaUajiZ+pZRSqhnRxK+UUko1I5r4lVJKqWZEE79SSinVjGjiV0oppZoRTfzNhIhcJCK7RCRPRJIaQTz3iMhzfpjPZBH5qh7TdRaRIye6/PryV/3ruMxkEflWRHJF5BYfys8Xkcvt/yvFW3V7EpFRIrLFfj84kPVozETkfhF5NdhxlBORaSJyr5dxFZ+vh3GjRGRVDfOtadp67ZMNRUTGisjmIC37nyJyh49l14nI8EDEER6ImdZGRPLc3sYB+UB5hwJ9jDE76jCvacBmY8yD/osweAJYn0eB3xlj5vh5vvVijHk4yMvfAbQsfy8i84FXjTFv+HtZIpKK9ZlW7G9Bqv8fgG3GmBF1ndBDvJW2JxF5B3jEGPPyiYfpu0B+bs2ZMeYboH+w42hMPO3HdWWMuboOZfvWdzm1CUriN8bEl/8vIoVAX2PMtmDE4k5Ewo0xJd7e13V6D+MFEGNM2QmGWh+dgfV1naiu66A5akLrqF7bgI/z0u1LNWtB/n6vG2NMUF9AIZBq/98aeAvIAn4GrnIr93tgO5ALbALGAlcBxcAxIA/4p5dlXASsAw4BnwJJ9vCxwGbgAeAA8BAwH/gbsMyebwTwK2ADcBiYCXSwp08FSoCrgV3AGx6WPQ14DpgLFAA97Lr8aNdlNTDWLuuxPkA6kGkvfzkwyB4eBjxrx34EWAq09RBDHlaLylFgiT1sFPCDPd0C4GS38ga4HtgKfONhft3seI4Au4GHa/h8J9qfVy6wDfiNPfx+rDM198/hL/ZntA04y20ePYDv7Hn8F3gXuNceNxn4yq2sx3XlIa5UoMT+/z6gFGtbzAPusYePsedxxN4uunn73IFWwOf2Z7EfeBmIssv/aK/TPPvV2b3+dpnatrGpwB5gL277hYd69QW+sWNeDpxiD/8nlbetgR6mHYy1PebY5TOByz18XpW2J6x9qwyr5W6vXaYz8Blw0K7X2W7L2QbcYQ/f5uM+Wm3b8Pa5eajXdcBP9mfzbyDObduZC7xi13klMKC2dWmPa8fx76oDwP+5rac3gfexttfvga72uBjgbbseh/Cwb9nlatqWvK4Pt31zER72FQ/LmY/13bfMrv+7VZfj47YRZ9f5CLAC63vUfZ+sbT/ydds2wDVY30sHgLvdxkUDz9vz2GGvnzB7nAN4Bmtb3ATcVaVuvn5neNqPp+Hj97tbPnD/7poLvGiv1/VARpX9ZKTbdM8CX9vznQO0dis7Fdhp13+qHWdHr+vS24iGelE58X8GPA5EAb2xkko/e8PKAXrY5bpwfGeqWJFe5j/EXiHpWEn8UeADt427xN5IIrB2zPlYO1Z3e2M6GcgGRtpxPQvMc9twDdaOEA3EeFj+NHsjdWG1sEQAE4BO9gY51f6wojzVB4jHSi6/tstfgLVhRwNnY+20ifa4DCC+hp2mo/1/G6yN/Fd2PLdjfTGGu5X9xJ6vpzp1w9qZw7E29B3ABV6Wu5fjyScZ61IOVE/8JcCd9jz/AGx3m8cyrC+oCOBcoAgPib+mdeUhrlTsxO/2JXi52/tOWF+6o+x5/QlY6u1zt9fpufY2koL1BXiTp2V5qH9vat/GngEigXFYXzoJHuoUiXXAfIO9ri7FSg6tattX7Gl/wfpijbDrW4KHxF91e/LwJRUGrLLjCAeG2+vS6VZ2MeC0150v+6i3baPS5+ahXhcDa7C+M2KwkvXjbttOSZU6/2wvp7Z1OQd4FUiw5zvcbT3lA6fa85kO/NsedzXWQU2MPW6Ul5hr2pZqWx9LgIft+C/AOtirKfFvsNdNS6wDr9+5LWezj9vGo8BXWN8Xve2y5fukL/tRrdu22zb3PtZ+noaVO8oPIspP2lphJeQfgcn2uGuxEnAy0B5r2yyvW72/M070+x1r+ysGfmuXfRBY4GWfmmbPpx/Wd87XwAP2uHSsg6pBWNvW6zSVxG9/KEeBCLdxj2PtSHFYX4wXlK/AKiu+psT/T9zOBLB21GL7QxprLzO8ys7gfiR5H/aO67ahFGPtkKn2Ck6pYfnTgJdqWQd7gDRP9QF+A8ypUn6ZHfvpWEewQ7CamGpahnviv6LKBhaGtfEPdys7vA6f4d+xv0w9jPsFmEKVAxKqJ/5sjh+hx9oxtLTXcYH75451FuYp8XtdV7XtxFRP/HcBL1eZZr89nS+f+x85nrwqLctD/X3Zxtq4jc/C7czUbfgo7DNot2HfAb+tbV/BOpDb6vZe7M+uPol/GPBjlfl/wPEv4m3lMfm4j3rcNjx9bh7q9Tlwmdv7NI63Mkz2UucRNa1LoAPWwWech+XdD8xwez8BWGn//z9YZ+N9fd23PGxLXtcHVgIvxC1pAQtr+MznA7e6vX8UeNptOeXJsbZtYyuVz2of5Pg+6ct+VOu27bbNudzeL8E+4QC2AKdWWWdf2P/PK9/27PdT3OpW7+8Mt32qXt/v9va3xq1cH+CIl31qGvAPt3HXAh+7bXP/chvXjVoSf2O6q78z1pHMfhE5Yt9x/Ucg2RhzFGuHuwHYJyLvi0j7Osz3z27z/AXraDXZHr/XVL/GuNPt//ZYR4AAGGPysJqMypdfZozZU0sM7vNDRC4QkRVuMSVhHeV7i39MeVm7/MlAe2PM11hfmi8De0TkcRGJqCUWT3Uqw1ov7ut0Z9WJ3OLvICIficheEckGbqoh/ouAC4GdIvK5iJzspdx+Ow6MMfn2sHisz2m/MeaYD7F5XVfe6lKDzsAVVeYVh/WlD1U+dxFJEJHpIrJTRHKAJ/G+TqqqbRsrNcYcdCufj7VuPM3nlyrDtuNb/VNwW6/G+gbxug3UojPQtcq6O9teRrmdVcrXtI962zZ8jeUlt3kvxGqmrxaHW51TqHlddgSy7O8lT/a5/e/+Wf0H60ztIxHZLiJ3e5rYh23J2/pIsccVupWtWgdfY3VX27aRUmU57v/Xth/5um3XFm+lfYjK231t8Z3od8aJfL/7sv5rK5tcJYZa99vGlPh3YTXztDLGtLRfCca+C9IYM8sYcxrWTncMqzkLjv8aoKb53uc2z5bGmBhjTPnK8TS9+7DdWBsHACISh/Uh7vZx+ZXKiEgU1nW+P2Md6bbEOsoVL/PbhXXk6h5/nDHmLQBjzFPGmAFY1+DOAib5EE/VOglW09RutzI11etBrEsFPY0xLYCn3eKvxBjzvTHmHKxm3VVY17PqYi/QTkQi3YZ19FK2xnVVC0/r/ZUq84o1xizyUv4WrIQywBiTaL/39plWVds25qvdWJ+ju84+zmcP1dert/Vcm13AhirrLt4Y83e3MqZK+Zr20Zr4sv9fVXWbcBvvqc57qHld/oK1Tcb6EN/xQI0pMsb8xRjTE2tfvVFExnooWtO2VJM9QFsRiXYbVrUO9VHbtrGnynLc/69tP/KXSvsQlbf72uLz9TvD27ZWl+/3QNjL8QMp8GG/bTSJ3xizC6sp7UERiRWRcBHJEJE+IuIUkYkiEoOV9POxbuoBa6Wm1jDr14HrRaQ/gIi0FpHz6xDaB8AFIjLCTj4PAt/6cJbvTRTW9awsO54bqXwGUrU+M4GB9lFkuIjEiMjZItJCRAaJyGARCce64aOY4+ulJrOB/iJyvj3tzVjN6ct8rEOCvbw8EUkDvP2eN1JELhORRDu2PB/jq2CsX3tsAO4SkQgROQcY6qW413Xlw6Kqrve3gIvF+j1zmH0WdlEN0ydgbZfZItIFqymu3AEgTES87ZD+2sa+BxCR6+36X4x19vK5D9N+B0SIyB/
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for layer in hidden_nodes:\n",
" plt.plot(epochs, \n",
" 1 - np.array([i[\"results\"][1] \n",
" for i in single_results \n",
" if i[\"nodes\"] == layer]), \n",
" label=f'{layer} Nodes')\n",
"\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.title(\"Test error rates for a single iteration of different epochs and hidden node training\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
"plt.ylim(0)\n",
"\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "7mJaKjlCxEkt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Multiple Iterations\n",
"\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the epoch/layer investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs |\n",
"| --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|3|0.2|0.05|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|4|0.08|0.04|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|5|0.08|0|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|6|0.06|0|128|1, 2, 3, 4, 5, 6, 7, 8|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|7|0.06|0|35|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
2021-04-29 22:53:26 +01:00
"| test | learning rate | momentum | batch size | hidden nodes | epochs | statified |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100| |\n",
"|3|1|0.3|20|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100| |\n",
"|4|0.6|0.1|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
"|5|0.05|0.01|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
"|6|1.5|0.5|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| |\n",
2021-04-30 20:51:04 +01:00
"|2-1|0.01|0|35|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| n |\n",
"|2-2|0.1|0|35|2, 16, 32|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-3|0.15|0|35|2, 16, 32|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-4|0.08|0.9|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-5|0.08|0.2|35|1, 2, 8, 16, 32, 64|1, 2, 4, 8, 16, 32, 64, 100| n |\n",
"|2-6|0.01|0.1|35|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200| n |"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 214,
2021-03-19 17:21:00 +00:00
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "-lsKo4BCP3yw",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_param_results = list()\n",
2021-03-19 17:21:00 +00:00
"multi_iterations = 30\n",
"for i in range(multi_iterations):\n",
" print(f\"Iteration {i+1}/{multi_iterations}\")\n",
2021-04-29 22:53:26 +01:00
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5\n",
"# , stratify=labels\n",
" )\n",
" multi_param_results.append(list(evaluate_parameters(dtrain=data_train, \n",
2021-03-19 17:21:00 +00:00
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
2021-04-30 20:51:04 +01:00
" hidden_nodes=[2, 16, 32],\n",
2021-04-29 22:53:26 +01:00
" epochs=[1, 2, 4, 8, 16, 32, 64, 100],\n",
2021-04-30 20:51:04 +01:00
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.15, momentum=0.0),\n",
2021-04-29 22:53:26 +01:00
" weight_init=lambda: 'random_uniform',\n",
2021-03-19 17:21:00 +00:00
" return_model=False,\n",
" print_params=False,\n",
2021-04-30 20:51:04 +01:00
" batch_size=35)))"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-22 20:49:29 +00:00
"(Iterations x [Test/Train] x Number of nodes x Number of epochs)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 301,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30 Tests\n",
2021-04-28 21:57:13 +01:00
"Nodes: [2, 8, 12, 16, 24, 32, 64, 128, 256]\n",
2021-03-22 20:49:29 +00:00
"Epochs: [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"\n",
"Loss: categorical_crossentropy\n",
2021-04-28 21:57:13 +01:00
"LR: 0.01\n",
"Momentum: 0.0\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_param_epochs = sorted(list({i[\"epochs\"] for i in multi_param_results[0]}))\n",
"multi_param_nodes = sorted(list({i[\"nodes\"] for i in multi_param_results[0]}))\n",
"multi_param_iter = len(multi_param_results)\n",
"\n",
2021-03-22 20:49:29 +00:00
"accuracy_tensor = np.zeros((multi_param_iter, 2, len(multi_param_nodes), len(multi_param_epochs)))\n",
"for iter_idx, iteration in enumerate(multi_param_results):\n",
2021-03-19 17:21:00 +00:00
" for single_test in iteration:\n",
2021-03-26 20:01:05 +00:00
" accuracy_tensor[iter_idx, :,\n",
" multi_param_nodes.index(single_test['nodes']), \n",
2021-03-26 20:01:05 +00:00
" multi_param_epochs.index(single_test['epochs'])] = [single_test[\"results\"][1], \n",
" single_test[\"history\"][\"accuracy\"][-1]]\n",
2021-03-22 20:49:29 +00:00
" \n",
"mean_param_accuracy = np.mean(accuracy_tensor, axis=0)\n",
"std_param_accuracy = np.std(accuracy_tensor, axis=0)\n",
"\n",
"print(f'{multi_param_iter} Tests')\n",
"print(f'Nodes: {multi_param_nodes}')\n",
"print(f'Epochs: {multi_param_epochs}')\n",
"print()\n",
"print(f'Loss: {multi_param_results[0][0][\"loss\"]}')\n",
2021-03-26 20:01:05 +00:00
"print(f'LR: {multi_param_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
2021-04-30 20:51:04 +01:00
"print(f'Momentum: {multi_param_results[0][0][\"optimizer\"][\"momentum\"]:}')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-04-30 20:51:04 +01:00
"cell_type": "code",
"execution_count": 215,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-04-30 20:51:04 +01:00
"outputs": [],
"source": [
2021-04-30 20:51:04 +01:00
"pickle.dump(multi_param_results, open(\"results/exp1-test2-3.p\", \"wb\"))"
]
},
{
2021-04-06 17:29:15 +01:00
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 300,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-04-06 17:29:15 +01:00
"outputs": [],
"source": [
2021-04-28 21:57:13 +01:00
"exp1_testname = 'exp1-test1'\n",
2021-03-22 20:49:29 +00:00
"multi_param_results = pickle.load(open(f\"results/{exp1_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"np.savetxt(\"exp1-mean.csv\", mean_param_accuracy, delimiter=',')\n",
"np.savetxt(\"exp1-std.csv\", std_param_accuracy, delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
2021-03-22 20:49:29 +00:00
"mean_param_accuracy = np.loadtxt(\"results/test1-exp1-mean.csv\", delimiter=',')\n",
"std_param_accuracy = np.loadtxt(\"results/test1-exp1-std.csv\", delimiter=',')\n",
2021-03-19 17:21:00 +00:00
"# multi_iterations = 30"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Best Results"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 302,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-28 21:57:13 +01:00
"Nodes: 256, Epochs: 200, 1e+02% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_param_accuracy_idx = np.unravel_index(np.argmax(mean_param_accuracy[0, :, :]), mean_param_accuracy.shape)\n",
"best_param_accuracy = mean_param_accuracy[best_param_accuracy_idx]\n",
"best_param_accuracy_nodes = multi_param_nodes[best_param_accuracy_idx[1]]\n",
"best_param_accuracy_epochs = multi_param_epochs[best_param_accuracy_idx[2]]\n",
"\n",
2021-03-26 20:01:05 +00:00
"print(f'Nodes: {best_param_accuracy_nodes}, Epochs: {best_param_accuracy_epochs}, {best_param_accuracy * 100:.1}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test Accuracy Surface"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 303,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653358,
"status": "aborted",
"timestamp": 1615994110345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "ZGJVhz6iJU-7",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6QAAAMlCAYAAABkUz6gAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3xkZ333/c81fdRW0vZevestLmt7ixvuxsa0OwVIQigpNyRASB7yJAGS3DxJaEkgMSHADTY2BoIBU4wxxcaAsY2Nvb33Vd2VtLvqmj5zPX+MRjuSRtJIOzNH5ft+vealkebMOdeMpJn5nt9VjLUWERERERERkVJzOd0AERERERERmZkUSEVERERERMQRCqQiIiIiIiLiCAVSERERERERcYQCqYiIiIiIiDhCgVREREREREQcoUAqIiIiIiIijlAgFREREREREUcokIqIiIiIiIgjFEhFRERERETEEQqkIiIiIiIi4ggFUhEREREREXGEAqmIiIiIiIg4QoFUREREREREHKFAKiIiIiIiIo5QIBURERERERFHKJCKiIiIiIiIIxRIRURERERExBEKpCIiIiIiIuIIBVIRERERERFxhAKpiOTNGPM1Y4zNuvyt022SqckYU2WM6cv6W0oYYxY53S4REREpLQVSEcmLMaYS+F9Dfvx2J9oi08LvAmVZ37uBtzrUFhEREXGIAqmI5GtogABYb4zZ4kRjZMrLdTJDJzhERERmGAVSEclXdlgIj/BzkTEZY1YBN/V/mwJi/dc3GGOuc6ZVIiIi4gQFUhEZkzFmJXBz/7cW+Ousm3/PGOMrfatkCnsbYPqv/wJ4Ius2neAQERGZQRRIRSQf2QHiWeCLwLn+72uB1zrRKJl6jDGG9N9Txlf7Lxk6wSEiIjKDKJCKyKhyBQhrbQJ4NOtnOataxhivMeZ81kyq14/juE9l3e//HWPbLcaY/zDG7DHGnDPGxIwxLcaYZ40xf2uMqcnjeHVZx1vR/7PVxpiPGmN29+83ZYzZk+O+640xf2WM+a4x5qgxpscYE++/z47+tm3I97Fn7fcuY8yjxpgGY0zEGHPWGPOcMeY9xpjy/m0+ktXuj+S53zuMMV8wxhw0xrQbY6LGmDPGmJ8aY95rjAmOt63jcDOwsv96CPgO8CPgQv/PZgP3jXenxpj5xpi/McY83f98hfsvDcaYH/fftiKP/biNMW8yxjzS/7vs6P9dXjDG/MYYc3//82dy3HdcvwtjzK1Z2/9yPNsYY15jjPmGMea4Maa3//a/HHJfrzHm1caYfzXG/KL/dxzpf16a+p+XvzTGVIzV1hztGvfz3f/cZR7L/x3Hsd6Zdb9d422riIhMctZaXXTRRZcRL6QDhO2/hIGq/p9vyfp5DJg7wv0/l7Xdf+d5zIVAov8+SWDxCNvVAI9l7X+kSwfwO2Mcsy5r+xXA/+5/vEP3tWfI/b6Vx/Et6bGS/wG483j8PtJVw9H2dwhYB3wk62cfGWO/S0l3kR2rrc3AzUX6e3ow6zhfz/r5f2f9/Pvj2J8L+EegL4/HlQQ2jPG3fjTP3+cnctw/799F//a3Zm3/y3y2AWYB3x2hTX855Hd9Ps/Hch64q9jPN7Ap67YuoCzPYz6fdb8/L8bfpS666KKLLs5dPIiIjC67+vm4tbYbwFr7ijHmCHA54AV+H7g/x/2/BvxZ//U3GWPeb9MV1tG8hfQyIAC/sNY2D93AGLMA+DmwPuvHB4G9QC8wj3TAmA1UA98yxvyhtfbrYxwb0jMK/2v/9TPAC6Q/QC8i3UU527L+rwnSIfE40En6w/g80sF9Mekuz38J+IE/H+P43wB+K+v7dtJhpJ100LiF9ON+EvhBHo8HY8x64BnSYR/SH+539bc53N/GVwGV/Y/zaWPMvdbaX+Sz/zzbUEb6uc3I7qr7CBefl9cYY+ZYa8+PsT838G0GL0cUA14kfYIhDiwAriX9uF2kw36ufb2lvw3erB8fA3aT/t1XARv7Ly4gMFrbisSQ/n96Lenf3w7Svz/DxbCXUU76bx/SJ2QOAvWk/zd8pKvU20k/jtnAj4wxt1hrfz3iwS/x+bbWHjDGvAhcT/r5/B3Sz/nID9iYdcCN/d+GgXz+f0VEZCpxOhHroosuk/cCBEl/GM9UJ+4bcvuHsm7bNcp+TmZt99o8jrsza/t35LjdRTqMZrb5DbA5x3YB4P+Qrk5a0h/GV45wzLqs/cWBKPCngBmynX/I9x8nHbKqRtivAV4HtGXt/6ZRHvsfM7jK9O85jjmPdDdXC0Sytv3ICPssJx1cMtv9CFidY7sqBle0zwCzCvj39AdZ+z7LkGox6QCYuf0v8tjfJ4Y8V/8FzB5h263AV4CNOW7bzOBq+C5g2wj7WUB6Uq+/yXHbR8b6XQzZ/tas7X+Zxzbx/q/7gCtybOvPur4c+Ez/43aNsO+q/r+vzP6PjrRtoZ5v4B1Z9382j+foX7O2f6RQf4u66KKLLrpMnovjDdBFF10m74V01TPzYbAN8Ay5fTkXw57N9SG5f7v/L2ubb4xxzMuztg0BlTm2+cOsbV4EgmPsMzsofH6EbeqGfNj+gwI/l9uy9v3NEbZxA01Z231hlP35gJeHtPkjI2z7D1nbfHe00NG//cNZ2/9tAZ+Dp7L2++kx2rlzjH2tJV2Fzmz/d5fQruwuoa8AFRPcT/bfWc7fxZDtb83a/pd5bJMJ8nMK/Lf5+az931vM55v0WsadWfu5bJRtPUBL1ravKuTj1kUXXXTRZXJcNKmRiIzm7VnXv2GHdLW11tYDvxph+2xfy7r+emNM5SjHfGvW9cettT05tvl/sq6/21obzrFNtk+Q/hAM6Vlcx3rte9nm17U3b9ba3wCH+7+9Y4TN7iHddRbSY/T+bpT9xRi8/E5Oxhgv8N7+b6Okn6/UGHfLVL4hXdW8ZMaYxQx+3F/NsdnXso57jTFm0yi7/CsuTsz3EvDJCbZrGxe7hFrg7dba3onsq0T+yY7RlXkCHsq6fucI2xTk+bbWhhjc7faPRtn8tcD8/uvHrLW/GmVbERGZojSGVERy6g8Q2R9OcwUISI8Bu6X/+h8YY/7WWpvM3sBae9wY8zLpbnxlpMegjTR27Pezrn9t6I3GmIXA1f3fHrLW7h3tcfQfP9I/du1e0pPCbCLd7XEkj45y24iMMWuB64DV/cfxc3G5HPp/BjDbGLPUWts4ZBe3Zl1/0lrbOdrxrLW/MsY0cHEcay7Xke7iC/CMtbZt1AeR3u+Z/vHB64FNxphZ1tquse43hj/kYqA5aK3dneO4p40xz3Nxzdu3AyPNsHxP1vXPWmvtCNuNJXs/z1hrD01wP6XyzfHeof+kxDbgKtJdjisZ/P6ffYLo6hF2U6jnG9LLRmXGC7/dGPP3Q18z+v1x1vUHL+F4IiIyiSmQishI3srFAHHEWrtjhO0eIz1DaoD0h91Xkx6jONTXSAfSzL6HBVJjzI1cXBLkHPDTHPvJXjomaIz57CiPIdvqrOtLGT2Q7sxznwAYY+4D/pn0WMR8zQGGBtKrs67/Js/9vMzogTT7+Voyjueruv+rAZaQHkt8KbKr5yOd3Mjclgmkf2CM+buhYcUYM5/0TMgZv7iEdm0v0H5K4bS1tj3fjfuX7/kQ8G7Sf2/5GLZdgZ9vrLV7s05QLQReAzwx5JiLSJ9AgvTY2a9cyjFFRGTyUiAVkZHkFSCstd3GmMeBN2fdL1cgfRT4NOnXnduNMQustS1DtsnuHvro0C7C/RZlXV8JvGekto1irHVJz+W7o/71Jv/PBNqQq9vy3KzrQ8PqSJrGuD37+bqy/zJeY67jOhpjzFbSY4MhPeZ4tO7Q3yY9WY6fdFi5G/jxkG3mZ12PWmvPXELzsvd16hL2Uwrj+busIT3x19XjPEauv8tCPt8ZX+TiCao/ZkggJf06kplp+4fW2tYCHFNERCYhjSEVkWGMMVu4uJyKZeylFrID6+uNMdVDN7DWZlc83cDvDTmmF3jTCPvMNmuEn4/HWCfjxhqTCoAx5i4Gh9EXSa9fupl0pSlgrTWZC/Bs1ra5Xn8rsq6H8mkD6ZmDR1OK52s
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1200x800 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(multi_param_epochs, multi_param_nodes)\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-22 20:49:29 +00:00
"# fig = plt.figure(figsize=(10, 5))\n",
"fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
2021-03-22 20:49:29 +00:00
"surf = ax.plot_surface(X, Y, mean_param_accuracy[0, :, :], cmap='coolwarm')\n",
"ax.set_title(f'Average Accuracy')\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
2021-04-06 17:29:15 +01:00
"# ax.set_zlim([0, 1])\n",
2021-03-19 17:21:00 +00:00
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
2021-03-22 20:49:29 +00:00
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig(f'graphs/{exp1_testname}-acc-surf.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Test Error Rate Curves"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 313,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653349,
"status": "aborted",
"timestamp": 1615994110347,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "Jrn3hKQAlGcc",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJwAAAPUCAYAAADogiNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3wUdf7H8dc3nV6lBBCwUGyICIgFAQUPARVEPPTE0EU9RfEUf5wY8QQb2AFPQeD0kH40EQTpKk0EpYlAREjoSAikZ35/zGbZJNuSbArJ+/l47COTme985zs7szOzn/0WY1kWIiIiIiIiIiIigRJU1AUQEREREREREZGSRQEnEREREREREREJKAWcREREREREREQkoBRwEhERERERERGRgFLASUREREREREREAkoBJxERERERERERCSgFnEREREREREREJKAUcBIRERERERERkYBSwElERERERERERAJKAScREREREREREQkoBZxERERERERERCSgFHASEREREREREZGAUsBJREREREREREQCSgEnEREREREREREJKAWcREREREREREQkoBRwEhERERERERGRgFLASUREREREREREAkoBJxERERERERERCSgFnEREREREREREJKAUcBIREfHBGPOAMWahMeawMSbZGGM5XquKumyFyRizymXf23lIM8UlTZQfeXY0xsw0xvxujEl0WTfGTdoQY8xAY8w3xpijxpgUl/RT8rl7IsVObj9PUnj8uR7mIq8G3q59+cg3xiXfBgHKM9olz+gA5dmutN5XS5qCOpfl4hVS1AUQKQqOm+6BAGf7imVZ0QHOU0SKkDHGAJ8DDxV1WUoiY8wYYLifacOBr4F2BVkmEREREQkMBZxEShjHr2wrHf+utiyrXZEVRkoMY4yVOW1ZlinKshSyh8gabNoI7ATOOf7fW+glKiGMMbeQNdi0A9gKnHH8fzLbKs+RNdi0GvgNSHL8/0PgS1nyOWoT3O74t71lWauKrjQiIiJSkijgJKVVPPCRjzStgJaO6Vhgno/0G/NbKBEpdh5xmX7ZsqxRRVaSksf1vZ0EDLQsy/KUOFv6Ry3LmlYwxRIRERGRQFDASUoly7JOAU96S+Nol54ZcNprWZbX9CJSIt3gMj2pyEpxEbEsKwqI8iOp63v7mbdgkzGmLNDY8W8KdjNHERHxwrKsBkVdBhEp3dRpuIiIiGdVXKbjiqwUJVNu3lvXtEcty8oogPKIiIiISAAp4CQiIuKZsyawghwBF+oy7eu9zU1aERERESkGFHASCRBjTFNjzGhjzEaX4bqPG2M2GGNGGWMi/cynvDHmMWPMYmPMQWPMeWNMqjHmjDFmt2No9v8zxlyTbb1oR8fOK11m3+4yNKnrKyZA+xxqjHnEMaT5fmPMWWPMOWPMAWPMdGNMd8coX97ycDsUrjHmbkcee40xCY7lQx3L3A65aoy51RjzqeN9OuNY/q6H7d5ljJlsjPnVGBPvGI79d2PMPGNMlDEm1N162fLIMVy1MaayMeZpY8waY8xhY0yaY3ll3+9o/t8blzShjn180xiz0hgTa4xJcuznIWPMEmPMUGNMeX+2n22+u3PK65DLxph6xpiXjDFrHWVJNsacMsZsNca8bYxp5Od7EmqM+ZsxZq7jnEtwvMdnjTG/GWOWOj5vrfzJz8M2Yvzdby95XG2Mecuxfycc+xtr7GG0XzDGVPOjHFEu25rimBdsjPmrMWa+Y/8THcvvy+v+OvINMsY8aoz5xhhzxHGuxDi2k6u83X0uXJY5hxEH6rssOuDufHJJ6zqqaH03aWO8lKeaMWaYY9/+cOzbn8aYncaYj4wxN/qxTzmGATfGlDHG9DfGLDP2tTrFsfx6D3ncYYyZaIzZ4Tj3M8+JpcaYJ40xZfwoR45zzxjT2BjzrjFml+PzEG+M2WaMGWOMqe4rLy50GA6w0sNnO8pX2fxhAnCfNG6Gojf29eVfjv0+Zez70G5jzDvGmCtyWcbyxpinHMflkON8OW2M+cUY86ExpnUe9ruiMebvxr5/xziOU+bxX2GMedkYc3Uu8itrjHncGLPO8T4mO87t6cbujN+fPIwx5j5jzH+NMXsc5026472LMcZ8a4x5wxjT3hiT7+8LJgD3JJe8PD0D3GjsZ4Bfjf3sdNpxrv2fMaZcLsoasOthfuV3n4zL/cx4uUe7pG/vOCd+d+x3nLHv248bu2lzXvbhSmNfo3Y7zq9TxpifHJ/5unnJ05FvQV1Pqhr7Pr3J2PfvRGPfbyeZbM/e+WUCeE33kH++n0Wy5VfbGPOaMWa7o1zxxr6nvWOMaew7B4/5XjTPiJIHlmXppZdebl5ANGA5Xqu8pAsHJgJpLundvc4DT/rYZhvgkI98XF8hHsrr6xUTgPenHfYIUb629T1Qx0c+zvcZqATM9ZDXUMc6DVz3BQhzHAN367ybbXs1gOV+lPtX4EYf78EUl/RRwC3AQQ/5Vc7je5yr98axXj3ghJ/nwgmgox/b9+fVwE0eQcAoINHHuqnAa4Dx8n40wh4hzt/yXJHHczvG3224WTcEeB/f14PT2B1feytHlEv6KUAksNZDfvfl47NcC3uEN2/lnQtUcJyHmfPa+fO5yLZslY/tZDmfcpHW7TUNeAL408e6Gdj9c4V5eY+iXdJHA02BXzzkd322deth/xDgax8OA7f5OFZZzj3gMexR+rx9vt1ex3Lx3uY4jnk4xwJ5n8xyDgL3+DjG54FBfpazK3bzTl/vxxdAWT/zfAw45ef7/BdfnyfgKnxfB1/xUaaawHe5OP535vP4B+Se5JJfA5f0MYABXgHSveS7H7issK+HuXiPCmSfyHo/a+AlXQj2ddDbfu/A7ksv2mVetI/tP473+/9poBvZnneK8HpyC96fw9OwB7jI87HOtn1n3i7Xizxd090cz4A8i7jk2d2R3lNeScCA7OeyjzwvumdEvXL/UqfhIvng+HVpKfYNKtM+YAv2RbmqY1kkUAb4wBhT0bKs0W7yqufIq4JjViqwCTuocx4oh30RbwZUdFOcjdgj79UB7nPM8zS6XvbhxnPFGPMA9gN3Zi2gROwHtBjsL2+NsINnIcBNwPfGmJaWZR31lTV2Z8BdsW8Gm7FvHga4xjHPnXeAwY7pn4Ft2O9fI1ya3xhjagLrgctd1t0HbACSsR/kM3+9vhL71/6/WJa13ke5Aa4A3sUOCp0F1mC//1WAtn6s70tu3ptyQOavVqexHxJ/BxKwg3MNsY9LhCPdV8aY2y3L+i7bNg9zYTTHJ1zmexrhMT5LgY0JBmYA92fLcyNwHCiP/X5fjn2u/B9wCTAox84bUwE7UFjPMSsD2ArscuxXWexzvxmQq18A3ZjKhffPn/3OLGMQMAf7C3CmU9gPtKewy94e+xhUBqYYYypblvWeH2UKBxYALbAfIL/DPnfDydr5dq4Yu+bdt9gBlEwHsAPFycDV2CN2dicwTdnmYQdqAPpw4Xo3Dftz48p1NNEKjvQ40mUfoS7HNc3YtRufdpl1Anu/jmCf+82xPzsG6AdEGmO6WL6bTlYDvgYuxX7AXof9+SqP/blyLUNTYAVQ2zHLAn7E/uwmYp+zbR37Fwl8Y4zpbFnWSh9lwNi1jiY4/t2DfU1IBJpg33eMo6wLjDFNLcs6ky2LzPe2u2PbAP/D/oxmt8tXebyUM2D3STduxP4SEoZ9Dqxy5NkAu+ZWqCPPj40x6ZZleez43xjzIPZ9LdgxKx372P6GfWxv48L79BDQ0BjTwbKsJC95vg/83WVWOvZ9fS/2uXMJcL2jvGCfl95EYl8Ha2MH2dZin8/VgQ7Y9x+AkcaYnZZlzXBTpmBgMfa1JNMvjtefjjLUwr6W1s6+fh4F6p7kycvASMf0T9jPAanY723m9bEh8D9jzA2WZaW5y6QIrofeBGSfcmEa0Nvl/z+xA+Unsa917bCfkb7Cvhf5ZIwZTNb7Zir2Z/R37M99O8ff2dj3f3/yLMjryTXAGOzP+zHsz9dJ7Ot0B0d+wcBEY8zPlmX94E+Z/RWAa3pmPgF/FjHGdAFmcqGbgQzs5+lfsd+vttjXi0+
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1200x1000 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig = plt.figure(figsize=(6, 5))\n",
"# fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-04-30 20:51:04 +01:00
"for idx, (layer, std) in enumerate(zip(mean_param_accuracy[0, :, :], std_param_accuracy[0, :, :])):\n",
"# plt.errorbar(multi_param_epochs, 1 - layer, yerr=std, capsize=4, label=f'{multi_param_nodes[idx]} Nodes')\n",
2021-03-22 20:49:29 +00:00
" plt.plot(multi_param_epochs, 1 - layer, '-', label=f'{multi_param_nodes[idx]} Nodes', lw=2)\n",
2021-03-19 17:21:00 +00:00
"\n",
"plt.legend()\n",
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.title(f\"Test error rates for different epochs and hidden nodes\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
2021-04-30 20:51:04 +01:00
"plt.ylim(0, 0.6)\n",
2021-03-22 20:49:29 +00:00
"\n",
"plt.tight_layout()\n",
2021-04-30 20:51:04 +01:00
"plt.savefig(f'graphs/{exp1_testname}-error-rate-curves.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-03-22 20:49:29 +00:00
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test/Train Error Over Nodes"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 314,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJwAAAe8CAYAAAD4aop3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gU1f7H8fdJIyHUANKbIMWCCtggQBBUsIHYr17sveu9LlaCXDXBa/up2JWLXr0WELGhiAYJoCKKKIIIUqSXQGgJkOT8/pjdZDbZJJtkN/Xzep59Mjt75szZnS3ffOfMOcZai4iIiIiIiIiISKhEVHUDRERERERERESkdlHCSUREREREREREQkoJJxERERERERERCSklnEREREREREREJKSUcBIRERERERERkZBSwklEREREREREREJKCScREREREREREQkpJZxERERERERERCSklHASEREREREREZGQUsJJRERERERERERCSgknEREREREREREJKSWcREREREREREQkpJRwEhERERERERGRkFLCSUREREREREREQkoJJxERERERERERCSklnEREREREREREJKSUcBIRERERERERkZBSwklEREREREREREJKCScREREREREREQkpJZxERETCwBiz2hhjvbdOVd0eERERkerAGNPJFSOtrur2SPgo4VSHGGPSXB/sst4mVXX7axpjzKRSXtO9xpgNxpivjDH/MsZ0q+o21yXGmKRyfhbSqrrtoWCMiTfGnGmMedIYM9sYs9EYs9/7vlxjjJlmjLnOGBMfhn1fHuB1PT3IbTsV2i421O0TkbpJcVLlUpxUvdX1OMnHGNPKGOMxxnzhfT9me28bjTGzjDEPGGM6hHifhT8be4wxLYLc1h1jfRvKdomUhxJOIlWnPtAaGAzcBywzxjxXFf9AG/XEqFOMMW8CW4CPgNuBgUArIAbnfdkBGAG8AKw2xoyqhGaNN8aYStiPiIjUDIqTpEoZY24FVgIpwCk478d63lsr4GTgIeAPY8wDYWxKPHBPGOsXCZuoqm6AVJkFwPdlKK8MecUsA2YVWtcA6AkcBxjv7UagtTHmXGutrdwm1nnPBVnuj7C2onKchxMs+ezA+T7YgPM+7Akc711uDkwxxtxgrX0hjG3qDYwCpoRxHyIiwVKcVLkUJ1V/dSlOwhhzD/CIa1UuznfCKu/9Q3Hem5E4J+weMsY0tdbeGaYm3WCMedxauz5M9YuEhRJOdden1trkqm5EHfKdtfbmQA8YYw4H3gZ6eVedg/7xrnTFHZ9abB/O++414FtrbZ77QWPMEcCbwDHeVc8aY+Zaa38JY5seMsZ8ULgtIiJVQHFS5VKcVM3VpTjJGHMYMM61ahZwo7V2eaFyPXF6gw/0rrrdGPNfa+3CMDQrFngAuD4MdYuEjS6pE6li1trfgOHAXtfq66qoOVI3PAMcaq292lo7L1CCx1q7BKer+Brvqkjg3jC05Xdgp3f5cOBvYdiHiIjUUIqTpApcAkR7lzcAIwonmwCstUuBM4FN3lWG0Mcx7t6TVxpjOoe4fpGwUsJJpBqw1m4A3nWtStR4NhIu1tp/Wms3B1FuB5DqWhXUwN5ltBP4t+t+sjFGvW9FRCSf4iSpZL1cy9OttXuLK2it3Q1Md60K9eD2s4EvvcvRQHKI6xcJKyWcpEICDaJojOlijHnYGPOTMWarMSbPGLPItY17Fpgk77rWxph7jTHfG2M2GWNyjTE7i9lnR2PMQ8aYb40xm40xB7x/vzXGjDPGtA+i3e6ZN9Jc6083xrxtjPnDOyOENcbcXoGXqCwWuZbjgKbFFTTGRBhjBnhfhy+MMWuNMfuMM8vYRu+MLvcZY5qXUEf+bF9AR9dDq4qZdSSphLqaGWPuMsbMNMb85Z29Y6cx5jfvAJ99g3kBjGOkMeYtY8zvxphd3vfCXu977StjTKoxZrAxptp9fxljGhhjbjXGfG6MWed9HXYYY341xjxrjDkhyHryX3fXuqONMU9768rwPj4tbE+mwFzXciNjTEIY9vE0sNW73AW4MpSVG2OaG2PGGP/Z+LZ5v6Me816uUZb66hljbjHGzPF+x2UZY1Z637eDK9DOeGPMDcaYj4wzU+A+Y8xu7/fRa8aYk8tQ18nGmFeNMb94P4s53vrWedv9lHFmKYwpb3tFpHRGcVIoLXItK05SnBTOOKmBa3lHEOUzXMvheN3vdy1fYozpEcrKjTFHGCce+sk48dF+48zIl2acGfqalbG+1t7vuMXe9+guY8wS48yM3L0C7WxvnFkB53jbt997rH8yxvzbBDmTpTEm2hhzqTFmqjHmT+93WY5xYq4V3vfnQ8aY48vbVnGx1upWR25AGmC9t+QQ1bnaVWcn4Fogy7XOd1tUTDuScGbDygiwzc4A+7uvmPrdtyzAU0q7k1zl04DGwNRi6ru9nK/NJFcdk4Iof02h/bYpplw0sK6U18B32wNcWkw9nYKsI/9YFVPPTTi9VEraNg94FYgp4fm3BOaVoT1DK/jedb8HbAg+C2cCG4No93+B+qXU5dcunLNZOQHqmhaKz3EpbTmq0D5bhKDOy131fetdd6dr3VqgXpDv29hS9nVlEO/PHOBJIDKItvfEGdy2pPqex/mcrnat61RKvecH+f75CGhcQj3xwIdl+BxdHe73kG661ZQbipPc9SS5yqehOCmYW1Ix9ShOcuqrMXES8Iqrjg+DKP+Rq/wjIXit3J+NFO+66a5175aw7eWuct+Wsp8o4P+Kee3ctx3AZUG2/Rxv+eLqygauLvT5Wl1KnRE4swGW9t12EHgYMCXU1Q34rQyfo64VPZ51/abLFiSUzgcmeJc34PSMyATaAMX1iuiH8yMRDWwHvgG2AYcAx7oLGmOexfnR9tkDfI1z3XQrnGlzG+AMqpdijGllrb0jiHYbnMGRz8T5YvkB54vIAEd611WGNq7lXJzXI5BIoK13eQ+wBPgT2IXzOrYDTgQa4fzz+YYx5qC19p1C9eyiYMaR0UBD7/JkYHeA/RaZFcMY8xRwm2vVNmA+zjGJxTmGR+K8llcCbYwxZ9iiA1RHAp8AfVyrf/XednrragUcjTMlbbVijLkQJ0CK9K7KBdKBFTjvyQEUHN+/AZ2NMSdba7ODqPufwFjv3ZU4M6Tsw/mhPhiip1CSo1zLWTjHOBwm4iSd2gLtcQbFfLoiFRpj/gE85lq1H6dr+lqcM+ODcb6bIoHbgQ7GmPOsNyIJUF9HnIFD3e/BJcCPON8TvXHe79fjHKNg23kH8DjO5wScz+Z8nH+YIoEjgL7ex88E0owx/a21gfbxJnC26/4K4Cecf1ajgRY4x7RTsO0TkZBRnFQxipMUJwWqOxxx0nTgKu/yGd7f3LmBChqnZ7NvyIH9OInDcHgA5zNogPOMMcdYaxeVtzJvD7gp+McMGTgJ5gycWGwwzgx8TYBJxpgm1tpiYzNjzBk4l776cgx5ON9zy3GO80Cc9+fLwK1BtjMSeAc417V6Pc5x3uqt9wScHvJROOONtsBJ7heuqyHO5Ym+Xp55ODHSUpzvivo43x1H48zSLKFQ1Rkv3SrvRvjP3B3E+aK9hkKZZVy9FQq14yDOh/1+ILqEbS7AP9v8OtCoUPlGwBuFyo0qpt1JhdpggcXAUQHKBuxpEcRrM8m1j0lBlHeftVpQQrkYnJnFkgq/Zu42A/90PbcdQIMgj2OnIJ/fla5tMnHOVhRpD86PlftM490ByoxwPb4BOKGE/R4BpADHV/C9634P2ArU0wUn8PTV9R2FzobgnJm5EyfA8pX7vxLqLHy2ZicwMlTvzTI+vy9cbfk4RHVe7qrzW9f6G1zrNxHgDCdB9nDC+SfNfcbuU6BlgM/JhEL13VlCu790ldsJnBmgzHAKeiIcKO1zBQxxvS/2A55invcxOP80+eqbGKDM0a7HdwPDS3guh+L0hDgr3O8h3XSrKTcUJ7nLJhVqg+IkxUnlrafGxUk4SR13/JOFc2KoN07ypYl3+WmcHjvW+/dvFXnNi/lspLjWv+NaHzAmI8geTsDdhV7HRynUww4nifl5odc64HsPaAZ
2021-03-22 20:49:29 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1200x2000 with 10 Axes>"
2021-03-22 20:49:29 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(6, 6*math.ceil(len(multi_param_nodes) / 2)/3))\n",
2021-03-22 20:49:29 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rates For {nodes} Nodes')\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[0, idx, :], markersize=4, lw=1, label='Test', capsize=4, c=(0, 0, 1), ecolor=(0, 0, 1, 0.5))\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[1, idx, :], markersize=4, lw=1, label='Train', capsize=4, c=(1, 0, 0), ecolor=(1, 0, 0, 0.5))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(1 - mean_param_accuracy + std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-04-30 20:51:04 +01:00
"fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate.png')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 315,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABLoAAAe8CAYAAAC6KfngAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdeXxU1f3/8dfJBmEJOwYCAiqyubC4RA2WuqJWq9StdcPWLn7b6rfan0m1SJCqia3Wfltb22pF3KpW3Nq6K0pQcENFFgFlkbDvWyDb+f1xZjJ3JjPJJJnJZCbv5+Mxj9y5c+69Z+beufPJ5557jrHWIiIiIiIiIiIikuzSEl0BERERERERERGRWFCiS0REREREREREUoISXSIiIiIiIiIikhKU6BIRERERERERkZSgRJeIiIiIiIiIiKQEJbpERERERERERCQlKNElIiIiIiIiIiIpQYkuERERERERERFJCUp0iYiIiIiIiIhISlCiS0REREREREREUoISXSIiIiIiIiIikhKU6BIRERERERERkZSgRJeIiIiIiIiIiKQEJbpERERERERERCQlKNElIiIiIiIiIiIpQYkuERERERERERFJCUp0iYiIiIiIiIhISlCiS0REREREREREUoISXSIiIiIiIiIikhKU6BKRuDLGTDbGWN9jRqLrIyKtS+cAEZHIdI4Uad+MMTM854DJia5Pqkj5RJcxZrbnwGnqY0ai659sQr6o4R57jTHrjDFvGmN+Y4w5PNF1TkbGmHHGmHuMMe8aYzYaYw4YY/YbY7YYYz4xxjxtjCkyxow3xmQkur6xZIwpbuD4qjbGbDXGfGWM+ch3PP7MGDMm0fVuy1pwnpyQ6LrHgjFmhDHmOmPMU8aYxcaYncaYKt/36UNjzB/idQwZY1aFfKbzm7Cs93xbEo/6SepTnNS6FCe1DsVJipNiqb3HSQDGOcMY8w9jzOfGmB2+42mHL3Z6xBhznjEmPYbbHBzmM72rCct7Y6yJsaqXJIeUT3RJm9MJ6Ad8E7gFWGqMuc8Y07G1KxJy8hvc2ttvDmPMAGPMf4APgV8AJwB9gSygA9ALOBq4ELgTeAfYEulHJyQYKm6FtxBv6UBPYAgwFrgK+CPwsTFmgTHmx7H8AZbkZow53RjzObAY+ANwETACyAEycN+nccB1uGPoKWNMzzhX6zhjzHlx3oaItF2Kk1pAcVKjFCdJk/m+/+8ArwBXA6OAbrjjqRsudroceB543xgzIo7V+akxJjeO65cUkVJXMKLwAfB+E8rPi1dF2omlwBsh87rgTobHAsb3+B+gnzHmO9Za27pVTB7GmEG4H5mDPbN34I7pcqAKF7wMx33G/kClG+5zTkXrgGdD5nUFugODgSMIJPRHA/cDk40xV1hrV7ROFZNOU86T5fGsSCsYhwvW/CzwGbAM2A70AQp8f8ElwkYaY75hrd0ax3pNN8a8qPOhJIDipNalOCmGFCeFpTgp9tpTnIQvqTQbGOSZvRYXL60H+uOSx/19r40F3jHGHG+t/SoOVeqEuwjw8zisW1JIe0t0/ddaW5zoSrQj8621Pwv3gjFmJPAEcJRv1gXAJOCZVqpbMvoHgeBtF3A98Ji1tiq0oDEmB/gW8D3grFarYetbHukYAzDGdMV9DjcAx/hm5+OuNuVba5e1Qh2TTXs8T34C/B14MjSBZYzJwrUKuB33T9Eo4M/AJXGsz1G+9f8zjtsQCac9fv8TSXFSbClOqk9xUuy1t/NkCYEk1wFcK/d/WGur/QWMMZnAj4B7cK0newP3AvFqof4jY8xvrbVr4rR+SQG6dVESwlq7GBdY7PXM/nGCqtPmGWOOBU7xPbXAudbaGeGCNwBr7S5r7ePW2m8BQ4GaVqpqm2Kt3W2tfQI4DriZwOfQA/i3MaZbwionbcEy4Hxr7Rhr7Z/DtdKy1lZaa0uB/+eZfbExZngc6uNtHVOs20dE2i/FSU2jOKl5FCdJQ4wxnYCLPbOKrLV/8ya5AKy1Vdba+3AtrfzONsb0iHGV/HFSFnBrjNctKUaJLkkYa+064CnPrAJjTKo2HW+pMzzT862170S7oLX2q/Z+q4N17gR+5Zk9FHe1V9opa+0sa+3zURb/P9wtIH5nx6FKvwe2+aaHAVfGYRsikiQUJzWJ4qQWUJwkERwGZHueP9FI+cc80+nAITGuz68901cZY4bGeP2SQpToaoZwnXMaYw41xtzu68hxszGm1hjziWcZ72gdE3zz+hljbjbGvG+M2WCMqTHG7IiwzUHGmNuMMfOMGz2m0vd3njFmmjFmYBT1nuCpw2zP/LONMU8YY5YbY/b4Xv/fFnxETfGJZzobdwUpLGNMmnGj49xmjHnVGLPGGLPPuJF01hs3QtEtxpjeDayjbvQOgu81X2maOFqKMaaXMeZGY8xrxpivjRvNxz/yyH3GmGMiLdsMeZ7p1S1dmf94BKZ6Zk+N8BnMaGA9Y40xfzdu9J4K37H/vjHmJhP/Trub43fAXM/z640xXRpbqCX72hgzyfNZfhFtRY3rULfGBEZJSqqON40x+caYPxljFhljtvs+s7XGmJeNG+GpcxTrqNcJsDEm2xjzA885oNL3+uh4vh9rbQ3gHRFxcBw2swvwjiZ0q3G3TsaMMeZM40ZMWmaM2eX73q42xjxr3BD3mU1cX1zOAcaYY40xvzdudLTNvv28wRjztjGm0ER5ldgY09sY80tjzOvGjWS337gRNXf4js1/GWNuMMYMaW5dpT6jOCmWPvFMK06KTHFSbChOaiUmOeKk0H2/vZHy20KexzrXMBd42TedAUyL5cqNMZnGmKuNMc8ZFxtVGBcrfWGMedAYc3oz1nm+MeZ5Y0y571y81vc9ucI0c8RX41xgjHnYuHhup+/4+dpX96uiXbcxZrgx5i7jfiu3+I6X/caYTcaNzPqQb32xbp0Xf9balH7gOs+zvkdxjNa5yrPOwbh7kis88/yPTyLUYwLwbdzJIHSZHWG2d0uE9XsfFUBhI/We4Ck/G9f55qwI6/vfZn42MzzrmBFF+R+GbLd/hHKZuI4PG/oM/I89wOUR1jM4ynXU7asI6/kproPThpatBR4EsmJwzP3Js975Mf5eNPYIux+B3wDVDSz3Na6fh8lNOSaiqHux9zhuxvKTQup5QSPlW7SvcaM8bfeUPTbKet7kWeaVln5uTTgeilu4rs64vqUaO67WAWc1YV8X4zoP/jzC+kbH8zPy1ecZz/bui9E6V3nWORHXyeoGz7z/aWDZGZ5yJY1spy/wehT7ZRlwTJR1j/k5APdP/L+iqOd24MJG1vVtwv/Ohnusjffx01Yfsfz+RziuB6M4ybsN7/e2we+Dr7zipOg+V8VJge0We4/jZiyvOKnh46G4hetKmjgJGBCyjjGNlB/nKVsJdGnhZzU4ZPsdQ7ZRAxzRwPKrPGUnNrKt44EVUeyXV4HeUdS9C/CfRtY1B8gl+HdhciPrPQpYEEU9lwIjozh+GjpHeR+PxvM7Fo9He+uMPh4uInAFfh0u07wTN/JEpCs1J+IOrExgK76hjXH/iIzxFjTG/An3A+K3B3gL949QLm746S64L36JMSbXWvuLKOptgEdxHVBa3DDMi33zj/DNaw39PdM1uM8jnHQCV+v2AIuAr3AtIDJxJ+J8IAf3A/KIMabKWvtkyHp2Aff5pq/EjTwDMBPYHWa79UZLMcbcS3BT7i3Ae7h90hG3D4/AfZbfB/obY86x1tZGeG/R+NIzfYwx5lRrbehITU3xLO6H8DjcyE4QeRSZeqNqGWPuILh5+z7gTdzoK7m4fjIGAP/FdUbZlrwI7MftK4Dx1B+RCIjNvrbWHjDGPI37ZwXgMtxn3ZjLPNOPRFE+4Yzry+FN3HHltw73Q74H1wS+APd97ge8YIz5rrX2X1GsvhfuKt7BuP1Xhrtq3wX33W8NR3qmv47HBqy1+4wxt+NulQT4tTHmIWttRXPXaYw5CPfbdKhn9pe4FmoHgJG4AA/crSpvGWMmWmvnEkE8zgG+q/Fv4gJ1v0XAp7jjpy/u+9oLN2LYU8aNDPZYyKrwtR74F4FBdypw57J
2021-03-22 20:49:29 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1200x2000 with 10 Axes>"
2021-03-22 20:49:29 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(6, 6*math.ceil(len(multi_param_nodes) / 2)/3))\n",
2021-03-22 20:49:29 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rate Std Dev. For {nodes} Nodes')\n",
" ax.plot(multi_param_epochs, std_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, std_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-04-30 20:51:04 +01:00
"fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate-std.png')"
2021-03-22 20:49:29 +00:00
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "eUPJuxUtVUc3",
"tags": [
"exp2"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 2\n",
"\n",
"For cancer dataset, choose an appropriate value of node and epochs, based on Exp 1) and use ensemble of individual (base) classifiers with random starting weights and Majority Vote to see if performance improves - repeat the majority vote ensemble at least thirty times with different 50/50 split and average and graph (Each classifier in the ensemble sees the same training patterns). Repeat for a different odd number (prevents tied vote) of individual classifiers between 3 and 25, and comment on the result of individualclassifier accuracy vs ensemble accuracy as number of base classifiers varies. Consider changing the number of nodes/epochs (both less complex and more complex) to see if you obtain better performance, and comment on the result with respect to why the optimal node/epoch combination may be different for an ensemble compared with the base classifier, as in Exp 1). \n",
"\n",
"(Hint4: to implement majority vote you need to determine the predicted class labels -probably easier to implement yourself rather than use the ensemble matlab functions)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 113,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2",
"exp-func"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [],
"source": [
2021-03-27 16:29:31 +00:00
"num_models=[1, 3, 9, 15, 25]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_ensemble_vote(hidden_nodes=16, \n",
" epochs=50, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
" weight_init=lambda: 'glorot_uniform',\n",
2021-03-19 17:21:00 +00:00
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
2021-03-26 20:01:05 +00:00
" round_predictions=True,\n",
2021-03-19 17:21:00 +00:00
"\n",
" nmodels=num_models,\n",
" tboard=True,\n",
" exp='2',\n",
2021-03-19 17:21:00 +00:00
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
"\n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for m in nmodels:\n",
" if print_params:\n",
" print(f\"Models: {m}\")\n",
2021-04-06 17:29:15 +01:00
" \n",
" response = {\"epochs\": list(),\n",
" \"num_models\": m}\n",
" \n",
" ###################\n",
" ## GET MODELS\n",
" ###################\n",
" if isinstance(hidden_nodes, tuple): # for range of hidden nodes, calculate value per model\n",
" if m == 1:\n",
" models = [get_model(int(np.mean(hidden_nodes)), weight_init=weight_init)]\n",
" response[\"nodes\"] = [int(np.mean(hidden_nodes))]\n",
" \n",
" else:\n",
" models = [get_model(int(i), weight_init=weight_init) \n",
" for i in np.linspace(*hidden_nodes, num=m)]\n",
" response[\"nodes\"] = [int(i) for i in np.linspace(*hidden_nodes, num=m)]\n",
2021-04-29 22:53:26 +01:00
" \n",
" elif hidden_nodes == 'm':\n",
2021-04-30 20:51:04 +01:00
" models = [get_model(i+1, weight_init=weight_init) for i in range(m)]\n",
" response[\"nodes\"] = [i+1 for i in range(m)]\n",
2021-04-06 17:29:15 +01:00
" else: # not a range of epochs, just set to given value\n",
" models = [get_model(hidden_nodes, weight_init=weight_init) for _ in range(m)]\n",
" response[\"nodes\"] = hidden_nodes\n",
2021-03-19 17:21:00 +00:00
"\n",
" for model in models: \n",
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
" metrics=metrics\n",
2021-04-06 17:29:15 +01:00
" ) \n",
2021-03-19 17:21:00 +00:00
" \n",
" if tboard:\n",
" if callbacks is not None:\n",
" cb = [i() for i in callbacks] + [tensorboard_callback(prefix=f'exp{exp}-{m}-')]\n",
" else:\n",
" cb = [tensorboard_callback(prefix=f'exp{exp}-{m}-')]\n",
" \n",
2021-03-26 20:01:05 +00:00
" ###################\n",
" ## TRAIN MODELS\n",
" ###################\n",
2021-03-19 17:21:00 +00:00
" histories = list()\n",
" for idx, model in enumerate(models):\n",
2021-04-06 17:29:15 +01:00
" if isinstance(epochs, tuple): # for range of epochs, calculate value per model\n",
2021-03-26 20:01:05 +00:00
" if m == 1:\n",
2021-04-06 17:29:15 +01:00
" e = np.mean(epochs) # average, not lower bound if single model\n",
2021-03-26 20:01:05 +00:00
" else:\n",
2021-04-06 17:29:15 +01:00
" e = np.linspace(*epochs, num=m)[idx]\n",
2021-03-19 17:21:00 +00:00
" e = int(e)\n",
2021-04-06 17:29:15 +01:00
" else: # not a range of epochs, just set to given value\n",
2021-03-19 17:21:00 +00:00
" e = epochs\n",
2021-03-26 20:01:05 +00:00
" \n",
"# print(m, e) # debug\n",
" \n",
" history = model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
2021-03-19 17:21:00 +00:00
" epochs=e, \n",
" verbose=verbose,\n",
"\n",
" callbacks=cb,\n",
2021-03-19 17:21:00 +00:00
" validation_split=validation_split)\n",
" histories.append(history.history)\n",
2021-03-19 17:21:00 +00:00
" response[\"epochs\"].append(e)\n",
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## FEEDFORWARD TEST\n",
" ########################\n",
2021-03-27 16:29:31 +00:00
" # TEST DATA PREDICTIONS\n",
2021-03-19 17:21:00 +00:00
" response[\"predictions\"] = [model(dtest.to_numpy()) for model in models]\n",
2021-03-27 16:29:31 +00:00
" # TEST LABEL TENSOR\n",
" ltest_tensor = tf.constant(ltest.to_numpy())\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## ENSEMBLE ACCURACY\n",
" ########################\n",
" ensem_sum_rounded = sum(tf.math.round(pred) for pred in response[\"predictions\"])\n",
2021-03-27 16:29:31 +00:00
" ensem_sum = sum(response[\"predictions\"])\n",
2021-03-26 20:01:05 +00:00
" # round predictions to onehot vectors and sum over all ensemble models\n",
" # take argmax for ensemble predicted class\n",
" \n",
" correct = 0 # number of correct ensemble predictions\n",
" correct_num_models = 0 # when correctly predicted ensembley, proportion of models correctly classifying\n",
" individual_accuracy = 0 # proportion of models correctly classifying\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" # pc = predicted class, pcr = rounded predicted class, gt = ground truth\n",
" for pc, pcr, gt in zip(ensem_sum, ensem_sum_rounded, ltest_tensor):\n",
" gt_argmax = tf.math.argmax(gt)\n",
" \n",
" if round_predictions:\n",
" pred_val = pcr\n",
" else:\n",
" pred_val = pc\n",
" \n",
2021-03-27 16:29:31 +00:00
" correct_models = pcr[gt_argmax] / m # use rounded value so will divide nicely\n",
" individual_accuracy += correct_models\n",
2021-03-26 20:01:05 +00:00
" \n",
" if tf.math.argmax(pred_val) == gt_argmax: # ENSEMBLE EVALUATE HERE\n",
2021-03-19 17:21:00 +00:00
" correct += 1\n",
2021-03-27 16:29:31 +00:00
" correct_num_models += correct_models\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), True) # debug\n",
"# else:\n",
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), False)\n",
" \n",
" ########################\n",
" ## RESULTS\n",
" ########################\n",
" response.update({\n",
" \"history\": histories,\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"model_config\": json.loads(model.to_json()),\n",
" \"loss\": model.loss,\n",
" \"round_predictions\": round_predictions,\n",
" \n",
" \"accuracy\": correct / len(ltest), # average number of correct ensemble predictions\n",
" \"agreement\": correct_num_models / correct, # when correctly predicted ensembley, average proportion of models correctly classifying\n",
" \"individual_accuracy\": individual_accuracy / len(ltest) # average proportion of individual models correctly classifying\n",
" })\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"models\"] = models\n",
"\n",
" yield response"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Single Iteration\n",
"Run a single iteration of ensemble model investigations"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 224,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Models: 1\n",
2021-04-30 20:51:04 +01:00
"[200] [20]\n",
2021-03-19 17:21:00 +00:00
"Models: 3\n",
2021-04-30 20:51:04 +01:00
"[1, 200, 400] [20, 20, 20]\n",
"Models: 9\n",
2021-04-30 20:51:04 +01:00
"[1, 50, 100, 150, 200, 250, 300, 350, 400] [20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
"Models: 15\n",
2021-04-30 20:51:04 +01:00
"[1, 29, 58, 86, 115, 143, 172, 200, 229, 257, 286, 314, 343, 371, 400] [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
2021-04-06 17:29:15 +01:00
"Models: 25\n",
2021-04-30 20:51:04 +01:00
"[1, 17, 34, 50, 67, 84, 100, 117, 134, 150, 167, 183, 200, 217, 233, 250, 267, 283, 300, 316, 333, 350, 366, 383, 400] [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
"single_ensem_results = list()\n",
2021-04-06 17:29:15 +01:00
"# for test in evaluate_ensemble_vote(epochs=(5, 300), optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
2021-04-30 20:51:04 +01:00
"for test in evaluate_ensemble_vote(hidden_nodes=(1, 400),\n",
" epochs=20,\n",
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
2021-04-06 17:29:15 +01:00
" single_ensem_results.append(test)\n",
" print(test[\"nodes\"], test[\"epochs\"])"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 225,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABXgAAAOcCAYAAAD0DtmZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3xb1fnH8c+Rd5wdO3s4iWM7ISE7IdNhhkDYZSWUsjootBQoBbp+dDFKC2UVWqCMsguFssLKcHYIGYSQ2M7ew07iLG/7/P64siNbki07tmVZ3/frdV6W7j336tH01aNzn2OstYiIiIiIiIiIiIhI6HEFOwARERERERERERERqR8leEVERERERERERERClBK8IiIiIiIiIiIiIiFKCV4RERERERERERGREKUEr4iIiIiIiIiIiEiIUoJXREREREREREREJEQpwSsiIiIiIiIiIiISopTgFREREREREREREQlRSvCKiIiIiIiIiIiIhCgleEVERERERERERERClBK8IiIiIiIiIiIiIiFKCV4RERERERERERGREKUEr4iIiIiIiIiIiEiIUoJXREREREREREREJEQpwSsiIiIiIiIiIiISopTgFREREREREREREQlRSvCKiIiIiIiIiIiIhCgleEVERERERERERERClBK8IiIiIiIiIiIiIiFKCV4RERGRMGCMGWiM+bsxZp0x5qgxxnq0pGDHJ7UzxkzxeM7mBTsecRhjXvR4Xq4LdjyhyBhzncdj+GIT3u48j9ud0lS3K81TY33GGmO26v+tiDQ2JXhFRJohY0xSteRLQ7T7gn2/moIx5r6TeIy2Bjv+UFPtS7mvVmSM2W+MWe5OLk4OdszhyBhzIbAKuBkYCLQObkQNp9prcGsdt63+WZvUOFGKBK5awrii3VbHfdziYx8vNlLI4lYtYVzX9mKw4xcRkdClBK+IiDQpjUALO9FAIjAKJ7mYYYyZa4zp1dSBhOsoO2NMa+BFIMa9aA/wH+DvwFPudiQowYmcJM8EWbBjaWTX1rH/9xolChEREWmWIoMdgIiI+HQEJ+lSkzHAaPfl3cC7tfT/8mSDCkGBPC6eDjRWIGHiKPBytWWxQB9gAhDnXjYFmGOMOc1aq8e88U0HOrgvfwuMttYWBDEeEam7EcaYU6y139bW0RiTxonjAwme5dTt2GtpYwUiIiItnxK8IiLNkLX2IHBrTX3cJRcqvsBtsNbW2D9M6XFpWgf9Pd7GmE44I0avcC9KBu4DftI0oYW1ER6XX1dyN3RZa+cBJthxSJNaBwxyX74WuDuAbTxH+3puL03rY2vtfcEOQkREwoNKNIiIiEijc4/UnYEzoqnC9caYqCCFFE46eFzeE7QoRKQ+PgFy3JdnGmNq/P7mXn+N+2oOMKsRYxMREZFmQgleERERaRLW2jLgSY9F8cDIIIUTTjyT6OVBi0JE6qMEeN19uQdwZi39Twcqapy/BpQ2UlwiIiLSjCjBKyISJowxA40x9xtjvjTG7DPGFBtjcowxy4wxvzfGdA9wP62NMT8yxnxkjNlujMk3xpQYYw4bYzKNMR8YY35pjBlcbbv73JPgzPVYnO5nJumtDXjXT5qvybmMMe2NMbcZY+YbY3YZY0rd69u719/nsc197mVxxpgbjTGfuR+7Yvf6YT5us7Ux5qfGmE+NMTuNMYXGmEPGmLXGmCeNMWMDjN1rAiJjzFBjzGPufR10r3/vpB+owKyudr3G1537dXu7Mea/xpgsY8xR9+stxxjzlTHmUWNMjacfG2O2uu+/56RDL/h57d1Xw36ijDHfNca8ZYzZ7I7luDFmizHmdWPMJcaYgE6fN8aMdj+PK93Pa6kxpsAYs8cYs9QY87Qx5gpjTHwg+/Ox/8rXXwD3+zo/+2ipr8GTYoxJMMbcY4zJcD9fRcaYXGPMKmPMw7W9Ht37SPL1eWeMmWiMec44n6WH3ev/5rG+xkkqjTHX+Xld19a2Vt9Xtf32Mc7/iaXmxP+Pfe7rvzMBTJroL3ZjzBnGmDfc76lCY8wB43yu3mr8jPD33Fe15f7uX5KPfZz0Z0sT8qxtXttka57rq9dED4gxZqox5l/GmGxjzBH3Z9M2Y8y77tdYnc68MMaMMMY8636OC9yP8ZfGmF8YYzrWJ0b3fke7n6fV7n0WG2P2ut+bdxtjOtS+l+bPuP+Heb6WjTE9jTF/MMZ8bYzJM87/okxjzBPGmD4B7rdex3M17O+knw8/9zXZOJ+ta92xFbjv9y+NMa187CPVOP+jvnG/fvPcn1W3GGMiArkv1fZnjDGXGmPed78PCt336zNjzLWmllH19WFC6HhDRJoRa62ampqaWgg2nPql1t3m1dAvBngGZxSPraHlA7fWcpvjgJ217MezRfqJt7a2tbEflzru80WPfV6HM2HYdj+xt/cRx33AQGCtn22GVbu96Tin0tf2OL0KtKol9sr+HnH5ei28V8/H5rq6PG/AgGq3O6OGvm8F+HopBx4FIvzsZ2sdXnv3+dnHFGBjANsvAXrUcJ8igX/UIZ4/NsD7oLZ2nY/tW+xrsNq2SdViSKql/w1AXi2PSWlNr0cft7sViMb5nPa1v79Vex1WLJ9Xy2NRl+b3cQN+BRTUsn0BcHctj12V2N33+Z+17HcFkFDLvgJpSdW2b5DPFo/9veixjdf7qR6vac/9PeheVvH/4xjQ2s928TiTXVpgrXvZgx77erGW2+0MfBHA45INjArwvvyRmo8/dgCnVXvt1hZnB+DtAOI8BHynln3N8+g/pQGeO8/93Xey+3Pvc6vnaxm4mJo/h/KB82vZZ72P5xr5+ah+X68Bjtewz5VAB4/tfw2U1dB/LjX8z8L7c6oN8F4t92sx0Lku96uWvlMIoeMNNTW15tM0yZqISAvm/jX+U5yEZIVNOF+aDwEd3eu6A3HAE8aYttba+33sq5d7X23ci0pw6qluxPkyEY9zMD4UaOsjnC+Bp3BOMb3YvWw38K6PvgcCvIvBkAz8DWiH80V6Ps796ABM9rNNJ5w6ir2BQmAhsA1ojfPFtpIx5kqcpFnFKJMyd/+N7v6TODHqdQbQ1xhzhrW2sLbAjTF3Af/nvroJ5znJx3neSmrbvoFUH7G7r4a+vd1/S3EmCtqA86W2DCcRMRrn9WSAn+H8mPFjH/t5Cec5OBNIcy+bDWT66Os147kx5nKc56Ri1FoBzmznW3ESQCk4X5YjcZ7PJcaY0dZaX/ftYeAHHtd3uW8zB+fMqk44EyKl+ti2Lireb1D7/V7veSUMXoP1Yoz5Oc7zV6EIyMD5sacDzqnxHXEet58BvY0x37HW2gB2/yjwQ/flb4CvcR6PFOpWVmM9J5732kwD+rkv+4zRGPMkcIvHomM4CZK9QFec+9waiAUeNMZ0tdbeHuDt/xNndHk5sAzndenCeQ9VvP5H4IxCPa/atrs4cT894/N3349Uu95Qny1N6d84ydp44DKcz7XqLsN5PqCOo3eNMV2ARUB/j8WbcJ6bIpzPpYpR+wOAucaYc621i2rY5/3AvR6L8oE5OD8edQXOAHoCH+P8Xw0kzq7ufQz0WPwtznvmGM7zNwnns7Q98JYx5rvW2lcD2X8IOAvnx6AInM+eJTiv7744icFInOO5t4wxg621W6rvoAGO5zz31ZjPxzScsk4unPfolzjHUKdyYpLh4cAbwFRjzL3AH9zL17hjKAXGAKe4l08BHgF+FMDtg/Njy0U4n5Ff4nxexADjcR4jcP7/zzbGTLDWVv+sqZMQPd4QkeYi2BlmNTU1NbX6NQIYqYrzBbCiTxY+RqfgfEm4Geeg2eIcDI/z0e9Rj33NB7r7uc1IIB14BR+jnqhlBFpTPC712OeLHvsscf99kmqjqHAOyF0+4qjY5j9AYrVtXECU+3J/Toy+sjhfrJN99L+DqiNUHq8hdlstjjzgYh/9Yur52Fznsf+tAfS/36N/MR4jb3z0fQC4HGjrZ70BLgD2e+xzYoDP43UB3r9TcL7wWpwvVw/jHqVdrV8/YIHH/j/20aeTx2uhFCe5ZfzcbjfgJ8CNDfz6rfF+h8NrsNq2SdXiS/LTbzx
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1600x1000 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"ensem_x = [i[\"num_models\"] for i in single_ensem_results]\n",
"\n",
2021-04-06 17:29:15 +01:00
"plt.plot(ensem_x, 1 - np.array([i[\"accuracy\"] for i in single_ensem_results]), 'x-', label='Ensemble Test')\n",
"plt.plot(ensem_x, 1 - np.array([i[\"individual_accuracy\"] for i in single_ensem_results]), 'x-', label='Individual Test')\n",
"plt.plot(ensem_x, 1 - np.array([i[\"agreement\"] for i in single_ensem_results]), 'x-', label='Disagreement')\n",
2021-03-26 20:01:05 +00:00
"\n",
2021-04-06 17:29:15 +01:00
"plt.title(\"Test Error Rates for Horizontal Model Ensembles\")\n",
"plt.ylim(0)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-26 20:01:05 +00:00
"plt.legend()\n",
2021-04-06 17:29:15 +01:00
"plt.ylabel(\"Error Rate\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Multiple Iterations\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the ensemble model investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|1|0.06|0|128|16|50|1, 3, 9, 15, 25|\n",
"|2|0.06|0|35|16|1 - 100|1, 3, 9, 15, 25|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
2021-04-30 20:51:04 +01:00
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models | stratify |\n",
"| --- | --- | --- | --- | --- | --- | --- | --- |\n",
"|3|0.06|0.05|35|16|1 - 300|1, 3, 9, 15, 25| |\n",
"|4|0.06|0.05|35|1 - 50|50|1, 3, 9, 15, 25| |\n",
"|5|0.06|0.05|35|1 - 300|50|1, 3, 9, 15, 25| |\n",
"|6|0.001|0.01|35|1 - 400|50|1, 3, 9, 15, 25| |\n",
"|7|0.01|0.01|35|1 - 400|30 - 150|1, 3, 9, 15, 25| |\n",
"|8|0.03|0.01|35|1 - 400|5 - 100|1, 3, 9, 15, 25| |\n",
"|9|0.1|0.01|35|1 - 400|20|1, 3, 9, 15, 25| |\n",
"|10|0.15|0.01|35|1 - 400|20|1, 3, 9, 15, 25, 35, 45| |\n",
"|11|0.15|0.01|35|1 - 400|10|1, 3, 9, 15, 25, 35, 45| |\n",
"|12|0.02|0.01|35|m|50|1, 3, 9, 15, 25, 35, 45| |\n",
"|13|0.01 exp 0.98, 1|0.01|35|1 - 200|50|1, 3, 9, 15, 25, 35, 45| n |\n",
"|14|0.01|0.01|35|1 - 200|50|1, 3, 9, 15, 25, 35, 45| n |\n",
"|15|0.01|0.9|35|50 - 100|50|1, 3, 5, 7, 9, 15, 25, 35, 45| n |\n",
"|16|0.01|0.1|35|50 - 100|50|1, 3, 5, 7, 9, 15, 25, 35, 45| n |\n",
"|17|0.1|0.1|35|50 - 100|50 - 100|1, 3, 5, 7, 9, 15, 25, 35, 45| n |"
2021-03-19 17:21:00 +00:00
]
},
2021-04-28 12:10:25 +01:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 335,
2021-04-28 12:10:25 +01:00
"metadata": {},
2021-04-30 20:51:04 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEGCAYAAABy53LJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqyUlEQVR4nO3deXxU1fnH8c+ThLATIIEECDthX0QQkEUCiCIiuKBgtbVWpQu21KWt9lertdoWrVq1Wou11eKC1qXEiggFgoLsCrJLWATCjsiiItvz+2MuNE1ZZmImk0y+79eLFzNn7r15jsR8c++59xxzd0RERMKVEOsCRESkbFFwiIhIRBQcIiISEQWHiIhERMEhIiIRSYp1ASUhLS3NmzRpUqR9P//8c6pWrVq8BZVy6nP8K2/9BfU5UosWLdrl7nVO9lm5CI4mTZqwcOHCIu2bm5tLdnZ28RZUyqnP8a+89RfU50iZ2Sen+kyXqkREJCIKDhERiUhUg8PMBpnZajPLM7M7TvJ5RTN7Ofh8npk1CdpTzWyGmR0wsz8W2qeLmS0N9nnMzCyafRARkf8WteAws0TgCeAioC1wtZm1LbTZDcAed28BPAKMDdoPAncBt5/k0H8CbgKygj+Dir96ERE5lWiecXQD8tx9nbsfAiYAwwptMwx4Lnj9KjDAzMzdP3f3WYQC5AQzqwfUcPe5Hppk6+/ApdEo/tgx5wcvLGJ2/mE0n5eIyH9E866qBsCmAu83A91PtY27HzGzvUAqsOs0x9xc6JgNTrahmY0CRgGkp6eTm5sbUfGfH3byNh9k0p5jfPiHd7iuXUWqJ5ePq2IHDhyI+L9XWVfe+lze+gvqc3GK29tx3X0cMA6ga9euXpRb0gYNcO58dipvrD3CvQuO8sAVHenXum4xV1r66LbF+Ffe+gvqc3GK5qWqfKBhgfeZQdtJtzGzJCAF2H2GY2ae4ZjFJjHBGNwsmYmje5NaNZnrn13Az99YyudfHYnWlxQRKfWiGRwLgCwza2pmycBIIKfQNjnAdcHr4cB0P82AgrtvBfaZWY/gbqpvAROLv/T/1rZ+DSbe3IvvnteMl+ZvZPBj77Hok0+j/WVFREqlqAWHux8BbgbeAVYCr7j7cjO718yGBps9A6SaWR5wK3Dill0z2wA8DHzbzDYXuCPrB8BfgDxgLfB2tPpQUMWkRO4c3IYJN/XgyFHnyqfm8Nu3V3Lw8NGS+PIiIqVGVMc43H0SMKlQ2y8LvD4IXHmKfZucon0h0L74qoxM92apTP5xH34zaSV/nrmOGat28NCVZ9EhMyVWJYmIlCg9OV4E1StV4LeXd+Rv15/D3i8Pc+mTs3lk6sccPnos1qWJiESdguNr6NeqLlN+3Jehnerz6LQ1XPrEbFZv2x/rskREokrB8TWlVKnAIyPO4qlru7Bt70GGPP4ef5y+hiM6+xCROKXgKCaD2mcw5ZbzuKBdBr+f8jGXPfk+q7bti3VZIiLFTsFRjFKrVeSJb5zNk9eczZbPvuSSx2fx+LQ1GvsQkbii4IiCwR3qMfXWvgxqX4+Hpn7MZU/OZuVWnX2ISHxQcERJ7arJPH51Z5669my27T3IJY/P4uGpH3PoiM4+RKRsU3BE2aD29Zh6S18u6VSfx6atYcjj77F402exLktEpMgUHCWgVtVkHhlxFn/9dlf2HzzC5U/O5v63VvDlIT11LiJlj4KjBPVvnc6UW85jZLdGPP3eegY9+i5z1p5uTkcRkdJHwVHCqleqwG8u68CLN4WWJrn66bnc8dpH7P3icIwrExEJj4IjRno2T2PymPP47nnNeGXhJgY8PJNJS7dqtUERKfUUHDFUOTk0427Ozb1Jr1GRH7zwATf9fRFb934Z69JERE5JwVEKtG+QwsTRvbjzotbMytvJwIffZfycDRw7prMPESl9FBylRFJiAt/t25x3fnweZzWsyV0Tl3PFU5q2RERKHwVHKdM4tSrjb+jGw1d14pPdXzDksVn87u1VunVXREoNBUcpZGZcfnYm027ty+VnN+CpmWu54A8zyV29I9aliYgoOEqzWlWTeWB4JyaM6kFyYgLf/tsCbn7xA3bsOxjr0kSkHFNwlAE9mqUyaUwfbjm/JVNWbGfAQzN5dvZ6jmrwXERiQMFRRlRMSmTM+VmhwfNGNbnnzRUMe2IWSzTvlYiUMAVHGdM0rSp//043/viNzuzY9xWXPjmbu/65jL1f6slzESkZCo4yyMwY0rE+027ry7d7NuGFeZ8w4KFcXlu0WU+ei0jUKTjKsOqVKnD3Je3Iubk3mbWqcNs/lnDVn+do0SgRiSoFRxxo3yCF17/fkweu6MjanZ8z5PFZ/OrN5ew7qMtXIlL8FBxxIiHBuOqchky/rS9Xd2vIs+9voP/vZ/LGh7p8JSLFS8ERZ2pWSea+SzswcXQvGtSqzC0vL+HKp+awLH9vrEsTkTih4IhTHTNr8sb3ezL2ig6s3/U5Q/84i/97Yyl7Pj8U69JEpIxTcMSxhARjxDmNmH57Nt86twkTFmyi30O5jJ/7iR4eFJEiU3CUAymVK3DP0HZM+lEf2mTU4K5/LuOSx2cxb52WrRWRyCk4ypFWGdV58abuPPGNs/nsi0OMGDeX0S98wOY9X8S6NBEpQxQc5YyZcXHHeky7LZtbzm/JtFWhua8enrKaLw4diXV5IlIGKDjKqcrJobmvpt+WzYXtMnhseh79fz+Tf36Yr9t3ReS0FBzlXP2alXns6s7843vnklY9mR+/vJhfzz3Iok/2xLo0ESmlohocZjbIzFabWZ6Z3XGSzyua2cvB5/PMrEmBz+4M2leb2YUF2m8xs+VmtszMXjKzStHsQ3lxTpPa5IzuzQPDO7L7oHPFn97nhy99qPEPEfkfUQsOM0sEngAuAtoCV5tZ20Kb3QDscfcWwCPA2GDftsBIoB0wCHjSzBLNrAHwI6Cru7cHEoPtpBgkJBhXdW3I2D6V+VH/FkxZvo3+D83kgcmrOPCVxj9EJCSaZxzdgDx3X+fuh4AJwLBC2wwDngtevwoMMDML2ie4+1fuvh7IC44HkARUNrMkoAqwJYp9KJcqJRm3XtCKGbdnM7h9Bk/mriX7wVxemr9Rz3+ICBatgVAzGw4Mcvcbg/ffBLq7+80FtlkWbLM5eL8W6A7cA8x19+eD9meAt939VTMbA9wPfAlMcfdrTvH1RwGjANLT07tMmDChSP04cOAA1apVK9K+ZVXhPq/97CgvrTpE3mfHyKxmXNUqmQ5piYQyPj6Ut3/n8tZfUJ8j1a9fv0Xu3vVknyV9rapKmJnVInQ20hT4DPiHmV17PGAKcvdxwDiArl27enZ2dpG+Zm5uLkXdt6wq3Ods4DvuTF62jd9NXsXDi76gd4s07hzcmnb1U2JVZrEqb//O5a2/oD4Xp2heqsoHGhZ4nxm0nXSb4NJTCrD7NPueD6x3953ufhh4HegZlerlv5gZF3Wox9Rb+vLLIW1ZtmUvQx6fxe3/WMLWvV/GujwRKUHRDI4FQJaZNTWzZEKD2DmFtskBrgteDweme+jaWQ4wMrjrqimQBcwHNgI9zKxKMBYyAFgZxT5IIclJCXynd1Nm3t6Pm/o0I2fxFrIfzGXs5FVavlaknIhacLj7EeBm4B1CP9xfcfflZnavmQ0NNnsGSDWzPOBW4I5g3+XAK8AKYDIw2t2Puvs8QoPoHwBLg/rHRasPcmopVSrw88FtmHZbXy5qn8GfctfS98EZ/OW9dXx15GisyxORKIrqGIe7TwImFWr7ZYHXB4ErT7Hv/YQGwQu33w3cXbyVSlE1rF2FP4zszI19mjF28irue2slz76/gdsvaMXQTvVJSIifAXQRCdGT41Is2jdIYfwN3fn7d7pRo1IFfvzyYoY8Povc1Ts0hYlInFFwSLE6r2Ud/vXD3jwyohP7Dh7m239bwNVPz+WDjZrCRCReKDik2CUkGJd1zmT6bdn8amg78nYc4PIn3+e74xeSt2N/rMsTka9JwSFRk5yUwHU9mzDzJ/24dWBLZuft5oJH3uX2fyzRHFgiZZiCQ6KuasUkfjQgi3d/2o/rezUlZ8k
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-04-28 12:10:25 +01:00
"source": [
"batch_size=35\n",
"test_size=0.5\n",
2021-04-30 20:51:04 +01:00
"epochs=50\n",
"lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.01,\n",
" decay_steps=1,\n",
" decay_rate=0.98)\n",
"\n",
"plt.plot(range(epochs+1), [lr_schedule(i) for i in range(epochs+1)])\n",
"plt.grid()\n",
"plt.ylim(0)\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Learning Rate')\n",
"plt.show()"
2021-04-28 12:10:25 +01:00
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 357,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_results = list()\n",
2021-04-06 17:29:15 +01:00
"multi_ensem_iterations = 30\n",
2021-03-19 17:21:00 +00:00
"for i in range(multi_ensem_iterations):\n",
" print(f\"Iteration {i+1}/{multi_ensem_iterations}\")\n",
2021-04-30 20:51:04 +01:00
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=test_size, \n",
"# stratify=labels\n",
" )\n",
" multi_ensem_results.append(list(evaluate_ensemble_vote(epochs=(50, 100),\n",
" hidden_nodes=(50, 100),\n",
" nmodels=[1, 3, 5, 7, 9, 15, 25, 35, 45],\n",
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.1),\n",
2021-03-19 17:21:00 +00:00
" weight_init=lambda: 'random_uniform',\n",
2021-04-28 12:10:25 +01:00
" batch_size=batch_size,\n",
2021-03-19 17:21:00 +00:00
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-26 20:01:05 +00:00
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
2021-03-19 17:21:00 +00:00
]
},
2021-03-30 16:31:10 +01:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 322,
2021-03-30 16:31:10 +01:00
"metadata": {},
"outputs": [],
"source": [
"def test_tensor_data(test):\n",
" return [test[\"accuracy\"], \n",
" np.mean([i[\"accuracy\"][-1] for i in test[\"history\"]]), # avg train acc\n",
" test[\"individual_accuracy\"], \n",
" test[\"agreement\"]]"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 362,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"30 Tests\n",
2021-04-30 20:51:04 +01:00
"Models: [1, 3, 5, 7, 9, 15, 25, 35, 45]\n",
2021-03-26 20:01:05 +00:00
"\n",
"Loss: categorical_crossentropy\n",
2021-04-30 20:51:04 +01:00
"LR: 0.1\n",
"Momentum: 0.1\n"
2021-03-26 20:01:05 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_models = sorted(list({i[\"num_models\"] for i in multi_ensem_results[0]}))\n",
"multi_ensem_iter = len(multi_ensem_results)\n",
"\n",
2021-03-26 20:01:05 +00:00
"accuracy_ensem_tensor = np.zeros((multi_ensem_iter, 4, len(multi_ensem_models)))\n",
2021-03-19 17:21:00 +00:00
"for iter_idx, iteration in enumerate(multi_ensem_results):\n",
" for single_test in iteration:\n",
2021-03-22 20:49:29 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" ensem_models_idx = multi_ensem_models.index(single_test['num_models'])\n",
2021-03-30 16:31:10 +01:00
" accuracy_ensem_tensor[iter_idx, :, ensem_models_idx] = test_tensor_data(single_test)\n",
2021-03-19 17:21:00 +00:00
" \n",
"mean_ensem_accuracy = np.mean(accuracy_ensem_tensor, axis=0)\n",
2021-03-22 20:49:29 +00:00
"std_ensem_accuracy = np.std(accuracy_ensem_tensor, axis=0)\n",
"\n",
"print(f'{multi_ensem_iter} Tests')\n",
"print(f'Models: {multi_ensem_models}')\n",
"print()\n",
2021-03-26 20:01:05 +00:00
"print(f'Loss: {multi_ensem_results[0][0][\"loss\"]}')\n",
"print(f'LR: {multi_ensem_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
"print(f'Momentum: {multi_ensem_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-04-25 15:25:41 +01:00
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 358,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-04-25 15:25:41 +01:00
"outputs": [],
"source": [
2021-04-30 20:51:04 +01:00
"exp2_testname = 'exp2-test17'\n",
"pickle.dump(multi_ensem_results, open(f\"results/{exp2_testname}.p\", \"wb\"))"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 349,
"metadata": {},
"outputs": [],
"source": [
2021-04-30 20:51:04 +01:00
"exp2_testname = 'exp2-test16'\n",
2021-04-06 17:29:15 +01:00
"multi_ensem_results = pickle.load(open(f\"results/{exp2_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
"np.savetxt(\"exp2-mean.csv\", mean_ensem_accuracy, delimiter=',')\n",
"np.savetxt(\"exp2-std.csv\", std_ensem_accuracy, delimiter=',')"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"mean_ensem_accuracy = np.loadtxt(\"results/test1-exp2-mean.csv\", delimiter=',')\n",
2021-03-26 20:01:05 +00:00
"std_ensem_accuracy = np.loadtxt(\"results/test1-exp2-std.csv\", delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
2021-03-22 20:49:29 +00:00
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Best Results"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 363,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-30 20:51:04 +01:00
"Models: 5, 96.4% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_ensem_accuracy_idx = np.unravel_index(np.argmax(mean_ensem_accuracy[0, :]), mean_ensem_accuracy.shape)\n",
"best_ensem_accuracy = mean_ensem_accuracy[best_ensem_accuracy_idx]\n",
"best_ensem_accuracy_models = multi_ensem_models[best_ensem_accuracy_idx[1]]\n",
"\n",
2021-03-27 16:29:31 +00:00
"print(f'Models: {best_ensem_accuracy_models}, {best_ensem_accuracy * 100:.3}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
2021-03-27 16:29:31 +00:00
"### Test/Train Error Over Model Numbers"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 364,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJwAAAMMCAYAAAAW770RAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AADnjklEQVR4nOzdd5zc1Ln/8c8zW13WvWFsMAZMNQQwECAGU0IgprcESMCUQCi5hJLkJiTBXNJ+oYQAoYVic+k4NFMSX4diAwZjCMSYgLGNce9t1946c35/aGZXMzt1VzPbvu/X6rUa6eicoxlJIz1zdGTOOURERERERERERIISausKiIiIiIiIiIhI56KAk4iIiIiIiIiIBEoBJxERERERERERCZQCTiIiIiIiIiIiEigFnEREREREREREJFAKOImIiIiIiIiISKAUcBIRERERERERkUAp4CQiIiIiIiIiIoFSwElERERERERERAKlgJOIiIiIiIiIiARKAScREREREREREQmUAk4iIiIiIiIiIhIoBZxERERERERERCRQCjiJiIiIiIiIiEigFHASEREREREREZFAKeAkIiIiIiIiIiKBUsBJREREREREREQCpYCTiIiIiIiIiIgESgEnEREREREREREJlAJOIiIibcDMupvZNWY2w8zWmVmDmbnoMLGt6yfZMbPFvs9tRFvXR8DMxvk+kzfauj7S9szsDd82MS6gPCf48pwURJ7SemY2yfe5TChgubEyXaHKFOkIFHASkQ4h4WQx12FSW9e/o0k4YUs2bDWzFWb2mpn9xsxGtXWdOxIz6wfMBm4FxgL9gaI2rVSAEraVcTku69/XJ+algiI5SAhgtWQY0dbrIF1TQkA8NuybYx7PJMljQp6qLCKdjAJOIiLSEt2B7YAjgeuBz8zsL2ZWXuiKdNAWJv8P2Cs63gD8A7gf+Et0mN1G9RJplbZqXSAiWTsv24Rm1hc4MY91EZFOrritKyAi0gLvk9sF+bv5qkgX8Rnwz4RpPYE9gAMBiw6XA9uZ2enOOTUpT8HMioGzfZOOds7NaKv6iEjO/pJj+i15qYVIy5xjZj91zoWzSPsdoCzfFRKRzksBJxHpiF5xzk1s60p0Ie85565MNsPM9gSeAPaJTjoVOA34W4Hq1hGNAnpExxco2NSxOedGtHUdpLBSHQ9F2rlPgT2BIcCxwKtZLBNrDVUHLAF2yU/VRKSz0i11IiLSYs65T4Hjga2+yZe2UXU6ir6+8ZVtVgsREelK/tc3nvG2OjPbFTgk+vIVYH0+KiUinZsCTiIi0irOuRXA075J3zAza6v6dAAlvvFIm9VCRES6ko+Af0fHTzazXhnS+4NSk/NSIxHp9BRwEpEuK1ln02a2s5n91sz+ZWZrzSxiZh/5lmn2aGUz287MfmFms81slZmFzWxTijJ3NLP/MbN3zWy1mdVF/79rZjea2fAs6p30kd9m9m0ze8LMvjCzquj8H7fiLcrFR77xbsS34oljZiEzGxt9H6aZ2RIz22ZmtWa2Mvrku+vNbECaPEbE3gNgR9+sL1M8JWpcmrz6m9m1ZvZ/ZrbUzGrMbJOZfRrtCH1M1u9CdvV93TfriCR1fSNNPt8ys4fMbL6ZbTGzajP7ysyeM+8R3SWplvXl0axTZzPrY2ZXmdkMM1tuZg3R+X1auep5Z54zo9v+wui2XxUdf9zMzsgmANrSfTvZcSRhfkufbDYuTV1LzOwCM3s++vlXR7eHz83sQTP7ZpbvXbJj4DAzu8nMPo7uB1vN7DMzu9PMdsyUF3C+b/LDKdZtYpLle5vZ2WZ2n5m9Z2brzDs+bol+lk+Y2Vlm1inOXaP7a7MnqZrZqWY21bzjYq2ZrTHvOPm9bLbjaB4HmtldZvahmW2M7s/V5h1f3zWze6LvZY8s8uphZpdF6/SVecfqSvO+Zx4ys6Nasq7mfQ+cY2avmnfcrTXvu/BvZnZIkjxKzez7ZvZPazpOLzGzyWa2RzbvS5I8h5v3lNWPzWyDb1v/k5nl5dYxM9vDzH5n3jEldg6wNrrN/4+ZDc1HuT6PRP93A85MU08Dvhd9uR54OdeCgjpOJeR5ipm9YN73VK2ZLTPvu/v75vWPmDPznBrdluab2ebo9rU0WvfzW5p3ivIGmNl1ZjbdvKf91phZvXnH23lmNsXMrjGznYIqU6RNOec0aNCgod0PwBuAiw4TA8pzsS/PEcAlQLVvWmz4KEU9xgEnAxuSLLMpSXnXp8jfP1QDP8tQ73G+9G8AvYFnU+T34xa+N5N8eUzKIv0PEsodmiJdCbAsw3sQG6qA76XIZ0SWeTR+VinyuQLYlGHZCPAgUNqKbS2X+r6RZPlBwPQslp0PjMnhs50AHIbXN0ey/Pq0cH0zvvdB7OvArsCHWbwvc4CROZQ7jiz3bRKOIxneiyC22YOBBVksPw0YkGGd4+oOnEL6/WEbMD6LvDINExOWPQ2oyXLZj4CdMqzXOF/6ZvtTC7Znf36utflF85zgy3MS3nH8hQzr/irQLU2excB9OXwOv8lQxzPxbvvNlM9UoHcO6zoA70EUqfKLABf4lt8Fr/+hVOlrgVNy3L9PIvO2fkkun2GGtGXAvXhPJU33Xm4DrgxiG0uyXx6H139TrA5vplnuCN9yd0WnveubNiFDuYEdp6L59cQLeqXLa2Z0/SblUM99gH9lUc/PgD0z5JXxGEHq75Vkw7KgtgMNGtpyUKfhIiKeM4E/RsdXAG8Dm4GhQL8UyxwKTMQLoqwHZgDr8IID+/kTmtldeMGNmCq8li6r8E6QjsQ7oSoH/mBmQ5xzV2dRbwMeBU7AO0GZg3dibsDe0WmF4P9VNkzqvh6KgO2j41XAPGAR3lOcSoBhwNeBXngda/+vmdU7555KyGcLTU+KOg+oiI4/AlQmKXd54gQzux24yjdpHTAL7zMpx/sM98Z7Ly8EhprZeOdcS26D89d3e7wLe/C2tecS0n6RUM/BeNvjzr7JC4H38C609sQ7uQcvAPO6mR3nnHs7i3rtAtyOd7FbibcNr8BroXZ4Fsu3mWirhjeBgb7Jc/ECEg7v8xsdnX4A8I6ZHe6cm59F9lnv21nI9olmI4DxvtfN9l0zOxwv6NDdl2Y23j5firfvxLaTbwJvm9k3nHNrsyj/GLwL4iK8AOQsvO12J7wL9GK8VhFPm9nezrkvE5afDPQHjgZ2j077J96FWqLEp4wOoulJWMui67MK7+I79kTM/fH2xX2BGWb2NedcZ+lTphjvQQtH43XO/A7ePl4OjAV2iKY7DrgNuCxFPjfj/XASsxzvvV6Ld1dDf7zjxW6ZKmRmVwO34r3n4G0Ls/A+nyJgL2BMdP4JwBtmdphzblsW6/psdL1q8PbhJXjfs0cDfaJ5PmBmX+AF0V8DhkfrMAMvCDYYb5vtjrftP25meyXZLpMZA/w2utx6vGDURrx98Ai8/b4bcJ+ZhZ1zD2aRZ0rRlmT/wAvuxywEPoiW2y86b2i03DvNrJdz7netKTcZ59wqM/s/vG1prJmNcM4tTpL0fN/4I0nmpxT0ccq8lrsvE/+dtApvW6jE+x77RnR4Du+cItt6TsU73wCox3sK8hfR8RHRPMvx9pl3zOwQ59x/ssk/SXljgCk0PbSrGi+Itxjvu7wX3vsymqb3TqTja+uIlwYNGjRkM5D/Fk71eF/4PwAsIV1ZinrU4/0S+0ugJM0yZxH/q9XDQK+E9L3wOvT0pzstRb3HJdTB4fXLMDpJ2rJs34+E5Sb5ypiURfp3fOnfT5OuFHgoug4lKdKUAT/xrdtGoGeWn+OILNfvQt8ym4GLk9UHLxDob5H10wC2O//n90YW6V/xpa8CvpskzRi8C5hYuiWkaJ2U8NnG3uO7Et9jvIuuUAvX0b8dj8txWf8+NjHNdvSRL91q4Jgk6Y7Fu9iOpfsgzXbnLzerfbul21+SsnsT34LjzSTl9k3YFucDByTJ61y8QE0s3YtZ7js10e3rezQ/Bu6VUPZDafL0b18Tslz/E4H/BnZJk2Yn4O++vB8Iah/Lon7+/Fxr84vmOSHhvXd4+/r2CemK8QJJsbSRZNsZXjA
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 1200x800 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig = plt.figure(figsize=(6, 4))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-27 16:29:31 +00:00
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], 'x-', label='Ensemble Test')\n",
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], 'x-', label='Individual Test')\n",
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], 'x-', label='Individual Train')\n",
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], 'x-', label='Agreement')\n",
"\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], yerr=std_ensem_accuracy[0, :], capsize=4, label='Ensemble Test')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], yerr=std_ensem_accuracy[2, :], capsize=4, label='Individual Test')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], yerr=std_ensem_accuracy[1, :], capsize=4, label='Individual Train')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], yerr=std_ensem_accuracy[3, :], capsize=4, label='Disagreement')\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
"plt.title(f\"Error Rate for Horizontal Ensemble Models\")\n",
2021-04-30 20:51:04 +01:00
"# plt.ylim(0, 0.2)\n",
"# plt.ylim(0, np.max(1 - mean_ensem_accuracy + std_ensem_accuracy) + 0.05)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.legend()\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
2021-03-22 20:49:29 +00:00
"plt.ylabel(\"Error Rate\")\n",
2021-04-30 20:51:04 +01:00
"\n",
"plt.tight_layout()\n",
2021-04-28 21:57:13 +01:00
"plt.savefig(f'graphs/{exp2_testname}-error-rate-curves.png')\n",
2021-04-30 20:51:04 +01:00
"\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-03-26 20:01:05 +00:00
{
"cell_type": "markdown",
"metadata": {
"id": "FSZq1mNiVZq_",
"tags": [
"ex3"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 3\n",
"\n",
"Repeat Exp 2) for cancer dataset with two different optimisers of your choice e.g. 'trainlm' and 'trainrp'. Comment and discuss the result and decide which is more appropriate training algorithm for the problem. In your discussion, include in your description a detailed account of how the training algorithms (optimisations) work."
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
"def evaluate_optimisers(optimizers=[(lambda: 'sgd', 'sgd'), \n",
" (lambda: 'adam', 'adam'), \n",
" (lambda: 'rmsprop', 'rmsprop')],\n",
" weight_init=lambda: 'glorot_uniform',\n",
" print_params=True,\n",
" **kwargs\n",
" ):\n",
" for o in optimizers:\n",
" \n",
" if print_params:\n",
" print(f'Optimiser: {o[1]}')\n",
" \n",
" yield list(evaluate_ensemble_vote(optimizer=o[0],\n",
" weight_init=weight_init,\n",
" exp=f'3-{o[1]}',\n",
" print_params=print_params,\n",
" **kwargs\n",
" ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single Iteration"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimiser: sgd\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n",
"Optimiser: adam\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n",
"Optimiser: rmsprop\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n"
]
}
],
"source": [
"single_optim_results = list()\n",
"for test in evaluate_optimisers(epochs=(5, 300), nmodels=[1, 3, 5]):\n",
" single_optim_results.append(test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multiple Iterations\n",
"\n",
"### Pickle Results\n",
"\n",
2021-04-29 22:53:26 +01:00
"| test | optim1 | optim2 | optim3 | lr | momentum | epsilon | batch size | hidden nodes | epochs | models | stratified |\n",
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n",
"| 1 | SGD | Adam | RMSprop | 0.1 | 0.0 | 1e7 | 35 | 16 | 1 - 100 | 1, 3, 9, 15, 25 | y |\n",
"| 2 | SGD | Adam | RMSprop | 0.05 | 0.01 | 1e7 | 35 | 16 | 1 - 100 | 1, 3, 9, 15, 25 | y |\n",
"| 3 | SGD | Adam | RMSprop | 0.1 | 0.01 | 1e7 | 35 | 1 - 400 | 20 | 1, 3, 9, 15, 25, 35, 45 | y |\n",
"| 4 | SGD | Adam | RMSprop | 0.075 | 0.01 | 1e7 | 35 | 1 - 400 | 20 | 1, 3, 9, 15, 25, 35, 45 | y |\n",
2021-04-30 20:51:04 +01:00
"| 5 | SGD | Adam | RMSprop | 0.05 | 0.01 | 1e7 | 35 | 1 - 400 | 20 | 1, 3, 9, 15, 25, 35, 45 | n |\n",
"| 6 | SGD | Adam | RMSprop | 0.02 | 0.01 | 1e7 | 35 | m | 50 | 1, 3, 9, 15, 25, 35, 45 | n |\n",
"| 7 | SGD | Adam | RMSprop | 0.1 | 0.9 | 1e-8 | 35 | 1 - 400 | 50 - 100 | 1, 3, 5, 7, 9, 15, 25 | n |\n",
"| 8 | SGD | Adam | RMSprop | 0.05 | 0.9 | 1e-8 | 35 | 1 - 400 | 50 - 100 | 1, 3, 5, 7, 9, 15, 25 | n |"
]
},
{
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 16:31:10 +01:00
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
]
}
],
"source": [
"multi_optim_results = list()\n",
2021-03-30 16:31:10 +01:00
"multi_optim_iterations = 30\n",
"\n",
2021-04-29 22:53:26 +01:00
"multi_optim_lr = 0.05\n",
2021-04-09 12:42:18 +01:00
"multi_optim_mom = 0.01\n",
"multi_optim_eps = 1e-07\n",
"multi_optims = [(lambda: tf_optim.SGD(learning_rate=multi_optim_lr, \n",
" momentum=multi_optim_mom), 'sgd'), \n",
" (lambda: tf_optim.Adam(learning_rate=multi_optim_lr, \n",
" epsilon=multi_optim_eps), 'adam'), \n",
" (lambda: tf_optim.RMSprop(learning_rate=multi_optim_lr, \n",
" momentum=multi_optim_mom, \n",
" epsilon=multi_optim_eps), 'rmsprop')]\n",
"\n",
"for i in range(multi_optim_iterations):\n",
" print(f\"Iteration {i+1}/{multi_optim_iterations}\")\n",
2021-04-29 22:53:26 +01:00
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, \n",
"# stratify=labels\n",
" )\n",
2021-04-30 20:51:04 +01:00
" multi_optim_results.append(list(evaluate_optimisers(epochs=(50, 100),\n",
2021-04-28 12:10:25 +01:00
" hidden_nodes=(1, 400),\n",
2021-04-30 20:51:04 +01:00
" nmodels=[1, 3, 9, 15, 25],\n",
" optimizers=multi_optims,\n",
" weight_init=lambda: 'random_uniform',\n",
" batch_size=35,\n",
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 467,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 16:31:10 +01:00
"30 Tests\n",
"Optimisers: ['SGD', 'Adam', 'RMSprop']\n",
2021-04-30 20:51:04 +01:00
"Models: [1, 3, 5, 7, 9, 15, 25]\n",
"\n",
"Loss: categorical_crossentropy\n"
]
}
],
"source": [
"multi_optim_results_dict = dict() # indexed by optimiser name\n",
"multi_optim_iter = len(multi_optim_results) # number of iterations (30)\n",
"\n",
"#####################################\n",
"## INDIVIDUAL RESULTS TO DICTIONARY\n",
"#####################################\n",
"for iter_idx, iteration in enumerate(multi_optim_results): # of 30 iterations\n",
" for model_idx, model_test in enumerate(iteration): # of 3 optimisers\n",
" for single_optim_test in model_test: # single tests for each optimisers\n",
" \n",
" single_optim_name = single_optim_test[\"optimizer\"][\"name\"]\n",
" if single_optim_name not in multi_optim_results_dict:\n",
" multi_optim_results_dict[single_optim_name] = list(list() for _ in range(multi_optim_iter))\n",
"\n",
" multi_optim_results_dict[single_optim_name][iter_idx].append(single_optim_test)\n",
"\n",
"# list of numbers of models used in test\n",
"multi_optim_models = sorted(list({i[\"num_models\"] for i in multi_optim_results[0][0]}))\n",
"\n",
"##################################\n",
"## DICTIONARY TO RESULTS TENSORS\n",
"##################################\n",
"optim_tensors = dict()\n",
"for optim, optim_results in multi_optim_results_dict.items():\n",
" \n",
" accuracy_optim_tensor = np.zeros((multi_optim_iter, 4, len(multi_optim_models)))\n",
" for iter_idx, iteration in enumerate(optim_results):\n",
" for single_test in iteration:\n",
"\n",
" optim_models_idx = multi_optim_models.index(single_test['num_models'])\n",
2021-03-30 16:31:10 +01:00
" accuracy_optim_tensor[iter_idx, :, optim_models_idx] = test_tensor_data(single_test)\n",
" \n",
" optim_tensors[optim] = {\n",
" \"accuracy\": accuracy_optim_tensor,\n",
" \"mean\": np.mean(accuracy_optim_tensor, axis=0),\n",
" \"std\": np.std(accuracy_optim_tensor, axis=0)\n",
" }\n",
"\n",
"print(f'{multi_optim_iter} Tests')\n",
"print(f'Optimisers: {list(multi_optim_results_dict.keys())}')\n",
"print(f'Models: {multi_optim_models}')\n",
"print()\n",
"print(f'Loss: {multi_optim_results[0][0][0][\"loss\"]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-03-30 16:31:10 +01:00
"cell_type": "code",
2021-04-29 22:53:26 +01:00
"execution_count": 28,
"metadata": {},
2021-03-30 16:31:10 +01:00
"outputs": [],
"source": [
2021-04-29 22:53:26 +01:00
"pickle.dump(multi_optim_results, open(\"results/exp3-test5.p\", \"wb\"))"
]
},
{
2021-04-06 17:29:15 +01:00
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 466,
"metadata": {},
2021-04-06 17:29:15 +01:00
"outputs": [],
"source": [
2021-04-30 20:51:04 +01:00
"exp3_testname = 'exp3-test8'\n",
2021-04-06 17:29:15 +01:00
"multi_optim_results = pickle.load(open(f\"results/{exp3_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Best Results"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 468,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-30 20:51:04 +01:00
"SGD: 9 Models, 96.5% Accurate\n",
"Adam: 7 Models, 96.3% Accurate\n",
"RMSprop: 9 Models, 96.3% Accurate\n"
]
}
],
"source": [
"for optim, optim_results in optim_tensors.items():\n",
" best_optim_accuracy_idx = np.unravel_index(np.argmax(optim_results[\"mean\"][0, :]), optim_results[\"mean\"].shape)\n",
" best_optim_accuracy = optim_results[\"mean\"][best_optim_accuracy_idx]\n",
" best_optim_accuracy_models = multi_optim_models[best_optim_accuracy_idx[1]]\n",
"\n",
" print(f'{optim}: {best_optim_accuracy_models} Models, {best_optim_accuracy * 100:.3}% Accurate')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optimiser Error Rates"
]
},
{
"cell_type": "code",
2021-04-30 20:51:04 +01:00
"execution_count": 469,
"metadata": {},
"outputs": [
{
"data": {
2021-04-30 20:51:04 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAACWQAAAJECAYAAACbnL/OAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd5gV1f3H8c+X3ouKFQVR7FhBY2UtsWGwm6hRMcaCXTSaxEQxJpqIvZv8UDA2EHtHIwsWFBALoiioqIiKSO9lz++PM3d37t1b5rbdu3ffr+e5z96dOXPmzJ12vjNnzphzTgAAAAAAAAAAAAAAAACA/DWp7wIAAAAAAAAAAAAAAAAAQLmgQRYAAAAAAAAAAAAAAAAAFAgNsgAAAAAAAAAAAAAAAACgQGiQBQAAAAAAAAAAAAAAAAAFQoMsAAAAAAAAAAAAAAAAACgQGmQBAAAAAAAAAAAAAAAAQIHQIAsAAAAAAAAAAAAAAAAACoQGWQAAAAAAAAAAAAAAAABQIDTIAgAAAAAAAAAAAAAAAIACoUEWAAAAAAAAAAAAAAAAABQIDbIAAAAAAAAAAAAAAAAAoEBokAUAAAAAAAAAAAAAAAAABUKDLAAAAAAAAAAAAAAAAAAoEBpkAQAAAAAAAAAAAAAAAECB0CALAAAAAAAAAAAAAAAAAAqEBlkAAAAAAAAAAAAAAAAAUCA0yAIAAAAAAAAAAAAAAACAAqFBFgAAAAAAAAAAAAAAAAAUCA2yAAAAAAAAAAAAAAAAAKBAaJAFACg5Zra7mQ03sxlmttTMXOxT32VD41aMbdHMhoXyHVCofJEfM5sZWi/d62ieFaF5VtbFPAEkZ2aDQ/vj4PouD4DiM7NmZnammb1qZj+a2arQcWBYfZcPjZeZDSjGtkicXXrqKx4gJgUAoGEys3XN7Goze9fM5pvZWs7pKAXFqF9y7bw01dc11Pq4f5MrGmSh3gUH0LvNbJKZ/RRc9FxuZnOCYY+Y2SVm1tvMLMu8Lcj/72Y21sy+MLMFwTzmmtnnZvakmf3FzHaNmGf4JJL4WRUswwwze8fM7jOzs8xs69x+nUjlqUhTniif7sUqWzkys+4Zfs81ZjbPzD4xs4fN7Ddm1qK+y92QmNlASeMlnSppC0lt6rdEhZNQMcn2M7O+y4/GKcVxb46ZNcsij6Zm9j3nIAC5MrOHEo4fV9R3mRq7DHFRpk9lfZe/oYlQj1xuZj+Y2ZtmdpOZ7VLfZW5IzKylpFcl/VvSQZLWl9S8XgtVQBZ/oTLbz+D6Lj8apxTHvVuyzKMf5yCgtJhZZYbzzmIz+8bMXjazP5vZJlnkXa/HDSvifQ4UBnWiulXM/R2SmfWQ9KGkwZJ2l9RJZXLf3zLfh8v0qajvZUDjlOK4d2SWeQzhHFReIt9IAwrNzLaVdL+kXyQZ3VxSK0ldJO0m6cRg+FRJO0TM/wRJV0naPkWSdYNPT0lHS7rWzL6UdIuk/zjnVkZbklrlXi/4bCFpj1B53pB0l3NuRA75ouFoKqlz8NlW0kmSZprZqc65N+qyIOZbnT8Q/DvcOTegLuefCzPrJul21QQOX0p6V9K8eisUgGS6SDpM0nMR0x8iacPiFQdAOTOz9vL19bDTJP2rHooDlKpWwWcDSXtLGmRmj0s62zk3vy4LEty07Bv8u79zrrIu55+jyyRVhP4fK2mGpBXB/+/UdYEAJHWimf3BObcmYvrTiloaAMXQLvhsKn8tYbCZ/UPS35xzufToV/TjRrHvcwBlrND7e87ManoMdc41lAaT90mKNWJbLuk1Sd9JWhsM+7Q+CgWgllMlPRMloZk1lXRycYuDukaDLNQL80/rvi7fYjvmR0mTJP0gyck3ltpB0paSYhWgcPpUebeWNFQ1wU3MMkkTg/wXBnmtLx8ItQ/S9JB0h6RfSorSYnWapP+FZy+pQ5D3VvKNvWJl31fSvmZ2qqTfOed+jJB/Lu7KMv2iopSi8XhQ0uLQ/81UcxOkSzCsu6RXzOwA5xwX8tM7UTXnptGS+mVxwaShmS3pqSzS/1ysggA5OlXRG2SdWsyCACh7x6t2j5nbmlkf59zE+igQakmMizKZXqyCNBLJ6pFt5B8K2lM1PTsdL6lrEIesENI5JfT9NOfcg/VWkuL7n/w+G9WEYhUEyMEG8jdsX8iU0Mw6SfpVsQsEIC8TVfs801HSTpJ6Bf83l+/9pZOkS3KYR1GPG8W8z4Gio05Ut+pif280zGwj+Z59JWmlpJ2cc+UcZyfeh8vku2IVBMjBEWbWOeLDcr+UtFGxC4S6RYMs1Dkzay7pEdUEHbMlnSfpWedcVZL0XeQbR50i32AqXd4t5F8zsHdo8ARJf5P0qnNuVZJpmsk/vXKGfG9GLSS1jbg47zrnzk9TnnXkL4JfIin22sLDJb1jZrs7536KOJ/I0pUHRXG1c25m4sBgW7xE0vXygXZrSfeZ2c51/XRHAxN+deiDZdwYS5Kms7+igfpE0naSfmVmnZxzC9IlNrOOqmnkHJsWALIRfkp9uXy9KjacBlmlIW1chIJLWY80s03lL1ZXBIP2lI+3b6qbojU8ZtZGNfH6KkkP1WNx6sJDzrlh9V0IIEvhOOJURWhYIekE+V5pEqcHUDpedM4NTjbCzPaS9KikzYJBF5vZw865SRHzLvpxo5j3OVAnqBPVrWLu741R+BX1b5R5YywpxX04oMTF6hItJP1G0j0Rpgk/2E4MUybK4l2yaHCOkrRN8H25/CsMnk4WpEiSc+4n59z/Oef6Kv4VAsncrvjGWP9wzu3hnHshWWOsIP81zrk3nXOnS9pc0pNZLEtazrl5zrn75Fv43xEa1V3SU0FjMJQh59wq59y/JN0aGryj/A0RpNY59P37eisFgHT+G/xtKenXEdKHL2iWc28TAIrAzDaX72lW8k+XXxYafWLQCB5AwDn3rXzPDt+GBp9dT8VpKMIxyI+prk0AqFdTJH0YfO8fPPSRSaxB92r5m7wAGhDn3NvyjZfCD7aelUUWdXHcOErFu88BNBoF2N8bI+6jAKXvMfk6hRThDSJm1kG+biFJH8jXZVAGaJCF+nBw6PszzrnPo07onPsi1Tgz66v4C823Oef+kk3BnHOznXPHSro8m+ki5LvaOXehpLtDg/eW75EL5e22hP/3q5dSNBzNQ9+5EQKUpkckxXqvi/Iqwlia1cG0AJCNU1XzWo+xkv4tKdbL7DqSjqiPQgGlzDm3RNL/hQb1NLMN66s8DQAxCNAwDA/+tpJ/6CMlM9tC0l7Bvy9KmlvEcgEoEufcB5IqQ4Oyva5a7ONGUe5zAI1RAfb3xoYYBih9P0l6Kfj+CzPrmSH98ap5K8DwdAnRsNAgC/Vhk9D3rwuY759D37+S9MdcM3LOTc6/OEkNki9bzJ/MrOT2QzMbYGYu+AwLhjU1s9+Y2TNm9qWZLQ/GHxWMrwhNUxnK63Aze9TMppvZkmD8xUnmaWZ2fJD2iyDtkuD7I2Z2nJlZ4nRJ8qkMlaMiGLaRmf3ZzCaY2Q9mttbMFhTgp8rIOfe1pPB7gTdOl97M1jez081suJm9b2bzzGy1mS0ws2lm9oCZHZIhj2Fm5iQ9EBp8Wuh3CX8qM+R1oJnda2ZTg7KsNLPZZvaKmZ1vZq3TTR9FrLxBmfuGRo1JUt6KFHmsZ2Z/NLOxZvZ9UM65wW84xMwydutpZt1D85kZGr6Pmf1f8PsvDMbfmt9SF0749zOzAcGwNmZ2rpm9aWY/Br/Ht8H+tXeGLGP5mpkdFex/n5nZomDfWWpmM83sdTP7l5ntbxGOY2a2qZn91czeCLahlcE29b6Z3WhmW+W4rJ3M7FIze8fM5pjZKvPHqLvNv7InMY91g21lgpn9ZGbLzOxTM/unmXWuNdNov1WfYBv5PPh95gX5/8n8Uw0FVxf7ZhpzJL0cfN/L/EXLVOXcXDU9V76smkYUkRVi/07Ir6WZXRBsiz+ZP5/FzjX7Z1u+UL5tzWygmT1nZl8H29Zi8+e/+83sgFzzTjG/bczshmDbnxts+yuC/eA98+eL03LdroFSYGam+Iaf/w1eZ/xYaNhpylJw7nok2Fd
"text/plain": [
2021-04-30 20:51:04 +01:00
"<Figure size 2400x600 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-04-30 20:51:04 +01:00
"fig, axes = plt.subplots(1, 3, figsize=(12, 3))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, ((optimiser_name, tensors_dict), ax) in enumerate(zip(optim_tensors.items(), axes.flatten())):\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][0, :], 'x-', label='Ensemble Test')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][2, :], 'x-', label='Individual Test')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][1, :], 'x-', label='Individual Train')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][3, :], 'x-', label='Disagreement')\n",
"\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][0, :], yerr=tensors_dict[\"std\"][0, :], capsize=4, label='Ensemble Test')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][2, :], yerr=tensors_dict[\"std\"][2, :], capsize=4, label='Individual Test')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][1, :], yerr=tensors_dict[\"std\"][1, :], capsize=4, label='Individual Train')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][3, :], yerr=tensors_dict[\"std\"][3, :], capsize=4, label='Disagreement')\n",
"\n",
" ax.set_title(f\"{optimiser_name} Error Rate for Ensemble Models\")\n",
2021-04-30 20:51:04 +01:00
" ax.set_ylim(0, 0.1)\n",
"# ax.set_ylim(0, np.max([np.max(1 - i[\"mean\"] + i[\"std\"]) for i in optim_tensors.values()]) + 0.03)\n",
" ax.grid()\n",
2021-04-30 20:51:04 +01:00
"# if idx > 0:\n",
" ax.legend()\n",
" ax.set_xlabel(\"Number of Models\")\n",
" ax.set_ylabel(\"Error Rate\")\n",
"\n",
2021-04-30 20:51:04 +01:00
"# axes[0].set_ylim(0, 0.4)\n",
"axes[1].legend()\n",
"axes[2].legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(f'graphs/{exp3_testname}-error-rate-curves.png')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2021-03-19 17:21:00 +00:00
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyNAMGLKzaoWaq1wvQ+w0w8h",
"collapsed_sections": [],
"name": "nncw.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-04-30 20:51:04 +01:00
"version": "3.8.9"
},
"toc-autonumbering": false,
2021-03-27 16:29:31 +00:00
"toc-showcode": false,
2021-03-22 20:49:29 +00:00
"toc-showmarkdowntxt": false,
2021-03-27 16:29:31 +00:00
"toc-showtags": false
2021-03-19 17:21:00 +00:00
},
"nbformat": 4,
"nbformat_minor": 4
}