Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from functools import partial | |
from typing import Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
from mmengine.model import BaseModule | |
from mmengine.utils import digit_version | |
from ..utils import to_2tuple | |
# After pytorch v1.10.0, use torch.meshgrid without indexing | |
# will raise extra warning. For more details, | |
# refers to https://github.com/pytorch/pytorch/issues/50276 | |
if digit_version(torch.__version__) >= digit_version('1.10.0'): | |
torch_meshgrid = partial(torch.meshgrid, indexing='ij') | |
else: | |
torch_meshgrid = torch.meshgrid | |
class ConditionalPositionEncoding(BaseModule): | |
"""The Conditional Position Encoding (CPE) module. | |
The CPE is the implementation of 'Conditional Positional Encodings | |
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_. | |
Args: | |
in_channels (int): Number of input channels. | |
embed_dims (int): The feature dimension. Default: 768. | |
stride (int): Stride of conv layer. Default: 1. | |
""" | |
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): | |
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) | |
self.proj = nn.Conv2d( | |
in_channels, | |
embed_dims, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
bias=True, | |
groups=embed_dims) | |
self.stride = stride | |
def forward(self, x, hw_shape): | |
B, N, C = x.shape | |
H, W = hw_shape | |
feat_token = x | |
# convert (B, N, C) to (B, C, H, W) | |
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous() | |
if self.stride == 1: | |
x = self.proj(cnn_feat) + cnn_feat | |
else: | |
x = self.proj(cnn_feat) | |
x = x.flatten(2).transpose(1, 2) | |
return x | |
class PositionEncodingFourier(BaseModule): | |
"""The Position Encoding Fourier (PEF) module. | |
The PEF is adopted from EdgeNeXt <https://arxiv.org/abs/2206.10589>'_. | |
Args: | |
in_channels (int): Number of input channels. | |
Default: 32 | |
embed_dims (int): The feature dimension. | |
Default: 768. | |
temperature (int): Temperature. | |
Default: 10000. | |
dtype (torch.dtype): The data type. | |
Default: torch.float32. | |
init_cfg (dict): The config dict for initializing the module. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels=32, | |
embed_dims=768, | |
temperature=10000, | |
dtype=torch.float32, | |
init_cfg=None): | |
super(PositionEncodingFourier, self).__init__(init_cfg=init_cfg) | |
self.proj = nn.Conv2d(in_channels * 2, embed_dims, kernel_size=1) | |
self.scale = 2 * math.pi | |
self.in_channels = in_channels | |
self.embed_dims = embed_dims | |
self.dtype = dtype | |
if digit_version(torch.__version__) < digit_version('1.8.0'): | |
floor_div = torch.floor_divide | |
else: | |
floor_div = partial(torch.div, rounding_mode='floor') | |
dim_t = torch.arange(in_channels, dtype=self.dtype) | |
self.dim_t = temperature**(2 * floor_div(dim_t, 2) / in_channels) | |
def forward(self, bhw_shape): | |
B, H, W = bhw_shape | |
mask = torch.zeros(B, H, W).bool().to(self.proj.weight.device) | |
not_mask = ~mask | |
eps = 1e-6 | |
y_embed = not_mask.cumsum(1, dtype=self.dtype) | |
x_embed = not_mask.cumsum(2, dtype=self.dtype) | |
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |
dim_t = self.dim_t.to(mask.device) | |
pos_x = x_embed[:, :, :, None] / dim_t | |
pos_y = y_embed[:, :, :, None] / dim_t | |
pos_x = torch.stack( | |
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), | |
dim=4).flatten(3) | |
pos_y = torch.stack( | |
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), | |
dim=4).flatten(3) | |
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
pos = self.proj(pos) | |
return pos | |
def build_2d_sincos_position_embedding( | |
patches_resolution: Union[int, Sequence[int]], | |
embed_dims: int, | |
temperature: Optional[int] = 10000., | |
cls_token: Optional[bool] = False) -> torch.Tensor: | |
"""The function is to build position embedding for model to obtain the | |
position information of the image patches. | |
Args: | |
patches_resolution (Union[int, Sequence[int]]): The resolution of each | |
patch. | |
embed_dims (int): The dimension of the embedding vector. | |
temperature (int, optional): The temperature parameter. Defaults to | |
10000. | |
cls_token (bool, optional): Whether to concatenate class token. | |
Defaults to False. | |
Returns: | |
torch.Tensor: The position embedding vector. | |
""" | |
if isinstance(patches_resolution, int): | |
patches_resolution = (patches_resolution, patches_resolution) | |
h, w = patches_resolution | |
grid_w = torch.arange(w, dtype=torch.float32) | |
grid_h = torch.arange(h, dtype=torch.float32) | |
grid_w, grid_h = torch_meshgrid(grid_w, grid_h) | |
assert embed_dims % 4 == 0, \ | |
'Embed dimension must be divisible by 4.' | |
pos_dim = embed_dims // 4 | |
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | |
omega = 1. / (temperature**omega) | |
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) | |
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) | |
pos_emb = torch.cat( | |
[ | |
torch.sin(out_w), | |
torch.cos(out_w), | |
torch.sin(out_h), | |
torch.cos(out_h) | |
], | |
dim=1, | |
)[None, :, :] | |
if cls_token: | |
cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) | |
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) | |
return pos_emb | |
class RotaryEmbeddingFast(BaseModule): | |
"""Implements 2D rotary embedding (RoPE) for image tokens. Position | |
encoding is implemented with sin and cos functions, | |
.. math:: | |
Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\ | |
Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}} | |
Args: | |
embed_dims (int): The feature dimension for each head. | |
patch_resolution (int | tuple): The resolution of the | |
image, in format (H, W). | |
theta (float): The hyperparameter for position coding. | |
Defaults to 10000. | |
init_cfg (dict, optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
embed_dims, | |
patch_resolution, | |
theta=10000., | |
init_cfg=None): | |
super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg) | |
self.half_dim = embed_dims // 2 | |
self.patch_resolution = to_2tuple(patch_resolution) | |
self.theta = theta | |
freqs_cos, freqs_sin = self.compute_position_embedding() | |
self.register_buffer('freqs_cos', freqs_cos) | |
self.register_buffer('freqs_sin', freqs_sin) | |
def compute_position_embedding(self): | |
frequency = self.theta**( | |
torch.arange(0, self.half_dim, 2).float() / self.half_dim) | |
frequency = 1. / frequency | |
h, w = self.patch_resolution | |
th = torch.arange(h) / h * self.half_dim | |
tw = torch.arange(w) / w * self.half_dim | |
position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2) | |
position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2) | |
height = position_h[:, None, :].expand(h, w, self.half_dim) | |
width = position_w[None, :, :].expand(h, w, self.half_dim) | |
position = torch.cat((height, width), dim=-1) | |
freqs_cos = position.cos().view(-1, position.shape[-1]) | |
freqs_sin = position.sin().view(-1, position.shape[-1]) | |
return freqs_cos, freqs_sin | |
def forward(self, x, patch_resolution): | |
# Check whether the patch resolution is the predefined size | |
patch_resolution = to_2tuple(patch_resolution) | |
if patch_resolution != self.patch_resolution: | |
self.patch_resolution = patch_resolution | |
freqs_cos, freqs_sin = self.compute_position_embedding() | |
self.register_buffer('freqs_cos', freqs_cos.to(x.device)) | |
self.register_buffer('freqs_sin', freqs_sin.to(x.device)) | |
batch, num_heads, num_patches, dim = x.shape | |
inputs = x | |
x = x.reshape(batch, num_heads, num_patches, -1, 2) | |
x1, x2 = x.unbind(dim=-1) | |
x = torch.stack((-x2, x1), dim=-1) | |
x = x.reshape(batch, num_heads, num_patches, dim) | |
return inputs * self.freqs_cos + x * self.freqs_sin | |