# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse from functools import partial import json import logging import os import sys from typing import List, Optional import torch from torch.nn.functional import one_hot, softmax import dinov2.distributed as distributed from dinov2.data import SamplerType, make_data_loader, make_dataset from dinov2.data.transforms import make_classification_eval_transform from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric from dinov2.eval.setup import get_args_parser as get_setup_args_parser from dinov2.eval.setup import setup_and_build_model from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features logger = logging.getLogger("dinov2") def get_args_parser( description: Optional[str] = None, parents: Optional[List[argparse.ArgumentParser]] = None, add_help: bool = True, ): parents = parents or [] setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) parents = [setup_args_parser] parser = argparse.ArgumentParser( description=description, parents=parents, add_help=add_help, ) parser.add_argument( "--train-dataset", dest="train_dataset_str", type=str, help="Training dataset", ) parser.add_argument( "--val-dataset", dest="val_dataset_str", type=str, help="Validation dataset", ) parser.add_argument( "--nb_knn", nargs="+", type=int, help="Number of NN to use. 20 is usually working the best.", ) parser.add_argument( "--temperature", type=float, help="Temperature used in the voting coefficient", ) parser.add_argument( "--gather-on-cpu", action="store_true", help="Whether to gather the train features on cpu, slower" "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", ) parser.add_argument( "--batch-size", type=int, help="Batch size.", ) parser.add_argument( "--n-per-class-list", nargs="+", type=int, help="Number to take per class", ) parser.add_argument( "--n-tries", type=int, help="Number of tries", ) parser.set_defaults( train_dataset_str="ImageNet:split=TRAIN", val_dataset_str="ImageNet:split=VAL", nb_knn=[10, 20, 100, 200], temperature=0.07, batch_size=256, n_per_class_list=[-1], n_tries=1, ) return parser class KnnModule(torch.nn.Module): """ Gets knn of test features from all processes on a chunk of the train features Each rank gets a chunk of the train features as well as a chunk of the test features. In `compute_neighbors`, for each rank one after the other, its chunk of test features is sent to all devices, partial knns are computed with each chunk of train features then collated back on the original device. """ def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): super().__init__() self.global_rank = distributed.get_global_rank() self.global_size = distributed.get_global_size() self.device = device self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) self.nb_knn = nb_knn self.max_k = max(self.nb_knn) self.T = T self.num_classes = num_classes def _get_knn_sims_and_labels(self, similarity, train_labels): topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) neighbors_labels = torch.gather(train_labels, 1, indices) return topk_sims, neighbors_labels def _similarity_for_rank(self, features_rank, source_rank): # Send the features from `source_rank` to all ranks broadcast_shape = torch.tensor(features_rank.shape).to(self.device) torch.distributed.broadcast(broadcast_shape, source_rank) broadcasted = features_rank if self.global_rank != source_rank: broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) torch.distributed.broadcast(broadcasted, source_rank) # Compute the neighbors for `source_rank` among `train_features_rank_T` similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) candidate_labels = self.candidates.expand(len(similarity_rank), -1) return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): # Gather all neighbors for `target_rank` topk_sims_rank = retrieved_rank = None if self.global_rank == target_rank: topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) if self.global_rank == target_rank: # Perform a second top-k on the k * global_size retrieved neighbors topk_sims_rank = torch.cat(topk_sims_rank, dim=1) retrieved_rank = torch.cat(retrieved_rank, dim=1) results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) return results return None def compute_neighbors(self, features_rank): for rank in range(self.global_size): topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) if results is not None: topk_sims_rank, neighbors_labels_rank = results return topk_sims_rank, neighbors_labels_rank def forward(self, features_rank): """ Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` """ assert all(k <= self.max_k for k in self.nb_knn) topk_sims, neighbors_labels = self.compute_neighbors(features_rank) batch_size = neighbors_labels.shape[0] topk_sims_transform = softmax(topk_sims / self.T, 1) matmul = torch.mul( one_hot(neighbors_labels, num_classes=self.num_classes), topk_sims_transform.view(batch_size, -1, 1), ) probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} return probas_for_k class DictKeysModule(torch.nn.Module): def __init__(self, keys): super().__init__() self.keys = keys def forward(self, features_dict, targets): for k in self.keys: features_dict = features_dict[k] return {"preds": features_dict, "target": targets} def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): modules = {} mapping = create_class_indices_mapping(train_labels) for npc in n_per_class_list: if npc < 0: # Only one try needed when using the full data full_module = module( train_features=train_features, train_labels=train_labels, nb_knn=nb_knn, ) modules["full"] = ModuleDictWithForward({"1": full_module}) continue all_tries = {} for t in range(n_tries): final_indices = filter_train(mapping, npc, seed=t) k_list = list(set(nb_knn + [npc])) k_list = sorted([el for el in k_list if el <= npc]) all_tries[str(t)] = module( train_features=train_features[final_indices], train_labels=train_labels[final_indices], nb_knn=k_list, ) modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) return ModuleDictWithForward(modules) def filter_train(mapping, n_per_class, seed): torch.manual_seed(seed) final_indices = [] for k in mapping.keys(): index = torch.randperm(len(mapping[k]))[:n_per_class] final_indices.append(mapping[k][index]) return torch.cat(final_indices).squeeze() def create_class_indices_mapping(labels): unique_labels, inverse = torch.unique(labels, return_inverse=True) mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} return mapping class ModuleDictWithForward(torch.nn.ModuleDict): def forward(self, *args, **kwargs): return {k: module(*args, **kwargs) for k, module in self._modules.items()} def eval_knn( model, train_dataset, val_dataset, accuracy_averaging, nb_knn, temperature, batch_size, num_workers, gather_on_cpu, n_per_class_list=[-1], n_tries=1, ): model = ModelWithNormalize(model) logger.info("Extracting features for train set...") train_features, train_labels = extract_features( model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu ) logger.info(f"Train features created, shape {train_features.shape}.") val_dataloader = make_data_loader( dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, sampler_type=SamplerType.DISTRIBUTED, drop_last=False, shuffle=False, persistent_workers=True, ) num_classes = train_labels.max() + 1 metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) device = torch.cuda.current_device() partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) knn_module_dict = create_module_dict( module=partial_module, n_per_class_list=n_per_class_list, n_tries=n_tries, nb_knn=nb_knn, train_features=train_features, train_labels=train_labels, ) postprocessors, metrics = {}, {} for n_per_class, knn_module in knn_module_dict.items(): for t, knn_try in knn_module.items(): postprocessors = { **postprocessors, **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, } metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} model_with_knn = torch.nn.Sequential(model, knn_module_dict) # ============ evaluation ... ============ logger.info("Start the k-NN classification.") _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) # Averaging the results over the n tries for each value of n_per_class for n_per_class, knn_module in knn_module_dict.items(): first_try = list(knn_module.keys())[0] k_list = knn_module[first_try].nb_knn for k in k_list: keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` results_dict[(n_per_class, k)] = { key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) for key in keys } for t in knn_module.keys(): del results_dict[(n_per_class, t, k)] return results_dict def eval_knn_with_model( model, output_dir, train_dataset_str="ImageNet:split=TRAIN", val_dataset_str="ImageNet:split=VAL", nb_knn=(10, 20, 100, 200), temperature=0.07, autocast_dtype=torch.float, accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, transform=None, gather_on_cpu=False, batch_size=256, num_workers=5, n_per_class_list=[-1], n_tries=1, ): transform = transform or make_classification_eval_transform() train_dataset = make_dataset( dataset_str=train_dataset_str, transform=transform, ) val_dataset = make_dataset( dataset_str=val_dataset_str, transform=transform, ) with torch.cuda.amp.autocast(dtype=autocast_dtype): results_dict_knn = eval_knn( model=model, train_dataset=train_dataset, val_dataset=val_dataset, accuracy_averaging=accuracy_averaging, nb_knn=nb_knn, temperature=temperature, batch_size=batch_size, num_workers=num_workers, gather_on_cpu=gather_on_cpu, n_per_class_list=n_per_class_list, n_tries=n_tries, ) results_dict = {} if distributed.is_main_process(): for knn_ in results_dict_knn.keys(): top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 results_dict[f"{knn_} Top 1"] = top1 results_dict[f"{knn_} Top 5"] = top5 logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") with open(metrics_file_path, "a") as f: for k, v in results_dict.items(): f.write(json.dumps({k: v}) + "\n") if distributed.is_enabled(): torch.distributed.barrier() return results_dict def main(args): model, autocast_dtype = setup_and_build_model(args) eval_knn_with_model( model=model, output_dir=args.output_dir, train_dataset_str=args.train_dataset_str, val_dataset_str=args.val_dataset_str, nb_knn=args.nb_knn, temperature=args.temperature, autocast_dtype=autocast_dtype, accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, transform=None, gather_on_cpu=args.gather_on_cpu, batch_size=args.batch_size, num_workers=5, n_per_class_list=args.n_per_class_list, n_tries=args.n_tries, ) return 0 if __name__ == "__main__": description = "DINOv2 k-NN evaluation" args_parser = get_args_parser(description=description) args = args_parser.parse_args() sys.exit(main(args))