Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Sequence, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from mmpose.models.utils.ops import resize | |
| from mmpose.registry import MODELS | |
| class FeatureMapProcessor(nn.Module): | |
| """A PyTorch module for selecting, concatenating, and rescaling feature | |
| maps. | |
| Args: | |
| select_index (Optional[Union[int, Tuple[int]]], optional): Index or | |
| indices of feature maps to select. Defaults to None, which means | |
| all feature maps are used. | |
| concat (bool, optional): Whether to concatenate the selected feature | |
| maps. Defaults to False. | |
| scale_factor (float, optional): The scaling factor to apply to the | |
| feature maps. Defaults to 1.0. | |
| apply_relu (bool, optional): Whether to apply ReLU on input feature | |
| maps. Defaults to False. | |
| align_corners (bool, optional): Whether to align corners when resizing | |
| the feature maps. Defaults to False. | |
| """ | |
| def __init__( | |
| self, | |
| select_index: Optional[Union[int, Tuple[int]]] = None, | |
| concat: bool = False, | |
| scale_factor: float = 1.0, | |
| apply_relu: bool = False, | |
| align_corners: bool = False, | |
| ): | |
| super().__init__() | |
| if isinstance(select_index, int): | |
| select_index = (select_index, ) | |
| self.select_index = select_index | |
| self.concat = concat | |
| assert ( | |
| scale_factor > 0 | |
| ), f'the argument `scale_factor` must be positive, ' \ | |
| f'but got {scale_factor}' | |
| self.scale_factor = scale_factor | |
| self.apply_relu = apply_relu | |
| self.align_corners = align_corners | |
| def forward(self, inputs: Union[Tensor, Sequence[Tensor]] | |
| ) -> Union[Tensor, List[Tensor]]: | |
| if not isinstance(inputs, (tuple, list)): | |
| sequential_input = False | |
| inputs = [inputs] | |
| else: | |
| sequential_input = True | |
| if self.select_index is not None: | |
| inputs = [inputs[i] for i in self.select_index] | |
| if self.concat: | |
| inputs = self._concat(inputs) | |
| if self.apply_relu: | |
| inputs = [F.relu(x) for x in inputs] | |
| if self.scale_factor != 1.0: | |
| inputs = self._rescale(inputs) | |
| if not sequential_input: | |
| inputs = inputs[0] | |
| return inputs | |
| def _concat(self, inputs: Sequence[Tensor]) -> List[Tensor]: | |
| size = inputs[0].shape[-2:] | |
| resized_inputs = [ | |
| resize( | |
| x, | |
| size=size, | |
| mode='bilinear', | |
| align_corners=self.align_corners) for x in inputs | |
| ] | |
| return [torch.cat(resized_inputs, dim=1)] | |
| def _rescale(self, inputs: Sequence[Tensor]) -> List[Tensor]: | |
| rescaled_inputs = [ | |
| resize( | |
| x, | |
| scale_factor=self.scale_factor, | |
| mode='bilinear', | |
| align_corners=self.align_corners, | |
| ) for x in inputs | |
| ] | |
| return rescaled_inputs | |