File size: 12,054 Bytes
bdb955e
 
 
 
 
 
 
41f1119
 
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41f1119
 
 
 
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41f1119
 
 
 
 
 
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41f1119
 
 
 
 
 
 
 
bdb955e
 
 
41f1119
bdb955e
41f1119
bdb955e
 
 
 
41f1119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41f1119
 
bdb955e
 
 
41f1119
 
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import argparse
import cv2
import numpy as np
import torch
from pathlib import Path
import time
import traceback
# Import YOLO
from ultralytics import YOLO

# Assurez-vous que le répertoire tvcalib est dans le PYTHONPATH
# ou exécutez depuis le répertoire tvcalib_image_processor
from tvcalib.infer.module import TvCalibInferModule
# Importer les fonctions de visualisation et les constantes de modulation
from visualizer import (
    create_minimap_view, 
    create_minimap_with_offset_skeletons,
    DYNAMIC_SCALE_MIN_MODULATION, # Importer les constantes
    DYNAMIC_SCALE_MAX_MODULATION
)
# Importer la fonction d'extraction des données joueurs
from pose_estimator import get_player_data

# Constantes
IMAGE_SHAPE = (720, 1280)  # Hauteur, Largeur
SEGMENTATION_MODEL_PATH = Path("models/segmentation/train_59.pt")
# Chemin vers le modèle YOLO pour la détection du ballon
YOLO_MODEL_PATH = Path("models/detection/yolo_football.pt") 
# Index de classe pour le ballon (basé sur votre exemple)
BALL_CLASS_INDEX = 2 

def preprocess_image_tvcalib(image_bgr):
    """Prétraite l'image BGR pour TvCalib et retourne le tenseur et l'image RGB redimensionnée."""
    if image_bgr is None:
        raise ValueError("Impossible de charger l'image")

    # 1. Redimensionner en 720p si nécessaire
    h, w = image_bgr.shape[:2]
    if h != IMAGE_SHAPE[0] or w != IMAGE_SHAPE[1]:
        print(f"Redimensionnement de l'image vers {IMAGE_SHAPE[1]}x{IMAGE_SHAPE[0]}")
        image_bgr_resized = cv2.resize(image_bgr, (IMAGE_SHAPE[1], IMAGE_SHAPE[0]), interpolation=cv2.INTER_LINEAR)
    else:
        image_bgr_resized = image_bgr

    # 2. Convertir en RGB (pour TvCalib ET pour la visualisation originale)
    image_rgb_resized = cv2.cvtColor(image_bgr_resized, cv2.COLOR_BGR2RGB)

    # 3. Normalisation spécifique pour le modèle pré-entraîné (pour TvCalib)
    image_tensor = torch.from_numpy(image_rgb_resized).float()
    image_tensor = image_tensor.permute(2, 0, 1)  # HWC -> CHW
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    image_tensor = (image_tensor / 255.0 - mean) / std

    # 4. Déplacer sur le bon device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_tensor = image_tensor.to(device)

    # Retourner le tenseur pour TvCalib, l'image BGR et RGB redimensionnée
    return image_tensor, image_bgr_resized, image_rgb_resized

def main():
    parser = argparse.ArgumentParser(description="Exécute la méthode TvCalib sur une seule image.")
    parser.add_argument("image_path", type=str, help="Chemin vers l'image à traiter.")
    parser.add_argument("--output_homography", type=str, default=None, help="Chemin optionnel pour sauvegarder la matrice d'homographie (.npy).")
    parser.add_argument("--optim_steps", type=int, default=500, help="Nombre d'étapes d'optimisation pour la calibration (l'arrêt anticipé est désactivé).")
    parser.add_argument("--target_avg_scale", type=float, default=1, 
                        help="Facteur d'échelle MOYEN CIBLE pour dessiner les squelettes sur la minimap (défaut: 0.35). Le script ajuste l'échelle de base pour tenter d'atteindre cette moyenne.")

    args = parser.parse_args()

    if not Path(args.image_path).exists():
        print(f"Erreur : Fichier image introuvable : {args.image_path}")
        return

    if not SEGMENTATION_MODEL_PATH.exists():
        print(f"Erreur : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
        print("Assurez-vous d'avoir copié train_59.pt dans le dossier models/segmentation/")
        return

    # Vérifier l'existence du modèle YOLO
    if not YOLO_MODEL_PATH.exists():
        print(f"Erreur : Modèle YOLO introuvable : {YOLO_MODEL_PATH}")
        print(f"Assurez-vous d'avoir téléchargé {YOLO_MODEL_PATH.name} et de l'avoir placé dans {YOLO_MODEL_PATH.parent}/")
        return

    print("Initialisation de TvCalibInferModule...")
    try:
        model = TvCalibInferModule(
            segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
            image_shape=IMAGE_SHAPE,
            optim_steps=args.optim_steps,
            lens_dist=False  # Gardons cela simple pour l'instant
        )
        print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device}")
    except Exception as e:
        print(f"Erreur lors de l'initialisation du modèle : {e}")
        return

    print(f"Traitement de l'image : {args.image_path}")
    try:
        # Vérification supplémentaire avant imread
        image_path_obj = Path(args.image_path)
        absolute_path = image_path_obj.resolve()
        print(f"Tentative de lecture de l'image via cv2.imread depuis : {absolute_path}")
        if not absolute_path.is_file():
             print(f"ERREUR : Le chemin absolu {absolute_path} ne pointe pas vers un fichier existant juste avant imread !")
             return # Arrêter ici si le fichier n'est pas trouvé à ce stade

        # Charger l'image (en BGR par défaut avec OpenCV)
        image_bgr_orig = cv2.imread(args.image_path)
        if image_bgr_orig is None:
            raise FileNotFoundError(f"Impossible de lire le fichier image: {args.image_path} (vérifié comme existant juste avant, problème avec imread)")

        # Prétraiter l'image pour TvCalib (redimensionne aussi)
        start_preprocess = time.time()
        image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(image_bgr_orig)
        print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")

        # --- Détection du ballon avec YOLO ---
        print("\nChargement du modèle YOLO et détection du ballon...")
        start_yolo = time.time()
        ball_ref_point_img = None # Point de référence du ballon sur l'image originale redimensionnée
        try:
            yolo_model = YOLO(YOLO_MODEL_PATH)
            # Utiliser l'image BGR redimensionnée pour YOLO
            results = yolo_model.predict(image_bgr_resized, classes=[BALL_CLASS_INDEX], verbose=False) 
            
            if results and len(results[0].boxes) > 0:
                # Prendre la détection avec la plus haute confiance
                best_ball_box = results[0].boxes[results[0].boxes.conf.argmax()]
                x1, y1, x2, y2 = map(int, best_ball_box.xyxy[0].tolist())
                conf = best_ball_box.conf[0].item()
                
                # Calculer le point de référence (centre bas de la bbox)
                ball_ref_point_img = np.array([(x1 + x2) / 2, y2], dtype=np.float32)
                print(f"  ✓ Ballon trouvé (conf: {conf:.2f}) à la bbox [{x1},{y1},{x2},{y2}]. Point réf: {ball_ref_point_img}")
            else:
                print("  Aucun ballon détecté.")
                
        except Exception as e_yolo:
             print(f"  Erreur pendant la détection YOLO : {e_yolo}")
        print(f"Temps de détection YOLO : {time.time() - start_yolo:.3f}s")

        # Exécuter la segmentation
        print("Exécution de la segmentation...")
        start_segment = time.time()
        with torch.no_grad():
            keypoints = model._segment(image_tensor)
        print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")

        # Exécuter la calibration (optimisation)
        print("Exécution de la calibration (optimisation)...")
        start_calibrate = time.time()
        homography = model._calibrate(keypoints)
        print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")

        if homography is not None:
            print("\n--- Homographie Calculée ---")
            if isinstance(homography, torch.Tensor):
                homography_np = homography.detach().cpu().numpy()
            else:
                homography_np = homography
            print(homography_np)

            if args.output_homography:
                try:
                    np.save(args.output_homography, homography_np)
                    print(f"\nHomographie sauvegardée dans : {args.output_homography}")
                except Exception as e:
                    print(f"Erreur lors de la sauvegarde de l'homographie : {e}")

            # --- Extraction des données joueurs --- 
            print("\nExtraction des données joueurs (pose+couleur)...")
            start_pose = time.time()
            player_list = get_player_data(image_bgr_resized) 
            print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
            
            # --- Calcul de l'échelle de base estimée --- 
            print("\nCalcul de l'échelle de base pour atteindre la cible...")
            target_average_scale = args.target_avg_scale
            
            # Calculer la modulation moyenne attendue (hypothèse: joueur moyen au centre Y=0.5)
            # Logique inversée actuelle : MIN + (MAX - MIN) * (1.0 - norm_y)
            avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
                                      (DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
            
            estimated_base_scale = target_average_scale # Valeur par défaut si modulation = 0
            if avg_modulation_expected != 0:
                estimated_base_scale = target_average_scale / avg_modulation_expected
            else:
                print("Avertissement : Modulation moyenne attendue nulle, impossible d'estimer l'échelle de base.")
                
            print(f"  Modulation dynamique moyenne attendue (pour Y=0.5) : {avg_modulation_expected:.3f}")
            print(f"  Échelle de base interne estimée pour cible {target_average_scale:.3f} : {estimated_base_scale:.3f}")

            # --- Génération des DEUX minimaps --- 
            print("\nGénération des minimaps (Originale et Squelettes Décalés)...")
            
            # 1. Minimap avec l'image originale (RGB)
            minimap_original = create_minimap_view(image_rgb_resized, homography_np)
            
            # 2. Minimap avec les squelettes ET LE BALLON
            # Utiliser l'échelle de base ESTIMÉE et passer les coordonnées du ballon
            minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
                player_list, 
                homography_np, 
                base_skeleton_scale=estimated_base_scale, # Utiliser l'estimation
                ball_ref_point_img=ball_ref_point_img     # Passer le point de référence du ballon
            )

            # Afficher la cible et le résultat réel
            if actual_avg_scale is not None:
                print(f"\nÉchelle moyenne CIBLE demandée (--target_avg_scale) : {target_average_scale:.3f}")
                print(f"Échelle moyenne FINALE RÉELLEMENT appliquée (basée sur joueurs réels) : {actual_avg_scale:.3f}")
            
            # --- Affichage des résultats --- 
            print("\nAffichage des résultats. Appuyez sur une touche pour quitter.")
            
            # Afficher la minimap originale
            if minimap_original is not None:
                cv2.imshow("Minimap avec Projection Originale", minimap_original)
            else:
                 print("N'a pas pu générer la minimap originale.")

            # Afficher la minimap avec les squelettes décalés
            if minimap_offset_skeletons is not None:
                cv2.imshow("Minimap avec Squelettes Decales", minimap_offset_skeletons)
            else:
                 print("N'a pas pu générer la minimap squelettes décalés.")

            cv2.waitKey(0) # Attend qu'une touche soit pressée

        else:
            print("\nAucune homographie n'a pu être calculée.")

    except Exception as e:
        print(f"Erreur lors du traitement de l'image : {e}")
        traceback.print_exc()
    finally:
        print("Fermeture des fenêtres OpenCV.")
        cv2.destroyAllWindows()

if __name__ == "__main__":
    main()