DeepNAPSI / Model.py
lfolle's picture
Added missing import.
b56dbfa
import os
from huggingface_hub import hf_hub_download
from nail_classification.inference import Inference
class Model:
def __init__(self, DEBUG):
if DEBUG:
base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1"
file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
else:
file_paths = [hf_hub_download("lfolle/DeepNAPSIModel", f"version_{v}.ckpt",
use_auth_token=os.environ['DeepNAPSIModel']) for v in [10, 11, 12, 13, 14]]
self.inference = Inference(file_paths)
def predict(self, x):
y_hat, uncertainty = self.inference.predict(x)
return y_hat, uncertainty
def __call__(self, x):
return self.predict(x)