reiflja1 commited on
Commit
8a22930
1 Parent(s): 85bbc81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -3,25 +3,30 @@ import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
 
6
- # Load the model
7
- model = tf.keras.models.load_model('pokemon_classifier_model.keras')
 
8
 
9
- def predict(image):
10
- img = tf.keras.preprocessing.image.img_to_array(image)
11
- img = tf.keras.preprocessing.image.smart_resize(img, (224, 224))
12
- img = tf.expand_dims(img, 0) # Make batch of one
13
-
14
- pred = model.predict(img)
15
- pred_label = tf.argmax(pred, axis=1).numpy()[0] # get the index of the max logit
16
- pred_class = class_names[pred_label] # use the index to get the corresponding class name
17
- confidence = tf.nn.softmax(pred)[0][pred_label] # softmax to get the confidence
18
-
19
- print(f"Predicted: {pred_class}, Confidence: {confidence:.4f}")
20
- return pred_class
21
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Setup Gradio interface
24
- iface = gr.Interface(fn=predict, inputs=gr.Image(), outputs="text", title="Pokémon Classifier")
25
-
26
- # Run the interface
27
- iface.launch()
 
 
 
 
 
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ # Load your custom regression model
7
+ model_path = "pokemon_model_tl.keras"
8
+ model = tf.keras.models.load_model(model_path)
9
 
10
+ labels = ['Wartortle', 'Weedle', 'Weepinbell', 'Weezing']
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Define regression function
13
+ def predict_regression(image):
14
+ # Preprocess image
15
+ image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
16
+ image = image.resize((150, 150)).convert('RGB') #resize the image to 28x28 and converts it to gray scale
17
+ image = np.array(image)
18
+ print(image.shape)
19
+ # Predict
20
+ prediction = model.predict(image[None, ...]) # Assuming single regression value
21
+ confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
22
+ return confidences
23
 
24
+ # Create Gradio interface
25
+ input_image = gr.Image()
26
+ output_text = gr.Textbox(label="Predicted Value")
27
+ interface = gr.Interface(fn=predict_regression,
28
+ inputs=input_image,
29
+ outputs=gr.Label(),
30
+ examples=["wartortle.jpg", "weedle.jpg", "weepinbell.jpg", "weezing.jpg"],
31
+ description="A simple mlp classification model for image classification using the pokemon dataset.")
32
+ interface.launch()