| | |
| | import math |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from mmengine.model import BaseModule |
| | from torch import Tensor |
| |
|
| | from mmdet.registry import MODELS |
| | from mmdet.utils import MultiConfig, OptMultiConfig |
| |
|
| |
|
| | @MODELS.register_module() |
| | class SinePositionalEncoding(BaseModule): |
| | """Position encoding with sine and cosine functions. |
| | |
| | See `End-to-End Object Detection with Transformers |
| | <https://arxiv.org/pdf/2005.12872>`_ for details. |
| | |
| | Args: |
| | num_feats (int): The feature dimension for each position |
| | along x-axis or y-axis. Note the final returned dimension |
| | for each position is 2 times of this value. |
| | temperature (int, optional): The temperature used for scaling |
| | the position embedding. Defaults to 10000. |
| | normalize (bool, optional): Whether to normalize the position |
| | embedding. Defaults to False. |
| | scale (float, optional): A scale factor that scales the position |
| | embedding. The scale will be used only when `normalize` is True. |
| | Defaults to 2*pi. |
| | eps (float, optional): A value added to the denominator for |
| | numerical stability. Defaults to 1e-6. |
| | offset (float): offset add to embed when do the normalization. |
| | Defaults to 0. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Defaults to None |
| | """ |
| |
|
| | def __init__(self, |
| | num_feats: int, |
| | temperature: int = 10000, |
| | normalize: bool = False, |
| | scale: float = 2 * math.pi, |
| | eps: float = 1e-6, |
| | offset: float = 0., |
| | init_cfg: OptMultiConfig = None) -> None: |
| | super().__init__(init_cfg=init_cfg) |
| | if normalize: |
| | assert isinstance(scale, (float, int)), 'when normalize is set,' \ |
| | 'scale should be provided and in float or int type, ' \ |
| | f'found {type(scale)}' |
| | self.num_feats = num_feats |
| | self.temperature = temperature |
| | self.normalize = normalize |
| | self.scale = scale |
| | self.eps = eps |
| | self.offset = offset |
| |
|
| | def forward(self, mask: Tensor, input: Optional[Tensor] = None) -> Tensor: |
| | """Forward function for `SinePositionalEncoding`. |
| | |
| | Args: |
| | mask (Tensor): ByteTensor mask. Non-zero values representing |
| | ignored positions, while zero values means valid positions |
| | for this image. Shape [bs, h, w]. |
| | input (Tensor, optional): Input image/feature Tensor. |
| | Shape [bs, c, h, w] |
| | |
| | Returns: |
| | pos (Tensor): Returned position embedding with shape |
| | [bs, num_feats*2, h, w]. |
| | """ |
| | assert not (mask is None and input is None) |
| |
|
| | if mask is not None: |
| | B, H, W = mask.size() |
| | device = mask.device |
| | |
| | |
| | |
| | mask = mask.to(torch.int) |
| | not_mask = 1 - mask |
| | y_embed = not_mask.cumsum(1, dtype=torch.float32) |
| | x_embed = not_mask.cumsum(2, dtype=torch.float32) |
| | else: |
| | |
| | B, _, H, W = input.shape |
| | device = input.device |
| | x_embed = torch.arange( |
| | 1, W + 1, dtype=torch.float32, device=device) |
| | x_embed = x_embed.view(1, 1, -1).repeat(B, H, 1) |
| | y_embed = torch.arange( |
| | 1, H + 1, dtype=torch.float32, device=device) |
| | y_embed = y_embed.view(1, -1, 1).repeat(B, 1, W) |
| | if self.normalize: |
| | y_embed = (y_embed + self.offset) / \ |
| | (y_embed[:, -1:, :] + self.eps) * self.scale |
| | x_embed = (x_embed + self.offset) / \ |
| | (x_embed[:, :, -1:] + self.eps) * self.scale |
| | dim_t = torch.arange( |
| | self.num_feats, dtype=torch.float32, device=device) |
| | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) |
| | pos_x = x_embed[:, :, :, None] / dim_t |
| | pos_y = y_embed[:, :, :, None] / dim_t |
| | |
| |
|
| | pos_x = torch.stack( |
| | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), |
| | dim=4).view(B, H, W, -1) |
| | pos_y = torch.stack( |
| | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), |
| | dim=4).view(B, H, W, -1) |
| | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
| | return pos |
| |
|
| | def __repr__(self) -> str: |
| | """str: a string that describes the module""" |
| | repr_str = self.__class__.__name__ |
| | repr_str += f'(num_feats={self.num_feats}, ' |
| | repr_str += f'temperature={self.temperature}, ' |
| | repr_str += f'normalize={self.normalize}, ' |
| | repr_str += f'scale={self.scale}, ' |
| | repr_str += f'eps={self.eps})' |
| | return repr_str |
| |
|
| |
|
| | @MODELS.register_module() |
| | class LearnedPositionalEncoding(BaseModule): |
| | """Position embedding with learnable embedding weights. |
| | |
| | Args: |
| | num_feats (int): The feature dimension for each position |
| | along x-axis or y-axis. The final returned dimension for |
| | each position is 2 times of this value. |
| | row_num_embed (int, optional): The dictionary size of row embeddings. |
| | Defaults to 50. |
| | col_num_embed (int, optional): The dictionary size of col embeddings. |
| | Defaults to 50. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_feats: int, |
| | row_num_embed: int = 50, |
| | col_num_embed: int = 50, |
| | init_cfg: MultiConfig = dict(type='Uniform', layer='Embedding') |
| | ) -> None: |
| | super().__init__(init_cfg=init_cfg) |
| | self.row_embed = nn.Embedding(row_num_embed, num_feats) |
| | self.col_embed = nn.Embedding(col_num_embed, num_feats) |
| | self.num_feats = num_feats |
| | self.row_num_embed = row_num_embed |
| | self.col_num_embed = col_num_embed |
| |
|
| | def forward(self, mask: Tensor) -> Tensor: |
| | """Forward function for `LearnedPositionalEncoding`. |
| | |
| | Args: |
| | mask (Tensor): ByteTensor mask. Non-zero values representing |
| | ignored positions, while zero values means valid positions |
| | for this image. Shape [bs, h, w]. |
| | |
| | Returns: |
| | pos (Tensor): Returned position embedding with shape |
| | [bs, num_feats*2, h, w]. |
| | """ |
| | h, w = mask.shape[-2:] |
| | x = torch.arange(w, device=mask.device) |
| | y = torch.arange(h, device=mask.device) |
| | x_embed = self.col_embed(x) |
| | y_embed = self.row_embed(y) |
| | pos = torch.cat( |
| | (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( |
| | 1, w, 1)), |
| | dim=-1).permute(2, 0, |
| | 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) |
| | return pos |
| |
|
| | def __repr__(self) -> str: |
| | """str: a string that describes the module""" |
| | repr_str = self.__class__.__name__ |
| | repr_str += f'(num_feats={self.num_feats}, ' |
| | repr_str += f'row_num_embed={self.row_num_embed}, ' |
| | repr_str += f'col_num_embed={self.col_num_embed})' |
| | return repr_str |
| |
|
| |
|
| | @MODELS.register_module() |
| | class SinePositionalEncoding3D(SinePositionalEncoding): |
| | """Position encoding with sine and cosine functions. |
| | |
| | See `End-to-End Object Detection with Transformers |
| | <https://arxiv.org/pdf/2005.12872>`_ for details. |
| | |
| | Args: |
| | num_feats (int): The feature dimension for each position |
| | along x-axis or y-axis. Note the final returned dimension |
| | for each position is 2 times of this value. |
| | temperature (int, optional): The temperature used for scaling |
| | the position embedding. Defaults to 10000. |
| | normalize (bool, optional): Whether to normalize the position |
| | embedding. Defaults to False. |
| | scale (float, optional): A scale factor that scales the position |
| | embedding. The scale will be used only when `normalize` is True. |
| | Defaults to 2*pi. |
| | eps (float, optional): A value added to the denominator for |
| | numerical stability. Defaults to 1e-6. |
| | offset (float): offset add to embed when do the normalization. |
| | Defaults to 0. |
| | init_cfg (dict or list[dict], optional): Initialization config dict. |
| | Defaults to None. |
| | """ |
| |
|
| | def forward(self, mask: Tensor) -> Tensor: |
| | """Forward function for `SinePositionalEncoding3D`. |
| | |
| | Args: |
| | mask (Tensor): ByteTensor mask. Non-zero values representing |
| | ignored positions, while zero values means valid positions |
| | for this image. Shape [bs, t, h, w]. |
| | |
| | Returns: |
| | pos (Tensor): Returned position embedding with shape |
| | [bs, num_feats*2, h, w]. |
| | """ |
| | assert mask.dim() == 4,\ |
| | f'{mask.shape} should be a 4-dimensional Tensor,' \ |
| | f' got {mask.dim()}-dimensional Tensor instead ' |
| | |
| | |
| | mask = mask.to(torch.int) |
| | not_mask = 1 - mask |
| | z_embed = not_mask.cumsum(1, dtype=torch.float32) |
| | y_embed = not_mask.cumsum(2, dtype=torch.float32) |
| | x_embed = not_mask.cumsum(3, dtype=torch.float32) |
| | if self.normalize: |
| | z_embed = (z_embed + self.offset) / \ |
| | (z_embed[:, -1:, :, :] + self.eps) * self.scale |
| | y_embed = (y_embed + self.offset) / \ |
| | (y_embed[:, :, -1:, :] + self.eps) * self.scale |
| | x_embed = (x_embed + self.offset) / \ |
| | (x_embed[:, :, :, -1:] + self.eps) * self.scale |
| | dim_t = torch.arange( |
| | self.num_feats, dtype=torch.float32, device=mask.device) |
| | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) |
| |
|
| | dim_t_z = torch.arange((self.num_feats * 2), |
| | dtype=torch.float32, |
| | device=mask.device) |
| | dim_t_z = self.temperature**(2 * (dim_t_z // 2) / (self.num_feats * 2)) |
| |
|
| | pos_x = x_embed[:, :, :, :, None] / dim_t |
| | pos_y = y_embed[:, :, :, :, None] / dim_t |
| | pos_z = z_embed[:, :, :, :, None] / dim_t_z |
| | |
| | B, T, H, W = mask.size() |
| | pos_x = torch.stack( |
| | (pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), |
| | dim=5).view(B, T, H, W, -1) |
| | pos_y = torch.stack( |
| | (pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), |
| | dim=5).view(B, T, H, W, -1) |
| | pos_z = torch.stack( |
| | (pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), |
| | dim=5).view(B, T, H, W, -1) |
| | pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3) |
| | return pos |
| |
|