kramesab commited on
Commit
c68cac3
1 Parent(s): ad60e8e

Delete Untitled-1.py

Browse files
Files changed (1) hide show
  1. Untitled-1.py +0 -60
Untitled-1.py DELETED
@@ -1,60 +0,0 @@
1
- # %%
2
- import gradio as gr
3
- import numpy as np
4
- from tensorflow.keras.models import load_model
5
- from tensorflow.keras.applications.resnet50 import preprocess_input
6
- from PIL import Image
7
-
8
- # Load the pre-trained Keras model
9
- model = load_model('pokemon-model.keras')
10
-
11
- # Define the class labels
12
- class_labels = ['Bulbasaur', 'Glumanda', 'Pikachu'] # Ensure this matches the training order
13
-
14
- # Define the image processing and prediction function
15
- def predict_image(img):
16
- # Ensure the image is a PIL image
17
- if not isinstance(img, Image.Image):
18
- img = Image.fromarray(img)
19
-
20
- # Resize the image to the size expected by ResNet50
21
- img = img.resize((224, 224))
22
-
23
- # Convert the image to a numpy array
24
- img_array = np.array(img)
25
-
26
- # Convert the image array to a batch of size 1 (1, 224, 224, 3)
27
- img_array = np.expand_dims(img_array, axis=0)
28
-
29
- # Preprocess the image array using ResNet50's preprocessing
30
- img_array = preprocess_input(img_array)
31
-
32
- # Make prediction
33
- prediction = model.predict(img_array)
34
-
35
- # Get the label with the highest probability
36
- predicted_index = int(np.argmax(prediction))
37
- predicted_label = class_labels[predicted_index]
38
-
39
- return predicted_label
40
-
41
- # Create the Gradio interface with multiple examples
42
- iface = gr.Interface(
43
- fn=predict_image,
44
- inputs=gr.Image(image_mode='RGB'),
45
- outputs='label',
46
- examples=[['00000015.jpg'], ['20.png'], ['glumanda.jpg'], ['j67j7.png'], ['pikachu.jpg']],
47
- title="Pokémon Classification",
48
- description="Upload an image of a Pokémon to classify it using the pre-trained model."
49
- )
50
-
51
- # Launch the interface inline in the Jupyter Notebook
52
- iface.launch(inline=True)
53
-
54
-
55
- # %%
56
- # Print model summary to verify input shape
57
- print(model.summary())
58
-
59
-
60
-