|
|
|
|
|
""" |
|
|
CourtSide-CV - Tennis Analysis Space |
|
|
Hugging Face Gradio App |
|
|
""" |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from ultralytics import YOLO |
|
|
from collections import defaultdict |
|
|
import logging |
|
|
from scipy import interpolate |
|
|
import tempfile |
|
|
import subprocess |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class BallTrackerLinkedIn: |
|
|
"""Tracker optimisé pour détection de balle de tennis""" |
|
|
|
|
|
def __init__(self, model_path): |
|
|
self.ball_model = YOLO(model_path) |
|
|
self.tracks = {} |
|
|
self.frame_idx = 0 |
|
|
self.all_positions = [] |
|
|
self.conf_thresh = 0.05 |
|
|
self.smooth_window = 5 |
|
|
self.max_interpolate_gap = 30 |
|
|
|
|
|
def process_batch(self, frames, progress_callback=None): |
|
|
"""Process un batch de frames pour le tracking""" |
|
|
positions = [] |
|
|
|
|
|
for i, frame in enumerate(frames): |
|
|
if progress_callback: |
|
|
progress_callback((i + 1) / len(frames), desc=f"Detecting ball... {i+1}/{len(frames)}") |
|
|
|
|
|
self.frame_idx = i |
|
|
results = self.ball_model.track( |
|
|
source=frame, |
|
|
conf=self.conf_thresh, |
|
|
classes=[0], |
|
|
imgsz=640, |
|
|
iou=0.5, |
|
|
persist=True, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
ball_pos = None |
|
|
if results[0].boxes is not None and len(results[0].boxes) > 0: |
|
|
best_idx = results[0].boxes.conf.argmax() |
|
|
x1, y1, x2, y2 = results[0].boxes.xyxy[best_idx].tolist() |
|
|
cx = (x1 + x2) / 2 |
|
|
cy = (y1 + y2) / 2 |
|
|
conf = float(results[0].boxes.conf[best_idx]) |
|
|
ball_pos = (cx, cy, conf) |
|
|
|
|
|
positions.append((i, ball_pos)) |
|
|
|
|
|
return positions |
|
|
|
|
|
def interpolate_missing(self, positions): |
|
|
"""Interpoler les positions manquantes""" |
|
|
detected_frames = [] |
|
|
detected_x = [] |
|
|
detected_y = [] |
|
|
|
|
|
for frame_idx, pos in positions: |
|
|
if pos is not None: |
|
|
detected_frames.append(frame_idx) |
|
|
detected_x.append(pos[0]) |
|
|
detected_y.append(pos[1]) |
|
|
|
|
|
if len(detected_frames) < 2: |
|
|
return positions |
|
|
|
|
|
fx = interpolate.interp1d(detected_frames, detected_x, kind='linear', fill_value='extrapolate') |
|
|
fy = interpolate.interp1d(detected_frames, detected_y, kind='linear', fill_value='extrapolate') |
|
|
|
|
|
interpolated = [] |
|
|
for frame_idx, pos in positions: |
|
|
if pos is None: |
|
|
prev_detected = max([f for f in detected_frames if f < frame_idx], default=-999) |
|
|
next_detected = min([f for f in detected_frames if f > frame_idx], default=999) |
|
|
|
|
|
if (frame_idx - prev_detected <= self.max_interpolate_gap and |
|
|
next_detected - frame_idx <= self.max_interpolate_gap): |
|
|
ix = float(fx(frame_idx)) |
|
|
iy = float(fy(frame_idx)) |
|
|
interpolated.append((frame_idx, (ix, iy, 0.0))) |
|
|
else: |
|
|
interpolated.append((frame_idx, None)) |
|
|
else: |
|
|
interpolated.append((frame_idx, pos)) |
|
|
|
|
|
return interpolated |
|
|
|
|
|
def smooth_trajectory(self, positions): |
|
|
"""Lisser la trajectoire avec filtre médian""" |
|
|
smoothed = [] |
|
|
|
|
|
for i, (frame_idx, pos) in enumerate(positions): |
|
|
if pos is None: |
|
|
smoothed.append((frame_idx, None)) |
|
|
continue |
|
|
|
|
|
window_start = max(0, i - self.smooth_window // 2) |
|
|
window_end = min(len(positions), i + self.smooth_window // 2 + 1) |
|
|
|
|
|
window_x = [] |
|
|
window_y = [] |
|
|
for j in range(window_start, window_end): |
|
|
if positions[j][1] is not None: |
|
|
window_x.append(positions[j][1][0]) |
|
|
window_y.append(positions[j][1][1]) |
|
|
|
|
|
if window_x: |
|
|
smooth_x = np.median(window_x) |
|
|
smooth_y = np.median(window_y) |
|
|
conf = pos[2] if len(pos) > 2 else 0.0 |
|
|
smoothed.append((frame_idx, (smooth_x, smooth_y, conf))) |
|
|
else: |
|
|
smoothed.append((frame_idx, pos)) |
|
|
|
|
|
return smoothed |
|
|
|
|
|
|
|
|
class VideoProcessorLinkedIn: |
|
|
"""Processeur vidéo pour Gradio""" |
|
|
|
|
|
def __init__(self, ball_model_path): |
|
|
self.tracker = BallTrackerLinkedIn(ball_model_path) |
|
|
self.person_model = YOLO('yolov8m.pt') |
|
|
self.pose_model = YOLO('yolov8m-pose.pt') |
|
|
|
|
|
self.skeleton_connections = [ |
|
|
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10), |
|
|
(5, 11), (6, 12), (11, 12), (11, 13), (13, 15), |
|
|
(12, 14), (14, 16), (0, 1), (0, 2), (1, 3), (2, 4) |
|
|
] |
|
|
|
|
|
def draw_skeleton(self, frame, keypoints, conf_threshold=0.5): |
|
|
"""Dessine le squelette sur la frame""" |
|
|
joint_color = (0, 255, 0) |
|
|
bone_color = (0, 255, 255) |
|
|
|
|
|
for connection in self.skeleton_connections: |
|
|
kp1_idx, kp2_idx = connection |
|
|
if kp1_idx < len(keypoints) and kp2_idx < len(keypoints): |
|
|
kp1 = keypoints[kp1_idx] |
|
|
kp2 = keypoints[kp2_idx] |
|
|
|
|
|
if len(kp1) > 2 and len(kp2) > 2: |
|
|
if kp1[2] > conf_threshold and kp2[2] > conf_threshold: |
|
|
pt1 = (int(kp1[0]), int(kp1[1])) |
|
|
pt2 = (int(kp2[0]), int(kp2[1])) |
|
|
cv2.line(frame, pt1, pt2, bone_color, 2, cv2.LINE_AA) |
|
|
|
|
|
for keypoint in keypoints: |
|
|
if len(keypoint) > 2 and keypoint[2] > conf_threshold: |
|
|
x, y = int(keypoint[0]), int(keypoint[1]) |
|
|
cv2.circle(frame, (x, y), 4, joint_color, -1, cv2.LINE_AA) |
|
|
cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA) |
|
|
|
|
|
def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2", |
|
|
max_duration=30, progress=gr.Progress(track_tqdm=True)): |
|
|
"""Traiter la vidéo et retourner la version annotée""" |
|
|
|
|
|
if video_path is None: |
|
|
return None, "❌ Veuillez uploader une vidéo" |
|
|
|
|
|
try: |
|
|
logger.info(f"Processing video: {video_path}") |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return None, "❌ Impossible d'ouvrir la vidéo" |
|
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
|
|
|
max_frames = min(total_frames, int(fps * max_duration)) |
|
|
|
|
|
|
|
|
progress(0, desc="Loading video...") |
|
|
frames = [] |
|
|
for i in range(max_frames): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frames.append(frame) |
|
|
cap.release() |
|
|
|
|
|
if len(frames) == 0: |
|
|
return None, "❌ Aucune frame lue" |
|
|
|
|
|
logger.info(f"Loaded {len(frames)} frames ({width}x{height} @ {fps}fps)") |
|
|
|
|
|
|
|
|
progress(0.1, desc="Tracking ball...") |
|
|
positions = self.tracker.process_batch(frames, progress_callback=progress) |
|
|
|
|
|
|
|
|
progress(0.4, desc="Interpolating missing positions...") |
|
|
positions = self.tracker.interpolate_missing(positions) |
|
|
|
|
|
|
|
|
progress(0.5, desc="Smoothing trajectory...") |
|
|
positions = self.tracker.smooth_trajectory(positions) |
|
|
|
|
|
|
|
|
detected = sum(1 for _, p in positions if p is not None) |
|
|
coverage = (detected / len(positions)) * 100 |
|
|
|
|
|
|
|
|
progress(0.6, desc="Rendering annotated video...") |
|
|
|
|
|
|
|
|
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) |
|
|
|
|
|
trail_length = 15 |
|
|
trail_positions = [] |
|
|
|
|
|
for frame_idx, (_, ball_pos) in enumerate(positions): |
|
|
progress(0.6 + 0.3 * (frame_idx / len(frames)), |
|
|
desc=f"Rendering... {frame_idx+1}/{len(frames)}") |
|
|
|
|
|
annotated = frames[frame_idx].copy() |
|
|
|
|
|
|
|
|
if ball_pos is not None: |
|
|
x, y, conf = ball_pos |
|
|
trail_positions.append((int(x), int(y))) |
|
|
if len(trail_positions) > trail_length: |
|
|
trail_positions.pop(0) |
|
|
|
|
|
|
|
|
for i in range(1, len(trail_positions)): |
|
|
alpha = i / len(trail_positions) |
|
|
thickness = int(2 + alpha * 2) |
|
|
cv2.line(annotated, trail_positions[i-1], trail_positions[i], |
|
|
(0, 255, 255), thickness, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
radius = 8 |
|
|
cv2.circle(annotated, (int(x), int(y)), radius + 3, |
|
|
(0, 255, 255), -1, cv2.LINE_AA) |
|
|
cv2.circle(annotated, (int(x), int(y)), radius, |
|
|
(0, 255, 0), -1, cv2.LINE_AA) |
|
|
cv2.circle(annotated, (int(x), int(y)), radius, |
|
|
(255, 255, 255), 2, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
pose_results = self.pose_model(frames[frame_idx], conf=0.3, verbose=False) |
|
|
if pose_results[0].keypoints is not None: |
|
|
for keypoints in pose_results[0].keypoints.data[:2]: |
|
|
keypoints_np = keypoints.cpu().numpy() |
|
|
keypoints_with_conf = [[kp[0], kp[1], kp[2]] for kp in keypoints_np] |
|
|
self.draw_skeleton(annotated, keypoints_with_conf, conf_threshold=0.3) |
|
|
|
|
|
|
|
|
cv2.rectangle(annotated, (0, height-45), (width, height), (0, 0, 0), -1) |
|
|
cv2.putText(annotated, "CourtSide-CV", (15, height-15), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2, cv2.LINE_AA) |
|
|
|
|
|
if ball_pos is not None: |
|
|
status = "TRACKING" if conf > 0.1 else "PREDICTED" |
|
|
color_status = (0, 255, 255) if conf > 0.1 else (255, 200, 0) |
|
|
cv2.putText(annotated, f"Ball: {status}", (width//2 - 60, height-15), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color_status, 2, cv2.LINE_AA) |
|
|
|
|
|
out.write(annotated) |
|
|
|
|
|
out.release() |
|
|
|
|
|
|
|
|
progress(0.95, desc="Finalizing video...") |
|
|
final_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name |
|
|
|
|
|
cmd = [ |
|
|
'ffmpeg', '-i', temp_output, |
|
|
'-c:v', 'libx264', '-preset', 'fast', '-crf', '22', |
|
|
'-pix_fmt', 'yuv420p', '-movflags', '+faststart', |
|
|
final_output, '-y', '-loglevel', 'error' |
|
|
] |
|
|
subprocess.run(cmd, check=True) |
|
|
os.remove(temp_output) |
|
|
|
|
|
message = f""" |
|
|
✅ **Vidéo traitée avec succès!** |
|
|
|
|
|
📊 **Statistiques:** |
|
|
- Frames traitées: {len(frames)} |
|
|
- Couverture balle: {coverage:.1f}% |
|
|
- Résolution: {width}x{height} |
|
|
- FPS: {fps} |
|
|
|
|
|
🎾 Prêt pour LinkedIn! |
|
|
""" |
|
|
|
|
|
logger.info(f"✅ Processing complete: {final_output}") |
|
|
return final_output, message |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing video: {e}", exc_info=True) |
|
|
return None, f"❌ Erreur: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
processor = None |
|
|
|
|
|
def get_processor(): |
|
|
"""Initialise le processeur de manière paresseuse""" |
|
|
global processor |
|
|
if processor is None: |
|
|
logger.info("Initializing processor...") |
|
|
|
|
|
logger.info("Downloading YOLO models...") |
|
|
_ = YOLO('yolov8m.pt') |
|
|
_ = YOLO('yolov8m-pose.pt') |
|
|
logger.info("✅ Models ready!") |
|
|
|
|
|
ball_model_path = 'yolov8m.pt' |
|
|
processor = VideoProcessorLinkedIn(ball_model_path) |
|
|
return processor |
|
|
|
|
|
|
|
|
|
|
|
def process_video_gradio(video, player1, player2, max_duration, progress=gr.Progress(track_tqdm=True)): |
|
|
"""Wrapper pour Gradio""" |
|
|
proc = get_processor() |
|
|
return proc.process_video(video, player1, player2, max_duration, progress) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="🎾 CourtSide-CV - Tennis Analysis", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🎾 CourtSide-CV - Tennis Analysis |
|
|
|
|
|
Analysez vos matchs de tennis avec l'IA ! Cette application utilise la vision par ordinateur pour : |
|
|
- 🎯 **Tracker la balle** en temps réel avec interpolation intelligente |
|
|
- 🤸 **Détecter la pose** des joueurs avec visualisation du squelette |
|
|
- 📊 **Analyser les trajectoires** avec lissage avancé |
|
|
|
|
|
--- |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input = gr.Video(label="📹 Uploadez votre vidéo de tennis") |
|
|
|
|
|
with gr.Row(): |
|
|
player1_input = gr.Textbox( |
|
|
label="👤 Nom Joueur 1 (gauche)", |
|
|
value="PLAYER 1", |
|
|
max_lines=1 |
|
|
) |
|
|
player2_input = gr.Textbox( |
|
|
label="👤 Nom Joueur 2 (droite)", |
|
|
value="PLAYER 2", |
|
|
max_lines=1 |
|
|
) |
|
|
|
|
|
max_duration_input = gr.Slider( |
|
|
minimum=5, |
|
|
maximum=60, |
|
|
value=30, |
|
|
step=5, |
|
|
label="⏱️ Durée maximale (secondes)", |
|
|
info="Pour des raisons de performance, limitez la durée" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("🚀 Analyser la vidéo", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
video_output = gr.Video(label="🎬 Vidéo annotée") |
|
|
status_output = gr.Markdown(label="📊 Résultats") |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### 💡 Conseils |
|
|
- Utilisez des vidéos de **bonne qualité** pour de meilleurs résultats |
|
|
- La **balle doit être visible** dans la majorité des frames |
|
|
- Pour de meilleures performances, limitez à **30 secondes** |
|
|
|
|
|
### 🔧 Technologies |
|
|
- **YOLOv8** pour la détection d'objets et de poses |
|
|
- **ByteTrack** pour le suivi d'objets |
|
|
- **OpenCV** pour le traitement vidéo |
|
|
- **Scipy** pour l'interpolation |
|
|
|
|
|
--- |
|
|
Créé avec ❤️ par CourtSide-CV |
|
|
""") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_video_gradio, |
|
|
inputs=[video_input, player1_input, player2_input, max_duration_input], |
|
|
outputs=[video_output, status_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |