Spaces:
Runtime error
Runtime error
import cv2 | |
import os | |
import numpy as np | |
from ultralytics import YOLO | |
import time | |
import matplotlib.pyplot as plt | |
import pyttsx3 | |
import datetime | |
import gradio as gr | |
import torch | |
import tempfile | |
#from fastrtc import Stream | |
# Global Initializations | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = YOLO("yolo11n-pose.pt").to(device) | |
# Text-to-speech engine | |
engine = pyttsx3.init() | |
voices = engine.getProperty('voices') | |
engine.setProperty('voice', voices[1].id) | |
engine.setProperty('rate', 150) | |
engine.setProperty('volume', 0.9) | |
engine.say("Welcome to the Pop-Up Trainer.") | |
engine.runAndWait() | |
# Tracking state | |
state = { | |
'warning_timestamp_hip': 0, | |
'warning_timestamp_head': 0, | |
'hip_warning_given': False, | |
'head_warning_given': False, | |
'repetitions': 0, | |
'position': "down", | |
'back_angles': [], | |
'head_angles': [], | |
'left_knee_angles': [], | |
'right_knee_angles': [], | |
'frame_numbers': [] | |
} | |
colors = { | |
'default': (255, 0, 0), 'circle': (0, 0, 0), | |
'back': (0, 255, 0), 'head': (0, 165, 255) | |
} | |
# Paths to example images & videos | |
image_folder = "images" | |
video_folder = "videos" | |
# Get list of example images & videos | |
example_images = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))] | |
example_videos = [os.path.join(video_folder, f) for f in os.listdir(video_folder) if f.endswith((".mp4", ".avi", ".mov"))] | |
def get_date_time(): | |
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
def calculate_angle(p1, p2, p3): | |
if not all(p != (0, 0) for p in [p1, p2, p3]): | |
return 0 | |
v1 = np.array(p1) - np.array(p2) | |
v2 = np.array(p3) - np.array(p2) | |
cosine_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
return np.degrees(np.arccos(np.clip(cosine_angle, -1.0, 1.0))) | |
#def process_frame(frame, frame_width, frame_height, fps, conf_threshold): | |
def process_frame(frame, fps, conf_threshold): #put fps back when plotting | |
global device | |
results = model(frame, | |
device=device, | |
imgsz=(480, 640), | |
half=True, | |
conf=conf_threshold, | |
stream=True, | |
stream_buffer=False, | |
max_det=1, | |
vid_stride=3, | |
show=False, #turn off for faster inference | |
show_conf=False, | |
save=False, | |
show_boxes=False, | |
save_crop=False | |
) | |
processed_frame = results[0].plot() | |
# Draw skeleton on image | |
#skeleton_image = draw_skeleton(frame, results[0].keypoints.data) | |
# Plot angles | |
#plot_image = plot_angles(fps) | |
# Combine visualizations | |
#combined = np.hstack((skeleton_image, plot_image)) | |
# Stack frames vertically | |
#min_width = min(pose_frame.shape[1], combined.shape[1]) | |
#pose_frame = cv2.resize(pose_frame, (min_width, pose_frame.shape[0])) | |
#combined = cv2.resize(combined, (min_width, combined.shape[0])) | |
#final_frame = np.vstack((pose_frame, combined)) | |
return processed_frame | |
def process_image(image_path, conf): | |
image = cv2.imread(image_path) | |
results = model(frame, | |
device=device, | |
imgsz=(480, 640), | |
half=True, | |
conf=conf_threshold, | |
stream=True, | |
stream_buffer=False, | |
max_det=1, | |
vid_stride=3, | |
show=False, #turn off for faster inference | |
show_conf=False, | |
save=False, | |
show_boxes=False, | |
save_crop=False | |
) | |
#results = process_frame(image, device=device, conf=conf) | |
processed_image = results[0].plot() | |
return cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) | |
# def process_video(video_path, conf_threshold): | |
# global state | |
# reset_state() | |
# cap = cv2.VideoCapture(video_path) | |
# if not cap.isOpened(): | |
# return "Error: Could not open video file." | |
# frame_width, frame_height = int(cap.get(3)), int(cap.get(4)) | |
# fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
# temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
# output_video_path = temp_output.name | |
# fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
# out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) | |
# while cap.isOpened(): | |
# ret, frame = cap.read() | |
# if not ret: | |
# break | |
# results = process_frame(frame, fps=fps, conf=conf_threshold) | |
# processed_frame = results[0].plot() | |
# out.write(processed_frame) | |
# yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) # Stream frame-by-frame | |
# cap.release() | |
# out.release() | |
# return output_video_path # Return downloadable processed video | |
def process_webcam(conf_threshold): | |
cap = cv2.VideoCapture(0) | |
if not cap.isOpened(): | |
raise ValueError("Error opening webcam") | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
processed_frame = process_frame(frame, 30, conf=conf_threshold) | |
processed_frame = results[0].plot() | |
yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) | |
cap.release() | |
def draw_skeleton(image, keypoints): | |
skeleton_image = np.full_like(image, 255) | |
connections = [(3, 1), (1, 0), (0, 2), (2, 4), (1, 2), (4, 6), (3, 5), | |
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10), (11, 12), (11, 13), | |
(13, 15), (12, 14), (14, 16), (5, 11), (6, 12)] | |
keypoints = keypoints.cpu().numpy() | |
#if keypoints.ndim == 2: | |
#keypoints = np.expand_dims(keypoints, axis=0) | |
keypoints = keypoints[0] | |
hip_idx = 12 if keypoints[12][2] > 0.5 else 11 | |
shoulder_idx = 6 if keypoints[12][2] > 0.5 else 5 | |
ear_idx = 4 if keypoints[12][2] > 0.5 else 3 | |
eye_idx = 2 if keypoints[12][2] > 0.5 else 1 | |
knee_idx = 14 if keypoints[12][2] > 0.5 else 13 | |
ankle_idx = 16 if keypoints[12][2] > 0.5 else 15 | |
for i, (x, y, conf) in enumerate(keypoints): | |
if conf > 0.5: | |
cv2.circle(skeleton_image, (int(x), int(y)), 12, colors['circle'], -1) | |
cv2.putText(skeleton_image, f'{i}', (int(x), int(y)-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.4, colors['circle'], 1) | |
for part_a, part_b in connections: | |
x1, y1, conf1 = keypoints[part_a] | |
x2, y2, conf2 = keypoints[part_b] | |
if conf1 > 0.5 and conf2 > 0.5: | |
color = colors['default'] | |
if (part_a, part_b) == (shoulder_idx, hip_idx): | |
color = colors['back'] | |
elif (part_a, part_b) == (eye_idx, ear_idx): | |
color = colors['head'] | |
cv2.line(skeleton_image, (int(x1), int(y1)), (int(x2), int(y2)), color, 6) | |
angles = process_pose(keypoints, hip_idx, shoulder_idx, ear_idx, eye_idx, knee_idx, ankle_idx, skeleton_image) | |
update_visualization(skeleton_image, angles, image.shape[0]) | |
return skeleton_image | |
def process_pose(keypoints, hip_idx, shoulder_idx, ear_idx, eye_idx, knee_idx, ankle_idx, skeleton_image): | |
global state | |
angles = {'back': 0, 'head': 0, 'right_knee': 0, 'left_knee': 0} | |
# Head angle calculation | |
eye, ear = keypoints[eye_idx], keypoints[ear_idx] | |
if eye[2] > 0.5 and ear[2] > 0.5: | |
angles['head'] = -np.degrees(np.arctan2(eye[1] - ear[1], abs(eye[0] - ear[0]))) | |
if angles['head'] < 0 and not state['head_warning_given']: | |
state['warning_timestamp_head'] = time.time() | |
engine.say("Keep your head up") | |
engine.runAndWait() | |
state['head_warning_given'] = True | |
# Back angle calculation | |
shoulder, hip = keypoints[shoulder_idx], keypoints[hip_idx] | |
if shoulder[2] > 0.5 and hip[2] > 0.5: | |
angles['back'] = -np.degrees(np.arctan2(shoulder[1] - hip[1], abs(shoulder[0] - hip[0]))) | |
if angles['back'] < 0 and not state['hip_warning_given']: | |
state['warning_timestamp_hip'] = time.time() | |
engine.say("Lower your hips") | |
engine.runAndWait() | |
state['hip_warning_given'] = True | |
# Knee angles calculation | |
for side, indices in [('right', (12, 14, 16)), ('left', (11, 13, 15))]: | |
hip, knee, ankle = [keypoints[i] for i in indices] | |
if all(k[2] > 0.5 for k in [hip, knee, ankle]): | |
angles[f'{side}_knee'] = calculate_angle( | |
(hip[0], hip[1]), (knee[0], knee[1]), (ankle[0], ankle[1])) | |
# Repetition counting logic | |
if angles['right_knee'] > 150 or angles['left_knee'] > 150: | |
if state['position'] == "up": | |
state['repetitions'] += 1 | |
state['position'] = "down" | |
engine.say(f"Repetitions: {state['repetitions']}") | |
engine.runAndWait() | |
elif angles['right_knee'] < 110 and angles['left_knee'] < 110: | |
state['position'] = "up" | |
# Update state with new angles | |
for key, value in angles.items(): | |
state[f'{key}_angles'].append(value) | |
state['frame_numbers'].append(len(state['frame_numbers'])) | |
return angles | |
def update_visualization(image, angles, height): | |
global state | |
current_time = time.time() | |
# Display warnings | |
if current_time - state['warning_timestamp_hip'] < 5: | |
cv2.putText(image, 'Lower your hips', (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3) | |
if current_time - state['warning_timestamp_head'] < 5: | |
cv2.putText(image, 'Lift your head up', (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3) | |
# Display angles | |
y_pos = 180 | |
for name, value in angles.items(): | |
cv2.putText(image, f'{name.capitalize()} Angle: {value:.2f}', (10, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
y_pos += 60 | |
# Display repetitions | |
cv2.putText(image, f'Repetitions: {state["repetitions"]}', (10, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
# Add ground line | |
ground_y = height - 50 | |
cv2.line(image, (0, ground_y), (image.shape[1], ground_y), (0, 0, 0), 2) | |
overlay = image.copy() | |
overlay[ground_y:, :] = (192, 192, 192) | |
cv2.addWeighted(overlay, 0.5, image, 0.5, 0, image) | |
return cv2.copyMakeBorder(image, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=(0, 0, 0)) | |
def plot_angles(fps): | |
global state | |
times = [frame / fps for frame in state['frame_numbers']] | |
plt.figure(figsize=(5, 5)) | |
# Plot all angles | |
for angle_type in ['back', 'head', 'left_knee', 'right_knee']: | |
angles = state[f'{angle_type}_angles'][:len(times)] | |
plt.plot(times, angles, label=f'{angle_type.capitalize()} Angle') | |
plt.xlabel('Time (seconds)') | |
plt.ylabel('Angle (degrees)') | |
plt.legend() | |
plt.title('Angles') | |
plt.grid(True) | |
plt.tight_layout() | |
plot_path = os.path.join(output_folder, 'plot.png') | |
plt.savefig(plot_path) | |
plt.close() | |
return cv2.imread(plot_path) | |
def process_video(video_path): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return "Error: Could not open video file." | |
frame_width, frame_height = int(cap.get(3)), int(cap.get(4)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
output_video_path = temp_output.name | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
results = model(frame, | |
device=device, | |
imgsz=(480, 640), | |
half=True, | |
conf=conf_threshold, | |
stream=True, | |
stream_buffer=False, | |
max_det=1, | |
vid_stride=3, | |
show=False, #turn off for faster inference | |
show_conf=False, | |
save=False, | |
show_boxes=False, | |
save_crop=False | |
) | |
processed_frame = results[0].plot() | |
out.write(processed_frame) | |
yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) # Stream frame-by-frame | |
cap.release() | |
out.release() | |
return output_video_path # Return downloadable processed video | |
def create_gui(): | |
with gr.Blocks(title="Pose Trainer") as demo: | |
gr.Markdown("# 🏄 Pose Estimation App with Repetition Counter") | |
gr.Markdown( | |
"**A repetition counter. A repetition starts in the down position when both knees are greater than 150 degrees.**" | |
" **Halfway is UP when both knees are less than 100 degrees.**" | |
" **The repetition is complete when the knees are over 150 degrees again (down).**" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tab("Image"): | |
#image_input = gr.Image(label="Upload Image", sources="upload", type="image", height=480, width=640, autoplay=True) | |
img_input = gr.File(label="Upload an Image", type="filepath") | |
img_output = gr.Image(label="Processed Image") | |
img_examples = gr.Gallery(value=example_images, inputs=img_input, height=150, preview=True) | |
img_button = gr.Button("Process Image") | |
img_button.click(fn=process_image, inputs=[img_input, gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.30)], outputs=img_output) | |
with gr.Tab("Upload Video"): | |
vid_input = gr.File(label="Upload a Video", type="filepath") | |
vid_output = gr.Video(label="Processed Video Stream", streaming=True) | |
vid_examples = gr.Examples(examples=example_videos, inputs=vid_input) | |
vid_button = gr.Button("Process Video") | |
vid_button.click(process_video, inputs=[img_input, gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.30)], outputs=vid_output) | |
with gr.Tab("Use Webcam"): | |
#webcam_input = gr.Video(label="Webcam", sources="webcam", interactive=True, streaming=True, format="mp4", height=480, width=640, autoplay=True, show_download_button=True) | |
webcam_output = gr.Video(label="Live Processed Webcam Stream", streaming=True) | |
webcam_button = gr.Button("Start Webcam") | |
webcam_button.click(process_webcam, outputs=webcam_output) | |
gr.Markdown("Developed for **Hugging Face Spaces** | YOLOv11n-Pose + OpenCV | **GPU Acceleration Supported**") | |
return demo | |
if __name__ == "__main__": | |
gui = create_gui() | |
gui.launch() |