Abedinho commited on
Commit
bcd4512
1 Parent(s): 273d677

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from tensorflow.keras.preprocessing import image
5
 
6
  # Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
7
- class_labels = ['Bench Press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
8
 
9
  def preprocess(img):
10
  img = img.resize((150, 150))
@@ -27,14 +27,22 @@ except Exception as e:
27
  print(f"Error loading ResNet50 model: {e}")
28
 
29
  def predict_cnn(img):
30
- img_array = preprocess(img)
31
- prediction = model_cnn.predict(img_array)
32
- return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
 
 
 
 
33
 
34
  def predict_resnet50(img):
35
- img_array = preprocess(img)
36
- prediction = model_resnet50.predict(img_array)
37
- return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
 
 
 
 
38
 
39
  # Gradio Interface für das CNN Modell
40
  interface_cnn = gr.Interface(
 
4
  from tensorflow.keras.preprocessing import image
5
 
6
  # Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
7
+ class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
8
 
9
  def preprocess(img):
10
  img = img.resize((150, 150))
 
27
  print(f"Error loading ResNet50 model: {e}")
28
 
29
  def predict_cnn(img):
30
+ try:
31
+ img_array = preprocess(img)
32
+ prediction = model_cnn.predict(img_array)
33
+ return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
34
+ except Exception as e:
35
+ print(f"Error in CNN prediction: {e}")
36
+ return {"error": str(e)}
37
 
38
  def predict_resnet50(img):
39
+ try:
40
+ img_array = preprocess(img)
41
+ prediction = model_resnet50.predict(img_array)
42
+ return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
43
+ except Exception as e:
44
+ print(f"Error in ResNet50 prediction: {e}")
45
+ return {"error": str(e)}
46
 
47
  # Gradio Interface für das CNN Modell
48
  interface_cnn = gr.Interface(