DeepNAPSI / backend.py
lfolle's picture
Added napsi sum, small refactoring.
761b08f
raw
history blame
No virus
957 Bytes
import torch
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from nail_detection.main import get_nails
from DummyModel import load_dummy_model
class Infer():
def __init__(self, DEBUG):
self.model = load_dummy_model(DEBUG)
def predict(self, data):
nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
predictions = []
if nails is None:
predictions.append(-1)
for _ in range(5):
predictions.append(np.zeros((64, 64, 3)))
predictions.append(-1)
else:
napsi_predictions = torch.argmax(self.model(nails), 1)
napsi_sum = int(napsi_predictions.sum().detach().cpu())
predictions.append(napsi_sum)
for napsi_prediction, nail in zip(napsi_predictions, nails):
predictions.append(nail)
predictions.append(napsi_prediction)
return predictions