Spaces:
Runtime error
Runtime error
| import math | |
| import numpy as np | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
| """ | |
| grid_size: int of the grid height and width | |
| return: | |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| grid_h = np.arange(grid_size, dtype=np.float32) | |
| grid_w = np.arange(grid_size, dtype=np.float32) | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size, grid_size]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token: | |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float32) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| class CrossAttention(nn.Module): | |
| def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_heads = num_heads | |
| self.head_dim = self.hidden_dim // self.num_heads | |
| if (self.head_dim * self.num_heads) != self.hidden_dim: | |
| raise ValueError( | |
| f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" | |
| f" and `num_heads`: {self.num_heads})." | |
| ) | |
| self.q_proj = nn.Sequential( | |
| nn.LayerNorm(q_dim), | |
| nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
| ) | |
| self.k_proj = nn.Sequential( | |
| nn.LayerNorm(kv_dim), | |
| nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
| ) | |
| self.v_proj = nn.Sequential( | |
| nn.LayerNorm(kv_dim), | |
| nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
| ) | |
| self.o_proj = nn.Linear( | |
| self.num_heads * self.head_dim, q_dim, bias=attention_bias | |
| ) | |
| def forward(self, vision_latents, queries, attention_mask): | |
| bsz, q_len, _ = queries.size() | |
| bsz, v_len, _ = vision_latents.size() | |
| query_states = self.q_proj(queries) | |
| key_states = self.k_proj(vision_latents) | |
| value_states = self.v_proj(vision_latents) | |
| query_states = query_states.view( | |
| bsz, q_len, self.num_heads, self.head_dim | |
| ).transpose(1, 2) | |
| key_states = key_states.view( | |
| bsz, v_len, self.num_heads, self.head_dim | |
| ).transpose(1, 2) | |
| value_states = value_states.view( | |
| bsz, v_len, self.num_heads, self.head_dim | |
| ).transpose(1, 2) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, q_len, v_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" | |
| ) | |
| # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | |
| # Reference: https://github.com/pytorch/pytorch/issues/112577. | |
| if query_states.device.type == "cuda" and attention_mask is not None: | |
| query_states = query_states.contiguous() | |
| key_states = key_states.contiguous() | |
| value_states = value_states.contiguous() | |
| attn_output = torch.nn.functional.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=attention_mask, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output | |
| class AggregationBlock(nn.Module): | |
| def __init__( | |
| self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False | |
| ): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_heads = num_heads | |
| self.head_dim = self.hidden_dim // self.num_heads | |
| if (self.head_dim * self.num_heads) != self.hidden_dim: | |
| raise ValueError( | |
| f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" | |
| f" and `num_heads`: {self.num_heads})." | |
| ) | |
| self.attention = attention | |
| if attention: | |
| self.attention_layer = CrossAttention( | |
| q_dim, kv_dim, hidden_dim, num_heads, attention_bias | |
| ) | |
| else: | |
| self.attention_layer = MLP(kv_dim, q_dim, q_dim) | |
| def forward(self, vision_latents, queries, attention_mask): | |
| if self.attention: | |
| queries = self.attention_layer(vision_latents, queries, attention_mask) | |
| else: | |
| queries = self.attention_layer(vision_latents) | |
| return queries | |
| class MultiKVCrossAttention(nn.Module): | |
| def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_heads = num_heads | |
| self.head_dim = self.hidden_dim // self.num_heads | |
| if (self.head_dim * self.num_heads) != self.hidden_dim: | |
| raise ValueError( | |
| f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" | |
| f" and `num_heads`: {self.num_heads})." | |
| ) | |
| self.q_proj = nn.Sequential( | |
| nn.LayerNorm(q_dim), | |
| nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
| ) | |
| self.num_of_kvs = len(kv_dim_list) | |
| for i, kv_dim in enumerate(kv_dim_list): | |
| setattr( | |
| self, | |
| "k_proj_{}".format(i), | |
| nn.Sequential( | |
| nn.LayerNorm(kv_dim), | |
| nn.Linear( | |
| kv_dim, self.num_heads * self.head_dim, bias=attention_bias | |
| ), | |
| ), | |
| ) | |
| setattr( | |
| self, | |
| "v_proj_{}".format(i), | |
| nn.Sequential( | |
| nn.LayerNorm(kv_dim), | |
| nn.Linear( | |
| kv_dim, self.num_heads * self.head_dim, bias=attention_bias | |
| ), | |
| ), | |
| ) | |
| self.o_proj = nn.Linear( | |
| self.num_heads * self.head_dim, q_dim, bias=attention_bias | |
| ) | |
| def forward( | |
| self, | |
| queries, | |
| *vision_latents_attention_mask_list, | |
| ): | |
| vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] | |
| attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] | |
| bsz, q_len, _ = queries.size() | |
| query_states = self.q_proj(queries) | |
| key_states = torch.cat( | |
| [ | |
| getattr(self, "k_proj_{}".format(i))(vision_latents_list[i]) | |
| for i in range(self.num_of_kvs) | |
| ], | |
| dim=1, | |
| ) | |
| value_states = torch.cat( | |
| [ | |
| getattr(self, "v_proj_{}".format(i))(vision_latents_list[i]) | |
| for i in range(self.num_of_kvs) | |
| ], | |
| dim=1, | |
| ) | |
| v_len = key_states.shape[1] | |
| query_states = query_states.view( | |
| bsz, q_len, self.num_heads, self.head_dim | |
| ).transpose(1, 2) | |
| key_states = key_states.view( | |
| bsz, v_len, self.num_heads, self.head_dim | |
| ).transpose(1, 2) | |
| value_states = value_states.view( | |
| bsz, v_len, self.num_heads, self.head_dim | |
| ).transpose(1, 2) | |
| # if kv_weight is not None: | |
| # kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1) | |
| attention_mask = torch.cat(attention_mask_list, dim=-1) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, q_len, v_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" | |
| ) | |
| # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | |
| # Reference: https://github.com/pytorch/pytorch/issues/112577. | |
| if query_states.device.type == "cuda" and attention_mask is not None: | |
| query_states = query_states.contiguous() | |
| key_states = key_states.contiguous() | |
| value_states = value_states.contiguous() | |
| attn_output = torch.nn.functional.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=attention_mask, | |
| ) | |
| # attn_output = spda( | |
| # query_states, | |
| # key_states, | |
| # value_states, | |
| # attn_mask=attention_mask, | |
| # additional_score=kv_weight | |
| # ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output | |
| class MLP(nn.Module): | |
| def __init__(self, d_in, d_hidden, d_out): | |
| super().__init__() | |
| self.linear_1 = nn.Linear(d_in, d_hidden, bias=False) | |
| self.act = nn.GELU() | |
| self.linear_2 = nn.Linear(d_hidden, d_out, bias=False) | |
| def forward(self, x): | |
| return self.linear_2(self.act(self.linear_1(x))) | |
| class VisionCrossAttentionLayer(nn.Module): | |
| def __init__( | |
| self, | |
| q_dim, | |
| context_dim, | |
| kv_dim_list, | |
| kv_size_list, | |
| hidden_dim=1024, | |
| layer_idx=0, | |
| ): | |
| super().__init__() | |
| num_heads = 16 | |
| self.num_of_kvs = len(kv_dim_list) | |
| self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) | |
| self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) | |
| # if self.num_of_kvs > 1: | |
| # self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs) | |
| # self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs))) | |
| self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| self.cross_attn = MultiKVCrossAttention( | |
| hidden_dim, kv_dim_list, hidden_dim, num_heads | |
| ) | |
| self.kv_size_list = kv_size_list | |
| for i, kv_size in enumerate(kv_size_list): | |
| if kv_size > 1: | |
| setattr( | |
| self, | |
| "pos_embed_{}".format(i), | |
| nn.Parameter(torch.randn(kv_size**2, hidden_dim)), | |
| ) | |
| # self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False) | |
| def forward( | |
| self, | |
| queries, | |
| context_feature, | |
| *vision_latents_attention_mask_list, | |
| ) -> torch.FloatTensor: | |
| residual = queries | |
| # queries = self.proj_in(queries) | |
| context_feature = self.proj_context(context_feature) | |
| # queries = queries + context_feature | |
| queries = torch.cat([queries, context_feature], -1) | |
| # if self.num_of_kvs > 1: | |
| # kv_weight = self.weight_mlp(queries) # B * 1 * num_tower | |
| # kv_weight = kv_weight + self.tower_weight.view(1, 1, -1) | |
| # kv_weight = kv_weight.softmax(-1) | |
| # kv_number_list = [size**2 for size in self.kv_size_list] | |
| # kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1) | |
| # else: | |
| # kv_weight = None | |
| queries = self.proj_in(queries) | |
| vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] | |
| attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] | |
| attention_mask_list_reshaped = [] | |
| if attention_mask_list is not None: | |
| for attention_mask in attention_mask_list: | |
| attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) | |
| attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) | |
| attention_mask_list_reshaped.append(attention_mask) | |
| vision_latents_pos_list = [] | |
| for i, vision_latents in enumerate(vision_latents_list): | |
| if vision_latents.shape[1] > 1: | |
| vision_latents_pos_list.append( | |
| vision_latents | |
| + getattr(self, "pos_embed_{}".format(i))[None, :, :].to( | |
| vision_latents.dtype | |
| ) | |
| ) | |
| else: | |
| vision_latents_pos_list.append(vision_latents) | |
| # Cross Attention | |
| attention_output = self.cross_attn( | |
| queries, *vision_latents_pos_list, *attention_mask_list_reshaped | |
| ) | |
| # attention_output = (attention_output * combination_weight).sum(2) | |
| queries = queries + attention_output | |
| queries = self.norm(queries) | |
| queries = self.proj_out(queries) | |
| queries = queries + residual | |
| return queries | |
| class VisionAggregationLayer(nn.Module): | |
| def __init__( | |
| self, | |
| q_dim, | |
| context_dim, | |
| kv_dim_list, | |
| kv_size_list, | |
| hidden_dim=1024, | |
| layer_idx=0, | |
| ): | |
| super().__init__() | |
| num_heads = 16 | |
| self.num_of_kvs = len(kv_dim_list) | |
| self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) | |
| self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) | |
| self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| if self.num_of_kvs > 1: | |
| self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs) | |
| for i, kv_size in enumerate(kv_size_list): | |
| if kv_size > 1: | |
| setattr( | |
| self, | |
| "pos_embed_{}".format(i), | |
| nn.Parameter(torch.randn(kv_size**2, hidden_dim)), | |
| ) | |
| setattr( | |
| self, | |
| "aggregate_{}".format(i), | |
| AggregationBlock( | |
| True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads | |
| ), | |
| ) | |
| else: | |
| setattr( | |
| self, | |
| "aggregate_{}".format(i), | |
| AggregationBlock( | |
| False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads | |
| ), | |
| ) | |
| def forward( | |
| self, | |
| queries, | |
| context_feature, | |
| *vision_latents_attention_mask_list, | |
| ) -> torch.FloatTensor: | |
| residual = queries | |
| # queries = self.proj_in(queries) | |
| context_feature = self.proj_context(context_feature) | |
| # queries = queries + context_feature | |
| queries = torch.cat([queries, context_feature], -1) | |
| if self.num_of_kvs > 1: | |
| combination_weight = self.weight_mlp(queries).softmax( | |
| -1 | |
| ) # B * 1 * num_tower | |
| combination_weight = combination_weight.unsqueeze(-1) | |
| else: | |
| combination_weight = 1 | |
| queries = self.proj_in(queries) | |
| vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] | |
| attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] | |
| attention_mask_list_reshaped = [] | |
| if attention_mask_list is not None: | |
| for attention_mask in attention_mask_list: | |
| attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) | |
| attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) | |
| attention_mask_list_reshaped.append(attention_mask) | |
| vision_latents_pos_list = [] | |
| for i, vision_latents in enumerate(vision_latents_list): | |
| if vision_latents.shape[1] > 1: | |
| vision_latents_pos_list.append( | |
| vision_latents | |
| + getattr(self, "pos_embed_{}".format(i))[None, :, :].to( | |
| vision_latents.dtype | |
| ) | |
| ) | |
| else: | |
| vision_latents_pos_list.append(vision_latents) | |
| aggregated_vision_latents_list = [] | |
| for i, (vision_latents, attention_mask) in enumerate( | |
| zip(vision_latents_pos_list, attention_mask_list_reshaped) | |
| ): | |
| aggregated_vision_latents_list.append( | |
| getattr(self, "aggregate_{}".format(i))( | |
| vision_latents, queries, attention_mask | |
| ) | |
| ) | |
| aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2) | |
| queries = queries + (aggregated_vision_latents * combination_weight).sum(2) | |
| queries = self.norm(queries) | |
| queries = self.proj_out(queries) | |
| queries = queries + residual | |
| return queries | |
| class VisionTokenSampler(nn.Module): | |
| def __init__( | |
| self, | |
| q_dim, | |
| context_dim, | |
| kv_dim_list, | |
| kv_size_list, | |
| vision_hidden_size, | |
| num_of_layers=1, | |
| layer_type="joint", | |
| ): | |
| super().__init__() | |
| assert layer_type in ["joint", "sep"] | |
| if layer_type == "joint": | |
| self.layers = nn.ModuleList( | |
| [ | |
| VisionCrossAttentionLayer( | |
| q_dim, | |
| context_dim, | |
| kv_dim_list, | |
| kv_size_list, | |
| vision_hidden_size, | |
| idx, | |
| ) | |
| for idx in range(num_of_layers) | |
| ] | |
| ) | |
| else: | |
| self.layers = nn.ModuleList( | |
| [ | |
| VisionAggregationLayer( | |
| q_dim, | |
| context_dim, | |
| kv_dim_list, | |
| kv_size_list, | |
| vision_hidden_size, | |
| idx, | |
| ) | |
| for idx in range(num_of_layers) | |
| ] | |
| ) | |
| def forward(self, queries, context_feature, *vision_latents_attention_mask_list): | |
| for layer in self.layers: | |
| queries = layer( | |
| queries, context_feature, *vision_latents_attention_mask_list | |
| ) | |
| return queries | |