Fabsi
commited on
Commit
•
3557ad4
1
Parent(s):
82d3f2e
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import tensorflow as tf
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
# Load the pre-trained Pokémon model
|
7 |
+
model_path = "pokemon_classifier_model.keras"
|
8 |
+
model = tf.keras.models.load_model(model_path)
|
9 |
+
|
10 |
+
# Define the Pokémon classes
|
11 |
+
classes = ['Aerodactyl', 'Alakazam', 'Beedrill'] # Adjust these as per your model's classes
|
12 |
+
|
13 |
+
# Define the image classification function
|
14 |
+
def classify_image(image):
|
15 |
+
try:
|
16 |
+
# Ensure the image is in RGB and normalize it
|
17 |
+
if image.ndim == 2: # Check if the image is grayscale
|
18 |
+
image = np.stack((image,)*3, axis=-1) # Convert grayscale to RGB by repeating the gray channel
|
19 |
+
elif image.shape[2] == 4: # Check if the image has an alpha channel
|
20 |
+
image = image[:, :, :3] # Drop the alpha channel
|
21 |
+
image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert to PIL Image to resize
|
22 |
+
image = image.resize((150, 150)) # Resize to match the model's input size
|
23 |
+
|
24 |
+
image_array = np.array(image) / 255.0 # Convert to array and normalize
|
25 |
+
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
|
26 |
+
|
27 |
+
# Predict using the model
|
28 |
+
prediction = model.predict(image_array)
|
29 |
+
predicted_class = classes[np.argmax(prediction)]
|
30 |
+
confidence = np.max(prediction)
|
31 |
+
|
32 |
+
return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
|
33 |
+
except Exception as e:
|
34 |
+
return str(e) # Return the error message if something goes wrong
|
35 |
+
|
36 |
+
# Create Gradio interface
|
37 |
+
input_image = gr.Image() # Using Gradio's Image component correctly
|
38 |
+
output_label = gr.Label()
|
39 |
+
|
40 |
+
interface = gr.Interface(fn=classify_image,
|
41 |
+
inputs=input_image,
|
42 |
+
outputs=output_label,
|
43 |
+
examples=["pokemon/aerodactyl.png", "pokemon/alakazam.png", "pokemon/beedrill.png"],
|
44 |
+
description="Upload an image of a Pokémon (Aerodactyl, Alakazam or Beedrill) to classify!")
|
45 |
+
|
46 |
+
interface.launch()
|