{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from keras.datasets import cifar10\n", "from keras.utils import to_categorical\n", "\n", "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", "\n", "# Normalize pixel values\n", "x_train = x_train.astype('float32') / 255.0\n", "x_test = x_test.astype('float32') / 255.0\n", "\n", "# One-hot encode the labels\n", "y_train = to_categorical(y_train, num_classes=10)\n", "y_test = to_categorical(y_test, num_classes=10)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from keras.models import Sequential\n", "from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout\n", "\n", "model = Sequential()\n", "model.add(Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)))\n", "model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))\n", "model.add(MaxPooling2D((2, 2)))\n", "model.add(Dropout(0.25))\n", "\n", "model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))\n", "model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))\n", "model.add(MaxPooling2D((2, 2)))\n", "model.add(Dropout(0.25))\n", "\n", "model.add(Flatten())\n", "model.add(Dense(512, activation='relu'))\n", "model.add(Dropout(0.5))\n", "model.add(Dense(10, activation='softmax'))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss='categorical_crossentropy',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "352/352 [==============================] - 69s 195ms/step - loss: 1.6485 - accuracy: 0.3958 - val_loss: 1.2859 - val_accuracy: 0.5282\n", "Epoch 2/50\n", "352/352 [==============================] - 68s 193ms/step - loss: 1.2111 - accuracy: 0.5655 - val_loss: 0.9605 - val_accuracy: 0.6640\n", "Epoch 3/50\n", "352/352 [==============================] - 73s 208ms/step - loss: 1.0086 - accuracy: 0.6435 - val_loss: 0.8589 - val_accuracy: 0.6952\n", "Epoch 4/50\n", "352/352 [==============================] - 101s 288ms/step - loss: 0.8924 - accuracy: 0.6857 - val_loss: 0.7665 - val_accuracy: 0.7336\n", "Epoch 5/50\n", "352/352 [==============================] - 69s 197ms/step - loss: 0.8080 - accuracy: 0.7136 - val_loss: 0.7567 - val_accuracy: 0.7392\n", "Epoch 6/50\n", "352/352 [==============================] - 69s 197ms/step - loss: 0.7429 - accuracy: 0.7398 - val_loss: 0.6749 - val_accuracy: 0.7666\n", "Epoch 7/50\n", "352/352 [==============================] - 71s 202ms/step - loss: 0.6849 - accuracy: 0.7584 - val_loss: 0.7045 - val_accuracy: 0.7512\n", "Epoch 8/50\n", "352/352 [==============================] - 75s 214ms/step - loss: 0.6489 - accuracy: 0.7703 - val_loss: 0.6511 - val_accuracy: 0.7778\n", "Epoch 9/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.6011 - accuracy: 0.7880 - val_loss: 0.6165 - val_accuracy: 0.7866\n", "Epoch 10/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.5679 - accuracy: 0.8005 - val_loss: 0.6104 - val_accuracy: 0.7870\n", "Epoch 11/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.5338 - accuracy: 0.8103 - val_loss: 0.6394 - val_accuracy: 0.7878\n", "Epoch 12/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.5041 - accuracy: 0.8208 - val_loss: 0.5970 - val_accuracy: 0.7942\n", "Epoch 13/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.4798 - accuracy: 0.8284 - val_loss: 0.5972 - val_accuracy: 0.8032\n", "Epoch 14/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.4508 - accuracy: 0.8380 - val_loss: 0.5943 - val_accuracy: 0.8000\n", "Epoch 15/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.4290 - accuracy: 0.8470 - val_loss: 0.6002 - val_accuracy: 0.8002\n", "Epoch 16/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.4112 - accuracy: 0.8536 - val_loss: 0.6042 - val_accuracy: 0.8038\n", "Epoch 17/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.3918 - accuracy: 0.8590 - val_loss: 0.6220 - val_accuracy: 0.7956\n", "Epoch 18/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.3798 - accuracy: 0.8650 - val_loss: 0.6229 - val_accuracy: 0.8052\n", "Epoch 19/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.3636 - accuracy: 0.8699 - val_loss: 0.6045 - val_accuracy: 0.8090\n", "Epoch 20/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.3456 - accuracy: 0.8763 - val_loss: 0.6035 - val_accuracy: 0.8086\n", "Epoch 21/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.3341 - accuracy: 0.8798 - val_loss: 0.6066 - val_accuracy: 0.8118\n", "Epoch 22/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.3196 - accuracy: 0.8854 - val_loss: 0.6200 - val_accuracy: 0.8054\n", "Epoch 23/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.3076 - accuracy: 0.8889 - val_loss: 0.6121 - val_accuracy: 0.8036\n", "Epoch 24/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.3041 - accuracy: 0.8927 - val_loss: 0.6166 - val_accuracy: 0.8058\n", "Epoch 25/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2892 - accuracy: 0.8978 - val_loss: 0.6202 - val_accuracy: 0.8082\n", "Epoch 26/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2808 - accuracy: 0.8987 - val_loss: 0.6370 - val_accuracy: 0.8126\n", "Epoch 27/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2791 - accuracy: 0.9014 - val_loss: 0.6416 - val_accuracy: 0.8132\n", "Epoch 28/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.2661 - accuracy: 0.9062 - val_loss: 0.6206 - val_accuracy: 0.8070\n", "Epoch 29/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.2627 - accuracy: 0.9073 - val_loss: 0.6165 - val_accuracy: 0.8148\n", "Epoch 30/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2551 - accuracy: 0.9108 - val_loss: 0.6796 - val_accuracy: 0.8084\n", "Epoch 31/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.2521 - accuracy: 0.9110 - val_loss: 0.6634 - val_accuracy: 0.8148\n", "Epoch 32/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.2434 - accuracy: 0.9137 - val_loss: 0.6582 - val_accuracy: 0.8082\n", "Epoch 33/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2403 - accuracy: 0.9138 - val_loss: 0.6734 - val_accuracy: 0.8084\n", "Epoch 34/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2346 - accuracy: 0.9172 - val_loss: 0.6654 - val_accuracy: 0.8066\n", "Epoch 35/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2287 - accuracy: 0.9198 - val_loss: 0.6993 - val_accuracy: 0.8088\n", "Epoch 36/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.2340 - accuracy: 0.9176 - val_loss: 0.6680 - val_accuracy: 0.8082\n", "Epoch 37/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2196 - accuracy: 0.9228 - val_loss: 0.6933 - val_accuracy: 0.8144\n", "Epoch 38/50\n", "352/352 [==============================] - 76s 216ms/step - loss: 0.2199 - accuracy: 0.9228 - val_loss: 0.6392 - val_accuracy: 0.8126\n", "Epoch 39/50\n", "352/352 [==============================] - 76s 215ms/step - loss: 0.2190 - accuracy: 0.9236 - val_loss: 0.6644 - val_accuracy: 0.8172\n", "Epoch 40/50\n", "352/352 [==============================] - 77s 220ms/step - loss: 0.2114 - accuracy: 0.9254 - val_loss: 0.6753 - val_accuracy: 0.8140\n", "Epoch 41/50\n", "352/352 [==============================] - 74s 210ms/step - loss: 0.2058 - accuracy: 0.9286 - val_loss: 0.6879 - val_accuracy: 0.8130\n", "Epoch 42/50\n", "352/352 [==============================] - 73s 208ms/step - loss: 0.2038 - accuracy: 0.9280 - val_loss: 0.6891 - val_accuracy: 0.8178\n", "Epoch 43/50\n", "352/352 [==============================] - 75s 214ms/step - loss: 0.2085 - accuracy: 0.9265 - val_loss: 0.6986 - val_accuracy: 0.8160\n", "Epoch 44/50\n", "352/352 [==============================] - 77s 219ms/step - loss: 0.2026 - accuracy: 0.9286 - val_loss: 0.6906 - val_accuracy: 0.8046\n", "Epoch 45/50\n", "352/352 [==============================] - 104s 296ms/step - loss: 0.1975 - accuracy: 0.9297 - val_loss: 0.7136 - val_accuracy: 0.8130\n", "Epoch 46/50\n", "352/352 [==============================] - 72s 205ms/step - loss: 0.2015 - accuracy: 0.9296 - val_loss: 0.6834 - val_accuracy: 0.8092\n", "Epoch 47/50\n", "352/352 [==============================] - 72s 205ms/step - loss: 0.1904 - accuracy: 0.9335 - val_loss: 0.6789 - val_accuracy: 0.8156\n", "Epoch 48/50\n", "352/352 [==============================] - 75s 212ms/step - loss: 0.1879 - accuracy: 0.9347 - val_loss: 0.7115 - val_accuracy: 0.8162\n", "Epoch 49/50\n", "352/352 [==============================] - 77s 218ms/step - loss: 0.1856 - accuracy: 0.9346 - val_loss: 0.7351 - val_accuracy: 0.8122\n", "Epoch 50/50\n", "352/352 [==============================] - 74s 211ms/step - loss: 0.1834 - accuracy: 0.9361 - val_loss: 0.7103 - val_accuracy: 0.8200\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 128\n", "epochs = 50\n", "\n", "model.fit(x_train, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " validation_split=0.1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test loss: 0.7982190847396851\n", "Test accuracy: 0.7968000173568726\n" ] } ], "source": [ "score = model.evaluate(x_test, y_test, verbose=0)\n", "print(f'Test loss: {score[0]}')\n", "print(f'Test accuracy: {score[1]}')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "390/390 [==============================] - 97s 247ms/step - loss: 0.9107 - accuracy: 0.6928 - val_loss: 0.6933 - val_accuracy: 0.7650\n", "Epoch 2/50\n", "390/390 [==============================] - 100s 257ms/step - loss: 0.8106 - accuracy: 0.7224 - val_loss: 0.6537 - val_accuracy: 0.7749\n", "Epoch 3/50\n", "390/390 [==============================] - 96s 246ms/step - loss: 0.7845 - accuracy: 0.7299 - val_loss: 0.6504 - val_accuracy: 0.7822\n", "Epoch 4/50\n", "390/390 [==============================] - 93s 237ms/step - loss: 0.7672 - accuracy: 0.7350 - val_loss: 0.6664 - val_accuracy: 0.7761\n", "Epoch 5/50\n", "390/390 [==============================] - 96s 246ms/step - loss: 0.7436 - accuracy: 0.7443 - val_loss: 0.6713 - val_accuracy: 0.7691\n", "Epoch 6/50\n", "390/390 [==============================] - 95s 242ms/step - loss: 0.7276 - accuracy: 0.7493 - val_loss: 0.6087 - val_accuracy: 0.7958\n", "Epoch 7/50\n", "390/390 [==============================] - 93s 238ms/step - loss: 0.7221 - accuracy: 0.7492 - val_loss: 0.6521 - val_accuracy: 0.7838\n", "Epoch 8/50\n", "390/390 [==============================] - 99s 253ms/step - loss: 0.7092 - accuracy: 0.7545 - val_loss: 0.6414 - val_accuracy: 0.7822\n", "Epoch 9/50\n", "390/390 [==============================] - 97s 249ms/step - loss: 0.6963 - accuracy: 0.7605 - val_loss: 0.5902 - val_accuracy: 0.7986\n", "Epoch 10/50\n", "390/390 [==============================] - 90s 231ms/step - loss: 0.6905 - accuracy: 0.7616 - val_loss: 0.6284 - val_accuracy: 0.7890\n", "Epoch 11/50\n", "390/390 [==============================] - 85s 217ms/step - loss: 0.6846 - accuracy: 0.7606 - val_loss: 0.5959 - val_accuracy: 0.8034\n", "Epoch 12/50\n", "390/390 [==============================] - 86s 221ms/step - loss: 0.6759 - accuracy: 0.7647 - val_loss: 0.6337 - val_accuracy: 0.7914\n", "Epoch 13/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.6675 - accuracy: 0.7686 - val_loss: 0.6242 - val_accuracy: 0.7947\n", "Epoch 14/50\n", "390/390 [==============================] - 91s 234ms/step - loss: 0.6705 - accuracy: 0.7676 - val_loss: 0.6022 - val_accuracy: 0.7946\n", "Epoch 15/50\n", "390/390 [==============================] - 88s 226ms/step - loss: 0.6559 - accuracy: 0.7724 - val_loss: 0.6387 - val_accuracy: 0.7818\n", "Epoch 16/50\n", "390/390 [==============================] - 88s 227ms/step - loss: 0.6446 - accuracy: 0.7756 - val_loss: 0.6323 - val_accuracy: 0.7935\n", "Epoch 17/50\n", "390/390 [==============================] - 88s 227ms/step - loss: 0.6514 - accuracy: 0.7748 - val_loss: 0.5713 - val_accuracy: 0.8084\n", "Epoch 18/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.6500 - accuracy: 0.7753 - val_loss: 0.6053 - val_accuracy: 0.7970\n", "Epoch 19/50\n", "390/390 [==============================] - 88s 226ms/step - loss: 0.6371 - accuracy: 0.7797 - val_loss: 0.5728 - val_accuracy: 0.8037\n", "Epoch 20/50\n", "390/390 [==============================] - 90s 230ms/step - loss: 0.6281 - accuracy: 0.7828 - val_loss: 0.5671 - val_accuracy: 0.8131\n", "Epoch 21/50\n", "390/390 [==============================] - 87s 222ms/step - loss: 0.6257 - accuracy: 0.7799 - val_loss: 0.5519 - val_accuracy: 0.8133\n", "Epoch 22/50\n", "390/390 [==============================] - 87s 223ms/step - loss: 0.6203 - accuracy: 0.7845 - val_loss: 0.5605 - val_accuracy: 0.8143\n", "Epoch 23/50\n", "390/390 [==============================] - 89s 229ms/step - loss: 0.6205 - accuracy: 0.7851 - val_loss: 0.5704 - val_accuracy: 0.8075\n", "Epoch 24/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.6113 - accuracy: 0.7870 - val_loss: 0.5804 - val_accuracy: 0.8053\n", "Epoch 25/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.6066 - accuracy: 0.7885 - val_loss: 0.5350 - val_accuracy: 0.8210\n", "Epoch 26/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.6116 - accuracy: 0.7879 - val_loss: 0.5985 - val_accuracy: 0.8009\n", "Epoch 27/50\n", "390/390 [==============================] - 88s 227ms/step - loss: 0.6057 - accuracy: 0.7903 - val_loss: 0.5751 - val_accuracy: 0.8096\n", "Epoch 28/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.6032 - accuracy: 0.7922 - val_loss: 0.5915 - val_accuracy: 0.8031\n", "Epoch 29/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.6055 - accuracy: 0.7878 - val_loss: 0.6113 - val_accuracy: 0.8034\n", "Epoch 30/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.5959 - accuracy: 0.7921 - val_loss: 0.5415 - val_accuracy: 0.8230\n", "Epoch 31/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5904 - accuracy: 0.7971 - val_loss: 0.5385 - val_accuracy: 0.8208\n", "Epoch 32/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.5876 - accuracy: 0.7952 - val_loss: 0.5533 - val_accuracy: 0.8160\n", "Epoch 33/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5902 - accuracy: 0.7945 - val_loss: 0.5819 - val_accuracy: 0.8108\n", "Epoch 34/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5878 - accuracy: 0.7952 - val_loss: 0.5956 - val_accuracy: 0.8055\n", "Epoch 35/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5837 - accuracy: 0.7982 - val_loss: 0.5445 - val_accuracy: 0.8203\n", "Epoch 36/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.5773 - accuracy: 0.7998 - val_loss: 0.5764 - val_accuracy: 0.8081\n", "Epoch 37/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.5763 - accuracy: 0.7982 - val_loss: 0.5113 - val_accuracy: 0.8290\n", "Epoch 38/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.5695 - accuracy: 0.8028 - val_loss: 0.5255 - val_accuracy: 0.8279\n", "Epoch 39/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5743 - accuracy: 0.8021 - val_loss: 0.5300 - val_accuracy: 0.8269\n", "Epoch 40/50\n", "390/390 [==============================] - 89s 229ms/step - loss: 0.5691 - accuracy: 0.8035 - val_loss: 0.5385 - val_accuracy: 0.8244\n", "Epoch 41/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5675 - accuracy: 0.8013 - val_loss: 0.5401 - val_accuracy: 0.8199\n", "Epoch 42/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5690 - accuracy: 0.8019 - val_loss: 0.5816 - val_accuracy: 0.8117\n", "Epoch 43/50\n", "390/390 [==============================] - 89s 227ms/step - loss: 0.5611 - accuracy: 0.8069 - val_loss: 0.5632 - val_accuracy: 0.8143\n", "Epoch 44/50\n", "390/390 [==============================] - 89s 228ms/step - loss: 0.5647 - accuracy: 0.8025 - val_loss: 0.5038 - val_accuracy: 0.8345\n", "Epoch 45/50\n", "390/390 [==============================] - 88s 227ms/step - loss: 0.5617 - accuracy: 0.8048 - val_loss: 0.5606 - val_accuracy: 0.8207\n", "Epoch 46/50\n", "390/390 [==============================] - 88s 226ms/step - loss: 0.5575 - accuracy: 0.8044 - val_loss: 0.6000 - val_accuracy: 0.8049\n", "Epoch 47/50\n", "390/390 [==============================] - 88s 227ms/step - loss: 0.5603 - accuracy: 0.8076 - val_loss: 0.5683 - val_accuracy: 0.8095\n", "Epoch 48/50\n", "390/390 [==============================] - 82s 211ms/step - loss: 0.5580 - accuracy: 0.8070 - val_loss: 0.5356 - val_accuracy: 0.8258\n", "Epoch 49/50\n", "390/390 [==============================] - 84s 215ms/step - loss: 0.5531 - accuracy: 0.8088 - val_loss: 0.5358 - val_accuracy: 0.8207\n", "Epoch 50/50\n", "390/390 [==============================] - 88s 226ms/step - loss: 0.5531 - accuracy: 0.8101 - val_loss: 0.5430 - val_accuracy: 0.8236\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Optimize the model\n", "from keras.preprocessing.image import ImageDataGenerator\n", "\n", "datagen = ImageDataGenerator(\n", " rotation_range=15,\n", " width_shift_range=0.1,\n", " height_shift_range=0.1,\n", " horizontal_flip=True\n", ")\n", "\n", "datagen.fit(x_train)\n", "\n", "model.fit(datagen.flow(x_train, y_train, batch_size=batch_size),\n", " steps_per_epoch=len(x_train) // batch_size,\n", " epochs=epochs,\n", " validation_data=(x_test, y_test))\n", " " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test loss: 0.5430349111557007\n", "Test accuracy: 0.8235999941825867\n" ] } ], "source": [ "# Evaluate the model again\n", "score = model.evaluate(x_test, y_test, verbose=0)\n", "print(f'Test loss: {score[0]}')\n", "print(f'Test accuracy: {score[1]}')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Save model\n", "model.save('cifar10_cnn.keras')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 [==============================] - 0s 99ms/step\n", "The model predicted the image as: ship\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Choose image from testing set and predict the label\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Define the labels of the dataset\n", "labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", "\n", "# Choose a random image from the test set\n", "index = np.random.randint(0, x_test.shape[0])\n", "img = x_test[index]\n", "\n", "# Display the image\n", "plt.imshow(img)\n", "\n", "# Add a dimension to the image so that it can be passed to the model\n", "img = np.expand_dims(img, axis=0)\n", "\n", "# Get the model prediction\n", "predicted_label = model.predict(img)\n", "predicted_label = labels[np.argmax(predicted_label)]\n", "print(f'The model predicted the image as: {predicted_label}')" ] } ], "metadata": { "kernelspec": { "display_name": "ml", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }