natexcvi
Add fecnet
0edd049 unverified
raw
history blame
No virus
2.36 kB
import os
import sys
from importlib import import_module, invalidate_caches
from importlib.util import module_from_spec, spec_from_file_location
from tempfile import TemporaryDirectory
import cv2
import numpy as np
import plotly.express as px
import requests
import torch
from git import Repo
from huggingface_hub import hf_hub_download
class FECNetModel:
def __init__(self, hf_token: str) -> None:
self.hf_token = hf_token
repo_dir = TemporaryDirectory()
Repo.clone_from(
"https://github.com/AmirSh15/FECNet.git",
repo_dir.name,
)
invalidate_caches()
sys.path.append(repo_dir.name)
fecnet_module_path = os.path.join(repo_dir.name, "models", "FECNet.py")
with open(fecnet_module_path, "r") as f:
content = f.read()
content = content.replace(
"cuda",
"cpu",
)
with open(fecnet_module_path, "w") as f:
f.write(content)
spec = spec_from_file_location("FECNet", fecnet_module_path)
fecnet_module = module_from_spec(spec) # type: ignore
spec.loader.exec_module(fecnet_module) # type: ignore
self.model = self.__load_model(
self.__download_weights(repo_dir.name), fecnet_module.FECNet
)
def __download_weights(self, model_dir: str) -> str:
model_path = hf_hub_download(
"natexcvi/pretrained-fecnet",
"fecnet.pt",
token=self.hf_token,
)
return model_path
def __load_model(self, model_path: str, model_class):
model = model_class(pretrained=False)
model_weights = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(model_weights)
model.eval()
return model.double()
def predict(self, image: np.ndarray):
pred = self.model.forward(image)
return pred
def distance(a, b):
return np.linalg.norm(a - b)
def embed_image(self, image) -> np.ndarray:
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
image = cv2.resize(image, (224, 224))
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, axis=0)
image = torch.from_numpy(image).double()
pred = self.predict(image)
return pred.detach().numpy()