# Copyright (c) Facebook, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import argparse import torch from torch import nn import torch.distributed as dist import torch.backends.cudnn as cudnn from torchvision import datasets from torchvision import transforms as pth_transforms from torchvision import models as torchvision_models import utils import vision_transformer as vits def extract_feature_pipeline(args): # ============ preparing data ... ============ transform = pth_transforms.Compose([ pth_transforms.Resize(256, interpolation=3), pth_transforms.CenterCrop(224), pth_transforms.ToTensor(), pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform) sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") # ============ building network ... ============ if "vit" in args.arch: model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") elif "xcit" in args.arch: model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) elif args.arch in torchvision_models.__dict__.keys(): model = torchvision_models.__dict__[args.arch](num_classes=0) else: print(f"Architecture {args.arch} non supported") sys.exit(1) model.cuda() utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) model.eval() # ============ extract features ... ============ print("Extracting features for train set...") train_features = extract_features(model, data_loader_train, args.use_cuda) print("Extracting features for val set...") test_features = extract_features(model, data_loader_val, args.use_cuda) if utils.get_rank() == 0: train_features = nn.functional.normalize(train_features, dim=1, p=2) test_features = nn.functional.normalize(test_features, dim=1, p=2) train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() # save features and labels if args.dump_features and dist.get_rank() == 0: torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) return train_features, test_features, train_labels, test_labels @torch.no_grad() def extract_features(model, data_loader, use_cuda=True, multiscale=False): metric_logger = utils.MetricLogger(delimiter=" ") features = None for samples, index in metric_logger.log_every(data_loader, 10): samples = samples.cuda(non_blocking=True) index = index.cuda(non_blocking=True) if multiscale: feats = utils.multi_scale(samples, model) else: feats = model(samples).clone() # init storage feature matrix if dist.get_rank() == 0 and features is None: features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) if use_cuda: features = features.cuda(non_blocking=True) print(f"Storing features into tensor of shape {features.shape}") # get indexes from all processes y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) y_l = list(y_all.unbind(0)) y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) y_all_reduce.wait() index_all = torch.cat(y_l) # share features between processes feats_all = torch.empty( dist.get_world_size(), feats.size(0), feats.size(1), dtype=feats.dtype, device=feats.device, ) output_l = list(feats_all.unbind(0)) output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) output_all_reduce.wait() # update storage feature matrix if dist.get_rank() == 0: if use_cuda: features.index_copy_(0, index_all, torch.cat(output_l)) else: features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) return features @torch.no_grad() def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): top1, top5, total = 0.0, 0.0, 0 train_features = train_features.t() num_test_images, num_chunks = test_labels.shape[0], 100 imgs_per_chunk = num_test_images // num_chunks retrieval_one_hot = torch.zeros(k, num_classes).cuda() for idx in range(0, num_test_images, imgs_per_chunk): # get the features for test images features = test_features[ idx : min((idx + imgs_per_chunk), num_test_images), : ] targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] batch_size = targets.shape[0] # calculate the dot product and compute top-k neighbors similarity = torch.mm(features, train_features) distances, indices = similarity.topk(k, largest=True, sorted=True) candidates = train_labels.view(1, -1).expand(batch_size, -1) retrieved_neighbors = torch.gather(candidates, 1, indices) retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) distances_transform = distances.clone().div_(T).exp_() probs = torch.sum( torch.mul( retrieval_one_hot.view(batch_size, -1, num_classes), distances_transform.view(batch_size, -1, 1), ), 1, ) _, predictions = probs.sort(1, True) # find the predictions that match the target correct = predictions.eq(targets.data.view(-1, 1)) top1 = top1 + correct.narrow(1, 0, 1).sum().item() top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5 total += targets.size(0) top1 = top1 * 100.0 / total top5 = top5 * 100.0 / total return top1, top5 class ReturnIndexDataset(datasets.ImageFolder): def __getitem__(self, idx): img, lab = super(ReturnIndexDataset, self).__getitem__(idx) return img, idx if __name__ == '__main__': parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet') parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int, help='Number of NN to use. 20 is usually working the best.') parser.add_argument('--temperature', default=0.07, type=float, help='Temperature used in the voting coefficient') parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") parser.add_argument('--use_cuda', default=True, type=utils.bool_flag, help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') parser.add_argument('--dump_features', default=None, help='Path where to save computed features, empty for no saving') parser.add_argument('--load_features', default=None, help="""If the features have already been computed, where to find them.""") parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed training; see https://pytorch.org/docs/stable/distributed.html""") parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) args = parser.parse_args() utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) cudnn.benchmark = True if args.load_features: train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) else: # need to extract features ! train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) if utils.get_rank() == 0: if args.use_cuda: train_features = train_features.cuda() test_features = test_features.cuda() train_labels = train_labels.cuda() test_labels = test_labels.cuda() print("Features are ready!\nStart the k-NN classification.") for k in args.nb_knn: top1, top5 = knn_classifier(train_features, train_labels, test_features, test_labels, k, args.temperature) print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") dist.barrier()