DeepNAPSI / backend.py
Folle, Lukas
Added first working version of application, NAPSI is dummy.
14e27af
raw history blame
No virus
1.01 kB
import torch
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from nail_detection.main import get_nails
from DummyModel import DummyModel
def load_model(DEBUG):
model = DummyModel()
if not DEBUG:
file_path = hf_hub_download("lfolle/DeepNAPSIModel", "dummy_model.pth",
use_auth_token=os.environ['DeepNAPSIModel'])
model.load_state_dict(torch.load(file_path))
return model
class Infer():
def __init__(self, DEBUG):
self.model = load_model(DEBUG)
def predict(self, data):
nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
predictions = []
if nails is None:
for _ in range(5):
predictions.append(np.zeros((64, 64, 3)))
predictions.append(-1)
else:
for nail in nails:
predictions.append(nail)
predictions.append(int(torch.argmax(self.model(nail))))
return predictions