sprinala commited on
Commit
2c96251
·
verified ·
1 Parent(s): b98a78d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -10,28 +10,22 @@ model = tf.keras.models.load_model('gym_equipment_transferlearning.keras')
10
  class_names = ['benchPress', 'dumbBell', 'kettleBell', 'treadMill']
11
 
12
  def classify_image(image):
13
- # Convert the input image to a PIL image
14
  image = Image.fromarray(image.astype('uint8'), 'RGB')
15
- # Resize the image to the input size expected by the model
16
  img = image.resize((150, 150))
17
- # Convert the image to a numpy array and expand dimensions to create a batch
18
  img_array = tf.keras.preprocessing.image.img_to_array(img)
19
- img_array = tf.expand_dims(img_array, 0)
20
- # Make predictions
21
  predictions = model.predict(img_array)
22
- # Get the predicted class and confidence
23
  predicted_class = class_names[np.argmax(predictions[0])]
24
  confidence = np.max(predictions[0])
25
  return {predicted_class: float(confidence)}
26
 
27
- # Define the Gradio interface components
28
  image_input = gr.Image() # Entferne den `shape` Parameter
29
  label = gr.Label(num_top_classes=3)
30
 
31
- # Create the Gradio interface
32
  iface = gr.Interface(
33
- fn=classify_image,
34
- inputs=image_input,
35
  outputs=label,
36
  title='Gym Equipment Classifier',
37
  description='Upload an image of gym equipment and the classifier will tell you which one it is and the confidence level of the prediction.'
 
10
  class_names = ['benchPress', 'dumbBell', 'kettleBell', 'treadMill']
11
 
12
  def classify_image(image):
 
13
  image = Image.fromarray(image.astype('uint8'), 'RGB')
 
14
  img = image.resize((150, 150))
 
15
  img_array = tf.keras.preprocessing.image.img_to_array(img)
16
+ img_array = tf.expand_dims(img_array, 0) # Erstelle einen Batch
 
17
  predictions = model.predict(img_array)
 
18
  predicted_class = class_names[np.argmax(predictions[0])]
19
  confidence = np.max(predictions[0])
20
  return {predicted_class: float(confidence)}
21
 
22
+
23
  image_input = gr.Image() # Entferne den `shape` Parameter
24
  label = gr.Label(num_top_classes=3)
25
 
 
26
  iface = gr.Interface(
27
+ fn=classify_image,
28
+ inputs=image_input,
29
  outputs=label,
30
  title='Gym Equipment Classifier',
31
  description='Upload an image of gym equipment and the classifier will tell you which one it is and the confidence level of the prediction.'