|
import sys |
|
from pathlib import Path |
|
import torch |
|
from zipfile import ZipFile |
|
import os |
|
import sklearn |
|
import gdown |
|
|
|
from ..utils.base_model import BaseModel |
|
|
|
sys.path.append( |
|
str(Path(__file__).parent / "../../third_party/deep-image-retrieval") |
|
) |
|
os.environ["DB_ROOT"] = "" |
|
|
|
from dirtorch.utils import common |
|
from dirtorch.extract_features import load_model |
|
|
|
|
|
|
|
|
|
|
|
sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca |
|
|
|
|
|
class DIR(BaseModel): |
|
default_conf = { |
|
"model_name": "Resnet-101-AP-GeM", |
|
"whiten_name": "Landmarks_clean", |
|
"whiten_params": { |
|
"whitenp": 0.25, |
|
"whitenv": None, |
|
"whitenm": 1.0, |
|
}, |
|
"pooling": "gem", |
|
"gemp": 3, |
|
} |
|
required_inputs = ["image"] |
|
|
|
dir_models = { |
|
"Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy", |
|
} |
|
|
|
def _init(self, conf): |
|
checkpoint = Path( |
|
torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt" |
|
) |
|
if not checkpoint.exists(): |
|
checkpoint.parent.mkdir(exist_ok=True, parents=True) |
|
link = self.dir_models[conf["model_name"]] |
|
gdown.download(str(link), str(checkpoint) + ".zip", quiet=False) |
|
zf = ZipFile(str(checkpoint) + ".zip", "r") |
|
zf.extractall(checkpoint.parent) |
|
zf.close() |
|
os.remove(str(checkpoint) + ".zip") |
|
|
|
self.net = load_model(checkpoint, False) |
|
if conf["whiten_name"]: |
|
assert conf["whiten_name"] in self.net.pca |
|
|
|
def _forward(self, data): |
|
image = data["image"] |
|
assert image.shape[1] == 3 |
|
mean = self.net.preprocess["mean"] |
|
std = self.net.preprocess["std"] |
|
image = image - image.new_tensor(mean)[:, None, None] |
|
image = image / image.new_tensor(std)[:, None, None] |
|
|
|
desc = self.net(image) |
|
desc = desc.unsqueeze(0) |
|
if self.conf["whiten_name"]: |
|
pca = self.net.pca[self.conf["whiten_name"]] |
|
desc = common.whiten_features( |
|
desc.cpu().numpy(), pca, **self.conf["whiten_params"] |
|
) |
|
desc = torch.from_numpy(desc) |
|
|
|
return { |
|
"global_descriptor": desc, |
|
} |
|
|