thomen commited on
Commit
c9ab25f
1 Parent(s): 0e7ca3e

Update pokemon-deploy.py

Browse files
Files changed (1) hide show
  1. pokemon-deploy.py +16 -34
pokemon-deploy.py CHANGED
@@ -3,56 +3,38 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # Load the model
7
  model_path = "pokemon-predict-model_transferlearning.keras"
8
  model = tf.keras.models.load_model(model_path)
9
 
10
  # Define the core prediction function
11
  def predict_pokemon(image):
12
  # Preprocess image
 
13
  image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
14
- image = image.resize((150, 150)) # Resize the image to 150x150
15
  image = np.array(image)
16
- image = np.expand_dims(image, axis=0) # Same as image[None, ...]
17
-
18
  # Predict
19
  prediction = model.predict(image)
20
-
21
  # Apply softmax to get probabilities for each class
22
  prediction = tf.nn.softmax(prediction)
23
-
24
  # Create a dictionary with the probabilities for each Pokemon
25
- pokemon_classes = [
26
- 'Abra', 'Aerodactyl', 'Alakazam', 'Arbok', 'Arcanine', 'Articuno', 'Beedrill', 'Bellsprout',
27
- 'Blastoise', 'Bulbasaur', 'Butterfree', 'Caterpie', 'Chansey', 'Charizard', 'Charmander',
28
- 'Charmeleon', 'Clefable', 'Clefairy', 'Cloyster', 'Cubone', 'Dewgong', 'Diglett', 'Ditto',
29
- 'Dodrio', 'Doduo', 'Dragonair', 'Dragonite', 'Dratini', 'Drowzee', 'Dugtrio', 'Eevee', 'Ekans',
30
- 'Electabuzz', 'Electrode', 'Exeggcute', 'Exeggutor', 'Farfetchd', 'Fearow', 'Flareon', 'Gastly',
31
- 'Gengar', 'Geodude', 'Gloom', 'Golbat', 'Goldeen', 'Golduck', 'Graveler', 'Grimer', 'Growlithe',
32
- 'Gyarados', 'Haunter', 'Hitmonchan', 'Hitmonlee', 'Horsea', 'Hypno', 'Ivysaur', 'Jigglypuff',
33
- 'Jolteon', 'Jynx', 'Kabutops', 'Kadabra', 'Kakuna', 'Kangaskhan', 'Kingler', 'Koffing', 'Lapras',
34
- 'Lickitung', 'Machamp', 'Machoke', 'Machop', 'Magikarp', 'Magmar', 'Magnemite', 'Magneton', 'Mankey',
35
- 'Marowak', 'Meowth', 'Metapod', 'Mew', 'Mewtwo', 'Moltres', 'Mr. Mime', 'MrMime', 'Nidoking', 'Nidoqueen',
36
- 'Nidorina', 'Nidorino', 'Ninetales', 'Oddish', 'Omanyte', 'Omastar', 'Parasect', 'Pidgeot', 'Pidgeotto',
37
- 'Pidgey', 'Pikachu', 'Pinsir', 'Poliwag', 'Poliwhirl', 'Poliwrath', 'Ponyta', 'Porygon', 'Primeape',
38
- 'Psyduck', 'Raichu', 'Rapidash', 'Raticate', 'Rattata', 'Rhydon', 'Rhyhorn', 'Sandshrew', 'Sandslash',
39
- 'Scyther', 'Seadra', 'Seaking', 'Seel', 'Shellder', 'Slowbro', 'Slowpoke', 'Snorlax', 'Spearow', 'Squirtle',
40
- 'Starmie', 'Staryu', 'Tangela', 'Tauros', 'Tentacool', 'Tentacruel', 'Vaporeon', 'Venomoth', 'Venonat',
41
- 'Venusaur', 'Victreebel', 'Vileplume', 'Voltorb', 'Vulpix', 'Wartortle', 'Weedle', 'Weepinbell', 'Weezing',
42
- 'Wigglytuff', 'Zapdos', 'Zubat'
43
- ]
44
-
45
- probabilities = [np.round(float(prediction[0][i]), 2) for i in range(len(pokemon_classes))]
46
- pokemon_probabilities = dict(zip(pokemon_classes, probabilities))
47
 
48
- return pokemon_probabilities
49
 
50
- # Interface setup
51
  input_image = gr.Image()
52
  iface = gr.Interface(
53
  fn=predict_pokemon,
54
- inputs=input_image,
55
  outputs=gr.Label(),
56
- description="A simple MLP classification model for image classification using the Pokémon dataset."
57
- )
58
- iface.launch(share=True)
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+
7
  model_path = "pokemon-predict-model_transferlearning.keras"
8
  model = tf.keras.models.load_model(model_path)
9
 
10
  # Define the core prediction function
11
  def predict_pokemon(image):
12
  # Preprocess image
13
+ print(type(image))
14
  image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
15
+ image = image.resize((150, 150)) #resize the image to 150x150
16
  image = np.array(image)
17
+ image = np.expand_dims(image, axis=0) # same as image[None, ...]
18
+
19
  # Predict
20
  prediction = model.predict(image)
21
+
22
  # Apply softmax to get probabilities for each class
23
  prediction = tf.nn.softmax(prediction)
24
+
25
  # Create a dictionary with the probabilities for each Pokemon
26
+ evee = np.round(float(prediction[0][0]), 2)
27
+ farfetched = np.round(float(prediction[0][1]), 2)
28
+ graveler = np.round(float(prediction[0][2]), 2)
29
+ venonta = np.round(float(prediction[0][3]), 2)
30
+
31
+ return {'Evee': evee, 'Farfetched': farfetched, 'Graveler': graveler, 'Venonta': venonta}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
33
 
 
34
  input_image = gr.Image()
35
  iface = gr.Interface(
36
  fn=predict_pokemon,
37
+ inputs=input_image,
38
  outputs=gr.Label(),
39
+ description="A simple mlp classification model for image classification using the mnist dataset.")
40
+ iface.launch(share=True)