Spaces:
Running
Running
# Copyright (c) Alibaba Cloud. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.nn.init import normal_ | |
def get_abs_pos(abs_pos, tgt_size): | |
# abs_pos: L, C | |
# tgt_size: M | |
# return: M, C | |
src_size = int(math.sqrt(abs_pos.size(0))) | |
tgt_size = int(math.sqrt(tgt_size)) | |
dtype = abs_pos.dtype | |
if src_size != tgt_size: | |
return F.interpolate( | |
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), | |
size=(tgt_size, tgt_size), | |
mode="bicubic", | |
align_corners=False, | |
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) | |
else: | |
return abs_pos | |
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 | |
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
""" | |
grid_size: int of the grid height and width | |
return: | |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
grid_h = np.arange(grid_size, dtype=np.float32) | |
grid_w = np.arange(grid_size, dtype=np.float32) | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size, grid_size]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float32) | |
omega /= embed_dim / 2. | |
omega = 1. / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class Resampler(nn.Module): | |
""" | |
A 2D perceiver-resampler network with one cross attention layers by | |
(grid_size**2) learnable queries and 2d sincos pos_emb | |
Outputs: | |
A tensor with the shape of (grid_size**2, embed_dim) | |
""" | |
def __init__( | |
self, | |
grid_size, | |
embed_dim, | |
num_heads, | |
kv_dim=None, | |
norm_layer=nn.LayerNorm | |
): | |
super().__init__() | |
self.num_queries = grid_size ** 2 | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.pos_embed = nn.Parameter( | |
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() | |
).requires_grad_(False) | |
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) | |
normal_(self.query, std=.02) | |
if kv_dim is not None and kv_dim != embed_dim: | |
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) | |
else: | |
self.kv_proj = nn.Identity() | |
self.attn = nn.MultiheadAttention(embed_dim, num_heads) | |
self.ln_q = norm_layer(embed_dim) | |
self.ln_kv = norm_layer(embed_dim) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, x, key_padding_mask=None): | |
pos_embed = get_abs_pos(self.pos_embed, x.size(1)) | |
x = self.kv_proj(x) | |
x = self.ln_kv(x).permute(1, 0, 2) | |
N = x.shape[1] | |
q = self.ln_q(self.query) | |
out = self.attn( | |
self._repeat(q, N) + self.pos_embed.unsqueeze(1), | |
x + pos_embed.unsqueeze(1), | |
x, | |
key_padding_mask=key_padding_mask)[0] | |
return out.permute(1, 0, 2) | |
def _repeat(self, query, N: int): | |
return query.unsqueeze(1).repeat(1, N, 1) |