Abedinho commited on
Commit
936453f
1 Parent(s): ed2e8fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -73
app.py CHANGED
@@ -1,87 +1,75 @@
1
  import gradio as gr
2
- import tensorflow as tf
3
  import numpy as np
4
- from tensorflow.keras.preprocessing import image
 
 
5
 
6
  # Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
7
  class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
8
 
9
- def preprocess(img):
 
 
 
 
10
  try:
11
- if img is None:
12
- raise ValueError("No image uploaded")
13
- print("Image received:", type(img), img)
14
  img = img.resize((150, 150))
15
- img_array = image.img_to_array(img)
16
- print("Image array shape:", img_array.shape)
 
17
  img_array = np.expand_dims(img_array, axis=0)
18
- img_array /= 255.0
19
- print("Processed image array shape:", img_array.shape)
20
- return img_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  except Exception as e:
22
- print(f"Error in preprocessing: {e}")
23
- return None
24
-
25
- # Laden der Modelle
26
- try:
27
- model_cnn = tf.keras.models.load_model('exercise_classification_model.h5')
28
- print("CNN model loaded successfully")
29
- except Exception as e:
30
- print(f"Error loading CNN model: {e}")
31
-
32
- try:
33
- model_resnet50 = tf.keras.models.load_model('exercise_classification_model_resnet50.h5')
34
- print("ResNet50 model loaded successfully")
35
- except Exception as e:
36
- print(f"Error loading ResNet50 model: {e}")
37
 
38
- def predict_cnn(img):
39
- try:
40
- img_array = preprocess(img)
41
- if img_array is None:
42
- return {label: 0.0 for label in class_labels}
43
- prediction = model_cnn.predict(img_array)
44
- print("Prediction:", prediction)
45
- return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
46
- except Exception as e:
47
- print(f"Error in CNN prediction: {e}")
48
- return {label: 0.0 for label in class_labels}
49
 
50
- def predict_resnet50(img):
51
- try:
52
- img_array = preprocess(img)
53
- if img_array is None:
54
- return {label: 0.0 for label in class_labels}
55
- print(f"Processed image array: {img_array}") # Debugging-Information
56
- prediction = model_resnet50.predict(img_array)
57
- print("Prediction:", prediction)
58
- return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
59
- except Exception as e:
60
- print(f"Error in ResNet50 prediction: {e}")
61
- return {label: 0.0 for label in class_labels}
62
-
63
- # Gradio Interface für das CNN Modell
64
- interface_cnn = gr.Interface(
65
- fn=predict_cnn,
66
- inputs=gr.Image(type="pil"), # Verwenden Sie gr.Image
67
- outputs=gr.Label(num_top_classes=3),
68
- live=True,
69
- title="Exercise Classification with CNN",
70
- description="Upload an image to classify the exercise using CNN model."
71
  )
72
 
73
- # Gradio Interface für das ResNet50 Modell
74
- interface_resnet50 = gr.Interface(
75
- fn=predict_resnet50,
76
- inputs=gr.Image(type="pil"), # Verwenden Sie gr.Image
77
- outputs=gr.Label(num_top_classes=3),
78
- live=True,
79
- title="Exercise Classification with ResNet50",
80
- description="Upload an image to classify the exercise using ResNet50 model."
81
- )
82
-
83
- # Kombiniere die Interfaces in eine Gradio App
84
- demo = gr.TabbedInterface([interface_cnn, interface_resnet50], ["CNN Model", "ResNet50 Model"])
85
-
86
- if __name__ == "__main__":
87
- demo.launch()
 
1
  import gradio as gr
2
+ from PIL import Image
3
  import numpy as np
4
+ from tensorflow.keras.preprocessing import image as keras_image
5
+ from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
6
+ from tensorflow.keras.models import load_model
7
 
8
  # Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
9
  class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
10
 
11
+ # Load your trained models
12
+ resnet_model = load_model('exercise_classification_model_resnet50.h5')
13
+ cnn_model = load_model('exercise_classification_model.h5')
14
+
15
+ def predict_exercise(img, model_type):
16
  try:
17
+ # Convert the image to RGB format
18
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
19
+ # Resize the image to the required input size of the model
20
  img = img.resize((150, 150))
21
+ # Convert the image to an array
22
+ img_array = keras_image.img_to_array(img)
23
+ # Expand dimensions to match the model's input shape
24
  img_array = np.expand_dims(img_array, axis=0)
25
+
26
+ if model_type == 'ResNet50':
27
+ # Preprocess the input as expected by ResNet50
28
+ img_array = resnet_preprocess_input(img_array)
29
+ # Predict using the ResNet50 model
30
+ prediction = resnet_model.predict(img_array)
31
+ elif model_type == 'CNN':
32
+ # Preprocess the input as expected by the CNN
33
+ img_array /= 255.0 # Normalisierung
34
+ # Predict using the CNN model
35
+ prediction = cnn_model.predict(img_array)
36
+ else:
37
+ return {"error": "Invalid model type selected"}
38
+
39
+ # Debugging: print the prediction shape
40
+ print(f"Prediction shape: {prediction.shape}")
41
+
42
+ # Check prediction shape
43
+ if prediction.shape[1] == len(class_labels): # Ensure the prediction matches the number of classes
44
+ # Return the prediction as a dictionary with class probabilities
45
+ return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
46
+ else:
47
+ return {"error": f"Unexpected prediction shape: {prediction.shape}"}
48
  except Exception as e:
49
+ return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Custom CSS for Dark Mode and Elegant Design
52
+ custom_css = """
53
+ body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;}
54
+ h1 {color: #ff5722;}
55
+ label {color: #ff9800;}
56
+ input[type=radio] {accent-color: #ff5722;}
57
+ button:hover {background-color: #e64a19;}
58
+ .footer {display: none !important;}
59
+ """
 
 
60
 
61
+ # Define the Gradio interface
62
+ interface = gr.Interface(
63
+ fn=predict_exercise,
64
+ inputs=[
65
+ gr.Image(type="numpy", label="Upload an image of the exercise"),
66
+ gr.Radio(['ResNet50', 'CNN'], label="Choose Model")
67
+ ],
68
+ outputs=gr.Label(num_top_classes=3, label="Prediction"),
69
+ title="Exercise Classifier",
70
+ description="Upload an image of an exercise and the classifier will predict the exercise.",
71
+ css=custom_css # Apply the custom CSS
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
+ # Launch the interface
75
+ interface.launch()