File size: 1,333 Bytes
9ae1b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import itertools
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

_DINOV2_BASE_URL = 'https://dl.fbaipublicfiles.com/dinov2'


def _make_dinov2_model_name(
    arch_name: str, patch_size: int, num_register_tokens: int = 0
) -> str:
    compact_arch_name = arch_name.replace('_', '')[:4]
    registers_suffix = (
        f'_reg{num_register_tokens}' if num_register_tokens else ''
    )
    return f'dinov2_{compact_arch_name}{patch_size}{registers_suffix}'


class CenterPadding(nn.Module):
    def __init__(self, multiple):
        super().__init__()
        self.multiple = multiple

    def _get_pad(self, size):
        new_size = math.ceil(size / self.multiple) * self.multiple
        pad_size = new_size - size
        pad_size_left = pad_size // 2
        pad_size_right = pad_size - pad_size_left
        return pad_size_left, pad_size_right

    @torch.inference_mode()
    def forward(self, x):
        pads = list(
            itertools.chain.from_iterable(
                self._get_pad(m) for m in x.shape[:1:-1]
            )
        )
        output = F.pad(x, pads)
        return output