RamziBm's picture
feat: Intégration de la détection de ballon avec YOLO et amélioration des vérifications de modèles dans app.py et main.py.
41f1119
import argparse
import cv2
import numpy as np
import torch
from pathlib import Path
import time
import traceback
# Import YOLO
from ultralytics import YOLO
# Assurez-vous que le répertoire tvcalib est dans le PYTHONPATH
# ou exécutez depuis le répertoire tvcalib_image_processor
from tvcalib.infer.module import TvCalibInferModule
# Importer les fonctions de visualisation et les constantes de modulation
from visualizer import (
create_minimap_view,
create_minimap_with_offset_skeletons,
DYNAMIC_SCALE_MIN_MODULATION, # Importer les constantes
DYNAMIC_SCALE_MAX_MODULATION
)
# Importer la fonction d'extraction des données joueurs
from pose_estimator import get_player_data
# Constantes
IMAGE_SHAPE = (720, 1280) # Hauteur, Largeur
SEGMENTATION_MODEL_PATH = Path("models/segmentation/train_59.pt")
# Chemin vers le modèle YOLO pour la détection du ballon
YOLO_MODEL_PATH = Path("models/detection/yolo_football.pt")
# Index de classe pour le ballon (basé sur votre exemple)
BALL_CLASS_INDEX = 2
def preprocess_image_tvcalib(image_bgr):
"""Prétraite l'image BGR pour TvCalib et retourne le tenseur et l'image RGB redimensionnée."""
if image_bgr is None:
raise ValueError("Impossible de charger l'image")
# 1. Redimensionner en 720p si nécessaire
h, w = image_bgr.shape[:2]
if h != IMAGE_SHAPE[0] or w != IMAGE_SHAPE[1]:
print(f"Redimensionnement de l'image vers {IMAGE_SHAPE[1]}x{IMAGE_SHAPE[0]}")
image_bgr_resized = cv2.resize(image_bgr, (IMAGE_SHAPE[1], IMAGE_SHAPE[0]), interpolation=cv2.INTER_LINEAR)
else:
image_bgr_resized = image_bgr
# 2. Convertir en RGB (pour TvCalib ET pour la visualisation originale)
image_rgb_resized = cv2.cvtColor(image_bgr_resized, cv2.COLOR_BGR2RGB)
# 3. Normalisation spécifique pour le modèle pré-entraîné (pour TvCalib)
image_tensor = torch.from_numpy(image_rgb_resized).float()
image_tensor = image_tensor.permute(2, 0, 1) # HWC -> CHW
mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
image_tensor = (image_tensor / 255.0 - mean) / std
# 4. Déplacer sur le bon device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_tensor = image_tensor.to(device)
# Retourner le tenseur pour TvCalib, l'image BGR et RGB redimensionnée
return image_tensor, image_bgr_resized, image_rgb_resized
def main():
parser = argparse.ArgumentParser(description="Exécute la méthode TvCalib sur une seule image.")
parser.add_argument("image_path", type=str, help="Chemin vers l'image à traiter.")
parser.add_argument("--output_homography", type=str, default=None, help="Chemin optionnel pour sauvegarder la matrice d'homographie (.npy).")
parser.add_argument("--optim_steps", type=int, default=500, help="Nombre d'étapes d'optimisation pour la calibration (l'arrêt anticipé est désactivé).")
parser.add_argument("--target_avg_scale", type=float, default=1,
help="Facteur d'échelle MOYEN CIBLE pour dessiner les squelettes sur la minimap (défaut: 0.35). Le script ajuste l'échelle de base pour tenter d'atteindre cette moyenne.")
args = parser.parse_args()
if not Path(args.image_path).exists():
print(f"Erreur : Fichier image introuvable : {args.image_path}")
return
if not SEGMENTATION_MODEL_PATH.exists():
print(f"Erreur : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
print("Assurez-vous d'avoir copié train_59.pt dans le dossier models/segmentation/")
return
# Vérifier l'existence du modèle YOLO
if not YOLO_MODEL_PATH.exists():
print(f"Erreur : Modèle YOLO introuvable : {YOLO_MODEL_PATH}")
print(f"Assurez-vous d'avoir téléchargé {YOLO_MODEL_PATH.name} et de l'avoir placé dans {YOLO_MODEL_PATH.parent}/")
return
print("Initialisation de TvCalibInferModule...")
try:
model = TvCalibInferModule(
segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
image_shape=IMAGE_SHAPE,
optim_steps=args.optim_steps,
lens_dist=False # Gardons cela simple pour l'instant
)
print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device}")
except Exception as e:
print(f"Erreur lors de l'initialisation du modèle : {e}")
return
print(f"Traitement de l'image : {args.image_path}")
try:
# Vérification supplémentaire avant imread
image_path_obj = Path(args.image_path)
absolute_path = image_path_obj.resolve()
print(f"Tentative de lecture de l'image via cv2.imread depuis : {absolute_path}")
if not absolute_path.is_file():
print(f"ERREUR : Le chemin absolu {absolute_path} ne pointe pas vers un fichier existant juste avant imread !")
return # Arrêter ici si le fichier n'est pas trouvé à ce stade
# Charger l'image (en BGR par défaut avec OpenCV)
image_bgr_orig = cv2.imread(args.image_path)
if image_bgr_orig is None:
raise FileNotFoundError(f"Impossible de lire le fichier image: {args.image_path} (vérifié comme existant juste avant, problème avec imread)")
# Prétraiter l'image pour TvCalib (redimensionne aussi)
start_preprocess = time.time()
image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(image_bgr_orig)
print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")
# --- Détection du ballon avec YOLO ---
print("\nChargement du modèle YOLO et détection du ballon...")
start_yolo = time.time()
ball_ref_point_img = None # Point de référence du ballon sur l'image originale redimensionnée
try:
yolo_model = YOLO(YOLO_MODEL_PATH)
# Utiliser l'image BGR redimensionnée pour YOLO
results = yolo_model.predict(image_bgr_resized, classes=[BALL_CLASS_INDEX], verbose=False)
if results and len(results[0].boxes) > 0:
# Prendre la détection avec la plus haute confiance
best_ball_box = results[0].boxes[results[0].boxes.conf.argmax()]
x1, y1, x2, y2 = map(int, best_ball_box.xyxy[0].tolist())
conf = best_ball_box.conf[0].item()
# Calculer le point de référence (centre bas de la bbox)
ball_ref_point_img = np.array([(x1 + x2) / 2, y2], dtype=np.float32)
print(f" ✓ Ballon trouvé (conf: {conf:.2f}) à la bbox [{x1},{y1},{x2},{y2}]. Point réf: {ball_ref_point_img}")
else:
print(" Aucun ballon détecté.")
except Exception as e_yolo:
print(f" Erreur pendant la détection YOLO : {e_yolo}")
print(f"Temps de détection YOLO : {time.time() - start_yolo:.3f}s")
# Exécuter la segmentation
print("Exécution de la segmentation...")
start_segment = time.time()
with torch.no_grad():
keypoints = model._segment(image_tensor)
print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")
# Exécuter la calibration (optimisation)
print("Exécution de la calibration (optimisation)...")
start_calibrate = time.time()
homography = model._calibrate(keypoints)
print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")
if homography is not None:
print("\n--- Homographie Calculée ---")
if isinstance(homography, torch.Tensor):
homography_np = homography.detach().cpu().numpy()
else:
homography_np = homography
print(homography_np)
if args.output_homography:
try:
np.save(args.output_homography, homography_np)
print(f"\nHomographie sauvegardée dans : {args.output_homography}")
except Exception as e:
print(f"Erreur lors de la sauvegarde de l'homographie : {e}")
# --- Extraction des données joueurs ---
print("\nExtraction des données joueurs (pose+couleur)...")
start_pose = time.time()
player_list = get_player_data(image_bgr_resized)
print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
# --- Calcul de l'échelle de base estimée ---
print("\nCalcul de l'échelle de base pour atteindre la cible...")
target_average_scale = args.target_avg_scale
# Calculer la modulation moyenne attendue (hypothèse: joueur moyen au centre Y=0.5)
# Logique inversée actuelle : MIN + (MAX - MIN) * (1.0 - norm_y)
avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
(DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
estimated_base_scale = target_average_scale # Valeur par défaut si modulation = 0
if avg_modulation_expected != 0:
estimated_base_scale = target_average_scale / avg_modulation_expected
else:
print("Avertissement : Modulation moyenne attendue nulle, impossible d'estimer l'échelle de base.")
print(f" Modulation dynamique moyenne attendue (pour Y=0.5) : {avg_modulation_expected:.3f}")
print(f" Échelle de base interne estimée pour cible {target_average_scale:.3f} : {estimated_base_scale:.3f}")
# --- Génération des DEUX minimaps ---
print("\nGénération des minimaps (Originale et Squelettes Décalés)...")
# 1. Minimap avec l'image originale (RGB)
minimap_original = create_minimap_view(image_rgb_resized, homography_np)
# 2. Minimap avec les squelettes ET LE BALLON
# Utiliser l'échelle de base ESTIMÉE et passer les coordonnées du ballon
minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
player_list,
homography_np,
base_skeleton_scale=estimated_base_scale, # Utiliser l'estimation
ball_ref_point_img=ball_ref_point_img # Passer le point de référence du ballon
)
# Afficher la cible et le résultat réel
if actual_avg_scale is not None:
print(f"\nÉchelle moyenne CIBLE demandée (--target_avg_scale) : {target_average_scale:.3f}")
print(f"Échelle moyenne FINALE RÉELLEMENT appliquée (basée sur joueurs réels) : {actual_avg_scale:.3f}")
# --- Affichage des résultats ---
print("\nAffichage des résultats. Appuyez sur une touche pour quitter.")
# Afficher la minimap originale
if minimap_original is not None:
cv2.imshow("Minimap avec Projection Originale", minimap_original)
else:
print("N'a pas pu générer la minimap originale.")
# Afficher la minimap avec les squelettes décalés
if minimap_offset_skeletons is not None:
cv2.imshow("Minimap avec Squelettes Decales", minimap_offset_skeletons)
else:
print("N'a pas pu générer la minimap squelettes décalés.")
cv2.waitKey(0) # Attend qu'une touche soit pressée
else:
print("\nAucune homographie n'a pu être calculée.")
except Exception as e:
print(f"Erreur lors du traitement de l'image : {e}")
traceback.print_exc()
finally:
print("Fermeture des fenêtres OpenCV.")
cv2.destroyAllWindows()
if __name__ == "__main__":
main()