import gradio as gr from PIL import Image import numpy as np from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input from tensorflow.keras.models import load_model # Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen) class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips'] # Load your trained models resnet_model = load_model('exercise_classification_model_resnet50.h5') cnn_model = load_model('exercise_classification_model.h5') def predict_exercise(img, model_type): try: # Convert the image to RGB format img = Image.fromarray(img.astype('uint8'), 'RGB') # Resize the image to the required input size of the model img = img.resize((150, 150)) # Convert the image to an array img_array = keras_image.img_to_array(img) # Expand dimensions to match the model's input shape img_array = np.expand_dims(img_array, axis=0) if model_type == 'ResNet50': # Preprocess the input as expected by ResNet50 img_array = resnet_preprocess_input(img_array) # Predict using the ResNet50 model prediction = resnet_model.predict(img_array) elif model_type == 'CNN': # Preprocess the input as expected by the CNN img_array /= 255.0 # Normalisierung # Predict using the CNN model prediction = cnn_model.predict(img_array) else: return {"error": "Invalid model type selected"} # Debugging: print the prediction shape print(f"Prediction shape: {prediction.shape}") # Check prediction shape if prediction.shape[1] == len(class_labels): # Ensure the prediction matches the number of classes # Return the prediction as a dictionary with class probabilities return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))} else: return {"error": f"Unexpected prediction shape: {prediction.shape}"} except Exception as e: return {"error": str(e)} # Custom CSS for Dark Mode and Elegant Design custom_css = """ body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;} h1 {color: #ff5722;} label {color: #ff9800;} input[type=radio] {accent-color: #ff5722;} button:hover {background-color: #e64a19;} .footer {display: none !important;} """ # Define the Gradio interface interface = gr.Interface( fn=predict_exercise, inputs=[ gr.Image(type="numpy", label="Upload an image of the exercise"), gr.Radio(['ResNet50', 'CNN'], label="Choose Model") ], outputs=gr.Label(num_top_classes=len(class_labels), label="Prediction"), # Show all classes title="Exercise Classifier", description="Upload an image of an exercise and the classifier will predict the exercise.", css=custom_css # Apply the custom CSS ) # Launch the interface interface.launch()