import os import sys from importlib import import_module, invalidate_caches from importlib.util import module_from_spec, spec_from_file_location from tempfile import TemporaryDirectory import cv2 import numpy as np import plotly.express as px import requests import torch from git import Repo from huggingface_hub import hf_hub_download class FECNetModel: def __init__(self, hf_token: str) -> None: self.hf_token = hf_token repo_dir = TemporaryDirectory() Repo.clone_from( "https://github.com/AmirSh15/FECNet.git", repo_dir.name, ) invalidate_caches() sys.path.append(repo_dir.name) fecnet_module_path = os.path.join(repo_dir.name, "models", "FECNet.py") with open(fecnet_module_path, "r") as f: content = f.read() content = content.replace( "cuda", "cpu", ) with open(fecnet_module_path, "w") as f: f.write(content) spec = spec_from_file_location("FECNet", fecnet_module_path) fecnet_module = module_from_spec(spec) # type: ignore spec.loader.exec_module(fecnet_module) # type: ignore self.model = self.__load_model( self.__download_weights(repo_dir.name), fecnet_module.FECNet ) def __download_weights(self, model_dir: str) -> str: model_path = hf_hub_download( "natexcvi/pretrained-fecnet", "fecnet.pt", token=self.hf_token, ) return model_path def __load_model(self, model_path: str, model_class): model = model_class(pretrained=False) model_weights = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(model_weights) model.eval() return model.double() def predict(self, image: np.ndarray): pred = self.model.forward(image) return pred def distance(a, b): return np.linalg.norm(a - b) def embed_image(self, image) -> np.ndarray: image = cv2.imdecode(image, cv2.IMREAD_COLOR) image = cv2.resize(image, (224, 224)) image = np.transpose(image, (2, 0, 1)) image = np.expand_dims(image, axis=0) image = torch.from_numpy(image).double() pred = self.predict(image) return pred.detach().numpy()