|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
from tensorflow.keras.preprocessing import image as keras_image |
|
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input |
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
class_labels = ['bench_press', 'deadlift', 'hip_thrust', 'lat_pulldown', 'pull_up', 'squat', 'tricep_dips'] |
|
|
|
|
|
resnet_model = load_model('exercise_classification_model_resnet50.h5') |
|
cnn_model = load_model('exercise_classification_model.h5') |
|
|
|
def predict_exercise(img, model_type): |
|
try: |
|
|
|
img = Image.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
img = img.resize((150, 150)) |
|
|
|
img_array = keras_image.img_to_array(img) |
|
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
|
|
if model_type == 'ResNet50': |
|
|
|
img_array = resnet_preprocess_input(img_array) |
|
|
|
prediction = resnet_model.predict(img_array) |
|
elif model_type == 'CNN': |
|
|
|
img_array /= 255.0 |
|
|
|
prediction = cnn_model.predict(img_array) |
|
else: |
|
return {"error": "Invalid model type selected"} |
|
|
|
|
|
print(f"Prediction shape: {prediction.shape}") |
|
|
|
|
|
if prediction.shape[1] == len(class_labels): |
|
|
|
return {class_labels[i]: float(prediction[0][i]) for i in range(len(class_labels))} |
|
else: |
|
return {"error": f"Unexpected prediction shape: {prediction.shape}"} |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
|
|
custom_css = """ |
|
body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;} |
|
h1 {color: #ff5722;} |
|
label {color: #ff9800;} |
|
input[type=radio] {accent-color: #ff5722;} |
|
button:hover {background-color: #e64a19;} |
|
.footer {display: none !important;} |
|
""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_exercise, |
|
inputs=[ |
|
gr.Image(type="numpy", label="Upload an image of the exercise"), |
|
gr.Radio(['ResNet50', 'CNN'], label="Choose Model") |
|
], |
|
outputs=gr.Label(num_top_classes=len(class_labels), label="Prediction"), |
|
title="Exercise Classifier", |
|
description="Upload an image of an exercise and the classifier will predict the exercise.", |
|
css=custom_css |
|
) |
|
|
|
|
|
interface.launch() |
|
|