Abedinho commited on
Commit
4200023
1 Parent(s): 30e456b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -7,13 +7,17 @@ from tensorflow.keras.preprocessing import image
7
  class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
8
 
9
  def preprocess(img):
10
- if img is None:
11
- raise ValueError("No image uploaded")
12
- img = img.resize((150, 150))
13
- img_array = image.img_to_array(img)
14
- img_array = np.expand_dims(img_array, axis=0)
15
- img_array /= 255.0
16
- return img_array
 
 
 
 
17
 
18
  # Laden der Modelle
19
  try:
@@ -31,20 +35,24 @@ except Exception as e:
31
  def predict_cnn(img):
32
  try:
33
  img_array = preprocess(img)
 
 
34
  prediction = model_cnn.predict(img_array)
35
  return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
36
  except Exception as e:
37
  print(f"Error in CNN prediction: {e}")
38
- return {label: 0.0 for label in class_labels} # Standardkonfidenz 0.0 für alle Klassen
39
 
40
  def predict_resnet50(img):
41
  try:
42
  img_array = preprocess(img)
 
 
43
  prediction = model_resnet50.predict(img_array)
44
  return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
45
  except Exception as e:
46
  print(f"Error in ResNet50 prediction: {e}")
47
- return {label: 0.0 for label in class_labels} # Standardkonfidenz 0.0 für alle Klassen
48
 
49
  # Gradio Interface für das CNN Modell
50
  interface_cnn = gr.Interface(
 
7
  class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
8
 
9
  def preprocess(img):
10
+ try:
11
+ if img is None:
12
+ raise ValueError("No image uploaded")
13
+ img = img.resize((150, 150))
14
+ img_array = image.img_to_array(img)
15
+ img_array = np.expand_dims(img_array, axis=0)
16
+ img_array /= 255.0
17
+ return img_array
18
+ except Exception as e:
19
+ print(f"Error in preprocessing: {e}")
20
+ return None
21
 
22
  # Laden der Modelle
23
  try:
 
35
  def predict_cnn(img):
36
  try:
37
  img_array = preprocess(img)
38
+ if img_array is None:
39
+ return {label: 0.0 for label in class_labels}
40
  prediction = model_cnn.predict(img_array)
41
  return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
42
  except Exception as e:
43
  print(f"Error in CNN prediction: {e}")
44
+ return {label: 0.0 for label in class_labels}
45
 
46
  def predict_resnet50(img):
47
  try:
48
  img_array = preprocess(img)
49
+ if img_array is None:
50
+ return {label: 0.0 for label in class_labels}
51
  prediction = model_resnet50.predict(img_array)
52
  return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
53
  except Exception as e:
54
  print(f"Error in ResNet50 prediction: {e}")
55
+ return {label: 0.0 for label in class_labels}
56
 
57
  # Gradio Interface für das CNN Modell
58
  interface_cnn = gr.Interface(