Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import pickle | |
import argparse | |
import torch | |
from torch import nn | |
import torch.distributed as dist | |
import torch.backends.cudnn as cudnn | |
from torchvision import models as torchvision_models | |
from torchvision import transforms as pth_transforms | |
from PIL import Image, ImageFile | |
import numpy as np | |
import utils | |
import vision_transformer as vits | |
from eval_knn import extract_features | |
class OxfordParisDataset(torch.utils.data.Dataset): | |
def __init__(self, dir_main, dataset, split, transform=None, imsize=None): | |
if dataset not in ["roxford5k", "rparis6k"]: | |
raise ValueError("Unknown dataset: {}!".format(dataset)) | |
# loading imlist, qimlist, and gnd, in cfg as a dict | |
gnd_fname = os.path.join(dir_main, dataset, "gnd_{}.pkl".format(dataset)) | |
with open(gnd_fname, "rb") as f: | |
cfg = pickle.load(f) | |
cfg["gnd_fname"] = gnd_fname | |
cfg["ext"] = ".jpg" | |
cfg["qext"] = ".jpg" | |
cfg["dir_data"] = os.path.join(dir_main, dataset) | |
cfg["dir_images"] = os.path.join(cfg["dir_data"], "jpg") | |
cfg["n"] = len(cfg["imlist"]) | |
cfg["nq"] = len(cfg["qimlist"]) | |
cfg["im_fname"] = config_imname | |
cfg["qim_fname"] = config_qimname | |
cfg["dataset"] = dataset | |
self.cfg = cfg | |
self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"] | |
self.transform = transform | |
self.imsize = imsize | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg") | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
with open(path, "rb") as f: | |
img = Image.open(f) | |
img = img.convert("RGB") | |
if self.imsize is not None: | |
img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS) | |
if self.transform is not None: | |
img = self.transform(img) | |
return img, index | |
def config_imname(cfg, i): | |
return os.path.join(cfg["dir_images"], cfg["imlist"][i] + cfg["ext"]) | |
def config_qimname(cfg, i): | |
return os.path.join(cfg["dir_images"], cfg["qimlist"][i] + cfg["qext"]) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Image Retrieval on revisited Paris and Oxford") | |
parser.add_argument( | |
"--data_path", default="/path/to/revisited_paris_oxford/", type=str | |
) | |
parser.add_argument( | |
"--dataset", default="roxford5k", type=str, choices=["roxford5k", "rparis6k"] | |
) | |
parser.add_argument("--multiscale", default=False, type=utils.bool_flag) | |
parser.add_argument("--imsize", default=224, type=int, help="Image size") | |
parser.add_argument( | |
"--pretrained_weights", | |
default="", | |
type=str, | |
help="Path to pretrained weights to evaluate.", | |
) | |
parser.add_argument("--use_cuda", default=True, type=utils.bool_flag) | |
parser.add_argument("--arch", default="vit_small", type=str, help="Architecture") | |
parser.add_argument( | |
"--patch_size", default=16, type=int, help="Patch resolution of the model." | |
) | |
parser.add_argument( | |
"--checkpoint_key", | |
default="teacher", | |
type=str, | |
help='Key to use in the checkpoint (example: "teacher")', | |
) | |
parser.add_argument( | |
"--num_workers", | |
default=10, | |
type=int, | |
help="Number of data loading workers per GPU.", | |
) | |
parser.add_argument( | |
"--dist_url", | |
default="env://", | |
type=str, | |
help="""url used to set up | |
distributed training; see https://pytorch.org/docs/stable/distributed.html""", | |
) | |
parser.add_argument( | |
"--local_rank", | |
default=0, | |
type=int, | |
help="Please ignore and do not set this argument.", | |
) | |
args = parser.parse_args() | |
utils.init_distributed_mode(args) | |
print("git:\n {}\n".format(utils.get_sha())) | |
print( | |
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())) | |
) | |
cudnn.benchmark = True | |
# ============ preparing data ... ============ | |
transform = pth_transforms.Compose( | |
[ | |
pth_transforms.ToTensor(), | |
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
dataset_train = OxfordParisDataset( | |
args.data_path, | |
args.dataset, | |
split="train", | |
transform=transform, | |
imsize=args.imsize, | |
) | |
dataset_query = OxfordParisDataset( | |
args.data_path, | |
args.dataset, | |
split="query", | |
transform=transform, | |
imsize=args.imsize, | |
) | |
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) | |
data_loader_train = torch.utils.data.DataLoader( | |
dataset_train, | |
sampler=sampler, | |
batch_size=1, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
data_loader_query = torch.utils.data.DataLoader( | |
dataset_query, | |
batch_size=1, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs") | |
# ============ building network ... ============ | |
if "vit" in args.arch: | |
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) | |
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") | |
elif "xcit" in args.arch: | |
model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0) | |
elif args.arch in torchvision_models.__dict__.keys(): | |
model = torchvision_models.__dict__[args.arch](num_classes=0) | |
else: | |
print(f"Architecture {args.arch} non supported") | |
sys.exit(1) | |
if args.use_cuda: | |
model.cuda() | |
model.eval() | |
# load pretrained weights | |
if os.path.isfile(args.pretrained_weights): | |
state_dict = torch.load(args.pretrained_weights, map_location="cpu") | |
if args.checkpoint_key is not None and args.checkpoint_key in state_dict: | |
print(f"Take key {args.checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[args.checkpoint_key] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print( | |
"Pretrained weights found at {} and loaded with msg: {}".format( | |
args.pretrained_weights, msg | |
) | |
) | |
elif args.arch == "vit_small" and args.patch_size == 16: | |
print( | |
"Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2." | |
) | |
model.load_state_dict( | |
torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth" | |
) | |
) | |
else: | |
print("Warning: We use random weights.") | |
############################################################################ | |
# Step 1: extract features | |
train_features = extract_features( | |
model, data_loader_train, args.use_cuda, multiscale=args.multiscale | |
) | |
query_features = extract_features( | |
model, data_loader_query, args.use_cuda, multiscale=args.multiscale | |
) | |
if utils.get_rank() == 0: # only rank 0 will work from now on | |
# normalize features | |
train_features = nn.functional.normalize(train_features, dim=1, p=2) | |
query_features = nn.functional.normalize(query_features, dim=1, p=2) | |
############################################################################ | |
# Step 2: similarity | |
sim = torch.mm(train_features, query_features.T) | |
ranks = torch.argsort(-sim, dim=0).cpu().numpy() | |
############################################################################ | |
# Step 3: evaluate | |
gnd = dataset_train.cfg["gnd"] | |
# evaluate ranks | |
ks = [1, 5, 10] | |
# search for easy & hard | |
gnd_t = [] | |
for i in range(len(gnd)): | |
g = {} | |
g["ok"] = np.concatenate([gnd[i]["easy"], gnd[i]["hard"]]) | |
g["junk"] = np.concatenate([gnd[i]["junk"]]) | |
gnd_t.append(g) | |
mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks) | |
# search for hard | |
gnd_t = [] | |
for i in range(len(gnd)): | |
g = {} | |
g["ok"] = np.concatenate([gnd[i]["hard"]]) | |
g["junk"] = np.concatenate([gnd[i]["junk"], gnd[i]["easy"]]) | |
gnd_t.append(g) | |
mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks) | |
print( | |
">> {}: mAP M: {}, H: {}".format( | |
args.dataset, | |
np.around(mapM * 100, decimals=2), | |
np.around(mapH * 100, decimals=2), | |
) | |
) | |
print( | |
">> {}: mP@k{} M: {}, H: {}".format( | |
args.dataset, | |
np.array(ks), | |
np.around(mprM * 100, decimals=2), | |
np.around(mprH * 100, decimals=2), | |
) | |
) | |
dist.barrier() | |