{ "cells": [ { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import tensorflow as tf\n", "import numpy as np\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model_path = \"transferlearning_pokemon.keras\"\n", "model = tf.keras.models.load_model(model_path)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Define the core prediction function\n", "def predict_pokemon(image):\n", " # Preprocess image\n", " print(type(image))\n", " image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image\n", " image = image.resize((150, 150)) # Resize the image to 150x150\n", " image = np.array(image)\n", " image = np.expand_dims(image, axis=0) # Expand dimensions to match the model input shape\n", " \n", " # Predict\n", " prediction = model.predict(image)\n", " \n", " # Print the shape of the prediction to debug\n", " print(f\"Prediction shape: {prediction.shape}\")\n", " \n", " # Assuming the output is already softmax probabilities\n", " probabilities = prediction[0]\n", " \n", " # Print the probabilities array to debug\n", " print(f\"Probabilities: {probabilities}\")\n", " \n", " # Assuming your model was trained with these class names\n", " class_names = ['charmander', 'eevee', 'pikachuu'] # Replace 'another_pokemon' with your third class name\n", " \n", " # Create a dictionary of class probabilities\n", " result = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}\n", " \n", " return result" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7866\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 140ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [9.1263162e-31 1.1169604e-30 1.0000000e+00]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 90ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [4.4493477e-06 8.4401548e-01 1.5598010e-01]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [9.9999964e-01 1.0916104e-07 1.8336594e-07]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 78ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [5.0329237e-04 8.8987160e-01 1.0962512e-01]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 82ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [9.1263162e-31 1.1169604e-30 1.0000000e+00]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [4.4493477e-06 8.4401548e-01 1.5598010e-01]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [5.0329237e-04 8.8987160e-01 1.0962512e-01]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [5.0329237e-04 8.8987160e-01 1.0962512e-01]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [9.9999964e-01 1.0916104e-07 1.8336594e-07]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [4.0465540e-22 8.3268744e-22 1.0000000e+00]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 75ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [9.9999964e-01 1.0916104e-07 1.8336594e-07]\n", "\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n", "Prediction shape: (1, 3)\n", "Probabilities: [5.0329237e-04 8.8987160e-01 1.0962512e-01]\n" ] } ], "source": [ "# Create the Gradio interface\n", "input_image = gr.Image()\n", "iface = gr.Interface(\n", " fn=predict_pokemon,\n", " inputs=input_image, \n", " outputs=gr.Label(),\n", " examples=[\"pokemon_examples/charmander.png\", \"pokemon_examples/charmander1.jpg\", \"pokemon_examples/eevee.png\", \"pokemon_examples/eevee1.jpg\", \"pokemon_examples/pika.png\", \"pokemon_examples/pika1.jpg\"], \n", " description=\"A simple mlp classification model for image classification using the mnist dataset.\")\n", "iface.launch()" ] } ], "metadata": { "kernelspec": { "display_name": "venv_new", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }