| from pathlib import Path
|
| from typing import List, Tuple, Dict
|
| import sys
|
| import os
|
|
|
| from numpy import ndarray
|
| from pydantic import BaseModel
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| from keypoint_helper import run_keypoints_post_processing
|
|
|
| from ultralytics import YOLO
|
| from team_cluster import TeamClassifier
|
| from utils import (
|
| BoundingBox,
|
| Constants,
|
| )
|
| from inference import predict_batch
|
| import time
|
| import torch
|
| import gc
|
| from pitch import process_batch_input, get_cls_net
|
| import yaml
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
| class BoundingBox(BaseModel):
|
| x1: int
|
| y1: int
|
| x2: int
|
| y2: int
|
| cls_id: int
|
| conf: float
|
|
|
|
|
| class TVFrameResult(BaseModel):
|
| frame_id: int
|
| boxes: List[BoundingBox]
|
| keypoints: List[Tuple[int, int]]
|
|
|
|
|
| class Miner:
|
| SMALL_CONTAINED_IOA = Constants.SMALL_CONTAINED_IOA
|
| SMALL_RATIO_MAX = Constants.SMALL_RATIO_MAX
|
| SINGLE_PLAYER_HUE_PIVOT = Constants.SINGLE_PLAYER_HUE_PIVOT
|
| CORNER_INDICES = Constants.CORNER_INDICES
|
| KEYPOINTS_CONFIDENCE = Constants.KEYPOINTS_CONFIDENCE
|
| CORNER_CONFIDENCE = Constants.CORNER_CONFIDENCE
|
| GOALKEEPER_POSITION_MARGIN = Constants.GOALKEEPER_POSITION_MARGIN
|
| MIN_SAMPLES_FOR_FIT = 16
|
| MAX_SAMPLES_FOR_FIT = 600
|
|
|
| def __init__(self, path_hf_repo: Path) -> None:
|
| try:
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| model_path = path_hf_repo / "football_object_detection.onnx"
|
| self.bbox_model = YOLO(model_path)
|
|
|
| print("BBox Model Loaded")
|
|
|
| ball_path = path_hf_repo / "ball_detection.pt"
|
| self.ball_model = YOLO(ball_path)
|
| self.ball_model.to(device)
|
|
|
| for _ in range(3):
|
| dummy_input = torch.zeros(16, 3, 1024, 1024, device=device)
|
| self.ball_model(dummy_input)
|
| print("Ball Model Loaded")
|
|
|
| team_model_path = path_hf_repo / "osnet_model.pth.tar-100"
|
| self.team_classifier = TeamClassifier(
|
| device=device,
|
| batch_size=32,
|
| model_name=str(team_model_path)
|
| )
|
| print("Team Classifier Loaded")
|
|
|
|
|
| self.team_classifier_fitted = False
|
| self.player_crops_for_fit = []
|
|
|
| model_kp_path = path_hf_repo / 'keypoint'
|
| config_kp_path = path_hf_repo / 'hrnetv2_w48.yaml'
|
| cfg_kp = yaml.safe_load(open(config_kp_path, 'r'))
|
|
|
| loaded_state_kp = torch.load(model_kp_path, map_location=device)
|
| model = get_cls_net(cfg_kp)
|
| model.load_state_dict(loaded_state_kp)
|
| model.to(device)
|
| model.eval()
|
|
|
| self.keypoints_model = model
|
| self.kp_threshold = 0.1
|
| self.pitch_batch_size = 4
|
| self.health = "healthy"
|
| print("✅ Keypoints Model Loaded")
|
| except Exception as e:
|
| self.health = "❌ Miner initialization failed: " + str(e)
|
| print(self.health)
|
|
|
| def __repr__(self) -> str:
|
| if self.health == 'healthy':
|
| return (
|
| f"health: {self.health}\n"
|
| f"BBox Model: {type(self.bbox_model).__name__}\n"
|
| f"Keypoints Model: {type(self.keypoints_model).__name__}"
|
| )
|
| else:
|
| return self.health
|
|
|
| def predict_batch(self, batch_images: List[ndarray], offset: int, n_keypoints: int) -> List[TVFrameResult]:
|
| results = predict_batch(
|
| self.bbox_model,
|
| self.ball_model,
|
| self.team_classifier,
|
| self.keypoints_model,
|
| batch_images,
|
| offset,
|
| n_keypoints,
|
| self.pitch_batch_size,
|
| self.kp_threshold
|
| )
|
| return results |