shallow-training/nncw.ipynb

2187 lines
2.0 MiB
Plaintext
Raw Normal View History

2021-03-19 17:21:00 +00:00
{
"cells": [
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 1,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2450,
"status": "ok",
"timestamp": 1615991459232,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "TGIxH9Tmt5zp"
},
2021-03-26 20:01:05 +00:00
"outputs": [],
2021-03-19 17:21:00 +00:00
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"tf.get_logger().setLevel('ERROR')\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import seaborn as sns\n",
"import random\n",
"import pickle\n",
"import json\n",
2021-03-22 20:49:29 +00:00
"import math\n",
2021-03-19 17:21:00 +00:00
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
2021-03-26 20:01:05 +00:00
"fig_dpi = 70"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fksHv5rXACEX"
},
"source": [
"# Neural Network Training\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l4zqVWyRAM0Z"
},
"source": [
"## Load Dataset\n",
"\n",
"Read CSVs dumped from MatLab and parse into Pandas DataFrames"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 4,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 331
},
"executionInfo": {
"elapsed": 2441,
"status": "ok",
"timestamp": 1615991459234,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "Hj5l_tdZuYh7",
"outputId": "fbfa9406-f662-4ebc-8ba2-67950714627c"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Clump thickness</th>\n",
" <th>Uniformity of cell size</th>\n",
" <th>Uniformity of cell shape</th>\n",
" <th>Marginal adhesion</th>\n",
" <th>Single epithelial cell size</th>\n",
" <th>Bare nuclei</th>\n",
" <th>Bland chomatin</th>\n",
" <th>Normal nucleoli</th>\n",
" <th>Mitoses</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" <td>699.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.441774</td>\n",
" <td>0.313448</td>\n",
" <td>0.320744</td>\n",
" <td>0.280687</td>\n",
" <td>0.321602</td>\n",
" <td>0.354363</td>\n",
" <td>0.343777</td>\n",
" <td>0.286695</td>\n",
" <td>0.158941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.281574</td>\n",
" <td>0.305146</td>\n",
" <td>0.297191</td>\n",
" <td>0.285538</td>\n",
" <td>0.221430</td>\n",
" <td>0.360186</td>\n",
" <td>0.243836</td>\n",
" <td>0.305363</td>\n",
" <td>0.171508</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" <td>0.200000</td>\n",
" <td>0.100000</td>\n",
" <td>0.300000</td>\n",
" <td>0.100000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.600000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.400000</td>\n",
" <td>0.500000</td>\n",
" <td>0.500000</td>\n",
" <td>0.400000</td>\n",
" <td>0.100000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Clump thickness Uniformity of cell size Uniformity of cell shape \\\n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.441774 0.313448 0.320744 \n",
"std 0.281574 0.305146 0.297191 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.400000 0.100000 0.100000 \n",
"75% 0.600000 0.500000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Marginal adhesion Single epithelial cell size Bare nuclei \\\n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.280687 0.321602 0.354363 \n",
"std 0.285538 0.221430 0.360186 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.100000 0.200000 0.100000 \n",
"50% 0.100000 0.200000 0.100000 \n",
"75% 0.400000 0.400000 0.500000 \n",
"max 1.000000 1.000000 1.000000 \n",
"\n",
" Bland chomatin Normal nucleoli Mitoses \n",
"count 699.000000 699.000000 699.000000 \n",
"mean 0.343777 0.286695 0.158941 \n",
"std 0.243836 0.305363 0.171508 \n",
"min 0.100000 0.100000 0.100000 \n",
"25% 0.200000 0.100000 0.100000 \n",
"50% 0.300000 0.100000 0.100000 \n",
"75% 0.500000 0.400000 0.100000 \n",
"max 1.000000 1.000000 1.000000 "
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 4,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('features.csv', header=None).T\n",
"data.columns = ['Clump thickness', 'Uniformity of cell size', 'Uniformity of cell shape', 'Marginal adhesion', 'Single epithelial cell size', 'Bare nuclei', 'Bland chomatin', 'Normal nucleoli', 'Mitoses']\n",
"labels = pd.read_csv('targets.csv', header=None).T\n",
"labels.columns = ['Benign', 'Malignant']\n",
"data.describe()"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 31,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"executionInfo": {
"elapsed": 2436,
"status": "ok",
"timestamp": 1615991459236,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "qc1Mku6h041u",
"outputId": "94e38c34-0419-4a02-ac8c-17bbc83f777b"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Benign</th>\n",
" <th>Malignant</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Benign Malignant\n",
"0 1 0\n",
"1 1 0\n",
"2 1 0\n",
"3 0 1\n",
"4 1 0"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 31,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h9QsJjWEMbLu"
},
"source": [
"### Explore Dataset\n",
"\n",
"The classes are uneven in their occurences, stratify when splitting later on"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2430,
"status": "ok",
"timestamp": 1615991459237,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "rjjiSYAZMa4k",
"outputId": "ae0c3772-00be-4f2b-80d2-9cd62a6b6e08"
},
"outputs": [
{
"data": {
"text/plain": [
"Benign 458\n",
"Malignant 241\n",
"dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels.astype(bool).sum(axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E9lVYI14AUMf"
},
"source": [
"## Split Dataset\n",
"\n",
"Using a 50/50 split"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 5,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2604,
"status": "ok",
"timestamp": 1615991459418,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "L83Ae5l9wM35"
},
"outputs": [],
"source": [
"data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qf2U199fNjmJ"
},
"source": [
"## Generate & Retrieve Model\n",
"\n",
"Get a shallow model with a single hidden layer of varying nodes"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 6,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2598,
"status": "ok",
"timestamp": 1615991459419,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "SgoQ-NjWB0T5"
},
"outputs": [],
"source": [
"def get_model(hidden_nodes=9, activation=lambda: 'sigmoid', weight_init=lambda: 'glorot_uniform'):\n",
" layers = [tf.keras.layers.InputLayer(input_shape=(9,)), \n",
" tf.keras.layers.Dense(hidden_nodes, activation=activation(), kernel_initializer=weight_init()), \n",
" tf.keras.layers.Dense(2, activation='softmax', kernel_initializer=weight_init())]\n",
"\n",
" model = tf.keras.models.Sequential(layers)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT5B9PTUN3pj"
},
"source": [
"# Example Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQGAUtIPAd6e"
},
"source": [
"## Define Model\n",
"\n",
"Variable number of hidden nodes. All using 9D outputs except the last layer which is 2D for binary classification"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 60,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 7889,
"status": "ok",
"timestamp": 1615991464716,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "fYA34P0Vu_pX",
"outputId": "aded18b8-aa7f-4362-a614-837c8a0f526f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Model: \"sequential_1\"\n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
2021-03-26 20:01:05 +00:00
"dense_2 (Dense) (None, 9) 90 \n",
2021-03-19 17:21:00 +00:00
"_________________________________________________________________\n",
2021-03-26 20:01:05 +00:00
"dense_3 (Dense) (None, 2) 20 \n",
2021-03-19 17:21:00 +00:00
"=================================================================\n",
"Total params: 110\n",
"Trainable params: 110\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = get_model(9)\n",
"model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KZSwFe-AAs1y"
},
"source": [
"## Train Model\n",
"\n",
"Example 10 epochs"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 61,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11304,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "s8U9Atu3zelS",
"outputId": "8439e8d2-7a5d-495f-a192-a34f01e95bfa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 1s 2ms/step - loss: 0.6257 - accuracy: 0.6607\n",
"Epoch 2/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 3ms/step - loss: 0.6226 - accuracy: 0.6651\n",
"Epoch 3/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 2ms/step - loss: 0.6326 - accuracy: 0.6424\n",
"Epoch 4/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 3ms/step - loss: 0.6158 - accuracy: 0.6696\n",
"Epoch 5/5\n",
2021-03-26 20:01:05 +00:00
"11/11 [==============================] - 0s 2ms/step - loss: 0.6228 - accuracy: 0.6534\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
"text/plain": [
2021-03-26 20:01:05 +00:00
"<tensorflow.python.keras.callbacks.History at 0x2cd249f3400>"
2021-03-19 17:21:00 +00:00
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 61,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"model.fit(data_train.to_numpy(), labels_train.to_numpy(), epochs=5)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 62,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11294,
"status": "ok",
"timestamp": 1615991468137,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "VnUEJdXovzi-",
"outputId": "02075086-352c-4a23-fac5-ad54d11e0e35"
},
"outputs": [
{
"data": {
"text/plain": [
"['loss', 'accuracy']"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 62,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics_names"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 63,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 11285,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "r0vxP3Ah42ib",
"outputId": "061113ba-52db-4fbe-c7f9-b5d3d85438ed"
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=float32, numpy=0.6561605>"
]
},
2021-03-26 20:01:05 +00:00
"execution_count": 63,
2021-03-19 17:21:00 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.metrics[1].result()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "z7bn8pKTAynt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 1\n",
"\n",
"The below function runs an iteration of layer/epoch investigations.\n",
"Returns the amount of layers/epochs used as well as the results and the model.\n",
"\n",
"Using cancer dataset (as in E2) and trainscg or an optimiser of your choice, vary nodes and epochs (that is using early stopping for epochs) over suitable range, to find optimal choice in terms of classification test error rate of node/epochs for 50/50% random train/test split (no validation set). It is suggested that you initially try epochs = [ 1 2 4 8 16 32 64], nodes = [2 8 32], so there would be 21 node/epoch combinations. \n",
"\n",
"(Hint1: from the advanced scriptin E2, nodes can be changed to xx, with hiddenLayerSize = xx; and epochs changed to xx by addingnet. trainParam.epochs = xx; placed afternet = patternnet(hiddenLayerSize, trainFcn); --see trainscg help documentation for changing epochs). \n",
"\n",
"Repeat each of the 21 node/epoch combinations at least thirty times, with different 50/50 split and take average and report classification error rate and standard deviation (std). Graph classification train and test error rate and std as node-epoch changes, that is plot error rate vs epochs for different number of nodes. Report the optimal value for test error rate and associated node/epoch values. \n",
"\n",
2021-03-22 20:49:29 +00:00
"(Hint2: as epochs increases you can expect the test error rate to reach a minimum and then start increasing, you may need to set the stopping criteria to achieve the desired number of epochs Hint 3: to find classification error rates for train and test set, you need to check the code from E2, to determine how you may obtain the train and test set patterns)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 64,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 11274,
"status": "ok",
"timestamp": 1615991468138,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "mYWhCSW4A57V",
"tags": [
"exp1",
"exp-func"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [],
"source": [
2021-03-26 20:01:05 +00:00
"# hidden_nodes = [2, 8, 16, 24, 32]\n",
"# epochs = [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"hidden_nodes = [2, 8, 16]\n",
"epochs = [1, 2, 4, 8]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_parameters(hidden_nodes=hidden_nodes, \n",
" epochs=epochs, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
2021-03-26 20:01:05 +00:00
" run_eagerly=False,\n",
2021-03-19 17:21:00 +00:00
" \n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for idx1, hn in enumerate(hidden_nodes):\n",
" for idx2, e in enumerate(epochs):\n",
" if print_params:\n",
" print(f\"Nodes: {hn}, Epochs: {e}\")\n",
"\n",
" model = get_model(hn)\n",
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
2021-03-26 20:01:05 +00:00
" metrics=metrics,\n",
" run_eagerly=run_eagerly\n",
2021-03-19 17:21:00 +00:00
" )\n",
" \n",
" response = {\"nodes\": hn, \n",
" \"epochs\": e,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TRAIN\n",
" ##############\n",
" \"history\": model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
" epochs=e, \n",
" verbose=verbose,\n",
" \n",
" callbacks=callbacks,\n",
" validation_split=validation_split).history,\n",
2021-03-26 20:01:05 +00:00
" ##############\n",
" ## TEST\n",
" ##############\n",
" \"results\": model.evaluate(dtest.to_numpy(), \n",
" ltest.to_numpy(), \n",
2021-03-19 17:21:00 +00:00
" batch_size=batch_size, \n",
" verbose=verbose),\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"loss\": model.loss,\n",
" \"model_config\": json.loads(model.to_json())\n",
" }\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"model\"] = model\n",
"\n",
" yield response"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "r-63V9qb-i4w",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Single Iteration\n",
"Run a single iteration of epoch/layer investigations"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 65,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 313592,
"status": "ok",
"timestamp": 1615991770468,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "ZmGFkE9y8E4H",
2021-03-26 20:01:05 +00:00
"outputId": "243fb136-bc07-4438-afb7-f2d21758168d",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nodes: 2, Epochs: 1\n",
"Nodes: 2, Epochs: 2\n",
"Nodes: 2, Epochs: 4\n",
"Nodes: 2, Epochs: 8\n",
"Nodes: 8, Epochs: 1\n",
"Nodes: 8, Epochs: 2\n",
"Nodes: 8, Epochs: 4\n",
"Nodes: 8, Epochs: 8\n",
"Nodes: 16, Epochs: 1\n",
"Nodes: 16, Epochs: 2\n",
"Nodes: 16, Epochs: 4\n",
2021-03-26 20:01:05 +00:00
"Nodes: 16, Epochs: 8\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
2021-03-26 20:01:05 +00:00
"# es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience = 5)\n",
"single_results = list(evaluate_parameters(return_model=False, validation_split=0.2\n",
" , optimizer = lambda: tf.keras.optimizers.SGD(learning_rate=0.5, momentum=0.5)\n",
"# , callbacks=[es]\n",
" ))"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "mdWK3-M6SK8_",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Train/Test Curves\n",
"\n",
"For a single test from the set"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 68,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 517
},
"executionInfo": {
"elapsed": 314527,
"status": "ok",
"timestamp": 1615991771417,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "F9Xre1o6SesD",
2021-03-26 20:01:05 +00:00
"outputId": "d6b817aa-58cd-4510-807f-e5e6bcf62f18",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Nodes: 2, Epochs: 4\n"
2021-03-19 17:21:00 +00:00
]
},
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAGuCAYAAACNy6eFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAACAaUlEQVR4nOzddZxU1f/H8dfZJpalY5eU7u62SJVUsQMVMb4WtvwwsMUEJFQMREJQJERBlqW7uxtZuhd29/z+uAMsuMDCxp2ZfT8fj3kwc2s+Z2aYu585536OsdYiIiIiIiIiqRfgdgAiIiIiIiL+QgmWiIiIiIhIGlGCJSIiIiIikkaUYImIiIiIiKQRJVgiIiIiIiJpRAmWiIiIiIhIGlGCJV7BGPO1MebFFG670hhTP71jyijGmOLGmPgkjycaY+64xLbNjDEbrvF5Ghtjll5rnCIicp7OWzpviVyK0TxYkhrGmGNJHmYDTgBnP1QVrLXbMj6qjGeMaQBMAApaa08lWZ4d+Beob61ddol9iwMbrLVBKXieZsBga22pFGy7BbjHWjsjBU24ZsaYITjxv5OezyMikhZ03nJk5vNWkucbCDwAFLLW7s+I55TMQT1YkirW2uxnb0AcUDHJsm0AxuHXnzVr7SxgP9DmolXtgY2XOkmJiEjG0nnLkdnPW8aYUKATcAxItvctnZ73ikmp+D6//vIQ9xhjhhhjvjLG/IPz62BJY8xDxph1xpijxphlnl+1km7/uuf+A8aYf4wx/Y0xR4wxq4wxNZJsu8UY0yjJfl8YY6Z4jvuXMSZ3km0fMcbsMMbs8dy3xpjCycQ72Bjzfxct22SMaWSMyecZ/nDIGLPPGDPsEs0eCtx10bK7gZ+MMbmMMX969o81xgz0fLkn99pFG2Pu8dwPNMZ8bozZb4xZC9S7aNsvjTG7PLH9ZYwperY9QFHgL2PMMWPM3RcP0zDGVDTGTPfsu9AY0/Ci1/h5Y8xqz/qvLtHmSzLGhBlj+npe+23GmJ5n/2AxxtQzxiz2vL87jTHPXm65iEh603nrnMxy3moDHAfeB+65KMYSxpjxnjbsNsY87VkeZIx5yxiz1Rhz2BgT7Vn+n2GQSd83T2wvGmNWAxsu9zpc6vmNMYU9n61sSbZ70Bjz1xXaKS5QgiXp6U6gBxAObAH2ADcAOYEvgV8u9WUNNAZigFzAaODTyzzP7cCzQD4gEPgfgDGmMvAR0A4oATS4zDGGe46DZ986QDAwE3ge2AzkBaI8sSfnJ6C1MSbCc4wCwPXAzzj/1/p69q8C1AIev0w8Zz0GNAcqev7tctH6mUB5oBCwA/gCwFrbFdgG3Oz5VXZo0p2MMSHAH8BInNftQ+APY0yuJJu1w3kfKgG3G2OapyDepN7wxF0eaIRzArvPs+4z4GNrbQ7P8aOvsFxEJCPovJV5zlv3ACOAX4C6xpjrPM8TBIwH5nvaXtYTM8BLQEvPc+QGel7x1TivA9DM03a4xOtwqee31u4AFgC3JjlmF5z3SryMEixJT79aaxdaa+OttWestROstduttQnW2kE4Y95LX2LfNdbaYdbaBJwvj6qXeZ6R1tplnjHkvybZtiMw2lq7wFp7ErjcNUL/APmMMRU9j2/3HNcCZ3C+AItYa+M8wyr+w1q7DljueV5wTtQzrLU7rLX7rbV/ePbfDQzASTqupDPQx1q7x1q7i4tOktbaX6y1hz3t+yCFxwSoCwRYa7/wvDfDgbU4J46zPrPW7vN8qUdz+fcgOXcCb1prD3qG3XzC+RPtGaCUMSa3Z/3iKywXEckIOm9lgvOWMSYn0BoY7jk/zeF8L1ZdnAT7LWvtKWvtEWvtQs+6B4DXrLXbPJ+JmBTGDvC5tfZfT7sv9zpc7vl/wnMe9STDDXCSefEySrAkPe1I+sAY084Ys8jTHX4IyA/kucS+/ya5fwLIfpnnudS2BS+K4YJ4kvKcEEcBdxhjDM4JYrhn9Uc4v6pNM8asMcY8fJlYfuL8cIu7PY8xxoQbY37wDPs4AvTh0m1PqhCwPcnjpPcxxrxmjNngOea8FB4TIPLiYwFbPcvPupr34FLPkfRi8aTH74rz6+YGY8wMc7661qWWi4hkBJ23Msd563Zgl7V2nufxLzhtBygMbLXWJiazX2GcnsFrcfFn61Kvw+WefxTQ1NNr1xn401p75BrjkXSkBEvS07kSlZ4hFcOA14A81tqcwF7ApOPz78HpXj/rP2PYL3J2uEU9INFaOxfA8+vR/6y1RXF+vfry7FCCZPwCNDLGNMUZojDKs/w5nCEN1TzD354jZW3fDRRJ8vjcfc9zdMf5FS4CqHPRvpcrEbrrouOCM/Z9VwpiSqldnmP+5/jW2rXW2ttx/lj5BeezccnlIiIZROetzHHeugeINM51bnuAXkAZzzDL7UAxT9J6se1A8WSWHweynH3g6V26WNLP1uVeh0s+vyeZmoTT49gF5xo68UJKsCSjhAIhOCcnjDH/w/niTk9jgI7GmBrGmDDg1StsPx2nW/4dnHHZABhj2hhjrvN82R3G+ZJMSO4A1tq9OMM2fgDGJfllKRzn17TDxphiOF+sKTEKeNYYU8AYUwh4Msm6cJxhIPtwSg2/ftG+e0n+RAAw19O2Jz0X7XbGGQv+ZwrjuliQcYpanL0F4Zz43zDOhdJFcE7Ov3ie925jTB5rbTxwFM/reanlIiIu0HnLD89bnrY0xLk+rJrnVhHnuqd7cHqTjuKcv8KMMTmMMTU9uw8B3jHGFDFOMY8mnuXrgFzGmKaexPyNK4Rxudfhcs8PTg/js562T7iatkvGUYIlGcLzhd0D55eXPThd4dc08eBVPOdS4GWci2K3AGfHMMddYvtEnItnryfJiQooA0zF+cIbDzxjrd16maf+CedXtZ+SLPscZ9jEQZzx9mNS2IwBOBdNr8YZT/5LknV/4lwkuxVnDP3FY+w/AN73DG25oEqUtfY0zoWyXXDK9L4C3GqtPZjCuC72f8DJJLevgbdxxsevAWZ7Yv/es31rYK0x5ijwNOeLX1xquYhIhtJ5y2/PW3cD0621sz3Xie2x1u4BvuJ8ufa2ONc37cY5j50drv4RMMUT936cni+stYdxCpWMwBlCOP8KMVzydfD8wHip5weYCBQAxlhrk/1ciPs00bBkGsaYssAyIMzqgy8iIl5O5y1JjjFmBfA/a+0Ut2OR5KkHS/yaMaatp4s9AngPGKuTlIiIeCudt+RyjDE3AVlxeijFSynBEn93B05VoS04n/enXI1GRHyOMWaMMeagMWbUJdbXMcas9FQEu5p5cUSSo/OWJMsYMxxnyOXTl6gyKF5CQwRFREQuwxjTDOei9PuttZ2SWT8feBhYiXNdxSPW2uUZGaOIiHgP9WCJiIhchrU2GqdYwH8YYyKBIM+ksQk4vy63zcDwRETEywS5HUBSBQoUsCVKlEjVMY4dO0b27Fc7H6rvUPt8m9rn2/y9fZD6Ns6dO/dfa23BNAzJ20UCO5M83gk0TW5DY0xXnMm0yZo1a93ixYun6okTEhIIDAxM1TG8mdrn29Q+36b2XdmqVasueb7zqgSrRIkSzJkzJ1XHiImJoUmTJlfe0Eepfb5N7fNt/t4+SH0bjTFb0i4a/2KtHQwMBqhXr57V+e7y1D7fpvb5NrXvyi53vtMQQRERkWu3C4hK8jjKs0xERDIpJVgiIiLXyFq7C0gwxlQxxgQCd+JMEisiIpmUVw0RFBER8TbGmMlAVSCbMWYH0Bl4A+jqSbCeBIYBYcCPqiAoIuIF4uPh+HE4duw/t3zz50OJElCkSLo8tRIsEclU4uPj2bFjB6dOnbrqfcPDw1mzZk06ROU9rqaNhQoVIiIiIp0jcp+19sZkFrdOsn4OUDG1z3O1n01//zxea/vCwsIoXLgwQUH6E0fEZ5w5899E6OjRZJOjFN9Onrzk05UHqFhRCZaISFrYsWMH4eHhFCtWDGPMVe179OhRwsPD0yky75DSNp46dYodO3ZkigQro1ztZ9PfP4/X0j5rLQcOHGDHjh2ktkqjiCTDWoiLS3mSk9Ik6fTptI0zMBAiIiB7ducWHn7
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 1050x490 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-26 20:01:05 +00:00
"single_result = random.choice([i for i in single_results if i[\"epochs\"] > 1])\n",
2021-03-19 17:21:00 +00:00
"single_history = single_result[\"history\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(15,7))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## LOSS\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[0]\n",
"ax.set_title(\"Training vs Validation Loss\")\n",
"ax.plot(single_history['loss'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_loss'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
2021-03-26 20:01:05 +00:00
"################\n",
"## ACCURACY\n",
"################\n",
2021-03-19 17:21:00 +00:00
"ax = axes[1]\n",
"ax.set_title(\"Training vs Validation Accuracy\")\n",
"ax.plot(single_history['accuracy'], label=\"train\", lw=2)\n",
"ax.plot(single_history['val_accuracy'], label=\"validation\", lw=2, c=(1,0,0))\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel(\"Epochs\")\n",
"ax.set_ylim(0, 1)\n",
"ax.grid()\n",
2021-03-19 17:21:00 +00:00
"ax.legend()\n",
"\n",
"print(f\"Nodes: {single_result['nodes']}, Epochs: {single_result['epochs']}\")\n",
"# plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "0IQ7HfJCSDud",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Accuracy Surface"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 69,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 705
},
"executionInfo": {
"elapsed": 315450,
"status": "ok",
"timestamp": 1615991772345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "X3MWHLxJElbc",
2021-03-26 20:01:05 +00:00
"outputId": "134671d0-bfd3-4ee6-aa02-1a2a5b23f3ca",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAFWCAYAAABwwARRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAADBx0lEQVR4nOydd3hcV5n/P2eapJlR771asiz3RkISkpCErA0YQkJICLBLCyVh+UEgECCh7bILIbALoSxkIUBCQm8hgVASlpLYVrcl925JttXbSKMp5/fHzB2PRjOjO9KMms/neeaxNbede+fe+z3ve97zvkJKiUKhUCgUiuWPYbEboFAoFAqFIj4oUVcoFAqFYoWgRF2hUCgUihWCEnWFQqFQKFYIStQVCoVCoVghKFFXKBQKhWKFsCiiLoR4VAjxCR3rXSOEOLoQbVIsLkKICiGEO+jvDiHE5f7/CyHED4QQQ0KIX/i/+7wQYkAI0bRYbU4EQoirhBBtOtf9mBDi4US3aT4IIe4QQvxmjts+L4R4UwLaFPUaCyH+RQjxx3nsXwohSvz/f0YI8YagZdPuWyHE+4QQF4QQfXM93nJGrxYsJAv1bhFCnBRCXBnv/c4q6v4DjwshbEHfWYUQo0KIk/FukF5CRWAe+5nXA6xIDFLKBinlC/4/rwKuAAqklDcJIcqA9wI1UsotC9UmPffcfIVISvlXKeUGnet+Tkp591yPtRBIKR+XUr56sdsRTOg1TtTL1X+sHVLKH/mPM+2+FUJYgC8AV0opcxJx/EgEdzwUF1msd0s80WupdwGvDfr7NUBP3FujmIYQwrTYbYgHcTiPMuC4lHIy6O/zUsqBRWjLvFjs4y8kl9K56iT0vs0DzFLKw7HuSF3b6cTxesz53bJkkFJG/QAngU8DTwd991vgk8DJoO8agL8CQ0ATcEXQsmrg78Ao8DPgR8AngpbfBRwB+oDvATb/99cARyO06zAggTH/pwww+tt6CjgPPASY/OtfBrQAI/g6KR8AqoBJwO3fR0eEY30V6Paf27NAWdCySv/16MfX0flX//cm4DP+tgwDz0c6J/95lARd73uBA9r1jfX4QIn/PG1B670VeDbC+YX97YA3Ac+FrPtd7bfzX3Pt2AeAfwq5b6adR5jjfsz/O530t9sdsv2VwB0hv9EngQnA6//7If/6twAdwADwayAv+Hr774s+4N+BFOBh/zU9C3w06LiPAl8B/oTvfn0WyIp0z4Wcz/2Ax9/eMf/5hTt+NfB//uvdDXwuaB/T7g//8d4DnPBvf1/Qsk8Bj4Sc5wP+a3ASuDFo3VXAC0R4BkPO43ngs/iel0F8z2RK0PJYrvW/AH8M2vYq/36HgL8A9UHLtgHt+O7db/qv0Zv8y14FHPK3/yRwW5h23wD8I+Qd8R3//22AA7AGX2PgEXz3ksP/m93hb/OfgW/429IJbI7yjnwbcAY4B9zJ9Of5eXzP0VVMv2+/C4xz8X76iX/9q/E9g0P+bav931fgewbeje/99RjR33efAh4HfuK/ZruBSv+yZ/3HHfcf+6oI98CngUb/NfgRkORfFvqbVjD92ZX43ukn/efxLuCl/us4CNwf8rx9Fd+9MAL8Bv/zNofrUQv8zb+f88CDEX6vZOBr/t/rNL5nxhDmN3oozLaPEuH94F/+OnzvvEHgKaA4aNkOfM/HAH79xOelgejvpFnv/WltjLYw6OV6Nb6bNtf/OYPPHaqJjgU4ju/FbAbe4G94pn/5HuBz/vVeC7i4KAyvB/YB5f4T+yHwxUgCGOlG8n/3YXwPYy6QATwH3O1f9iJwh///mcCmcDdohGPdBqT72/cd4Jf+7034btRP+W+UNGCLf9nH/eetdTZeFumcmCnqLwL5+F+kczz+n4Hbg47xLPAvYc4t4m8HpOK7ufOD1h3EJxAGoM2/nQm4HOgNWnfGeYQcdye+h7HW/1v9kTCiHuElMu0aAtvxPQjr/OfwBeCnQeu68T24Zv81/Bq++8wOFOETqFcFPbTngPX+a/on4NOR7rkIL8M3hbQ19PjV+J4pk/9angZeG+HcJL4Xsx1Yi6/DoL3YPsV0UXcDH/Hv907gVNB+GvG9pM3Aq4Epoov6qaDf5jngs3O81oHfDsjGd/+8zr/8w/g68yZ899YZfB0YM/A+/740UT/Hxc5mAbAmTLvt+MQ5Bd99dwI45F92HfBihGt8Ev+9FnS/uYDb8T27/wb8JcK1WotPRF7iP+73CSPqEY5bwfR7vhTfM3SV/7jvA/YGrSvxdXaS/ceK9r77lP9aXOu/vt8HvhfunRPlHjiA772cge8ZeWuE5zH0PCTwJL4O1LX4hPJnQBaw2v93VdDzNhR0/X4I/GCO1+NJ4D5A4OvEbY9wbv/uP79MfO/nw/jfjaG/UZhtHyXy+2E1PgPuSiAJn/g/51+Wi0+UX4XvXv8Cvvtbe8dFeyfNeu9Pa2O0hcE3PPBl4G7/57/wWb6aqF9FiDWGzyq43X9TTALJQcv+xkVR/x3wxpCHRNtvxAsceiP5vzsIvDTo71dx0UL+K77eUVbINv/CLKIesn4d0Of//xX4XkSGMOsdAW4I8/2Mc2KmqN8eh+O/Dfi1///5+MQ5Lcx6EX87//9/DtwVdD2b/f+/DDgcst1PufhwzHYe3wU+FfT39cxd1L8JfCzo71R8L2WTf91xLlowAt/Lriho/buBR4Me2q8GLXsvFztRM+65MOf1PDNFPXD8CNv8BxE6sv57Y0vQ33u42AH4FNNFfVi7F/C9UCW+F3IFvhdpUtB+/kp0UQ/9bQ7Heq1DfzvgzQSJI76OYRe+DuHVwImgZQLfva0J4hngHYB9luu/F5+Q3ILP8tmP7/7/VJRrfJKZor4v6O81wFCE431Su3f8f9cwd1H/KPCtkP33+ter8O+3MGhZtPfdp4DfBC3bCbSG3Fezifo9QX9/AfivCM9j6HlIgjwb+Kzm1wX9vZuL9/CjYa7fpP/3j/V6/ADf/VkY6bz86x0Drg36+13A78P9RmG2fZTI74f7md5xsuN7NgqBf9Z+m6Dncwqfts72TtJ172ufWKLfHwfeiM899XjIsiL/gYM55f++EOiVF8dDCVm3DPgff2TzED7Bz42hXcGUAc8E7etxfONW4LsoDcBRIcTftMhqPQghPi6EOCqEGMH3Us32LyrBZw15w2xWgs9SmAtn43D8nwJXCyEy8XlDfielHAmzXrTfDny9Xy169w343HDgu9aV2rX2X+9/wvd7hz2PEApDjhvahlgoAz4e1I4z+HrBBf7l56SUWoBbLr5efWfQ+p/D9+LXOB/0fwe+h3M+BB8fIUSxEOIXQohzQohh4P9x8TcNh9729Gr3gpTS4f/Oju869EopnUHrRvttYOZvo/2usVzrUIrweSXwt9Hr3157T5wNWiZD2ngLcBNwVgjxOyFEfYRj/BVfR/UqfO+S4L//GuV8Q9F7zeN9H7855JmyAcX+5V4pZU/I+pHedzD/+3g+218I+v9EmL+D9xV6/ZLwWfWxXo978VnBrUKIFiFEpADNafch0993eoh0XULv7zF8Q5Pa/X0maJnDvwxmfyfpvfeBGKa0SSkb8V3oTCnl3pDF3fhcJcGU+b/vAXKEEMlBy4LX7QL+WUqZEfSxMTsyzHdd+Hpg2n7SpZRr/O0/JKW8Fd9N/yTwRJT9BBBCXI2vN7YTnwt8e9DiM0C5EEKE2fQMvt5kKOP4fkBt//lh1pFBy+d0fL+A/x64GZ/HJLQjphHttwPfuNBGIUQVPrftj/3fdwEHQn43u5TyP8KdRxh6Qo4b2oZY6MI3ThfclhQppSYKwe3oA5z43H/aumlSyh06jhP1XomyTuh3/4bPDV0rpUzH5/kKdw/Fi3NArj/aWmO2yOfQ30Z7ecZyrUPpxndvAb6piv59a++J0DYF/pZS7pZSvhLfi64N33h3OP4KvAyfBfRX/+dafM/N3yNso+d3jUS87+Nvh1xbq5RSa3doOyO+7xLMtHcY0zvEcyH0+jnxDQHGdD2klD1Syrfh62B+CvhxiO5oTLsPmf6+mw+h97cNX2ddu79Lg5alcLEjH/W
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 2 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(epochs, hidden_nodes)\n",
"\n",
"shaped_result = np.reshape([r[\"results\"][1] for r in single_results], \n",
" (len(hidden_nodes), len(epochs)))\n",
"\n",
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
"surf = ax.plot_surface(X, Y, shaped_result, cmap='viridis')\n",
"ax.set_title('Model test accuracy over different training periods with different numbers of nodes')\n",
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
"ax.set_zlim([0, 1])\n",
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "C793_RHvSGai",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"### Error Rate Curves"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 70,
2021-03-19 17:21:00 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 668
},
"executionInfo": {
"elapsed": 316211,
"status": "ok",
"timestamp": 1615991773109,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
"id": "tpClZMptrq-q",
2021-03-26 20:01:05 +00:00
"outputId": "f9fe93f9-7b67-4772-83e4-9e3567fd9318",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAf4AAAFECAYAAADGPlw2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABd10lEQVR4nO3dd3wc1bXA8d/Rqjd3reQq44otuWjdcaMZbExLKAmmOHl2Qgu9B/IgAfLoJZTQguPQITQbGwy4CBuMG+4FbFxwlauKJVntvj9mJK+kXWkl72ql1fl+PvuRdubOzLmzM3tm7szeEWMMSimllGoewoIdgFJKKaUajiZ+pZRSqhnRxK+UUko1I5r4lVJKqWZEE79SSinVjGjiV0oppZoRTfzNhIhcJCK7RCRPRJIaQTz3iMhzfpjPZBH5qh7TdRaRIye6/PryV/3ruMxkEflWRHJF5BYfys8Xkcvt/yvFW3V7EpFRIrLFfj84kPVozETkfhF5NdhxlBORaSJyr5dxFZ+vh3GjRGRVDfOtadp67ZMNRUTGisjmIC37nyJyh49l14nI8EDEER6ImdZGRPLc3sYB+UB5hwJ9jDE76jCvacBmY8yD/osweAJYn0eB3xlj5vh5vvVijHk4yMvfAbQsfy8i84FXjTFv+HtZIpKK9ZlW7G9Bqv8fgG3GmBF1ndBDvJW2JxF5B3jEGPPyiYfpu0B+bs2ZMeYboH+w42hMPO3HdWWMuboOZfvWdzm1CUriN8bEl/8vIoVAX2PMtmDE4k5Ewo0xJd7e13V6D+MFEGNM2QmGWh+dgfV1naiu66A5akLrqF7bgI/z0u1LNWtB/n6vG2NMUF9AIZBq/98aeAvIAn4GrnIr93tgO5ALbALGAlcBxcAxIA/4p5dlXASsAw4BnwJJ9vCxwGbgAeAA8BAwH/gbsMyebwTwK2ADcBiYCXSwp08FSoCrgV3AGx6WPQ14DpgLFAA97Lr8aNdlNTDWLuuxPkA6kGkvfzkwyB4eBjxrx34EWAq09RBDHlaLylFgiT1sFPCDPd0C4GS38ga4HtgKfONhft3seI4Au4GHa/h8J9qfVy6wDfiNPfx+rDM198/hL/ZntA04y20ePYDv7Hn8F3gXuNceNxn4yq2sx3XlIa5UoMT+/z6gFGtbzAPusYePsedxxN4uunn73IFWwOf2Z7EfeBmIssv/aK/TPPvV2b3+dpnatrGpwB5gL277hYd69QW+sWNeDpxiD/8nlbetgR6mHYy1PebY5TOByz18XpW2J6x9qwyr5W6vXaYz8Blw0K7X2W7L2QbcYQ/f5uM+Wm3b8Pa5eajXdcBP9mfzbyDObduZC7xi13klMKC2dWmPa8fx76oDwP+5rac3gfexttfvga72uBjgbbseh/Cwb9nlatqWvK4Pt31zER72FQ/LmY/13bfMrv+7VZfj47YRZ9f5CLAC63vUfZ+sbT/ydds2wDVY30sHgLvdxkUDz9vz2GGvnzB7nAN4Bmtb3ATcVaVuvn5neNqPp+Hj97tbPnD/7poLvGiv1/VARpX9ZKTbdM8CX9vznQO0dis7Fdhp13+qHWdHr+vS24iGelE58X8GPA5EAb2xkko/e8PKAXrY5bpwfGeqWJFe5j/EXiHpWEn8UeADt427xN5IIrB2zPlYO1Z3e2M6GcgGRtpxPQvMc9twDdaOEA3EeFj+NHsjdWG1sEQAE4BO9gY51f6wojzVB4jHSi6/tstfgLVhRwNnY+20ifa4DCC+hp2mo/1/G6yN/Fd2PLdjfTGGu5X9xJ6vpzp1w9qZw7E29B3ABV6Wu5fjyScZ61IOVE/8JcCd9jz/AGx3m8cyrC+oCOBcoAgPib+mdeUhrlTsxO/2JXi52/tOWF+6o+x5/QlY6u1zt9fpufY2koL1BXiTp2V5qH9vat/GngEigXFYXzoJHuoUiXXAfIO9ri7FSg6tattX7Gl/wfpijbDrW4KHxF91e/LwJRUGrLLjCAeG2+vS6VZ2MeC0150v+6i3baPS5+ahXhcDa7C+M2KwkvXjbttOSZU6/2wvp7Z1OQd4FUiw5zvcbT3lA6fa85kO/NsedzXWQU2MPW6Ul5hr2pZqWx9LgIft+C/AOtirKfFvsNdNS6wDr9+5LWezj9vGo8BXWN8Xve2y5fukL/tRrdu22zb3PtZ+noaVO8oPIspP2lphJeQfgcn2uGuxEnAy0B5r2yyvW72/M070+x1r+ysGfmuXfRBY4GWfmmbPpx/Wd87XwAP2uHSsg6pBWNvW6zSVxG9/KEeBCLdxj2PtSHFYX4wXlK/AKiu+psT/T9zOBLB21GL7QxprLzO8ys7gfiR5H/aO67ahFGPtkKn2Ck6pYfnTgJdqWQd7gDRP9QF+A8ypUn6ZHfvpWEewQ7CamGpahnviv6LKBhaGtfEPdys7vA6f4d+xv0w9jPsFmEKVAxKqJ/5sjh+hx9oxtLTXcYH75451FuYp8XtdV7XtxFRP/HcBL1eZZr89nS+f+x85nrwqLctD/X3Zxtq4jc/C7czUbfgo7DNot2HfAb+tbV/BOpDb6vZe7M+uPol/GPBjlfl/wPEv4m3lMfm4j3rcNjx9bh7q9Tlwmdv7NI63Mkz2UucRNa1LoAPWwWech+XdD8xwez8BWGn//z9YZ+N9fd23PGxLXtcHVgIvxC1pAQtr+MznA7e6vX8UeNptOeXJsbZtYyuVz2of5Pg+6ct+VOu27bbNudzeL8E+4QC2AKdWWWdf2P/PK9/27PdT3OpW7+8Mt32qXt/v9va3xq1cH+CIl31qGvAPt3HXAh+7bXP/chvXjVoSf2O6q78z1pHMfhE5Yt9x/Ucg2RhzFGuHuwHYJyLvi0j7Osz3z27z/AXraDXZHr/XVL/GuNPt//ZYR4AAGGPysJqMypdfZozZU0sM7vNDRC4QkRVuMSVhHeV7i39MeVm7/MlAe2PM11hfmi8De0TkcRGJqCUWT3Uqw1ov7ut0Z9WJ3OLvICIficheEckGbqoh/ouAC4GdIvK5iJzspdx+Ow6MMfn2sHisz2m/MeaYD7F5XVfe6lKDzsAVVeYVh/WlD1U+dxFJEJHpIrJTRHKAJ/G+TqqqbRsrNcYcdCufj7VuPM3nlyrDtuNb/VNwW6/G+gbxug3UojPQtcq6O9teRrmdVcrXtI962zZ8jeUlt3kvxGqmrxaHW51TqHlddgSy7O8lT/a5/e/+Wf0H60ztIxHZLiJ3e5rYh23J2/pIsccVupWtWgdfY3VX27aRUmU57v/Xth/5um3XFm+lfYjK231t8Z3od8aJfL/7sv5rK5tcJYZa99vGlPh3YTXztDLGtLRfCca+C9IYM8sYcxrWTncMqzkLjv8aoKb53uc2z5bGmBhjTPnK8TS9+7DdWBsHACISh/Uh7vZx+ZXKiEgU1nW+P2Md6bbEOsoVL/PbhXXk6h5/nDHmLQBjzFPGmAFY1+DOAib5EE/VOglW09RutzI11etBrEsFPY0xLYCn3eKvxBjzvTHmHKxm3VVY17PqYi/QTkQi3YZ19FK2xnVVC0/r/ZUq84o1xizyUv4WrIQywBiTaL/39plWVds25qvdWJ+ju84+zmcP1dert/Vcm13AhirrLt4Y83e3MqZK+Zr20Zr4sv9fVXWbcBvvqc57qHld/oK1Tcb6EN/xQI0pMsb8xRjTE2tfvVFExnooWtO2VJM9QFsRiXYbVrUO9VHbtrGnynLc/69tP/KXSvsQlbf72uLz9TvD27ZWl+/3QNjL8QMp8GG/bTSJ3xizC6sp7UERiRWRcBHJEJE+IuIUkYkiEoOV9POxbuoBa6Wm1jDr14HrRaQ/gIi0FpHz6xDaB8AFIjLCTj4PAt/6cJbvTRTW9awsO54bqXwGUrU+M4GB9lFkuIjEiMjZItJCRAaJyGARCce64aOY4+ulJrOB/iJyvj3tzVjN6ct8rEOCvbw8EUkDvP2eN1JELhORRDu2PB/jq2CsX3tsAO4SkQgROQcY6qW413Xlw6Kqrve3gIvF+j1zmH0WdlEN0ydgbZfZItIFqymu3AEgTES87ZD+2sa+BxCR6+36X4x19vK5D9N+B0SIyB/
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
"for layer in hidden_nodes:\n",
" plt.plot(epochs, \n",
" 1 - np.array([i[\"results\"][1] \n",
" for i in single_results \n",
" if i[\"nodes\"] == layer]), \n",
" label=f'{layer} Nodes')\n",
"\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.title(\"Test error rates for a single iteration of different epochs and hidden node training\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
"\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig('fig.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "7mJaKjlCxEkt",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"## Multiple Iterations\n",
"\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the epoch/layer investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs |\n",
"| --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|3|0.2|0.05|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|4|0.08|0.04|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|5|0.08|0|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|6|0.06|0|128|1, 2, 3, 4, 5, 6, 7, 8|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|7|0.06|0|35|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs |\n",
"| --- | --- | --- | --- | --- | --- |\n",
"|1|0.01|0|128|2, 8, 12, 16, 24, 32, 64, 128, 256|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|2|0.5|0.1|128|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|3|1|0.3|20|2, 8, 12, 16, 24, 32, 64, 128|1, 2, 4, 8, 16, 32, 64, 100|\n",
"|4|0.6|0.1|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|5|0.05|0.01|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|\n",
"|6|1.5|0.5|20|2, 8, 16, 24, 32|1, 2, 4, 8, 16, 32, 64, 100, 150, 200|"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-22 20:49:29 +00:00
"execution_count": 30,
2021-03-19 17:21:00 +00:00
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "-lsKo4BCP3yw",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 1/30\n",
"Iteration 2/30\n",
"Iteration 3/30\n",
"Iteration 4/30\n",
"Iteration 5/30\n",
"Iteration 6/30\n",
"Iteration 7/30\n",
"Iteration 8/30\n",
"Iteration 9/30\n",
"Iteration 10/30\n",
"Iteration 11/30\n",
"Iteration 12/30\n",
"Iteration 13/30\n",
"Iteration 14/30\n",
"Iteration 15/30\n",
"Iteration 16/30\n",
"Iteration 17/30\n",
"Iteration 18/30\n",
"Iteration 19/30\n",
"Iteration 20/30\n",
"Iteration 21/30\n",
"Iteration 22/30\n",
"Iteration 23/30\n",
"Iteration 24/30\n",
"Iteration 25/30\n",
"Iteration 26/30\n",
"Iteration 27/30\n",
"Iteration 28/30\n",
"Iteration 29/30\n",
"Iteration 30/30\n"
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_param_results = list()\n",
2021-03-19 17:21:00 +00:00
"multi_iterations = 30\n",
"for i in range(multi_iterations):\n",
" print(f\"Iteration {i+1}/{multi_iterations}\")\n",
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)\n",
" multi_param_results.append(list(evaluate_parameters(dtrain=data_train, \n",
2021-03-19 17:21:00 +00:00
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
2021-03-22 20:49:29 +00:00
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=1.5, momentum=0.5),\n",
2021-03-19 17:21:00 +00:00
" return_model=False,\n",
" print_params=False,\n",
2021-03-22 20:49:29 +00:00
" batch_size=20)))"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-22 20:49:29 +00:00
"(Iterations x [Test/Train] x Number of nodes x Number of epochs)"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-22 20:49:29 +00:00
"execution_count": 173,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30 Tests\n",
"Nodes: [2, 8, 16, 24, 32]\n",
"Epochs: [1, 2, 4, 8, 16, 32, 64, 100, 150, 200]\n",
"\n",
"Loss: categorical_crossentropy\n",
"LR: 0.05000000074505806\n",
"Momentum: 0.009999999776482582\n"
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_param_epochs = sorted(list({i[\"epochs\"] for i in multi_param_results[0]}))\n",
"multi_param_nodes = sorted(list({i[\"nodes\"] for i in multi_param_results[0]}))\n",
"multi_param_iter = len(multi_param_results)\n",
"\n",
2021-03-22 20:49:29 +00:00
"accuracy_tensor = np.zeros((multi_param_iter, 2, len(multi_param_nodes), len(multi_param_epochs)))\n",
"for iter_idx, iteration in enumerate(multi_param_results):\n",
2021-03-19 17:21:00 +00:00
" for single_test in iteration:\n",
2021-03-26 20:01:05 +00:00
" accuracy_tensor[iter_idx, :,\n",
" multi_param_nodes.index(single_test['nodes']), \n",
2021-03-26 20:01:05 +00:00
" multi_param_epochs.index(single_test['epochs'])] = [single_test[\"results\"][1], \n",
" single_test[\"history\"][\"accuracy\"][-1]]\n",
2021-03-22 20:49:29 +00:00
" \n",
"mean_param_accuracy = np.mean(accuracy_tensor, axis=0)\n",
"std_param_accuracy = np.std(accuracy_tensor, axis=0)\n",
"\n",
"print(f'{multi_param_iter} Tests')\n",
"print(f'Nodes: {multi_param_nodes}')\n",
"print(f'Epochs: {multi_param_epochs}')\n",
"print()\n",
"print(f'Loss: {multi_param_results[0][0][\"loss\"]}')\n",
2021-03-26 20:01:05 +00:00
"print(f'LR: {multi_param_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
"print(f'Momentum: {multi_param_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
"cell_type": "code",
2021-03-22 20:49:29 +00:00
"execution_count": 31,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [],
"source": [
"pickle.dump(multi_param_results, open(\"result.p\", \"wb\"))"
]
},
{
2021-03-22 20:49:29 +00:00
"cell_type": "code",
"execution_count": 172,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [],
"source": [
2021-03-22 20:49:29 +00:00
"exp1_testname = 'exp1-test5'\n",
"multi_param_results = pickle.load(open(f\"results/{exp1_testname}.p\", \"rb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"np.savetxt(\"exp1-mean.csv\", mean_param_accuracy, delimiter=',')\n",
"np.savetxt(\"exp1-std.csv\", std_param_accuracy, delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
2021-03-22 20:49:29 +00:00
"mean_param_accuracy = np.loadtxt(\"results/test1-exp1-mean.csv\", delimiter=',')\n",
"std_param_accuracy = np.loadtxt(\"results/test1-exp1-std.csv\", delimiter=',')\n",
2021-03-19 17:21:00 +00:00
"# multi_iterations = 30"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Best Results"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-22 20:49:29 +00:00
"execution_count": 141,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Nodes: 24, Epochs: 8, 96.9% Accurate\n"
]
}
],
"source": [
"best_param_accuracy_idx = np.unravel_index(np.argmax(mean_param_accuracy[0, :, :]), mean_param_accuracy.shape)\n",
"best_param_accuracy = mean_param_accuracy[best_param_accuracy_idx]\n",
"best_param_accuracy_nodes = multi_param_nodes[best_param_accuracy_idx[1]]\n",
"best_param_accuracy_epochs = multi_param_epochs[best_param_accuracy_idx[2]]\n",
"\n",
2021-03-26 20:01:05 +00:00
"print(f'Nodes: {best_param_accuracy_nodes}, Epochs: {best_param_accuracy_epochs}, {best_param_accuracy * 100:.1}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test Accuracy Surface"
]
},
{
"cell_type": "code",
"execution_count": 174,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653358,
"status": "aborted",
"timestamp": 1615994110345,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "ZGJVhz6iJU-7",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-22 20:49:29 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAACRsAAAfcCAYAAAA/09pSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAEzlAABM5QF1zvCVAAEAAElEQVR4nOzdeZysV10n/s/p9XbfLJAAIYR9JyH7JQkKCDLuK4IKAgIjLjM/Rx3HbURHHR1nXGYc910RARdUHHdxQ0XNnhCyAwlrQkJC1tt9ez2/P57qpu/t7qrq6qrq5b7fr1e/+t4+p85zuurpqrr9fO73W2qtAQAAAAAAAAAA6GRkpzcAAAAAAAAAAADsDcJGAAAAAAAAAABAV4SNAAAAAAAAAACArggbAQAAAAAAAAAAXRE2AgAAAAAAAAAAuiJsBAAAAAAAAAAAdEXYCAAAAAAAAAAA6IqwEQAAAAAAAAAA0BVhIwAAAAAAAAAAoCvCRgAAAAAAAAAAQFeEjQAAAAAAAAAAgK4IGwEAAAAAAAAAAF0RNgIAAAAAAAAAALoibAQAAAAAAAAAAHRF2AgAAAAAAAAAAOiKsBEAAAAAAAAAANAVYSMAAAAAAAAAAKArwkYAAAAAAAAAAEBXhI0AAAAAAAAAAICuCBsBAAAAAAAAAABdETYCAAAAAAAAAAC6ImwEAAAAAAAAAAB0RdgIAAAAAAAAAADoirARAAAAAAAAAADQFWEjAAAAAAAAAACgK8JGAAAAAAAAAABAV4SNAAAAAAAAAACArggbAQAAAAAAAAAAXRE2AgAAAAAAAAAAuiJsBAAAAAAAAAAAdEXYCAAAAAAAAAAA6IqwEQAAAAAAAAAA0BVhIwAAAAAAAAAAoCvCRgAAAAAAAAAAQFeEjQAAAAAAAAAAgK4IGwEAAAAAAAAAAF0RNgIAAAAAAAAAALoibAQAAAAAAAAAAHRF2AgAAAAAAAAAAOiKsBEAAAAAAAAAANAVYSMAAAAAAAAAAKArwkYAAAAAAAAAAEBXhI0AAAAAAAAAAICuCBsBAAAAAAAAAABdETYCAAAAAAAAAAC6ImwEAAAAAAAAAAB0RdgIAAAAAAAAAADoirARAAAAAAAAAADQFWEjAAAAAAAAAACgK8JGAAAAAAAAAABAV4SNAAAAAAAAAACArozt9AYAANjfSilPTPLvu5j61lrrBwa9HwC6V0oZSfJfk4x3mPqeWuvfDmFLAAAAAMAOEzYCAGDQvinNhepOHpHk2wa6EwC26t8l+ZEu5r03yXmD3QoAAAAAsBtoowYAwMC0KmJ8bZfTX11K6VQ5A4DhekOX884tpZw/0J0AAAAAALuCsBEAAIP0uUnO6HLuo5J86QD3AsAWlFIekeTLt3CTblpmAgAAAAB7nLARAACDtNULz91W0ABg8L4myYGtzC+lTA5qMwAAAADA7iBsBADAQJRSTsnWKxV9finl9EHsB4At22oAtJfnfQAAAABgjxE2AgBgUL4myVYrXIwmee0A9gLAFpRSnpvkUA83VaEOAAAAAPY5YSMAAAZlqy3UVrhQDbDzen0O/9xSyuP6uhMAAAAAYFcRNgIAoO9KKecmOb/Hmz+7lPL8fu4HgO6VUsaTvKbHm48meV0ftwMAAAAA7DLCRgAADEKnihh/12FcdSOAnfPFSR7dZvwfktQ246/v624AAAAAgF1F2AgAgL4qpUwkeXWbKTNJXtv6vJmvLqVM9XVjAHSrU+Dzh5O8p834M0spn9nH/QAAAAAAu4iwEQAA/fYlSU5tM/7OWuudSd7ZZs5JSV7e110B0FEp5bQkX9BmykeTvDvJWzospUIdAAAAAOxTwkYAAPRbpxZqbznm82ZcqAYYvq9NMtZm/K211prkHUmOtJn3VaWU6b7uDAAAAADYFYSNAADom1LK6Uk+r82UO5L8bevPf9v6+2ZeUkp5cp+2BkB3OgU935IktdYHkvy/NvNOTPKKfm0KAAAAANg9hI0AAOinr00y2mb8bbXW5SRpfX5bm7klyev7tzVgM6XxxFLKS0spX1lKeX0p5ZWllC8spZxXSjmw03tk8EopFyd5TpspV9Rab17z904V6jpVugMAAAAA9qDSVD8HAIDtK6XcnORZbaacXWu9fs385yZ5X5v5H0ry1OpN67aVUh6b5AVJnp/kmUmemuTRSQ4mOZBkNsnhJB9L8sEkVyX55ySX1VqXdmLP3SqlTCZ5YZKXpAlKPDPJo5KckOZ7ezjJ3Wm+r2+stX6kx+OUJM9tfTy79fG0JI9IU8XlpCTj+fR9eUea+/P6JFcneXet9Z5ejj0IrQDRK5J8eZKXpvk+NrOc5IYkf5bkHbXWa9qse2GSR24yfF+t9ape9tuLVnW0z0xySZKnJ3lKklPTnPcTaR6rh5J8JM35cUWSf6q1Xj2sPe4mpZRfTvINbab8p1rrz62ZP5rk40lO22R+TfL0Wutt/dvl9pRSxpNcmOTiJOcneXKSJ6Y5/6fTtJB7qPXxQJLbktyY5KY0z4e3DH3TSUopU0kuSrPvc9Ls+wlpnnem0wR0V/Z9f5L3p9nzTUn+pdfnPTprvb5+XprHZuX19eQ0r0FJcx59NM1rwH/ZxnEOJjmU5nVu5TXojDTnwEmt4y2leV67P83P5u1Jrk3z3PavtdbFXo8/aKWUM9Pch4fSvLY+Kc1r+XSSySQzac7vB9O8tt6U5mfz2jRByC29VymlvCTJZ7WZ8he11su39l30rpTy1DSh/c3cUGt9x7D2AwAAAJ0IGwEA0BellM9I8i9tplxbaz1/g9tdneaC72ZeWmv9+x73dCjNBbbN/G2t9XN6WbvH/Tw9zQXgzVxVaz3Ux+OdmOR1SV6VJmRUeljmE0l+L8lP11pv79fe1iql/GCSH9hsvNa64b5LKecm+U9pvr/pLg93fq312i3s7VlJPidNkOmz0gRVerWc5LIkv5nk7bXWw9tYq2et0MJ/TvJfkpzS4zJ/l+QHaq3rfuZLKe/O5hdw/7HW+uIej9mVUsqjk3xdklcmObfHZT6U5K1Jfq7Weleftrartc6LO9OEJDaykORxxwbmSin/J835tJkfrrX+t/7ssjetoODnJnl1ki9J+2BdJx9O8tdJ/iTJXw0yjNkKc70szXPc56f757mN3Jxm3++stf5jj/t5fZrnr808pdb6oV7W7uLYb07zeraRD9dan7yNtT+UJtiykd+qtb5+g9uMJ3l5km9OE2jsxntrredtYV9jaZ5LX9L6eF6aQGuvHkwTGP2lWus/b2OdvmmFzv99kq/I5o9BN+5P0573L9IEYh/u4tid3iO+q9barjVwX5VSfibNe5rNvK7W2qmaHAAAAAyNNmoAAPRLp3Y5m10g6XTh5A097CVJUmu9Ms3/et/MZ5dSzuh1/R60+x/rSfLmfhyklHKwlPLf01Rr+dkkn5HegkZJ8tgk35rk1lLKr5dSHtWPPW5HKeW0Uspvp6lm8HXZ3gX4jdZ/cinle0op16a5QP+zaS6EbidolDT//np+kl9J8uFSyre2LlgPTSnlBWl+Jv5Heg8aJU0lpH8upfxsK6Sy40opj2pdrP1wkv+Z3oNGSVM15vuSfKiU8hOtiiL73cuzedAoSf5yk8pcnZ7DX1dK2ZHfPbTaA746TXWxv0ry2mwvaJQ0gYhvSBPa+FAp5QdKKZtVdupJKWWslPLNaaptvSPN8892n+eenea5/N2llJtLKf+llNLu8WYTrYo41yX5nXQfNOp27ZFSyktaVcY+kSZA86Y0r+Pbfb04KcnXJPmnUsq/tkLiO6KUckkp5V1pqlv+52wvaJQ0P9evSPIbSe4spfxyKeWcdjdovUdsV2nvc1rV8QauVWnwNW2m3JfmuQAAAAB2DWEjAAC2rZQyneSr2kxZTPL2Tcbe3hrfzFeUUk7qdW9JfqvN2EjaX9zpm1Zljde2mTKf5sLldo/zxWlai3x/tn9Rfa2xNIGym0spX9rHdbeklPKiNO28BvK4tYIJt2f7YZVOTk3yf5P8WynlGQM8zqpSyrckeXeaIE1flkxT1eNvSynbDWJtbyOlvCHJLWmqQvQz/HQgyXc
2021-03-19 17:21:00 +00:00
"text/plain": [
"<Figure size 3000x2000 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, Y = np.meshgrid(multi_param_epochs, multi_param_nodes)\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-22 20:49:29 +00:00
"# fig = plt.figure(figsize=(10, 5))\n",
"fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"ax = plt.axes(projection='3d')\n",
"\n",
2021-03-22 20:49:29 +00:00
"surf = ax.plot_surface(X, Y, mean_param_accuracy[0, :, :], cmap='coolwarm')\n",
"ax.set_title(f'Average Accuracy')\n",
2021-03-19 17:21:00 +00:00
"ax.set_xlabel('Epochs')\n",
"ax.set_ylabel('Hidden Nodes')\n",
"ax.set_zlabel('Accuracy')\n",
"ax.view_init(30, -110)\n",
"ax.set_zlim([0, 1])\n",
"fig.colorbar(surf, shrink=0.3, aspect=6)\n",
"\n",
2021-03-22 20:49:29 +00:00
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig(f'graphs/{exp1_testname}-acc-surf.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-19 17:21:00 +00:00
"source": [
2021-03-22 20:49:29 +00:00
"### Test Error Rate Curves"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-22 20:49:29 +00:00
"execution_count": 175,
2021-03-19 17:21:00 +00:00
"metadata": {
"executionInfo": {
"elapsed": 2653349,
"status": "aborted",
"timestamp": 1615994110347,
"user": {
"displayName": "Andy Pack",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjA4K4ZhdArHXAFbAGr4n0aCv2HmyUpx4cy6zcUq34=s64",
"userId": "16615063155528027547"
},
"user_tz": 0
},
2021-03-26 20:01:05 +00:00
"id": "Jrn3hKQAlGcc",
"tags": [
"exp1"
]
2021-03-19 17:21:00 +00:00
},
"outputs": [
{
"data": {
2021-03-22 20:49:29 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAC4YAAAeeCAYAAAAResutAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAEzlAABM5QF1zvCVAAEAAElEQVR4nOzdd/gdRb348fcnjZKEHoq0CAKCgIIIolItiFwVsIGoF+tPuXbUq2ABC1bs3XsVOxbEjkhVRBGQ3qRopEknlRSSzO+PObmE8D27p+ye9n2/nmcfHrL7nZk9Z8/szOxnZyKlhCRJkiRJkiRJkiRJkiRJkiRpeE3odwEkSZIkSZIkSZIkSZIkSZIkSd0xMFySJEmSJEmSJEmSJEmSJEmShpyB4ZIkSZIkSZIkSZIkSZIkSZI05AwMlyRJkiRJkiRJkiRJkiRJkqQhZ2C4JEmSJEmSJEmSJEmSJEmSJA05A8MlSZIkSZIkSZIkSZIkSZIkacgZGC5JkiRJkiRJkiRJkiRJkiRJQ87AcEmSJEmSJEmSJEmSJEmSJEkacgaGS5IkSZIkSZIkSZIkSZIkSdKQMzBckiRJkiRJkiRJkiRJkiRJkoacgeGSJEmSJEmSJEmSJEmSJEmSNOQMDJckSZIkSZIkSZIkSZIkSZKkIWdguCRJkiRJkiRJkiRJkiRJkiQNOQPDJUmSJEmSJEmSJEmSJEmSJGnIGRguSZIkSZIkSZIkSZIkSZIkSUPOwHBJkiRJkiRJkiRJkiRJkiRJGnIGhkuSJEmSJEmSJEmSJEmSJEnSkDMwXJIkSZIkSZIkSZIkSZIkSZKGnIHhkiRJkiRJkiRJkiRJkiRJkjTkDAyXJEmSJEmSJEmSJEmSJEmSpCFnYLgkSZIkSZIkSZIkSZIkSZIkDTkDwyVJkiRJkiRJkiRJkiRJkiRpyBkYLkmSJEmSJEmSJEmSJEmSJElDzsBwSZIkSZIkSZIkSZIkSZIkSRpyBoZLkiRJkiRJkiRJkiRJkiRJ0pAzMFySJEmSJEmSJEmSJEmSJEmShpyB4ZIkSZIkSZIkSZIkSZIkSZI05AwMlyRJkiRJkiRJkiRJkiRJkqQhZ2C4JEmSJEmSJEmSJEmSJEmSJA05A8MlSZIkSZIkSZIkSZIkSZIkacgZGC5JkiRJkiRJkiRJkiRJkiRJQ87AcEmSJEmSJEmSJEmSJEmSJEkacgaGS5IkSZIkSZIkSZIkSZIkSdKQMzBckiRJkiRJkiRJkiRJkiRJkoacgeGSJEmSJEmSJEmSJEmSJEmSNOQMDJckSZIkSZIkSZIkSZIkSZKkIWdguCRJkiRJkiRJkiRJkiRJkiQNOQPDJUmSJEmSJEmSJEmSJEmSJGnIGRguSZIkSZIkSZIkSZIkSZIkSUPOwHBJkiRJkiRJkiRJkiRJkiRJGnIGhkuSJEmSJEmSJEmSJEmSJEnSkDMwXJIkSZIkSZIkSZIkSZIkSZKGnIHhkiRJkiRJkiRJkiRJkiRJkjTkDAyXJEmSJEmSJEmSJEmSJEmSpCFnYLgkSZIkSZIkSZIkSZIkSZIkDTkDwyVJkiRJkiRJkiRJkiRJkiRpyBkYLkmSJEmSJEmSJEmSJEmSJElDzsBwSZIkSZIkSZIkSZIkSZIkSRpyBoZLkiRJkiRJkiRJkiRJkiRJ0pAzMFySJEmSJGlERcQuEXF0RHwvIi6KiFsjYkFELI2I1Gzrd7k1WiJiVsH1dtKgpt1i/htHxBER8cWIODsiboiI2RGxuOg3FhFHdpHn1Ig4KCI+FhG/joirI+KeiFgUEcv7+XlIkjRMIuLcgvvmuf0un6R6RcSRJW32mf0uY5mImFlXv6Of+t3Pq0NEnFRwTrP6Xb5OjOI5SRoOtuMlqdykfhdAkiRJkiRJ1YmINYHXA0cBW/e5ONJIioj9gXcBz6RHk29ExLbAfwMvAab2Ik9JkiRJkiRJkjRcDAyXJEmSJEkaERGxL3ASsGVfCyKNqIhYH/g6cGgP85wIfAB4D47nSpIkSZIkSZKkAj5IkCSpRhGxCXBQv8tRoZNTSvP7XQhJkiQ9UkS8DPgWjvdItYiIzYBzgMf0MM/JwCnAc3uVpyRJkiRJkiRJGl4+KJQkqV7bAd/odyEqdCZgYLgkSdKAiYh9MChcqk1ErA78lh4GhTd8CYPCJUmSJEmSJElSiyb0uwCSJEmCiDgpIlKTbVa/yydJal1E7FtQp6eI2LffZdRoiYjVMChcqtv7gZ16mWFEPAt4bS/zlMaDknbacf0unyRJkiRJkiR1wweGkiRJkiRJw+01wKNLjknAX4GLgVuBOcDSmssljYSImAG8pYVDbwH+BFxL/o09UHDs+S2k95EWjpkLnAtcCdxLXuEpNTn2+hbSkyRJkiRJkiRJQ8zAcEmSJEmSpOH2+pL9ZwGvTynd2IvCSCPoFcCaBfvvBt4AnJpSWl5FhhHxRGC3gkOWAh8APptSKgpAlyRJkiRJkiRJ44iB4ZIkSZIkSUMqIrYBdiw45ELg2SklZweXOndowb4HgQNTSn/rYZ4AR6eUPl9xnpIkSZIkSZIkachN6HcBJEmSJEmS1LF9S/a/36BwqXMRsQawe8Ehp9QQFA7Fv+3bgC/WkKckSZIkSZIkSRpyzhguSVKNUkrnAlF1uhFxEvCfBYd8O6V0ZNX5SpIkaeA8vmDffOCMXhVE6oeU0syas3gcxWOoP6s6w4gIYOeCQ36RUlpedb6SJEmSNAh60M+TJEmSRpozhkuSJEmSJA2vrQr2XWnwqNS1ot8YwOU15DkDmNbjPCVJkiRJkiRJ0ggwMFySJEmSJGl4bVSw746elUIaXUW/Majnd9aPPCVJkiRJkiRJ0ggwMFySJEmSJGl4Fc0qPL9npZBGV9FvDOr5nfUjT0mSJEmSJEmSNAIMDJckSZIkSRpeUwr2Le9ZKaTRVfQbI6VUx++sME/8bUuSJEmSJEmSpCYMDJckSZIkSRpe0e8CSCOuH78xf9eSJEmSJEmSJKkjk/pdAEmSpH6KiAnA1sD6wNrAWsBUYDHwQGObD9wC3F7TjICSJKlL3tMlSZJUJiImAo8BHg1MJ7cXFwJzgBuAf9pO1HgQEesCM8n9phX9pwk81Hd6ALgHmJVSmt+nYkqSJEmqWURMIz9bWdEvWAuYTO4rr+gb3Av8K6U0u0/FlNQmA8MlSVLHGp2EPYGnAo8jP1TblPxQbU3gQWABcCfwD+BK4E/AH/v1QCEiJgHPBp4D7ALs3ChrK5ZExC3A9cAFje2vKaU5dZRVWUSsDuxOvs52Jl9nmwPTyN/dMvJ1dg/5OruGfJ2dm1K6vx9lbkdEbA88C3gCsB353FY8nF4C3A/MAr6RUvp2DfkH+bfwTGCnRhkeRf58Vzwgvw/4J/DxlNJpXea3GbA38BRgG/L3uX4jr0nk73IecDNwE/A34Dzg0mF4ON94sPpM8vltRw44WIf8nU4A5gK3AVellI7oUzErFREbAwcAewDbAluRB4+mNQ6ZQw7EPTeldHQX+UwFdgO2Bx7b2DbloUGqaeT6YCEwm/w5/xO4DLgI+HNKaWmn+Y867+mDfU+PiNcU7J5WsG+bkr99hJTS/7RzfJGI2BHYi3wf3xrYkocC9hP5mppN/q3eSP4e/phSuqmqMtQpImaS678nkuv8meT6fkV9NBv4F3BKSumTfSlkByJiNfK9+gDg8eS6fV3y7+sB8n36FuBq4ELg1yml2/pT2mpExN7k8xzLriV/29ZvDPhDSumGiNgG2KfJMduVpHFQRDymjTyvTyn9sY3jCzXq313J94xdyPf+LXioDZvI18qKNuR1wPnAOSmlO6oqR116/dseB/2drYDnkdtxjwM24aEHjCuCDv8JXAL8ETgjpbS4P6VVo23/NHK7cEXbfgb5t706ua29ALiVh/fX/ppSWtaPMrciIrYEXgI8F3gSsFrB4fMj4mzg58DJKaWF9ZewXETMIN+fn0qum7YCNiB/N1N46CXIW8l1xaXkuuLClNKD/ShzkYiYTK5n9yDfS2aS7yXrkOu+SeQ2xzxyn3JF/Xct+Xr7e88LPYZGP+rZ5Hv6jjwUTLHiBdW
2021-03-19 17:21:00 +00:00
"text/plain": [
"<Figure size 3000x2000 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"# fig = plt.figure(figsize=(7, 5))\n",
"fig = plt.figure()\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-22 20:49:29 +00:00
"for idx, layer in enumerate(mean_param_accuracy[0, :, :]):\n",
"# plt.errorbar(epochs, 1- layer, yerr=std_param_accuracy[idx], label=f'{hidden_nodes[idx]} Nodes')\n",
" plt.plot(multi_param_epochs, 1 - layer, '-', label=f'{multi_param_nodes[idx]} Nodes', lw=2)\n",
2021-03-19 17:21:00 +00:00
"\n",
"plt.legend()\n",
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.title(f\"Test error rates for different epochs and hidden nodes\")\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Error Rate\")\n",
2021-03-22 20:49:29 +00:00
"plt.ylim(0)\n",
"\n",
"plt.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# plt.savefig(f'graphs/{exp1_testname}-error-rate-curves.png')\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
2021-03-22 20:49:29 +00:00
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Test/Train Error Over Nodes"
]
},
{
"cell_type": "code",
"execution_count": 169,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAD20AAA9uCAYAAABI3BnXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAEzlAABM5QF1zvCVAAEAAElEQVR4nOzdd5jlVX0/8Pdnd1nKLqhUCyqCvQZUFCt2I0ZFibFG7FEw9h4jxJ+iMYk9sYOJwViCBVsUG0VRI1jQWAELIArS27K75/fHnVVYZu69M7fNnXm9nuc+6pxzz/nM7Mz9fj3ve86t1loAAAAAAAAAAAAAAAAAAABYmBWTLgAAAAAAAAAAAAAAAAAAAGCa2bQNAAAAAAAAAAAAAAAAAAAwAJu2AQAAAAAAAAAAAAAAAAAABmDTNgAAAAAAAAAAAAAAAAAAwABs2gYAAAAAAAAAAAAAAAAAABiATdsAAAAAAAAAAAAAAAAAAAADsGkbAAAAAAAAAAAAAAAAAABgADZtAwAAAAAAAAAAAAAAAAAADMCmbQAAAAAAAAAAAAAAAAAAgAHYtA0AAAAAAAAAAAAAAAAAADAAm7YBAAAAAAAAAAAAAAAAAAAGYNM2AAAAAAAAAAAAAAAAAADAAGzaBgAAAAAAAAAAAAAAAAAAGIBN2wAAAAAAAAAAAAAAAAAAAAOwaRsAAAAAAAAAAAAAAAAAAGAANm0DAAAAAAAAAAAAAAAAAAAMwKZtAAAAAAAAAAAAAAAAAACAAdi0DQAAAAAAAAAAAAAAAAAAMACbtgEAAAAAAAAAAAAAAAAAAAZg0zYAAAAAAAAAAAAAAAAAAMAAbNoGAAAAAAAAAAAAAAAAAAAYgE3bAAAAAAAAAAAAAAAAAAAAA7BpGwAAAAAAAAAAAAAAAAAAYAA2bQMAAAAAAAAAAAAAAAAAAAzApm0AAAAAAAAAAAAAAAAAAIAB2LQNAAAAAAAAAAAAAAAAAAAwAJu2AQAAAAAAAAAAAAAAAAAABmDTNgAAAAAAAAAAAAAAAAAAwABs2gYAAAAAAAAAAAAAAAAAABiATdsAAAAAAAAAAAAAAAAAAAADsGkbAAAAAAAAAAAAAAAAAABgADZtAwAAAAAAAAAAAAAAAAAADMCmbQAAAAAAAAAAAAAAAAAAgAHYtA0AAAAAAAAAAAAAAAAAADAAm7YBAAAAAAAAAAAAAAAAAAAGYNM2AAAAAAAAAAAAAAAAAADAAGzaBgAAAAAAAAAAAAAAAAAAGIBN2wAAAAAAAAAAAAAAAAAAAAOwaRsAAAAAAAAAAAAAAAAAAGAANm0DAAAAAAAAAAAAAAAAAAAMwKZtAAAAAAAAAAAAAAAAAACAAdi0DQAAAAAAAAAAAAAAAAAAMACbtgEAAAAAAAAAAAAAAAAAAAZg0zYAAAAAAAAAAAAAAAAAAMAAbNoGAAAAAAAAAAAAAAAAAAAYgE3bAAAAAAAAAAAAAAAAAAAAA7BpGwAAAAAAAAAAAAAAAAAAYAA2bQMAAAAAAAAAAAAAAAAAAAzApm0AAAAAAAAAAAAAAAAAAIAB2LQNAAAAAMDIVdWBVdW6PHabdI0AAAAAAAAAjJ78GABYqlZNugAAAAAAAAAAAAAAAIBpUVVbJblJkusm2THJVunsz7giyaVJfp/kzCS/bK1tnFSdAADAeNm0DQAAAABXMROu32bmcduZ/7xekmvPPLZLcmWSy5Kcm07Q/vMk30/yzSTfbq1tGHfdAAAAAAAAAIxGVV0nyV8kuW+SuyXZI8mKPp56WVX9MMlxSY5J8sXW2pUjKxQAAJgom7YBgGWhqk5PcuNJ1zGLt7TWnjfpImASqqoNcbgrZh4XJfldkrOS/DTJj5KcmOSU1tow5wOGrKr2S2dT7GJwVGvtD5MugvGpqi2S7J3kfukE7PskWd3jaSvTOSn9OklumuReV2m7oKqOTvKBJF9aategqjowyeF9dj83ye6ttQtHV1FHH3V9oLV24KjrAAAAAGDxkx/D4iM/Bq5KfsxiUlX7JHlhOhu2e+XIs9k6yZ1mHs9Pcm5V/XuSN7XWfj20QidEfgwAAFdn0zYAALAUbDnz2C7JDZLsmeQhV2k/r6o+leRDSb4ggIdF6cVJ7j3pImb8bxKh+xI382naD0ny6CT7JVk7xOGvleQJM48fVNVrWmsfHeL402SHdN7A8OpJFwIAAAAAwLIhP4bpJz9m4qrqpknenE6ePEw7pLN5++CqelOS17TWLh7yHIuV/BgAgCVvxaQLAAAAGIPrJHlSks8n+b+q+uuqqgnXNBWqat+qal0e+066RoD5qKp9qurIJL9P8t9J/irD3bC9udsl+UhVfamqdh/hPIvZC6pqp0kXAQAAAAAAM+THCyQ/BpaLmU9p/m6Gv2H7qrZI8pIkJ1XVn41wnsVGfgwAwJJm0zYAALDc3CLJB5J8var2mHQxAIzdk5I8NqPdqD2b+6YTtj98zPMuBmuTvGLSRQAAAAAAwCzkxwBcTVW9LsnhSdaMacqbJTm+qh44pvkmTX4MAMCSZtM2AACwXN01nc1zyyXwAGDyrpXkqKp62qQLmYBnVdUNJ10EAAAAAADMQX4MQKrq0CQvn8DUa5J8qqruMYG5J0F+DADAkrVq0gUAAABM0HbpBB4Pa619YdLFALAorUvyvSQ/TXJakvOSXJpOaL5Dkl2T3CvJjfscb0WSd1fVxa21/xp+uYvWlkkOSfLUCdcBAAAAAABzkR8DLGNV9Ygkfz+Pp3wtyReTnJDkl0n+kE6WfK0k2ye5bZK7J3l4kj36GG/LJB+vqtu21s6eRx3TSH4MAMCSZdM2AEDH4Um+PoF5T5nAnDAt/inJT/roV+mE59dOskuSOyW5Xfr//ztbJvlwVd25tfbzBdQJwNLznSSfSfL5JN9pra3r9YSq2j3JwUmenmRtr+5Jjqiqn7TWTh602CnypKp6Y2vtx5MuBAAAAAB6kB/D4iM/BmBkqupaSd7VZ/dPJvn71tr352g/Z+bx0yRHVdWLkzwsyT8muVmPsXdM8q9JHtVnLdNMfgwAwJJk0zYAQMexrbUjJl0EcDWfaa19dSFPrKptkvxVkuck2bOPp1w7yfuqat/WWlvInMDIHdpaO2TSRbCk/TbJe5P8R2vtp/N9cmvt1CQvqKrXJXlfOqF7N1ums3H7Tq21K+dd7XRameQ1Sf5y0oUAAAAAQA/yY1h85MfAVcmPGbYXJtm5R5/1Sf62tfZv8xm4tbYxySeq6pgk707y2B5PeWRV3bW1duJ85plC8mMAAJakFZMuAAAAYNhaa5e21g5vre2V5KAkl/XxtHtFCACwHJ2cTih+o9baqxayYfuqWmvntNYenuSVfXS/fZKnDDLfFHpUVd1x0kUAAAAAALB8yI8B6KaqtkjyrD66HjTfDdtX1Vq7OMkTkny8j+7PW+g8U0Z+DADAkmPTNgAAsKS11v41yb2TXNpH95eMuBwAFo//TfIXrbW9Wmv/NexPu26tvS7JP/bR9eVVtZzW6CrJ6yZdBAAAAAAAy5P8GIBZPCjJjj36/Gdr7d2DTjTzqduPT/LrHl0fXlVrBp1vCsiPAQBYcpbTG0IBAIBlqrX27XQCj17uWFW3HXU9AEzUqel8svberbVPj3iulyX5ao8+N05y3xHXMW5fTbKuS/sDq2rfsVQCAAAAAACbkR8DsJn79Wi/MsmrhjVZa+2yJK/u0W2rJPcY1pwT9tXIjwEAWEZs2gYAAJaF1tonkny+j677j7gUACaotfaPM5+s3cYwV0vywiS95nrEqGsZs18meWePPk5LBwAAAABgYuTHAFzFXj3av9xaO23Ic/5nkkt69LnTkOecFPkxAADLik3bAADAcvIPffRZKqfUArAItNZOSvLlHt32HUMp4/baJBd3ad+nqh42rmIAAAAAAGAW8mMAkmSPHu1fGPaErbV16XwCdTc3Hfa8EyQ/BgBg2Vg16QIAAADG6MQkv0+yU5c+dxxTLTCnqlqRTii4Q5JrJdkuyZokVyS5dOZxcZJfJzmztbZxQqUC/fl0kvt1ab9FVW3RWrtyXAWNWmvtd1X15iR/16Xba6v
"text/plain": [
"<Figure size 4000x4000 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(8, 8*math.ceil(len(multi_param_nodes) / 2)/3))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rates For {nodes} Nodes')\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[0, idx, :], markersize=4, lw=1, label='Test', capsize=4, c=(0, 0, 1), ecolor=(0, 0, 1, 0.5))\n",
"# ax.errorbar(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], fmt='x', ls='-', yerr=std_param_accuracy[1, idx, :], markersize=4, lw=1, label='Train', capsize=4, c=(1, 0, 0), ecolor=(1, 0, 0, 0.5))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, 1 - mean_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(1 - mean_param_accuracy + std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate.png')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
"execution_count": 170,
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp1"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAD20AAA9uCAYAAABI3BnXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAEzlAABM5QF1zvCVAAEAAElEQVR4nOzddfwlVf3H8ddni9ilG0HCFqVMUBFMFBUVExWwMH7Y3d2NnYBiYIGCioRSKhYYgICAICIluSyxsPv5/TF3YVm+d+bGzM3X8/G4D3TPueec7425M+c9cyYyE0mSJEmSJEmSJEmSJEmSJEmSJEmSJElSb2YNewCSJEmSJEmSJEmSJEmSJEmSJEmSJEmSNM68aFuSJEmSJEmSJEmSJEmSJEmSJEmSJEmS+uBF25IkSZIkSZIkSZIkSZIkSZIkSZIkSZLUBy/aliRJkiRJkiRJkiRJkiRJkiRJkiRJkqQ+eNG2JEmSJEmSJEmSJEmSJEmSJEmSJEmSJPXBi7YlSZIkSZIkSZIkSZIkSZIkSZIkSZIkqQ9etC1JkiRJkiRJkiRJkiRJkiRJkiRJkiRJffCibUmSJEmSJEmSJEmSJEmSJEmSJEmSJEnqgxdtS5IkSZIkSZIkSZIkSZIkSZIkSZIkSVIfvGhbkiRJkiRJkiRJkiRJkiRJkiRJkiRJkvrgRduSJEmSJEmSJEmSJEmSJEmSJEmSJEmS1Acv2pYkSZIkSZIkSZIkSZIkSZIkSZIkSZKkPnjRtiRJkiRJkiRJkiRJkiRJkiRJkiRJkiT1wYu2JUmSJEmSJEmSJEmSJEmSJEmSJEmSJKkPXrQtSZIkSZIkSZIkSZIkSZIkSZIkSZIkSX3wom1JkiRJkiRJkiRJkiRJkiRJkiRJkiRJ6oMXbUuSJEmSJEmSJEmSJEmSJEmSJEmSJElSH7xoW5IkSZIkSZIkSZIkSZIkSZIkSZIkSZL64EXbkiRJkiRJkiRJkiRJkiRJkiRJkiRJktQHL9qWJEmSJEmSJEmSJEmSJEmSJEmSJEmSpD540bYkSZIkSZIkSZIkSZIkSZIkSZIkSZIk9cGLtiVJkiRJkiRJkiRJkiRJkiRJkiRJkiSpD160LUmSJEmSJEmSJEmSJEmSJEmSJEmSJEl98KJtSZIkSZIkSZIkSZIkSZIkSZIkSZIkSeqDF21LkiRJkiRJkiRJkiRJkiRJkiRJkiRJUh+8aFuSJEmSJEmSJEmSJEmSJEmSJEmSJEmS+uBF25IkSZIkSZIkSZIkSZIkSZIkSZIkSZLUBy/aliRJkiRJkiRJkiRJkiRJkiRJkiRJkqQ+eNG2JEmSJEmSJEmSJEmSJEmSJEmSJEmSJPXBi7YlSZIkSZIkSZIkSZIkSZIkSZIkSZIkqQ9etC1JkiRJkiRJkiRJkiRJkiRJkiRJkiRJffCibUmSJEmSJEmSJEmSJEmSJEmSJEmSJEnqgxdtS5IkSZIkSZIkSZIkSZIkSZIkSZIkSVIfvGhbkiRJkiRJkiRJkiRJkiRJkiRJkiRJkvrgRduSJEmSJEmSJEmSJEmSJEmSJEmSJEmS1Acv2pYkSZIkSZIkSZIkSZIkSZIkSZIkSZKkPnjRtiRJkiRJkiRJkiRJkiRJkiRJkiRJkiT1wYu2JUmSJEmSJEmSJEmSJEmSJEmSJEmSJKkPXrQtSZIkSZIkSZIkSZIkSZIkSZIkSZIkSX3wom1JkiRJkiRJkiRJkiRJkiRJkiRJkiRJ6oMXbUuSJEmSJEmSJEmSJEmSJEmSJEmSJElSH7xoW5IkSZIkSZIkSZIkSZIkSZIkSZIkSZL64EXbkiRJkiRJkiRJkiRJkiRJkiRJkiRJktQHL9qWJEmSJEmSJEmSJEmSJEmSJEmSJEmSpD540bYkSZIkSZIkSZIkSZIkSZIkSZIkSZIk9cGLtiVJkiRJkiRJkiRJkiRJkiRJkiRJkiSpD160LUmSJEmSJEmSJEmSJEmSJEmSJEmSJEl98KJtSZIkSZIkSZIkSZIkSZIkSZIkSZIkSeqDF21LkiRJkiRJkiRJkiRJkiRJkiRJkiRJUh+8aFuSJEmSpIZFxPkRkW0eBw57fJIkSapXROxTsv+XEbH5sMcoSZIkSZIkaTjMjyVJkqaL+bEkTZc5wx6AJEmSJKk+EbE6sCmwEbAmsDIwF7geWNR6XAdcBFycmTmckUqSJEmSJEmSJEmSmmR+LEmS6hARKwNbABsC61LsU8wBbqLYr7gc+C9wQWYuHdY4JUmSRoEXbUuSJEnSGIuIbYFHAQ8HtgU26eLpN0XE+cD5wD+APwF/BP5pGC9J0nRohetbtR73af132cl7awKrAzcDNwBXUATt5wB/A34P/DEzlwx63JIkSZIkSZKkOzI/liRJdYiItYAnAo8AdgTuAszq4Kk3RMTpwInAMcDRmXlzYwOVJEkaQV60LUkaC61AYLNhj2MGn8nMVw97ENIwRESdgdxNrcdC4DLgYuBs4AzgZOA0A0DpNq1J8ZcC+wB376OplYB7tB6PXe7fr4mIPwFHA0dm5l/76EMTIiI2B/5VQ1M3c9t2f9m2/1KK7f9lwL+BMylOBDkvM2+poU9ppETEbhQXxY6CH2fmlcMehAYnIuYCDwQeSRGw7wDMq3jabIqV0tcC7grstFzZNRFxOHAQcOyk7bdHxD7AAR1WvwLYMjOvbW5EhQ7GdVBm7tP0OCRJkiRpWMyPpdFjfiwNj/mxhsH8WKqP+bFGSUTsALyO4oLtqhx5JqsA9289XgNcERHfBD6VmRfWNtAhMT+WJEmd8KJtSZIkjYKVWo/VgTsB2wGPX678qoj4KfBd4CgDeE2riFgFeDPwWmBBg12tQXEh1yOBD0fEf4FfAm/JzEsb7FfTYW7rsfxn+B5t6i6OiD9TrL57InDCIIIMaQDeQHGHi1HwJ8DQfcK17qb9eOAZwG7Uux+xBvDc1uPvEfG+zPxBje2Pk3UoTmB417AHIkmSJEmSNObMj6UOmB9rQpgfS+bHGgERcVfg0xR5cp3Wobh4e7+I+BTwvsy8ruY+RpX5sSRJU2rWsAcgSZIkdWAtYG/gSOAfEbFXRMSQxzQWImLniMiSx87DHqM6ExE7Utw94J00G7jPZGPg+cCm/TQSEceVfBaPq2WkmjTzKO4A+0bgcOCyiDgsIvaMiPnDHZokjb6I2CEivgNcDvwIeCbN7kfcF/h+RBwbEVs22M8oe21ErDfsQUiSJEmSJE048+MemR9PDvNjTSnzY0lqQOsuzX+h/gu2lzeXYvt9SkRs22A/o8b8WJKkKeRF25IkSRo39wAOAn4bEXcZ9mCkQYiIlwHHA5sPeSjSsK0E7A58G/hPRHw4IjYe8pgkaZTtDTybwZ+w9wiKsH33Afc7ChYAbx32ICRJkiRJkqaI+bGmjvmxdCvzY0nqU0R8EDgAGNTiF3cDToqIxwyov2EzP5YkaQp50bYkSZLG1YMpLgSZlsk7TamIeDPwBWDOsMcijZg1gTcB50fEJyNi9SGPR5J0e2sAP46IFw17IEPwsojo6w4zkiRJkiRJ6pr5saaC+bHU1pqYH0tSVyLiPcBbhtD1fOCnEfHQIfQ9DObHkiRNGSdtJEmSNM5Wp5i8e1JmHjXswUh1i4jnAR/q4ikLgV8ApwJ/B85p/du1wPXAKsBqwMbAZsB9gG2AhwAb1jZwabDmAq8Bnh0Rr87MQ4Y9IEkaY4uBvwJnA/8CrqLYh5gPrANsAuxEsR/RiVnAVyLiusz8Xv3DHVkrAe8GXjjkcUiSJEmSJE0b82NNNPNjqSPmx5LUgYh4MvDOLp5yPHA08BvgAuBKiv2JNYC1KfYjHgLsDtylg/ZWAg6NiPtk5qVdjGMcmR9LkjRlvGhbkjQpDgB+O4R+TxtCn9K4+DhwVgf1giI8XxPYALg/cF8631ddCTgkIh6Qmef0ME5pJEXE3YEvd1j9D8BHgZ9l5o0l9Ra1HpcApwCHLtffPYAnAE8CHkbx3ZS69Wbgig7qzaY4CWRlihNBNqI48eOuFMHNrB763hD4XkTsBLwmMxf30IYkTaM/Az8DjgT+3Mn2MyK2BPYDXgwsqKoOHBgRZ2X
"text/plain": [
"<Figure size 4000x4000 with 6 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(math.ceil(len(multi_param_nodes) / 2), 2, figsize=(8, 8*math.ceil(len(multi_param_nodes) / 2)/3))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"for idx, (nodes, ax) in enumerate(zip(multi_param_nodes, axes.flatten())):\n",
" ax.set_title(f'Error Rate Std Dev. For {nodes} Nodes')\n",
" ax.plot(multi_param_epochs, std_param_accuracy[0, idx, :], 'x', ls='-', lw=1, label='Test', c=(0, 0, 1))\n",
" ax.plot(multi_param_epochs, std_param_accuracy[1, idx, :], 'x', ls='-', lw=1, label='Train', c=(1, 0, 0))\n",
" ax.set_ylim(0, np.round(np.max(std_param_accuracy) + 0.05, 1))\n",
" ax.legend()\n",
" ax.grid()\n",
"\n",
"fig.tight_layout()\n",
2021-03-26 20:01:05 +00:00
"# fig.savefig(f'graphs/{exp1_testname}-test-train-error-rate-std.png')"
2021-03-22 20:49:29 +00:00
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"id": "eUPJuxUtVUc3",
"tags": [
"exp2"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 2\n",
"\n",
"For cancer dataset, choose an appropriate value of node and epochs, based on Exp 1) and use ensemble of individual (base) classifiers with random starting weights and Majority Vote to see if performance improves - repeat the majority vote ensemble at least thirty times with different 50/50 split and average and graph (Each classifier in the ensemble sees the same training patterns). Repeat for a different odd number (prevents tied vote) of individual classifiers between 3 and 25, and comment on the result of individualclassifier accuracy vs ensemble accuracy as number of base classifiers varies. Consider changing the number of nodes/epochs (both less complex and more complex) to see if you obtain better performance, and comment on the result with respect to why the optimal node/epoch combination may be different for an ensemble compared with the base classifier, as in Exp 1). \n",
"\n",
2021-03-22 20:49:29 +00:00
"(Hint4: to implement majority vote you need to determine the predicted class labels probably easier to implement yourself rather than use the ensemble matlab functions)\n"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 89,
"metadata": {
"tags": [
"exp2",
"exp-func"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [],
"source": [
2021-03-26 20:01:05 +00:00
"# num_models=[1, 3, 9, 15, 25]\n",
"num_models=[1, 3, 9]\n",
2021-03-19 17:21:00 +00:00
"\n",
"def evaluate_ensemble_vote(hidden_nodes=16, \n",
" epochs=50, \n",
" batch_size=128,\n",
" optimizer=lambda: 'sgd',\n",
" weight_init=lambda: 'glorot_uniform',\n",
2021-03-19 17:21:00 +00:00
" loss=lambda: 'categorical_crossentropy',\n",
" metrics=['accuracy'],\n",
" callbacks=None,\n",
" validation_split=None,\n",
2021-03-26 20:01:05 +00:00
" round_predictions=True,\n",
2021-03-19 17:21:00 +00:00
"\n",
" nmodels=num_models,\n",
"\n",
" verbose=0,\n",
" print_params=True,\n",
" return_model=True,\n",
"\n",
" dtrain=data_train,\n",
" dtest=data_test,\n",
" ltrain=labels_train,\n",
" ltest=labels_test):\n",
" for m in nmodels:\n",
" if print_params:\n",
" print(f\"Models: {m}\")\n",
"\n",
" models = [get_model(hidden_nodes, weight_init=weight_init) for _ in range(m)]\n",
" for model in models: \n",
" model.compile(\n",
" optimizer=optimizer(),\n",
" loss=loss(),\n",
" metrics=metrics\n",
" )\n",
" \n",
" \n",
"\n",
" response = {\"nodes\": hidden_nodes, \n",
" \"epochs\": list(),\n",
" \"num_models\": m}\n",
" \n",
2021-03-26 20:01:05 +00:00
" ###################\n",
" ## TRAIN MODELS\n",
" ###################\n",
2021-03-19 17:21:00 +00:00
" histories = list()\n",
" for idx, model in enumerate(models):\n",
" if isinstance(epochs, tuple):\n",
2021-03-26 20:01:05 +00:00
" if m == 1:\n",
" e = (epochs[0] + epochs[1]) / 2 # average, not lower bound if single model\n",
" else:\n",
" e = np.linspace(epochs[0], epochs[1], num=m)[idx]\n",
2021-03-19 17:21:00 +00:00
" e = int(e)\n",
" else:\n",
" e = epochs\n",
2021-03-26 20:01:05 +00:00
" \n",
"# print(m, e) # debug\n",
" \n",
" history = model.fit(dtrain.to_numpy(), \n",
" ltrain.to_numpy(), \n",
2021-03-19 17:21:00 +00:00
" epochs=e, \n",
" verbose=verbose,\n",
"\n",
" callbacks=callbacks,\n",
" validation_split=validation_split)\n",
" histories.append(history.history)\n",
2021-03-19 17:21:00 +00:00
" response[\"epochs\"].append(e)\n",
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## FEEDFORWARD TEST\n",
" ########################\n",
2021-03-19 17:21:00 +00:00
" response[\"predictions\"] = [model(dtest.to_numpy()) for model in models]\n",
"\n",
2021-03-26 20:01:05 +00:00
" ########################\n",
" ## ENSEMBLE ACCURACY\n",
" ########################\n",
" ensem_sum_rounded = sum(tf.math.round(pred) for pred in response[\"predictions\"])\n",
" ensem_sum = sum(pred for pred in response[\"predictions\"])\n",
" # round predictions to onehot vectors and sum over all ensemble models\n",
" # take argmax for ensemble predicted class\n",
" \n",
" ltest_tensor = tf.constant(ltest.to_numpy()) # transform test labels into tensor\n",
" correct = 0 # number of correct ensemble predictions\n",
" correct_num_models = 0 # when correctly predicted ensembley, proportion of models correctly classifying\n",
" individual_accuracy = 0 # proportion of models correctly classifying\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" # pc = predicted class, pcr = rounded predicted class, gt = ground truth\n",
" for pc, pcr, gt in zip(ensem_sum, ensem_sum_rounded, ltest_tensor):\n",
" gt_argmax = tf.math.argmax(gt)\n",
" \n",
" if round_predictions:\n",
" pred_val = pcr\n",
" else:\n",
" pred_val = pc\n",
" \n",
" individual_accuracy += pcr[gt_argmax] / m\n",
" \n",
" if tf.math.argmax(pred_val) == gt_argmax: # ENSEMBLE EVALUATE HERE\n",
2021-03-19 17:21:00 +00:00
" correct += 1\n",
2021-03-26 20:01:05 +00:00
" correct_num_models += pcr[gt_argmax] / m\n",
2021-03-19 17:21:00 +00:00
" \n",
2021-03-26 20:01:05 +00:00
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), True) # debug\n",
"# else:\n",
"# print(pc.numpy(), pcr.numpy(), gt.numpy(), (pcr[gt_argmax] / m).numpy(), False)\n",
" \n",
" ########################\n",
" ## RESULTS\n",
" ########################\n",
" response.update({\n",
" \"history\": histories,\n",
" \"optimizer\": model.optimizer.get_config(),\n",
" \"model_config\": json.loads(model.to_json()),\n",
" \"loss\": model.loss,\n",
" \"round_predictions\": round_predictions,\n",
" \n",
" \"accuracy\": correct / len(ltest), # average number of correct ensemble predictions\n",
" \"agreement\": correct_num_models / correct, # when correctly predicted ensembley, average proportion of models correctly classifying\n",
" \"individual_accuracy\": individual_accuracy / len(ltest) # average proportion of individual models correctly classifying\n",
" })\n",
2021-03-19 17:21:00 +00:00
"\n",
" if return_model:\n",
" response[\"models\"] = models\n",
"\n",
" yield response"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Single Iteration\n",
"Run a single iteration of ensemble model investigations"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 90,
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Models: 1\n",
"Models: 3\n",
2021-03-26 20:01:05 +00:00
"Models: 9\n"
2021-03-19 17:21:00 +00:00
]
}
],
"source": [
"single_ensem_results = list()\n",
2021-03-26 20:01:05 +00:00
"for test in evaluate_ensemble_vote(epochs=(5, 300), optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.02)):\n",
2021-03-19 17:21:00 +00:00
" single_ensem_results.append(test)"
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 91,
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFECAYAAAD2sk0XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABFvUlEQVR4nO3deXxU1f3/8dcnG0sIS0CQfVEW2fdNhQDWutS1glZU0K9L+3WpVdx+1q3V1lqq1tb2q7UWV4y71LVlCahsAqIsiqgEwiIgIHsgy/n9cW8mk2QmmYQMM0nez8fjPjJ3mXs/ZzIznznn3nuOOecQERGR2EqIdQAiIiKihCwiIhIXlJBFRETigBKyiIhIHFBCFhERiQNKyCIiInFACVmkGpjnOTP7wczeiHU8pZlZBzP7IdZxFDGzyWY28ygf05lZuwi2m2Zmvz4aMVUnM7vXzJ4qZ31E5ZfYUUKuo8xsX9DkzGx/0HyHSu4roi8wMxvpH+uGqkcet04GTgSOdc6dd6Q7M7MsM7uk1LIqJzHn3AbnXNMjjSscM+tkZvnVtK8M/33yTKnll/nL762O41SFn/TySn1+5scqHqldlJDrKOdco6IJOAT0Clq2IUqHvQTY5f89asws8SgcpgPwrXMut7JPNLOkKMRz1PYfJVuAH5lZg6BlE4G1MYon2DPBnx/n3MhYByS1gxKylGBm6Wb2opltM7NvzWxS0LorzGy9me01szV+TWYS3hflXX5t4f/C7DcZmADcAAw0s25B65LM7Df+vnebWVbQurFmtsTM9pjZWjM72V+ebWYnBW0XqKX7tZjpZvaame0DxprZT8xshR/7WjMbX9Hxzew/ZjY5aLtU//klmv3MbCLwFJDhvwY3mFmCmd1nZjlmtsXMHjOzev72k81sjpk9YWa7gcsr918KHPdkM/vUbyafa2YnBK1zZnadma0D5gTXYM1sRKka3qGgMjfzX7vvzewbM7um1Gv8mJnN8l+H/5hZur/6P0BicCuLmQ0zs0/8/916M7u+EsU7AMwBzvKPfSzQA8gq9Rr8r/8+3W5mz5tZk6B1V/iv/3dmdnWp54V9n1eV/3+dbWZ/98u82swG+usS/Nfue///9YmZtfDXdTCzd8xsh5l9YWanBe0z28xu8Zfv9d+n3f3PxG4r+3lLNbM3/W3nmVmnMLE2MLO/mtlmM9toZrcHrfuJeZ/vvf7xLzrS10Yi5JzTVMcnIBfo5D9+B5gK1MP7AtwM9AVSgT1AV3+7jkBn//E04NcVHONsYCuQCMwCfhO07k5gMV4tMxEY5S/v4h/zJ/7yDsDx/rps4KSgfQRiAO71y3Qq3o/O+sBooLs/fwawD695ubzjXwZ8EHSMi4GsMOWbDMwMmr8KWAW0A5oDHwP3Bm2b7/9NABqE2F8WcEm4Y/j73AWcDyQDt+DVHpP89Q54C2gMNAA6AfkhjtMQ+Az4hT//AjDdX94X2A6MDnqNv/OX1/f/j/f568rsHxjoTwnAYGA3MCDU61XqeRnA1/7/6S1/2Y3AH4H/C3odf4RXk+6J9/58DZjmr+uN994Z5pf/Wf81aVfe+7yi9zPee+upct4DecDP8N5H9wNz/XWnAUv8/0ei/7o08l+bz/B+qCYBI/zXvFXQ+zwLSPfjzMX78dMeONb/f4wJiu0w3uclBXgImBcUX3D5Hwde9GNog/de/Ym/7jvgRP/xsUDPWH9H1ZUp5gFoiv3kf8g7+R++/UBy0Lqp/gc91f9CPReoV+r5Yb/AgrZ5Gfib//gq4JugdWuBH4V4zp3AC2H2l035CfmDCuJZEPQFFO74jfwyt/Tn/w1cHWZ/kymZkGcBlwfN/xhYE7Ttmgriy/L/Fz8ETQcoTsiX4n/Z+/MJwCZghD/vih77850InZBfBP7lP070v9A7B63/PfBE0Gv8l6B1/wu8Wd7+Sx1rOnBdqNer1HYZeAk5CdiIl4yWAP0pmZD/if+DwJ/vDhwEDLgHPzn76473X5N2lPM+r+j97L+3DpX6vzwRVKYVQdv2BH7wH48D1gBDAQvaZjjwValjvApMDnqfnx+0bhFwU9B8JnBjUGxZQesa+v/PNkHviXb+63OgaLm/7jqKf8zkAFcCjcr7f2qq/klN1hKsA17NZ7vfrPYDcA1eTXI/3i//G4CtZvaKmbWJZKdm1hiv6THTX/Qa0N7Mis69tQPWhXhquOWR2FgqhpPM7GMz2+mXazBeLTPscZxz+/BqUuP9ptkxeF+WkWgDBJ+LX+8vCxlfGNc455oWTXgJMOT+nXOFeF+kER/DzG4GugG/8Be1wKttlxf31qDHB/B+tITbfy8z+6/fnLwbrzbfPNz2pTnn8oE3gV8D9Z1zy0ttEuo1ro+XwFvjvR5Fgh+HfZ9HGNpzwf8X59w1QetCvj7OuVl4PyaeBLaY2VTzTuN0ADoXxeHHcpoff5FtQY8PhpgP/h8EyumcOwDsKLUvgGPwWg1WBx3zd0Arf/0FwHnARjN7P/hUiESXErIE24TXlNss6MsmzTn3cwDn3LvOubF4CewQ3ocYvF/e5bkA7wvwZTP7DliN994rurgrB6+GVVq45eDVcIIv+GlVan3pmJ4DnsH7cdEUr8ZlERznebwfIj8FZjnndobZrrTNeF+2RTr4y8LFV1kl9m9mhteMGdExzGwsMAWv9lV0Idr3eE2u5cUdTqhj/RWvJaKDc64J8DrFr3mkXsRrrn4hxLpQr3EusBOvKbt90Lrgx+W+z6PFOfeIc64/MASvxWSiH8sXpRJ8I+fc76t4mEA5zbsgrjneaxHse7zPb5egYzZ2zp3ux7nIOXcm3mfqM+DvVYxFKkkJWQKcc5vwvkDvN7OG5l3sNNDMeppZK/9ijwZ4H+YDQIH/1G2ET2jgJd5HgX54zY798S5kmuDXEqb5x2xvZolmNsp/3nTgLDM7w78opr2ZHeev+8x/fqKZnYLXzFmeNLzaQp6Z/RQYFLQu3PHBO193PF5SeLGCYwTLBG42s7Z+7fou4KVKPL8i7wH9zOwc866i/hVebWlJRU8077a2F4CJLuiKeudcAV4LQNH/vzfwPxHG/T2QYCUveEvDa9LNNe9ivDMjKlkQ59x8vHPFoS4WzASuMrMTzCwVeAB42Tnn8FphzjezIf579tdB+wz7Pq9sfJEys8F+LEnAXrwfPgV4TdAJZvYLM0vxp5OtkrceBhnpf15S8JrtFzvnSvyg8ltTngH+ZGZN/c/WCWY21D/+xX6rVh7eD5eCMkeRqFBCltIm4tWAv8VLtI/i1UQTgFvxmuS2AW0p/pJ7GhjmN3/9LXhn/hf0ycBjzrnviia8ZJsLnI53sc4sYD5e0rwXwDm3Dq9m+gDeudxZFDe/3QMMwPvC/x+8C5jKcz3wGN6FUD8G5gatC3l8P4Z8vPPfHYAZFRwj2D+BN/AuFluN9wOiqrWeMpxz3+Odz78PL+bzgHOdc3kRPH0s0BKYYcVXRb/nr7sO79xtDl5573XOzYkgnv3Ag8By/33QAbgNuBbv4qobqdzrF7zvWc65XSGW/wfvNX0Xr7k6zz8OzrmVwE14/4NsvP9tsHDv80hMspJXqWdH8JwmeJ+TH/DOJX8MvOi/v87Ee09uwqv130nVv5tfA67GayU4Ce9ag1B+hfeZWuFv+yzQzF83Ce/13IX3Y+i6KsYilWTej0kRCcfMpgD9nHPhvtxERI6Yasgi5TCzNLwa+D9jHYuI1G5RS8hm9oaZ7TKzkFel+ucrVpnZ12Z2d7TiEKkqMzsb757MBc65rBiHIyK1XNSarM0sA++ijknOuQtCrP8Er+axCu98ylXOuRVRCUZERCTORa2G7Nco9oZa59+/muSc+9y/svMlvN5lRERE6qRYdTrfBu+KwiKb8Lo2LMPMrsTrNYaGDRsO69Sp0xEf3DnHpsObyPMvSE2yJBomNKRhYkPqW328WzprhoKCAhITj8bYCdFTG8oAtaMcKkP8qA3lUBnKWr169VbnXMhOaOJ+FBjn3FN4HfczfPhwt3DhwiPe59KtS5kydwpD6g3howMf0at5L1btWMXew3tJTU7lpLYnkdE+g5Pbnky
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"ensem_x = [i[\"num_models\"] for i in single_ensem_results]\n",
"\n",
"plt.plot(ensem_x, [i[\"accuracy\"] for i in single_ensem_results], 'x-', label='Ensemble Accuracy')\n",
"plt.plot(ensem_x, [i[\"individual_accuracy\"] for i in single_ensem_results], 'x-', label='Individual Accuracy')\n",
"plt.plot(ensem_x, [i[\"agreement\"] for i in single_ensem_results], 'x-', label='Agreement')\n",
"\n",
2021-03-19 17:21:00 +00:00
"plt.title(\"Test Accuracy for Horizontal Model Ensembles\")\n",
"plt.ylim(0, 1)\n",
"plt.grid()\n",
2021-03-26 20:01:05 +00:00
"plt.legend()\n",
2021-03-19 17:21:00 +00:00
"plt.ylabel(\"Accuracy\")\n",
"plt.xlabel(\"Number of Models\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"## Multiple Iterations\n",
2021-03-22 20:49:29 +00:00
"Run multiple iterations of the ensemble model investigations and average\n",
"\n",
2021-03-26 20:01:05 +00:00
"### CSV Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|1|0.06|0|128|16|50|1, 3, 9, 15, 25|\n",
"|2|0.06|0|35|16|1 - 100|1, 3, 9, 15, 25|\n",
"\n",
2021-03-26 20:01:05 +00:00
"### Pickle Results\n",
2021-03-22 20:49:29 +00:00
"\n",
"| test | learning rate | momentum | batch size | hidden nodes | epochs | models |\n",
"| --- | --- | --- | --- | --- | --- | --- |\n",
"|3|0.06|0.05|35|16|1 - 300|1, 3, 9, 15, 25|"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 116,
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Iteration 1/3\n",
"Iteration 2/3\n",
"Iteration 3/3\n"
2021-03-22 20:49:29 +00:00
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_results = list()\n",
2021-03-26 20:01:05 +00:00
"multi_ensem_iterations = 3\n",
2021-03-19 17:21:00 +00:00
"for i in range(multi_ensem_iterations):\n",
" print(f\"Iteration {i+1}/{multi_ensem_iterations}\")\n",
" data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.5, stratify=labels)\n",
2021-03-26 20:01:05 +00:00
" multi_ensem_results.append(list(evaluate_ensemble_vote(epochs=(1, 100),\n",
" hidden_nodes=16,\n",
" nmodels=[1, 3, 5, 7, 9],\n",
" optimizer=lambda: tf.keras.optimizers.SGD(learning_rate=0.05, momentum=0.02),\n",
2021-03-19 17:21:00 +00:00
" weight_init=lambda: 'random_uniform',\n",
" batch_size=35,\n",
" dtrain=data_train, \n",
" dtest=data_test, \n",
" ltrain=labels_train, \n",
" ltest=labels_test,\n",
" return_model=False,\n",
" print_params=False)))"
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"### Accuracy Tensor\n",
"\n",
"Create a tensor for holding the accuracy results\n",
"\n",
2021-03-26 20:01:05 +00:00
"(Iterations x Param x Number of models)\n",
"\n",
"#### Params\n",
"0. Test Accuracy\n",
"1. Train Accuracy\n",
"2. Individual Accuracy\n",
"3. Agreement"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 117,
"metadata": {
"tags": [
"exp2"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 Tests\n",
"Models: [1, 3, 5, 7, 9]\n",
"\n",
"Loss: categorical_crossentropy\n",
"LR: 0.05\n",
"Momentum: 0.02\n"
]
}
],
2021-03-19 17:21:00 +00:00
"source": [
"multi_ensem_models = sorted(list({i[\"num_models\"] for i in multi_ensem_results[0]}))\n",
"multi_ensem_iter = len(multi_ensem_results)\n",
"\n",
2021-03-26 20:01:05 +00:00
"accuracy_ensem_tensor = np.zeros((multi_ensem_iter, 4, len(multi_ensem_models)))\n",
2021-03-19 17:21:00 +00:00
"for iter_idx, iteration in enumerate(multi_ensem_results):\n",
" for single_test in iteration:\n",
2021-03-22 20:49:29 +00:00
" \n",
2021-03-26 20:01:05 +00:00
" ensem_models_idx = multi_ensem_models.index(single_test['num_models'])\n",
" accuracy_ensem_tensor[iter_idx, :, ensem_models_idx] = [single_test[\"accuracy\"], \n",
" np.mean([i[\"accuracy\"][-1] for i in single_test[\"history\"]]), \n",
" single_test[\"individual_accuracy\"], \n",
" single_test[\"agreement\"]]\n",
2021-03-19 17:21:00 +00:00
" \n",
"mean_ensem_accuracy = np.mean(accuracy_ensem_tensor, axis=0)\n",
2021-03-22 20:49:29 +00:00
"std_ensem_accuracy = np.std(accuracy_ensem_tensor, axis=0)\n",
"\n",
"print(f'{multi_ensem_iter} Tests')\n",
"print(f'Models: {multi_ensem_models}')\n",
"print()\n",
2021-03-26 20:01:05 +00:00
"print(f'Loss: {multi_ensem_results[0][0][\"loss\"]}')\n",
"print(f'LR: {multi_ensem_results[0][0][\"optimizer\"][\"learning_rate\"]:.3}')\n",
"print(f'Momentum: {multi_ensem_results[0][0][\"optimizer\"][\"momentum\"]:.3}')"
2021-03-19 17:21:00 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"source": [
"#### Export/Import Test Sets\n",
"\n",
"Export mean and standard deviations for retrieval and visualisation "
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 76,
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [],
"source": [
"pickle.dump(multi_ensem_results, open(\"result.p\", \"wb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"multi_ensem_results = pickle.load(open(\"results/exp2-test1.p\", \"rb\"))"
]
},
{
"cell_type": "raw",
"metadata": {},
2021-03-19 17:21:00 +00:00
"source": [
"np.savetxt(\"exp2-mean.csv\", mean_ensem_accuracy, delimiter=',')\n",
"np.savetxt(\"exp2-std.csv\", std_ensem_accuracy, delimiter=',')"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"mean_ensem_accuracy = np.loadtxt(\"results/test1-exp2-mean.csv\", delimiter=',')\n",
2021-03-26 20:01:05 +00:00
"std_ensem_accuracy = np.loadtxt(\"results/test1-exp2-std.csv\", delimiter=',')"
2021-03-19 17:21:00 +00:00
]
},
2021-03-22 20:49:29 +00:00
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
"### Best Results"
]
},
2021-03-19 17:21:00 +00:00
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 118,
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-26 20:01:05 +00:00
"Models: 1, 8e+01% Accurate\n"
2021-03-22 20:49:29 +00:00
]
}
],
"source": [
"best_ensem_accuracy_idx = np.unravel_index(np.argmax(mean_ensem_accuracy[0, :]), mean_ensem_accuracy.shape)\n",
"best_ensem_accuracy = mean_ensem_accuracy[best_ensem_accuracy_idx]\n",
"best_ensem_accuracy_models = multi_ensem_models[best_ensem_accuracy_idx[1]]\n",
"\n",
2021-03-26 20:01:05 +00:00
"print(f'Models: {best_ensem_accuracy_models}, {best_ensem_accuracy * 100:.1}% Accurate')"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "markdown",
2021-03-26 20:01:05 +00:00
"metadata": {
"tags": [
"exp2"
]
},
2021-03-22 20:49:29 +00:00
"source": [
2021-03-26 20:01:05 +00:00
"### Test/Train Accuracy Over Model Numbers"
2021-03-22 20:49:29 +00:00
]
},
{
"cell_type": "code",
2021-03-26 20:01:05 +00:00
"execution_count": 139,
"metadata": {
"tags": [
"exp2"
]
},
2021-03-19 17:21:00 +00:00
"outputs": [
{
"data": {
2021-03-26 20:01:05 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFECAYAAAD2sk0XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABJOUlEQVR4nO3dd5xcdb3/8ddnyvbdVBIIEEJvqZCQgEKASBNQQFAuoKCilIuKSuDarpV7BRELKMQfINIiigJBQLgEQkAgGHpNQgkhtPRkS3Z3yuf3xzk7O1szu9nJzGbfz8djHnP6+XxnZ+dzvt9TvubuiIiISGFFCh2AiIiIKCGLiIgUBSVkERGRIqCELCIiUgSUkEVERIqAErKIiEgRUEIW6YIFbjazdWZ2Z6Hjac/MRpvZukLH0cLMzjKzhwodR0+Z2aFm9kY38+eZ2RlbMqaubCrWdssuNbOP5zsm6TtKyLJJ4T92g5nVZb3O34L7dzOrD/e7zMy+38N1d+jlrg8GPgZs6+4n9nIb2bF0+GHfnCTm7svcffDmxtUVMxtjZsk+2tahZpZu9x2qM7NBfbH9YmRmPwq/f19sN/2GcPqhhYlMilWs0AFIv3Gkuz/e3QJmFnP35KamdbO+Aebu6U5m7+nuy81sf2C+mT3t7g/mHH3vjAbecvfGnq7Yk3L3Rr63nydvuftuhQ5iC1sCnAb8EcDMyoCjgfcKGZQUJ9WQZbOEtb6fmtlCoN7MjjCzN8zsx2a2CvixmQ0xs9lmtsrM3jSzc7LWv9HMrjazh4EGYNfu9ufuzwCvABOztvF3M1thZmvM7K9mNjSc3pKwF4W1sYPD6f9pZkvCeP5kZpWdlOt04Drg0HDdr5tZJCzXu2b2gZn91sxKw+XPMrNHzGyWma0Hvth+mzl+ngeb2XNhM/mjZrZ31jw3swvM7G3gkewarJkd2K7m2WRm88J5m/r8f2tmc82s1swebPn8gAeBaNY2R5vZVDP7t5ltMLN3zOxrvSlnuzKPMbOkmX0l/Fw/NLMzs+Z/KdxXrZktaqlZmll5+N1538yWm9l/dVKuR8LY7zGzbczsjjD2R8xsWLs4fmpma81ssZkd0U28m/z+ZHkM2N3MtgvHjwceBeqytldmZr8Ly73MzP7bzCLhvKiZ/cbMVpvZImBau1jGmdn8MO5nzGxyFzEfF352tRa0eJ3aTcxSIErI0hf+AzgVGAQkgTFACtgO+BlwdbjcaOBE4GdmNj1r/VOBmUA1sLS7HZnZVGAs8GbW5L8DO4evauC/Adz9yHD+nu5e5e6PmdkpwLnAJ4AdgTjw4/b7cfdbw+Xmhev+FvgycDJwYBjD/sB3slY7GHgSGALc0l05uijbMGAO8FNgG+AfwBwzy27JOgKYAByZva67PxnGWQWMAF4Hbg9nb+rz/yzwzXCfUeAb4fQjgVTLdt19GZAAzgEGA58JtzWpp2XtRJTgM90J+ALwOzOrDpPdr4FPuHt1GNM74TpXAEOBPYADgM+b2XFZ2zwFuIDgezgGeBz4DTA8LMfXs5YdE8YwErgY+KuZDWkfZK7fnywO/IXgOw5wOh2/Gz8A9gX2Bj4OnBF+BhB81oeF8w8j+F9riaUK+GdWmX4K/N2CWnh71wFfCj/DacCL3cQsheLueunV7YsgSdYC67Je08N584DvZC17KFAPxMLxKNAM7Jy1zP8Cs8LhG1uGu9m/A+sJatAOXAVEulj2KGBhu3V3yBr/J3Ba1vhYYGkX2zoLeChrfC7wxXb7WpS17KJNlGNe+Nlkf44NLfsAPg88mrV8hKBp88CsshyYNX8MkOxkP7cBf+zB539V1rzzgbu62367fc0GLujs82q33KEEB2nZZX8laz8ODMtafgVBK0hl+Lc/ASjNmm/hZzcqa9oFwI1Z5fpt1rzLgDlZ4+dllfNQoBEoy5r/eMv3JPy7ndGL78+PCBLhBGAhwcHDewSnCl8HDg2XexM4LGu9c4AHwuFHgLOy5p0NvBEOnwo82G6fC7O2uxT4eDj8brhuVV/9LujV9y/VkCVXx7j74KzXo1nzlrdb9kNvPb85nKAWsSxr/jvAqG7W78y+QBXBj+4h4TYxs5iZ/Tps0twA3AEM63ozjAZmhU3C6wh+eLfJYf+EMW9uOc7J/hwJEmCn2/fgXPq7PdmHmX2boMZ4Xjgpl8//o6zhBoLPuavt72tm/2dmK8Om+ZPo/vPO9na779C+WfNS7r66fRzuXk9QK/w68JEFpyRGEfzNyoFXs/6W/0NQw22xImt4Yyfj2eVc6W2vFXiXoGbdXo+/P+7+AlAGfJ/gIKD9uf/uvlfbhbFkx5Udy/SWWMJ49qbt37bFyQStI8vN7J/Zp0KkeCghS19o32VY9vgqgubB0VnTRgPvd7N+5ztxT7v778JttiSc0wlqOAe5ew3BD491s5n3gDPbJYbuzgFme58+KEeu2zczI2gWzWkfZnY4cBFwUlZyyeXz70pn+7qaoFl+tLsPIjhd0N3nvdnc/T53PxzYAWgiSLyrwuFdsv6ONe5+TC93M7xdU++OwAedLNfb789twIV0fiqju+/VB2Es2XFlx/JA+1jc/bb2O3D3Be5+LMEBywvANTnELFuYErLklbunCGqtPzOzCjMbS3Au9s+bsdlfABeZWQnBOeNGYK2ZDSdISNlWEDSJtrgB+K6Z7QpgZtuZ2dE57vd24Ntmtn144dMP2LxytHc/MMHMPh2eN/4mQU1u4aZWNLPRwK3A6R6c6wU2+/NfBUSs7W1j1QTNzY0WXCR3bE4l6yUzGxlekFROkIAbCGrTaeBPwC/NbLAFF9ztbWYH9HJXceD7ZhY3s08RNEXf38lyvf3+XAMc4e5PdjLvduAHFlx8tyPwLVr/PncA3ww/h+0IWoha/AOYZGYnhC1F5WZ2tLW7lczMSszsNDOrITg4qyM4fSBFRglZcvWgtb2K99IerHsBwXmzdwkuWvqRuz/S20Dc/Z8ESeHzwE3AWoJm18cIzvFl+wnwt7BJ7+PuPhu4Hrg3bOJ+FNgnx11fD9wJPA28SlDT+N/elqM9d19FcK70x8BqgibGE9w9kcPqhxNczDUn62/UklB69fmHzcU/B54PP7/RwCXAfwIbCGp8c3IvIbtYx/uQN9V0GiG4yOojgoOr7QmafiE4YFkPvASsIfgudLgQK0dLCVoEVgC/BD7n7mvbL9Tb74+7r3X3uV3M/imwiOC88pMEyfhP4bxZwHzgNYJz2ZkDKXdfT3BA9LUw7qXAV7vYx5kETeFrCS4MvKCL5aSAzH1zW9lERERkc6mGLCIiUgTylpDN7M7wZvU7uph/gJm9YsFDJP47X3GIiIj0B/msIf+G1pvbO/M7gtsZ9gQ+aWbj8hiLiIhIUctbQnb3eQQPk+ggvI8w5u4vhleB/hk4rrNlRUREBoJCdS4xirYPV38PmN7ZgmZ2NsETZqioqJg6ZsyYPgsilUoRjUb7bHuFoDIUj62hHCpD8dgayqEydPTqq69+5O7bdjav6Ht7cvfrCB4/x7Rp0/ypp57qs23Pnz+fQw45pM+2VwgqQ/HYGsqhMhSPraEcKkNHZra0q3mFusr6fYL7CVtsT25PDhIREdkqFSQhu/v7QMrMxptZlOAh6fcUIhYREZFikM/bnh4C/kpwBfVyC/prvS+8oAuCJ8XMBhYD/3T3l/IVi4iISLHL2zlkd/9EJ5M/mTX/KYIefEREpAvJZJLly5fT2Ni46YWLTHV1Na+//nqhw9gsvS1DWVkZO+ywA7FY7mm26C/qEhEZyJYvX051dTU77bQTQQdg/UdtbS3V1dWFDmOz9KYM7s6aNWtYvnw5PbkzSI/OFBEpYo2NjQwdOrTfJeOBzMwYOnRoj1s1lJBFRIqcknH/05u/mZqsRUS2ElfNXcJVD79Byp2oGV87fDe+NmP3QoclOVINWURkK/G1Gbtz21emkko7t31lap8k41gsxsSJEzOvm266qQ8izd3SpUuZPHlyp/PGjBlDXV1dTtv5xCc+wcSJExk9ejQjRozIlGfdunU5rX/XXXexePHiXMPuFdWQRUSkS4MHD+b5558vdBib7aG
2021-03-19 17:21:00 +00:00
"text/plain": [
2021-03-26 20:01:05 +00:00
"<Figure size 560x350 with 1 Axes>"
2021-03-19 17:21:00 +00:00
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-03-22 20:49:29 +00:00
"fig = plt.figure(figsize=(8, 5))\n",
2021-03-19 17:21:00 +00:00
"fig.set_dpi(fig_dpi)\n",
"\n",
2021-03-26 20:01:05 +00:00
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[0, :], yerr=std_ensem_accuracy[0, :], capsize=2, label='Ensemble Test')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[2, :], yerr=std_ensem_accuracy[2, :], capsize=2, label='Individual Test')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[1, :], yerr=std_ensem_accuracy[1, :], capsize=2, label='Individual Train')\n",
"plt.errorbar(multi_ensem_models, 1 - mean_ensem_accuracy[3, :], yerr=std_ensem_accuracy[3, :], capsize=2, label='Anti-agreement')\n",
2021-03-19 17:21:00 +00:00
"\n",
2021-03-26 20:01:05 +00:00
"plt.title(f\"Error Rate for Horizontal Ensemble Models\")\n",
"plt.ylim(0, 1)\n",
2021-03-19 17:21:00 +00:00
"plt.grid()\n",
2021-03-22 20:49:29 +00:00
"plt.legend()\n",
2021-03-19 17:21:00 +00:00
"plt.xlabel(\"Number of Models\")\n",
2021-03-22 20:49:29 +00:00
"plt.ylabel(\"Error Rate\")\n",
2021-03-19 17:21:00 +00:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
2021-03-26 20:01:05 +00:00
"tags": [
"exp2"
]
},
"source": [
"### Ensemble Model Statistics"
]
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {
"tags": [
"exp2"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFECAYAAAD2sk0XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAArEAAAKxAFmbYLUAABK7klEQVR4nO3deZgU1b3/8fe3txlmYwcBRVwQUBBQENQoKkmMOzESF4wao1Fz0WuMxpjlJt4svxs16o1JkMQkRqPGqHFJXJIrirhijDuKQBRkUTZh9q27z++Pqu7p6dl6hunpnpnP63n66dq66pyenv7UOVVdZc45REREJLcCuS6AiIiIKJBFRETyggJZREQkDyiQRURE8oACWUREJA8okEVERPKAAlkkjZmdZmYbzazKzEbkujzpzOxxMzs91+VIMDNnZrv34PZuN7PvZrDcUWa2pifK1J3MbJyZRduZn1H9pfdRIEuXmNl7ZvZqrsuRJdcBX3bOlTjntuzKiszsPDN7spXpXQ4x59xxzrl7d6Vc7TGzpWZ2djeta62ZVZtZccq0IjOrNLO13bGNLpZrnP83qEp7TMlVmUQUyNJpZnYIMBqYbGaTenC7oR7a1Fjgnc6+KNvlM09v/J/dCMxLGT8F+Cg3RWkm5u90pT7eynWhpP/qjf/ckntnAw8D/wd8KXWGmR1jZq+YWYWZrTazI/zpw83sbjPbYmbbzOx//Ok/MLPbUl6f7GZMdN2Z2cVmthG43cwGm9kT/jq2mtmvzaygve2b2VlmtjStnH8ws2+nV8zMqoAg8J6ZvexPO8LMXjOznWb2TOpOiN/KWmhmHwBPd+XN9Ot0j1+nf5vZRSnzbjezX5jZU0ANsE9qC9bM3khr4TkzG+fP+5qZve+/T380s4Gp77GZ/ZeZfeK3Yo/1530POAK4zV/ft/3pf/H/dp+Y2X1mNqQTVbwHWJAyfjZwV9p7cICZPeu/x/8ys8NT5u1jZs/7reoHgAFpr/0P/2+9zf+7FrOL/PfxEjP7wF/vNSnzTjSvh6jSf+/O8KcHzexaM1tnZpvN7GeJnTT/c363mT3gv6/Pm9luZnarmZWb2atmtndaGS7x1/OhtdNjYd4hlhX+3+YR8w+zmPc/97j/nm4zs3t29X2RLHPO6aFHxg8gBGwGTsIL47WA+fP2BiqAE/FCbSywrz/vH8BtQCneF+qh/vQfALelrP8oYI0/PA5wwK1Aof+6of62C4BRwKvA5e1tHygCyoEx/nKF/vhebdTRAbv7w0OBHcCpQBi4ClgNhFKWfRgoAwa0sq7zgCc72MZdeKFVBBwIbAXm+PNuB7YBB/vvfRhYCpzdyjq/CzznL/MZvFbo/kAx8ABwe8p7HAWu9tf5VWBdynparB8vRIuBgcATwM2t1aWVMq0F5gDrgeH+Yz1wOLDWXyYCvA9c5pf9dOATYLA//2XgJ/5y84BG4Lv+vPnAW8Ce/ufjbuCG9M9SK+UaB0Tb+Zw74D6gBJgM1AH7+PM+Bg73h3cD9veHrwKe8us4CG8HbWHK57wab2cngvf/8AHwRf9v8FvgD2mf+9/jfVYPw/tc75fymUjU/xBgAzDFf++uA+735/0P8Ct//QXAYbn+/tCj/UfOC6BH73oAx+MFVAQvXGuBI/153wHuauU1Y4AGoLiVeT+g40Ae1U55Lkr5Amp1+/68O4Ar/OEvAC+0s87UsPwS8EzKvABeF+yhKcse2s66zsMLkJ1pDwfsjrfj0EDKzgHw/4DF/vDtieGU+UtpGZjH+V/Mu/njvwWuTZk/wf9bmf8elwMBf16RX55Bba0/bVvHAq+09n61suxa4FPATcBC/3EzMJumQD4iMZzyuheBM/GCtg4oTJn3HE2B9ARwVsq8ySnrTX6WWilX4rOV/ncJptTp4JTlXwbm+cPrgQuAkrR1riQl9PB2DJemfM4fSZl3CfBm2t/v9bSyjUuZ/0fg2ymfiUT9b01M98dL8T5vIeCHwIO0seOpR/491GUtnXU28JBzrsE5Vwk85k8DL2A+aOU1uwNbnHPVXdhe3DmXPN5oZqVmdoeZbTCzCuBGvFZse9sH7wvtTH/4TLyWVCZGAx8mRpxzcbwv5NEpy2zoYB3POOcGpT5S5g3Da9l8mDJtXWfWb2b74H1Jf9E593Fr5fbXWQgkupq3+nXBOVfjTytpY/0hM7vZ74qtAO6n6T3P1F3AWXhd13elzRuN956mSrwHo/yy1qXMS112LLDY75bdiRfWwzMsUyz97+Kci6XM35wyXEPT+3Ma8Hlgg3mHTxKHMMYCj6eU5S4g9Sz91BMEa1sZT3//16cNj2qlDmOB76Rscz1e78duwPV4n4FnzGylmX2llddLHlEgS8bMrATvhJwvmNnHZvYx8FngNPOO467H27tPtx4YbmZFrcyrpvkxwZFp89NvR3YF3hfuNOdcmT9uKdtpbfsAS4AxZnYQXgsv07OUN+F96QHeiVXAHv70tsrYGdvwWjRjU6aNzXT9/vHSh4AfOOdeSJnVrNz+cB1eV3BH0re3AK+1eZj/np9G03ueEefcK3g7A4Odc/9Mm70J7z1NlXgPPgKGmVlhyrzUZTcC56aF6i4fQ26Pc265c+4EvM/qG8CilLIcnVKOgc65/XdhU3ukDbd2ItxG4Htp9R/gnNvgnKtwzv2nc24sXk/NLenHqSW/KJClM07F666eAEzzHxPx9shPwDsOepKZHW9mATPbw8z2cc5tAp4BbjKzEjMbYGaz/XW+ARzln+AyAvjPDspQitdaKTezPYGvpcxrdfsAfsvnT3hd188557ZmWOfHgalmdop/gs7X8Vozr2T4+nb55bof+JF5PweaDHzFL2smfgf80zm3KG36vcCFZjbJD+0fA392zmWy87CF5js2pXhhvsPMhgFXZli2dKf6j3TLAcw7OS5kZvOBScATzrl1wNvAd80sbGYn4x03Tfgd8O3E39nMRpnZ57pYvg6ZWcS8kwTL8HakqoBEq/p3eH/HUeYZZ2ZzdmFz3zOzQv9/5WS88wDS/R5YaGZT/fINMbNT/OETzGxvfyeyHG9HK9bKOiRPKJClM87GO977kXPuY/+xEe+L6Gzn3Ad4x2d/jPcFsISmbrYFeCe6rMXrRjvZn/5/wN/wjr89TetfOqn+11/nDn/ZBxMzOtg+eN3WB5B5dzXOuW14JxJdC2zH66qc55xrzHQdGViId8xvPfAIXmv36Qxf+0XgDGt+pvVY59w/8I5FP4bX/dsIXJ7hOm8BzvO7Qb+FtxOzA68L91m847ad5px7xznX4udkzrkGvM/DmXjv8TXAyc65Hf4iZwFz8Vr359H8b34P3vHyR/3u9GfwTmTLRNBa/g45kzA/F+893YF38txCf/r1eMe+n8f7/P2Vli3/TMXwdvo+xDu57FLn3HvpC/m9IlcCd/j1fxXvhDmA/fD+pyqBR/FOflzXxfJID7DMdphFej8zG453jHk351xVrssjIpJKLWTpF/xuu8vwzshWGItI3slaIJvZg2a2w8zub2P+If6P2deY2X9lqxwivo/wund/kONyiIi0Kpst5P8Fzmln/i/xjhlNAI43XUNWssg5t5tzboJzbm2uyyIi0pqsBbJzbineyQQtmNlovCsdvZly9uuJ2SqLiIhIvuupi/WnG433+7mEjXiX12vBzC7AuyoORUVFs8aNG7fLGw9+9BFWW9vBQkFcKNT8ORiEUCj5TCAA1qmfY3a7WCxGMBjMaRl2VV+oA/SNeqgO+aMv1EN1aOmdd97Z7JzbrbV5uQrkjDnnbsO7BjKzZ892L730Urete9myZRxxxBHEq6qIbtmSfDRu2UJ0y9Zm06JbtuAa037pEggSGjaM0IgRhEYMJzxihDc83H/2pwcHDcKyFNzLli3jyCOPzMq6e0pfqAP0jXqoDvmjL9RDdWjJ2rntaK4CeRPe9Y0TxtD8ykQ9xswIlpYSLC2lYJ992lzOOUds587mQb3VH97qhXjdineIbt0K0eb3FrdwmNDw4SkhPaJliI8YQaC0NGvBLSIi+S0ngeyc22RmMTM7EFgBnAFcmIuyZMrMCA0eTGjwYJi
"text/plain": [
"<Figure size 560x350 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(8, 5))\n",
"fig.set_dpi(fig_dpi)\n",
"\n",
"# plt.plot(multi_ensem_models, mean_ensem_accuracy[0, :], 'x-', label='Ensemble')\n",
"# plt.plot(multi_ensem_models, mean_ensem_accuracy[2, :], 'x-', label='Individual Test')\n",
"# plt.plot(multi_ensem_models, mean_ensem_accuracy[1, :], 'x-', label='Individual Train')\n",
"# plt.plot(multi_ensem_models, mean_ensem_accuracy[3, :], 'x-', label='Agreement')\n",
"\n",
"plt.errorbar(multi_ensem_models, mean_ensem_accuracy[0, :], yerr=std_ensem_accuracy[0, :], capsize=2, label='Ensemble Test')\n",
"plt.errorbar(multi_ensem_models, mean_ensem_accuracy[2, :], yerr=std_ensem_accuracy[2, :], capsize=2, label='Individual Test')\n",
"plt.errorbar(multi_ensem_models, mean_ensem_accuracy[1, :], yerr=std_ensem_accuracy[1, :], capsize=2, label='Individual Train')\n",
"plt.errorbar(multi_ensem_models, mean_ensem_accuracy[3, :], yerr=std_ensem_accuracy[3, :], capsize=2, label='Agreement')\n",
"\n",
"plt.title(\"Accuracy for Horizontal Model Ensembles\")\n",
"plt.ylim(0, 1)\n",
"plt.grid()\n",
"plt.legend()\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.xlabel(\"Number of Models\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FSZq1mNiVZq_",
"tags": [
"ex3"
]
2021-03-19 17:21:00 +00:00
},
"source": [
"# Experiment 3\n",
"\n",
"Repeat Exp 2) for cancer dataset with two different optimisers of your choice e.g. trainlm and trainrp. Comment and discuss the result and decide which is more appropriate training algorithm for the problem. In your discussion, include in your description a detailed account of how the training algorithms (optimisations) work."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyNAMGLKzaoWaq1wvQ+w0w8h",
"collapsed_sections": [],
"name": "nncw.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"toc-autonumbering": false,
2021-03-26 20:01:05 +00:00
"toc-showcode": true,
2021-03-22 20:49:29 +00:00
"toc-showmarkdowntxt": false,
2021-03-26 20:01:05 +00:00
"toc-showtags": true
2021-03-19 17:21:00 +00:00
},
"nbformat": 4,
"nbformat_minor": 4
}