# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import gc import logging import sys import time from typing import List, Optional from cuml.linear_model import LogisticRegression import torch import torch.backends.cudnn as cudnn import torch.distributed from torch import nn from torch.utils.data import TensorDataset from torchmetrics import MetricTracker from dinov2.data import make_dataset from dinov2.data.transforms import make_classification_eval_transform from dinov2.distributed import get_global_rank, get_global_size from dinov2.eval.metrics import MetricType, build_metric from dinov2.eval.setup import get_args_parser as get_setup_args_parser from dinov2.eval.setup import setup_and_build_model from dinov2.eval.utils import evaluate, extract_features from dinov2.utils.dtype import as_torch_dtype logger = logging.getLogger("dinov2") DEFAULT_MAX_ITER = 1_000 C_POWER_RANGE = torch.linspace(-6, 5, 45) _CPU_DEVICE = torch.device("cpu") def get_args_parser( description: Optional[str] = None, parents: Optional[List[argparse.ArgumentParser]] = None, add_help: bool = True, ): parents = parents or [] setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) parents = [setup_args_parser] parser = argparse.ArgumentParser( description=description, parents=parents, add_help=add_help, ) parser.add_argument( "--train-dataset", dest="train_dataset_str", type=str, help="Training dataset", ) parser.add_argument( "--val-dataset", dest="val_dataset_str", type=str, help="Validation dataset", ) parser.add_argument( "--finetune-dataset-str", dest="finetune_dataset_str", type=str, help="Fine-tuning dataset", ) parser.add_argument( "--finetune-on-val", action="store_true", help="If there is no finetune dataset, whether to choose the " "hyperparameters on the val set instead of 10%% of the train dataset", ) parser.add_argument( "--metric-type", type=MetricType, choices=list(MetricType), help="Metric type", ) parser.add_argument( "--train-features-device", type=str, help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s", ) parser.add_argument( "--train-dtype", type=str, help="Data type to convert the train features to (default: %(default)s)", ) parser.add_argument( "--max-train-iters", type=int, help="Maximum number of train iterations (default: %(default)s)", ) parser.set_defaults( train_dataset_str="ImageNet:split=TRAIN", val_dataset_str="ImageNet:split=VAL", finetune_dataset_str=None, metric_type=MetricType.MEAN_ACCURACY, train_features_device="cpu", train_dtype="float64", max_train_iters=DEFAULT_MAX_ITER, finetune_on_val=False, ) return parser class LogRegModule(nn.Module): def __init__( self, C, max_iter=DEFAULT_MAX_ITER, dtype=torch.float64, device=_CPU_DEVICE, ): super().__init__() self.dtype = dtype self.device = device self.estimator = LogisticRegression( penalty="l2", C=C, max_iter=max_iter, output_type="numpy", tol=1e-12, linesearch_max_iter=50, ) def forward(self, samples, targets): samples_device = samples.device samples = samples.to(dtype=self.dtype, device=self.device) if self.device == _CPU_DEVICE: samples = samples.numpy() probas = self.estimator.predict_proba(samples) return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} def fit(self, train_features, train_labels): train_features = train_features.to(dtype=self.dtype, device=self.device) train_labels = train_labels.to(dtype=self.dtype, device=self.device) if self.device == _CPU_DEVICE: # both cuML and sklearn only work with numpy arrays on CPU train_features = train_features.numpy() train_labels = train_labels.numpy() self.estimator.fit(train_features, train_labels) def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device): postprocessors = {"metrics": logreg_model} metrics = {"metrics": logreg_metric} return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device) def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE): logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device) logreg_model.fit(train_features, train_labels) return logreg_model def train_and_evaluate( *, C, max_iter, train_features, train_labels, logreg_metric, test_data_loader, train_dtype=torch.float64, train_features_device, eval_device, ): logreg_model = train_for_C( C=C, max_iter=max_iter, train_features=train_features, train_labels=train_labels, dtype=train_dtype, device=train_features_device, ) return evaluate_model( logreg_model=logreg_model, logreg_metric=logreg_metric, test_data_loader=test_data_loader, device=eval_device, ) def sweep_C_values( *, train_features, train_labels, test_data_loader, metric_type, num_classes, train_dtype=torch.float64, train_features_device=_CPU_DEVICE, max_train_iters=DEFAULT_MAX_ITER, ): if metric_type == MetricType.PER_CLASS_ACCURACY: # If we want to output per-class accuracy, we select the hyperparameters with mean per class metric_type = MetricType.MEAN_PER_CLASS_ACCURACY logreg_metric = build_metric(metric_type, num_classes=num_classes) metric_tracker = MetricTracker(logreg_metric, maximize=True) ALL_C = 10**C_POWER_RANGE logreg_models = {} train_features = train_features.to(dtype=train_dtype, device=train_features_device) train_labels = train_labels.to(device=train_features_device) for i in range(get_global_rank(), len(ALL_C), get_global_size()): C = ALL_C[i].item() logger.info( f"Training for C = {C:.5f}, dtype={train_dtype}, " f"features: {train_features.shape}, {train_features.dtype}, " f"labels: {train_labels.shape}, {train_labels.dtype}" ) logreg_models[C] = train_for_C( C=C, max_iter=max_train_iters, train_features=train_features, train_labels=train_labels, dtype=train_dtype, device=train_features_device, ) gather_list = [None for _ in range(get_global_size())] torch.distributed.all_gather_object(gather_list, logreg_models) logreg_models_gathered = {} for logreg_dict in gather_list: logreg_models_gathered.update(logreg_dict) for i in range(len(ALL_C)): metric_tracker.increment() C = ALL_C[i].item() evals = evaluate_model( logreg_model=logreg_models_gathered[C], logreg_metric=metric_tracker, test_data_loader=test_data_loader, device=torch.cuda.current_device(), ) logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}") best_stats, which_epoch = metric_tracker.best_metric(return_step=True) best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} if which_epoch["top-1"] == i: best_C = C logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}") return best_stats, best_C def eval_log_regression( *, model, train_dataset, val_dataset, finetune_dataset, metric_type, batch_size, num_workers, finetune_on_val=False, train_dtype=torch.float64, train_features_device=_CPU_DEVICE, max_train_iters=DEFAULT_MAX_ITER, ): """ Implements the "standard" process for log regression evaluation: The value of C is chosen by training on train_dataset and evaluating on finetune_dataset. Then, the final model is trained on a concatenation of train_dataset and finetune_dataset, and is evaluated on val_dataset. If there is no finetune_dataset, the value of C is the one that yields the best results on a random 10% subset of the train dataset """ start = time.time() train_features, train_labels = extract_features( model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) ) val_features, val_labels = extract_features( model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) ) val_data_loader = torch.utils.data.DataLoader( TensorDataset(val_features, val_labels), batch_size=batch_size, drop_last=False, num_workers=0, persistent_workers=False, ) if finetune_dataset is None and finetune_on_val: logger.info("Choosing hyperparameters on the val dataset") finetune_features, finetune_labels = val_features, val_labels elif finetune_dataset is None and not finetune_on_val: logger.info("Choosing hyperparameters on 10% of the train dataset") torch.manual_seed(0) indices = torch.randperm(len(train_features), device=train_features.device) finetune_index = indices[: len(train_features) // 10] train_index = indices[len(train_features) // 10 :] finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index] train_features, train_labels = train_features[train_index], train_labels[train_index] else: logger.info("Choosing hyperparameters on the finetune dataset") finetune_features, finetune_labels = extract_features( model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) ) # release the model - free GPU memory del model gc.collect() torch.cuda.empty_cache() finetune_data_loader = torch.utils.data.DataLoader( TensorDataset(finetune_features, finetune_labels), batch_size=batch_size, drop_last=False, ) if len(train_labels.shape) > 1: num_classes = train_labels.shape[1] else: num_classes = train_labels.max() + 1 logger.info("Using cuML for logistic regression") best_stats, best_C = sweep_C_values( train_features=train_features, train_labels=train_labels, test_data_loader=finetune_data_loader, metric_type=metric_type, num_classes=num_classes, train_dtype=train_dtype, train_features_device=train_features_device, max_train_iters=max_train_iters, ) if not finetune_on_val: logger.info("Best parameter found, concatenating features") train_features = torch.cat((train_features, finetune_features)) train_labels = torch.cat((train_labels, finetune_labels)) logger.info("Training final model") logreg_metric = build_metric(metric_type, num_classes=num_classes) evals = train_and_evaluate( C=best_C, max_iter=max_train_iters, train_features=train_features, train_labels=train_labels, logreg_metric=logreg_metric.clone(), test_data_loader=val_data_loader, eval_device=torch.cuda.current_device(), train_dtype=train_dtype, train_features_device=train_features_device, ) best_stats = evals[1]["metrics"] best_stats["best_C"] = best_C logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") return best_stats def eval_log_regression_with_model( model, train_dataset_str="ImageNet:split=TRAIN", val_dataset_str="ImageNet:split=VAL", finetune_dataset_str=None, autocast_dtype=torch.float, finetune_on_val=False, metric_type=MetricType.MEAN_ACCURACY, train_dtype=torch.float64, train_features_device=_CPU_DEVICE, max_train_iters=DEFAULT_MAX_ITER, ): cudnn.benchmark = True transform = make_classification_eval_transform(resize_size=224) target_transform = None train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform) val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform) if finetune_dataset_str is not None: finetune_dataset = make_dataset( dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform ) else: finetune_dataset = None with torch.cuda.amp.autocast(dtype=autocast_dtype): results_dict_logreg = eval_log_regression( model=model, train_dataset=train_dataset, val_dataset=val_dataset, finetune_dataset=finetune_dataset, metric_type=metric_type, batch_size=256, num_workers=0, # 5, finetune_on_val=finetune_on_val, train_dtype=train_dtype, train_features_device=train_features_device, max_train_iters=max_train_iters, ) results_dict = { "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0, "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0, "best_C": results_dict_logreg["best_C"], } logger.info( "\n".join( [ "Training of the supervised logistic regression on frozen features completed.\n" "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]), "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]), "obtained for C = {c:.6f}".format(c=results_dict["best_C"]), ] ) ) torch.distributed.barrier() return results_dict def main(args): model, autocast_dtype = setup_and_build_model(args) eval_log_regression_with_model( model=model, train_dataset_str=args.train_dataset_str, val_dataset_str=args.val_dataset_str, finetune_dataset_str=args.finetune_dataset_str, autocast_dtype=autocast_dtype, finetune_on_val=args.finetune_on_val, metric_type=args.metric_type, train_dtype=as_torch_dtype(args.train_dtype), train_features_device=torch.device(args.train_features_device), max_train_iters=args.max_train_iters, ) return 0 if __name__ == "__main__": description = "DINOv2 logistic regression evaluation" args_parser = get_args_parser(description=description) args = args_parser.parse_args() sys.exit(main(args))