from functools import partial import numpy as np import torch from torch import nn from torch.nn.init import trunc_normal_ def get_2d_sincos_pos_embed(embed_dim, image_size): """ image_size: image_size or (image_height, image_width) return: pos_embed: [image_height, image_width, embed_dim] """ if isinstance(image_size, int): grid_h_size, grid_w_size = image_size, image_size else: grid_h_size, grid_w_size = image_size[0], image_size[1] grid_h = np.arange(grid_h_size, dtype=np.float32) grid_w = np.arange(grid_w_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2) emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) return emb def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (H, W) out: (H, W, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000 ** omega # (D/2,) out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb class Resampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by given learnable queries and 2d sincos pos_emb Outputs: A tensor with the shape of (batch_size, num_queries, embed_dim) """ def __init__( self, num_queries, embed_dim, num_heads, kv_dim=None, norm_layer=partial(nn.LayerNorm, eps=1e-6), adaptive=False, max_size=(70, 70), ): super().__init__() self.num_queries = num_queries self.embed_dim = embed_dim self.num_heads = num_heads self.adaptive = adaptive self.max_size = max_size self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) else: self.kv_proj = nn.Identity() self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.ln_post = norm_layer(embed_dim) self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim)) self._set_2d_pos_cache(self.max_size) self.apply(self._init_weights) def _set_2d_pos_cache(self, max_size, device='cpu'): pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) def _adjust_pos_cache(self, tgt_sizes, device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])] self._set_2d_pos_cache(self.max_size, device) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, tgt_sizes=None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] device = x.device dtype = x.dtype patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] self._adjust_pos_cache(tgt_sizes, device=device) max_patch_len = torch.max(patch_len) key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) pos_embed = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i] pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D key_padding_mask[i, patch_len[i]:] = True pos_embed = torch.nn.utils.rnn.pad_sequence( pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D x = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D out = self.attn( self._repeat(q, bs), # Q * B * D x + pos_embed, # L * B * D + L * B * D x, key_padding_mask=key_padding_mask)[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D x = self.ln_post(x) x = x @ self.proj return x def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1)