team-classifier / inference.py
7gonzalodm's picture
Update inference.py
0d964c1 verified
import torch
import numpy as np
from sports.common.team import TeamClassifier
# Load the trained TeamClassifier from file
model = torch.load("team_classifier.pth", map_location="cpu")
# Example: Predict team for a list of player crops (numpy arrays)
def predict_teams(crops):
"""
Predicts team assignments for a list of player crops (numpy arrays).
Args:
crops (List[np.ndarray]): List of player crops as numpy arrays.
Returns:
np.ndarray: Predicted team labels (0 or 1)
"""
return model.predict(crops)
if __name__ == "__main__":
# Example usage: load a crop and predict
# Replace this with your own image loading logic
# For demonstration, we use a dummy crop
dummy_crop = np.zeros((224, 224, 3), dtype=np.uint8) # Replace with real crop
crops = [dummy_crop]
preds = predict_teams(crops)
print("Predicted team labels:", preds)