Spaces:
Paused
Paused
import torch | |
import copy | |
from torch import nn, Tensor | |
import os | |
import math | |
import torch.nn.functional as F | |
from torch import nn | |
class MLP(nn.Module): | |
""" Very simple multi-layer perceptron (also called FFN)""" | |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
return x | |
def inverse_sigmoid(x, eps=1e-5): | |
x = x.clamp(min=0, max=1) | |
x1 = x.clamp(min=eps) | |
x2 = (1 - x).clamp(min=eps) | |
return torch.log(x1/x2) | |
def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor): | |
""" | |
Input: | |
- memory: bs, \sum{hw}, d_model | |
- memory_padding_mask: bs, \sum{hw} | |
- spatial_shapes: nlevel, 2 | |
Output: | |
- output_memory: bs, \sum{hw}, d_model | |
- output_proposals: bs, \sum{hw}, 4 | |
""" | |
N_, S_, C_ = memory.shape | |
base_scale = 4.0 | |
proposals = [] | |
_cur = 0 | |
for lvl, (H_, W_) in enumerate(spatial_shapes): | |
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) | |
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) | |
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) | |
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), | |
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) | |
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) | |
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) | |
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale | |
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) | |
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) | |
proposals.append(proposal) | |
_cur += (H_ * W_) | |
output_proposals = torch.cat(proposals, 1) | |
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) | |
output_proposals = torch.log(output_proposals / (1 - output_proposals)) | |
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) | |
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) | |
output_memory = memory | |
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) | |
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) | |
return output_memory, output_proposals | |
def gen_sineembed_for_position(pos_tensor): | |
# n_query, bs, _ = pos_tensor.size() | |
# sineembed_tensor = torch.zeros(n_query, bs, 256) | |
scale = 2 * math.pi | |
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) | |
dim_t = 10000 ** (2 * (dim_t // 2) / 128) | |
x_embed = pos_tensor[:, :, 0] * scale | |
y_embed = pos_tensor[:, :, 1] * scale | |
pos_x = x_embed[:, :, None] / dim_t | |
pos_y = y_embed[:, :, None] / dim_t | |
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) | |
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) | |
if pos_tensor.size(-1) == 2: | |
pos = torch.cat((pos_y, pos_x), dim=2) | |
elif pos_tensor.size(-1) == 4: | |
w_embed = pos_tensor[:, :, 2] * scale | |
pos_w = w_embed[:, :, None] / dim_t | |
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) | |
h_embed = pos_tensor[:, :, 3] * scale | |
pos_h = h_embed[:, :, None] / dim_t | |
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) | |
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
else: | |
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) | |
return pos | |
def _get_activation_fn(activation): | |
"""Return an activation function given a string""" | |
if activation == "relu": | |
return F.relu | |
if activation == "gelu": | |
return F.gelu | |
if activation == "glu": | |
return F.glu | |
if activation == "prelu": | |
return nn.PReLU() | |
if activation == "selu": | |
return F.selu | |
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") | |
def _get_clones(module, N, layer_share=False): | |
if layer_share: | |
return nn.ModuleList([module for i in range(N)]) | |
else: | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
def _get_clones_advanced(module, N, N_valid): | |
assert N_valid <= N | |
layers = [] | |
for i in range(N): | |
if i < N_valid: | |
layers.append(copy.deepcopy(module)) | |
else: | |
layers.append(nn.Identity()) | |
return nn.ModuleList(layers) |