Spaces:
Running
Running
import os | |
import sys | |
from pathlib import Path | |
from zipfile import ZipFile | |
import gdown | |
import sklearn | |
import torch | |
from ..utils.base_model import BaseModel | |
sys.path.append( | |
str(Path(__file__).parent / "../../third_party/deep-image-retrieval") | |
) | |
os.environ["DB_ROOT"] = "" # required by dirtorch | |
from dirtorch.extract_features import load_model # noqa: E402 | |
from dirtorch.utils import common # noqa: E402 | |
# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca, | |
# which has been deprecated in sklearn v0.24 | |
# and must be explicitly imported with `from sklearn.decomposition import PCA`. | |
# This is a hacky workaround to maintain forward compatibility. | |
sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca | |
class DIR(BaseModel): | |
default_conf = { | |
"model_name": "Resnet-101-AP-GeM", | |
"whiten_name": "Landmarks_clean", | |
"whiten_params": { | |
"whitenp": 0.25, | |
"whitenv": None, | |
"whitenm": 1.0, | |
}, | |
"pooling": "gem", | |
"gemp": 3, | |
} | |
required_inputs = ["image"] | |
dir_models = { | |
"Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy", | |
} | |
def _init(self, conf): | |
checkpoint = Path( | |
torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt" | |
) | |
if not checkpoint.exists(): | |
checkpoint.parent.mkdir(exist_ok=True, parents=True) | |
link = self.dir_models[conf["model_name"]] | |
gdown.download(str(link), str(checkpoint) + ".zip", quiet=False) | |
zf = ZipFile(str(checkpoint) + ".zip", "r") | |
zf.extractall(checkpoint.parent) | |
zf.close() | |
os.remove(str(checkpoint) + ".zip") | |
self.net = load_model(checkpoint, False) # first load on CPU | |
if conf["whiten_name"]: | |
assert conf["whiten_name"] in self.net.pca | |
def _forward(self, data): | |
image = data["image"] | |
assert image.shape[1] == 3 | |
mean = self.net.preprocess["mean"] | |
std = self.net.preprocess["std"] | |
image = image - image.new_tensor(mean)[:, None, None] | |
image = image / image.new_tensor(std)[:, None, None] | |
desc = self.net(image) | |
desc = desc.unsqueeze(0) # batch dimension | |
if self.conf["whiten_name"]: | |
pca = self.net.pca[self.conf["whiten_name"]] | |
desc = common.whiten_features( | |
desc.cpu().numpy(), pca, **self.conf["whiten_params"] | |
) | |
desc = torch.from_numpy(desc) | |
return { | |
"global_descriptor": desc, | |
} | |