import torch import cv2 import numpy as np from huggingface_hub import hf_hub_download from nail_detection.main import get_nails from DummyModel import DummyModel def load_model(DEBUG): model = DummyModel() if not DEBUG: file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth", use_auth_token=os.environ['DeepNAPSIModel']) model.load_state_dict(torch.load(file_path)) return model class Infer(): def __init__(self, DEBUG): self.model = load_model(DEBUG) def predict(self, data): nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR)) predictions = [] if nails is None: for _ in range(5): predictions.append(np.zeros((64, 64, 3))) predictions.append(-1) else: for nail in nails: predictions.append(nail) predictions.append(int(torch.argmax(self.model(nail)))) return predictions