|
|
|
|
|
|
|
|
|
|
| import logging
|
| import os
|
| import random
|
| import subprocess
|
| from typing import Callable, List, Optional, Tuple
|
|
|
| import numpy as np
|
| import torch
|
| from torch import Tensor, nn
|
|
|
| logger = logging.getLogger("dinov3")
|
|
|
|
|
| def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
|
| shapes = [x.shape for x in x_list]
|
| num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
|
| flattened = torch.cat([x.flatten(0, -2) for x in x_list])
|
| return flattened, shapes, num_tokens
|
|
|
|
|
| def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]:
|
| outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
|
| shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes]
|
| outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)]
|
| return outputs_reshaped
|
|
|
|
|
| def named_replace(
|
| fn: Callable,
|
| module: nn.Module,
|
| name: str = "",
|
| depth_first: bool = True,
|
| include_root: bool = False,
|
| ) -> nn.Module:
|
| if not depth_first and include_root:
|
| module = fn(module=module, name=name)
|
| for child_name_o, child_module in list(module.named_children()):
|
| child_name = ".".join((name, child_name_o)) if name else child_name_o
|
| new_child = named_replace(
|
| fn=fn,
|
| module=child_module,
|
| name=child_name,
|
| depth_first=depth_first,
|
| include_root=True,
|
| )
|
| setattr(module, child_name_o, new_child)
|
|
|
| if depth_first and include_root:
|
| module = fn(module=module, name=name)
|
| return module
|
|
|
|
|
| def named_apply(
|
| fn: Callable,
|
| module: nn.Module,
|
| name: str = "",
|
| depth_first: bool = True,
|
| include_root: bool = False,
|
| ) -> nn.Module:
|
| if not depth_first and include_root:
|
| fn(module=module, name=name)
|
| for child_name, child_module in module.named_children():
|
| child_name = ".".join((name, child_name)) if name else child_name
|
| named_apply(
|
| fn=fn,
|
| module=child_module,
|
| name=child_name,
|
| depth_first=depth_first,
|
| include_root=True,
|
| )
|
| if depth_first and include_root:
|
| fn(module=module, name=name)
|
| return module
|
|
|
|
|
| def fix_random_seeds(seed: int = 31):
|
| """
|
| Fix random seeds.
|
| """
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
| np.random.seed(seed)
|
| random.seed(seed)
|
|
|
|
|
| def get_sha() -> str:
|
| cwd = os.path.dirname(os.path.abspath(__file__))
|
|
|
| def _run(command):
|
| return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
|
|
| sha = "N/A"
|
| diff = "clean"
|
| branch = "N/A"
|
| try:
|
| sha = _run(["git", "rev-parse", "HEAD"])
|
| subprocess.check_output(["git", "diff"], cwd=cwd)
|
| diff = _run(["git", "diff-index", "HEAD"])
|
| diff = "has uncommited changes" if diff else "clean"
|
| branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| except Exception:
|
| pass
|
| message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| return message
|
|
|
|
|
| def get_conda_env() -> Tuple[Optional[str], Optional[str]]:
|
| conda_env_name = os.environ.get("CONDA_DEFAULT_ENV")
|
| conda_env_path = os.environ.get("CONDA_PREFIX")
|
| return conda_env_name, conda_env_path
|
|
|
|
|
| def count_parameters(module: nn.Module) -> int:
|
| c = 0
|
| for m in module.parameters():
|
| c += m.nelement()
|
| return c
|
|
|
|
|
| def has_batchnorms(model: nn.Module) -> bool:
|
| bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| for _, module in model.named_modules():
|
| if isinstance(module, bn_types):
|
| return True
|
| return False
|
|
|