File size: 1,147 Bytes
14e27af
 
 
 
 
 
761b08f
ec42e29
14e27af
 
 
 
ec42e29
 
14e27af
 
 
 
 
 
 
 
c02063c
14e27af
ec42e29
 
 
761b08f
 
14e27af
c02063c
 
14e27af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import cv2
import numpy as np
from huggingface_hub import hf_hub_download
from nail_detection.main import get_nails

from DummyModel import load_dummy_model
from Model import Model


class Infer():
    def __init__(self, DEBUG):
        # self.model = load_dummy_model(DEBUG)
        self.model = Model(DEBUG)

    def predict(self, data):
        nails = get_nails(cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
        predictions = []
        if nails is None:
            for _ in range(5):
                predictions.append(np.zeros((64, 64, 3)))
                predictions.append(-1)
            predictions.append("-1")
        else:
            model_prediction, uncertainty = self.model(nails)
            model_prediction = model_prediction[0]
            napsi_predictions = torch.argmax(model_prediction, 1)
            napsi_sum = int(napsi_predictions.sum().detach().cpu())
            for napsi_prediction, nail in zip(napsi_predictions, nails):
                predictions.append(nail)
                predictions.append(int(napsi_prediction.detach().cpu()))
        predictions.append(napsi_sum)
        return predictions