himipo's picture
first
11aa70b
"""
DEIM: DETR with Improved Matching for Fast Convergence
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from D-FINE (https://github.com/Peterande/D-FINE)
Copyright (c) 2023 . All Rights Reserved.
"""
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
def inverse_sigmoid(x: torch.Tensor, eps: float=1e-5) -> torch.Tensor:
x = x.clip(min=0., max=1.)
return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps))
def bias_init_with_prob(prior_prob=0.01):
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-math.log((1 - prior_prob) / prior_prob))
return bias_init
def deformable_attention_core_func(value, value_spatial_shapes, sampling_locations, attention_weights):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor|List): [n_levels, 2]
value_level_start_index (Tensor|List): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, _, n_head, c = value.shape
_, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
split_shape = [h * w for h, w in value_spatial_shapes]
value_list = value.split(split_shape, dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (h, w) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[level].flatten(2).permute(
0, 2, 1).reshape(bs * n_head, c, h, w)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].permute(
0, 2, 1, 3, 4).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_,
sampling_grid_l_,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.permute(0, 2, 1, 3, 4).reshape(
bs * n_head, 1, Len_q, n_levels * n_points)
output = (torch.stack(
sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).reshape(bs, n_head * c, Len_q)
return output.permute(0, 2, 1)
def deformable_attention_core_func_v2(\
value: torch.Tensor,
value_spatial_shapes,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
num_points_list: List[int],
method='default',
value_shape='default',
):
"""
Args:
value (Tensor): [bs, value_length, n_head, c]
value_spatial_shapes (Tensor|List): [n_levels, 2]
value_level_start_index (Tensor|List): [n_levels]
sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2]
attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points]
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
if value_shape == 'default':
bs, n_head, c, _ = value[0].shape
elif value_shape == 'reshape': # reshape following RT-DETR
bs, _, n_head, c = value.shape
split_shape = [h * w for h, w in value_spatial_shapes]
value = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
_, Len_q, _, _, _ = sampling_locations.shape
# sampling_offsets [8, 480, 8, 12, 2]
if method == 'default':
sampling_grids = 2 * sampling_locations - 1
elif method == 'discrete':
sampling_grids = sampling_locations
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
sampling_locations_list = sampling_grids.split(num_points_list, dim=-2)
sampling_value_list = []
for level, (h, w) in enumerate(value_spatial_shapes):
value_l = value[level].reshape(bs * n_head, c, h, w)
sampling_grid_l: torch.Tensor = sampling_locations_list[level]
if method == 'default':
sampling_value_l = F.grid_sample(
value_l,
sampling_grid_l,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
elif method == 'discrete':
# n * m, seq, n, 2
sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5).to(torch.int64)
# FIX ME? for rectangle input
sampling_coord = sampling_coord.clamp(0, h - 1)
sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2)
s_idx = torch.arange(sampling_coord.shape[0], device=value_l.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1])
sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c
sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level])
sampling_value_list.append(sampling_value_l)
attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list))
weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights
output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q)
return output.permute(0, 2, 1)
def get_activation(act: str, inpace: bool=True):
"""get activation
"""
if act is None:
return nn.Identity()
elif isinstance(act, nn.Module):
return act
act = act.lower()
if act == 'silu' or act == 'swish':
m = nn.SiLU()
elif act == 'relu':
m = nn.ReLU()
elif act == 'leaky_relu':
m = nn.LeakyReLU()
elif act == 'silu':
m = nn.SiLU()
elif act == 'gelu':
m = nn.GELU()
elif act == 'hardsigmoid':
m = nn.Hardsigmoid()
else:
raise RuntimeError('')
if hasattr(m, 'inplace'):
m.inplace = inpace
return m