Penguins example#
¶We will use the penguin dataset to train a neural network which can classify which species a penguin belongs to, based on their physical characteristics.#
(inspired from https://carpentries-incubator.github.io/deep-learning-intro/02-keras/index.html and mcnakhaee/palmerpenguins)#
pip install palmerpenguins
Collecting palmerpenguins
Downloading palmerpenguins-0.1.4-py3-none-any.whl (17 kB)
Requirement already satisfied: pandas in /srv/conda/envs/notebook/lib/python3.11/site-packages (from palmerpenguins) (2.1.4)
Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from palmerpenguins) (1.26.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pandas->palmerpenguins) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pandas->palmerpenguins) (2023.3.post1)
Requirement already satisfied: tzdata>=2022.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pandas->palmerpenguins) (2023.3)
Requirement already satisfied: six>=1.5 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->palmerpenguins) (1.16.0)
Installing collected packages: palmerpenguins
Successfully installed palmerpenguins-0.1.4
Note: you may need to restart the kernel to use updated packages.
import pandas as pd
import seaborn as sns
from palmerpenguins import load_penguins
sns.set_style('whitegrid')
penguins = load_penguins()
type(penguins)
pandas.core.frame.DataFrame
penguins.head()
| species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | year | |
|---|---|---|---|---|---|---|---|---|
| 0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | male | 2007 |
| 1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | female | 2007 |
| 2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | female | 2007 |
| 3 | Adelie | Torgersen | NaN | NaN | NaN | NaN | NaN | 2007 |
| 4 | Adelie | Torgersen | 36.7 | 19.3 | 193.0 | 3450.0 | female | 2007 |
penguins.describe()
| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | |
|---|---|---|---|---|---|
| count | 342.000000 | 342.000000 | 342.000000 | 342.000000 | 344.000000 |
| mean | 43.921930 | 17.151170 | 200.915205 | 4201.754386 | 2008.029070 |
| std | 5.459584 | 1.974793 | 14.061714 | 801.954536 | 0.818356 |
| min | 32.100000 | 13.100000 | 172.000000 | 2700.000000 | 2007.000000 |
| 25% | 39.225000 | 15.600000 | 190.000000 | 3550.000000 | 2007.000000 |
| 50% | 44.450000 | 17.300000 | 197.000000 | 4050.000000 | 2008.000000 |
| 75% | 48.500000 | 18.700000 | 213.000000 | 4750.000000 | 2009.000000 |
| max | 59.600000 | 21.500000 | 231.000000 | 6300.000000 | 2009.000000 |
penguins["species"].unique()
array(['Adelie', 'Gentoo', 'Chinstrap'], dtype=object)
penguins["species"].describe()
count 344
unique 3
top Adelie
freq 152
Name: species, dtype: object
sns.pairplot(penguins, hue="species")
<seaborn.axisgrid.PairGrid at 0x7f1edc44e850>
penguins['species'] = penguins['species'].astype('category')
# Drop two columns and the rows that have NaN values in them
penguins_filtered = penguins.drop(columns=['island', 'sex']).dropna()
penguins_filtered
| species | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | |
|---|---|---|---|---|---|---|
| 0 | Adelie | 39.1 | 18.7 | 181.0 | 3750.0 | 2007 |
| 1 | Adelie | 39.5 | 17.4 | 186.0 | 3800.0 | 2007 |
| 2 | Adelie | 40.3 | 18.0 | 195.0 | 3250.0 | 2007 |
| 4 | Adelie | 36.7 | 19.3 | 193.0 | 3450.0 | 2007 |
| 5 | Adelie | 39.3 | 20.6 | 190.0 | 3650.0 | 2007 |
| ... | ... | ... | ... | ... | ... | ... |
| 339 | Chinstrap | 55.8 | 19.8 | 207.0 | 4000.0 | 2009 |
| 340 | Chinstrap | 43.5 | 18.1 | 202.0 | 3400.0 | 2009 |
| 341 | Chinstrap | 49.6 | 18.2 | 193.0 | 3775.0 | 2009 |
| 342 | Chinstrap | 50.8 | 19.0 | 210.0 | 4100.0 | 2009 |
| 343 | Chinstrap | 50.2 | 18.7 | 198.0 | 3775.0 | 2009 |
342 rows × 6 columns
# Extract columns corresponding to features
penguins_features = penguins_filtered.drop(columns=['species'])
penguins_features
| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | |
|---|---|---|---|---|---|
| 0 | 39.1 | 18.7 | 181.0 | 3750.0 | 2007 |
| 1 | 39.5 | 17.4 | 186.0 | 3800.0 | 2007 |
| 2 | 40.3 | 18.0 | 195.0 | 3250.0 | 2007 |
| 4 | 36.7 | 19.3 | 193.0 | 3450.0 | 2007 |
| 5 | 39.3 | 20.6 | 190.0 | 3650.0 | 2007 |
| ... | ... | ... | ... | ... | ... |
| 339 | 55.8 | 19.8 | 207.0 | 4000.0 | 2009 |
| 340 | 43.5 | 18.1 | 202.0 | 3400.0 | 2009 |
| 341 | 49.6 | 18.2 | 193.0 | 3775.0 | 2009 |
| 342 | 50.8 | 19.0 | 210.0 | 4100.0 | 2009 |
| 343 | 50.2 | 18.7 | 198.0 | 3775.0 | 2009 |
342 rows × 5 columns
import pandas as pd
target = pd.get_dummies(penguins_filtered['species'])
target.head() # print out the top 5 to see what it looks like.
| Adelie | Chinstrap | Gentoo | |
|---|---|---|---|
| 0 | True | False | False |
| 1 | True | False | False |
| 2 | True | False | False |
| 4 | True | False | False |
| 5 | True | False | False |
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(penguins_features, target,test_size=0.1, random_state=0, shuffle=True, stratify=target)
X_train
| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | |
|---|---|---|---|---|---|
| 245 | 49.5 | 16.1 | 224.0 | 5650.0 | 2009 |
| 338 | 45.7 | 17.0 | 195.0 | 3650.0 | 2009 |
| 89 | 38.9 | 18.8 | 190.0 | 3600.0 | 2008 |
| 262 | 50.5 | 15.2 | 216.0 | 5000.0 | 2009 |
| 48 | 36.0 | 17.9 | 190.0 | 3450.0 | 2007 |
| ... | ... | ... | ... | ... | ... |
| 285 | 51.3 | 19.9 | 198.0 | 3700.0 | 2007 |
| 268 | 44.5 | 15.7 | 217.0 | 4875.0 | 2009 |
| 218 | 46.2 | 14.4 | 214.0 | 4650.0 | 2008 |
| 234 | 47.4 | 14.6 | 212.0 | 4725.0 | 2009 |
| 187 | 48.4 | 16.3 | 220.0 | 5400.0 | 2008 |
307 rows × 5 columns
X_test
| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | |
|---|---|---|---|---|---|
| 333 | 49.3 | 19.9 | 203.0 | 4050.0 | 2009 |
| 112 | 39.7 | 17.7 | 193.0 | 3200.0 | 2009 |
| 239 | 51.3 | 14.2 | 218.0 | 5300.0 | 2009 |
| 269 | 48.8 | 16.2 | 222.0 | 6000.0 | 2009 |
| 311 | 47.5 | 16.8 | 199.0 | 3900.0 | 2008 |
| 12 | 41.1 | 17.6 | 182.0 | 3200.0 | 2007 |
| 219 | 49.5 | 16.2 | 229.0 | 5800.0 | 2008 |
| 278 | 51.3 | 19.2 | 193.0 | 3650.0 | 2007 |
| 161 | 46.8 | 15.4 | 215.0 | 5150.0 | 2007 |
| 84 | 37.3 | 17.8 | 191.0 | 3350.0 | 2008 |
| 173 | 45.1 | 14.5 | 215.0 | 5000.0 | 2007 |
| 119 | 41.1 | 18.6 | 189.0 | 3325.0 | 2009 |
| 235 | 50.0 | 15.9 | 224.0 | 5350.0 | 2009 |
| 80 | 34.6 | 17.2 | 189.0 | 3200.0 | 2008 |
| 42 | 36.0 | 18.5 | 186.0 | 3100.0 | 2007 |
| 255 | 49.1 | 15.0 | 228.0 | 5500.0 | 2009 |
| 122 | 40.2 | 17.0 | 176.0 | 3450.0 | 2009 |
| 72 | 39.6 | 17.2 | 196.0 | 3550.0 | 2008 |
| 14 | 34.6 | 21.1 | 198.0 | 4400.0 | 2007 |
| 216 | 45.8 | 14.2 | 219.0 | 4700.0 | 2008 |
| 199 | 50.5 | 15.9 | 225.0 | 5400.0 | 2008 |
| 299 | 50.6 | 19.4 | 193.0 | 3800.0 | 2007 |
| 32 | 39.5 | 17.8 | 188.0 | 3300.0 | 2007 |
| 19 | 46.0 | 21.5 | 194.0 | 4200.0 | 2007 |
| 334 | 50.2 | 18.8 | 202.0 | 3800.0 | 2009 |
| 188 | 42.6 | 13.7 | 213.0 | 4950.0 | 2008 |
| 222 | 47.7 | 15.0 | 216.0 | 4750.0 | 2008 |
| 105 | 39.7 | 18.9 | 184.0 | 3550.0 | 2009 |
| 310 | 49.7 | 18.6 | 195.0 | 3600.0 | 2008 |
| 41 | 40.8 | 18.4 | 195.0 | 3900.0 | 2007 |
| 196 | 50.5 | 15.9 | 222.0 | 5550.0 | 2008 |
| 71 | 39.7 | 18.4 | 190.0 | 3900.0 | 2008 |
| 335 | 45.6 | 19.4 | 194.0 | 3525.0 | 2009 |
| 118 | 35.7 | 17.0 | 189.0 | 3350.0 | 2009 |
| 270 | 47.2 | 13.7 | 214.0 | 4925.0 | 2009 |
y_train
| Adelie | Chinstrap | Gentoo | |
|---|---|---|---|
| 245 | False | False | True |
| 338 | False | True | False |
| 89 | True | False | False |
| 262 | False | False | True |
| 48 | True | False | False |
| ... | ... | ... | ... |
| 285 | False | True | False |
| 268 | False | False | True |
| 218 | False | False | True |
| 234 | False | False | True |
| 187 | False | False | True |
307 rows × 3 columns
y_test
| Adelie | Chinstrap | Gentoo | |
|---|---|---|---|
| 333 | False | True | False |
| 112 | True | False | False |
| 239 | False | False | True |
| 269 | False | False | True |
| 311 | False | True | False |
| 12 | True | False | False |
| 219 | False | False | True |
| 278 | False | True | False |
| 161 | False | False | True |
| 84 | True | False | False |
| 173 | False | False | True |
| 119 | True | False | False |
| 235 | False | False | True |
| 80 | True | False | False |
| 42 | True | False | False |
| 255 | False | False | True |
| 122 | True | False | False |
| 72 | True | False | False |
| 14 | True | False | False |
| 216 | False | False | True |
| 199 | False | False | True |
| 299 | False | True | False |
| 32 | True | False | False |
| 19 | True | False | False |
| 334 | False | True | False |
| 188 | False | False | True |
| 222 | False | False | True |
| 105 | True | False | False |
| 310 | False | True | False |
| 41 | True | False | False |
| 196 | False | False | True |
| 71 | True | False | False |
| 335 | False | True | False |
| 118 | True | False | False |
| 270 | False | False | True |
from tensorflow import keras
2024-01-09 07:27:03.676328: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
from numpy.random import seed
seed(1)
from tensorflow.random import set_seed
set_seed(2)
inputs = keras.Input(shape=X_train.shape[1])
inputs
<KerasTensor: shape=(None, 5) dtype=float32 (created by layer 'input_1')>
hidden_layer = keras.layers.Dense(10, activation="relu")(inputs)
2024-01-09 07:27:06.498989: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20953 MB memory: -> device: 0, name: NVIDIA A10, pci bus id: 0000:61:00.0, compute capability: 8.6
output_layer = keras.layers.Dense(3, activation="softmax")(hidden_layer)
model = keras.Model(inputs=inputs, outputs=output_layer)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 5)] 0
dense (Dense) (None, 10) 60
dense_1 (Dense) (None, 3) 33
=================================================================
Total params: 93 (372.00 Byte)
Trainable params: 93 (372.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
import tensorflow as tf
model.compile(optimizer=tf.keras.optimizers.legacy.Adam(), loss=keras.losses.CategoricalCrossentropy())
history = model.fit(X_train, y_train, epochs=100)
Epoch 1/100
10/10 [==============================] - 0s 2ms/step - loss: 1015.4874
Epoch 2/100
10/10 [==============================] - 0s 2ms/step - loss: 884.8992
Epoch 3/100
10/10 [==============================] - 0s 2ms/step - loss: 761.2500
Epoch 4/100
10/10 [==============================] - 0s 2ms/step - loss: 643.1030
Epoch 5/100
10/10 [==============================] - 0s 2ms/step - loss: 535.1355
Epoch 6/100
10/10 [==============================] - 0s 2ms/step - loss: 457.2358
Epoch 7/100
10/10 [==============================] - 0s 2ms/step - loss: 378.4028
Epoch 8/100
1/10 [==>...........................] - ETA: 0s - loss: 378.2369
2024-01-09 07:27:07.202523: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
10/10 [==============================] - 0s 2ms/step - loss: 303.1219
Epoch 9/100
10/10 [==============================] - 0s 2ms/step - loss: 226.0870
Epoch 10/100
10/10 [==============================] - 0s 2ms/step - loss: 154.4714
Epoch 11/100
10/10 [==============================] - 0s 2ms/step - loss: 82.2054
Epoch 12/100
10/10 [==============================] - 0s 2ms/step - loss: 26.8942
Epoch 13/100
10/10 [==============================] - 0s 2ms/step - loss: 29.9371
Epoch 14/100
10/10 [==============================] - 0s 2ms/step - loss: 21.5699
Epoch 15/100
10/10 [==============================] - 0s 2ms/step - loss: 20.9316
Epoch 16/100
10/10 [==============================] - 0s 2ms/step - loss: 19.9883
Epoch 17/100
10/10 [==============================] - 0s 2ms/step - loss: 18.9656
Epoch 18/100
10/10 [==============================] - 0s 2ms/step - loss: 18.5297
Epoch 19/100
10/10 [==============================] - 0s 2ms/step - loss: 17.7819
Epoch 20/100
10/10 [==============================] - 0s 2ms/step - loss: 16.8236
Epoch 21/100
10/10 [==============================] - 0s 2ms/step - loss: 16.1105
Epoch 22/100
10/10 [==============================] - 0s 2ms/step - loss: 15.7117
Epoch 23/100
10/10 [==============================] - 0s 2ms/step - loss: 14.7285
Epoch 24/100
10/10 [==============================] - 0s 2ms/step - loss: 13.9431
Epoch 25/100
10/10 [==============================] - 0s 2ms/step - loss: 13.1817
Epoch 26/100
10/10 [==============================] - 0s 2ms/step - loss: 12.4109
Epoch 27/100
10/10 [==============================] - 0s 2ms/step - loss: 12.4272
Epoch 28/100
10/10 [==============================] - 0s 2ms/step - loss: 11.2461
Epoch 29/100
10/10 [==============================] - 0s 2ms/step - loss: 10.7755
Epoch 30/100
10/10 [==============================] - 0s 2ms/step - loss: 9.7845
Epoch 31/100
10/10 [==============================] - 0s 2ms/step - loss: 9.1328
Epoch 32/100
10/10 [==============================] - 0s 2ms/step - loss: 8.9986
Epoch 33/100
10/10 [==============================] - 0s 2ms/step - loss: 7.6808
Epoch 34/100
10/10 [==============================] - 0s 2ms/step - loss: 7.3601
Epoch 35/100
10/10 [==============================] - 0s 2ms/step - loss: 6.5697
Epoch 36/100
10/10 [==============================] - 0s 2ms/step - loss: 6.0516
Epoch 37/100
10/10 [==============================] - 0s 2ms/step - loss: 5.4636
Epoch 38/100
10/10 [==============================] - 0s 2ms/step - loss: 5.0098
Epoch 39/100
10/10 [==============================] - 0s 2ms/step - loss: 4.7890
Epoch 40/100
10/10 [==============================] - 0s 2ms/step - loss: 4.3679
Epoch 41/100
10/10 [==============================] - 0s 2ms/step - loss: 4.5915
Epoch 42/100
10/10 [==============================] - 0s 2ms/step - loss: 4.2618
Epoch 43/100
10/10 [==============================] - 0s 2ms/step - loss: 3.8259
Epoch 44/100
10/10 [==============================] - 0s 2ms/step - loss: 3.7209
Epoch 45/100
10/10 [==============================] - 0s 2ms/step - loss: 3.4403
Epoch 46/100
10/10 [==============================] - 0s 2ms/step - loss: 3.4103
Epoch 47/100
10/10 [==============================] - 0s 2ms/step - loss: 3.8025
Epoch 48/100
10/10 [==============================] - 0s 2ms/step - loss: 3.6355
Epoch 49/100
10/10 [==============================] - 0s 2ms/step - loss: 3.1812
Epoch 50/100
10/10 [==============================] - 0s 2ms/step - loss: 2.6628
Epoch 51/100
10/10 [==============================] - 0s 2ms/step - loss: 2.4035
Epoch 52/100
10/10 [==============================] - 0s 2ms/step - loss: 2.3838
Epoch 53/100
10/10 [==============================] - 0s 2ms/step - loss: 3.0629
Epoch 54/100
10/10 [==============================] - 0s 2ms/step - loss: 2.3382
Epoch 55/100
10/10 [==============================] - 0s 2ms/step - loss: 2.1033
Epoch 56/100
10/10 [==============================] - 0s 2ms/step - loss: 1.6672
Epoch 57/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5632
Epoch 58/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5092
Epoch 59/100
10/10 [==============================] - 0s 2ms/step - loss: 1.3471
Epoch 60/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2847
Epoch 61/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5229
Epoch 62/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5827
Epoch 63/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2878
Epoch 64/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0977
Epoch 65/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1016
Epoch 66/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2771
Epoch 67/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5543
Epoch 68/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1844
Epoch 69/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1441
Epoch 70/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1843
Epoch 71/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1437
Epoch 72/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0442
Epoch 73/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1147
Epoch 74/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0953
Epoch 75/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1173
Epoch 76/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0520
Epoch 77/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0242
Epoch 78/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0637
Epoch 79/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1284
Epoch 80/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1324
Epoch 81/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1172
Epoch 82/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2402
Epoch 83/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1240
Epoch 84/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0794
Epoch 85/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0693
Epoch 86/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0898
Epoch 87/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0458
Epoch 88/100
10/10 [==============================] - 0s 1ms/step - loss: 1.2224
Epoch 89/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0337
Epoch 90/100
10/10 [==============================] - 0s 1ms/step - loss: 1.1267
Epoch 91/100
10/10 [==============================] - 0s 1ms/step - loss: 2.1266
Epoch 92/100
10/10 [==============================] - 0s 1ms/step - loss: 1.9268
Epoch 93/100
10/10 [==============================] - 0s 1ms/step - loss: 1.1867
Epoch 94/100
10/10 [==============================] - 0s 1ms/step - loss: 0.9464
Epoch 95/100
10/10 [==============================] - 0s 990us/step - loss: 1.4309
Epoch 96/100
10/10 [==============================] - 0s 1ms/step - loss: 1.1865
Epoch 97/100
10/10 [==============================] - 0s 996us/step - loss: 1.6076
Epoch 98/100
10/10 [==============================] - 0s 1ms/step - loss: 1.6465
Epoch 99/100
10/10 [==============================] - 0s 1ms/step - loss: 1.3923
Epoch 100/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0014
sns.lineplot(x=history.epoch, y=history.history['loss'])
<Axes: >
y_pred = model.predict(X_test)
prediction = pd.DataFrame(y_pred, columns=target.columns)
prediction
2/2 [==============================] - 0s 7ms/step
| Adelie | Chinstrap | Gentoo | |
|---|---|---|---|
| 0 | 7.648391e-01 | 1.715673e-01 | 6.359363e-02 |
| 1 | 5.454711e-01 | 4.545289e-01 | 2.100461e-08 |
| 2 | 1.833452e-09 | 6.493035e-10 | 1.000000e+00 |
| 3 | 7.966254e-15 | 1.784904e-15 | 1.000000e+00 |
| 4 | 8.872245e-01 | 1.109848e-01 | 1.790675e-03 |
| 5 | 9.873828e-01 | 1.261715e-02 | 4.766879e-10 |
| 6 | 1.018278e-14 | 5.001138e-14 | 1.000000e+00 |
| 7 | 9.829051e-01 | 1.709081e-02 | 3.938773e-06 |
| 8 | 1.619196e-08 | 9.007781e-09 | 1.000000e+00 |
| 9 | 7.116187e-01 | 2.883812e-01 | 1.857964e-07 |
| 10 | 1.152847e-07 | 1.505111e-07 | 9.999998e-01 |
| 11 | 9.103811e-01 | 8.961879e-02 | 3.948410e-08 |
| 12 | 5.516306e-11 | 1.907833e-10 | 1.000000e+00 |
| 13 | 6.351340e-01 | 3.648659e-01 | 1.353383e-08 |
| 14 | 8.191870e-01 | 1.808130e-01 | 1.281623e-09 |
| 15 | 1.306449e-12 | 1.319104e-11 | 1.000000e+00 |
| 16 | 9.992987e-01 | 7.012586e-04 | 1.666273e-09 |
| 17 | 5.740452e-01 | 4.259438e-01 | 1.094981e-05 |
| 18 | 2.500658e-02 | 6.073019e-03 | 9.689205e-01 |
| 19 | 2.341845e-06 | 3.067315e-05 | 9.999670e-01 |
| 20 | 1.879179e-11 | 7.223661e-11 | 1.000000e+00 |
| 21 | 9.881677e-01 | 1.179230e-02 | 4.009461e-05 |
| 22 | 9.027959e-01 | 9.720411e-02 | 2.563040e-08 |
| 23 | 9.342888e-01 | 1.061412e-02 | 5.509710e-02 |
| 24 | 7.657505e-01 | 2.333840e-01 | 8.654961e-04 |
| 25 | 3.820047e-07 | 4.642012e-07 | 9.999992e-01 |
| 26 | 4.629333e-06 | 1.199403e-05 | 9.999834e-01 |
| 27 | 9.900614e-01 | 9.938368e-03 | 2.153488e-07 |
| 28 | 9.464169e-01 | 5.357845e-02 | 4.640924e-06 |
| 29 | 8.868090e-01 | 1.114146e-01 | 1.776450e-03 |
| 30 | 6.899412e-12 | 5.361637e-12 | 1.000000e+00 |
| 31 | 9.746879e-01 | 2.498482e-02 | 3.272409e-04 |
| 32 | 8.940583e-01 | 1.059394e-01 | 2.274549e-06 |
| 33 | 7.915920e-01 | 2.084079e-01 | 1.082515e-07 |
| 34 | 9.376542e-07 | 6.750374e-07 | 9.999983e-01 |
predicted_species = prediction.idxmax(axis="columns")
predicted_species
0 Adelie
1 Adelie
2 Gentoo
3 Gentoo
4 Adelie
5 Adelie
6 Gentoo
7 Adelie
8 Gentoo
9 Adelie
10 Gentoo
11 Adelie
12 Gentoo
13 Adelie
14 Adelie
15 Gentoo
16 Adelie
17 Adelie
18 Gentoo
19 Gentoo
20 Gentoo
21 Adelie
22 Adelie
23 Adelie
24 Adelie
25 Gentoo
26 Gentoo
27 Adelie
28 Adelie
29 Adelie
30 Gentoo
31 Adelie
32 Adelie
33 Adelie
34 Gentoo
dtype: category
Categories (3, object): ['Adelie', 'Chinstrap', 'Gentoo']
from sklearn.metrics import confusion_matrix
true_species = y_test.idxmax(axis="columns")
matrix = confusion_matrix(true_species, predicted_species)
print(matrix)
[[14 0 1]
[ 7 0 0]
[ 0 0 13]]
# Convert to a pandas dataframe
confusion_df = pd.DataFrame(matrix, index=y_test.columns.values, columns=y_test.columns.values)
# Set the names of the x and y axis, this helps with the readability of the heatmap.
confusion_df.index.name = 'True Label'
confusion_df.columns.name = 'Predicted Label'
sns.heatmap(confusion_df, annot=True)
<Axes: xlabel='Predicted Label', ylabel='True Label'>
model.save('my_first_model')
INFO:tensorflow:Assets written to: my_first_model/assets
INFO:tensorflow:Assets written to: my_first_model/assets
pretrained_model = keras.models.load_model('my_first_model')
# use the pretrained model here
y_pretrained_pred = pretrained_model.predict(X_test)
pretrained_prediction = pd.DataFrame(y_pretrained_pred, columns=target.columns.values)
# idxmax will select the column for each row with the highest value
pretrained_predicted_species = pretrained_prediction.idxmax(axis="columns")
print(pretrained_predicted_species)
2/2 [==============================] - 0s 2ms/step
0 Adelie
1 Adelie
2 Gentoo
3 Gentoo
4 Adelie
5 Adelie
6 Gentoo
7 Adelie
8 Gentoo
9 Adelie
10 Gentoo
11 Adelie
12 Gentoo
13 Adelie
14 Adelie
15 Gentoo
16 Adelie
17 Adelie
18 Gentoo
19 Gentoo
20 Gentoo
21 Adelie
22 Adelie
23 Adelie
24 Adelie
25 Gentoo
26 Gentoo
27 Adelie
28 Adelie
29 Adelie
30 Gentoo
31 Adelie
32 Adelie
33 Adelie
34 Gentoo
dtype: category
Categories (3, object): ['Adelie', 'Chinstrap', 'Gentoo']