shallow-training/nncw.ipynb

2587 lines
512 KiB
Plaintext
Raw Normal View History

2021-03-19 17:21:00 +00:00
{
"cells": [
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 2,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2450,
"status": "ok",
"timestamp": 1615991459232,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "TGIxH9Tmt5zp"
},
2021-03-26 20:01:05 +00:00
"outputs": [],
2021-03-19 17:21:00 +00:00
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import tensorflow.keras.optimizers as tf_optim\n",
2021-03-19 17:21:00 +00:00
"tf.get_logger().setLevel('ERROR')\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import seaborn as sns\n",
"import random\n",
"import pickle\n",
"import json\n",
2021-03-22 20:49:29 +00:00
"import math\n",
"import datetime\n",
"import os\n",
2021-03-19 17:21:00 +00:00
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
2021-03-26 20:01:05 +00:00
"fig_dpi = 70"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fksHv5rXACEX"
},
"source": [
"# Neural Network Training\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l4zqVWyRAM0Z"
},
"source": [
"## Load Dataset\n",
"\n",
"Read CSVs dumped from MatLab and parse into Pandas DataFrames"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 3,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 331
},
"executionInfo": {
"elapsed": 2441,
"status": "ok",
"timestamp": 1615991459234,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "Hj5l_tdZuYh7",
"outputId": "fbfa9406-f662-4ebc-8ba2-67950714627c"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Clump thickness</th>\n",
" <th>Uniformity of cell size</th>\n",
" <th>Uniformity of cell shape</th>\n",
" <th>Marginal adhesion</th>\n",
" <th>Single epithelial cell size</th>\n",
" <th>Bare nuclei</th>\n",
" <th>Bland chomatin</th>\n",
" <th>Normal nucleoli</th>\n",
" <th>Mitoses</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.441774</td>\n",
" <td>0.313448</td>\n",
" <td>0.320744</td>\n",
" <td>0.280687</td>\n",
" <td>0.321602</td>\n",
" <td>0.354363</td>\n",
" <td>0.343777</td>\n",
" <td>0.286695</td>\n",
" <td>0.158941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.281574</td>\n",
" <td>0.305146</td>\n",
" <td>0.297191</td>\n",
" <td>0.285538</td>\n",
" <td>0.221430</td>\n",
" <td>0.360186</td>\n",
" <td>0.243836</td>\n",
" <td>0.305363</td>\n",
" <td>0.171508</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.300000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.600000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.400000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Clump thickness Uniformity of cell size Uniformity of cell shape \\\n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.441774 0.313448 0.320744 \n",
"std 0.281574 0.305146 0.297191 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.400000 0.100000 0.100000 \n",
"75% 0.600000 0.500000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Marginal adhesion Single epithelial cell size Bare nuclei \\\n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.280687 0.321602 0.354363 \n",
"std 0.285538 0.221430 0.360186 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.100000 0.200000 0.100000 \n",
"50% 0.100000 0.200000 0.100000 \n",
"75% 0.400000 0.400000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Bland chomatin Normal nucleoli Mitoses \n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.343777 0.286695 0.158941 \n",
"std 0.243836 0.305363 0.171508 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.300000 0.100000 0.100000 \n",
"75% 0.500000 0.400000 0.100000 \n",
"max 1.000000 1.000000 1.000000 "
]
},
2021-04-09 12:42:18 +01:00
"execution_count": 3,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('features.csv', header=None).T\n",
"data.columns = ['Clump thickness', 'Uniformity of cell size', 'Uniformity of cell shape', 'Marginal adhesion', 'Single epithelial cell size', 'Bare nuclei', 'Bland chomatin', 'Normal nucleoli', 'Mitoses']\n",
"labels = pd.read_csv('targets.csv', header=None).T\n",
"labels.columns = ['Benign', 'Malignant']\n",
"data.describe()"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 31,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"executionInfo": {
"elapsed": 2436,
"status": "ok",
"timestamp": 1615991459236,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "qc1Mku6h041u",
"outputId": "94e38c34-0419-4a02-ac8c-17bbc83f777b"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Benign</th>\n",
" <th>Malignant</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Benign Malignant\n",
"0 1 0\n",
"1 1 0\n",
"2 1 0\n",
"3 0 1\n",
"4 1 0"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 31,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h9QsJjWEMbLu"
},
"source": [
"### Explore Dataset\n",
"\n",
"The classes are uneven in their occurences, stratify when splitting later on"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2430,
"status": "ok",
"timestamp": 1615991459237,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "rjjiSYAZMa4k",
"outputId": "ae0c3772-00be-4f2b-80d2-9cd62a6b6e08"
},
"outputs": [
{
"data": {
"text/plain": [
"Benign 458\n",
"Malignant 241\n",
"dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.astype(bool).sum(axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E9lVYI14AUMf"
},
"source": [
"## Split Dataset\n",
"\n",
"Using a 50/50 split"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 4,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2604,
"status": "ok",
"timestamp": 1615991459418,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "L83Ae5l9wM35"
},
"outputs": [],
"source": [
"data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qf2U199fNjmJ"
},
"source": [
"## Generate & Retrieve Model\n",
"\n",
"Get a shallow model with a single hidden layer of varying nodes"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 5,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2598,
"status": "ok",
"timestamp": 1615991459419,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "SgoQ-NjWB0T5"
},
"outputs": [],
"source": [
"def get_model(hidden_nodes=9, activation=lambda: 'sigmoid', weight_init=lambda: 'glorot_uniform'):\n",
" layers = [tf.keras.layers.InputLayer(input_shape=(9,), name='Input'), \n",
" tf.keras.layers.Dense(hidden_nodes, activation=activation(), kernel_initializer=weight_init(), name='Hidden'), \n",
" tf.keras.layers.Dense(2, activation='softmax', kernel_initializer=weight_init(), name='Output')]\n",
2021-03-19 17:21:00 +00:00
"\n",
" model = tf.keras.models.Sequential(layers)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get a Keras Tensorboard callback for dumping data for later analysis"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(path='tensorboard-logs', prefix=''):\n",
2021-03-30 16:31:10 +01:00
" return tf.keras.callbacks.TensorBoard(\n",
" log_dir=os.path.normpath(os.path.join(path, prefix + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))), histogram_freq=1\n",
" ) "
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "markdown",
"metadata": {
"id": "QT5B9PTUN3pj"
},
"source": [
"# Example Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQGAUtIPAd6e"
},
"source": [
"## Define Model\n",
"\n",
"Variable number of hidden nodes. All using 9D outputs except the last layer which is 2D for binary classification"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 60,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 7889,
"status": "ok",
"timestamp": 1615991464716,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "fYA34P0Vu_pX",
"outputId": "aded18b8-aa7f-4362-a614-837c8a0f526f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Model: \"sequential_1\"\n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
2021-03-26 20:01:05 +00:00
"dense_2 (Dense) (None, 9) 90 \n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
2021-03-26 20:01:05 +00:00
"dense_3 (Dense) (None, 2) 20 \n",
2021-03-19 17:21:00 +00:00
"=================================================================\n",
"Total params: 110\n",
"Trainable params: 110\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = get_model(9)\n",
"model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KZSwFe-AAs1y"
},
"source": [
"## Train Model\n",
"\n",
"Example 10 epochs"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 61,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11304,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "s8U9Atu3zelS",
"outputId": "8439e8d2-7a5d-495f-a192-a34f01e95bfa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 1s 2ms/step - loss: 0.6257 - accuracy: 0.6607\n",
"Epoch 2/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 3ms/step - loss: 0.6226 - accuracy: 0.6651\n",
"Epoch 3/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 2ms/step - loss: 0.6326 - accuracy: 0.6424\n",
"Epoch 4/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 3ms/step - loss: 0.6158 - accuracy: 0.6696\n",
"Epoch 5/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 2ms/step - loss: 0.6228 - accuracy: 0.6534\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
"text/plain": [
2021-03-26 20:01:05 +00:00
"<tensorflow.python.keras.callbacks.History at 0x2cd249f3400>"
2021-03-19 17:21:00 +00:00
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 61,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"model.fit(data_train.to_numpy(), labels_train.to_numpy(), epochs=5)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 62,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11294,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "VnUEJdXovzi-",
"outputId": "02075086-352c-4a23-fac5-ad54d11e0e35"
},
"outputs": [
{
"data": {
"text/plain": [
"['loss', 'accuracy']"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 62,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics_names"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 63,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11285,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "r0vxP3Ah42ib",
"outputId": "061113ba-52db-4fbe-c7f9-b5d3d85438ed"
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=float32, numpy=0.6561605>"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 63,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics[1].result()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "z7bn8pKTAynt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 1\n",
"\n",
"The below function runs an iteration of layer/epoch investigations.\n",
"Returns the amount of layers/epochs used as well as the results and the model.\n",
"\n",
"Using cancer dataset (as in E2) and 'trainscg' or an optimiser of your choice, vary nodes and epochs (that is using early stopping for epochs) over suitable range, to find optimal choice in terms of classification test error rate of node/epochs for 50/50% random train/test split (no validation set). It is suggested that you initially try epochs = [ 1 2 4 8 16 32 64], nodes = [2 8 32], so there would be 21 node/epoch combinations. \n",
2021-03-19 17:21:00 +00:00
"\n",
"(Hint1: from the 'advanced script' in E2, nodes can be changed to xx, with hiddenLayerSize = xx; and epochs changed to xx by addingnet. trainParam.epochs = xx; placed afternet = patternnet(hiddenLayerSize, trainFcn); --see 'trainscg' help documentation for changing epochs). \n",
2021-03-19 17:21:00 +00:00
"\n",
"Repeat each of the 21 node/epoch combinations at least thirty times, with different 50/50 split and take average and report classification error rate and standard deviation (std). Graph classification train and test error rate and std as node-epoch changes, that is plot error rate vs epochs for different number of nodes. Report the optimal value for test error rate and associated node/epoch values. \n",
"\n",
"(Hint2: as epochs increases you can expect the test error rate to reach a minimum and then start increasing, you may need to set the stopping criteria to achieve the desired number of epochs - Hint 3: to find classification error rates for train and test set, you need to check the code from E2, to determine how you may obtain the train and test set patterns)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
"execution_count": 16,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 11274,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "mYWhCSW4A57V",
"tags": [
"exp1",
"exp-func"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [],
"source": [
2021-03-26 20:01:05 +00:00
"# hidden_nodes = [2, 8, 16, 24, 32]\n",
"# epochs = [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"hidden_nodes = [2, 8, 16]\n",
"epochs = [1, 2, 4, 8]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_parameters(hidden_nodes=hidden_nodes, \n",
" epochs=epochs, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
2021-03-26 20:01:05 +00:00
" run_eagerly=False,\n",
" tboard=True,\n",
2021-03-19 17:21:00 +00:00
" \n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for idx1, hn in enumerate(hidden_nodes):\n",
" for idx2, e in enumerate(epochs):\n",
" if print_params:\n",
" print(f\"Nodes: {hn}, Epochs: {e}\")\n",
"\n",
" model = get_model(hn)\n",
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
2021-03-26 20:01:05 +00:00
" metrics=metrics,\n",
" run_eagerly=run_eagerly\n",
2021-03-19 17:21:00 +00:00
" )\n",
" \n",
" if tboard:\n",
" if callbacks is not None:\n",
" cb = [i() for i in callbacks] + [tensorboard_callback(prefix=f'exp1-{hn}-{e}-')]\n",
" else:\n",
" cb = [tensorboard_callback(prefix=f'exp1-{hn}-{e}-')]\n",
" \n",
2021-03-19 17:21:00 +00:00
" response = {\"nodes\": hn, \n",
" \"epochs\": e,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TRAIN\n",
" ##############\n",
" \"history\": model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
" epochs=e, \n",
" verbose=verbose,\n",
" \n",
" callbacks=cb,\n",
" validation_split=validation_split).history,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TEST\n",
" ##############\n",
" \"results\": model.evaluate(dtest.to_numpy(), \n",
" ltest.to_numpy(),\n",
" callbacks=cb,\n",
2021-03-19 17:21:00 +00:00
" batch_size=batch_size, \n",
" verbose=verbose),\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"loss\": model.loss,\n",
" \"model_config\": json.loads(model.to_json())\n",
" }\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"model\"] = model\n",
"\n",
" yield response"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "r-63V9qb-i4w",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Single Iteration\n",
"Run a single iteration of epoch/layer investigations"
]
},
{
"cell_type": "code",
"execution_count": 17,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 313592,
"status": "ok",
"timestamp": 1615991770468,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "ZmGFkE9y8E4H",
2021-03-26 20:01:05 +00:00
"outputId": "243fb136-bc07-4438-afb7-f2d21758168d",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nodes: 2, Epochs: 1\n",
"Nodes: 2, Epochs: 2\n",
"Nodes: 2, Epochs: 4\n",
"Nodes: 2, Epochs: 8\n",
"Nodes: 8, Epochs: 1\n",
"Nodes: 8, Epochs: 2\n",
"Nodes: 8, Epochs: 4\n",
"Nodes: 8, Epochs: 8\n",
"Nodes: 16, Epochs: 1\n",
"Nodes: 16, Epochs: 2\n",
"Nodes: 16, Epochs: 4\n",
2021-03-26 20:01:05 +00:00
"Nodes: 16, Epochs: 8\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
2021-03-26 20:01:05 +00:00
"# es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience = 5)\n",
"single_results = list(evaluate_parameters(return_model=False, validation_split=0.2\n",
" , optimizer = lambda: tf.keras.optimizers.SGD(learning_rate=0.5, momentum=0.5)\n",
"# , callbacks=[es]\n",
" ))"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "mdWK3-M6SK8_",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Train/Test Curves\n",
"\n",
"For a single test from the set"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 68,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 517
},
"executionInfo": {
"elapsed": 314527,
"status": "ok",
"timestamp": 1615991771417,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "F9Xre1o6SesD",
2021-03-26 20:01:05 +00:00
"outputId": "d6b817aa-58cd-4510-807f-e5e6bcf62f18",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Nodes: 2, Epochs: 4\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAGuCAYAAACNy6eFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAACAaUlEQVR4nOzddZxU1f/H8dfZJpalY5eU7u62SJVUsQMVMb4WtvwwsMUEJFQMREJQJERBlqW7uxtZuhd29/z+uAMsuMDCxp2ZfT8fj3kwc2s+Z2aYu585536OsdYiIiIiIiIiqRfgdgAiIiIiIiL+QgmWiIiIiIhIGlGCJSIiIiIikkaUYImIiIiIiKQRJVgiIiIiIiJpRAmWiIiIiIhIGlGCJV7BGPO1MebFFG670hhTP71jyijGmOLGmPgkjycaY+64xLbNjDEbrvF5Ghtjll5rnCIicp7OWzpviVyK0TxYkhrGmGNJHmYDTgBnP1QVrLXbMj6qjGeMaQBMAApaa08lWZ4d+Beob61ddol9iwMbrLVBKXieZsBga22pFGy7BbjHWjsjBU24ZsaYITjxv5OezyMikhZ03nJk5vNWkucbCDwAFLLW7s+I55TMQT1YkirW2uxnb0AcUDHJsm0AxuHXnzVr7SxgP9DmolXtgY2XOkmJiEjG0nnLkdnPW8aYUKATcAxItvctnZ73ikmp+D6//vIQ9xhjhhhjvjLG/IPz62BJY8xDxph1xpijxphlnl+1km7/uuf+A8aYf4wx/Y0xR4wxq4wxNZJsu8UY0yjJfl8YY6Z4jvuXMSZ3km0fMcbsMMbs8dy3xpjCycQ72Bjzfxct22SMaWSMyecZ/nDIGLPPGDPsEs0eCtx10bK7gZ+MMbmMMX969o81xgz0fLkn99pFG2Pu8dwPNMZ8bozZb4xZC9S7aNsvjTG7PLH9ZYwperY9QFHgL2PMMWPM3RcP0zDGVDTGTPfsu9AY0/Ci1/h5Y8xqz/qvLtHmSzLGhBlj+npe+23GmJ5n/2AxxtQzxiz2vL87jTHPXm65iEh603nrnMxy3moDHAfeB+65KMYSxpjxnjbsNsY87VkeZIx5yxiz1Rhz2BgT7Vn+n2GQSd83T2wvGmNWAxsu9zpc6vmNMYU9n61sSbZ70Bjz1xXaKS5QgiXp6U6gBxAObAH2ADcAOYEvgV8u9WUNNAZigFzAaODTyzzP7cCzQD4gEPgfgDGmMvAR0A4oATS4zDGGe46DZ986QDAwE3ge2AzkBaI8sSfnJ6C1MSbCc4wCwPXAzzj/1/p69q8C1AIev0w8Zz0GNAcqev7tctH6mUB5oBCwA/gCwFrbFdgG3Oz5VXZo0p2MMSHAH8BInNftQ+APY0yuJJu1w3kfKgG3G2OapyDepN7wxF0eaIRzArvPs+4z4GNrbQ7P8aOvsFxEJCPovJV5zlv3ACOAX4C6xpjrPM8TBIwH5nvaXtYTM8BLQEvPc+QGel7x1TivA9DM03a4xOtwqee31u4AFgC3JjlmF5z3SryMEixJT79aaxdaa+OttWestROstduttQnW2kE4Y95LX2LfNdbaYdbaBJwvj6qXeZ6R1tplnjHkvybZtiMw2lq7wFp7ErjcNUL/APmMMRU9j2/3HNcCZ3C+AItYa+M8wyr+w1q7DljueV5wTtQzrLU7rLX7rbV/ePbfDQzASTqupDPQx1q7x1q7i4tOktbaX6y1hz3t+yCFxwSoCwRYa7/wvDfDgbU4J46zPrPW7vN8qUdz+fcgOXcCb1prD3qG3XzC+RPtGaCUMSa3Z/3iKywXEckIOm9lgvOWMSYn0BoY7jk/zeF8L1ZdnAT7LWvtKWvtEWvtQs+6B4DXrLXbPJ+JmBTGDvC5tfZfT7sv9zpc7vl/wnMe9STDDXCSefEySrAkPe1I+sAY084Ys8jTHX4IyA/kucS+/ya5fwLIfpnnudS2BS+K4YJ4kvKcEEcBdxhjDM4JYrhn9Uc4v6pNM8asMcY8fJlYfuL8cIu7PY8xxoQbY37wDPs4AvTh0m1PqhCwPcnjpPcxxrxmjNngOea8FB4TIPLiYwFbPcvPupr34FLPkfRi8aTH74rz6+YGY8wMc7661qWWi4hkBJ23Msd563Zgl7V2nufxLzhtBygMbLXWJiazX2GcnsFrcfFn61Kvw+WefxTQ1NNr1xn401p75BrjkXSkBEvS07kSlZ4hFcOA14A81tqcwF7ApOPz78HpXj/rP2PYL3J2uEU9INFaOxfA8+vR/6y1RXF+vfry7FCCZPwCNDLGNMUZojDKs/w5nCEN1TzD354jZW3fDRRJ8vjcfc9zdMf5FS4CqHPRvpcrEbrrouOCM/Z9VwpiSqldnmP+5/jW2rXW2ttx/lj5BeezccnlIiIZROetzHHeugeINM51bnuAXkAZzzDL7UAxT9J6se1A8WSWHweynH3g6V26WNLP1uVeh0s+vyeZmoTT49gF5xo68UJKsCSjhAIhOCcnjDH/w/niTk9jgI7GmBrGmDDg1StsPx2nW/4dnHHZABhj2hhjrvN82R3G+ZJMSO4A1tq9OMM2fgDGJfllKRzn17TDxphiOF+sKTEKeNYYU8AYUwh4Msm6cJxhIPtwSg2/ftG+e0n+RAAw19O2Jz0X7XbGGQv+ZwrjuliQcYpanL0F4Zz43zDOhdJFcE7Ov3ie925jTB5rbTxwFM/reanlIiIu0HnLD89bnrY0xLk+rJrnVhHnuqd7cHqTjuKcv8KMMTmMMTU9uw8B3jHGFDFOMY8mnuXrgFzGmKaexPyNK4Rxudfhcs8PTg/js562T7iatkvGUYIlGcLzhd0D55eXPThd4dc08eBVPOdS4GWci2K3AGfHMMddYvtEnItnryfJiQooA0zF+cIbDzxjrd16maf+CedXtZ+SLPscZ9jEQZzx9mNS2IwBOBdNr8YZT/5LknV/4lwkuxVnDP3FY+w/AN73DG25oEqUtfY0zoWyXXDK9L4C3GqtPZjCuC72f8DJJLevgbdxxsevAWZ7Yv/es31rYK0x5ijwNOeLX1xquYhIhtJ5y2/PW3cD0621sz3Xie2x1u4BvuJ8ufa2ONc37cY5j50drv4RMMUT936cni+stYdxCpWMwBlCOP8KMVzydfD8wHip5weYCBQAxlhrk/1ciPs00bBkGsaYssAyIMzqgy8iIl5O5y1JjjFmBfA/a+0Ut2OR5KkHS/yaMaatp4s9AngPGKuTlIiIeCudt+RyjDE3AVlxeijFSynBEn93B05VoS04n/enXI1GRHyOMWaMMeagMWbUJdbXMcas9FQEu5p5cUSSo/OWJMsYMxxnyOXTl6gyKF5CQwRFREQuwxjTDOei9PuttZ2SWT8feBhYiXNdxSPW2uUZGaOIiHgP9WCJiIhchrU2GqdYwH8YYyKBIM+ksQk4vy63zcDwRETEywS5HUBSBQoUsCVKlEjVMY4dO0b27Fc7H6rvUPt8m9rn2/y9fZD6Ns6dO/dfa23BNAzJ20UCO5M83gk0TW5DY0xXnMm0yZo1a93ixYun6okTEhIIDAxM1TG8mdrn29Q+36b2XdmqVasueb7zqgSrRIkSzJkzJ1XHiImJoUmTJlfe0Eepfb5N7fNt/t4+SH0bjTFb0i4a/2KtHQwMBqhXr57V+e7y1D7fpvb5NrXvyi53vtMQQRERkWu3C4hK8jjKs0xERDIpJVgiIiLXyFq7C0gwxlQxxgQCd+JMEisiIpmUVw0RFBER8TbGmMlAVSCbMWYH0Bl4A+jqSbCeBIYBYcCPqiAoIuIF4uPh+HE4duw/t3zz50OJElCkSLo8tRIsEclU4uPj2bFjB6dOnbrqfcPDw1mzZk06ROU9rqaNhQoVIiIiIp0jcp+19sZkFrdOsn4OUDG1z3O1n01//zxea/vCwsIoXLgwQUH6E0fEZ5w5899E6OjRZJOjFN9Onrzk05UHqFhRCZaISFrYsWMH4eHhFCtWDGPMVe179OhRwsPD0yky75DSNp46dYodO3ZkigQro1ztZ9PfP4/X0j5rLQcOHGDHjh2ktkqjiCTDWoiLS3mSk9Ik6fTptI0zMBAiIiB7ducWHn7
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 1050x490 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"single_result = random.choice([i for i in single_results if i[\"epochs\"] > 1])\n",
2021-03-19 17:21:00 +00:00
"single_history = single_result[\"history\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(15,7))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## LOSS\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[0]\n",
"ax.set_title(\"Training vs Validation Loss\")\n",
"ax.plot(single_history['loss'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_loss'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## ACCURACY\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[1]\n",
"ax.set_title(\"Training vs Validation Accuracy\")\n",
"ax.plot(single_history['accuracy'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_accuracy'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
2021-04-06 17:29:15 +01:00
"# ax.set_ylim(0, 1)\n",
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
"print(f\"Nodes: {single_result['nodes']}, Epochs: {single_result['epochs']}\")\n",
"# plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "0IQ7HfJCSDud",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Accuracy Surface"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 69,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 705
},
"executionInfo": {
"elapsed": 315450,
"status": "ok",
"timestamp": 1615991772345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "X3MWHLxJElbc",
2021-03-26 20:01:05 +00:00
"outputId": "134671d0-bfd3-4ee6-aa02-1a2a5b23f3ca",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAFWCAYAAABwwARRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAADBx0lEQVR4nOydd3hcV5n/P2eapJlR771asiz3RkISkpCErA0YQkJICLBLCyVh+UEgECCh7bILIbALoSxkIUBCQm8hgVASlpLYVrcl925JttXbSKMp5/fHzB2PRjOjO9KMms/neeaxNbede+fe+z3ve97zvkJKiUKhUCgUiuWPYbEboFAoFAqFIj4oUVcoFAqFYoWgRF2hUCgUihWCEnWFQqFQKFYIStQVCoVCoVghKFFXKBQKhWKFsCiiLoR4VAjxCR3rXSOEOLoQbVIsLkKICiGEO+jvDiHE5f7/CyHED4QQQ0KIX/i/+7wQYkAI0bRYbU4EQoirhBBtOtf9mBDi4US3aT4IIe4QQvxmjts+L4R4UwLaFPUaCyH+RQjxx3nsXwohSvz/f0YI8YagZdPuWyHE+4QQF4QQfXM93nJGrxYsJAv1bhFCnBRCXBnv/c4q6v4DjwshbEHfWYUQo0KIk/FukF5CRWAe+5nXA6xIDFLKBinlC/4/rwKuAAqklDcJIcqA9wI1UsotC9UmPffcfIVISvlXKeUGnet+Tkp591yPtRBIKR+XUr56sdsRTOg1TtTL1X+sHVLKH/mPM+2+FUJYgC8AV0opcxJx/EgEdzwUF1msd0s80WupdwGvDfr7NUBP3FujmIYQwrTYbYgHcTiPMuC4lHIy6O/zUsqBRWjLvFjs4y8kl9K56iT0vs0DzFLKw7HuSF3b6cTxesz53bJkkFJG/QAngU8DTwd991vgk8DJoO8agL8CQ0ATcEXQsmrg78Ao8DPgR8AngpbfBRwB+oDvATb/99cARyO06zAggTH/pwww+tt6CjgPPASY/OtfBrQAI/g6KR8AqoBJwO3fR0eEY30V6Paf27NAWdCySv/16MfX0flX//cm4DP+tgwDz0c6J/95lARd73uBA9r1jfX4QIn/PG1B670VeDbC+YX97YA3Ac+FrPtd7bfzX3Pt2AeAfwq5b6adR5jjfsz/O530t9sdsv2VwB0hv9EngQnA6//7If/6twAdwADwayAv+Hr774s+4N+BFOBh/zU9C3w06LiPAl8B/oTvfn0WyIp0z4Wcz/2Ax9/eMf/5hTt+NfB//uvdDXwuaB/T7g//8d4DnPBvf1/Qsk8Bj4Sc5wP+a3ASuDFo3VXAC0R4BkPO43ngs/iel0F8z2RK0PJYrvW/AH8M2vYq/36HgL8A9UHLtgHt+O7db/qv0Zv8y14FHPK3/yRwW5h23wD8I+Qd8R3//22AA7AGX2PgEXz3ksP/m93hb/OfgW/429IJbI7yjnwbcAY4B9zJ9Of5eXzP0VVMv2+/C4xz8X76iX/9q/E9g0P+bav931fgewbeje/99RjR33efAh4HfuK/ZruBSv+yZ/3HHfcf+6oI98CngUb/NfgRkORfFvqbVjD92ZX43ukn/efxLuCl/us4CNwf8rx9Fd+9MAL8Bv/zNofrUQv8zb+f88CDEX6vZOBr/t/rNL5nxhDmN3oozLaPEuH94F/+OnzvvEHgKaA4aNkOfM/HAH79xOelgejvpFnv/WltjLYw6OV6Nb6bNtf/OYPPHaqJjgU4ju/FbAbe4G94pn/5HuBz/vVeC7i4KAyvB/YB5f4T+yHwxUgCGOlG8n/3YXwPYy6QATwH3O1f9iJwh///mcCmcDdohGPdBqT72/cd4Jf+7034btRP+W+UNGCLf9nH/eetdTZeFumcmCnqLwL5+F+kczz+n4Hbg47xLPAvYc4t4m8HpOK7ufOD1h3EJxAGoM2/nQm4HOgNWnfGeYQcdye+h7HW/1v9kTCiHuElMu0aAtvxPQjr/OfwBeCnQeu68T24Zv81/Bq++8wOFOETqFcFPbTngPX+a/on4NOR7rkIL8M3hbQ19PjV+J4pk/9angZeG+HcJL4Xsx1Yi6/DoL3YPsV0UXcDH/Hv907gVNB+GvG9pM3Aq4Epoov6qaDf5jngs3O81oHfDsjGd/+8zr/8w/g68yZ899YZfB0YM/A+/740UT/Hxc5mAbAmTLvt+MQ5Bd99dwI45F92HfBihGt8Ev+9FnS/uYDb8T27/wb8JcK1WotPRF7iP+73CSPqEY5bwfR7vhTfM3SV/7jvA/YGrSvxdXaS/ceK9r77lP9aXOu/vt8HvhfunRPlHjiA772cge8ZeWuE5zH0PCTwJL4O1LX4hPJnQBaw2v93VdDzNhR0/X4I/GCO1+NJ4D5A4OvEbY9wbv/uP79MfO/nw/jfjaG/UZhtHyXy+2E1PgPuSiAJn/g/51+Wi0+UX4XvXv8Cvvtbe8dFeyfNeu9Pa2O0hcE3PPBl4G7/57/wWb6aqF9FiDWGzyq43X9TTALJQcv+xkVR/x3wxpCHRNtvxAsceiP5vzsIvDTo71dx0UL+K77eUVbINv/CLKIesn4d0Of//xX4XkSGMOsdAW4I8/2Mc2KmqN8eh+O/Dfi1///5+MQ5Lcx6EX87//9/DtwVdD2b/f+/DDgcst1PufhwzHYe3wU+FfT39cxd1L8JfCzo71R8L2WTf91xLlowAt/Lriho/buBR4Me2q8GLXsvFztRM+65MOf1PDNFPXD8CNv8BxE6sv57Y0vQ33u42AH4FNNFfVi7F/C9UCW+F3IFvhdpUtB+/kp0UQ/9bQ7Heq1DfzvgzQSJI76OYRe+DuHVwImgZQLfva0J4hngHYB9luu/F5+Q3ILP8tmP7/7/VJRrfJKZor4v6O81wFCE431Su3f8f9cwd1H/KPCtkP33+ter8O+3MGhZtPfdp4DfBC3bCbSG3Fezifo9QX9/AfivCM9j6HlIgjwb+Kzm1wX9vZuL9/CjYa7fpP/3j/V6/ADf/VkY6bz86x0Drg36+13A78P9RmG2fZTI74f7md5xsuN7NgqBf9Z+m6Dncwqfts72TtJ172ufWKLfHwfeiM899XjIsiL/gYM55f++EOiVF8dDCVm3DPgff2TzED7Bz42hXcGUAc8E7etxfONW4LsoDcBRIcTftMhqPQghPi6EOCqEGMH3Us32LyrBZw15w2xWgs9SmAtn43D8nwJXCyEy8XlDfielHAmzXrTfDny9Xy169w343HDgu9aV2rX2X+9/wvd7hz2PEApDjhvahlgoAz4e1I4z+HrBBf7l56SUWoBbLr5efWfQ+p/D9+LXOB/0fwe+h3M+BB8fIUSxEOIXQohzQohh4P9x8TcNh9729Gr3gpTS4f/Oju869EopnUHrRvttYOZvo/2usVzrUIrweSXwt9Hr3157T5wNWiZD2ngLcBNwVgjxOyFEfYRj/BVfR/UqfO+S4L//GuV8Q9F7zeN9H7855JmyAcX+5V4pZU/I+pHedzD/+3g+218I+v9EmL+D9xV6/ZLwWfWxXo978VnBrUKIFiFEpADNafch0993eoh0XULv7zF8Q5Pa/X0maJnDvwxmfyfpvfeBGKa0SSkb8V3oTCnl3pDF3fhcJcGU+b/vAXKEEMlBy4LX7QL+WUqZEfSxMTsyzHdd+Hpg2n7SpZRr/O0/JKW8Fd9N/yTwRJT9BBBCXI2vN7YTnwt8e9DiM0C5EEKE2fQMvt5kKOP4fkBt//lh1pFBy+d0fL+A/x64GZ/HJLQjphHttwPfuNBGIUQVPrftj/3fdwEHQn43u5TyP8KdRxh6Qo4b2oZY6MI3ThfclhQppSYKwe3oA5z43H/aumlSyh06jhP1XomyTuh3/4bPDV0rpUzH5/kKdw/Fi3NArj/aWmO2yOfQ30Z7ecZyrUPpxndvAb6piv59a++J0DYF/pZS7pZSvhLfi64N33h3OP4KvAyfBfRX/+dafM/N3yNso+d3jUS87+Nvh1xbq5RSa3doOyO+7xLMtHcY0zvEcyH0+jnxDQHGdD2klD1Syrfh62B+CvhxiO5oTLsPmf6+mw+h97cNX2ddu79Lg5alcLEjH/W
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(epochs, hidden_nodes)\n",
"\n",
"shaped_result = np.reshape([r[\"results\"][1] for r in single_results], \n",
" (len(hidden_nodes), len(epochs)))\n",
"\n",
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
"surf = ax.plot_surface(X, Y, shaped_result, cmap='viridis')\n",
"ax.set_title('Model test accuracy over different training periods with different numbers of nodes')\n",
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
2021-04-06 17:29:15 +01:00
"# ax.set_zlim([0, 1])\n",
2021-03-19 17:21:00 +00:00
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "C793_RHvSGai",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Error Rate Curves"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 70,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 668
},
"executionInfo": {
"elapsed": 316211,
"status": "ok",
"timestamp": 1615991773109,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "tpClZMptrq-q",
2021-03-26 20:01:05 +00:00
"outputId": "f9fe93f9-7b67-4772-83e4-9e3567fd9318",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAf4AAAFECAYAAADGPlw2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABd10lEQVR4nO3dd3wc1bXA8d/Rqjd3reQq44otuWjdcaMZbExLKAmmOHl2Qgu9B/IgAfLoJZTQguPQITQbGwy4CBuMG+4FbFxwlauKJVntvj9mJK+kXWkl72ql1fl+PvuRdubOzLmzM3tm7szeEWMMSimllGoewoIdgFJKKaUajiZ+pZRSqhnRxK+UUko1I5r4lVJKqWZEE79SSinVjGjiV0oppZoRTfzNhIhcJCK7RCRPRJIaQTz3iMhzfpjPZBH5qh7TdRaRIye6/PryV/3ruMxkEflWRHJF5BYfys8Xkcvt/yvFW3V7EpFRIrLFfj84kPVozETkfhF5NdhxlBORaSJyr5dxFZ+vh3GjRGRVDfOtadp67ZMNRUTGisjmIC37nyJyh49l14nI8EDEER6ImdZGRPLc3sYB+UB5hwJ9jDE76jCvacBmY8yD/osweAJYn0eB3xlj5vh5vvVijHk4yMvfAbQsfy8i84FXjTFv+HtZIpKK9ZlW7G9Bqv8fgG3GmBF1ndBDvJW2JxF5B3jEGPPyiYfpu0B+bs2ZMeYboH+w42hMPO3HdWWMuboOZfvWdzm1CUriN8bEl/8vIoVAX2PMtmDE4k5Ewo0xJd7e13V6D+MFEGNM2QmGWh+dgfV1naiu66A5akLrqF7bgI/z0u1LNWtB/n6vG2NMUF9AIZBq/98aeAvIAn4GrnIr93tgO5ALbALGAlcBxcAxIA/4p5dlXASsAw4BnwJJ9vCxwGbgAeAA8BAwH/gbsMyebwTwK2ADcBiYCXSwp08FSoCrgV3AGx6WPQ14DpgLFAA97Lr8aNdlNTDWLuuxPkA6kGkvfzkwyB4eBjxrx34EWAq09RBDHlaLylFgiT1sFPCDPd0C4GS38ga4HtgKfONhft3seI4Au4GHa/h8J9qfVy6wDfiNPfx+rDM198/hL/ZntA04y20ePYDv7Hn8F3gXuNceNxn4yq2sx3XlIa5UoMT+/z6gFGtbzAPusYePsedxxN4uunn73IFWwOf2Z7EfeBmIssv/aK/TPPvV2b3+dpnatrGpwB5gL277hYd69QW+sWNeDpxiD/8nlbetgR6mHYy1PebY5TOByz18XpW2J6x9qwyr5W6vXaYz8Blw0K7X2W7L2QbcYQ/f5uM+Wm3b8Pa5eajXdcBP9mfzbyDObduZC7xi13klMKC2dWmPa8fx76oDwP+5rac3gfexttfvga72uBjgbbseh/Cwb9nlatqWvK4Pt31zER72FQ/LmY/13bfMrv+7VZfj47YRZ9f5CLAC63vUfZ+sbT/ydds2wDVY30sHgLvdxkUDz9vz2GGvnzB7nAN4Bmtb3ATcVaVuvn5neNqPp+Hj97tbPnD/7poLvGiv1/VARpX9ZKTbdM8CX9vznQO0dis7Fdhp13+qHWdHr+vS24iGelE58X8GPA5EAb2xkko/e8PKAXrY5bpwfGeqWJFe5j/EXiHpWEn8UeADt427xN5IIrB2zPlYO1Z3e2M6GcgGRtpxPQvMc9twDdaOEA3EeFj+NHsjdWG1sEQAE4BO9gY51f6wojzVB4jHSi6/tstfgLVhRwNnY+20ifa4DCC+hp2mo/1/G6yN/Fd2PLdjfTGGu5X9xJ6vpzp1w9qZw7E29B3ABV6Wu5fjyScZ61IOVE/8JcCd9jz/AGx3m8cyrC+oCOBcoAgPib+mdeUhrlTsxO/2JXi52/tOWF+6o+x5/QlY6u1zt9fpufY2koL1BXiTp2V5qH9vat/GngEigXFYXzoJHuoUiXXAfIO9ri7FSg6tattX7Gl/wfpijbDrW4KHxF91e/LwJRUGrLLjCAeG2+vS6VZ2MeC0150v+6i3baPS5+ahXhcDa7C+M2KwkvXjbttOSZU6/2wvp7Z1OQd4FUiw5zvcbT3lA6fa85kO/NsedzXWQU2MPW6Ul5hr2pZqWx9LgIft+C/AOtirKfFvsNdNS6wDr9+5LWezj9vGo8BXWN8Xve2y5fukL/tRrdu22zb3PtZ+noaVO8oPIspP2lphJeQfgcn2uGuxEnAy0B5r2yyvW72/M070+x1r+ysGfmuXfRBY4GWfmmbPpx/Wd87XwAP2uHSsg6pBWNvW6zSVxG9/KEeBCLdxj2PtSHFYX4wXlK/AKiu+psT/T9zOBLB21GL7QxprLzO8ys7gfiR5H/aO67ahFGPtkKn2Ck6pYfnTgJdqWQd7gDRP9QF+A8ypUn6ZHfvpWEewQ7CamGpahnviv6LKBhaGtfEPdys7vA6f4d+xv0w9jPsFmEKVAxKqJ/5sjh+hx9oxtLTXcYH75451FuYp8XtdV7XtxFRP/HcBL1eZZr89nS+f+x85nrwqLctD/X3Zxtq4jc/C7czUbfgo7DNot2HfAb+tbV/BOpDb6vZe7M+uPol/GPBjlfl/wPEv4m3lMfm4j3rcNjx9bh7q9Tlwmdv7NI63Mkz2UucRNa1LoAPWwWech+XdD8xwez8BWGn//z9YZ+N9fd23PGxLXtcHVgIvxC1pAQtr+MznA7e6vX8UeNptOeXJsbZtYyuVz2of5Pg+6ct+VOu27bbNudzeL8E+4QC2AKdWWWdf2P/PK9/27PdT3OpW7+8Mt32qXt/v9va3xq1cH+CIl31qGvAPt3HXAh+7bXP/chvXjVoSf2O6q78z1pHMfhE5Yt9x/Ucg2RhzFGuHuwHYJyLvi0j7Osz3z27z/AXraDXZHr/XVL/GuNPt//ZYR4AAGGPysJqMypdfZozZU0sM7vNDRC4QkRVuMSVhHeV7i39MeVm7/MlAe2PM11hfmi8De0TkcRGJqCUWT3Uqw1ov7ut0Z9WJ3OLvICIficheEckGbqoh/ouAC4GdIvK5iJzspdx+Ow6MMfn2sHisz2m/MeaYD7F5XVfe6lKDzsAVVeYVh/WlD1U+dxFJEJHpIrJTRHKAJ/G+TqqqbRsrNcYcdCufj7VuPM3nlyrDtuNb/VNwW6/G+gbxug3UojPQtcq6O9teRrmdVcrXtI962zZ8jeUlt3kvxGqmrxaHW51TqHlddgSy7O8lT/a5/e/+Wf0H60ztIxHZLiJ3e5rYh23J2/pIsccVupWtWgdfY3VX27aRUmU57v/Xth/5um3XFm+lfYjK231t8Z3od8aJfL/7sv5rK5tcJYZa99vGlPh3YTXztDLGtLRfCca+C9IYM8sYcxrWTncMqzkLjv8aoKb53uc2z5bGmBhjTPnK8TS9+7DdWBsHACISh/Uh7vZx+ZXKiEgU1nW+P2Md6bbEOsoVL/PbhXXk6h5/nDHmLQBjzFPGmAFY1+DOAib5EE/VOglW09RutzI11etBrEsFPY0xLYCn3eKvxBjzvTHmHKxm3VVY17PqYi/QTkQi3YZ19FK2xnVVC0/r/ZUq84o1xizyUv4WrIQywBiTaL/39plWVds25qvdWJ+ju84+zmcP1dert/Vcm13AhirrLt4Y83e3MqZK+Zr20Zr4sv9fVXWbcBvvqc57qHld/oK1Tcb6EN/xQI0pMsb8xRjTE2tfvVFExnooWtO2VJM9QFsRiXYbVrUO9VHbtrGnynLc/69tP/KXSvsQlbf72uLz9TvD27ZWl+/3QNjL8QMp8GG/bTSJ3xizC6sp7UERiRWRcBHJEJE+IuIUkYkiEoOV9POxbuoBa6Wm1jDr14HrRaQ/gIi0FpHz6xDaB8AFIjLCTj4PAt/6cJbvTRTW9awsO54bqXwGUrU+M4GB9lFkuIjEiMjZItJCRAaJyGARCce64aOY4+ulJrOB/iJyvj3tzVjN6ct8rEOCvbw8EUkDvP2eN1JELhORRDu2PB/jq2CsX3tsAO4SkQgROQcY6qW413Xlw6Kqrve3gIvF+j1zmH0WdlEN0ydgbZfZItIFqymu3AEgTES87ZD+2sa+BxCR6+36X4x19vK5D9N+B0SIyB/
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for layer in hidden_nodes:\n",
" plt.plot(epochs, \n",
" 1 - np.array([i[\"results\"][1] \n",
" for i in single_results \n",
" if i[\"nodes\"] == layer]), \n",
" label=f'{layer} Nodes')\n",
"\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.title(\"Test error rates for a single iteration of different epochs and hidden node training\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
"plt.ylim(0)\n",
"\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "7mJaKjlCxEkt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Multiple Iterations\n",
"\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the epoch/layer investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs |\n",
"| --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|3|0.2|0.05|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|4|0.08|0.04|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|5|0.08|0|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|6|0.06|0|128|1, 2, 3, 4, 5, 6, 7, 8|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|7|0.06|0|35|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs |\n",
"| --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|3|1|0.3|20|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|4|0.6|0.1|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|5|0.05|0.01|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|6|1.5|0.5|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-22 20:49:29 +00:00
"execution_count": 30,
2021-03-19 17:21:00 +00:00
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "-lsKo4BCP3yw",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_param_results = list()\n",
2021-03-19 17:21:00 +00:00
"multi_iterations = 30\n",
"for i in range(multi_iterations):\n",
" print(f\"Iteration {i+1}/{multi_iterations}\")\n",
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)\n",
" multi_param_results.append(list(evaluate_parameters(dtrain=data_train, \n",
2021-03-19 17:21:00 +00:00
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
2021-03-22 20:49:29 +00:00
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=1.5, momentum=0.5),\n",
2021-03-19 17:21:00 +00:00
" return_model=False,\n",
" print_params=False,\n",
2021-03-22 20:49:29 +00:00
" batch_size=20)))"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-22 20:49:29 +00:00
"(Iterations x [Test/Train] x Number of nodes x Number of epochs)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 121,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30 Tests\n",
"Nodes: [2, 8, 16, 24, 32]\n",
"Epochs: [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"\n",
"Loss: categorical_crossentropy\n",
2021-04-06 17:29:15 +01:00
"LR: 0.6\n",
"Momentum: 0.1\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_param_epochs = sorted(list({i[\"epochs\"] for i in multi_param_results[0]}))\n",
"multi_param_nodes = sorted(list({i[\"nodes\"] for i in multi_param_results[0]}))\n",
"multi_param_iter = len(multi_param_results)\n",
"\n",
2021-03-22 20:49:29 +00:00
"accuracy_tensor = np.zeros((multi_param_iter, 2, len(multi_param_nodes), len(multi_param_epochs)))\n",
"for iter_idx, iteration in enumerate(multi_param_results):\n",
2021-03-19 17:21:00 +00:00
" for single_test in iteration:\n",
2021-03-26 20:01:05 +00:00
" accuracy_tensor[iter_idx, :,\n",
" multi_param_nodes.index(single_test['nodes']), \n",
2021-03-26 20:01:05 +00:00
" multi_param_epochs.index(single_test['epochs'])] = [single_test[\"results\"][1], \n",
" single_test[\"history\"][\"accuracy\"][-1]]\n",
2021-03-22 20:49:29 +00:00
" \n",
"mean_param_accuracy = np.mean(accuracy_tensor, axis=0)\n",
"std_param_accuracy = np.std(accuracy_tensor, axis=0)\n",
"\n",
"print(f'{multi_param_iter} Tests')\n",
"print(f'Nodes: {multi_param_nodes}')\n",
"print(f'Epochs: {multi_param_epochs}')\n",
"print()\n",
"print(f'Loss: {multi_param_results[0][0][\"loss\"]}')\n",
2021-03-26 20:01:05 +00:00
"print(f'LR: {multi_param_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
"print(f'Momentum: {multi_param_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
"cell_type": "raw",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
"source": [
"pickle.dump(multi_param_results, open(\"result.p\", \"wb\"))"
]
},
{
2021-04-06 17:29:15 +01:00
"cell_type": "code",
"execution_count": 112,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-04-06 17:29:15 +01:00
"outputs": [],
"source": [
2021-04-06 17:29:15 +01:00
"exp1_testname = 'exp1-test4'\n",
2021-03-22 20:49:29 +00:00
"multi_param_results = pickle.load(open(f\"results/{exp1_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"np.savetxt(\"exp1-mean.csv\", mean_param_accuracy, delimiter=',')\n",
"np.savetxt(\"exp1-std.csv\", std_param_accuracy, delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
2021-03-22 20:49:29 +00:00
"mean_param_accuracy = np.loadtxt(\"results/test1-exp1-mean.csv\", delimiter=',')\n",
"std_param_accuracy = np.loadtxt(\"results/test1-exp1-std.csv\", delimiter=',')\n",
2021-03-19 17:21:00 +00:00
"# multi_iterations = 30"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Best Results"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 122,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"Nodes: 32, Epochs: 32, 1e+02% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_param_accuracy_idx = np.unravel_index(np.argmax(mean_param_accuracy[0, :, :]), mean_param_accuracy.shape)\n",
"best_param_accuracy = mean_param_accuracy[best_param_accuracy_idx]\n",
"best_param_accuracy_nodes = multi_param_nodes[best_param_accuracy_idx[1]]\n",
"best_param_accuracy_epochs = multi_param_epochs[best_param_accuracy_idx[2]]\n",
"\n",
2021-03-26 20:01:05 +00:00
"print(f'Nodes: {best_param_accuracy_nodes}, Epochs: {best_param_accuracy_epochs}, {best_param_accuracy * 100:.1}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test Accuracy Surface"
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 123,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653358,
"status": "aborted",
"timestamp": 1615994110345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "ZGJVhz6iJU-7",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-04-06 17:29:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUUAAAEaCAYAAACGrEV/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAACa7UlEQVR4nOydd3hj5ZX/P6+aZau5d4/r2OPuGXtm6KGlUJcQCLCUkN5ISAgJJISU3XQCm2TTYfeXhBAgZEOWBNhAKh3GvY7bjD3uvUm2Ve/vD/lqJFuS5e5h9HkePeO5em+RdHV03vec8z1CkiTChAkTJowbxU5fQJgwYcLsJsJGMUyYMGG8CBvFMGHChPEibBTDhAkTxouwUQwTJkwYL8JGMUyYMGG8CBvFMGHChPEibBR3ACFEuxCidqevYzMQQvyrEEISQly509cSJsxmEDaK24wQ4hCQCpQIIQo3+dhCCLHdn+lNwNTSv9vCDr3OMKcJ4Rtr+7kJ+F/geeBmACFEhhBiTggRJQ8SQrxHCPGXpb8jhRA/FEIMCiH6hRB3e437xdJzfwPmgVwhxPuEEB1Lx2wUQpzvNT5JCPFnIcSsEOLvQogfCSEe8nr+GiFEixBiUgjxlBAiMdALWXrurcDHgSuEEEav53RCiB8vXfOUEOJhr+euE0I0L11fkxCiYGm7JIRI9xr3DyHETet8nX7Pv+Sle4/LFkJMCyG0wT60MKcPYaO4jQghVMB1wOPAY8C/CiGEJEl9QCNwmddweRzAd4FYIB84BNwshLjca+z1wGcBA9ADDAMXAdHAfwKPCSEilsb+GOgHEoHPAzd6Xd8h4HtLx0sCji6ND8T1QKMkSY8CA8A1Xs99D9gDlC2d62dL5zgb+BHwYcAEXAvMBjnH8vOF+jr9nh94BLjB65g3AH+QJGkxxGsI82ZHkqTwY5sewKW4p5oa3F/sBeC8pec+CTyx9HcMYAHiAIHbM0r1Os5twC+W/v4F8LNVzjsElAAqwAHs8XruYeChpb9/CnzB6zkDYAdUAY77BvC5pb+/Dvxt6W8FsAjs9bPPg8DXAxxPAtK9/v8P4KZ1vM5g588BxgH10v+bgLft9L0RfuyeR9hT3F5uwu2V2CRJmgOe4eRa3BPAO4QQOuBq4AVJkiaABCASaF2a5k0D38Dtycn0e59ECHGVEKLWa3wibgMbDyhxe3X+9t0D3OO1Xx9uI5q8/IUIIfKBKuC3S5seA96yNP1NACKA437eg/QA20Mh1NcZ8PySJB0DOoC3CSFKlsb+dZ3XE+ZNiGqnL+B0QQihB/4FcAohLlnarAPsQohPSJI0JISoBq4A3g08ujRmHLACOZIkTQY4vEfqaGn6+Chuw/qcJElOIcQQbo9zHHDiDvT0Le2SvnR8cBvLeyVJeiCElyQb89eEEPI2BfCvuKf7ViAL6Fq2X9/Sdn/M4/4BkEla9nyor3MsyPkBfo172twLPC5JkjPA9YQ5DQl7itvH1binzgVAxdJjH25PTF5LfAz4GHAO8AcASZJcwC+B+4UQ0UIIhRCicGn9zx8RuKfnowBCiNtxe0NIkuQAngK+JISIWDrGFV77/j/gNiFE+dK+sUKIfwlwnhuBO71eSwVwD+7prgv4FfCAECJOCKFeWktk6bV8WAhx5lIUuUAIkbL0XANwvRBCKYS4BcgLcO7VXmew84N7rfbSpdfwmyDnCHMaEjaK28dNuNfuhiRJGl56DAD/zUmv63+AM3GvzU177ftpYAb3+tck7i98jL+TSJI0izsY8WfcgYg4fL2ljwGZuL3Gb+Oe/lqX9n0Ft6H7lRBiFqgFvI0JAEKIs3BPVR/0ei3DwE+APCFEGXAHMAi0ACPAh5bO8TJw+9LrnsW9bCBHrT+N21BNApXAKwHey1Bep9/zL+07AbwEOCRJej3QOcKcnghJCovMns4IIR4FmiRJ+sZOX8t2IoT4ITAlSdK9O30tYXYXYaN4miGEKAZcQDtwIe7p9CFJkpp39MK2ESFEMu4UqDOWAi9hwngIT59PP6KBpwEz7unux04zg3gb7mn2T8IGMYw/wp5imDBhwngR9hTDhAkTxouwUQwTJkwYL7YqeTs8Jw8T5s2BWH3Im4uwpxgmTJgwXoSNYpgwYcJ4ETaKYcKECeNF2CiGCRMmjBdhoxgmTJgwXoSNYpgwYcJ4ETaKYcKECeNF2CiGCRMmjBdhoxgmTJgwXoSNYpgwYcJ4ETaKYcKECeNF2CiGCRMmjBdhoxgmTJgwXoSNYpgwYcJ4ETaKYcKECeNF2CiGCRMmjBdhoxgmTJgwXoSNYpgwYcJ4ETaKYcKECeNF2CiGOa2QJAmXy4Xdbsflcu305YTZhWxV46owYXYU2fi5XC6cTqfnIRtCl8uFVqslMjISIU673kxhghA2imFOafwZv6mpKRQKBVFRUZ5xQgiEECgUCoQQnvF2ux21Wh02jGE8hI1imF2PJEkBPT9J8u2mK4RgenqayMhIDAZD0OMqFAqcTidCCFQqVdgwhgHCRjHMLsLb+C2f8vozfrLn5+84oRo4IQQOhwMAtVq98RcR5pQnbBTDbDurrffJY2TDF8j4BTv+akZRNrLyOIfD4fEYw5zehO+AMFtGKMYPThomeb1vM84bynHkMfK/drsdIGwYT3PCn36YDbF8vc/lcuFwOJiZmcHlcmEwGPx6fVu5freW6bOMt2EUQqBUKrfi0sKcAoSNYpiQWOt6n8Viwel0Eh0dvSPXGswoulwuFhYW0Gq1PtvlfUZHR0lISAh7jKcp4U89jA/LvT5vA7ic9az3bQeyUXQ6nVgsFsxzcywsLmKxWLBarZ7rTUhIIC8vz2dfIQRdXV0YjUaUSmU4In0aEjaKpylrDXasZ8q7XQbFbrdjmZllur0Lc8cxZo52MTY4hhgaQxoaRRVjIutXD5BeUIBGo0EIgdVqpauri/7+ftLT032OJ78HVquViIiIsGE8zQgbxTc53l6fvN4n/+3Ndq33rRdJkrDb7Uz3nGC2rQtL53EWj/fh7B+GgVGkkXHwNuhLDwD79BxDn/wqcY//CBFxMrhSXFxMQ0MDERERJCQk+JxLqVQiSRI2m81jSMOcHoSN4psAf8nNDoeD4eFh4uPjfb7Qp4LxW5ieYaq1g7n2Y8x39WDr6cM1MII0OAbzC+s6rrmuhbb33knRw99DEaEB3FHmiooKqqur0Wg0mEwmz3j5vXG5XGHDeJoRNoqnELLxk6e5qwU7BgcHSUhI2JFI6vLrWY7L6WSmq4fZo13MdRxjsbsX+4kBXP0jMDkDq+y/HmZerqb9Y/ew7+ff8mxTq9VUVFRQV1dHRUXFitJAwONhh6teTg/CRnEXslnJzUKIVY3TVmObmWW6uZ3ZYyewtHSwePwEzr4hpKExsNm3/Xom/++fdN7x72R+5/OebZGRkZSWltLQ0EBlZaXPeO/kbghXvZwOhI3iDrKW5Ob1THm306uZ7x9yG7+2Liwdx1no6kHqG+bY1CwACoMOVbQex9Dotl1TIMZ+9wxCH0Xht7/g2WYwGMjPz6eurs6v1w3hqpfThfCnu8WsVcxgM9f7NttTdNkdmLt6mG5uZ669m4XOHqzH+3CdGIaFxeD7zlkQMSYUBj2uOfOmXdN6Gf3F74hMiCPncx/1bIuLi8Nms9HS0oLL5fLxvsNVL6cP4U92k/CX3Dw1NYVWq12xpred+X3rMYqOOTOzrV3MtHViPtrNQlcvtuN9SMPj4FiZrxgq1hOD6EoLWOjoAufOC7z23vczNNEm0j/0r55tKSkptLe309raSnFx8YogFbhTdSBsGN+shD/VNbKW9b6BgQHS09NXlbDaKlbzFBcHR5ht62KmtQNL+3EWj53A0TOANDG9ZddkaWrHeMZ+zNX1W3aOtdD5xftQmQwkX3eFZ1tERAQqlYru7m6/yd3Dw8M4nU7y8vJ2XeJ6mI0TNooB2AwxAznXbacQQiC5XMwfO8FMSzszLZ3MdxzH2n0Cx4lBmA8+5d0qZl+rw3TWAebeqNuR8/sgSRz91FdRGvUkXHKBZ3NBQQGNjY0Bk7uVSiU2my2c3P0m5LQ2ioHEDLz
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-06 17:29:15 +01:00
"<Figure size 420x280 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(multi_param_epochs, multi_param_nodes)\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-22 20:49:29 +00:00
"# fig = plt.figure(figsize=(10, 5))\n",
"fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
2021-03-22 20:49:29 +00:00
"surf = ax.plot_surface(X, Y, mean_param_accuracy[0, :, :], cmap='coolwarm')\n",
"ax.set_title(f'Average Accuracy')\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
2021-04-06 17:29:15 +01:00
"# ax.set_zlim([0, 1])\n",
2021-03-19 17:21:00 +00:00
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
2021-03-22 20:49:29 +00:00
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig(f'graphs/{exp1_testname}-acc-surf.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Test Error Rate Curves"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 124,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653349,
"status": "aborted",
"timestamp": 1615994110347,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "Jrn3hKQAlGcc",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-04-06 17:29:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZwAAAERCAYAAABPbxE/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABBoklEQVR4nO3deXwV5dnw8d91liSERBaRJYCC0KIIYREQFZHHWquUqrVoVRR53lftY1v32oJPF+xbtdal1mpbrbu1lta2VhGr1YopVsWlEQEBUUEIiMiShYQk55zr/WPuczIJJ+EknCWY6/vhfJh9rpkzmevc98zcI6qKMcYYk2mBXAdgjDGma7CEY4wxJiss4RhjjMkKSzjGGGOywhKOMcaYrLCEY4wxJiss4ZhOQ0RmikiFiNSISN80L3uaiKz19deISInrLhSRZ0SkSkRuF88jIrJTRP6azjj2NyKiIjIo13EAiMgQEYm0Mq7Z95tk/DMi8vUOzrtORKa0P+LsEJHFInJeruNIRSjXAXRWIlLj6+0O1ALxh5ZGqupH7VjWg8BaVf1J+iLMnQxuz8+A/1bV59K83D2oapGvdyZQBPRS1aiITAWOBfqr6u5MxxInItOAe1V1eLbW2VWo6im5jsFYwmmV/4QkIruBI1R1Xe4iSsQSUtVIa/3tnT/JeAFEVWP7GGpHHAysbO9M7d0Hrax3tapGff0fdCTZpCEWYz67VNU+e/kAu4Ehrrs38HvgE+AD4ALfdP8HWA9UA6uBacAFQCNQD9QAv2llHTOBFcB24Emgrxs+DVgLXAd8ClwPLAb+H/CGW24YOAN4F9gBLAQGuvmHABHgf4AK4HdJ1v0gcCfwT6AO+JzbljVuW5YB09y0SbcHGA2UufW/CUxwwwPAHS72ncDrQJ8kMdTglSB3AUvdsOOA/7j5XgIO902vwLeBD4F/JVleEPgFsM19F3PxSmX++QcB1wINbptqgFnu+464/svc9N8C3nPb8RDQ3Q2fA7wI3A1UAhfR9jGy2H2XbwBVwAIg38VbB8TcemtaOU4OBp522/UucLJv3Drgu+572wbcAgR838N1wAZgs/tO8n3zfh1Y7r7vd4ARvv10idvPnwLzfPPMcPu22q377FZiTnostbU/fOOvBba45V8GRFpZxzS8v5Mf4v0NrQO+1GI956V4bJzilrUd+JFb1hQ3rhve38omYCMwt8Xf0R3AC25bnwN6txLvfOBR4E9u2teAob7xbR37E91+rAJ+g/d359+26/DOQ1uAW4GQGzfZLbMK71xwZdbPpdle4f74oXnCeRrvDzkfOMwdeKV41W5VwOfcdIfEDyB3IH6/jeVPcgfvaLzk8TPgcTduGt7J74duXDf3x7MWGA4UAIfjneymuLjuAF508w/BO2n8xk3bLcn6H8Q7mRyJV+oNA9OBwe4Avgj4GHciaLk9eNVRFcDX3PSnAx+59Z2MdzI5wI0bDxS1sh8UGOS6D8RLXme4eK7BO+GHfNP+zS032TZ90/1R9gdKgLdJknBc93y8qqz4uDnA877+M/FOwoe4/f974BbftBH3f8CNT3qMuOkX4yWKQ4CeeD8y/tv3Xa9Ntm/c+IDbjsvc93Q0sBXo58avwzuhxLd5BfB/3biLXP8gt29fBua7cce67/9Yt47DgAG+/fQn9x2PwvtbGObGfQwc67r741U1J4u7rWOprf0xHe+4+rwb9zxtJ5wI8D23by4G1vvGL6bppNzqsQEchJcAZgB5eH+LEZoSzl3u+y/y7eMZvr+Lj/HOBwV4iee6VuKdj1dN/18u3oeBh/Z27LuYNuD9CAgDl7r44tt2Dd4Px4PcPnsR+LYb9yowy3X3AsZl/Vya7RXujx/3RzbEHaC7gLBv3C3u4OmOd9I/Hd8vNN+B2FbC+Q1wra+/GO8Xd8j9Ie3CnWjd+MU0/6X5g/jB6vqL3PwDaEo4A9pY/4PA3XvZB5uBUcm2BzgbeK7F9G+42L+A9ytyEl5VXVvr8CeB84GXfOMCeCefo33THt3Gsl4E5vj6L6TjCefvwLm+/lHAOt+0q33jWj1GfN/d1b5xPwNud93TaDvhTAbWtBj2eHw78RJOy21+znW/gDuRu/4vxeMGfgtc38Z3cqSvfylwuuve4NaR9AdEisdSW/vjgfh+c/0n0nbCqaSpRFfoYu/pW0/8pNzqsYFXgl/sG1eIVwKeAghekijxjf828KDv7+KXvnHfBJ5oJd75wFO+/ulA+d6OfeB44EPfOHHfQ3zbVgHH+MbPiG8P8C+8ElvSUlc2PnaXWvscjPfLZau7g2kn8A28i8u7gHPwfn1uEZE/xe+CSnG5/+tb5ga8Xy393fiPdc/rAht93SV4JQoAVLUGr7ogvv6Yqm7eSwz+5SEip4vIW76Y+uL98mot/uPj07rpD8f7w3wBL6HeA2wWkVtEJLyXWJJtUwxvv/j36caWM/kMcNPHbWhtwhQcDNzt27YleL8gk8XR6jHim2aLr7sW7wdCqnEMbbGfT8bb1riW2xwf12x/4lW5xPflILwqs9a0Fu9M4KvARhH5u4gcnmzmFI6l1pbf3u9wqztOUNVaNyzZvm1ruc3GueVsc70H4ZVgV/q25QagXwrbkkxr07Z17A/Ad7ypl0laHn/P+OJ7FG9/g5dYjwDWisgSETm6jdgywm4aaJ8KvPr1Xu6LbkZVFwGLRKQI7yR7A94v4D2mTbLcH6jqbS1HiMjwVub3D9uEV70Wn6c73h/0Jrxqnb2tv9nyRCQfeAyvSP+cendubcb7NdVy3fH4n1XVU5MuWPXnwM9FZDCwCO9awYN7iWcT3q++eEyCVy2zKVnMSWx208cNbm3CFFTglej+0sp4bTFtq8fIXqRynLyrqqVtTNNym+M/NDbhnYziDqZpX27AKwm3i6q+BnzZHS8/Bn6NV9JISOFYaks6v8NUl7sZL4kDICLdaEqOn+JduzxUVbenKZZk2jr28/B+IPj5+yuAr6vqWy0XqqqrgbNEJIR3TfcxOvC97wsr4bSDqlYArwA/cc9uhERkvIiMFJF+IjLDHaD1eL9Y4nc9fULbX+wDwLdFZAyAiPQWkdPaEdrjwOkicoyI5AE/Af6dQqmmNfl4B/YnLp7Laf6LvuX2LATGuV+yIRHpJiIni0gPEZkgIhPdQV6NV9UXZe+eAcaIyGlu3ivxLqq/keI2PA5c6b6XAXhVHx11P3CtiAwDEJEBInJysgnbOkZSWM8nwEHuB0MyrwEBEblERPLc5zgR8SeSS33bfAXwRzd8AXC1iAwUkd541bB/cOMeAr4hIkd7jyDJCDd/q9y6zxWRA2i64SLZ97q3Y6ktjwMXisjnRKQH3g0R6dDWsbEIOFJEpru/pR/hzpOupPEQcKuI9BSRgIgcLiKT0hRXXFvH/itAWEQuFpGwiHyL5iXc+/GOvQHuuxwiIscDiMgsETnQ1ZZUk9rfYVpZwmm/WXi/KD7A+yO6Ha+YHcD7g9jihg8Evu/muR84yhVzf9Vygar6b+A7wMMiUgW8hXcBNyWq+i5ecfkBt/4RQIcfBFPVKryLj8/iXQQ9EO8mhbhm26OqlcCX8S5gfoJ3LeFiN20PN/1OvGs5L+NddN1bDJ/iXQ+7Dq9K46t41w4aU9yMu/Hu3nkXr/7+D21O3XYsjwH3AU+77+cloK0E0toxsrf1vIt3I8QGVx3ScnwEbz9/Ce+X7Cbgf2n+d/xHvLr65cA/8I4JXPx/xbsGsxLvQvmNbrkvA5fjfU9VeDcJHLC3ePGud6zHu8D9RZIk9RSOpVap6tN43+PLeBf5F6YyXwpaPTZUdSte1fgv8f6W6mheZXUl3rWid/DuYnsY7wJ82rR17KtqA97NOZe6caXAv32z34yXlF52cT5FUwluOrBaRKrxqv5npzPuVEj7S/3GmM5IRNbhXTxekutYjEnGSjjGGGOywhKOMcaYrLAqNWOMMVlhJRxjjDFZYQnHGGNMVnwmHvzs16+fDh06tEPz1tTUUFSU6oPeudHZY+zs8YHFmC6dPcbOHh90jRhfe+21Laraf48RuWpTJ52fo446SjvqpZde6vC82dLZY+z
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-04-06 17:29:15 +01:00
"<Figure size 420x280 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"# fig = plt.figure(figsize=(7, 5))\n",
"fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-22 20:49:29 +00:00
"for idx, layer in enumerate(mean_param_accuracy[0, :, :]):\n",
"# plt.errorbar(epochs, 1- layer, yerr=std_param_accuracy[idx], label=f'{hidden_nodes[idx]} Nodes')\n",
" plt.plot(multi_param_epochs, 1 - layer, '-', label=f'{multi_param_nodes[idx]} Nodes', lw=2)\n",
2021-03-19 17:21:00 +00:00
"\n",
"plt.legend()\n",
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.title(f\"Test error rates for different epochs and hidden nodes\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
2021-03-22 20:49:29 +00:00
"plt.ylim(0)\n",
"\n",
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig(f'graphs/{exp1_testname}-error-rate-curves.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-03-22 20:49:29 +00:00
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test/Train Error Over Nodes"
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 125,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
2021-04-06 17:29:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAIpCAYAAABnk6geAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAACD3klEQVR4nO3deXxU9b3/8dcnGwESkEVlU8GlbihQlVDroLXuYrUuqVbb2pbuVu+t2F/Vamnr0l7R1u12ubbXWuuSWvXibhcgY9VQ0OAGuGIliigICUtCls/vj3MCk5BkJskkM2d4Px+P88jMOWfOfD+zfPI533Pme8zdEREREck2eZlugIiIiEhHVKSIiIhIVlKRIiIiIllJRYqIiIhkJRUpIiIikpVUpIiIiEhWUpEikoSZ3W5mP8x0O0Qkesxsvpmdl+l2RJWKlF4ysxVmtsnMNiRM3+7H53cz2xg+77+78880fOy4NLdntpk1tns9TkvDdr9vZq+YWZ2ZvWpmX07SBjez4xPmTTOzFb1th0g6KX9st82+yh/DzezPZrbWzN43s5vNLL+TdW83s2Yz2z9h3tlmNr+37ZDuU5GSHse5e0nC9N/tVzCzglTmdcYCnb1f+7p7CfBZ4FIzOy7llveNP7R7PR7szoM7eV0c+DywE3Am8DMz+2QXm/kI+FF3nlckQ5Q/2uqL/PFjYBCwGzARmA7M7GIz64EruvO80jdUpPShsJvvp2a2CNhoZsea2etm9mMz+xD4sZkNM7O7zexDM3vDzL6R8PjbzewWM/sHsAnYq6vnc/fFwMvA5IRt3G9mq8M9iD+b2fBw/pPhKsvDvZVYOP87ZvZa2J4/mNngcP7HzOwpM6sN90Su68Hr8W0ze9PMPjCzO81saDj/qPavSwexXefu1e7e7O4vAH8HpnXxdHOBXTtLuGZ2oJnFzWydmS1OLHjMbC8z+2fYa/MXYGC7x/bZayTSSvlju9ejx/kDGA886O4b3f0D4AnggC6e7vfACWa2XydtiZnZ82H+WGBte10OM7MXwlh/TcL/WTPLD9v5dvg6XG9hUWVBb+/z4eNqzOw/u/kS5SQVKX3vHOBsYCjQRPBlaQZGA1cBt4Tr7U6wJ3OVmR2Z8PizgUuAUmBFV09kZmUEewlvJMy+H5gQTqXAlQDu3vrPe99wbyVuZmcB3wSOIdjjKGTbF/4nwCNhHHsCf04x/ta2HUuwZzKD4DUYCNyYsMp42r4uXW2rkKBAebmL1ZqAq+mgN8XMioCHCGLYGfgv4CEzGxaucjewABgB/BE4PeGxffYaiXRA+YO05I/fAKeYWamZjQZOBP7axVOuI3htt+tNMbMRBDtBPyXIHw8Dc82sIMwt9wO/IsgfLwOHJzz8e0AMOBTYF/g4wWsG8EtgjrsPIXgf5nfRvh2Hu2vqxUTwxa8j+FC3TkeGy+YDlyasexSwESgI7+cDW4AJCetcC/wmvH176+0unt8JuiY3hbdvBvI6Wfd4YFG7x45LuP848PmE+xOBFeHtPwK/BkYnac9soCHhtVgVzv8d8OOE9fYFNgPW/nVJ4TW/CfgHYF204TagAHgTOJagqGmNJdZ6O+ExzxD8Q9gDqAeKE5Y9BfwwXa+RJk2tk/LHds/RJ/kDGBe+nk1hu2/rYt3bgR8CwwgOG+9LUOzND5d/AViQsH4eUAN8AjgSeCthmQHvAOeF95cBhycsn5Gw3TjBTtXwTH8us2lST0p6nOjuOyVMCxKWrWy37ip3bwpvjyTY2/h3wvK3gTFdPL4jBwIlwAUEx1oLITg2a2a/DLsWa4H7CKr7zuwO/CbswlxH8M9553DZ94EioDrskjyli+38MeG1GBXOG9NBnMXA8PB+4uvSKTO7FDgaONPDb3Znwu1dw/a9KWMIEkei1td9NPCBu9cnLEtcN12vkUgr5Y+2+iJ/VACLwzhHAR8zswu7WB93/wi4le17U9q0xd1bCHJEa/5YmbDMafse7A48lvAa/QnYJVw2k+C9eD08NPaJrtq3o1CR0vfa/yNNvP8h0EjwwW21O/BuF4/v+EncW9z91nCb3wpnn0uwl3G4B12IZxJU9p2pAb7ULmEODrf/nrt/heALPhuoMLPiVNoWepft46wH1raGkGwDZvYdgi/yce6+Ntn6oT8Q7EUd064tu7Vbr/V1fw8Y2S62xHX78jUSaU/5I9Db/DEJ+K2717v7+wRFy6dTeN4bgJMJelM6bIuZGUGOaM0f7X/xlHi/BvhUwusz1N0PAHD35e5eTlC03ENw2HmHpyIlg9y9mWDv5CozG2RmE4GvEnxAe+o6YFZ4bLSU4Iv8kZmNBGa1W3c1wbHcVr8HLjOzvQDMbLSZnRDePtPMxoR7BusIkkJKCTB0L/A1M9s/PJnuaqAiWW9IKzP7InAZQYHybrL1W7l7I0FvSuJJaFXhNi8I9xbPAvYHHnf3t4GXgB+aWaGZfQaYmvDYvnyNRFKm/JF6/gAWAV8Jv9MjCAquF5M9KNwZ+hWQ2OvyGDDJzE4NT3r9T4JDT4sIDhsXmtnXw+f6DkHvSqvfE7xfoy0wvvUcIjM718xGhD1CdQTn2OzwVKSkx5PW9nf9V3fjsRcQnDvxDsHJWLPdfV5PG+LujxMkgS8AdxAcU32f4Hjn4+1W/wnwl7Dr8Qh3v5vg2O8jYffuAradAT8VWGxmGwi+tOe4e0M32vUkwfHyRwm6ahuB/+hGaD8m6DpekvA6X5biY28HNiS0ZQvwGYJzUNYAlwKfCbt3Ifip86cJ9tLOBx5IeGyfvUayw1L+SN6u3uaPrwJTCAqrpQTnAl2b4mOvJzhU1dqWD4HTCHLSGoITlk9z98Ywt5wBfDdcdjDwdMK2riMoZP5JcC7QQ2zrqT2J4NdSdQRF0Re7EV/OstQLUREREZH+o54UERERyUopFSlmNsPMllswSM92o/RZMOjQMjOrDqeBHW1HRHYsyh0i0htJD/eEJwa9AnyK4BjaYoKzvdckrDMfuMDdX+q7popIlCh3iEhvpdKTMhV42d1r3H0DwZnNmb62g4hkP+UOEemVVIqUMQS/7W5VA4ztYL27wkF6vpeWlolI1Cl3iEivpHwVzSTOdfcaCy74NNfMlrv7I4krhMejZwIMGjSobPz48Uk32tzcTH5+h1fTjrRcjQtyNzbFBa+88sr7CSOApotyRzfkalyQu7HlalyQemy9yh3Jxs0nuDjSAwn3f0nC9Rk6WP87BL/V73SbZWVlnooFCxaktF7U5Gpc7rkbm+JyB571blxzQ7kj/XI1LvfcjS1X43JPPbbu5o7EKZXDPQuBiWY21sxKCK4e+UTrwnDEzpHh7aJweVdXpxWRHYNyh4j0StLDPe7eZGYXA/MIzmH5L3dfY2aPEnTBrgeeMLNCgqtyPkQwVLOI7MCUO0Skt1I6J8Xd5xIMuZw476SEu4eks1EikhuUO0SkN9J14qzIDqGpqYmVK1dSX1+f6aakRWlpKcuWLWszr7i4mHHjxlFQoPQgki65ljugf/KHspBIN6xcuZLS0lL22GMPgiu0R1tdXR2lpaVb77s7a9euZeXKlaTyKxoRSU2u5Q7on/yha/eIdEN9fT3Dhw/PmSTTnpkxfPjwnNrbE8kGuZ47oG/yh4oUkW7K5SQDuR+fSKbsCN+tdMeowz0iEbFlyxamTp0KwKpVqygoKGDkyJEMGjSIp59+Ounjb7/9dk466SR22WWXvm6qiGSRKOcO9aSI9JE5cyAebzsvHg/m90RRURHV1dVUV1fzzW9+kx/84AdUV1enlGQgSDSrV6/u2ZOLSL9R7thGRYpIHykrg/LybckmHg/ul5Wl7zkWLVrEkUceySGHHMIpp5zC2rVrAbjkkkvYd999mTRpEldddRUPPPAAixYt4swzz+TQQw9NXwNEJO2UO7bR4R6RPhKLQUVFkFwuvBBuuim4H4ulZ/vuzsUXX8wDDzzA8OHD+f3vf8+1117LD37wA+69915WrFhBXl4e69evZ+jQoRx66KHccsstTJw4MT0NEJE+odyxjYoUkV4YPBiam7tep6kJLrsM8vPh2GM7Xy8/HzZuTP25zYwlS5Z
2021-03-22 20:49:29 +00:00
"text/plain": [
2021-04-06 17:29:15 +01:00
"<Figure size 560x560 with 6 Axes>"
2021-03-22 20:49:29 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(8, 8*math.ceil(len(multi_param_nodes) / 2)/3))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rates For {nodes} Nodes')\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[0, idx, :], markersize=4, lw=1, label='Test', capsize=4, c=(0, 0, 1), ecolor=(0, 0, 1, 0.5))\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[1, idx, :], markersize=4, lw=1, label='Train', capsize=4, c=(1, 0, 0), ecolor=(1, 0, 0, 0.5))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(1 - mean_param_accuracy + std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate.png')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 126,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
2021-04-06 17:29:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAIpCAYAAABnk6geAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAACEnUlEQVR4nO3deZwcVbn/8c/T3TNJJpmETBayAWFRLhIJXCATkQmIAgIBkSVXRH9Gb/SKIlwlqKBiUAS8BC4iuNyL3oCsEQgCsmvIDEImBkmAsBMDJCH7MpN1Zrqf3x9VnelMZunZ0lWT7/v1qqTq1NLndHWfefrUqVPm7oiIiIhETaLQGRARERFpjoIUERERiSQFKSIiIhJJClJEREQkkhSkiIiISCQpSBEREZFIUpAieTGzyWb2dKHz0dOZmZvZqELnQ6SrqO7ofmY22swaCp2P7rDHBSlmtsTMtpjZppzpG7vx9d3MNoev+56Z/bCd+3boD5iZDTSz28xspZnVmNkiM5scruvUBzyshBrMrDYs15tmdqOZDeroMfN83abnckUXHHOomd1jZh+Y2QYze9rMDmkjD++aWVFO2m/MbFpn8yLRorpDdUcex/2kmS0Iy/OqmZ3RyrZuZs83SXs8+95KYI8LUkInuXu/nOlXTTcws1Q+aS2xQEvv78Hu3g/4LHCZmZ2Ud8477r8JzveHgYHA54GVXXj8Z9y9NDz2OcDBwPNmVtqFr9Gc3HM5rD07tnCO+gFzgcOBQcATwJ/aOFQp8OX2vLbEluoO1R3NniMzSwL3Ebxf/YHvAHeb2cBWDnXwbjqHsbWnBinNMrNnzOynZjYf2GxmJ5rZ22Z2pZmtAa4Mf1XcbWZrzOwdM/uPnP1nmNnNZvZXYAtwYGuv5+4vAIsI/iBmj/GAma0ys3Vm9kczKwvTnww3eSOM/CvC9G+a2Vthfm4zs74tvNzRwJ3uvtHd0+6+0N0fC9c9CSRzflXsa2Z9zezOsDXhH8CH8nkP3b3e3V8CzgZKgK+E+UyG7+O74S+y680sZWZ9wl9n++W8B8eZ2dv5vF5zOnOO3H2xu9/o7ivdPQ3cBBzUxi+7/wYut5zWlCb5+YaZLTaz1WZ2h5kNyFn3FTN738xWmNnXmuxXZmZ3hZ+HxWb2pSb7vRv+YnvDzI7P/x2Srqa6Q3UHsFc43e2Bx4HNwOhWXvK/gR+3kJdEWO73LWjZvcnMeuWsvzx8P5YAZzTZd18z+7OZrTWz18zs0znrfhAer8bMXjazj7T9zhSQu+9RE7AEOLaFdc8AbwMHAb2BTwANwBVAEdAHuBO4m+BLdBiwGjgu3H8GsAY4EkgBRc28hgOjwvlygg/72TnrvwD0BQYAjwM3NrdvuHwu8DKwX5i3u4DpLZTt98AC4IvA/k3WjQYamqT9F/A0wS+CfwHeB55u4diTm1sXvh8zw/lLgb8CQwi+yLOBC8N1dwKX5uz3K+BnHT2XnT1HTY71KeCDtvIAVAJfDdN+A0wL508EPgA+Ep7X+4EZ4boxQE34OegD3N7k8/FnYDrQKzwHy8Py9A33+1C43X5Nz6mmrp9a+ryF655BdUc2bY+tO8L38d+BJDAReBfo3UIePHz/FgMnhmmPA5PD+a8SBKKjCFp1/0ZjvXIqsIygdWuv8P1uCNclgIXARWE+PxaWY++c8zEMsHB5WKG/W62eq0JnYLcXOPhw1gIbcqbsh/AZ4LKcbY8niIRT4XISqMv9ogLXAL/N+RD/to3Xd2AjQQXjwC+BRAvbngzMb7JvbkXzOPD5nOUxwJIWjlVCUGEuBNIEFdT4cN1odq1o/gkcn7N8Fe2vaK4FngrnXweOyVk3kaCZF4JfAX/PeY9XAod14Fze0BXnKGe/wcA7wJfayMOxwCfD96yInYOU3wFX5mx/MLA1rCB+TBiwhOsOyp5jgkpkMzkVIUHAMo3gD9FG4EygV6G/U3vK1MznbQOqO0ajuiP3uGcSBDIN4Xk6sY3zOYogqHk257xMDuf/Any5yTl9I5z/P8I6Jlz+FI1BynjgzSavdV/4Xh8ErCIIolOF/k7lM+2pl3tOcfe9cqY5OeuWNtl2hbtnO4YNJvgj9F7O+neBEa3s35xDCfo+XAhMCI9J2IR5Y9isWUPwwWrtMsO+wG/DZtUNwLMEvzZ24e5b3P0n7j4WGArMBx6wlq99DyeIuLPeb2G71gwH1ufk9bGcvN4Z5gOCL+ZBZnYAQeW+zoNm33zknsvv0EXnyILr4Y8B97r7bW1t7+5/Ifhl86Umq0Y0k5feQBmtv8f7htutznnP/oPgV89m4DyCX0orw6b93PJJ91HdobqjWRZ0sP8DQX+hYoLWjjvNbGQb+bgdGGlmn2qS3lzdkc1LW3XH/tn3K3zPPg0Md/e3gUuAqwnqjlvNrH8b+SuoPTVIaY23srwGqCf4EGTtS9AM39L+zb+Ie8bdbwmPeUGYfD7BF+0Yd+9P0InMWjnMMoJf+bmVZkvXlXNfey1wPcEHvayFPH8A7JOzvE8z27TIzPoQtC78LSevn8jJ5wB3/0iYnzrgQWBSON3bntdqotPnKMz7I8AL7n55O177SuBywj8coeXN5GUbsI7W3+NlwCZgYM57VuruXwdw90fd/QSCX2LbCSodKSzVHYE9te4YA7zi7lXhOXqGIMgY19qLuns9wff3x01WNVd3ZPPSVt3xWpNz28/drwlf7w/u/jGCVt3RBB18I0tBSjt40JHyPuAqMysxszEETXX3dOKw1wFTzayY4C6RbcB6MxsMTG2y7Sp27oT1e4IOmwcCmNnw3A5Suczsh2b2r2ZWZGb9CH6VL3b3NQRfzoTtfIvifeGx+5vZwcD/y6cw4fHHAH8My/J/OXm9KsyjWXDr4nE5u95DcNfAWXSiounsObKg8+v9BJVBu24vdfengBUETb5Z9wJfNbNDwo6JPyO41u7h65xlZkeHFfMPc461DHg+pxyp8Px9xMz2NrOJ4T7bCZqV0+3Jq+xeqjvaFve6g6DfzqFm9rGwPMcS9Pl4NY99ZxAEGkfnpN0LXGJmIy3oBP2jnLzcB0wxsw9Z0BH/uzn7VROckwvMrDicKsLOtAeb2fHhZ2YLQf0R6bpjTw1SnrSdxzr4WTv2vZCgM9L7wEME1wVndzQjHvQA30DQKe12gibOlUAVQVNmrp8A94dNeMe6+90EfR7+HDbxziHooNkcA+4g+AW/hKDH/ZlhHjYTXANeEB57X4JWgbVhOe8maMZszfFmtiksyyyCvhwfc/eacP11BH90/0ZwXf1hdo7+/0Lw62yFu78GEH6xNrXxus3pzDk6BjgF+AxQk/MZ2beN/bKuJPiFCYC7P0lwXftRgubaeuA/w3WvEPyKmUVwTp5rcqzzCVpKFhP8kbmRoJNjgqBSWhmmjyQnwJFupbpDdUez3P0tgh82M8ysliDIutDd38hj33qCeqIsJ/l3BO/HPIJAZ2G4De7+Z+C3BO/JSwQtv9ljNQCnEfRhWUbwg+sHBPVGL4L3cy1BK89GgjuMIsuCH3QiIiIi0bKntqSIiIhIxOUVpITXv9+wYOCfKU3WlZjZY2b2ugXDJX8rZ91gM5sd7veAmfXu6gKISHSp7hCRzmjzco8Fwzm/SnBf9UbgBYIe5GvD9SXA0e4+J+xUNR+Y6O5vm9l0gnvvb86d78byiEhEqO4Qkc7KpyVlHLDI3Ze5+yaCsSN2PGsgvId+Tji/CXiDoBMTBAPtZDtN3QGc3lUZF5HIU90hIp2ST5AygqCHcNYygrsJdmFm+xAMJfyPMGmAu29saz8R6ZFUd4hIp+T9ZM62WPDgo3sJnqOwuR37TQGmAJSUlJSPHj26zX3S6TTJZLKDOY0mlSke9uQyvfrqqyu9nU+LzYfqjs5RmeJhTy5Tp+oOb/sZB8cAs3KWbyTnmQ9hmhFUMj9skv4mwS8iCJ7W+WRrr1VeXu75mDNnTl7bxYnKFA97cpmAud6OZ26o7tg9VKZ42JPL1N66I3fK53LPPGBMOOpdP4KBrp5oss01wBZ3v6pJ+iMEAw1B8ITOh/N4PRHpGVR3iEintBmkeDB63SUEj8deAFz
2021-03-22 20:49:29 +00:00
"text/plain": [
2021-04-06 17:29:15 +01:00
"<Figure size 560x560 with 6 Axes>"
2021-03-22 20:49:29 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(8, 8*math.ceil(len(multi_param_nodes) / 2)/3))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rate Std Dev. For {nodes} Nodes')\n",
" ax.plot(multi_param_epochs, std_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, std_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate-std.png')"
2021-03-22 20:49:29 +00:00
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "eUPJuxUtVUc3",
"tags": [
"exp2"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 2\n",
"\n",
"For cancer dataset, choose an appropriate value of node and epochs, based on Exp 1) and use ensemble of individual (base) classifiers with random starting weights and Majority Vote to see if performance improves - repeat the majority vote ensemble at least thirty times with different 50/50 split and average and graph (Each classifier in the ensemble sees the same training patterns). Repeat for a different odd number (prevents tied vote) of individual classifiers between 3 and 25, and comment on the result of individualclassifier accuracy vs ensemble accuracy as number of base classifiers varies. Consider changing the number of nodes/epochs (both less complex and more complex) to see if you obtain better performance, and comment on the result with respect to why the optimal node/epoch combination may be different for an ensemble compared with the base classifier, as in Exp 1). \n",
"\n",
"(Hint4: to implement majority vote you need to determine the predicted class labels -probably easier to implement yourself rather than use the ensemble matlab functions)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 7,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2",
"exp-func"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [],
"source": [
2021-03-27 16:29:31 +00:00
"num_models=[1, 3, 9, 15, 25]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_ensemble_vote(hidden_nodes=16, \n",
" epochs=50, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
" weight_init=lambda: 'glorot_uniform',\n",
2021-03-19 17:21:00 +00:00
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
2021-03-26 20:01:05 +00:00
" round_predictions=True,\n",
2021-03-19 17:21:00 +00:00
"\n",
" nmodels=num_models,\n",
" tboard=True,\n",
" exp='2',\n",
2021-03-19 17:21:00 +00:00
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
"\n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for m in nmodels:\n",
" if print_params:\n",
" print(f\"Models: {m}\")\n",
2021-04-06 17:29:15 +01:00
" \n",
" response = {\"epochs\": list(),\n",
" \"num_models\": m}\n",
" \n",
" ###################\n",
" ## GET MODELS\n",
" ###################\n",
" if isinstance(hidden_nodes, tuple): # for range of hidden nodes, calculate value per model\n",
" if m == 1:\n",
" models = [get_model(int(np.mean(hidden_nodes)), weight_init=weight_init)]\n",
" response[\"nodes\"] = [int(np.mean(hidden_nodes))]\n",
" \n",
" else:\n",
" models = [get_model(int(i), weight_init=weight_init) \n",
" for i in np.linspace(*hidden_nodes, num=m)]\n",
" response[\"nodes\"] = [int(i) for i in np.linspace(*hidden_nodes, num=m)]\n",
" \n",
" else: # not a range of epochs, just set to given value\n",
" models = [get_model(hidden_nodes, weight_init=weight_init) for _ in range(m)]\n",
" response[\"nodes\"] = hidden_nodes\n",
2021-03-19 17:21:00 +00:00
"\n",
" for model in models: \n",
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
" metrics=metrics\n",
2021-04-06 17:29:15 +01:00
" ) \n",
2021-03-19 17:21:00 +00:00
" \n",
" if tboard:\n",
" if callbacks is not None:\n",
" cb = [i() for i in callbacks] + [tensorboard_callback(prefix=f'exp{exp}-{m}-')]\n",
" else:\n",
" cb = [tensorboard_callback(prefix=f'exp{exp}-{m}-')]\n",
" \n",
2021-03-26 20:01:05 +00:00
" ###################\n",
" ## TRAIN MODELS\n",
" ###################\n",
2021-03-19 17:21:00 +00:00
" histories = list()\n",
" for idx, model in enumerate(models):\n",
2021-04-06 17:29:15 +01:00
" if isinstance(epochs, tuple): # for range of epochs, calculate value per model\n",
2021-03-26 20:01:05 +00:00
" if m == 1:\n",
2021-04-06 17:29:15 +01:00
" e = np.mean(epochs) # average, not lower bound if single model\n",
2021-03-26 20:01:05 +00:00
" else:\n",
2021-04-06 17:29:15 +01:00
" e = np.linspace(*epochs, num=m)[idx]\n",
2021-03-19 17:21:00 +00:00
" e = int(e)\n",
2021-04-06 17:29:15 +01:00
" else: # not a range of epochs, just set to given value\n",
2021-03-19 17:21:00 +00:00
" e = epochs\n",
2021-03-26 20:01:05 +00:00
" \n",
"# print(m, e) # debug\n",
" \n",
" history = model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
2021-03-19 17:21:00 +00:00
" epochs=e, \n",
" verbose=verbose,\n",
"\n",
" callbacks=cb,\n",
2021-03-19 17:21:00 +00:00
" validation_split=validation_split)\n",
" histories.append(history.history)\n",
2021-03-19 17:21:00 +00:00
" response[\"epochs\"].append(e)\n",
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## FEEDFORWARD TEST\n",
" ########################\n",
2021-03-27 16:29:31 +00:00
" # TEST DATA PREDICTIONS\n",
2021-03-19 17:21:00 +00:00
" response[\"predictions\"] = [model(dtest.to_numpy()) for model in models]\n",
2021-03-27 16:29:31 +00:00
" # TEST LABEL TENSOR\n",
" ltest_tensor = tf.constant(ltest.to_numpy())\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## ENSEMBLE ACCURACY\n",
" ########################\n",
" ensem_sum_rounded = sum(tf.math.round(pred) for pred in response[\"predictions\"])\n",
2021-03-27 16:29:31 +00:00
" ensem_sum = sum(response[\"predictions\"])\n",
2021-03-26 20:01:05 +00:00
" # round predictions to onehot vectors and sum over all ensemble models\n",
" # take argmax for ensemble predicted class\n",
" \n",
" correct = 0 # number of correct ensemble predictions\n",
" correct_num_models = 0 # when correctly predicted ensembley, proportion of models correctly classifying\n",
" individual_accuracy = 0 # proportion of models correctly classifying\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" # pc = predicted class, pcr = rounded predicted class, gt = ground truth\n",
" for pc, pcr, gt in zip(ensem_sum, ensem_sum_rounded, ltest_tensor):\n",
" gt_argmax = tf.math.argmax(gt)\n",
" \n",
" if round_predictions:\n",
" pred_val = pcr\n",
" else:\n",
" pred_val = pc\n",
" \n",
2021-03-27 16:29:31 +00:00
" correct_models = pcr[gt_argmax] / m # use rounded value so will divide nicely\n",
" individual_accuracy += correct_models\n",
2021-03-26 20:01:05 +00:00
" \n",
" if tf.math.argmax(pred_val) == gt_argmax: # ENSEMBLE EVALUATE HERE\n",
2021-03-19 17:21:00 +00:00
" correct += 1\n",
2021-03-27 16:29:31 +00:00
" correct_num_models += correct_models\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), True) # debug\n",
"# else:\n",
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), False)\n",
" \n",
" ########################\n",
" ## RESULTS\n",
" ########################\n",
" response.update({\n",
" \"history\": histories,\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"model_config\": json.loads(model.to_json()),\n",
" \"loss\": model.loss,\n",
" \"round_predictions\": round_predictions,\n",
" \n",
" \"accuracy\": correct / len(ltest), # average number of correct ensemble predictions\n",
" \"agreement\": correct_num_models / correct, # when correctly predicted ensembley, average proportion of models correctly classifying\n",
" \"individual_accuracy\": individual_accuracy / len(ltest) # average proportion of individual models correctly classifying\n",
" })\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"models\"] = models\n",
"\n",
" yield response"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Single Iteration\n",
"Run a single iteration of ensemble model investigations"
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 20,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Models: 1\n",
2021-04-06 17:29:15 +01:00
"13 [50]\n",
2021-03-19 17:21:00 +00:00
"Models: 3\n",
2021-04-06 17:29:15 +01:00
"[1, 13, 25] [50, 50, 50]\n",
"Models: 9\n",
2021-04-06 17:29:15 +01:00
"[1, 4, 7, 10, 13, 16, 19, 22, 25] [50, 50, 50, 50, 50, 50, 50, 50, 50]\n",
"Models: 15\n",
2021-04-06 17:29:15 +01:00
"[1, 2, 4, 6, 7, 9, 11, 13, 14, 16, 18, 19, 21, 23, 25] [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]\n",
"Models: 25\n",
"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25] [50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
"single_ensem_results = list()\n",
2021-04-06 17:29:15 +01:00
"# for test in evaluate_ensemble_vote(epochs=(5, 300), optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
"for test in evaluate_ensemble_vote(hidden_nodes=(1, 25), optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
" single_ensem_results.append(test)\n",
" print(test[\"nodes\"], test[\"epochs\"])"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 23,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2021-04-06 17:29:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFECAYAAAD2sk0XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABIYklEQVR4nO3deXyU5b3//9dnskI2CMi+uqAgYQ2ItOLWKu7UI61VW6lVa1tr/dalP2tbbc/x9FRxqVvVelrUqrV6KuJuXRB3BEVBFFEB2UG2JEASMnP9/rjvmcxMJskkzCQT8n4+HvPI3Nd9zz3XXDOZ91zXvZlzDhEREWlfgfaugIiIiCiQRUREMoICWUREJAMokEVERDKAAllERCQDKJBFREQygAJZJI3MbLiZLTazSjP7j/auTzwz+5WZ3d7e9Qgzs7lmdk4bPt8MM3sxyWWdmQ1Id51SzcxWmtnXG5mX9OuX9FMg7+PMrCrq5sxsZ9T0oBaua5aZ/bqJ+UeZWSjuOavMrGTvX0lS9ZthZnX+c1aY2Xwz+1oLHpuOL6YrgCedc0XOuf/bmxWZ2RAzq0tQ3uoQc879t3Pu4r2pV1PM7FozuzdF65rlf4aPjit/2S8fkornaQ0/9HbFfe5/0l71kY5JgbyPc84Vhm9ADXBoVNmXaXjKL6Kf07/tiF/IzLLjps3Mkv48xj8+ylz/tZYC/wb+2ZLKp8EgYGlLH9TE60uZtniONFgOnBWeMLP+wFCgtt1qVO+4uM/9ne1dIelYFMidlJmVmtlDZrbJzL4ws3Oj5p1nZqv8YdZlfs/3XOBs4Df+r/+7Wvh8Q/ze60VmthaY5fd4bjezl4FdwAFmdoSZvW9m283sVTMbHrUOZ2YXm9kK4JWmns85Vwc8CPQzs17+4w8zs3f93vMqM/uZX74/cBdwlP/aPmpNGyV4zc8CRwP3+uvtYWYDzewZM9tmZkvN7LSo5eea2X+a2QJgp5nltKSN/XUEzOx3ZrbazNab2a1mlufPm2Fmr5jZ3Wa2A/hBdA/Wfy+ie3hBM5vhz2vuffmxma0ws6/M7Cq//CjgV8C5/vqe9ct/5bddhZm9ZWajWvASHwdOCr8mvHD+BxA55aCZdTezh/26fG5mP4qaV2BmD/qv4z3goLj2O9LMFvrz55rZAS2oW0L+en5nZgv81/xI1HsyzMxe98s3mtkNUY87w8w+MrOtZjYn6nN8lJl95r93W/12P9zMfmRm68xsrZmdGFeNyWb2qZltMbOZ1siPXzMrM7N5/udzoZmV++UB/7P0ld8275pZz71tG4njnNOtk9yAamCIf/9pYCaQBxwCrANGAQVABXCQv9xgYKh/fxbw6ybWfxTwWSPzhuB9ad4F5ANd/PV9BYwHsoHewDbgdCAHb7h3OZDtr8MBTwDFQJcEzzEDeNG/nwv8D7AVyPHLxvm3AFAO7ADGxj82an0tbqMEdZoLnBM1/Tpwg7/Oo4BK4MCoZT8DDvTbyBK0YV1TzwFcAHwEDAB6AG8A10a9xjr/b8B/D64F7m3kvVyH18PvkcT78ihQCIzE+5wd4M9rsH5/Pfv56/o9sKix9op73Czg13ijHt/yyxYBhxL72X4QeBjo6r9fm4Ej/XnXAy/ifYYOAVZT/5kZ6C97BJAF/Ax4N+r5HTCgkbqtBL7exGfgY/9z0s1/f37gz/sHcBVgeJ+riX75RGANUOa30/XAY1HvTR1wGd7/zTXAl8BNeJ/7GcCquLq9D/QB+vnP/8ME/zOFwFrgP/zXP81fbz4wFVjgt1sW3v9RYXt/p+1rt3avgG5t+Gb7X1r+P+ZO/KDy5830vzwL8IJqGpAX9/hZNB/IQWB71O0jf94Q/wutb9z67o6a/h7watR0wP+CONyfduH7jTz/DGCP/7x1/uv4RhPLPwxcHPXYF6PmtaqNEjzHXOrDcqD/HnSJq8NVUcte1cS6wm24Pe5WF/UcL+F/2fvTxwPLol7jsrh1XkvDwByIF8bHtOB9GR81fz4wrbH1xz1XPhDC/3InuUA+DXgML4gXxX22s/CGr4dGPe4P+J8zYAVwVNS8/6I+kP4/4J6459xMfdA3F8iVce/LkVGv6bKoZa8HbvHvP4D3I7Vv3PruAn4VNV2E99nOxvs/247/gw0Y7tet1J/u4k93i6rbjKh1nQ+8EP+5B84Ml0ctu8B/vmOBZXg/FCxRG+i29zcNWXdOg/C+CDf7w0/bgR8BfZxzO4HvApcAG83sUTPr14J1r3DOdYu6HRo1L+ScWx+3/Jqo+/3wfpED4JwL4fVg+jWyfCKvOue64fXA5gMTwjPM7FAz+7eZbfaHbE/H6/0lko426gdsds7tjipb1cLXF4xr3254ve7o54jeN6BF6/eHUv8PuMk593KidTbyvmyMur8Lr7fV2HNc4A/F7gA24PUOG3sfEnkWOBz4KV5vOFpPvB5lY23Q1697WPT9QcD3wu+3/54XAP2TrNcJce/Nq1HzGmufK/F6tYv8TQKnRNXl6qh6rMb74dXHn/+V8xMT2I33udgKEPX5in4P4l9z3wT1HwQcGff6hwP9nHMv4f1IuAdY7w97t3iTijRNgdw5rQWqgO5RXx5FzrmLAJxzzzjnjsEb9qwB/tt/3N5eGizR46PLwkOkgLejF/W9tabW0XClzm0DLgSuMrPwl8/twFvAIOdcCfAvvDBItN7WtlFT1gH7mVl+VNmg1ry+Zp4jeu/5lq7/z3g75s1sbJ2NvC+NiXk+8/aEvgU4F+iOFwyO+veh+RU6V4u3OeFC4KG42V/h9SQba4P1ft3Dou+vBf4SF6pdnXNvJFu3lnLOrXfOnYcXtNcC//Q/H2uB38TVpYtzrrkfbI2Jf83xP4zxn/P5uOcscM495Nf1ZufcGLwfucfj7VMiKaRA7oScc2vxgum/zKyrmWWb2TgzG2Fmvc3sZDPrghc0u/CGoQE24Q0LpsuzwGgzO828PYD/H96v/wWtWZlzbgXwjL8e8Ib9tgPVZnYEcFLU4puAAf7z7k0bNVWf1cB7wO/MLNfMpgCn4A2/psojwGVm1t/MSoHf4G2nbJaZ/Rjvy/aHcbP25n3ZBAz2Qxy8XlsIbyg4G/hdMnVL4Bq8IeG10YXOuSBee4bft5H+6wm3wWPAr8ys2MwOBr4f9fCHgOnm7cAWMLMiMzujlfVLir/jVj+/t7sd78eJA/4GXGxmo/3lSi1qB8BW+Jn/ue0LXEriow+eAsaa2TT/897FzKaaWYmZlZvZBP/9r8T70dPsZ15aRoHceZ2N17v7Au9L8xa8bU8BvGG0jX55f7ztdgB/BQ7zh7MaO6Rjf2t4HPLwRpaN4Zz7Cm+77O+ALcC38LZF7mnF6wu7EbjIzLoBv8Qb5qzA+1KaE7Xcy3jb2jab2Yd+WWvaqDlnAqP9x90NnOucW96aF9aI/8XbE3k+3uFWH+BtQ03Gd4BheMPw4ffu7L18Xx7DC+FtZvaUc24J3uv+EK+9V9CKQ5b8nmVjPdeL8cJ+Nd57fK1z7hV/Xvg1rMbbfv9A1DpX4L0/N+DtDPgJ3vbqZL0Q97m/LonHTAQWmlkV3ujEd51zNc65N4HLgfvNrALvh1xSx9Q34p/Aa8ASvMMB/xa/gPMOTzwJb2e2TXjvz4X+7BK8///teNuS36Dh6ITspfBOASIiItKO1EMWERHJAGkLZDN73D+4POH2MTOb6O9p+ZmZ/TZd9RAREekI0tlD/hOxO0zEuwPv0JGDgRPNrCyNdREREcloaQtk59xcvL3xGvCP2cx2zn3o7xX5D+DkdNVFREQk07XXyeX74R3zFrYWODLRgmZ2Pt6ZZejatethQ4YMASAYDJKVlZXeWnYSasvUUDumhtoxNdSOqZHqdly6dOlG51yfRPMy/movzrl7gXsBJk2a5N5++20A5s2bx5QpU9qzavsMtWVqqB1TQ+2YGmrH1Eh1O5rZysbmtdde1uuIPR1df5I764+IiMg+qV0C2Tm3Dgia2Sgzy8I7GP/J9qiLiIhIJkjnYU8v4l2S7UQzW2Pe9TqfiToJ/8V4Z8r5FHjOObc4XXURERHJdGnbhuyc+0aC4hOj5r+
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"ensem_x = [i[\"num_models\"] for i in single_ensem_results]\n",
"\n",
2021-04-06 17:29:15 +01:00
"plt.plot(ensem_x, 1 - np.array([i[\"accuracy\"] for i in single_ensem_results]), 'x-', label='Ensemble Test')\n",
"plt.plot(ensem_x, 1 - np.array([i[\"individual_accuracy\"] for i in single_ensem_results]), 'x-', label='Individual Test')\n",
"plt.plot(ensem_x, 1 - np.array([i[\"agreement\"] for i in single_ensem_results]), 'x-', label='Disagreement')\n",
2021-03-26 20:01:05 +00:00
"\n",
2021-04-06 17:29:15 +01:00
"plt.title(\"Test Error Rates for Horizontal Model Ensembles\")\n",
"plt.ylim(0)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-26 20:01:05 +00:00
"plt.legend()\n",
2021-04-06 17:29:15 +01:00
"plt.ylabel(\"Error Rate\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Multiple Iterations\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the ensemble model investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|1|0.06|0|128|16|50|1, 3, 9, 15, 25|\n",
"|2|0.06|0|35|16|1 - 100|1, 3, 9, 15, 25|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
2021-04-06 17:29:15 +01:00
"|3|0.06|0.05|35|16|1 - 300|1, 3, 9, 15, 25|\n",
"|4|0.06|0.05|35|1 - 50|50|1, 3, 9, 15, 25|"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 37,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_results = list()\n",
2021-04-06 17:29:15 +01:00
"multi_ensem_iterations = 30\n",
2021-03-19 17:21:00 +00:00
"for i in range(multi_ensem_iterations):\n",
" print(f\"Iteration {i+1}/{multi_ensem_iterations}\")\n",
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)\n",
2021-04-06 17:29:15 +01:00
" multi_ensem_results.append(list(evaluate_ensemble_vote(epochs=50,\n",
" hidden_nodes=(1, 50),\n",
" nmodels=[1, 3, 9, 15, 25],\n",
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.06, momentum=0.05),\n",
2021-03-19 17:21:00 +00:00
" weight_init=lambda: 'random_uniform',\n",
" batch_size=35,\n",
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-26 20:01:05 +00:00
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
2021-03-19 17:21:00 +00:00
]
},
2021-03-30 16:31:10 +01:00
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 8,
2021-03-30 16:31:10 +01:00
"metadata": {},
"outputs": [],
"source": [
"def test_tensor_data(test):\n",
" return [test[\"accuracy\"], \n",
" np.mean([i[\"accuracy\"][-1] for i in test[\"history\"]]), # avg train acc\n",
" test[\"individual_accuracy\"], \n",
" test[\"agreement\"]]"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 136,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"30 Tests\n",
"Models: [1, 3, 9, 15, 25]\n",
2021-03-26 20:01:05 +00:00
"\n",
"Loss: categorical_crossentropy\n",
2021-04-06 17:29:15 +01:00
"LR: 0.06\n",
"Momentum: 0.05\n"
2021-03-26 20:01:05 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_models = sorted(list({i[\"num_models\"] for i in multi_ensem_results[0]}))\n",
"multi_ensem_iter = len(multi_ensem_results)\n",
"\n",
2021-03-26 20:01:05 +00:00
"accuracy_ensem_tensor = np.zeros((multi_ensem_iter, 4, len(multi_ensem_models)))\n",
2021-03-19 17:21:00 +00:00
"for iter_idx, iteration in enumerate(multi_ensem_results):\n",
" for single_test in iteration:\n",
2021-03-22 20:49:29 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" ensem_models_idx = multi_ensem_models.index(single_test['num_models'])\n",
2021-03-30 16:31:10 +01:00
" accuracy_ensem_tensor[iter_idx, :, ensem_models_idx] = test_tensor_data(single_test)\n",
2021-03-19 17:21:00 +00:00
" \n",
"mean_ensem_accuracy = np.mean(accuracy_ensem_tensor, axis=0)\n",
2021-03-22 20:49:29 +00:00
"std_ensem_accuracy = np.std(accuracy_ensem_tensor, axis=0)\n",
"\n",
"print(f'{multi_ensem_iter} Tests')\n",
"print(f'Models: {multi_ensem_models}')\n",
"print()\n",
2021-03-26 20:01:05 +00:00
"print(f'Loss: {multi_ensem_results[0][0][\"loss\"]}')\n",
"print(f'LR: {multi_ensem_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
"print(f'Momentum: {multi_ensem_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
"cell_type": "raw",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
"source": [
2021-04-06 17:29:15 +01:00
"pickle.dump(multi_ensem_results, open(\"results/exp2-test4.p\", \"wb\"))"
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
2021-04-06 17:29:15 +01:00
"exp2_testname = 'exp2-test4'\n",
"multi_ensem_results = pickle.load(open(f\"results/{exp2_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
"np.savetxt(\"exp2-mean.csv\", mean_ensem_accuracy, delimiter=',')\n",
"np.savetxt(\"exp2-std.csv\", std_ensem_accuracy, delimiter=',')"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"mean_ensem_accuracy = np.loadtxt(\"results/test1-exp2-mean.csv\", delimiter=',')\n",
2021-03-26 20:01:05 +00:00
"std_ensem_accuracy = np.loadtxt(\"results/test1-exp2-std.csv\", delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
2021-03-22 20:49:29 +00:00
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Best Results"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 137,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-06 17:29:15 +01:00
"Models: 15, 93.5% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_ensem_accuracy_idx = np.unravel_index(np.argmax(mean_ensem_accuracy[0, :]), mean_ensem_accuracy.shape)\n",
"best_ensem_accuracy = mean_ensem_accuracy[best_ensem_accuracy_idx]\n",
"best_ensem_accuracy_models = multi_ensem_models[best_ensem_accuracy_idx[1]]\n",
"\n",
2021-03-27 16:29:31 +00:00
"print(f'Models: {best_ensem_accuracy_models}, {best_ensem_accuracy * 100:.3}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
2021-03-27 16:29:31 +00:00
"### Test/Train Error Over Model Numbers"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-04-06 17:29:15 +01:00
"execution_count": 141,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2021-04-06 17:29:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAFECAYAAADY92yFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABqGUlEQVR4nO3dd3hUVfrA8e87M+kBQi8BBKVJ7yCrVCkiKquCDQUVQV10saCu7urqruvP7tpRUcQuFsQKSousJqEjKE3pvQfSM3N+f9ybMAmTZJLMZFLez/PMk1vOPffcQ8h7z7l3zhFjDEoppZSqXByhLoBSSimlSk4DuFJKKVUJaQBXSimlKiEN4EoppVQlpAFcKaWUqoQ0gCullFKVkAZwpQJILO+IyDER+TzU5SlIRJqLyLFQlyOXiEwQkR9CXY6SEpGBIrKliP2LRWRceZapMMWVtUDabSJybrDLpAJDA7gKCvsPQZqInPT63FqO5zcikmqfd4eI/L2ExzYt5anPA/4ENDLG/LmUeXiX5bRAUJagZ4zZYYyJK2u5CiMiLUQkJ0B5DRQRT4HfoZMiUisQ+VdEIvJP+/fv+gLb37S3DwxNyVRF5Ap1AVSVNswYs7SoBCLiMsbkFLetiOMFEGOMx8futsaYXSLSA0gQkWRjzHy/S186zYE/jDEZJT2wJNddGsHOP0j+MMa0CnUhytlm4GrgLQARiQRGALtDWShV8WgLXJU7u1X5LxFZDqSKyFAR2SIiD4vIIeBhEaktIh+IyCER+V1EJnsdP1NEXhSRhUAacFZR5zPGrADWA1298vhMRA6IyBERmS0ideztuQF+o93aO8/e/hcR2WyX520RifFxXdcAbwAD7WNvFxGHfV07RWSviDwvIhF2+gkiskhEpovIceD6gnn6WZ/nicgqu9t+iYic7bXPiMgUEdkKLPJuIYvIOQVatpkistjeV1z9Py8iC0TkhIjMz60/YD7g9MqzuYj0EZFlIpIiIttF5LbSXGeBa24hIjkicpNdr/tEZLzX/hvsc50QkY25LVcRibJ/d/aIyC4Ruc/HdS2yy/6liNQXkU/ssi8SkboFyvEvETkqIptEZGgR5S3298fLj0BrEWlsr18ELAFOeuUXKSIv2de9Q0QeFBGHvc8pIv8VkcMishHoW6AsnUQkwS73ChHpWUiZR9l1d0KsHrUriyizCgEN4CpUrgKuBGoBOUALwA00Bv4NvGinaw78Gfi3iAzwOv5KYBpQA9hW1IlEpA/QEfjda/NnQEv7UwN4EMAYM8ze39YYE2uM+VFExgA3A+cDzYAw4OGC5zHGvGenW2wf+zxwI3A5cI5dhh7A37wOOw/4GagNvFvUdRRybXWBucC/gPrAV8BcEfHuXRsKdAGGeR9rjPnZLmcs0ADYAHxk7y6u/scCd9jndAJ/tbcPA9y5+RpjdgDZwGQgDrjMzqtbSa/VBydWnZ4BXAe8JCI17OD4HHC+MaaGXabt9jFPAXWANkBv4FoRGeWV5xhgCtbvYQtgKfBfoJ59Hbd7pW1hl6EhcA8wW0RqFyykv78/XgzwMdbvOMA1nP678Q+gA3A2cC4wzq4DsOp6kL1/ENb/tdyyxALfeV3Tv4DPxGrlF/QGcINdh32BtUWUWYWCMUY/+gn4ByuongCOeX0G2PsWA3/zSjsQSAVc9roTyAJaeqV5DJhuL8/MXS7i/AY4jtVCN8ALgKOQtMOB5QWObeq1/h1wtdd6R2BbIXlNAH7wWl8AXF/gXBu90m4s5joW23XjXY9puecArgWWeKV3YHW1nuN1Led47W8B5Pg4z/vAWyWo/xe89t0KzCkq/wLn+gCY4qu+CqQbiHVT533t673OY4C6XukPYPWyxNj/9qOBCK/9YtddE69tU4CZXtf1vNe+x4G5Xuu3eF3nQCADiPTavzT398T+dxtXit+ff2IFzi7Acqybjd1Yjzs3AAPtdL8Dg7yOmwzMs5cXARO89k0EttjLVwLzC5xzuVe+24Bz7eWd9rGxgfq7oJ/AfrQFroLpAmNMnNdnide+XQXS7jOnns/Ww2ql7PDavx1oUsTxvnQAYrH+SPe380REXCLynN3FmgJ8AtQtPBuaA9PtLupjWH+o6/txfuwyl/U6JnvXI1bA9Jm/sd4F2FmSc4jIXVgt0lvsTf7U/36v5TSsei4s/w4i8r2IHLQfFVxK0fXtbWuB36EOXvvcxpjDBcthjEnFanXeDuwX6xFJE6x/syjgV69/y/9gtaBzHfBaTvex7n2dB03+dx12YrXcCyrx748xZg0QCfwd66ah4LsLRf1eNbbL4l0u77IMyC2LXZ6zyf9vm+tyrN6XXSLynfejGVUxaABXoVJwGjzv9UNY3ZXNvbY1B/YUcbzvkxjjMca8ZOeZG6CuwWpB9TPG1MT6QyVFZLMbGF8gkBT1DNPbHgJwHf7mLyKC1U3r1zlEZDBwN3CpVzDyp/4L4+tcL2I9JmhujKmF9fiiqPouM2PMN8aYwUBTIBMrUB+yl8/0+nesaYy5oJSnqVeg67kZsNdHutL+/rwPTMX3o5Wifq/22mXxLpd3WeYVLIsx5v2CJzDGJBljLsS6wVkDvOJHmVU50gCuKhxjjBurVfxvEYkWkY5Yz5I/LEO2TwJ3i0g41jPvDOCoiNTDCmDeDmB10eZ6E7hfRM4CEJHGIjLCz/N+BNwlIvH2i17/oGzXUdC3QBcRucR+7n0HVktxeXEHikhz4D3gGmM9qwbKXP+HAIfk/xpeDazu7wyxXgq80K8rKyURaWi/gBWFFbDTsFrrHuBt4GkRiRPrBcOzRaR3KU8VBvxdRMJE5GKsrvFvfaQr7e/PK8BQY8zPPvZ9BPxDrJcNmwF3curf5xPgDrseGmP1QOX6CugmIqPtnqgoERkhBb6aJyLhInK1iNTEupk7ifU4Q1UgGsBVMM2X/G85P1qCY6dgPffbifWS1j+NMYtKWxBjzHdYQeRaYBZwFKsb+EesZ5TeHgE+tbsYzzXGfADMAL62u9yXAO39PPUM4HMgGfgVqyXzWGmvoyBjzCGsZ70PA4exujxHG2Oy/Th8MNbLa3O9/o1yA1Cp6t/uvv4/YLVdf82Be4G/AClYLcq5/l8hZ8rp3wMvrivXgfVS2X6sm7F4rK5osG5wjgO/AEewfhdOe/HMT9uwehwOAE8DVxhjjhZMVNrfH2PMUWPMgkJ2/wvYiPVc/Ges4P22vW86kAD8hvUsPu/GyxhzHOsG6ja73NuASYWcYzxW1/xRrBchpxSSToWIGFPWHjyllFJKlTdtgSullFKVkAZwpZRSqhLSAK6UUkpVQhrAlVJKqUqoSk9m0rBhQ9OyZcu89ZMnTxIbW+h4E8pPWo+BofUYGFqPgaN1GRiBrMekpKT9xphGvvZV6QDesmVLEhMT89YTEhLo379/CEtUNWg9BobWY2BoPQaO1mVgBLIeRWRbYfu0C10ppZSqhDSAK6WUUpWQBnCllFKqEqrSz8CVUqq6ysnJYdeuXWRkZBSf2FajRg02bNgQxFJVD6Wtx8jISJo2bYrL5V9o1gCulFJV0K5du6hRowZnnHEG1iR1xTtx4gQ1atQIcsmqvtLUozGGI0eOsGvXLlq0aOHXMdqFrpRSVVBGRgZ16tTxO3ir0BIR6tSpU6IeEw3gSilVRWnwrlxK+u+lXehKKVWNvbBgMy8s2gJY3bgeA04RsGPJbYNacduQ1iEsoSqMBnCllKrGbhvSOi9AJ/y6i+tmreGjm/vSs0WdMuXrcrno2LFj3vqdd97JddddV6Y8S2Lbtm1cfvnlLF++/LR9LVq0YN26dX6Nlnb++edz6NAhjhw5QkZGBk2aNAFg8eLFxMXFFXv8nDlzaN++PW3atCnxNRRHA7hSSqmAi4uLY/Xq1aEuRpn98MMPAMycOZN169bx1FNPlej4OXPm4HK5NIArpZQquRMZ2Wzcd6LYdJsOpAKwYW+KX/m2bVSDGpFhJSpLvXr1mDBhAvPmzaNBgwbMnTuXmJgYnnvuOV555RUiIyPp168fr7zyCgcPHmTy5Mns2LGDsLAwXn75Zbp168aECROIjY1l2bJlHDl
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-27 16:29:31 +00:00
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], 'x-', label='Ensemble Test')\n",
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], 'x-', label='Individual Test')\n",
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], 'x-', label='Individual Train')\n",
"# plt.plot(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], 'x-', label='Agreement')\n",
"\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], yerr=std_ensem_accuracy[0, :], capsize=4, label='Ensemble Test')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], yerr=std_ensem_accuracy[2, :], capsize=4, label='Individual Test')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], yerr=std_ensem_accuracy[1, :], capsize=4, label='Individual Train')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], yerr=std_ensem_accuracy[3, :], capsize=4, label='Disagreement')\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
"plt.title(f\"Error Rate for Horizontal Ensemble Models\")\n",
"# plt.ylim(0, 1)\n",
"plt.ylim(0, np.max(1 - mean_ensem_accuracy + std_ensem_accuracy) + 0.05)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.legend()\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
2021-03-22 20:49:29 +00:00
"plt.ylabel(\"Error Rate\")\n",
2021-04-06 17:29:15 +01:00
"# plt.savefig(f'graphs/{exp2_testname}-error-rate-curves.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-03-26 20:01:05 +00:00
{
"cell_type": "markdown",
"metadata": {
"id": "FSZq1mNiVZq_",
"tags": [
"ex3"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 3\n",
"\n",
"Repeat Exp 2) for cancer dataset with two different optimisers of your choice e.g. 'trainlm' and 'trainrp'. Comment and discuss the result and decide which is more appropriate training algorithm for the problem. In your discussion, include in your description a detailed account of how the training algorithms (optimisations) work."
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def evaluate_optimisers(optimizers=[(lambda: 'sgd', 'sgd'), \n",
" (lambda: 'adam', 'adam'), \n",
" (lambda: 'rmsprop', 'rmsprop')],\n",
" weight_init=lambda: 'glorot_uniform',\n",
" print_params=True,\n",
" **kwargs\n",
" ):\n",
" for o in optimizers:\n",
" \n",
" if print_params:\n",
" print(f'Optimiser: {o[1]}')\n",
" \n",
" yield list(evaluate_ensemble_vote(optimizer=o[0],\n",
" weight_init=weight_init,\n",
" exp=f'3-{o[1]}',\n",
" print_params=print_params,\n",
" **kwargs\n",
" ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single Iteration"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimiser: sgd\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n",
"Optimiser: adam\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n",
"Optimiser: rmsprop\n",
"Models: 1\n",
"Models: 3\n",
"Models: 5\n"
]
}
],
"source": [
"single_optim_results = list()\n",
"for test in evaluate_optimisers(epochs=(5, 300), nmodels=[1, 3, 5]):\n",
" single_optim_results.append(test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multiple Iterations\n",
"\n",
"### Pickle Results\n",
"\n",
2021-03-30 16:31:10 +01:00
"| test | optim1 | optim2 | optim3 | lr | momentum | epsilon | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n",
2021-04-09 12:42:18 +01:00
"| 1 | SGD | Adam | RMSprop | 0.1 | 0.0 | 1e7 | 35 | 16 | 1 - 100 | 1, 3, 9, 15, 25 |\n",
"| 2 | SGD | Adam | RMSprop | 0.05 | 0.01 | 1e7 | 35 | 16 | 1 - 100 | 1, 3, 9, 15, 25 |"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 16:31:10 +01:00
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
]
}
],
"source": [
"multi_optim_results = list()\n",
2021-03-30 16:31:10 +01:00
"multi_optim_iterations = 30\n",
"\n",
2021-04-09 12:42:18 +01:00
"multi_optim_lr = 0.05\n",
"multi_optim_mom = 0.01\n",
"multi_optim_eps = 1e-07\n",
"multi_optims = [(lambda: tf_optim.SGD(learning_rate=multi_optim_lr, \n",
" momentum=multi_optim_mom), 'sgd'), \n",
" (lambda: tf_optim.Adam(learning_rate=multi_optim_lr, \n",
" epsilon=multi_optim_eps), 'adam'), \n",
" (lambda: tf_optim.RMSprop(learning_rate=multi_optim_lr, \n",
" momentum=multi_optim_mom, \n",
" epsilon=multi_optim_eps), 'rmsprop')]\n",
"\n",
"for i in range(multi_optim_iterations):\n",
" print(f\"Iteration {i+1}/{multi_optim_iterations}\")\n",
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)\n",
" multi_optim_results.append(list(evaluate_optimisers(epochs=(1, 100),\n",
" hidden_nodes=16,\n",
2021-03-30 16:31:10 +01:00
" nmodels=[1, 3, 9, 15, 25],\n",
" optimizers=multi_optims,\n",
" weight_init=lambda: 'random_uniform',\n",
" batch_size=35,\n",
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-30 16:31:10 +01:00
"30 Tests\n",
"Optimisers: ['SGD', 'Adam', 'RMSprop']\n",
2021-03-30 16:31:10 +01:00
"Models: [1, 3, 9, 15, 25]\n",
"\n",
"Loss: categorical_crossentropy\n"
]
}
],
"source": [
"multi_optim_results_dict = dict() # indexed by optimiser name\n",
"multi_optim_iter = len(multi_optim_results) # number of iterations (30)\n",
"\n",
"#####################################\n",
"## INDIVIDUAL RESULTS TO DICTIONARY\n",
"#####################################\n",
"for iter_idx, iteration in enumerate(multi_optim_results): # of 30 iterations\n",
" for model_idx, model_test in enumerate(iteration): # of 3 optimisers\n",
" for single_optim_test in model_test: # single tests for each optimisers\n",
" \n",
" single_optim_name = single_optim_test[\"optimizer\"][\"name\"]\n",
" if single_optim_name not in multi_optim_results_dict:\n",
" multi_optim_results_dict[single_optim_name] = list(list() for _ in range(multi_optim_iter))\n",
"\n",
" multi_optim_results_dict[single_optim_name][iter_idx].append(single_optim_test)\n",
"\n",
"# list of numbers of models used in test\n",
"multi_optim_models = sorted(list({i[\"num_models\"] for i in multi_optim_results[0][0]}))\n",
"\n",
"##################################\n",
"## DICTIONARY TO RESULTS TENSORS\n",
"##################################\n",
"optim_tensors = dict()\n",
"for optim, optim_results in multi_optim_results_dict.items():\n",
" \n",
" accuracy_optim_tensor = np.zeros((multi_optim_iter, 4, len(multi_optim_models)))\n",
" for iter_idx, iteration in enumerate(optim_results):\n",
" for single_test in iteration:\n",
"\n",
" optim_models_idx = multi_optim_models.index(single_test['num_models'])\n",
2021-03-30 16:31:10 +01:00
" accuracy_optim_tensor[iter_idx, :, optim_models_idx] = test_tensor_data(single_test)\n",
" \n",
" optim_tensors[optim] = {\n",
" \"accuracy\": accuracy_optim_tensor,\n",
" \"mean\": np.mean(accuracy_optim_tensor, axis=0),\n",
" \"std\": np.std(accuracy_optim_tensor, axis=0)\n",
" }\n",
"\n",
"print(f'{multi_optim_iter} Tests')\n",
"print(f'Optimisers: {list(multi_optim_results_dict.keys())}')\n",
"print(f'Models: {multi_optim_models}')\n",
"print()\n",
"print(f'Loss: {multi_optim_results[0][0][0][\"loss\"]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
2021-03-30 16:31:10 +01:00
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 13,
"metadata": {},
2021-03-30 16:31:10 +01:00
"outputs": [],
"source": [
2021-04-09 12:42:18 +01:00
"pickle.dump(multi_optim_results, open(\"results/exp3-test2.p\", \"wb\"))"
]
},
{
2021-04-06 17:29:15 +01:00
"cell_type": "code",
"execution_count": 97,
"metadata": {},
2021-04-06 17:29:15 +01:00
"outputs": [],
"source": [
2021-04-06 17:29:15 +01:00
"exp3_testname = 'exp3-test1'\n",
"multi_optim_results = pickle.load(open(f\"results/{exp3_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Best Results"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-04-09 12:42:18 +01:00
"SGD: 9 Models, 72.3% Accurate\n",
"Adam: 25 Models, 96.5% Accurate\n",
"RMSprop: 1 Models, 96.7% Accurate\n"
]
}
],
"source": [
"for optim, optim_results in optim_tensors.items():\n",
" best_optim_accuracy_idx = np.unravel_index(np.argmax(optim_results[\"mean\"][0, :]), optim_results[\"mean\"].shape)\n",
" best_optim_accuracy = optim_results[\"mean\"][best_optim_accuracy_idx]\n",
" best_optim_accuracy_models = multi_optim_models[best_optim_accuracy_idx[1]]\n",
"\n",
" print(f'{optim}: {best_optim_accuracy_models} Models, {best_optim_accuracy * 100:.3}% Accurate')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optimiser Error Rates"
]
},
{
"cell_type": "code",
2021-04-09 12:42:18 +01:00
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
2021-04-09 12:42:18 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABU4AAAFECAYAAAD4EaSNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAADOoElEQVR4nOzdd3wUxf/H8dekJySh966oIIQWQEAFFEVEmgo2VPj6RcUOCvq1YvmqXxUFK+rPggUsqIRuAcSgCAgIIkjvvZOEkJAyvz/2Ei4huUsgl0vC++njzN3u7OzsJOznZnZ21lhrEREREREREREREZETAvxdABEREREREREREZGSRh2nIiIiIiIiIiIiIrmo41REREREREREREQkF3WcioiIiIiIiIiIiOSijlMRERERERERERGRXNRxKiIiIiIiIiIiIpKLOk5FSgDj+MwYc9gYM8nf5fElY8xcY8zN+azrYoxZX9xlyo+nsuZKN8gYM6s4yiQipY8xxhpj6vi7HGWVMaafMWaHMSbJGFPN3+XxFW+xxhiz2RhzUXGWKT+FiYv69yEiUjIYY5oYY1YYYxKNMdf6uzy+YoxpYIxJ97B+nDHmieIsU368lTVX2gK1XaXw1HEq+TLGdDLGLDDGHDHGHDDGzDHGNHRbH2OMiTPGHHJ1+K0wxjxmjAlzrX/aGJPmOvEmGmP+NsY8Y4yJ8LDPzcaYZFfjJ+t1d3Ecr2v/1hhz1LXfrYU5YZ7mF/+LgQuBGtbaq08xD/eyzDXGpOSqx5dON9+SzFX/m4wxxm1ZfWNMpjFmrh+LJiJllDFmjTFmqb/LATlirvt5f34x7n+cMSbVtd8DxpjJxphahdj2dBooLwP/stZGWmv3nkY+WR1+6bnqcfvp5FnSuerfGmMuybV8jmt5Az8VTURKqFxttp3GmDeMMYFu6+e6zh8Nc223wRhj3T7HGGNmu9qTh4wx840xbYvzWHJzDeTIzBUHkowx5Ytp/+5xKMEYs8gYc2Ehtj2dwRwjgKnW2ihr7benkU9Wh5/Nox5jTiffksxV/9YY80yu5U+5lg/yU9HkNKnjVPLkCgyTcRojFYH6wJtAhmt9E+A3YCXQxFpbAegLVAXqumX1ibU2yrX8dqA78JN7YM1DN1fjJ+v1Th7lCyrIMg/HZ4wx+f39n2etjQSuBh41xnQraL6noR6w0VqbUtgNPRz34Fz1+MjpFbFUSAM6un2+CSgxI1hFpOwwxrQDagHNXDGxJPgk13m/Y16JcseNwsRPL+mfc8XPs4BwYFRh8j0N9YBVhd3Iw3HMzVWPZ8JoyHU4MRMAY0xtoCFw3G8lEpGSrpvrnH8xcC3w71zrc59XOuBqS7qZgtPmrArUBJ4CUgtbkMLGsQLYmCsORFprj3jbr5c25km8xSGgEvAT8HVhCn8aijqeZuRRjytOr4gl3gbghlzL1CYt5dRxKvk5F0i11n5nrc201iZZaydZa7e61o8EZllrH7fW7gaw1m6w1g6z1q7LnZm1NsVa+ztO52oLoGdhC+S6cvmcMWYxcNQYc7kxZr1xRrHuB54xxlQ0xnxhjNnvuqJ5p9v244wxbxlj5gDJwNme9metXYLTMdzSLY/vjDF7jTEHjTETjTGVXMt/dCVZ47qSdrFr+T3GmHWu8nxijCmXx3ENAD4Auri2vd8YE+A6rm3GmF2uq7ihrvSDjDE/G2PeM8YcAf5VyHoc58pvtnFGAv/odhxVjTEzjTOCeL8x5gu37TobY5a41s01xpztWt7AdVX0TmPMbterjzGmrzFmozFmnzEm9xep84wxfxrnyvInxpjwfMpazxgz3Tijl/4xxnT3cnhfAAPcPg8AxufK82LXvg8bY35x7/AwxrQ1xvxlnKu77+J2jjTGBLp+J1uMMXuMMa/m04Ef7vobPOh6zfNSZhEpnW7Gaez9BNzivsIYc5vr/L3bGHNHHuvWus6/fxljuritm2uMedZ1rk0yxnxojKlpnJF/CcaYb7NiQWHkFTeMM1roYWPMP7i+zBtj7nY7b39uXKNrjGsaFeMWbz3tz9W4jCNn/MzzuI0xA3HO1U+6jvld1/IYY0y8K04sMca0yefYkoBAnPi7yLXM03neGmPuNcZsAn4uZD1m1cNTrvP7ZmPMFW7rHzdOzE4wzl0457uWVzLGTDDO94eNrmPO2qawv/MAY8z/udYtM8a0JA8FjVluJgFXue3rJuBLwH1kWEWT/3escsaY8a46Xwqck6s8eX6HyKPct7nKnGicEd1dPJRZREoAa+0GnAE1LXOtyv29/GbcvpcbY6oCDYD/s9amu9qLs6y1f7nWP+0658S5zgnxxjUC3pxofwwxxuwAxhljwowxbxsn9m51nasDvOVVGPns96Q2ZlHFIWttuqvOahnXVDTGmAuMMX+44sAWY8x9ruVnAe9yok250rU83xiU69hmApcAH7i2r2yMqWuMmWGcWLzKGNPHLX3utnlwIevSGmPuMs4dg/uNMY+6revpigGJxom1N7iW5xvbXL/jCa64mWSM+c0YU8MY865x7qBd6qoj9zLc5cpnq/Fwi7txpgRaaZzYP8V4nhZoM7DfOBfYMc4I6kM4HapZ+eXb1netf8xVrs1A71xlKVD72BjT3vU3mGCc6YyGeSizeKGOU8nPWiDEGPOBcTooo3OtvwSnwVgo1tpdwGKc29JPxY04V3DKA+k4wTYD5wrlf4G3XOnq4YwY/a8xprPb9jfg3IIQhXNSy5cx5gKgGW4nOeA7nBEYDV15POU6rqxRqee5rqTNM8b0B4YAl+GMwg0mj8amtXa8K13WCJc3cK7Y9gM6uMoQCzzqttnFwO84o4E/93Qc+bgOGIZzdTcQeMC1/CFgE1AFqI0zyhhjTF3gG2AoUBn4FqdBlSUQp7O9HvAf4P+Aa1xlvw543RgT6Zb+VuB6nHqsBzyWu4CuLzpTgR+A6sBtwGfGmOoejusboI8xJtg4jcljwBq3PCvjXNl+znXs04ApxpggY0wIzu93rOsYV5Jz9OqDOPXeBjgPaI3ze8ttIFAOp/6q5XVsIlK6ub6kXw98hXMuvMkYZ5oQY0wzYAzOObwhkHvOyd1AV6ACzjn2S5OzY+xanIuMjYCrcM5Z9+GcU87FbfROIeUVN64BugBNjDGXA0/iXNhsgDNi9HW37RuQM97myxhT0ZW3e/zM87ittZ/gNAqfc8XAIa548b1r/1VwztnfGddUQO5cI3LAib/tPJ3n3Ta7HOci7qncUdIAZzRUNeAF4H3XMTfGiQmtcL6j9AcOurb5DNiJ812gB/CiMaa5W56F+Z13ApbixKkPceolrw7RgsasLInAr67ygdPZkfv7hafvWCNxYnU9V3lvzdqoAN8hstKVw/m3c5nrbqVuwBYPZRaREsAYcy7O+WZDrlXrgCRjTCvXeaovOUdO7ndtM94Y08t1/s7tWpwBJpWBBcCnbusCcTprz8a5s/FJoCnQBCf23ozbuchLXoWRe7+Qs42ZQBHFIVf75FacjrdDrsVpwJ048fRanHNxK2vtRnK2KZu60nuLQQBYa68E5nHirsUDOJ3fK4EawN3A58aYRm6b5W6bF9alQAzOd5GR5sRFtQ+A21yxoD3wl2u5t9jWByeOVAKO4nzvmYPzO/8TJ1ZlCQTa4dxZewPwjutvOQdXB+gYV5rqwGrgpDticxnPiYsGecXTfNv6xpgewD2u42yJW8dpIdvHY4BR1tpo1z7meimzeGKt1UuvPF84/8A+w2nspOL8g49yrUsHrnBL+xpwGOcEdYtr2dPAB3nk+yXOlcW89rkZ58v7YbdXZ9e6ucCjbmm7uPYX5PociHNLWUO3NC8C77nej8t67+GYLXAE52qhxWncBeST9gpgca5t67h9/h64KVd9bs4nr0E4I3izPs/Gma/NfV9r3NKu8XIcc111416PWb+XccCbbmnvBuJc75/DGXXSMFd+/wHez7VsH04DsoHr2Cu5loe7Prd2S7sHaOlWtqfd1l0GrHX7na53vW+ftdwt7TfAIA+/uzqu8vfEuUV0KE6Qm+tKcwvwi9s2AcAOnKDVGdjkts4A24CbXZ9
"text/plain": [
"<Figure size 1680x350 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(24, 5))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, ((optimiser_name, tensors_dict), ax) in enumerate(zip(optim_tensors.items(), axes.flatten())):\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][0, :], 'x-', label='Ensemble Test')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][2, :], 'x-', label='Individual Test')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][1, :], 'x-', label='Individual Train')\n",
" ax.plot(multi_optim_models, 1 - tensors_dict[\"mean\"][3, :], 'x-', label='Disagreement')\n",
"\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][0, :], yerr=tensors_dict[\"std\"][0, :], capsize=4, label='Ensemble Test')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][2, :], yerr=tensors_dict[\"std\"][2, :], capsize=4, label='Individual Test')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][1, :], yerr=tensors_dict[\"std\"][1, :], capsize=4, label='Individual Train')\n",
"# ax.errorbar(multi_optim_models, 1 - tensors_dict[\"mean\"][3, :], yerr=tensors_dict[\"std\"][3, :], capsize=4, label='Disagreement')\n",
"\n",
" ax.set_title(f\"{optimiser_name} Error Rate for Ensemble Models\")\n",
"# ax.set_ylim(0, 1)\n",
" ax.set_ylim(0, np.max([np.max(1 - i[\"mean\"] + i[\"std\"]) for i in optim_tensors.values()]) + 0.03)\n",
" ax.grid()\n",
" ax.legend()\n",
" ax.set_xlabel(\"Number of Models\")\n",
" ax.set_ylabel(\"Error Rate\")\n",
"\n",
2021-04-06 17:29:15 +01:00
"# plt.savefig(f'graphs/{exp3_testname}-error-rate-curves.png')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2021-03-19 17:21:00 +00:00
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyNAMGLKzaoWaq1wvQ+w0w8h",
"collapsed_sections": [],
"name": "nncw.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"toc-autonumbering": false,
2021-03-27 16:29:31 +00:00
"toc-showcode": false,
2021-03-22 20:49:29 +00:00
"toc-showmarkdowntxt": false,
2021-03-27 16:29:31 +00:00
"toc-showtags": false
2021-03-19 17:21:00 +00:00
},
"nbformat": 4,
"nbformat_minor": 4
}