""" Sin-cos, fourier, rotary position embedding modules and functions Hacked together by / Copyright 2022 Ross Wightman """ import math from typing import List, Tuple, Optional, Union import torch from torch import nn as nn from .trace_utils import _assert def pixel_freq_bands( num_bands: int, max_freq: float = 224., linear_bands: bool = True, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): if linear_bands: bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) else: bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) return bands * torch.pi def freq_bands( num_bands: int, temperature: float = 10000., step: int = 2, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) return bands def build_sincos2d_pos_embed( feat_shape: List[int], dim: int = 64, temperature: float = 10000., reverse_coord: bool = False, interleave_sin_cos: bool = False, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None ) -> torch.Tensor: """ Args: feat_shape: dim: temperature: reverse_coord: stack grid order W, H instead of H, W interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos dtype: device: Returns: """ assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' pos_dim = dim // 4 bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) if reverse_coord: feat_shape = feat_shape[::-1] # stack W, H instead of H, W grid = torch.stack(torch.meshgrid( [torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) # FIXME add support for unflattened spatial dim? stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) return pos_emb def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, num_bands: int = 64, max_res: int = 224, temperature: float = 10000., linear_bands: bool = False, include_grid: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> List[torch.Tensor]: """ Args: feat_shape: Feature shape for embedding. bands: Pre-calculated frequency bands. num_bands: Number of frequency bands (determines output dim). max_res: Maximum resolution for pixel based freq. temperature: Temperature for non-pixel freq. linear_bands: Linear band spacing for pixel based freq. include_grid: Include the spatial grid in output. in_pixels: Output in pixel freq. ref_feat_shape: Reference feature shape for resize / fine-tune. dtype: Output dtype. device: Output device. Returns: """ if bands is None: if in_pixels: bands = pixel_freq_bands( num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device, ) else: bands = freq_bands( num_bands, temperature=temperature, step=1, dtype=dtype, device=device, ) else: if device is None: device = bands.device if dtype is None: dtype = bands.dtype if in_pixels: t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape] else: t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] if ref_feat_shape is not None: # eva's scheme for resizing rope embeddings (ref shape = pretrain) t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] grid = torch.stack(torch.meshgrid(t), dim=-1) grid = grid.unsqueeze(-1) pos = grid * bands pos_sin, pos_cos = pos.sin(), pos.cos() out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] return out class FourierEmbed(nn.Module): def __init__( self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False, ): super().__init__() self.max_res = max_res self.num_bands = num_bands self.concat_grid = concat_grid self.keep_spatial = keep_spatial self.register_buffer( 'bands', pixel_freq_bands(max_res, num_bands), persistent=False, ) def forward(self, x): B, C = x.shape[:2] feat_shape = x.shape[2:] emb = build_fourier_pos_embed( feat_shape, self.bands, include_grid=self.concat_grid, dtype=x.dtype, device=x.device, ) emb = torch.cat(emb, dim=-1) emb = emb.transpose(-1, -2).flatten(len(feat_shape)) batch_expand = (B,) + (-1,) * (x.ndim - 1) # FIXME support nD if self.keep_spatial: x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) else: x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) x = x.reshape(B, feat_shape.numel(), -1) return x def rot(x): return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): if sin_emb.ndim == 3: return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) return x * cos_emb + rot(x) * sin_emb def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): if isinstance(x, torch.Tensor): x = [x] return [t * cos_emb + rot(t) * sin_emb for t in x] def apply_rot_embed_cat(x: torch.Tensor, emb): sin_emb, cos_emb = emb.tensor_split(2, -1) if sin_emb.ndim == 3: return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) return x * cos_emb + rot(x) * sin_emb def apply_keep_indices_nlc(x, pos_embed, keep_indices): pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1) pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])) return pos_embed def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, dim: int = 64, max_res: int = 224, temperature: float = 10000., linear_bands: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): """ Args: feat_shape: Spatial shape of the target tensor for embedding. bands: Optional pre-generated frequency bands dim: Output dimension of embedding tensor. max_res: Maximum resolution for pixel mode. temperature: Temperature (inv freq) for non-pixel mode linear_bands: Linearly (instead of log) spaced bands for pixel mode in_pixels: Pixel vs language (inv freq) mode. dtype: Output dtype. device: Output device. Returns: """ sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, num_bands=dim // 4, max_res=max_res, temperature=temperature, linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=ref_feat_shape, device=device, dtype=dtype, ) num_spatial_dim = 1 # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks for x in feat_shape: num_spatial_dim *= x sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) return sin_emb, cos_emb class RotaryEmbedding(nn.Module): """ Rotary position embedding NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not been well tested, and will likely change. It will be moved to its own file. The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__( self, dim, max_res=224, temperature=10000, in_pixels=True, linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, ): super().__init__() self.dim = dim self.max_res = max_res self.temperature = temperature self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape if feat_shape is None: # only cache bands if in_pixels: bands = pixel_freq_bands( dim // 4, float(max_res), linear_bands=linear_bands, ) else: bands = freq_bands( dim // 4, temperature=temperature, step=1, ) print(bands) self.register_buffer( 'bands', bands, persistent=False, ) self.pos_embed_sin = None self.pos_embed_cos = None else: # cache full sin/cos embeddings if shape provided up front emb_sin, emb_cos = build_rotary_pos_embed( feat_shape=feat_shape, dim=dim, max_res=max_res, linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=self.ref_feat_shape, ) self.bands = None self.register_buffer( 'pos_embed_sin', emb_sin, persistent=False, ) self.register_buffer( 'pos_embed_cos', emb_cos, persistent=False, ) def get_embed(self, shape: Optional[List[int]] = None): if self.bands is not None: # rebuild embeddings every call, use if target shape changes assert shape is not None return build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, ) else: return self.pos_embed_sin, self.pos_embed_cos def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 sin_emb, cos_emb = self.get_embed(x.shape[2:]) return apply_rot_embed(x, sin_emb, cos_emb) class RotaryEmbeddingCat(nn.Module): """ Rotary position embedding w/ concatenatd sin & cos The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__( self, dim, max_res=224, temperature=10000, in_pixels=True, linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, ): super().__init__() self.dim = dim self.max_res = max_res self.temperature = temperature self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape if feat_shape is None: # only cache bands if in_pixels: bands = pixel_freq_bands( dim // 4, float(max_res), linear_bands=linear_bands, ) else: bands = freq_bands( dim // 4, temperature=temperature, step=1, ) print(bands) self.register_buffer( 'bands', bands, persistent=False, ) self.embed = None else: # cache full sin/cos embeddings if shape provided up front embeds = build_rotary_pos_embed( feat_shape=feat_shape, dim=dim, max_res=max_res, linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=self.ref_feat_shape, ) self.bands = None self.register_buffer( 'pos_embed', torch.cat(embeds, -1), persistent=False, ) def get_embed(self, shape: Optional[List[int]] = None): if self.bands is not None: # rebuild embeddings every call, use if target shape changes _assert(shape is not None, 'valid shape needed') embeds = build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, ) return torch.cat(embeds, -1) else: return self.pos_embed def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) return apply_rot_embed_cat(x, pos_embed)