Penguins example#

¶We will use the penguin dataset to train a neural network which can classify which species a penguin belongs to, based on their physical characteristics.#

(inspired from https://carpentries-incubator.github.io/deep-learning-intro/02-keras/index.html and mcnakhaee/palmerpenguins)#

pip install palmerpenguins
Collecting palmerpenguins
  Downloading palmerpenguins-0.1.4-py3-none-any.whl (17 kB)
Requirement already satisfied: pandas in /srv/conda/envs/notebook/lib/python3.11/site-packages (from palmerpenguins) (2.1.4)
Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from palmerpenguins) (1.26.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pandas->palmerpenguins) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pandas->palmerpenguins) (2023.3.post1)
Requirement already satisfied: tzdata>=2022.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pandas->palmerpenguins) (2023.3)
Requirement already satisfied: six>=1.5 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->palmerpenguins) (1.16.0)
Installing collected packages: palmerpenguins
Successfully installed palmerpenguins-0.1.4
Note: you may need to restart the kernel to use updated packages.
import pandas as pd
import seaborn as sns 
from palmerpenguins import load_penguins
sns.set_style('whitegrid')
penguins = load_penguins()
type(penguins)
pandas.core.frame.DataFrame
penguins.head()
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex year
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 male 2007
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 female 2007
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 female 2007
3 Adelie Torgersen NaN NaN NaN NaN NaN 2007
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 female 2007
penguins.describe()
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year
count 342.000000 342.000000 342.000000 342.000000 344.000000
mean 43.921930 17.151170 200.915205 4201.754386 2008.029070
std 5.459584 1.974793 14.061714 801.954536 0.818356
min 32.100000 13.100000 172.000000 2700.000000 2007.000000
25% 39.225000 15.600000 190.000000 3550.000000 2007.000000
50% 44.450000 17.300000 197.000000 4050.000000 2008.000000
75% 48.500000 18.700000 213.000000 4750.000000 2009.000000
max 59.600000 21.500000 231.000000 6300.000000 2009.000000
penguins["species"].unique()
array(['Adelie', 'Gentoo', 'Chinstrap'], dtype=object)
penguins["species"].describe()
count        344
unique         3
top       Adelie
freq         152
Name: species, dtype: object
sns.pairplot(penguins, hue="species")
<seaborn.axisgrid.PairGrid at 0x7f1edc44e850>
_images/de407d08e3cbac9a0ff7b4ba1b6c5e34cb159a9340fe117bf3692ce18d9190c7.png
penguins['species'] = penguins['species'].astype('category')
# Drop two columns and the rows that have NaN values in them
penguins_filtered = penguins.drop(columns=['island', 'sex']).dropna()
penguins_filtered
species bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year
0 Adelie 39.1 18.7 181.0 3750.0 2007
1 Adelie 39.5 17.4 186.0 3800.0 2007
2 Adelie 40.3 18.0 195.0 3250.0 2007
4 Adelie 36.7 19.3 193.0 3450.0 2007
5 Adelie 39.3 20.6 190.0 3650.0 2007
... ... ... ... ... ... ...
339 Chinstrap 55.8 19.8 207.0 4000.0 2009
340 Chinstrap 43.5 18.1 202.0 3400.0 2009
341 Chinstrap 49.6 18.2 193.0 3775.0 2009
342 Chinstrap 50.8 19.0 210.0 4100.0 2009
343 Chinstrap 50.2 18.7 198.0 3775.0 2009

342 rows × 6 columns

# Extract columns corresponding to features
penguins_features = penguins_filtered.drop(columns=['species'])
penguins_features
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year
0 39.1 18.7 181.0 3750.0 2007
1 39.5 17.4 186.0 3800.0 2007
2 40.3 18.0 195.0 3250.0 2007
4 36.7 19.3 193.0 3450.0 2007
5 39.3 20.6 190.0 3650.0 2007
... ... ... ... ... ...
339 55.8 19.8 207.0 4000.0 2009
340 43.5 18.1 202.0 3400.0 2009
341 49.6 18.2 193.0 3775.0 2009
342 50.8 19.0 210.0 4100.0 2009
343 50.2 18.7 198.0 3775.0 2009

342 rows × 5 columns

import pandas as pd

target = pd.get_dummies(penguins_filtered['species'])
target.head() # print out the top 5 to see what it looks like.
Adelie Chinstrap Gentoo
0 True False False
1 True False False
2 True False False
4 True False False
5 True False False
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(penguins_features, target,test_size=0.1, random_state=0, shuffle=True, stratify=target)
X_train
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year
245 49.5 16.1 224.0 5650.0 2009
338 45.7 17.0 195.0 3650.0 2009
89 38.9 18.8 190.0 3600.0 2008
262 50.5 15.2 216.0 5000.0 2009
48 36.0 17.9 190.0 3450.0 2007
... ... ... ... ... ...
285 51.3 19.9 198.0 3700.0 2007
268 44.5 15.7 217.0 4875.0 2009
218 46.2 14.4 214.0 4650.0 2008
234 47.4 14.6 212.0 4725.0 2009
187 48.4 16.3 220.0 5400.0 2008

307 rows × 5 columns

X_test
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g year
333 49.3 19.9 203.0 4050.0 2009
112 39.7 17.7 193.0 3200.0 2009
239 51.3 14.2 218.0 5300.0 2009
269 48.8 16.2 222.0 6000.0 2009
311 47.5 16.8 199.0 3900.0 2008
12 41.1 17.6 182.0 3200.0 2007
219 49.5 16.2 229.0 5800.0 2008
278 51.3 19.2 193.0 3650.0 2007
161 46.8 15.4 215.0 5150.0 2007
84 37.3 17.8 191.0 3350.0 2008
173 45.1 14.5 215.0 5000.0 2007
119 41.1 18.6 189.0 3325.0 2009
235 50.0 15.9 224.0 5350.0 2009
80 34.6 17.2 189.0 3200.0 2008
42 36.0 18.5 186.0 3100.0 2007
255 49.1 15.0 228.0 5500.0 2009
122 40.2 17.0 176.0 3450.0 2009
72 39.6 17.2 196.0 3550.0 2008
14 34.6 21.1 198.0 4400.0 2007
216 45.8 14.2 219.0 4700.0 2008
199 50.5 15.9 225.0 5400.0 2008
299 50.6 19.4 193.0 3800.0 2007
32 39.5 17.8 188.0 3300.0 2007
19 46.0 21.5 194.0 4200.0 2007
334 50.2 18.8 202.0 3800.0 2009
188 42.6 13.7 213.0 4950.0 2008
222 47.7 15.0 216.0 4750.0 2008
105 39.7 18.9 184.0 3550.0 2009
310 49.7 18.6 195.0 3600.0 2008
41 40.8 18.4 195.0 3900.0 2007
196 50.5 15.9 222.0 5550.0 2008
71 39.7 18.4 190.0 3900.0 2008
335 45.6 19.4 194.0 3525.0 2009
118 35.7 17.0 189.0 3350.0 2009
270 47.2 13.7 214.0 4925.0 2009
y_train
Adelie Chinstrap Gentoo
245 False False True
338 False True False
89 True False False
262 False False True
48 True False False
... ... ... ...
285 False True False
268 False False True
218 False False True
234 False False True
187 False False True

307 rows × 3 columns

y_test
Adelie Chinstrap Gentoo
333 False True False
112 True False False
239 False False True
269 False False True
311 False True False
12 True False False
219 False False True
278 False True False
161 False False True
84 True False False
173 False False True
119 True False False
235 False False True
80 True False False
42 True False False
255 False False True
122 True False False
72 True False False
14 True False False
216 False False True
199 False False True
299 False True False
32 True False False
19 True False False
334 False True False
188 False False True
222 False False True
105 True False False
310 False True False
41 True False False
196 False False True
71 True False False
335 False True False
118 True False False
270 False False True
from tensorflow import keras
2024-01-09 07:27:03.676328: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
from numpy.random import seed
seed(1)
from tensorflow.random import set_seed
set_seed(2)
inputs = keras.Input(shape=X_train.shape[1])
inputs
<KerasTensor: shape=(None, 5) dtype=float32 (created by layer 'input_1')>
hidden_layer = keras.layers.Dense(10, activation="relu")(inputs)
2024-01-09 07:27:06.498989: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20953 MB memory:  -> device: 0, name: NVIDIA A10, pci bus id: 0000:61:00.0, compute capability: 8.6
output_layer = keras.layers.Dense(3, activation="softmax")(hidden_layer)
model = keras.Model(inputs=inputs, outputs=output_layer)
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 5)]               0         
                                                                 
 dense (Dense)               (None, 10)                60        
                                                                 
 dense_1 (Dense)             (None, 3)                 33        
                                                                 
=================================================================
Total params: 93 (372.00 Byte)
Trainable params: 93 (372.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
import tensorflow as tf
model.compile(optimizer=tf.keras.optimizers.legacy.Adam(), loss=keras.losses.CategoricalCrossentropy())
history = model.fit(X_train, y_train, epochs=100)
Epoch 1/100
10/10 [==============================] - 0s 2ms/step - loss: 1015.4874
Epoch 2/100
10/10 [==============================] - 0s 2ms/step - loss: 884.8992
Epoch 3/100
10/10 [==============================] - 0s 2ms/step - loss: 761.2500
Epoch 4/100
10/10 [==============================] - 0s 2ms/step - loss: 643.1030
Epoch 5/100
10/10 [==============================] - 0s 2ms/step - loss: 535.1355
Epoch 6/100
10/10 [==============================] - 0s 2ms/step - loss: 457.2358
Epoch 7/100
10/10 [==============================] - 0s 2ms/step - loss: 378.4028
Epoch 8/100
 1/10 [==>...........................] - ETA: 0s - loss: 378.2369
2024-01-09 07:27:07.202523: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
10/10 [==============================] - 0s 2ms/step - loss: 303.1219
Epoch 9/100
10/10 [==============================] - 0s 2ms/step - loss: 226.0870
Epoch 10/100
10/10 [==============================] - 0s 2ms/step - loss: 154.4714
Epoch 11/100
10/10 [==============================] - 0s 2ms/step - loss: 82.2054
Epoch 12/100
10/10 [==============================] - 0s 2ms/step - loss: 26.8942
Epoch 13/100
10/10 [==============================] - 0s 2ms/step - loss: 29.9371
Epoch 14/100
10/10 [==============================] - 0s 2ms/step - loss: 21.5699
Epoch 15/100
10/10 [==============================] - 0s 2ms/step - loss: 20.9316
Epoch 16/100
10/10 [==============================] - 0s 2ms/step - loss: 19.9883
Epoch 17/100
10/10 [==============================] - 0s 2ms/step - loss: 18.9656
Epoch 18/100
10/10 [==============================] - 0s 2ms/step - loss: 18.5297
Epoch 19/100
10/10 [==============================] - 0s 2ms/step - loss: 17.7819
Epoch 20/100
10/10 [==============================] - 0s 2ms/step - loss: 16.8236
Epoch 21/100
10/10 [==============================] - 0s 2ms/step - loss: 16.1105
Epoch 22/100
10/10 [==============================] - 0s 2ms/step - loss: 15.7117
Epoch 23/100
10/10 [==============================] - 0s 2ms/step - loss: 14.7285
Epoch 24/100
10/10 [==============================] - 0s 2ms/step - loss: 13.9431
Epoch 25/100
10/10 [==============================] - 0s 2ms/step - loss: 13.1817
Epoch 26/100
10/10 [==============================] - 0s 2ms/step - loss: 12.4109
Epoch 27/100
10/10 [==============================] - 0s 2ms/step - loss: 12.4272
Epoch 28/100
10/10 [==============================] - 0s 2ms/step - loss: 11.2461
Epoch 29/100
10/10 [==============================] - 0s 2ms/step - loss: 10.7755
Epoch 30/100
10/10 [==============================] - 0s 2ms/step - loss: 9.7845
Epoch 31/100
10/10 [==============================] - 0s 2ms/step - loss: 9.1328
Epoch 32/100
10/10 [==============================] - 0s 2ms/step - loss: 8.9986
Epoch 33/100
10/10 [==============================] - 0s 2ms/step - loss: 7.6808
Epoch 34/100
10/10 [==============================] - 0s 2ms/step - loss: 7.3601
Epoch 35/100
10/10 [==============================] - 0s 2ms/step - loss: 6.5697
Epoch 36/100
10/10 [==============================] - 0s 2ms/step - loss: 6.0516
Epoch 37/100
10/10 [==============================] - 0s 2ms/step - loss: 5.4636
Epoch 38/100
10/10 [==============================] - 0s 2ms/step - loss: 5.0098
Epoch 39/100
10/10 [==============================] - 0s 2ms/step - loss: 4.7890
Epoch 40/100
10/10 [==============================] - 0s 2ms/step - loss: 4.3679
Epoch 41/100
10/10 [==============================] - 0s 2ms/step - loss: 4.5915
Epoch 42/100
10/10 [==============================] - 0s 2ms/step - loss: 4.2618
Epoch 43/100
10/10 [==============================] - 0s 2ms/step - loss: 3.8259
Epoch 44/100
10/10 [==============================] - 0s 2ms/step - loss: 3.7209
Epoch 45/100
10/10 [==============================] - 0s 2ms/step - loss: 3.4403
Epoch 46/100
10/10 [==============================] - 0s 2ms/step - loss: 3.4103
Epoch 47/100
10/10 [==============================] - 0s 2ms/step - loss: 3.8025
Epoch 48/100
10/10 [==============================] - 0s 2ms/step - loss: 3.6355
Epoch 49/100
10/10 [==============================] - 0s 2ms/step - loss: 3.1812
Epoch 50/100
10/10 [==============================] - 0s 2ms/step - loss: 2.6628
Epoch 51/100
10/10 [==============================] - 0s 2ms/step - loss: 2.4035
Epoch 52/100
10/10 [==============================] - 0s 2ms/step - loss: 2.3838
Epoch 53/100
10/10 [==============================] - 0s 2ms/step - loss: 3.0629
Epoch 54/100
10/10 [==============================] - 0s 2ms/step - loss: 2.3382
Epoch 55/100
10/10 [==============================] - 0s 2ms/step - loss: 2.1033
Epoch 56/100
10/10 [==============================] - 0s 2ms/step - loss: 1.6672
Epoch 57/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5632
Epoch 58/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5092
Epoch 59/100
10/10 [==============================] - 0s 2ms/step - loss: 1.3471
Epoch 60/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2847
Epoch 61/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5229
Epoch 62/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5827
Epoch 63/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2878
Epoch 64/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0977
Epoch 65/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1016
Epoch 66/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2771
Epoch 67/100
10/10 [==============================] - 0s 2ms/step - loss: 1.5543
Epoch 68/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1844
Epoch 69/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1441
Epoch 70/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1843
Epoch 71/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1437
Epoch 72/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0442
Epoch 73/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1147
Epoch 74/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0953
Epoch 75/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1173
Epoch 76/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0520
Epoch 77/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0242
Epoch 78/100
10/10 [==============================] - 0s 2ms/step - loss: 1.0637
Epoch 79/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1284
Epoch 80/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1324
Epoch 81/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1172
Epoch 82/100
10/10 [==============================] - 0s 2ms/step - loss: 1.2402
Epoch 83/100
10/10 [==============================] - 0s 2ms/step - loss: 1.1240
Epoch 84/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0794
Epoch 85/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0693
Epoch 86/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0898
Epoch 87/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0458
Epoch 88/100
10/10 [==============================] - 0s 1ms/step - loss: 1.2224
Epoch 89/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0337
Epoch 90/100
10/10 [==============================] - 0s 1ms/step - loss: 1.1267
Epoch 91/100
10/10 [==============================] - 0s 1ms/step - loss: 2.1266
Epoch 92/100
10/10 [==============================] - 0s 1ms/step - loss: 1.9268
Epoch 93/100
10/10 [==============================] - 0s 1ms/step - loss: 1.1867
Epoch 94/100
10/10 [==============================] - 0s 1ms/step - loss: 0.9464
Epoch 95/100
10/10 [==============================] - 0s 990us/step - loss: 1.4309
Epoch 96/100
10/10 [==============================] - 0s 1ms/step - loss: 1.1865
Epoch 97/100
10/10 [==============================] - 0s 996us/step - loss: 1.6076
Epoch 98/100
10/10 [==============================] - 0s 1ms/step - loss: 1.6465
Epoch 99/100
10/10 [==============================] - 0s 1ms/step - loss: 1.3923
Epoch 100/100
10/10 [==============================] - 0s 1ms/step - loss: 1.0014
sns.lineplot(x=history.epoch, y=history.history['loss'])
<Axes: >
_images/927a575f84ddd64bb11d0f62f0d1a08f2bcbc1cd768f6f648e5ca6d11c472fc2.png
y_pred = model.predict(X_test)
prediction = pd.DataFrame(y_pred, columns=target.columns)
prediction
2/2 [==============================] - 0s 7ms/step
Adelie Chinstrap Gentoo
0 7.648391e-01 1.715673e-01 6.359363e-02
1 5.454711e-01 4.545289e-01 2.100461e-08
2 1.833452e-09 6.493035e-10 1.000000e+00
3 7.966254e-15 1.784904e-15 1.000000e+00
4 8.872245e-01 1.109848e-01 1.790675e-03
5 9.873828e-01 1.261715e-02 4.766879e-10
6 1.018278e-14 5.001138e-14 1.000000e+00
7 9.829051e-01 1.709081e-02 3.938773e-06
8 1.619196e-08 9.007781e-09 1.000000e+00
9 7.116187e-01 2.883812e-01 1.857964e-07
10 1.152847e-07 1.505111e-07 9.999998e-01
11 9.103811e-01 8.961879e-02 3.948410e-08
12 5.516306e-11 1.907833e-10 1.000000e+00
13 6.351340e-01 3.648659e-01 1.353383e-08
14 8.191870e-01 1.808130e-01 1.281623e-09
15 1.306449e-12 1.319104e-11 1.000000e+00
16 9.992987e-01 7.012586e-04 1.666273e-09
17 5.740452e-01 4.259438e-01 1.094981e-05
18 2.500658e-02 6.073019e-03 9.689205e-01
19 2.341845e-06 3.067315e-05 9.999670e-01
20 1.879179e-11 7.223661e-11 1.000000e+00
21 9.881677e-01 1.179230e-02 4.009461e-05
22 9.027959e-01 9.720411e-02 2.563040e-08
23 9.342888e-01 1.061412e-02 5.509710e-02
24 7.657505e-01 2.333840e-01 8.654961e-04
25 3.820047e-07 4.642012e-07 9.999992e-01
26 4.629333e-06 1.199403e-05 9.999834e-01
27 9.900614e-01 9.938368e-03 2.153488e-07
28 9.464169e-01 5.357845e-02 4.640924e-06
29 8.868090e-01 1.114146e-01 1.776450e-03
30 6.899412e-12 5.361637e-12 1.000000e+00
31 9.746879e-01 2.498482e-02 3.272409e-04
32 8.940583e-01 1.059394e-01 2.274549e-06
33 7.915920e-01 2.084079e-01 1.082515e-07
34 9.376542e-07 6.750374e-07 9.999983e-01
predicted_species = prediction.idxmax(axis="columns")
predicted_species
0     Adelie
1     Adelie
2     Gentoo
3     Gentoo
4     Adelie
5     Adelie
6     Gentoo
7     Adelie
8     Gentoo
9     Adelie
10    Gentoo
11    Adelie
12    Gentoo
13    Adelie
14    Adelie
15    Gentoo
16    Adelie
17    Adelie
18    Gentoo
19    Gentoo
20    Gentoo
21    Adelie
22    Adelie
23    Adelie
24    Adelie
25    Gentoo
26    Gentoo
27    Adelie
28    Adelie
29    Adelie
30    Gentoo
31    Adelie
32    Adelie
33    Adelie
34    Gentoo
dtype: category
Categories (3, object): ['Adelie', 'Chinstrap', 'Gentoo']
from sklearn.metrics import confusion_matrix

true_species = y_test.idxmax(axis="columns")

matrix = confusion_matrix(true_species, predicted_species)
print(matrix)
[[14  0  1]
 [ 7  0  0]
 [ 0  0 13]]
# Convert to a pandas dataframe
confusion_df = pd.DataFrame(matrix, index=y_test.columns.values, columns=y_test.columns.values)

# Set the names of the x and y axis, this helps with the readability of the heatmap.
confusion_df.index.name = 'True Label'
confusion_df.columns.name = 'Predicted Label'
sns.heatmap(confusion_df, annot=True)
<Axes: xlabel='Predicted Label', ylabel='True Label'>
_images/cc078bcfe89ee551aea2658c1592e12aa139108f8538e2f759ea59f7f243c096.png
model.save('my_first_model')
INFO:tensorflow:Assets written to: my_first_model/assets
INFO:tensorflow:Assets written to: my_first_model/assets
pretrained_model = keras.models.load_model('my_first_model')
# use the pretrained model here
y_pretrained_pred = pretrained_model.predict(X_test)
pretrained_prediction = pd.DataFrame(y_pretrained_pred, columns=target.columns.values)

# idxmax will select the column for each row with the highest value
pretrained_predicted_species = pretrained_prediction.idxmax(axis="columns")
print(pretrained_predicted_species)
2/2 [==============================] - 0s 2ms/step
0     Adelie
1     Adelie
2     Gentoo
3     Gentoo
4     Adelie
5     Adelie
6     Gentoo
7     Adelie
8     Gentoo
9     Adelie
10    Gentoo
11    Adelie
12    Gentoo
13    Adelie
14    Adelie
15    Gentoo
16    Adelie
17    Adelie
18    Gentoo
19    Gentoo
20    Gentoo
21    Adelie
22    Adelie
23    Adelie
24    Adelie
25    Gentoo
26    Gentoo
27    Adelie
28    Adelie
29    Adelie
30    Gentoo
31    Adelie
32    Adelie
33    Adelie
34    Gentoo
dtype: category
Categories (3, object): ['Adelie', 'Chinstrap', 'Gentoo']