# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os import random import subprocess from urllib.parse import urlparse import numpy as np import torch from torch import nn logger = logging.getLogger("dinov2") def load_pretrained_weights(model, pretrained_weights, checkpoint_key): if urlparse(pretrained_weights).scheme: # If it looks like an URL state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") else: state_dict = torch.load(pretrained_weights, map_location="cpu") if checkpoint_key is not None and checkpoint_key in state_dict: logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") state_dict = state_dict[checkpoint_key] # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # remove `backbone.` prefix induced by multicrop wrapper state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} msg = model.load_state_dict(state_dict, strict=False) logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) def fix_random_seeds(seed=31): """ Fix random seeds. """ torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() sha = "N/A" diff = "clean" branch = "N/A" try: sha = _run(["git", "rev-parse", "HEAD"]) subprocess.check_output(["git", "diff"], cwd=cwd) diff = _run(["git", "diff-index", "HEAD"]) diff = "has uncommitted changes" if diff else "clean" branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message class CosineScheduler(object): def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): super().__init__() self.final_value = final_value self.total_iters = total_iters freeze_schedule = np.zeros((freeze_iters)) warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(total_iters - warmup_iters - freeze_iters) schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) assert len(self.schedule) == self.total_iters def __getitem__(self, it): if it >= self.total_iters: return self.final_value else: return self.schedule[it] def has_batchnorms(model): bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) for name, module in model.named_modules(): if isinstance(module, bn_types): return True return False