Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,17 @@ from tensorflow.keras.preprocessing import image
|
|
7 |
class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
|
8 |
|
9 |
def preprocess(img):
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Laden der Modelle
|
19 |
try:
|
@@ -31,20 +35,24 @@ except Exception as e:
|
|
31 |
def predict_cnn(img):
|
32 |
try:
|
33 |
img_array = preprocess(img)
|
|
|
|
|
34 |
prediction = model_cnn.predict(img_array)
|
35 |
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
|
36 |
except Exception as e:
|
37 |
print(f"Error in CNN prediction: {e}")
|
38 |
-
return {label: 0.0 for label in class_labels}
|
39 |
|
40 |
def predict_resnet50(img):
|
41 |
try:
|
42 |
img_array = preprocess(img)
|
|
|
|
|
43 |
prediction = model_resnet50.predict(img_array)
|
44 |
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
|
45 |
except Exception as e:
|
46 |
print(f"Error in ResNet50 prediction: {e}")
|
47 |
-
return {label: 0.0 for label in class_labels}
|
48 |
|
49 |
# Gradio Interface für das CNN Modell
|
50 |
interface_cnn = gr.Interface(
|
|
|
7 |
class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
|
8 |
|
9 |
def preprocess(img):
|
10 |
+
try:
|
11 |
+
if img is None:
|
12 |
+
raise ValueError("No image uploaded")
|
13 |
+
img = img.resize((150, 150))
|
14 |
+
img_array = image.img_to_array(img)
|
15 |
+
img_array = np.expand_dims(img_array, axis=0)
|
16 |
+
img_array /= 255.0
|
17 |
+
return img_array
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error in preprocessing: {e}")
|
20 |
+
return None
|
21 |
|
22 |
# Laden der Modelle
|
23 |
try:
|
|
|
35 |
def predict_cnn(img):
|
36 |
try:
|
37 |
img_array = preprocess(img)
|
38 |
+
if img_array is None:
|
39 |
+
return {label: 0.0 for label in class_labels}
|
40 |
prediction = model_cnn.predict(img_array)
|
41 |
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
|
42 |
except Exception as e:
|
43 |
print(f"Error in CNN prediction: {e}")
|
44 |
+
return {label: 0.0 for label in class_labels}
|
45 |
|
46 |
def predict_resnet50(img):
|
47 |
try:
|
48 |
img_array = preprocess(img)
|
49 |
+
if img_array is None:
|
50 |
+
return {label: 0.0 for label in class_labels}
|
51 |
prediction = model_resnet50.predict(img_array)
|
52 |
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
|
53 |
except Exception as e:
|
54 |
print(f"Error in ResNet50 prediction: {e}")
|
55 |
+
return {label: 0.0 for label in class_labels}
|
56 |
|
57 |
# Gradio Interface für das CNN Modell
|
58 |
interface_cnn = gr.Interface(
|