Spaces:
Runtime error
Runtime error
import torch | |
import cv2 | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from nail_detection.main import get_nails | |
from DummyModel import load_dummy_model | |
from Model import Model | |
class Infer(): | |
def __init__(self, DEBUG): | |
# self.model = load_dummy_model(DEBUG) | |
self.model = Model(DEBUG) | |
def predict(self, data): | |
nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR)) | |
predictions = [] | |
if nails is None: | |
for _ in range(5): | |
predictions.append(np.zeros((64, 64, 3))) | |
predictions.append(-1) | |
predictions.append("-1") | |
else: | |
model_prediction, uncertainty = self.model(nails) | |
model_prediction = model_prediction[0] | |
napsi_predictions = torch.argmax(model_prediction, 1) | |
napsi_sum = int(napsi_predictions.sum().detach().cpu()) | |
for napsi_prediction, nail in zip(napsi_predictions, nails): | |
predictions.append(nail) | |
predictions.append(int(napsi_prediction.detach().cpu())) | |
predictions.append(napsi_sum) | |
return predictions | |