Spaces:
Build error
Build error
import os | |
import sys | |
from importlib import import_module, invalidate_caches | |
from importlib.util import module_from_spec, spec_from_file_location | |
from tempfile import TemporaryDirectory | |
import cv2 | |
import numpy as np | |
import plotly.express as px | |
import requests | |
import torch | |
from git import Repo | |
from huggingface_hub import hf_hub_download | |
class FECNetModel: | |
def __init__(self, hf_token: str) -> None: | |
self.hf_token = hf_token | |
repo_dir = TemporaryDirectory() | |
Repo.clone_from( | |
"https://github.com/AmirSh15/FECNet.git", | |
repo_dir.name, | |
) | |
invalidate_caches() | |
sys.path.append(repo_dir.name) | |
fecnet_module_path = os.path.join(repo_dir.name, "models", "FECNet.py") | |
with open(fecnet_module_path, "r") as f: | |
content = f.read() | |
content = content.replace( | |
"cuda", | |
"cpu", | |
) | |
with open(fecnet_module_path, "w") as f: | |
f.write(content) | |
spec = spec_from_file_location("FECNet", fecnet_module_path) | |
fecnet_module = module_from_spec(spec) # type: ignore | |
spec.loader.exec_module(fecnet_module) # type: ignore | |
self.model = self.__load_model( | |
self.__download_weights(repo_dir.name), fecnet_module.FECNet | |
) | |
def __download_weights(self, model_dir: str) -> str: | |
model_path = hf_hub_download( | |
"natexcvi/pretrained-fecnet", | |
"fecnet.pt", | |
token=self.hf_token, | |
) | |
return model_path | |
def __load_model(self, model_path: str, model_class): | |
model = model_class(pretrained=False) | |
model_weights = torch.load(model_path, map_location=torch.device("cpu")) | |
model.load_state_dict(model_weights) | |
model.eval() | |
return model.double() | |
def predict(self, image: np.ndarray): | |
pred = self.model.forward(image) | |
return pred | |
def distance(a, b): | |
return np.linalg.norm(a - b) | |
def embed_image(self, image) -> np.ndarray: | |
image = cv2.imdecode(image, cv2.IMREAD_COLOR) | |
image = cv2.resize(image, (224, 224)) | |
image = np.transpose(image, (2, 0, 1)) | |
image = np.expand_dims(image, axis=0) | |
image = torch.from_numpy(image).double() | |
pred = self.predict(image) | |
return pred.detach().numpy() | |