File size: 11,172 Bytes
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
from pathlib import Path

# --- Global variables for models and processor (lazy loading) ---
person_processor = None
person_model = None
pose_processor = None
pose_model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Pose Estimator: Using device: {device}")

# --- Constantes pour la couleur et le dessin ---
# Utilisation de tuples BGR pour les couleurs
DEFAULT_MARKER_COLOR = (255, 255, 255) # Blanc
MIN_PIXELS_FOR_COLOR = 20 # Nombre minimum de pixels valides dans la ROI pour tenter de calculer la couleur
CONFIDENCE_THRESHOLD_KEYPOINTS = 0.3 # Seuil pour considérer un keypoint fiable pour la ROI et le dessin
SKELETON_THICKNESS = 2

# Définition des segments du squelette (indices COCO 0-16)
# 0:Nose, 1:L_Eye, 2:R_Eye, 3:L_Ear, 4:R_Ear, 5:L_Shoulder, 6:R_Shoulder, 
# 7:L_Elbow, 8:R_Elbow, 9:L_Wrist, 10:R_Wrist, 11:L_Hip, 12:R_Hip, 
# 13:L_Knee, 14:R_Knee, 15:L_Ankle, 16:R_Ankle
SKELETON_EDGES = [
    # Tête
    (0, 1), (0, 2), (1, 3), (2, 4),
    # Torse
    (5, 6), (5, 11), (6, 12), (11, 12),
    # Bras Gauche
    (5, 7), (7, 9),
    # Bras Droit
    (6, 8), (8, 10),
    # Jambe Gauche
    (11, 13), (13, 15),
    # Jambe Droite
    (12, 14), (14, 16)
]

# Indices des keypoints pour le torse et les chevilles
TORSO_KP_INDICES = [5, 6, 11, 12] # Épaules, Hanches
LEFT_ANKLE_KP_INDEX = 15
RIGHT_ANKLE_KP_INDEX = 16

def _load_models():
    """Loads the models if they haven't been loaded yet."""
    global person_processor, person_model, pose_processor, pose_model
    
    if person_processor is None:
        print("Loading RTDetr person detector model...")
        person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
        person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
        print("✓ RTDetr loaded.")
        
    if pose_processor is None:
        print("Loading ViTPose pose estimation model...")
        pose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
        pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
        print("✓ ViTPose loaded.")

def _is_color_greenish(bgr_pixel, threshold=10):
    b, g, r = bgr_pixel
    return g > b + threshold and g > r + threshold

def _is_color_grayscale(bgr_pixel, tolerance=30):
     b, g, r = bgr_pixel
     min_val, max_val = min(b, g, r), max(b, g, r)
     is_dark = max_val < 50
     is_light = min_val > 200
     is_low_saturation = (max_val - min_val) < tolerance
     return is_dark or is_light or is_low_saturation

def _get_average_color(roi_bgr):
    """Calcule la couleur moyenne d'une ROI après filtrage."""
    if roi_bgr is None or roi_bgr.size == 0:
        return DEFAULT_MARKER_COLOR

    try:
        pixels = roi_bgr.reshape(-1, 3)
        valid_pixels = []
        for pixel in pixels:
            if not _is_color_greenish(pixel) and not _is_color_grayscale(pixel):
                valid_pixels.append(pixel)
        
        if len(valid_pixels) < MIN_PIXELS_FOR_COLOR:
            return DEFAULT_MARKER_COLOR

        avg_color = np.mean(valid_pixels, axis=0)
        return tuple(map(int, avg_color))

    except Exception as e:
        print(f"  Erreur calcul couleur moyenne: {e}. Utilisation couleur défaut.")
        return DEFAULT_MARKER_COLOR

def get_player_data(image_bgr: np.ndarray) -> list:
    """
    Detects persons, estimates pose, calculates average torso color, 
    and returns a list of data for each player.
    
    Args:
        image_bgr: The input image in BGR format (NumPy array).

    Returns:
        A list of dictionaries, each containing:
        {
            'keypoints': np.ndarray (17, 2), 
            'scores': np.ndarray (17,),
            'bbox': np.ndarray (4,) [x1, y1, x2, y2],
            'avg_color': tuple (b, g, r)
        }
        Returns an empty list if no persons are detected or an error occurs.
    """
    _load_models() 
    player_list = []
    height, width = image_bgr.shape[:2]

    try:
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        image_pil = Image.fromarray(image_rgb)

        # --- Stage 1: Detect humans ---
        inputs_det = person_processor(images=image_pil, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs_det = person_model(**inputs_det)
        results_det = person_processor.post_process_object_detection(
            outputs_det, target_sizes=torch.tensor([(height, width)]), threshold=0.5
        )
        result_det = results_det[0]
        person_boxes = result_det["boxes"][result_det["labels"] == 0].cpu().numpy()

        if len(person_boxes) == 0:
            print("No persons detected.")
            return player_list

        person_boxes_coco = person_boxes.copy()
        person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]
        person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]

        # --- Stage 2: Detect keypoints ---
        inputs_pose = pose_processor(image_pil, boxes=[person_boxes_coco], return_tensors="pt").to(device)
        with torch.no_grad():
            outputs_pose = pose_model(**inputs_pose)
        pose_results = pose_processor.post_process_pose_estimation(outputs_pose, boxes=[person_boxes_coco])
        image_pose_result = pose_results[0] 

        if not image_pose_result:
             print("Pose estimation did not return results.")
             return player_list
        
        # --- Stage 3: Process each person --- 
        for i, person_box_xyxy in enumerate(person_boxes):
            if i >= len(image_pose_result): continue 
            
            pose_result = image_pose_result[i]
            xy = pose_result['keypoints'].cpu().numpy()
            scores = pose_result['scores'].cpu().numpy()

            # Ensure xy shape is correct before proceeding
            if xy.shape != (17, 2):
                print(f"Person {i}: Unexpected keypoints shape {xy.shape}, skipping.")
                continue

            # -- Define Torso ROI --
            reliable_torso_keypoints = xy[TORSO_KP_INDICES][scores[TORSO_KP_INDICES] > CONFIDENCE_THRESHOLD_KEYPOINTS]
            x1_box, y1_box, x2_box, y2_box = map(int, person_box_xyxy)
            box_h = y2_box - y1_box
            box_w = x2_box - x1_box
            if len(reliable_torso_keypoints) >= 3:
                min_x_kp = int(np.min(reliable_torso_keypoints[:, 0]))
                max_x_kp = int(np.max(reliable_torso_keypoints[:, 0]))
                min_y_kp = int(np.min(reliable_torso_keypoints[:, 1]))
                max_y_kp = int(np.max(reliable_torso_keypoints[:, 1]))
                roi_x1 = max(x1_box, min_x_kp - 5); roi_y1 = max(y1_box, min_y_kp - 5)
                roi_x2 = min(x2_box, max_x_kp + 5); roi_y2 = min(y2_box, max_y_kp + 5)
            else: 
                roi_x1 = x1_box; roi_y1 = y1_box + int(0.1 * box_h) 
                roi_x2 = x2_box; roi_y2 = y1_box + int(0.6 * box_h) 
            roi_x1 = max(0, roi_x1); roi_y1 = max(0, roi_y1)
            roi_x2 = min(width, roi_x2); roi_y2 = min(height, roi_y2)

            # -- Extract Average Color --
            avg_color = DEFAULT_MARKER_COLOR
            if roi_y2 > roi_y1 and roi_x2 > roi_x1: 
                torso_roi = image_bgr[roi_y1:roi_y2, roi_x1:roi_x2]
                avg_color = _get_average_color(torso_roi)
            # else: # Pas besoin de message si ROI invalide, couleur par défaut suffit
                # print(f"Person {i}: Invalid ROI, using default color.")
            
            # -- Store player data --
            player_data = {
                'keypoints': xy,
                'scores': scores,
                'bbox': person_box_xyxy, # Utiliser la bbox originale xyxy
                'avg_color': avg_color
            }
            player_list.append(player_data)

    except Exception as e:
        print(f"Error during player data extraction: {e}")
        import traceback
        traceback.print_exc()
        # Retourner une liste vide en cas d'erreur majeure
        return []

    return player_list

# Example usage (optional, for testing the module directly)
if __name__ == '__main__':
    test_image_path = 'img3.png' 
    
    if not Path(test_image_path).exists():
         print(f"Test image not found: {test_image_path}")
    else:
        print(f"Testing player data extraction with image: {test_image_path}")
        input_img = cv2.imread(test_image_path)
        
        if input_img is None:
            print(f"Failed to load test image: {test_image_path}")
        else:
            print("Getting player data...")
            players = get_player_data(input_img)
            print(f"✓ Found data for {len(players)} players.")

            # --- Draw markers and info on original image for testing --- 
            output_img_test = input_img.copy()
            for idx, p_data in enumerate(players):
                kps = p_data['keypoints']
                scores = p_data['scores']
                bbox = p_data['bbox']
                color = p_data['avg_color']

                # Determine reference point (ankles or bbox bottom mid)
                l_ankle_pt = kps[LEFT_ANKLE_KP_INDEX]
                r_ankle_pt = kps[RIGHT_ANKLE_KP_INDEX]
                l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
                r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
                
                ref_point = None
                if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
                    ref_point = tuple(map(int, (l_ankle_pt + r_ankle_pt) / 2))
                elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
                    ref_point = tuple(map(int, l_ankle_pt))
                elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
                    ref_point = tuple(map(int, r_ankle_pt))
                else:
                    x1, y1, x2, y2 = map(int, bbox)
                    ref_point = (int((x1 + x2) / 2), y2) 

                # Draw marker at reference point
                if ref_point:
                    cv2.circle(output_img_test, ref_point, 8, color, -1, cv2.LINE_AA)
                    cv2.circle(output_img_test, ref_point, 8, (0,0,0), 1, cv2.LINE_AA) # Black outline
                    # Draw player index
                    cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 2, cv2.LINE_AA)
                    cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
            
            cv2.imshow("Original Image", input_img)
            cv2.imshow("Player Markers Test", output_img_test)
            print("Displaying test results. Press any key to exit.")
            cv2.waitKey(0)
            cv2.destroyAllWindows()