Vincentqyw
update: ci
8320ccc
raw
history blame
2.65 kB
import os
import sys
from pathlib import Path
from zipfile import ZipFile
import gdown
import sklearn
import torch
from ..utils.base_model import BaseModel
sys.path.append(
str(Path(__file__).parent / "../../third_party/deep-image-retrieval")
)
os.environ["DB_ROOT"] = "" # required by dirtorch
from dirtorch.extract_features import load_model # noqa: E402
from dirtorch.utils import common # noqa: E402
# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
# which has been deprecated in sklearn v0.24
# and must be explicitly imported with `from sklearn.decomposition import PCA`.
# This is a hacky workaround to maintain forward compatibility.
sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca
class DIR(BaseModel):
default_conf = {
"model_name": "Resnet-101-AP-GeM",
"whiten_name": "Landmarks_clean",
"whiten_params": {
"whitenp": 0.25,
"whitenv": None,
"whitenm": 1.0,
},
"pooling": "gem",
"gemp": 3,
}
required_inputs = ["image"]
dir_models = {
"Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy",
}
def _init(self, conf):
checkpoint = Path(
torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt"
)
if not checkpoint.exists():
checkpoint.parent.mkdir(exist_ok=True, parents=True)
link = self.dir_models[conf["model_name"]]
gdown.download(str(link), str(checkpoint) + ".zip", quiet=False)
zf = ZipFile(str(checkpoint) + ".zip", "r")
zf.extractall(checkpoint.parent)
zf.close()
os.remove(str(checkpoint) + ".zip")
self.net = load_model(checkpoint, False) # first load on CPU
if conf["whiten_name"]:
assert conf["whiten_name"] in self.net.pca
def _forward(self, data):
image = data["image"]
assert image.shape[1] == 3
mean = self.net.preprocess["mean"]
std = self.net.preprocess["std"]
image = image - image.new_tensor(mean)[:, None, None]
image = image / image.new_tensor(std)[:, None, None]
desc = self.net(image)
desc = desc.unsqueeze(0) # batch dimension
if self.conf["whiten_name"]:
pca = self.net.pca[self.conf["whiten_name"]]
desc = common.whiten_features(
desc.cpu().numpy(), pca, **self.conf["whiten_params"]
)
desc = torch.from_numpy(desc)
return {
"global_descriptor": desc,
}