import torch import cv2 import numpy as np from huggingface_hub import hf_hub_download from nail_detection.main import get_nails from DummyModel import load_dummy_model from Model import Model class Infer(): def __init__(self, DEBUG): # self.model = load_dummy_model(DEBUG) self.model = Model(DEBUG) def predict(self, data): nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR)) predictions = [] if nails is None: for _ in range(5): predictions.append(np.zeros((64, 64, 3))) predictions.append(-1) predictions.append("-1") else: model_prediction, uncertainty = self.model(nails) model_prediction = model_prediction[0] napsi_predictions = torch.argmax(model_prediction, 1) napsi_sum = int(napsi_predictions.sum().detach().cpu()) for napsi_prediction, nail in zip(napsi_predictions, nails): predictions.append(nail) predictions.append(int(napsi_prediction.detach().cpu())) predictions.append(napsi_sum) return predictions