File size: 12,813 Bytes
4d0eb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union

import mmengine.dist as dist
import torch
import torch.nn as nn
from mmengine.runner import Runner
from torch.utils.data import DataLoader

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .base import BaseRetriever


@MODELS.register_module()
class ImageToImageRetriever(BaseRetriever):
    """Image To Image Retriever for supervised retrieval task.

    Args:
        image_encoder (Union[dict, List[dict]]): Encoder for extracting
            features.
        prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be
            retrieved. The following four types are supported.

            - DataLoader: The original dataloader serves as the prototype.
            - dict: The configuration to construct Dataloader.
            - str: The path of the saved vector.
            - torch.Tensor: The saved tensor whose dimension should be dim.

        head (dict, optional): The head module to calculate loss from
            processed features. See :mod:`mmpretrain.models.heads`. Notice
            that if the head is not set, `loss` method cannot be used.
            Defaults to None.
        similarity_fn (Union[str, Callable]): The way that the similarity
            is calculated. If `similarity` is callable, it is used directly
            as the measure function. If it is a string, the appropriate
            method will be used.  The larger the calculated value, the
            greater the similarity. Defaults to "cosine_similarity".
        train_cfg (dict, optional): The training setting. The acceptable
            fields are:

            - augments (List[dict]): The batch augmentation methods to use.
              More details can be found in
              :mod:`mmpretrain.model.utils.augment`.

            Defaults to None.
        data_preprocessor (dict, optional): The config for preprocessing input
            data. If None or no specified type, it will use
            "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
            more details. Defaults to None.
        topk (int): Return the topk of the retrieval result. `-1` means
            return all. Defaults to -1.
        init_cfg (dict, optional): the config to control the initialization.
            Defaults to None.
    """

    def __init__(self,
                 image_encoder: Union[dict, List[dict]],
                 prototype: Union[DataLoader, dict, str, torch.Tensor],
                 head: Optional[dict] = None,
                 pretrained: Optional[str] = None,
                 similarity_fn: Union[str, Callable] = 'cosine_similarity',
                 train_cfg: Optional[dict] = None,
                 data_preprocessor: Optional[dict] = None,
                 topk: int = -1,
                 init_cfg: Optional[dict] = None):

        if data_preprocessor is None:
            data_preprocessor = {}
        # The build process is in MMEngine, so we need to add scope here.
        data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')

        if train_cfg is not None and 'augments' in train_cfg:
            # Set batch augmentations by `train_cfg`
            data_preprocessor['batch_augments'] = train_cfg

        super(ImageToImageRetriever, self).__init__(
            init_cfg=init_cfg, data_preprocessor=data_preprocessor)

        if not isinstance(image_encoder, nn.Module):
            image_encoder = MODELS.build(image_encoder)
        if head is not None and not isinstance(head, nn.Module):
            head = MODELS.build(head)

        self.image_encoder = image_encoder
        self.head = head

        self.similarity = similarity_fn

        assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), (
            'The `prototype` in  `ImageToImageRetriever` must be a path, '
            'a torch.Tensor, a dataloader or a dataloader dict format config.')
        self.prototype = prototype
        self.prototype_inited = False
        self.topk = topk

    @property
    def similarity_fn(self):
        """Returns a function that calculates the similarity."""
        # If self.similarity_way is callable, return it directly
        if isinstance(self.similarity, Callable):
            return self.similarity

        if self.similarity == 'cosine_similarity':
            # a is a tensor with shape (N, C)
            # b is a tensor with shape (M, C)
            # "cosine_similarity" will get the matrix of similarity
            # with shape (N, M).
            # The higher the score is, the more similar is
            return lambda a, b: torch.cosine_similarity(
                a.unsqueeze(1), b.unsqueeze(0), dim=-1)
        else:
            raise RuntimeError(f'Invalid function "{self.similarity_fn}".')

    def forward(self,
                inputs: torch.Tensor,
                data_samples: Optional[List[DataSample]] = None,
                mode: str = 'tensor'):
        """The unified entry for a forward process in both training and test.

        The method should accept three modes: "tensor", "predict" and "loss":

        - "tensor": Forward the whole network and return tensor without any
          post-processing, same as a common nn.Module.
        - "predict": Forward and return the predictions, which are fully
          processed to a list of :obj:`DataSample`.
        - "loss": Forward and return a dict of losses according to the given
          inputs and data samples.

        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.

        Args:
            inputs (torch.Tensor, tuple): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample], optional): The annotation
                data of every samples. It's required if ``mode="loss"``.
                Defaults to None.
            mode (str): Return what kind of value. Defaults to 'tensor'.

        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor.
            - If ``mode="predict"``, return a list of
              :obj:`mmpretrain.structures.DataSample`.
            - If ``mode="loss"``, return a dict of tensor.
        """
        if mode == 'tensor':
            return self.extract_feat(inputs)
        elif mode == 'loss':
            return self.loss(inputs, data_samples)
        elif mode == 'predict':
            return self.predict(inputs, data_samples)
        else:
            raise RuntimeError(f'Invalid mode "{mode}".')

    def extract_feat(self, inputs):
        """Extract features from the input tensor with shape (N, C, ...).

        Args:
            inputs (Tensor): A batch of inputs. The shape of it should be
                ``(num_samples, num_channels, *img_shape)``.
        Returns:
            Tensor: The output of encoder.
        """

        feat = self.image_encoder(inputs)
        return feat

    def loss(self, inputs: torch.Tensor,
             data_samples: List[DataSample]) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample]): The annotation data of
                every samples.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        feats = self.extract_feat(inputs)
        return self.head.loss(feats, data_samples)

    def matching(self, inputs: torch.Tensor):
        """Compare the prototype and calculate the similarity.

        Args:
            inputs (torch.Tensor): The input tensor with shape (N, C).
        Returns:
            dict: a dictionary of score and prediction label based on fn.
        """
        sim = self.similarity_fn(inputs, self.prototype_vecs)
        sorted_sim, indices = torch.sort(sim, descending=True, dim=-1)
        predictions = dict(
            score=sim, pred_label=indices, pred_score=sorted_sim)
        return predictions

    def predict(self,
                inputs: tuple,
                data_samples: Optional[List[DataSample]] = None,
                **kwargs) -> List[DataSample]:
        """Predict results from the extracted features.

        Args:
            inputs (tuple): The features extracted from the backbone.
            data_samples (List[DataSample], optional): The annotation
                data of every samples. Defaults to None.
            **kwargs: Other keyword arguments accepted by the ``predict``
                method of :attr:`head`.
        Returns:
            List[DataSample]: the raw data_samples with
                the predicted results
        """
        if not self.prototype_inited:
            self.prepare_prototype()

        feats = self.extract_feat(inputs)
        if isinstance(feats, tuple):
            feats = feats[-1]

        # Matching of similarity
        result = self.matching(feats)
        return self._get_predictions(result, data_samples)

    def _get_predictions(self, result, data_samples):
        """Post-process the output of retriever."""
        pred_scores = result['score']
        pred_labels = result['pred_label']
        if self.topk != -1:
            topk = min(self.topk, pred_scores.size()[-1])
            pred_labels = pred_labels[:, :topk]

        if data_samples is not None:
            for data_sample, score, label in zip(data_samples, pred_scores,
                                                 pred_labels):
                data_sample.set_pred_score(score).set_pred_label(label)
        else:
            data_samples = []
            for score, label in zip(pred_scores, pred_labels):
                data_samples.append(
                    DataSample().set_pred_score(score).set_pred_label(label))
        return data_samples

    def _get_prototype_vecs_from_dataloader(self, data_loader):
        """get prototype_vecs from dataloader."""
        self.eval()
        num = len(data_loader.dataset)

        prototype_vecs = None
        for data_batch in track_on_main_process(data_loader,
                                                'Prepare prototype'):
            data = self.data_preprocessor(data_batch, False)
            feat = self(**data)
            if isinstance(feat, tuple):
                feat = feat[-1]

            if prototype_vecs is None:
                dim = feat.shape[-1]
                prototype_vecs = torch.zeros(num, dim)
            for i, data_sample in enumerate(data_batch['data_samples']):
                sample_idx = data_sample.get('sample_idx')
                prototype_vecs[sample_idx] = feat[i]

        assert prototype_vecs is not None
        dist.all_reduce(prototype_vecs)
        return prototype_vecs

    def _get_prototype_vecs_from_path(self, proto_path):
        """get prototype_vecs from prototype path."""
        data = [None]
        if dist.is_main_process():
            data[0] = torch.load(proto_path)
        dist.broadcast_object_list(data, src=0)
        prototype_vecs = data[0]
        assert prototype_vecs is not None
        return prototype_vecs

    @torch.no_grad()
    def prepare_prototype(self):
        """Used in meta testing. This function will be called before the meta
        testing. Obtain the vector based on the prototype.

        - torch.Tensor: The prototype vector is the prototype
        - str: The path of the extracted feature path, parse data structure,
            and generate the prototype feature vector set
        - Dataloader or config: Extract and save the feature vectors according
            to the dataloader
        """
        device = next(self.image_encoder.parameters()).device
        if isinstance(self.prototype, torch.Tensor):
            prototype_vecs = self.prototype
        elif isinstance(self.prototype, str):
            prototype_vecs = self._get_prototype_vecs_from_path(self.prototype)
        elif isinstance(self.prototype, (dict, DataLoader)):
            loader = Runner.build_dataloader(self.prototype)
            prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)

        self.register_buffer(
            'prototype_vecs', prototype_vecs.to(device), persistent=False)
        self.prototype_inited = True

    def dump_prototype(self, path):
        """Save the features extracted from the prototype to specific path.

        Args:
            path (str): Path to save feature.
        """
        if not self.prototype_inited:
            self.prepare_prototype()
        torch.save(self.prototype_vecs, path)