Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Callable, List, Optional, Union | |
import mmengine.dist as dist | |
import torch | |
import torch.nn as nn | |
from mmengine.runner import Runner | |
from torch.utils.data import DataLoader | |
from mmpretrain.registry import MODELS | |
from mmpretrain.structures import DataSample | |
from mmpretrain.utils import track_on_main_process | |
from .base import BaseRetriever | |
class ImageToImageRetriever(BaseRetriever): | |
"""Image To Image Retriever for supervised retrieval task. | |
Args: | |
image_encoder (Union[dict, List[dict]]): Encoder for extracting | |
features. | |
prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be | |
retrieved. The following four types are supported. | |
- DataLoader: The original dataloader serves as the prototype. | |
- dict: The configuration to construct Dataloader. | |
- str: The path of the saved vector. | |
- torch.Tensor: The saved tensor whose dimension should be dim. | |
head (dict, optional): The head module to calculate loss from | |
processed features. See :mod:`mmpretrain.models.heads`. Notice | |
that if the head is not set, `loss` method cannot be used. | |
Defaults to None. | |
similarity_fn (Union[str, Callable]): The way that the similarity | |
is calculated. If `similarity` is callable, it is used directly | |
as the measure function. If it is a string, the appropriate | |
method will be used. The larger the calculated value, the | |
greater the similarity. Defaults to "cosine_similarity". | |
train_cfg (dict, optional): The training setting. The acceptable | |
fields are: | |
- augments (List[dict]): The batch augmentation methods to use. | |
More details can be found in | |
:mod:`mmpretrain.model.utils.augment`. | |
Defaults to None. | |
data_preprocessor (dict, optional): The config for preprocessing input | |
data. If None or no specified type, it will use | |
"ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for | |
more details. Defaults to None. | |
topk (int): Return the topk of the retrieval result. `-1` means | |
return all. Defaults to -1. | |
init_cfg (dict, optional): the config to control the initialization. | |
Defaults to None. | |
""" | |
def __init__(self, | |
image_encoder: Union[dict, List[dict]], | |
prototype: Union[DataLoader, dict, str, torch.Tensor], | |
head: Optional[dict] = None, | |
pretrained: Optional[str] = None, | |
similarity_fn: Union[str, Callable] = 'cosine_similarity', | |
train_cfg: Optional[dict] = None, | |
data_preprocessor: Optional[dict] = None, | |
topk: int = -1, | |
init_cfg: Optional[dict] = None): | |
if data_preprocessor is None: | |
data_preprocessor = {} | |
# The build process is in MMEngine, so we need to add scope here. | |
data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') | |
if train_cfg is not None and 'augments' in train_cfg: | |
# Set batch augmentations by `train_cfg` | |
data_preprocessor['batch_augments'] = train_cfg | |
super(ImageToImageRetriever, self).__init__( | |
init_cfg=init_cfg, data_preprocessor=data_preprocessor) | |
if not isinstance(image_encoder, nn.Module): | |
image_encoder = MODELS.build(image_encoder) | |
if head is not None and not isinstance(head, nn.Module): | |
head = MODELS.build(head) | |
self.image_encoder = image_encoder | |
self.head = head | |
self.similarity = similarity_fn | |
assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), ( | |
'The `prototype` in `ImageToImageRetriever` must be a path, ' | |
'a torch.Tensor, a dataloader or a dataloader dict format config.') | |
self.prototype = prototype | |
self.prototype_inited = False | |
self.topk = topk | |
def similarity_fn(self): | |
"""Returns a function that calculates the similarity.""" | |
# If self.similarity_way is callable, return it directly | |
if isinstance(self.similarity, Callable): | |
return self.similarity | |
if self.similarity == 'cosine_similarity': | |
# a is a tensor with shape (N, C) | |
# b is a tensor with shape (M, C) | |
# "cosine_similarity" will get the matrix of similarity | |
# with shape (N, M). | |
# The higher the score is, the more similar is | |
return lambda a, b: torch.cosine_similarity( | |
a.unsqueeze(1), b.unsqueeze(0), dim=-1) | |
else: | |
raise RuntimeError(f'Invalid function "{self.similarity_fn}".') | |
def forward(self, | |
inputs: torch.Tensor, | |
data_samples: Optional[List[DataSample]] = None, | |
mode: str = 'tensor'): | |
"""The unified entry for a forward process in both training and test. | |
The method should accept three modes: "tensor", "predict" and "loss": | |
- "tensor": Forward the whole network and return tensor without any | |
post-processing, same as a common nn.Module. | |
- "predict": Forward and return the predictions, which are fully | |
processed to a list of :obj:`DataSample`. | |
- "loss": Forward and return a dict of losses according to the given | |
inputs and data samples. | |
Note that this method doesn't handle neither back propagation nor | |
optimizer updating, which are done in the :meth:`train_step`. | |
Args: | |
inputs (torch.Tensor, tuple): The input tensor with shape | |
(N, C, ...) in general. | |
data_samples (List[DataSample], optional): The annotation | |
data of every samples. It's required if ``mode="loss"``. | |
Defaults to None. | |
mode (str): Return what kind of value. Defaults to 'tensor'. | |
Returns: | |
The return type depends on ``mode``. | |
- If ``mode="tensor"``, return a tensor. | |
- If ``mode="predict"``, return a list of | |
:obj:`mmpretrain.structures.DataSample`. | |
- If ``mode="loss"``, return a dict of tensor. | |
""" | |
if mode == 'tensor': | |
return self.extract_feat(inputs) | |
elif mode == 'loss': | |
return self.loss(inputs, data_samples) | |
elif mode == 'predict': | |
return self.predict(inputs, data_samples) | |
else: | |
raise RuntimeError(f'Invalid mode "{mode}".') | |
def extract_feat(self, inputs): | |
"""Extract features from the input tensor with shape (N, C, ...). | |
Args: | |
inputs (Tensor): A batch of inputs. The shape of it should be | |
``(num_samples, num_channels, *img_shape)``. | |
Returns: | |
Tensor: The output of encoder. | |
""" | |
feat = self.image_encoder(inputs) | |
return feat | |
def loss(self, inputs: torch.Tensor, | |
data_samples: List[DataSample]) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (torch.Tensor): The input tensor with shape | |
(N, C, ...) in general. | |
data_samples (List[DataSample]): The annotation data of | |
every samples. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
feats = self.extract_feat(inputs) | |
return self.head.loss(feats, data_samples) | |
def matching(self, inputs: torch.Tensor): | |
"""Compare the prototype and calculate the similarity. | |
Args: | |
inputs (torch.Tensor): The input tensor with shape (N, C). | |
Returns: | |
dict: a dictionary of score and prediction label based on fn. | |
""" | |
sim = self.similarity_fn(inputs, self.prototype_vecs) | |
sorted_sim, indices = torch.sort(sim, descending=True, dim=-1) | |
predictions = dict( | |
score=sim, pred_label=indices, pred_score=sorted_sim) | |
return predictions | |
def predict(self, | |
inputs: tuple, | |
data_samples: Optional[List[DataSample]] = None, | |
**kwargs) -> List[DataSample]: | |
"""Predict results from the extracted features. | |
Args: | |
inputs (tuple): The features extracted from the backbone. | |
data_samples (List[DataSample], optional): The annotation | |
data of every samples. Defaults to None. | |
**kwargs: Other keyword arguments accepted by the ``predict`` | |
method of :attr:`head`. | |
Returns: | |
List[DataSample]: the raw data_samples with | |
the predicted results | |
""" | |
if not self.prototype_inited: | |
self.prepare_prototype() | |
feats = self.extract_feat(inputs) | |
if isinstance(feats, tuple): | |
feats = feats[-1] | |
# Matching of similarity | |
result = self.matching(feats) | |
return self._get_predictions(result, data_samples) | |
def _get_predictions(self, result, data_samples): | |
"""Post-process the output of retriever.""" | |
pred_scores = result['score'] | |
pred_labels = result['pred_label'] | |
if self.topk != -1: | |
topk = min(self.topk, pred_scores.size()[-1]) | |
pred_labels = pred_labels[:, :topk] | |
if data_samples is not None: | |
for data_sample, score, label in zip(data_samples, pred_scores, | |
pred_labels): | |
data_sample.set_pred_score(score).set_pred_label(label) | |
else: | |
data_samples = [] | |
for score, label in zip(pred_scores, pred_labels): | |
data_samples.append( | |
DataSample().set_pred_score(score).set_pred_label(label)) | |
return data_samples | |
def _get_prototype_vecs_from_dataloader(self, data_loader): | |
"""get prototype_vecs from dataloader.""" | |
self.eval() | |
num = len(data_loader.dataset) | |
prototype_vecs = None | |
for data_batch in track_on_main_process(data_loader, | |
'Prepare prototype'): | |
data = self.data_preprocessor(data_batch, False) | |
feat = self(**data) | |
if isinstance(feat, tuple): | |
feat = feat[-1] | |
if prototype_vecs is None: | |
dim = feat.shape[-1] | |
prototype_vecs = torch.zeros(num, dim) | |
for i, data_sample in enumerate(data_batch['data_samples']): | |
sample_idx = data_sample.get('sample_idx') | |
prototype_vecs[sample_idx] = feat[i] | |
assert prototype_vecs is not None | |
dist.all_reduce(prototype_vecs) | |
return prototype_vecs | |
def _get_prototype_vecs_from_path(self, proto_path): | |
"""get prototype_vecs from prototype path.""" | |
data = [None] | |
if dist.is_main_process(): | |
data[0] = torch.load(proto_path) | |
dist.broadcast_object_list(data, src=0) | |
prototype_vecs = data[0] | |
assert prototype_vecs is not None | |
return prototype_vecs | |
def prepare_prototype(self): | |
"""Used in meta testing. This function will be called before the meta | |
testing. Obtain the vector based on the prototype. | |
- torch.Tensor: The prototype vector is the prototype | |
- str: The path of the extracted feature path, parse data structure, | |
and generate the prototype feature vector set | |
- Dataloader or config: Extract and save the feature vectors according | |
to the dataloader | |
""" | |
device = next(self.image_encoder.parameters()).device | |
if isinstance(self.prototype, torch.Tensor): | |
prototype_vecs = self.prototype | |
elif isinstance(self.prototype, str): | |
prototype_vecs = self._get_prototype_vecs_from_path(self.prototype) | |
elif isinstance(self.prototype, (dict, DataLoader)): | |
loader = Runner.build_dataloader(self.prototype) | |
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader) | |
self.register_buffer( | |
'prototype_vecs', prototype_vecs.to(device), persistent=False) | |
self.prototype_inited = True | |
def dump_prototype(self, path): | |
"""Save the features extracted from the prototype to specific path. | |
Args: | |
path (str): Path to save feature. | |
""" | |
if not self.prototype_inited: | |
self.prepare_prototype() | |
torch.save(self.prototype_vecs, path) | |