Spaces:
Sleeping
Sleeping
import torch | |
import cv2 | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from nail_detection.main import get_nails | |
from DummyModel import DummyModel | |
def load_model(DEBUG): | |
model = DummyModel() | |
if not DEBUG: | |
file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth", | |
use_auth_token=os.environ['DeepNAPSIModel']) | |
model.load_state_dict(torch.load(file_path)) | |
return model | |
class Infer(): | |
def __init__(self, DEBUG): | |
self.model = load_model(DEBUG) | |
def predict(self, data): | |
nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR)) | |
predictions = [] | |
if nails is None: | |
for _ in range(5): | |
predictions.append(np.zeros((64, 64, 3))) | |
predictions.append(-1) | |
else: | |
for nail in nails: | |
predictions.append(nail) | |
predictions.append(int(torch.argmax(self.model(nail)))) | |
return predictions | |