|
import torch |
|
import numpy as np |
|
from sports.common.team import TeamClassifier |
|
|
|
|
|
model = torch.load("team_classifier.pth", map_location="cpu") |
|
|
|
|
|
def predict_teams(crops): |
|
""" |
|
Predicts team assignments for a list of player crops (numpy arrays). |
|
Args: |
|
crops (List[np.ndarray]): List of player crops as numpy arrays. |
|
Returns: |
|
np.ndarray: Predicted team labels (0 or 1) |
|
""" |
|
return model.predict(crops) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
dummy_crop = np.zeros((224, 224, 3), dtype=np.uint8) |
|
crops = [dummy_crop] |
|
preds = predict_teams(crops) |
|
print("Predicted team labels:", preds) |
|
|