Vincentqyw
update: limit keypoints number
e15a186
raw history blame
No virus
2.64 kB
import sys
from pathlib import Path
import torch
from zipfile import ZipFile
import os
import sklearn
import gdown
from ..utils.base_model import BaseModel
sys.path.append(
str(Path(__file__).parent / "../../third_party/deep-image-retrieval")
)
os.environ["DB_ROOT"] = "" # required by dirtorch
from dirtorch.utils import common # noqa: E402
from dirtorch.extract_features import load_model # noqa: E402
# The DIR model checkpoints (pickle files) include sklearn.decomposition.pca,
# which has been deprecated in sklearn v0.24
# and must be explicitly imported with `from sklearn.decomposition import PCA`.
# This is a hacky workaround to maintain forward compatibility.
sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca
class DIR(BaseModel):
default_conf = {
"model_name": "Resnet-101-AP-GeM",
"whiten_name": "Landmarks_clean",
"whiten_params": {
"whitenp": 0.25,
"whitenv": None,
"whitenm": 1.0,
},
"pooling": "gem",
"gemp": 3,
}
required_inputs = ["image"]
dir_models = {
"Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy",
}
def _init(self, conf):
checkpoint = Path(
torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt"
)
if not checkpoint.exists():
checkpoint.parent.mkdir(exist_ok=True, parents=True)
link = self.dir_models[conf["model_name"]]
gdown.download(str(link), str(checkpoint) + ".zip", quiet=False)
zf = ZipFile(str(checkpoint) + ".zip", "r")
zf.extractall(checkpoint.parent)
zf.close()
os.remove(str(checkpoint) + ".zip")
self.net = load_model(checkpoint, False) # first load on CPU
if conf["whiten_name"]:
assert conf["whiten_name"] in self.net.pca
def _forward(self, data):
image = data["image"]
assert image.shape[1] == 3
mean = self.net.preprocess["mean"]
std = self.net.preprocess["std"]
image = image - image.new_tensor(mean)[:, None, None]
image = image / image.new_tensor(std)[:, None, None]
desc = self.net(image)
desc = desc.unsqueeze(0) # batch dimension
if self.conf["whiten_name"]:
pca = self.net.pca[self.conf["whiten_name"]]
desc = common.whiten_features(
desc.cpu().numpy(), pca, **self.conf["whiten_params"]
)
desc = torch.from_numpy(desc)
return {
"global_descriptor": desc,
}