Update app.py
Browse files
app.py
CHANGED
@@ -1,87 +1,75 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
import numpy as np
|
4 |
-
from tensorflow.keras.preprocessing import image
|
|
|
|
|
5 |
|
6 |
# Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
|
7 |
class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
img = img.resize((150, 150))
|
15 |
-
|
16 |
-
|
|
|
17 |
img_array = np.expand_dims(img_array, axis=0)
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
except Exception as e:
|
22 |
-
|
23 |
-
return None
|
24 |
-
|
25 |
-
# Laden der Modelle
|
26 |
-
try:
|
27 |
-
model_cnn = tf.keras.models.load_model('exercise_classification_model.h5')
|
28 |
-
print("CNN model loaded successfully")
|
29 |
-
except Exception as e:
|
30 |
-
print(f"Error loading CNN model: {e}")
|
31 |
-
|
32 |
-
try:
|
33 |
-
model_resnet50 = tf.keras.models.load_model('exercise_classification_model_resnet50.h5')
|
34 |
-
print("ResNet50 model loaded successfully")
|
35 |
-
except Exception as e:
|
36 |
-
print(f"Error loading ResNet50 model: {e}")
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
print(f"Error in CNN prediction: {e}")
|
48 |
-
return {label: 0.0 for label in class_labels}
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
return {label: 0.0 for label in class_labels}
|
62 |
-
|
63 |
-
# Gradio Interface für das CNN Modell
|
64 |
-
interface_cnn = gr.Interface(
|
65 |
-
fn=predict_cnn,
|
66 |
-
inputs=gr.Image(type="pil"), # Verwenden Sie gr.Image
|
67 |
-
outputs=gr.Label(num_top_classes=3),
|
68 |
-
live=True,
|
69 |
-
title="Exercise Classification with CNN",
|
70 |
-
description="Upload an image to classify the exercise using CNN model."
|
71 |
)
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
fn=predict_resnet50,
|
76 |
-
inputs=gr.Image(type="pil"), # Verwenden Sie gr.Image
|
77 |
-
outputs=gr.Label(num_top_classes=3),
|
78 |
-
live=True,
|
79 |
-
title="Exercise Classification with ResNet50",
|
80 |
-
description="Upload an image to classify the exercise using ResNet50 model."
|
81 |
-
)
|
82 |
-
|
83 |
-
# Kombiniere die Interfaces in eine Gradio App
|
84 |
-
demo = gr.TabbedInterface([interface_cnn, interface_resnet50], ["CNN Model", "ResNet50 Model"])
|
85 |
-
|
86 |
-
if __name__ == "__main__":
|
87 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
import numpy as np
|
4 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
5 |
+
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
|
8 |
# Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
|
9 |
class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
|
10 |
|
11 |
+
# Load your trained models
|
12 |
+
resnet_model = load_model('exercise_classification_model_resnet50.h5')
|
13 |
+
cnn_model = load_model('exercise_classification_model.h5')
|
14 |
+
|
15 |
+
def predict_exercise(img, model_type):
|
16 |
try:
|
17 |
+
# Convert the image to RGB format
|
18 |
+
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
19 |
+
# Resize the image to the required input size of the model
|
20 |
img = img.resize((150, 150))
|
21 |
+
# Convert the image to an array
|
22 |
+
img_array = keras_image.img_to_array(img)
|
23 |
+
# Expand dimensions to match the model's input shape
|
24 |
img_array = np.expand_dims(img_array, axis=0)
|
25 |
+
|
26 |
+
if model_type == 'ResNet50':
|
27 |
+
# Preprocess the input as expected by ResNet50
|
28 |
+
img_array = resnet_preprocess_input(img_array)
|
29 |
+
# Predict using the ResNet50 model
|
30 |
+
prediction = resnet_model.predict(img_array)
|
31 |
+
elif model_type == 'CNN':
|
32 |
+
# Preprocess the input as expected by the CNN
|
33 |
+
img_array /= 255.0 # Normalisierung
|
34 |
+
# Predict using the CNN model
|
35 |
+
prediction = cnn_model.predict(img_array)
|
36 |
+
else:
|
37 |
+
return {"error": "Invalid model type selected"}
|
38 |
+
|
39 |
+
# Debugging: print the prediction shape
|
40 |
+
print(f"Prediction shape: {prediction.shape}")
|
41 |
+
|
42 |
+
# Check prediction shape
|
43 |
+
if prediction.shape[1] == len(class_labels): # Ensure the prediction matches the number of classes
|
44 |
+
# Return the prediction as a dictionary with class probabilities
|
45 |
+
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
|
46 |
+
else:
|
47 |
+
return {"error": f"Unexpected prediction shape: {prediction.shape}"}
|
48 |
except Exception as e:
|
49 |
+
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
# Custom CSS for Dark Mode and Elegant Design
|
52 |
+
custom_css = """
|
53 |
+
body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;}
|
54 |
+
h1 {color: #ff5722;}
|
55 |
+
label {color: #ff9800;}
|
56 |
+
input[type=radio] {accent-color: #ff5722;}
|
57 |
+
button:hover {background-color: #e64a19;}
|
58 |
+
.footer {display: none !important;}
|
59 |
+
"""
|
|
|
|
|
60 |
|
61 |
+
# Define the Gradio interface
|
62 |
+
interface = gr.Interface(
|
63 |
+
fn=predict_exercise,
|
64 |
+
inputs=[
|
65 |
+
gr.Image(type="numpy", label="Upload an image of the exercise"),
|
66 |
+
gr.Radio(['ResNet50', 'CNN'], label="Choose Model")
|
67 |
+
],
|
68 |
+
outputs=gr.Label(num_top_classes=3, label="Prediction"),
|
69 |
+
title="Exercise Classifier",
|
70 |
+
description="Upload an image of an exercise and the classifier will predict the exercise.",
|
71 |
+
css=custom_css # Apply the custom CSS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
|
74 |
+
# Launch the interface
|
75 |
+
interface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|