Spaces:
Runtime error
Runtime error
File size: 1,147 Bytes
14e27af 761b08f ec42e29 14e27af ec42e29 14e27af c02063c 14e27af ec42e29 761b08f 14e27af c02063c 14e27af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from nail_detection.main import get_nails
from DummyModel import load_dummy_model
from Model import Model
class Infer():
def __init__(self, DEBUG):
# self.model = load_dummy_model(DEBUG)
self.model = Model(DEBUG)
def predict(self, data):
nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
predictions = []
if nails is None:
for _ in range(5):
predictions.append(np.zeros((64, 64, 3)))
predictions.append(-1)
predictions.append("-1")
else:
model_prediction, uncertainty = self.model(nails)
model_prediction = model_prediction[0]
napsi_predictions = torch.argmax(model_prediction, 1)
napsi_sum = int(napsi_predictions.sum().detach().cpu())
for napsi_prediction, nail in zip(napsi_predictions, nails):
predictions.append(nail)
predictions.append(int(napsi_prediction.detach().cpu()))
predictions.append(napsi_sum)
return predictions
|