File size: 802 Bytes
ec42e29
b56dbfa
761b08f
 
 
 
ec42e29
 
 
 
 
 
 
761b08f
 
 
 
 
ec42e29
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
from huggingface_hub import hf_hub_download
from nail_classification.inference import Inference


class Model:
    def __init__(self, DEBUG):
        if DEBUG:
            base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1"
            file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
        else:
            file_paths = [hf_hub_download("lfolle/DeepNAPSIModel", f"version_{v}.ckpt",
                                    use_auth_token=os.environ['DeepNAPSIModel']) for v in [10, 11, 12, 13, 14]]
        self.inference = Inference(file_paths)

    def predict(self, x):
        y_hat, uncertainty = self.inference.predict(x)
        return y_hat, uncertainty

    def __call__(self, x):
        return self.predict(x)