File size: 3,251 Bytes
3d3e4e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import warnings
from collections import defaultdict
from dataclasses import field, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Callable


import torch
import torch.nn as nn
import torchvision

import io
from PIL import Image
import numpy as np

logger = logging.getLogger(__name__)

_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]


class MultiScaleImageFeatureExtractor(nn.Module):
    def __init__(
        self,
        modelname: str = "dino_vits16",
        freeze: bool = False,
        scale_factors: list = [1, 1 / 2, 1 / 3],
    ):
        super().__init__()
        self.freeze = freeze
        self.scale_factors = scale_factors

        if "res" in modelname:
            self._net = getattr(torchvision.models, modelname)(pretrained=True)
            self._output_dim = self._net.fc.weight.shape[1]
            self._net.fc = nn.Identity()
        elif "dino" in modelname:
            self._net = torch.hub.load("facebookresearch/dino:main", modelname)
            self._output_dim = self._net.norm.weight.shape[0]
        else:
            raise ValueError(f"Unknown model name {modelname}")

        for name, value in (
            ("_resnet_mean", _RESNET_MEAN),
            ("_resnet_std", _RESNET_STD),
        ):
            self.register_buffer(
                name,
                torch.FloatTensor(value).view(1, 3, 1, 1),
                persistent=False,
            )

        if self.freeze:
            for param in self.parameters():
                param.requires_grad = False

    def get_output_dim(self):
        return self._output_dim

    def forward(self, image_rgb: torch.Tensor) -> torch.Tensor:
        img_normed = self._resnet_normalize_image(image_rgb)

        features = self._compute_multiscale_features(img_normed)

        return features

    def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
        return (img - self._resnet_mean) / self._resnet_std

    def _compute_multiscale_features(
        self, img_normed: torch.Tensor
    ) -> torch.Tensor:
        multiscale_features = None

        if len(self.scale_factors) <= 0:
            raise ValueError(
                f"Wrong format of self.scale_factors: {self.scale_factors}"
            )

        for scale_factor in self.scale_factors:
            if scale_factor == 1:
                inp = img_normed
            else:
                inp = self._resize_image(img_normed, scale_factor)

            if multiscale_features is None:
                multiscale_features = self._net(inp)
            else:
                multiscale_features += self._net(inp)

        averaged_features = multiscale_features / len(self.scale_factors)
        return averaged_features

    @staticmethod
    def _resize_image(image: torch.Tensor, scale_factor: float) -> torch.Tensor:
        return nn.functional.interpolate(
            image,
            scale_factor=scale_factor,
            mode="bilinear",
            align_corners=False,
        )