Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import argparse | |
import torch | |
from torch import nn | |
import torch.distributed as dist | |
import torch.backends.cudnn as cudnn | |
from torchvision import datasets | |
from torchvision import transforms as pth_transforms | |
from torchvision import models as torchvision_models | |
import utils | |
import vision_transformer as vits | |
def extract_feature_pipeline(args): | |
# ============ preparing data ... ============ | |
transform = pth_transforms.Compose( | |
[ | |
pth_transforms.Resize(256, interpolation=3), | |
pth_transforms.CenterCrop(224), | |
pth_transforms.ToTensor(), | |
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
dataset_train = ReturnIndexDataset( | |
os.path.join(args.data_path, "train"), transform=transform | |
) | |
dataset_val = ReturnIndexDataset( | |
os.path.join(args.data_path, "val"), transform=transform | |
) | |
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) | |
data_loader_train = torch.utils.data.DataLoader( | |
dataset_train, | |
sampler=sampler, | |
batch_size=args.batch_size_per_gpu, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
data_loader_val = torch.utils.data.DataLoader( | |
dataset_val, | |
batch_size=args.batch_size_per_gpu, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
print( | |
f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs." | |
) | |
# ============ building network ... ============ | |
if "vit" in args.arch: | |
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) | |
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") | |
elif "xcit" in args.arch: | |
model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0) | |
elif args.arch in torchvision_models.__dict__.keys(): | |
model = torchvision_models.__dict__[args.arch](num_classes=0) | |
model.fc = nn.Identity() | |
else: | |
print(f"Architecture {args.arch} non supported") | |
sys.exit(1) | |
model.cuda() | |
utils.load_pretrained_weights( | |
model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size | |
) | |
model.eval() | |
# ============ extract features ... ============ | |
print("Extracting features for train set...") | |
train_features = extract_features(model, data_loader_train, args.use_cuda) | |
print("Extracting features for val set...") | |
test_features = extract_features(model, data_loader_val, args.use_cuda) | |
if utils.get_rank() == 0: | |
train_features = nn.functional.normalize(train_features, dim=1, p=2) | |
test_features = nn.functional.normalize(test_features, dim=1, p=2) | |
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() | |
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() | |
# save features and labels | |
if args.dump_features and dist.get_rank() == 0: | |
torch.save( | |
train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth") | |
) | |
torch.save( | |
test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth") | |
) | |
torch.save( | |
train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth") | |
) | |
torch.save( | |
test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth") | |
) | |
return train_features, test_features, train_labels, test_labels | |
def extract_features(model, data_loader, use_cuda=True, multiscale=False): | |
metric_logger = utils.MetricLogger(delimiter=" ") | |
features = None | |
for samples, index in metric_logger.log_every(data_loader, 10): | |
samples = samples.cuda(non_blocking=True) | |
index = index.cuda(non_blocking=True) | |
if multiscale: | |
feats = utils.multi_scale(samples, model) | |
else: | |
feats = model(samples).clone() | |
# init storage feature matrix | |
if dist.get_rank() == 0 and features is None: | |
features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) | |
if use_cuda: | |
features = features.cuda(non_blocking=True) | |
print(f"Storing features into tensor of shape {features.shape}") | |
# get indexes from all processes | |
y_all = torch.empty( | |
dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device | |
) | |
y_l = list(y_all.unbind(0)) | |
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) | |
y_all_reduce.wait() | |
index_all = torch.cat(y_l) | |
# share features between processes | |
feats_all = torch.empty( | |
dist.get_world_size(), | |
feats.size(0), | |
feats.size(1), | |
dtype=feats.dtype, | |
device=feats.device, | |
) | |
output_l = list(feats_all.unbind(0)) | |
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) | |
output_all_reduce.wait() | |
# update storage feature matrix | |
if dist.get_rank() == 0: | |
if use_cuda: | |
features.index_copy_(0, index_all, torch.cat(output_l)) | |
else: | |
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) | |
return features | |
def knn_classifier( | |
train_features, train_labels, test_features, test_labels, k, T, num_classes=1000 | |
): | |
top1, top5, total = 0.0, 0.0, 0 | |
train_features = train_features.t() | |
num_test_images, num_chunks = test_labels.shape[0], 100 | |
imgs_per_chunk = num_test_images // num_chunks | |
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) | |
for idx in range(0, num_test_images, imgs_per_chunk): | |
# get the features for test images | |
features = test_features[idx : min((idx + imgs_per_chunk), num_test_images), :] | |
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] | |
batch_size = targets.shape[0] | |
# calculate the dot product and compute top-k neighbors | |
similarity = torch.mm(features, train_features) | |
distances, indices = similarity.topk(k, largest=True, sorted=True) | |
candidates = train_labels.view(1, -1).expand(batch_size, -1) | |
retrieved_neighbors = torch.gather(candidates, 1, indices) | |
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() | |
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) | |
distances_transform = distances.clone().div_(T).exp_() | |
probs = torch.sum( | |
torch.mul( | |
retrieval_one_hot.view(batch_size, -1, num_classes), | |
distances_transform.view(batch_size, -1, 1), | |
), | |
1, | |
) | |
_, predictions = probs.sort(1, True) | |
# find the predictions that match the target | |
correct = predictions.eq(targets.data.view(-1, 1)) | |
top1 = top1 + correct.narrow(1, 0, 1).sum().item() | |
top5 = ( | |
top5 + correct.narrow(1, 0, min(5, k)).sum().item() | |
) # top5 does not make sense if k < 5 | |
total += targets.size(0) | |
top1 = top1 * 100.0 / total | |
top5 = top5 * 100.0 / total | |
return top1, top5 | |
class ReturnIndexDataset(datasets.ImageFolder): | |
def __getitem__(self, idx): | |
img, lab = super(ReturnIndexDataset, self).__getitem__(idx) | |
return img, idx | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Evaluation with weighted k-NN on ImageNet") | |
parser.add_argument( | |
"--batch_size_per_gpu", default=128, type=int, help="Per-GPU batch-size" | |
) | |
parser.add_argument( | |
"--nb_knn", | |
default=[10, 20, 100, 200], | |
nargs="+", | |
type=int, | |
help="Number of NN to use. 20 is usually working the best.", | |
) | |
parser.add_argument( | |
"--temperature", | |
default=0.07, | |
type=float, | |
help="Temperature used in the voting coefficient", | |
) | |
parser.add_argument( | |
"--pretrained_weights", | |
default="", | |
type=str, | |
help="Path to pretrained weights to evaluate.", | |
) | |
parser.add_argument( | |
"--use_cuda", | |
default=True, | |
type=utils.bool_flag, | |
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM", | |
) | |
parser.add_argument("--arch", default="vit_small", type=str, help="Architecture") | |
parser.add_argument( | |
"--patch_size", default=16, type=int, help="Patch resolution of the model." | |
) | |
parser.add_argument( | |
"--checkpoint_key", | |
default="teacher", | |
type=str, | |
help='Key to use in the checkpoint (example: "teacher")', | |
) | |
parser.add_argument( | |
"--dump_features", | |
default=None, | |
help="Path where to save computed features, empty for no saving", | |
) | |
parser.add_argument( | |
"--load_features", | |
default=None, | |
help="""If the features have | |
already been computed, where to find them.""", | |
) | |
parser.add_argument( | |
"--num_workers", | |
default=10, | |
type=int, | |
help="Number of data loading workers per GPU.", | |
) | |
parser.add_argument( | |
"--dist_url", | |
default="env://", | |
type=str, | |
help="""url used to set up | |
distributed training; see https://pytorch.org/docs/stable/distributed.html""", | |
) | |
parser.add_argument( | |
"--local_rank", | |
default=0, | |
type=int, | |
help="Please ignore and do not set this argument.", | |
) | |
parser.add_argument("--data_path", default="/path/to/imagenet/", type=str) | |
args = parser.parse_args() | |
utils.init_distributed_mode(args) | |
print("git:\n {}\n".format(utils.get_sha())) | |
print( | |
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())) | |
) | |
cudnn.benchmark = True | |
if args.load_features: | |
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) | |
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) | |
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) | |
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) | |
else: | |
# need to extract features ! | |
( | |
train_features, | |
test_features, | |
train_labels, | |
test_labels, | |
) = extract_feature_pipeline(args) | |
if utils.get_rank() == 0: | |
if args.use_cuda: | |
train_features = train_features.cuda() | |
test_features = test_features.cuda() | |
train_labels = train_labels.cuda() | |
test_labels = test_labels.cuda() | |
print("Features are ready!\nStart the k-NN classification.") | |
for k in args.nb_knn: | |
top1, top5 = knn_classifier( | |
train_features, | |
train_labels, | |
test_features, | |
test_labels, | |
k, | |
args.temperature, | |
) | |
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") | |
dist.barrier() | |