Spaces:
Sleeping
Sleeping
import argparse | |
import sys | |
import cv2 | |
import numpy as np | |
from rich.console import Console | |
from rich.panel import Panel | |
from rich.align import Align | |
from rich.layout import Layout | |
from pyfiglet import Figlet | |
import mediapipe as mp | |
from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
from PoseClassification.pose_classifier import PoseClassifier | |
from PoseClassification.utils import EMADictSmoothing | |
from PoseClassification.visualize import PoseClassificationVisualizer | |
# For cross-platform compatibility | |
try: | |
import msvcrt # Windows | |
except ImportError: | |
import termios # Unix-like | |
import tty | |
def getch(): | |
if sys.platform == "win32": | |
return msvcrt.getch().decode("utf-8") | |
else: | |
fd = sys.stdin.fileno() | |
old_settings = termios.tcgetattr(fd) | |
try: | |
tty.setraw(sys.stdin.fileno()) | |
ch = sys.stdin.read(1) | |
finally: | |
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) | |
return ch | |
def create_ascii_title(text): | |
f = Figlet(font="isometric2") | |
return f.renderText(text) | |
def main(input_source, display=False, output_file=None): | |
console = Console() | |
layout = Layout() | |
# Create ASCII title | |
ascii_title = create_ascii_title("YOGAI") | |
# Create the layout | |
layout.split( | |
Layout(Panel(Align.center(ascii_title), border_style="bold blue"), size=15), | |
Layout(name="main"), | |
) | |
is_live = input_source == "live" | |
if is_live: | |
layout["main"].update( | |
Panel( | |
"Processing live video from camera", | |
title="Video Classification", | |
border_style="bold blue", | |
) | |
) | |
else: | |
layout["main"].update( | |
Panel( | |
f"Processing video: {input_source}", | |
title="Video Classification", | |
border_style="bold blue", | |
) | |
) | |
console.print(layout) | |
# Initialize pose tracker, embedder, and classifier | |
mp_pose = mp.solutions.pose | |
pose_tracker = mp_pose.Pose() | |
pose_embedder = FullBodyPoseEmbedding() | |
pose_classifier = PoseClassifier( | |
pose_samples_folder="data/yoga_poses_csvs_out", | |
pose_embedder=pose_embedder, | |
top_n_by_max_distance=30, | |
top_n_by_mean_distance=10, | |
) | |
pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) | |
# Open the video source | |
if is_live: | |
video = cv2.VideoCapture(0) | |
fps = 30 # Assume 30 fps for live video | |
total_frames = float("inf") # Infinite frames for live video | |
else: | |
video = cv2.VideoCapture(input_source) | |
fps = video.get(cv2.CAP_PROP_FPS) | |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Initialize pose timings (use lowercase for keys) | |
pose_timings = { | |
"chair": 0, | |
"cobra": 0, | |
"dog": 0, | |
"plank": 0, | |
"goddess": 0, | |
"tree": 0, | |
"warrior": 0, | |
"no pose detected": 0, | |
"fallen": 0, | |
} | |
frame_count = 0 | |
while True: | |
ret, frame = video.read() | |
if not ret: | |
if is_live: | |
console.print( | |
"[bold red]Error reading from camera. Exiting...[/bold red]" | |
) | |
break | |
# Process the frame | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
result = pose_tracker.process(image=frame_rgb) | |
if result.pose_landmarks is not None: | |
# Draw landmarks on the frame | |
mp.solutions.drawing_utils.draw_landmarks( | |
frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
) | |
frame_height, frame_width = frame.shape[0], frame.shape[1] | |
pose_landmarks = np.array( | |
[ | |
[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
for lmk in result.pose_landmarks.landmark | |
], | |
dtype=np.float32, | |
) | |
# Classify the pose | |
pose_classification = pose_classifier(pose_landmarks) | |
pose_classification_filtered = pose_classification_filter( | |
pose_classification | |
) | |
# Update pose timings (only for the pose with highest confidence) | |
max_pose = max( | |
pose_classification_filtered, key=pose_classification_filtered.get | |
).lower() | |
pose_timings[max_pose] += 1 / fps | |
else: | |
pose_timings["no pose detected"] += 1 / fps | |
frame_count += 1 | |
if frame_count % 30 == 0: # Update every 30 frames | |
panel_content = ( | |
f"[bold]Chair:[/bold] {pose_timings['chair']:.2f}s\n" | |
f"[bold]Cobra:[/bold] {pose_timings['cobra']:.2f}s\n" | |
f"[bold]Dog:[/bold] {pose_timings['dog']:.2f}s\n" | |
f"[bold]Plank:[/bold] {pose_timings['plank']:.2f}s\n" | |
f"[bold]Goddess:[/bold] {pose_timings['goddess']:.2f}s\n" | |
f"[bold]Tree:[/bold] {pose_timings['tree']:.2f}s\n" | |
f"[bold]Warrior:[/bold] {pose_timings['warrior']:.2f}s\n" | |
f"---\n" | |
f"[bold]No pose detected:[/bold] {pose_timings['no pose detected']:.2f}s\n" | |
f"[bold]Fallen:[/bold] {pose_timings['fallen']:.2f}s" | |
) | |
if not is_live: | |
panel_content += f"\n\nProcessed {frame_count}/{total_frames} frames" | |
layout["main"].update( | |
Panel( | |
panel_content, | |
title="Classification Results", | |
border_style="bold green", | |
) | |
) | |
console.print(layout) | |
if display: | |
cv2.imshow("Video", frame) | |
if cv2.waitKey(1) & 0xFF == ord("q"): | |
break | |
video.release() | |
if display: | |
cv2.destroyAllWindows() | |
# Final results | |
final_panel_content = ( | |
f"[bold]Chair:[/bold] {pose_timings['chair']:.2f}s\n" | |
f"[bold]Cobra:[/bold] {pose_timings['cobra']:.2f}s\n" | |
f"[bold]Dog:[/bold] {pose_timings['dog']:.2f}s\n" | |
f"[bold]Plank:[/bold] {pose_timings['plank']:.2f}s\n" | |
f"[bold]Goddess:[/bold] {pose_timings['goddess']:.2f}s\n" | |
f"[bold]Tree:[/bold] {pose_timings['tree']:.2f}s\n" | |
f"[bold]Warrior:[/bold] {pose_timings['warrior']:.2f}s\n" | |
f"---\n" | |
f"[bold]No pose detected:[/bold] {pose_timings['no pose detected']:.2f}s\n" | |
f"[bold]Fallen:[/bold] {pose_timings['fallen']:.2f}s" | |
) | |
layout["main"].update( | |
Panel( | |
final_panel_content, | |
title="Final Classification Results", | |
border_style="bold green", | |
) | |
) | |
console.print(layout) | |
if output_file: | |
console.print(f"[green]Output saved to: {output_file}[/green]") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Classify poses in a video file or from live camera." | |
) | |
parser.add_argument("input", help="Input video file or 'live' for camera feed") | |
parser.add_argument( | |
"--display", action="store_true", help="Display the video with detected poses" | |
) | |
parser.add_argument("--output", help="Output video file") | |
if len(sys.argv) == 1: | |
parser.print_help(sys.stderr) | |
sys.exit(1) | |
args = parser.parse_args() | |
main(args.input, args.display, args.output) | |