gassmdav commited on
Commit
242ec93
1 Parent(s): 7d85074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -27
app.py CHANGED
@@ -1,44 +1,40 @@
1
- import gradio as gr
2
  import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
 
6
-
7
- # Laden des vortrainierten Pokémon-Modells
8
  model_path = "kia_pokemon_keras_model.h5"
9
  model = tf.keras.models.load_model(model_path)
10
 
11
- # Labels für den Pokémon Classifier
12
- labels = [
13
- 'Bulbasaur','Charmander','Squirtle'
14
- ]
15
 
16
  def predict_pokemon(image):
17
  # Preprocess image
18
- print(type(image)) # Output the type of the input image for debugging
19
- image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
20
- image = image.resize((224, 224)) # Resize the image to 224x224
21
  image = np.array(image)
22
- image = np.expand_dims(image, axis=0) # same as image[None, ...]
23
-
24
  # Predict
25
  predictions = model.predict(image)
26
  prediction = np.argmax(predictions, axis=1)[0]
27
  confidence = np.max(predictions)
28
-
29
- # Vorbereiten der Ausgabe
30
- result = f"Predicted Pokémon: {labels[prediction]} with confidence: {confidence:.2f}"
31
  return result
32
 
33
- # Erstellen der Gradio-Oberfläche
34
- input_image = gr.Image()
35
- output_label = gr.Label()
36
- interface = gr.Interface(fn=predict_pokemon,
37
- inputs=input_image,
38
- outputs=output_label,
39
- examples=["images/bulbasaur.png", "images/charmander.png", "images/squirtle.png"],
40
- title="Pokémon Classifier",
41
- description="Drag and drop an image or select an example below to predict the Pokémon.")
42
-
43
- # Interface starten
44
- interface.launch()
 
1
+ import streamlit as st
2
  import tensorflow as tf
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ # Load the pre-trained Pokémon model
 
7
  model_path = "kia_pokemon_keras_model.h5"
8
  model = tf.keras.models.load_model(model_path)
9
 
10
+ # Pokémon classifier labels
11
+ labels = ['Bulbasaur', 'Charmander', 'Squirtle']
 
 
12
 
13
  def predict_pokemon(image):
14
  # Preprocess image
15
+ image = Image.fromarray(np.array(image).astype('uint8')) # Convert to PIL image
16
+ image = image.resize((224, 224)) # Resize image to 224x224
 
17
  image = np.array(image)
18
+ image = np.expand_dims(image, axis=0) # Add batch dimension
19
+
20
  # Predict
21
  predictions = model.predict(image)
22
  prediction = np.argmax(predictions, axis=1)[0]
23
  confidence = np.max(predictions)
24
+
25
+ # Prepare output
26
+ result = f"Predicted Pokémon: {labels[prediction]} with confidence: {confidence:.2f}%"
27
  return result
28
 
29
+ st.title("Pokémon Classifier")
30
+
31
+ file_uploader = st.file_uploader("Upload an image of a Pokémon", type=['png', 'jpg', 'jpeg'])
32
+
33
+ if file_uploader is not None:
34
+ # Display the image
35
+ image = Image.open(file_uploader)
36
+ st.image(image, caption='Uploaded Image', use_column_width=True)
37
+
38
+ # Make prediction
39
+ result = predict_pokemon(image)
40
+ st.subheader(result)