kia_gym / app.py
Abedinho's picture
Update app.py
6449c65 verified
import gradio as gr
from PIL import Image
import numpy as np
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
from tensorflow.keras.models import load_model
# Klassenlabels (ersetzen Sie durch Ihre spezifischen Klassen)
class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips']
# Load your trained models
resnet_model = load_model('exercise_classification_model_resnet50.h5')
cnn_model = load_model('exercise_classification_model.h5')
def predict_exercise(img, model_type):
try:
# Convert the image to RGB format
img = Image.fromarray(img.astype('uint8'), 'RGB')
# Resize the image to the required input size of the model
img = img.resize((150, 150))
# Convert the image to an array
img_array = keras_image.img_to_array(img)
# Expand dimensions to match the model's input shape
img_array = np.expand_dims(img_array, axis=0)
if model_type == 'ResNet50':
# Preprocess the input as expected by ResNet50
img_array = resnet_preprocess_input(img_array)
# Predict using the ResNet50 model
prediction = resnet_model.predict(img_array)
elif model_type == 'CNN':
# Preprocess the input as expected by the CNN
img_array /= 255.0 # Normalisierung
# Predict using the CNN model
prediction = cnn_model.predict(img_array)
else:
return {"error": "Invalid model type selected"}
# Debugging: print the prediction shape
print(f"Prediction shape: {prediction.shape}")
# Check prediction shape
if prediction.shape[1] == len(class_labels): # Ensure the prediction matches the number of classes
# Return the prediction as a dictionary with class probabilities
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))}
else:
return {"error": f"Unexpected prediction shape: {prediction.shape}"}
except Exception as e:
return {"error": str(e)}
# Custom CSS for Dark Mode and Elegant Design
custom_css = """
body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;}
h1 {color: #ff5722;}
label {color: #ff9800;}
input[type=radio] {accent-color: #ff5722;}
button:hover {background-color: #e64a19;}
.footer {display: none !important;}
"""
# Define the Gradio interface
interface = gr.Interface(
fn=predict_exercise,
inputs=[
gr.Image(type="numpy", label="Upload an image of the exercise"),
gr.Radio(['ResNet50', 'CNN'], label="Choose Model")
],
outputs=gr.Label(num_top_classes=len(class_labels), label="Prediction"), # Show all classes
title="Exercise Classifier",
description="Upload an image of an exercise and the classifier will predict the exercise.",
css=custom_css # Apply the custom CSS
)
# Launch the interface
interface.launch()