a / app.py
Davidsv's picture
Update app.py
1724828 verified
#!/usr/bin/env python
"""
CourtSide-CV - Tennis Analysis Space
Hugging Face Gradio App
"""
import os
import cv2
import gradio as gr
import numpy as np
from pathlib import Path
from ultralytics import YOLO
from collections import defaultdict
import logging
from scipy import interpolate
import tempfile
import subprocess
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class BallTrackerLinkedIn:
"""Tracker optimisé pour détection de balle de tennis"""
def __init__(self, model_path):
self.ball_model = YOLO(model_path)
self.tracks = {}
self.frame_idx = 0
self.all_positions = []
self.conf_thresh = 0.05
self.smooth_window = 5
self.max_interpolate_gap = 30
def process_batch(self, frames, progress_callback=None):
"""Process un batch de frames pour le tracking"""
positions = []
for i, frame in enumerate(frames):
if progress_callback:
progress_callback((i + 1) / len(frames), desc=f"Detecting ball... {i+1}/{len(frames)}")
self.frame_idx = i
results = self.ball_model.track(
source=frame,
conf=self.conf_thresh,
classes=[0],
imgsz=640,
iou=0.5,
persist=True,
verbose=False
)
ball_pos = None
if results[0].boxes is not None and len(results[0].boxes) > 0:
best_idx = results[0].boxes.conf.argmax()
x1, y1, x2, y2 = results[0].boxes.xyxy[best_idx].tolist()
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
conf = float(results[0].boxes.conf[best_idx])
ball_pos = (cx, cy, conf)
positions.append((i, ball_pos))
return positions
def interpolate_missing(self, positions):
"""Interpoler les positions manquantes"""
detected_frames = []
detected_x = []
detected_y = []
for frame_idx, pos in positions:
if pos is not None:
detected_frames.append(frame_idx)
detected_x.append(pos[0])
detected_y.append(pos[1])
if len(detected_frames) < 2:
return positions
fx = interpolate.interp1d(detected_frames, detected_x, kind='linear', fill_value='extrapolate')
fy = interpolate.interp1d(detected_frames, detected_y, kind='linear', fill_value='extrapolate')
interpolated = []
for frame_idx, pos in positions:
if pos is None:
prev_detected = max([f for f in detected_frames if f < frame_idx], default=-999)
next_detected = min([f for f in detected_frames if f > frame_idx], default=999)
if (frame_idx - prev_detected <= self.max_interpolate_gap and
next_detected - frame_idx <= self.max_interpolate_gap):
ix = float(fx(frame_idx))
iy = float(fy(frame_idx))
interpolated.append((frame_idx, (ix, iy, 0.0)))
else:
interpolated.append((frame_idx, None))
else:
interpolated.append((frame_idx, pos))
return interpolated
def smooth_trajectory(self, positions):
"""Lisser la trajectoire avec filtre médian"""
smoothed = []
for i, (frame_idx, pos) in enumerate(positions):
if pos is None:
smoothed.append((frame_idx, None))
continue
window_start = max(0, i - self.smooth_window // 2)
window_end = min(len(positions), i + self.smooth_window // 2 + 1)
window_x = []
window_y = []
for j in range(window_start, window_end):
if positions[j][1] is not None:
window_x.append(positions[j][1][0])
window_y.append(positions[j][1][1])
if window_x:
smooth_x = np.median(window_x)
smooth_y = np.median(window_y)
conf = pos[2] if len(pos) > 2 else 0.0
smoothed.append((frame_idx, (smooth_x, smooth_y, conf)))
else:
smoothed.append((frame_idx, pos))
return smoothed
class VideoProcessorLinkedIn:
"""Processeur vidéo pour Gradio"""
def __init__(self, ball_model_path):
self.tracker = BallTrackerLinkedIn(ball_model_path)
self.person_model = YOLO('yolov8m.pt')
self.pose_model = YOLO('yolov8m-pose.pt')
self.skeleton_connections = [
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
(5, 11), (6, 12), (11, 12), (11, 13), (13, 15),
(12, 14), (14, 16), (0, 1), (0, 2), (1, 3), (2, 4)
]
def draw_skeleton(self, frame, keypoints, conf_threshold=0.5):
"""Dessine le squelette sur la frame"""
joint_color = (0, 255, 0)
bone_color = (0, 255, 255)
for connection in self.skeleton_connections:
kp1_idx, kp2_idx = connection
if kp1_idx < len(keypoints) and kp2_idx < len(keypoints):
kp1 = keypoints[kp1_idx]
kp2 = keypoints[kp2_idx]
if len(kp1) > 2 and len(kp2) > 2:
if kp1[2] > conf_threshold and kp2[2] > conf_threshold:
pt1 = (int(kp1[0]), int(kp1[1]))
pt2 = (int(kp2[0]), int(kp2[1]))
cv2.line(frame, pt1, pt2, bone_color, 2, cv2.LINE_AA)
for keypoint in keypoints:
if len(keypoint) > 2 and keypoint[2] > conf_threshold:
x, y = int(keypoint[0]), int(keypoint[1])
cv2.circle(frame, (x, y), 4, joint_color, -1, cv2.LINE_AA)
cv2.circle(frame, (x, y), 4, (255, 255, 255), 1, cv2.LINE_AA)
def process_video(self, video_path, player1_name="PLAYER 1", player2_name="PLAYER 2",
max_duration=30, progress=gr.Progress(track_tqdm=True)):
"""Traiter la vidéo et retourner la version annotée"""
if video_path is None:
return None, "❌ Veuillez uploader une vidéo"
try:
logger.info(f"Processing video: {video_path}")
# Ouvrir la vidéo
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "❌ Impossible d'ouvrir la vidéo"
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Limiter la durée
max_frames = min(total_frames, int(fps * max_duration))
# Lire toutes les frames
progress(0, desc="Loading video...")
frames = []
for i in range(max_frames):
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
if len(frames) == 0:
return None, "❌ Aucune frame lue"
logger.info(f"Loaded {len(frames)} frames ({width}x{height} @ {fps}fps)")
# Phase 1: Tracking de la balle
progress(0.1, desc="Tracking ball...")
positions = self.tracker.process_batch(frames, progress_callback=progress)
# Phase 2: Interpolation
progress(0.4, desc="Interpolating missing positions...")
positions = self.tracker.interpolate_missing(positions)
# Phase 3: Lissage
progress(0.5, desc="Smoothing trajectory...")
positions = self.tracker.smooth_trajectory(positions)
# Stats
detected = sum(1 for _, p in positions if p is not None)
coverage = (detected / len(positions)) * 100
# Phase 4: Rendu vidéo
progress(0.6, desc="Rendering annotated video...")
# Créer fichier de sortie temporaire
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
trail_length = 15
trail_positions = []
for frame_idx, (_, ball_pos) in enumerate(positions):
progress(0.6 + 0.3 * (frame_idx / len(frames)),
desc=f"Rendering... {frame_idx+1}/{len(frames)}")
annotated = frames[frame_idx].copy()
# Dessiner la balle et sa trajectoire
if ball_pos is not None:
x, y, conf = ball_pos
trail_positions.append((int(x), int(y)))
if len(trail_positions) > trail_length:
trail_positions.pop(0)
# Trail
for i in range(1, len(trail_positions)):
alpha = i / len(trail_positions)
thickness = int(2 + alpha * 2)
cv2.line(annotated, trail_positions[i-1], trail_positions[i],
(0, 255, 255), thickness, cv2.LINE_AA)
# Balle
radius = 8
cv2.circle(annotated, (int(x), int(y)), radius + 3,
(0, 255, 255), -1, cv2.LINE_AA)
cv2.circle(annotated, (int(x), int(y)), radius,
(0, 255, 0), -1, cv2.LINE_AA)
cv2.circle(annotated, (int(x), int(y)), radius,
(255, 255, 255), 2, cv2.LINE_AA)
# Détection de pose
pose_results = self.pose_model(frames[frame_idx], conf=0.3, verbose=False)
if pose_results[0].keypoints is not None:
for keypoints in pose_results[0].keypoints.data[:2]:
keypoints_np = keypoints.cpu().numpy()
keypoints_with_conf = [[kp[0], kp[1], kp[2]] for kp in keypoints_np]
self.draw_skeleton(annotated, keypoints_with_conf, conf_threshold=0.3)
# Overlay
cv2.rectangle(annotated, (0, height-45), (width, height), (0, 0, 0), -1)
cv2.putText(annotated, "CourtSide-CV", (15, height-15),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2, cv2.LINE_AA)
if ball_pos is not None:
status = "TRACKING" if conf > 0.1 else "PREDICTED"
color_status = (0, 255, 255) if conf > 0.1 else (255, 200, 0)
cv2.putText(annotated, f"Ball: {status}", (width//2 - 60, height-15),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color_status, 2, cv2.LINE_AA)
out.write(annotated)
out.release()
# Conversion finale en H.264 pour compatibilité
progress(0.95, desc="Finalizing video...")
final_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
cmd = [
'ffmpeg', '-i', temp_output,
'-c:v', 'libx264', '-preset', 'fast', '-crf', '22',
'-pix_fmt', 'yuv420p', '-movflags', '+faststart',
final_output, '-y', '-loglevel', 'error'
]
subprocess.run(cmd, check=True)
os.remove(temp_output)
message = f"""
✅ **Vidéo traitée avec succès!**
📊 **Statistiques:**
- Frames traitées: {len(frames)}
- Couverture balle: {coverage:.1f}%
- Résolution: {width}x{height}
- FPS: {fps}
🎾 Prêt pour LinkedIn!
"""
logger.info(f"✅ Processing complete: {final_output}")
return final_output, message
except Exception as e:
logger.error(f"Error processing video: {e}", exc_info=True)
return None, f"❌ Erreur: {str(e)}"
# Variable globale pour le processeur (initialisé paresseusement)
processor = None
def get_processor():
"""Initialise le processeur de manière paresseuse"""
global processor
if processor is None:
logger.info("Initializing processor...")
# Télécharger les modèles
logger.info("Downloading YOLO models...")
_ = YOLO('yolov8m.pt')
_ = YOLO('yolov8m-pose.pt')
logger.info("✅ Models ready!")
ball_model_path = 'yolov8m.pt'
processor = VideoProcessorLinkedIn(ball_model_path)
return processor
# Interface Gradio
def process_video_gradio(video, player1, player2, max_duration, progress=gr.Progress(track_tqdm=True)):
"""Wrapper pour Gradio"""
proc = get_processor()
return proc.process_video(video, player1, player2, max_duration, progress)
# Créer l'interface
with gr.Blocks(title="🎾 CourtSide-CV - Tennis Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎾 CourtSide-CV - Tennis Analysis
Analysez vos matchs de tennis avec l'IA ! Cette application utilise la vision par ordinateur pour :
- 🎯 **Tracker la balle** en temps réel avec interpolation intelligente
- 🤸 **Détecter la pose** des joueurs avec visualisation du squelette
- 📊 **Analyser les trajectoires** avec lissage avancé
---
""")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="📹 Uploadez votre vidéo de tennis")
with gr.Row():
player1_input = gr.Textbox(
label="👤 Nom Joueur 1 (gauche)",
value="PLAYER 1",
max_lines=1
)
player2_input = gr.Textbox(
label="👤 Nom Joueur 2 (droite)",
value="PLAYER 2",
max_lines=1
)
max_duration_input = gr.Slider(
minimum=5,
maximum=60,
value=30,
step=5,
label="⏱️ Durée maximale (secondes)",
info="Pour des raisons de performance, limitez la durée"
)
submit_btn = gr.Button("🚀 Analyser la vidéo", variant="primary", size="lg")
with gr.Column():
video_output = gr.Video(label="🎬 Vidéo annotée")
status_output = gr.Markdown(label="📊 Résultats")
gr.Markdown("""
---
### 💡 Conseils
- Utilisez des vidéos de **bonne qualité** pour de meilleurs résultats
- La **balle doit être visible** dans la majorité des frames
- Pour de meilleures performances, limitez à **30 secondes**
### 🔧 Technologies
- **YOLOv8** pour la détection d'objets et de poses
- **ByteTrack** pour le suivi d'objets
- **OpenCV** pour le traitement vidéo
- **Scipy** pour l'interpolation
---
Créé avec ❤️ par CourtSide-CV
""")
submit_btn.click(
fn=process_video_gradio,
inputs=[video_input, player1_input, player2_input, max_duration_input],
outputs=[video_output, status_output]
)
# Lancer l'application
if __name__ == "__main__":
demo.launch()