KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
12.8 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union
import mmengine.dist as dist
import torch
import torch.nn as nn
from mmengine.runner import Runner
from torch.utils.data import DataLoader
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .base import BaseRetriever
@MODELS.register_module()
class ImageToImageRetriever(BaseRetriever):
"""Image To Image Retriever for supervised retrieval task.
Args:
image_encoder (Union[dict, List[dict]]): Encoder for extracting
features.
prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be
retrieved. The following four types are supported.
- DataLoader: The original dataloader serves as the prototype.
- dict: The configuration to construct Dataloader.
- str: The path of the saved vector.
- torch.Tensor: The saved tensor whose dimension should be dim.
head (dict, optional): The head module to calculate loss from
processed features. See :mod:`mmpretrain.models.heads`. Notice
that if the head is not set, `loss` method cannot be used.
Defaults to None.
similarity_fn (Union[str, Callable]): The way that the similarity
is calculated. If `similarity` is callable, it is used directly
as the measure function. If it is a string, the appropriate
method will be used. The larger the calculated value, the
greater the similarity. Defaults to "cosine_similarity".
train_cfg (dict, optional): The training setting. The acceptable
fields are:
- augments (List[dict]): The batch augmentation methods to use.
More details can be found in
:mod:`mmpretrain.model.utils.augment`.
Defaults to None.
data_preprocessor (dict, optional): The config for preprocessing input
data. If None or no specified type, it will use
"ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
more details. Defaults to None.
topk (int): Return the topk of the retrieval result. `-1` means
return all. Defaults to -1.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
image_encoder: Union[dict, List[dict]],
prototype: Union[DataLoader, dict, str, torch.Tensor],
head: Optional[dict] = None,
pretrained: Optional[str] = None,
similarity_fn: Union[str, Callable] = 'cosine_similarity',
train_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
topk: int = -1,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')
if train_cfg is not None and 'augments' in train_cfg:
# Set batch augmentations by `train_cfg`
data_preprocessor['batch_augments'] = train_cfg
super(ImageToImageRetriever, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
if not isinstance(image_encoder, nn.Module):
image_encoder = MODELS.build(image_encoder)
if head is not None and not isinstance(head, nn.Module):
head = MODELS.build(head)
self.image_encoder = image_encoder
self.head = head
self.similarity = similarity_fn
assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), (
'The `prototype` in `ImageToImageRetriever` must be a path, '
'a torch.Tensor, a dataloader or a dataloader dict format config.')
self.prototype = prototype
self.prototype_inited = False
self.topk = topk
@property
def similarity_fn(self):
"""Returns a function that calculates the similarity."""
# If self.similarity_way is callable, return it directly
if isinstance(self.similarity, Callable):
return self.similarity
if self.similarity == 'cosine_similarity':
# a is a tensor with shape (N, C)
# b is a tensor with shape (M, C)
# "cosine_similarity" will get the matrix of similarity
# with shape (N, M).
# The higher the score is, the more similar is
return lambda a, b: torch.cosine_similarity(
a.unsqueeze(1), b.unsqueeze(0), dim=-1)
else:
raise RuntimeError(f'Invalid function "{self.similarity_fn}".')
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor without any
post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor, tuple): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor.
- If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
return self.extract_feat(inputs)
elif mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_feat(self, inputs):
"""Extract features from the input tensor with shape (N, C, ...).
Args:
inputs (Tensor): A batch of inputs. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
Returns:
Tensor: The output of encoder.
"""
feat = self.image_encoder(inputs)
return feat
def loss(self, inputs: torch.Tensor,
data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample]): The annotation data of
every samples.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
feats = self.extract_feat(inputs)
return self.head.loss(feats, data_samples)
def matching(self, inputs: torch.Tensor):
"""Compare the prototype and calculate the similarity.
Args:
inputs (torch.Tensor): The input tensor with shape (N, C).
Returns:
dict: a dictionary of score and prediction label based on fn.
"""
sim = self.similarity_fn(inputs, self.prototype_vecs)
sorted_sim, indices = torch.sort(sim, descending=True, dim=-1)
predictions = dict(
score=sim, pred_label=indices, pred_score=sorted_sim)
return predictions
def predict(self,
inputs: tuple,
data_samples: Optional[List[DataSample]] = None,
**kwargs) -> List[DataSample]:
"""Predict results from the extracted features.
Args:
inputs (tuple): The features extracted from the backbone.
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.
Returns:
List[DataSample]: the raw data_samples with
the predicted results
"""
if not self.prototype_inited:
self.prepare_prototype()
feats = self.extract_feat(inputs)
if isinstance(feats, tuple):
feats = feats[-1]
# Matching of similarity
result = self.matching(feats)
return self._get_predictions(result, data_samples)
def _get_predictions(self, result, data_samples):
"""Post-process the output of retriever."""
pred_scores = result['score']
pred_labels = result['pred_label']
if self.topk != -1:
topk = min(self.topk, pred_scores.size()[-1])
pred_labels = pred_labels[:, :topk]
if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(
DataSample().set_pred_score(score).set_pred_label(label))
return data_samples
def _get_prototype_vecs_from_dataloader(self, data_loader):
"""get prototype_vecs from dataloader."""
self.eval()
num = len(data_loader.dataset)
prototype_vecs = None
for data_batch in track_on_main_process(data_loader,
'Prepare prototype'):
data = self.data_preprocessor(data_batch, False)
feat = self(**data)
if isinstance(feat, tuple):
feat = feat[-1]
if prototype_vecs is None:
dim = feat.shape[-1]
prototype_vecs = torch.zeros(num, dim)
for i, data_sample in enumerate(data_batch['data_samples']):
sample_idx = data_sample.get('sample_idx')
prototype_vecs[sample_idx] = feat[i]
assert prototype_vecs is not None
dist.all_reduce(prototype_vecs)
return prototype_vecs
def _get_prototype_vecs_from_path(self, proto_path):
"""get prototype_vecs from prototype path."""
data = [None]
if dist.is_main_process():
data[0] = torch.load(proto_path)
dist.broadcast_object_list(data, src=0)
prototype_vecs = data[0]
assert prototype_vecs is not None
return prototype_vecs
@torch.no_grad()
def prepare_prototype(self):
"""Used in meta testing. This function will be called before the meta
testing. Obtain the vector based on the prototype.
- torch.Tensor: The prototype vector is the prototype
- str: The path of the extracted feature path, parse data structure,
and generate the prototype feature vector set
- Dataloader or config: Extract and save the feature vectors according
to the dataloader
"""
device = next(self.image_encoder.parameters()).device
if isinstance(self.prototype, torch.Tensor):
prototype_vecs = self.prototype
elif isinstance(self.prototype, str):
prototype_vecs = self._get_prototype_vecs_from_path(self.prototype)
elif isinstance(self.prototype, (dict, DataLoader)):
loader = Runner.build_dataloader(self.prototype)
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)
self.register_buffer(
'prototype_vecs', prototype_vecs.to(device), persistent=False)
self.prototype_inited = True
def dump_prototype(self, path):
"""Save the features extracted from the prototype to specific path.
Args:
path (str): Path to save feature.
"""
if not self.prototype_inited:
self.prepare_prototype()
torch.save(self.prototype_vecs, path)