PopUpTrainer / app.py
sschmied's picture
Update app.py
9d70e71 verified
# %%
# %%pip install mediapipe gradio plotly matplotlib pyttsx3 opencv-python --quiet
# Import packages
import base64
import gradio as gr
import cv2
import mediapipe as mp
from mediapipe import solutions
from ultralytics import YOLO
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import math
from gtts import gTTS
from pygame import mixer
import os
from datetime import datetime
import logging
import io
import torch
import time
import tempfile
# Initialize logging
logging.basicConfig(level=logging.INFO)
logging.getLogger().setLevel(logging.CRITICAL + 1)
logger = logging.getLogger(__name__)
# Check for GPU availability
device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = YOLO("yolo11/yolo11n-pose.pt").to(device)
confidence = 0.5
# Initialize MediaPipe Pose model
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose
mp_drawing_styles = solutions.drawing_styles
pose_image = mp_pose.Pose(
static_image_mode=True,
model_complexity=2,
smooth_landmarks=False,
min_detection_confidence=confidence,
min_tracking_confidence=confidence,
enable_segmentation=False
)
pose_video = mp_pose.Pose(
static_image_mode=False,
model_complexity=2,
smooth_landmarks=True,
min_detection_confidence=confidence,
min_tracking_confidence=confidence,
enable_segmentation=False
)
pose_webcam = mp_pose.Pose(
static_image_mode=False,
model_complexity=2,
smooth_landmarks=True,
min_detection_confidence=confidence,
min_tracking_confidence=confidence,
enable_segmentation=False
)
output_folder = "output"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Paths to example images & videos
image_folder = "images"
video_folder = "videos"
example_images = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))] if os.path.exists(image_folder) else []
example_videos = [os.path.join(video_folder, f) for f in os.listdir(video_folder) if f.endswith((".mp4", ".avi", ".mov"))] if os.path.exists(video_folder) else []
# Global constants
FRAMERATE = 30
frame_reduction = 5
reduced_FRAMERATE = int(FRAMERATE / frame_reduction)
repetitions = 0
visible_side = "right"
previous_knee_forward = False
warnings_issued = {"look_up": False, "lower_hips": False}
# Initialize webcam control variables
is_webcam_running = False
webcam_paused = False
stop_webcam = False
status_text_webcam = None
start_btn_webcam = None
stop_btn_webcam = None
pause_btn_webcam = None
resume_btn_webcam = None
reps_counter_webcam = None
last_skeleton_3d = None
last_frame = None
try:
if not mixer.get_init():
mixer.init()
except Exception as e:
logging.warning(f"Failed to initialize pygame mixer: {e}")
def speak_message(message):
try:
# Convert text to speech using gTTS
tts = gTTS(text=message, lang="en")
# Save audio to an in-memory buffer
audio_buffer = io.BytesIO()
tts.write_to_fp(audio_buffer)
audio_buffer.seek(0) # Reset the buffer pointer to the beginning
# Load and play the audio directly from the buffer
sound = mixer.Sound(audio_buffer)
sound.play()
# Wait until the playback finishes
while mixer.get_busy():
time.sleep(0.1)
except Exception as e:
logging.warning(f"Speech synthesis failed: {e}")
def calculate_angle(a, b, c):
ba = np.array(a) - np.array(b)
bc = np.array(c) - np.array(b)
cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
angle = np.arccos(np.clip(cosine_angle, -1.0, 1.0)) # Avoid invalid values
return np.degrees(angle)
def angle_with_vertical(point1, point2):
"""
Calculate the angle between the line defined by two points and the vertical axis in degrees.
"""
dx = point2[0] - point1[0]
dy = point2[1] - point1[1]
angle_radians = math.atan2(dx, dy)
angle_degrees = math.degrees(angle_radians)
angle_degrees = abs(math.degrees(angle_radians)) # Ensure positive angle
return angle_degrees
def create_3d_skeleton(skeleton_coords):
"""Create interactive 3D skeleton plot"""
valid_coords = [coord for coord in skeleton_coords if len(coord) == 3]
x, y, z = zip(*valid_coords)
x = [-xi for xi in x] # Negate x-coordinates to rotate 180 degrees
z = [-zi for zi in z] # Negate z-coordinates for consistency
# Add joints with standard MediaPipe landmark color
fig = go.Figure()
fig.add_trace(go.Scatter3d(
x=x, y=y, z=z,
mode='markers',
marker=dict(
size=6,
color='rgb(227, 34, 20)', # Wlandmarks
),
name='3D Skeleton'
))
connections = mp_pose.POSE_CONNECTIONS
for connection in connections:
if connection[0] < len(x) and connection[1] < len(x):
fig.add_trace(go.Scatter3d(
x=[x[connection[0]], x[connection[1]]],
y=[y[connection[0]], y[connection[1]]],
z=[z[connection[0]], z[connection[1]]],
mode='lines',
line=dict(
color='rgb(3, 252, 244)', # connections
width=4
),
showlegend=False
))
# Calculate axis ranges for consistent scaling
x_values = np.array(x)
y_values = np.array(y)
z_values = np.array(z)
max_range = max(max(x_values) - min(x_values),
max(y_values) - min(y_values),
max(z_values) - min(z_values))
fig.update_layout(
scene=dict(
xaxis_title="Horizontal (X)",
yaxis_title="Vertical (Y)",
zaxis_title="Depth (Z)",
camera=dict(
up=dict(x=0, y=-1, z=0), # Define the 'up' direction (Y-axis is up)
eye=dict(
x=0,
y=0,
z=2 # Position the camera along the Z-axis
),
),
xaxis=dict(showgrid=False, showticklabels=False, visible=False),
yaxis=dict(showgrid=False, showticklabels=False, visible=False),
zaxis=dict(showgrid=False, showticklabels=False, visible=False)
),
margin=dict(l=0, r=0, b=0, t=0),
# scene_camera=dict(projection=dict(type='orthographic')),
template='plotly_dark'
)
return fig
def generate_plot_frame(data, current_frame):
global FRAMERATE
plt.figure(figsize=(10, 5))
time_seconds = np.array(data["frame"]) / FRAMERATE
plt.plot(time_seconds[:current_frame + 1], data["back_angle"][:current_frame + 1], label="Back Angle")
plt.plot(time_seconds[:current_frame + 1], data["head_angle"][:current_frame + 1], label="Head Angle")
plt.plot(time_seconds[:current_frame + 1], data["left_knee_angle"][:current_frame + 1], label="Left Knee Angle")
plt.plot(time_seconds[:current_frame + 1], data["right_knee_angle"][:current_frame + 1], label="Right Knee Angle")
plt.xlabel("Time (s)")
plt.ylabel("Angle (°)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plot_path = f"output/pose_plot-{datetime.now().strftime('%Y%m%d')}.png"
plt.savefig(plot_path)
plt.close()
return plot_path
def check_repetition(knee, hip):
global previous_knee_forward, repetitions, visible_side
if visible_side == "right":
knee_forward = knee.x > hip.x
elif visible_side == "left":
knee_forward = knee.x < hip.x
if not previous_knee_forward and knee_forward:
repetitions += 1
previous_knee_forward = knee_forward
return None
def process_frame(frame, angles_data, frame_idx, media):
global repetitions, visible_side, warnings_issued
skeleton_coords = []
try:
logging.info(f"Processing frame {frame_idx}")
if media == "image":
results = pose_image.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
elif media == "video":
results = pose_video.process(frame)
elif media == "webcam":
results = pose_webcam.process(frame)
else:
raise ValueError("Unsupported media type.")
black_frame = np.zeros_like(frame)
if not results.pose_landmarks:
logging.warning("No pose landmarks detected.")
return black_frame, None, None, angles_data
landmarks = results.pose_landmarks.landmark
skeleton_coords = [[lm.x, lm.y, lm.z] for lm in landmarks]
skeleton_coords = np.array(skeleton_coords)
# (safe extraction of landmarks)
left_hip = landmarks[mp_pose.PoseLandmark.LEFT_HIP.value]
left_knee = landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value]
left_ankle = landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value]
right_hip = landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value]
right_knee = landmarks[mp_pose.PoseLandmark.RIGHT_KNEE.value]
right_ankle = landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE.value]
eye = (
landmarks[mp_pose.PoseLandmark.RIGHT_EYE.value]
if landmarks[mp_pose.PoseLandmark.RIGHT_EYE.value].visibility > 0.5
else landmarks[mp_pose.PoseLandmark.LEFT_EYE.value]
)
ear = (
landmarks[mp_pose.PoseLandmark.RIGHT_EAR.value]
if landmarks[mp_pose.PoseLandmark.RIGHT_EAR.value].visibility > 0.5
else landmarks[mp_pose.PoseLandmark.LEFT_EAR.value]
)
knee = (
right_knee
if visible_side == "right" and right_knee.visibility > 0.5
else left_knee
)
hip = right_hip if right_hip.visibility > 0.5 else left_hip
shoulder = (
landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER.value]
if landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER.value].visibility > 0.5
else landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value]
)
# Angles
head_angle = angle_with_vertical((eye.x, eye.y), (ear.x, ear.y))
back_angle = angle_with_vertical((shoulder.x, shoulder.y), (hip.x, hip.y))
left_knee_angle = calculate_angle((left_hip.x, left_hip.y), (left_knee.x, left_knee.y), (left_ankle.x, left_ankle.y))
right_knee_angle = calculate_angle((right_hip.x, right_hip.y), (right_knee.x, right_knee.y), (right_ankle.x, right_ankle.y))
check_repetition(knee, hip)
# Draw head angle line
start_point = (int(eye.x * frame.shape[1]), int(eye.y * frame.shape[0]))
if visible_side == "right":
end_point = (
int(start_point[0] + 100 * np.cos(np.radians(90 - head_angle))),
int(start_point[1] - 100 * np.sin(np.radians(90 - head_angle))),
)
else: # visible_side == "left"
end_point = (
int(start_point[0] - 100 * np.cos(np.radians(90 - head_angle))),
int(start_point[1] + 100 * np.sin(np.radians(90 - head_angle))), # Adjusted for left side
)
cv2.arrowedLine(black_frame, start_point, end_point, (0, 0, 255), 2)
# Draw feedback text
cv2.putText(black_frame, f"Back: {back_angle:.1f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(black_frame, f"Head: {head_angle:.1f}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(black_frame, f"Left Knee: {left_knee_angle:.1f}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(black_frame, f"Right Knee: {right_knee_angle:.1f}", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(black_frame, f"Repetitions: {repetitions}", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Warnings
if head_angle > 110 and not warnings_issued["look_up"]:
speak_message("Look up")
warnings_issued["look_up"] = True
if back_angle > 90 and not warnings_issued["lower_hips"]:
speak_message("Lower your hips")
warnings_issued["lower_hips"] = True
# Draw pose
mp_drawing.draw_landmarks(
black_frame,
results.pose_landmarks,
mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=mp_drawing.DrawingSpec(color=(255, 255, 0), thickness=2, circle_radius=2)
)
# Append angle data
angles_data["frame"].append(frame_idx)
angles_data["back_angle"].append(back_angle)
angles_data["head_angle"].append(head_angle)
angles_data["left_knee_angle"].append(left_knee_angle)
angles_data["right_knee_angle"].append(right_knee_angle)
black_frame = cv2.cvtColor(black_frame, cv2.COLOR_BGR2RGB)
skeleton_3d = create_3d_skeleton(skeleton_coords)
# plot_frame = generate_plot_frame(angles_data, frame_idx)
return black_frame, skeleton_3d, angles_data
except Exception as e:
logging.error(f"Exception during frame processing: {e}", exc_info=True)
return frame, None, angles_data
def process_image(image_path):
global warnings_issued, repetitions, previous_knee_forward
media = "image"
repetitions = 0
previous_knee_forward = False
warnings_issued = {"look_up": False, "lower_hips": False}
angles_data = {"frame": [], "back_angle": [], "head_angle": [], "left_knee_angle": [], "right_knee_angle": []}
try:
if isinstance(image_path, str):
frame = cv2.imread(image_path)
elif isinstance(image_path, np.ndarray):
frame = cv2.cvtColor(image_path, cv2.COLOR_RGB2BGR)
else:
raise ValueError("Unsupported image input type.")
if frame is None:
raise ValueError("Failed to load image.")
black_frame, skeleton_3d, _ = process_frame(frame, angles_data, 0, media)
return black_frame, skeleton_3d
except Exception as e:
logging.error(f"Error processing image: {e}", exc_info=True)
return None, None
def process_video(video_path=None):
global fps, warnings_issued, repetitions, previous_knee_forward, frame_reduction, reduced_FRAMERATE
angles_data = {"frame": [], "back_angle": [], "head_angle": [], "left_knee_angle": [], "right_knee_angle": []}
frames = []
repetitions = 0
previous_knee_forward = False
warnings_issued = {"look_up": False, "lower_hips": False}
# Video capture setup
cap = cv2.VideoCapture(video_path if video_path else 1)
if video_path:
media = "video"
else:
media = "webcam"
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
reduced_FRAMERATE = int(fps / frame_reduction) # Reduce frame rate
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 format
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
output_path = temp_file.name
# pose_video = f"output/pose_video-{datetime.now().strftime('%Y%m%d%H%M')}.mp4"
out = cv2.VideoWriter(output_path, fourcc, reduced_FRAMERATE, (original_width, original_height))
try:
frame_idx = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Process only every nth frame
if frame_idx % frame_reduction != 0:
frame_idx += 1
continue
black_frame, skeleton_3d, angles_data = process_frame(frame, angles_data, frame_idx, media)
# Write frame to video
out.write(black_frame)
angle_plot = generate_plot_frame(angles_data, frame_idx)
yield (skeleton_3d, None, None)
frame_idx += 1
finally:
cap.release()
out.release()
# Final yield with completed video
yield (skeleton_3d,
gr.update(visible=True, value=output_path), # Use the in-memory buffer for video output
gr.update (visible=True, value=angle_plot) # Show the video at the end
)
def start_webcam():
global is_webcam_running, webcam_paused, stop_webcam
if not is_webcam_running:
stop_webcam = False
webcam_paused = False
is_webcam_running = True
return (
gr.update(interactive=False),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(value="Webcam started")
)
return (
gr.update(interactive=False),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(value="Webcam already running")
)
def pause_webcam():
global webcam_paused
webcam_paused = True
return (
gr.update(interactive=False),
gr.update(interactive=True),
gr.update(value="Webcam paused")
)
def resume_webcam():
global webcam_paused
webcam_paused = False
return (
gr.update(interactive=True),
gr.update(interactive=False),
gr.update(value="Webcam resumed")
)
def stop_webcam_action():
global is_webcam_running, stop_webcam
stop_webcam = True
is_webcam_running = False
# Do not reset outputs, just update the status
return (
gr.update(interactive=False), # Disable the stop button
gr.update(value="Webcam stopped") # Update status text
)
def process_webcam_feed():
global fps, warnings_issued, repetitions, previous_knee_forward, frame_reduction, reduced_FRAMERATE, is_webcam_running, webcam_paused, reps_counter_webcam, stop_webcam, last_skeleton_3d, last_frame
angles_data = {"frame": [], "back_angle": [], "head_angle": [], "left_knee_angle": [], "right_knee_angle": []}
media = "webcam"
repetitions = 0
last_skeleton_3d = None
last_frame = None
previous_knee_forward = False
warnings_issued = {"look_up": False, "lower_hips": False}
angle_plot_webcam = None
# Video capture setup
cap = cv2.VideoCapture(1)
if not cap.isOpened():
yield (
None,
None,
None,
None,
0
)
return
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Video writer setup with H.264 codec
pose_video = f"output/pose_video-{datetime.now().strftime('%Y%m%d%H%M')}.mp4"
out = cv2.VideoWriter(pose_video, cv2.VideoWriter_fourcc(*'avc1'), reduced_FRAMERATE, (original_width, original_height))
try:
frame_idx = 0
while is_webcam_running and not stop_webcam:
if webcam_paused:
time.sleep(0.1)
continue
ret, frame = cap.read()
if not ret:
break
# Process only every nth frame
if frame_idx % frame_reduction != 0:
frame_idx += 1
continue
frame_idx = len(angles_data["frame"]) # Tracking frame index
black_frame, skeleton_3d, angles_data = process_frame(frame, angles_data, frame_idx, "webcam")
# Write frame to video
out.write(black_frame)
angle_plot_webcam = generate_plot_frame(angles_data, frame_idx)
# Store the last skeleton and frame
last_skeleton_3d = skeleton_3d # Store the figure object directly
last_frame = frame # Store the numpy array directly
yield (cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) if isinstance(frame, np.ndarray) else None,
skeleton_3d,
None,
None,
repetitions
)
frame_idx += 1
finally:
cap.release()
out.release()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# Final yield with completed video
if isinstance(last_frame, np.ndarray):
yield (
cv2.cvtColor(last_frame, cv2.COLOR_RGB2BGR),
last_skeleton_3d,
gr.update(visible=True, value=pose_video),
gr.update(visible=True, value=angle_plot_webcam),
repetitions
)
else:
logging.warning("No valid frames were captured during the webcam session.")
yield (
None,
None,
gr.update(visible=False),
gr.update(visible=False),
repetitions
)
# Fix: Update Gradio interface to use the corrected functions
def build_gradio_interface():
"""
Build Gradio interface with updated logic.
"""
global status_text_webcam, start_btn_webcam, stop_btn_webcam, webcam_output_webcam, webcam_3d_view_webcam, webcam_plot_webcam, reps_counter_webcam, pause_btn_webcam, resume_btn_webcam
with gr.Blocks( title="Surf Pop-Up Trainer") as app:
speak_message("Welcome to the Pop-Up Trainer")
gr.Image(type="filepath", value='./Logo/Logo.jpg', show_label=False, show_download_button=False, show_fullscreen_button=False)
gr.Markdown("<h1 style='text-align: center;'>A Surfing Pop-Up Trainer</h1>")
gr.Markdown(
"<h2 style='text-align: center;'>"
"©Schmied Research And Development Pty Ltd. Dr Steven Schmied. Email sschmie@tpg.com.au<br>"
"<i>Seeking only peace and friendship, to teach if we are called upon, to be taught if we are fortunate</i> - Voyager Golden Record<br>"
"To analyse your pop-up, use your webcam or upload an image or video. The analysis will take a few seconds to process.<br>"
"Hold down the left mouse button and drag to rotate three dimension Skeleton. Scroll the mouse wheel to zoom<br>"
"For a good pop-up: 1. keep your hips lower than your shoulders, 2. do not look down and 3. bend your knees. A repetition counted when your knees bends to less than 100 degrees.</h2>")
with gr.Tabs():
with gr.Tab("Image Analysis"):
with gr.Row():
image_input = gr.Image(type="filepath", show_label=False, show_download_button=False, show_fullscreen_button=False)
image_output_2d = gr.Image(show_label=False)
skeleton_3d_img = gr.Plot(show_label=False)
image_input.change(
fn=process_image,
inputs=image_input,
outputs=[image_output_2d, skeleton_3d_img]
)
with gr.Row():
if example_images:
gr.Examples(
examples=example_images,
inputs=image_input,
outputs=[image_output_2d, skeleton_3d_img]
)
with gr.Tab("Video Analysis"):
with gr.Row():
vid_input = gr.Video(show_label=False, sources=["upload", "webcam"])
skeleton_3d_vid = gr.Plot(show_label=False)
with gr.Row():
pose_video = gr.Video(
show_label=False,
show_download_button=True,
show_share_button=True,
format="mp4",
autoplay=True,
visible=False,
interactive=False
)
angle_plot = gr.Image(show_label=False, visible = False)
vid_input.change(
fn=process_video,
inputs=vid_input,
outputs=[skeleton_3d_vid, pose_video, angle_plot]
)
with gr.Row():
if example_videos:
gr.Examples(
examples=example_videos,
inputs=vid_input,
outputs=[skeleton_3d_vid, pose_video, angle_plot]
)
with gr.Tab("Live Webcam Analysis"):
with gr.Row():
with gr.Row():
webcam_output = gr.Image(show_label=False, interactive=False)
skeleton_3d_webcam = gr.Plot(show_label=False) # Removed the trailing comma
with gr.Row():
pose_video_webcam = gr.Video(
show_label=False,
show_download_button=True,
show_share_button=True,
format="mp4",
autoplay=True,
visible=False,
interactive=False
)
angle_plot_webcam = gr.Image(show_label=False, visible = False)
with gr.Row():
reps_counter_webcam = gr.Number(label="Repetitions", value=0, interactive=False)
status_text_webcam = gr.Textbox(label="Status", interactive=False, value="Webcam ready")
with gr.Row():
start_btn_webcam = gr.Button("Start Webcam", variant="primary")
pause_btn_webcam = gr.Button("Pause", variant="secondary", interactive=False)
resume_btn_webcam = gr.Button("Resume", variant="secondary", interactive=False)
stop_btn_webcam = gr.Button("Stop", variant="secondary", interactive=False)
start_btn_webcam.click(
start_webcam,
outputs=[start_btn_webcam, pause_btn_webcam, stop_btn_webcam, status_text_webcam]
).then(
process_webcam_feed,
outputs=[webcam_output, skeleton_3d_webcam, pose_video_webcam, angle_plot_webcam, reps_counter_webcam]
)
pause_btn_webcam.click(
pause_webcam,
outputs=[pause_btn_webcam, resume_btn_webcam, status_text_webcam]
)
resume_btn_webcam.click(
resume_webcam,
outputs=[pause_btn_webcam, resume_btn_webcam, status_text_webcam]
)
stop_btn_webcam.click(
stop_webcam_action,
outputs=[stop_btn_webcam, status_text_webcam]
)
gr.Markdown(
f"<div style='text-align: center; margin-top: 2em'>"
f"Developed with MediaPipe & Gradio | GPU Status: "
f"{'✅ Active (CUDA)' if torch.cuda.is_available() else '✅ Active (MPS)' if torch.backends.mps.is_available() else '❌ Inactive'}</div>"
)
return app
if __name__ == "__main__":
app = build_gradio_interface()
app.launch()