# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import math import numpy as np from scipy import interpolate import torch import torch.nn as nn import torch.nn.functional as F __all__ = [ "window_partition", "window_unpartition", "add_decomposed_rel_pos", "get_abs_pos", "PatchEmbed", "VisionRotaryEmbeddingFast", ] def window_partition(x, window_size): """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition(windows, window_size, pad_hw, hw): """ Window unpartition into original sequences and removing padding. Args: x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size, k_size, rel_pos): """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) use_log_interpolation = True # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: if not use_log_interpolation: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: src_size = rel_pos.shape[0] dst_size = max_rel_dist # q = 1.13492 q = 1.0903078 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) # print("x = %s" % str(x)) # print("dx = %s" % str(dx)) all_rel_pos_bias = [] for i in range(rel_pos.shape[1]): z = rel_pos[:, i].view(src_size).cpu().float().numpy() f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate") all_rel_pos_bias.append( torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device)) rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn def get_abs_pos(abs_pos, has_cls_token, hw): """ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the original embeddings. Args: abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. hw (Tuple): size of input image tokens. Returns: Absolute positional embeddings after processing with shape (1, H, W, C) """ h, w = hw if has_cls_token: abs_pos = abs_pos[:, 1:] xy_num = abs_pos.shape[1] size = int(math.sqrt(xy_num)) assert size * size == xy_num if size != h or size != w: new_abs_pos = F.interpolate( abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), size=(h, w), mode="bicubic", align_corners=False, ) return new_abs_pos.permute(0, 2, 3, 1) else: return abs_pos.reshape(1, h, w, -1) class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 ): """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x): x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x from math import pi import torch from torch import nn from einops import rearrange, repeat def broadcat(tensors, dim = -1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim = dim) def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r = 2) x1, x2 = x.unbind(dim = -1) x = torch.stack((-x2, x1), dim = -1) return rearrange(x, '... d r -> ... (d r)') class VisionRotaryEmbedding(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, num_freqs = 1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: raise ValueError(f'unknown modality {freqs_for}') if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs_h = torch.einsum('..., f -> ... f', t, freqs) freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) freqs_w = torch.einsum('..., f -> ... f', t, freqs) freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) self.register_buffer("freqs_cos", freqs.cos()) self.register_buffer("freqs_sin", freqs.sin()) print('======== shape of rope freq', self.freqs_cos.shape, '========') def forward(self, t, start_index = 0): rot_dim = self.freqs_cos.shape[-1] end_index = start_index + rot_dim assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) return torch.cat((t_left, t, t_right), dim = -1) class VisionRotaryEmbeddingFast(nn.Module): def __init__( self, dim, pt_seq_len=16, ft_seq_len=None, custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, num_freqs = 1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: raise ValueError(f'unknown modality {freqs_for}') if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs = torch.einsum('..., f -> ... f', t, freqs) freqs = repeat(freqs, '... n -> ... (n r)', r = 2) freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) print('======== shape of rope freq', self.freqs_cos.shape, '========') # def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin def forward(self, t): if t.shape[2] != self.freqs_cos.shape[0]: t_len = t.shape[2] output = t * self.freqs_cos[:t_len] + rotate_half(t) * self.freqs_sin[:t_len] else: output = t * self.freqs_cos + rotate_half(t) * self.freqs_sin return output