| |
| |
| |
| import math |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint |
| from mmcv.cnn.bricks import DropPath |
| from torch import Tensor |
|
|
| try: |
| from transformers import BertConfig, BertPreTrainedModel |
| from transformers.modeling_utils import apply_chunking_to_forward |
| from transformers.models.bert.modeling_bert import \ |
| BertAttention as HFBertAttention |
| from transformers.models.bert.modeling_bert import \ |
| BertIntermediate as HFBertIntermediate |
| from transformers.models.bert.modeling_bert import \ |
| BertOutput as HFBertOutput |
| except ImportError: |
| BertConfig = None |
| BertPreTrainedModel = object |
| apply_chunking_to_forward = None |
| HFBertAttention = object |
| HFBertIntermediate = object |
| HFBertOutput = object |
|
|
| MAX_CLAMP_VALUE = 50000 |
|
|
|
|
| def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, |
| W: int) -> Tensor: |
| """Permute and then flatten a tensor, |
| |
| from size (N, A, C, H, W) to (N, H * W * A, C). |
| |
| Args: |
| layer (Tensor): Tensor of shape (N, C, H, W). |
| N (int): Batch size. |
| A (int): Number of attention heads. |
| C (int): Number of channels. |
| H (int): Height of feature map. |
| W (int): Width of feature map. |
| |
| Returns: |
| Tensor: A Tensor of shape (N, H * W * A, C). |
| """ |
| layer = layer.view(N, A, C, H, W) |
| layer = layer.permute(0, 3, 4, 1, 2) |
| layer = layer.reshape(N, -1, C) |
| return layer |
|
|
|
|
| def clamp_values(vector: Tensor) -> Tensor: |
| """Clamp the values of a vector to the range [-MAX_CLAMP_VALUE, |
| MAX_CLAMP_VALUE]. |
| |
| Args: |
| vector (Tensor): Tensor of shape (N, C, H, W). |
| |
| Returns: |
| Tensor: A Tensor of shape (N, C, H, W) with clamped values. |
| """ |
| vector = torch.clamp(vector, min=-MAX_CLAMP_VALUE, max=MAX_CLAMP_VALUE) |
| return vector |
|
|
|
|
| class BiMultiHeadAttention(nn.Module): |
| """Bidirectional fusion Multi-Head Attention layer. |
| |
| Args: |
| v_dim (int): The dimension of the vision input. |
| l_dim (int): The dimension of the language input. |
| embed_dim (int): The embedding dimension for the attention operation. |
| num_heads (int): The number of attention heads. |
| dropout (float, optional): The dropout probability. Defaults to 0.1. |
| """ |
|
|
| def __init__(self, |
| v_dim: int, |
| l_dim: int, |
| embed_dim: int, |
| num_heads: int, |
| dropout: float = 0.1): |
| super(BiMultiHeadAttention, self).__init__() |
|
|
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dim = embed_dim // num_heads |
| self.v_dim = v_dim |
| self.l_dim = l_dim |
|
|
| assert ( |
| self.head_dim * self.num_heads == self.embed_dim |
| ), 'embed_dim must be divisible by num_heads ' \ |
| f'(got `embed_dim`: {self.embed_dim} ' \ |
| f'and `num_heads`: {self.num_heads}).' |
| self.scale = self.head_dim**(-0.5) |
| self.dropout = dropout |
|
|
| self.v_proj = nn.Linear(self.v_dim, self.embed_dim) |
| self.l_proj = nn.Linear(self.l_dim, self.embed_dim) |
| self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) |
| self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) |
|
|
| self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) |
| self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) |
|
|
| self.stable_softmax_2d = False |
| self.clamp_min_for_underflow = True |
| self.clamp_max_for_overflow = True |
|
|
| self._reset_parameters() |
|
|
| def _shape(self, tensor: Tensor, seq_len: int, bsz: int): |
| return tensor.view(bsz, seq_len, self.num_heads, |
| self.head_dim).transpose(1, 2).contiguous() |
|
|
| def _reset_parameters(self): |
| nn.init.xavier_uniform_(self.v_proj.weight) |
| self.v_proj.bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.l_proj.weight) |
| self.l_proj.bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.values_v_proj.weight) |
| self.values_v_proj.bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.values_l_proj.weight) |
| self.values_l_proj.bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.out_v_proj.weight) |
| self.out_v_proj.bias.data.fill_(0) |
| nn.init.xavier_uniform_(self.out_l_proj.weight) |
| self.out_l_proj.bias.data.fill_(0) |
|
|
| def forward( |
| self, |
| vision: Tensor, |
| lang: Tensor, |
| attention_mask_v: Optional[Tensor] = None, |
| attention_mask_l: Optional[Tensor] = None, |
| ) -> Tuple[Tensor, Tensor]: |
| bsz, tgt_len, _ = vision.size() |
|
|
| query_states = self.v_proj(vision) * self.scale |
| key_states = self._shape(self.l_proj(lang), -1, bsz) |
| value_v_states = self._shape(self.values_v_proj(vision), -1, bsz) |
| value_l_states = self._shape(self.values_l_proj(lang), -1, bsz) |
|
|
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| query_states = self._shape(query_states, tgt_len, |
| bsz).view(*proj_shape) |
| key_states = key_states.view(*proj_shape) |
| value_v_states = value_v_states.view(*proj_shape) |
| value_l_states = value_l_states.view(*proj_shape) |
|
|
| src_len = key_states.size(1) |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| raise ValueError( |
| f'Attention weights should be of ' |
| f'size {(bsz * self.num_heads, tgt_len, src_len)}, ' |
| f'but is {attn_weights.size()}') |
|
|
| if self.stable_softmax_2d: |
| attn_weights = attn_weights - attn_weights.max() |
|
|
| if self.clamp_min_for_underflow: |
| |
| attn_weights = torch.clamp(attn_weights, min=-MAX_CLAMP_VALUE) |
| if self.clamp_max_for_overflow: |
| |
| attn_weights = torch.clamp(attn_weights, max=MAX_CLAMP_VALUE) |
|
|
| attn_weights_T = attn_weights.transpose(1, 2) |
| attn_weights_l = ( |
| attn_weights_T - |
| torch.max(attn_weights_T, dim=-1, keepdim=True)[0]) |
| if self.clamp_min_for_underflow: |
| |
| attn_weights_l = torch.clamp(attn_weights_l, min=-MAX_CLAMP_VALUE) |
| if self.clamp_max_for_overflow: |
| |
| attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE) |
|
|
| if attention_mask_v is not None: |
| attention_mask_v = ( |
| attention_mask_v[:, None, |
| None, :].repeat(1, self.num_heads, 1, |
| 1).flatten(0, 1)) |
| attn_weights_l.masked_fill_(attention_mask_v, float('-inf')) |
|
|
| attn_weights_l = attn_weights_l.softmax(dim=-1) |
|
|
| if attention_mask_l is not None: |
| assert (attention_mask_l.dim() == 2) |
| attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) |
| attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) |
| attention_mask = attention_mask.masked_fill( |
| attention_mask == 0, -9e15) |
|
|
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| raise ValueError('Attention mask should be of ' |
| f'size {(bsz, 1, tgt_len, src_len)}') |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, |
| src_len) + attention_mask |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, |
| src_len) |
|
|
| attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) |
|
|
| attn_probs_v = F.dropout( |
| attn_weights_v, p=self.dropout, training=self.training) |
| attn_probs_l = F.dropout( |
| attn_weights_l, p=self.dropout, training=self.training) |
|
|
| attn_output_v = torch.bmm(attn_probs_v, value_l_states) |
| attn_output_l = torch.bmm(attn_probs_l, value_v_states) |
|
|
| if attn_output_v.size() != (bsz * self.num_heads, tgt_len, |
| self.head_dim): |
| raise ValueError( |
| '`attn_output_v` should be of ' |
| f'size {(bsz, self.num_heads, tgt_len, self.head_dim)}, ' |
| f'but is {attn_output_v.size()}') |
|
|
| if attn_output_l.size() != (bsz * self.num_heads, src_len, |
| self.head_dim): |
| raise ValueError( |
| '`attn_output_l` should be of size ' |
| f'{(bsz, self.num_heads, src_len, self.head_dim)}, ' |
| f'but is {attn_output_l.size()}') |
|
|
| attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, |
| self.head_dim) |
| attn_output_v = attn_output_v.transpose(1, 2) |
| attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) |
|
|
| attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, |
| self.head_dim) |
| attn_output_l = attn_output_l.transpose(1, 2) |
| attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) |
|
|
| attn_output_v = self.out_v_proj(attn_output_v) |
| attn_output_l = self.out_l_proj(attn_output_l) |
|
|
| return attn_output_v, attn_output_l |
|
|
|
|
| class BiAttentionBlock(nn.Module): |
| """BiAttentionBlock Module: |
| |
| First, multi-level visual features are concat; Then the concat visual |
| feature and lang feature are fused by attention; Finally the newly visual |
| feature are split into multi levels. |
| |
| Args: |
| v_dim (int): The dimension of the visual features. |
| l_dim (int): The dimension of the language feature. |
| embed_dim (int): The embedding dimension for the attention operation. |
| num_heads (int): The number of attention heads. |
| dropout (float, optional): The dropout probability. Defaults to 0.1. |
| drop_path (float, optional): The drop path probability. |
| Defaults to 0.0. |
| init_values (float, optional): |
| The initial value for the scaling parameter. |
| Defaults to 1e-4. |
| """ |
|
|
| def __init__(self, |
| v_dim: int, |
| l_dim: int, |
| embed_dim: int, |
| num_heads: int, |
| dropout: float = 0.1, |
| drop_path: float = .0, |
| init_values: float = 1e-4): |
| super().__init__() |
|
|
| |
| self.layer_norm_v = nn.LayerNorm(v_dim) |
| self.layer_norm_l = nn.LayerNorm(l_dim) |
| self.attn = BiMultiHeadAttention( |
| v_dim=v_dim, |
| l_dim=l_dim, |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| dropout=dropout) |
|
|
| |
| self.drop_path = DropPath( |
| drop_path) if drop_path > 0. else nn.Identity() |
| self.gamma_v = nn.Parameter( |
| init_values * torch.ones(v_dim), requires_grad=True) |
| self.gamma_l = nn.Parameter( |
| init_values * torch.ones(l_dim), requires_grad=True) |
|
|
| def forward(self, |
| vf0: Tensor, |
| vf1: Tensor, |
| vf2: Tensor, |
| vf3: Tensor, |
| vf4: Tensor, |
| lang_feature: Tensor, |
| attention_mask_l=None): |
| visual_features = [vf0, vf1, vf2, vf3, vf4] |
| size_per_level, visual_features_flatten = [], [] |
| for i, feat_per_level in enumerate(visual_features): |
| bs, c, h, w = feat_per_level.shape |
| size_per_level.append([h, w]) |
| feat = permute_and_flatten(feat_per_level, bs, -1, c, h, w) |
| visual_features_flatten.append(feat) |
| visual_features_flatten = torch.cat(visual_features_flatten, dim=1) |
| new_v, new_lang_feature = self.single_attention_call( |
| visual_features_flatten, |
| lang_feature, |
| attention_mask_l=attention_mask_l) |
| |
| new_v = new_v.transpose(1, 2).contiguous() |
|
|
| start = 0 |
| |
| fvfs = [] |
| for (h, w) in size_per_level: |
| new_v_per_level = new_v[:, :, |
| start:start + h * w].view(bs, -1, h, |
| w).contiguous() |
| fvfs.append(new_v_per_level) |
| start += h * w |
|
|
| return fvfs[0], fvfs[1], fvfs[2], fvfs[3], fvfs[4], new_lang_feature |
|
|
| def single_attention_call( |
| self, |
| visual: Tensor, |
| lang: Tensor, |
| attention_mask_v: Optional[Tensor] = None, |
| attention_mask_l: Optional[Tensor] = None, |
| ) -> Tuple[Tensor, Tensor]: |
| """Perform a single attention call between the visual and language |
| inputs. |
| |
| Args: |
| visual (Tensor): The visual input tensor. |
| lang (Tensor): The language input tensor. |
| attention_mask_v (Optional[Tensor]): |
| An optional attention mask tensor for the visual input. |
| attention_mask_l (Optional[Tensor]): |
| An optional attention mask tensor for the language input. |
| |
| Returns: |
| Tuple[Tensor, Tensor]: A tuple containing the updated |
| visual and language tensors after the attention call. |
| """ |
| visual = self.layer_norm_v(visual) |
| lang = self.layer_norm_l(lang) |
| delta_v, delta_l = self.attn( |
| visual, |
| lang, |
| attention_mask_v=attention_mask_v, |
| attention_mask_l=attention_mask_l) |
| |
| visual = visual + self.drop_path(self.gamma_v * delta_v) |
| lang = lang + self.drop_path(self.gamma_l * delta_l) |
| return visual, lang |
|
|
|
|
| class SingleScaleBiAttentionBlock(BiAttentionBlock): |
| """This is a single-scale implementation of `BiAttentionBlock`. |
| |
| The only differenece between it and `BiAttentionBlock` is that the |
| `forward` function of `SingleScaleBiAttentionBlock` only accepts a single |
| flatten visual feature map, while the `forward` function in |
| `BiAttentionBlock` accepts multiple visual feature maps. |
| """ |
|
|
| def forward(self, |
| visual_feature: Tensor, |
| lang_feature: Tensor, |
| attention_mask_v=None, |
| attention_mask_l=None): |
| """Single-scale forward pass. |
| |
| Args: |
| visual_feature (Tensor): The visual input tensor. Tensor of |
| shape (bs, patch_len, ch). |
| lang_feature (Tensor): The language input tensor. Tensor of |
| shape (bs, text_len, ch). |
| attention_mask_v (_type_, optional): Visual feature attention |
| mask. Defaults to None. |
| attention_mask_l (_type_, optional): Language feature attention |
| mask.Defaults to None. |
| """ |
| new_v, new_lang_feature = self.single_attention_call( |
| visual_feature, |
| lang_feature, |
| attention_mask_v=attention_mask_v, |
| attention_mask_l=attention_mask_l) |
| return new_v, new_lang_feature |
|
|
|
|
| class VLFuse(nn.Module): |
| """Early Fusion Module. |
| |
| Args: |
| v_dim (int): Dimension of visual features. |
| l_dim (int): Dimension of language features. |
| embed_dim (int): The embedding dimension for the attention operation. |
| num_heads (int): Number of attention heads. |
| dropout (float): Dropout probability. |
| drop_path (float): Drop path probability. |
| use_checkpoint (bool): Whether to use PyTorch's checkpoint function. |
| """ |
|
|
| def __init__(self, |
| v_dim: int = 256, |
| l_dim: int = 768, |
| embed_dim: int = 2048, |
| num_heads: int = 8, |
| dropout: float = 0.1, |
| drop_path: float = 0.0, |
| use_checkpoint: bool = False): |
| super().__init__() |
| self.use_checkpoint = use_checkpoint |
| self.b_attn = BiAttentionBlock( |
| v_dim=v_dim, |
| l_dim=l_dim, |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| drop_path=drop_path, |
| init_values=1.0 / 6.0) |
|
|
| def forward(self, x: dict) -> dict: |
| """Forward pass of the VLFuse module.""" |
| visual_features = x['visual'] |
| language_dict_features = x['lang'] |
|
|
| if self.use_checkpoint: |
| |
| |
| |
| vf0, vf1, vf2, vf3, vf4, language_features = checkpoint.checkpoint( |
| self.b_attn, *visual_features, |
| language_dict_features['hidden'], |
| language_dict_features['masks']) |
| else: |
| vf0, vf1, vf2, vf3, vf4, language_features = self.b_attn( |
| *visual_features, language_dict_features['hidden'], |
| language_dict_features['masks']) |
|
|
| language_dict_features['hidden'] = language_features |
| fused_language_dict_features = language_dict_features |
|
|
| features_dict = { |
| 'visual': [vf0, vf1, vf2, vf3, vf4], |
| 'lang': fused_language_dict_features |
| } |
|
|
| return features_dict |
|
|
|
|
| class BertEncoderLayer(BertPreTrainedModel): |
| """A modified version of the `BertLayer` class from the |
| `transformers.models.bert.modeling_bert` module. |
| |
| Args: |
| config (:class:`~transformers.BertConfig`): |
| The configuration object that |
| contains various parameters for the model. |
| clamp_min_for_underflow (bool, optional): |
| Whether to clamp the minimum value of the hidden states |
| to prevent underflow. Defaults to `False`. |
| clamp_max_for_overflow (bool, optional): |
| Whether to clamp the maximum value of the hidden states |
| to prevent overflow. Defaults to `False`. |
| """ |
|
|
| def __init__(self, |
| config: BertConfig, |
| clamp_min_for_underflow: bool = False, |
| clamp_max_for_overflow: bool = False): |
| super().__init__(config) |
| self.config = config |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
|
|
| self.attention = BertAttention(config, clamp_min_for_underflow, |
| clamp_max_for_overflow) |
| self.intermediate = BertIntermediate(config) |
| self.output = BertOutput(config) |
|
|
| def forward( |
| self, inputs: Dict[str, Dict[str, torch.Tensor]] |
| ) -> Dict[str, Dict[str, torch.Tensor]]: |
| """Applies the BertEncoderLayer to the input features.""" |
| language_dict_features = inputs['lang'] |
| hidden_states = language_dict_features['hidden'] |
| attention_mask = language_dict_features['masks'] |
|
|
| device = hidden_states.device |
| input_shape = hidden_states.size()[:-1] |
| extended_attention_mask = self.get_extended_attention_mask( |
| attention_mask, input_shape, device) |
|
|
| self_attention_outputs = self.attention( |
| hidden_states, |
| extended_attention_mask, |
| None, |
| output_attentions=False, |
| past_key_value=None) |
| attention_output = self_attention_outputs[0] |
| outputs = self_attention_outputs[1:] |
| layer_output = apply_chunking_to_forward(self.feed_forward_chunk, |
| self.chunk_size_feed_forward, |
| self.seq_len_dim, |
| attention_output) |
| outputs = (layer_output, ) + outputs |
| hidden_states = outputs[0] |
|
|
| language_dict_features['hidden'] = hidden_states |
|
|
| features_dict = { |
| 'visual': inputs['visual'], |
| 'lang': language_dict_features |
| } |
|
|
| return features_dict |
|
|
| def feed_forward_chunk(self, attention_output: Tensor) -> Tensor: |
| """Applies the intermediate and output layers of the BertEncoderLayer |
| to a chunk of the input sequence.""" |
| intermediate_output = self.intermediate(attention_output) |
| layer_output = self.output(intermediate_output, attention_output) |
| return layer_output |
|
|
|
|
| |
| |
| class BertSelfAttention(nn.Module): |
| """BERT self-attention layer from Huggingface transformers. |
| |
| Compared to the BertSelfAttention of Huggingface, only add the clamp. |
| |
| Args: |
| config (:class:`~transformers.BertConfig`): |
| The configuration object that |
| contains various parameters for the model. |
| clamp_min_for_underflow (bool, optional): |
| Whether to clamp the minimum value of the hidden states |
| to prevent underflow. Defaults to `False`. |
| clamp_max_for_overflow (bool, optional): |
| Whether to clamp the maximum value of the hidden states |
| to prevent overflow. Defaults to `False`. |
| """ |
|
|
| def __init__(self, |
| config: BertConfig, |
| clamp_min_for_underflow: bool = False, |
| clamp_max_for_overflow: bool = False): |
| super().__init__() |
| if config.hidden_size % config.num_attention_heads != 0 and \ |
| not hasattr(config, 'embedding_size'): |
| raise ValueError(f'The hidden size ({config.hidden_size}) is ' |
| 'not a multiple of the number of attention ' |
| f'heads ({config.num_attention_heads})') |
|
|
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / |
| config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * \ |
| self.attention_head_size |
|
|
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.position_embedding_type = getattr(config, |
| 'position_embedding_type', |
| 'absolute') |
| if self.position_embedding_type == 'relative_key' or \ |
| self.position_embedding_type == 'relative_key_query': |
| self.max_position_embeddings = config.max_position_embeddings |
| self.distance_embedding = nn.Embedding( |
| 2 * config.max_position_embeddings - 1, |
| self.attention_head_size) |
| self.clamp_min_for_underflow = clamp_min_for_underflow |
| self.clamp_max_for_overflow = clamp_max_for_overflow |
|
|
| self.is_decoder = config.is_decoder |
|
|
| def transpose_for_scores(self, x: Tensor) -> Tensor: |
| """Transpose the dimensions of `x`.""" |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, |
| self.attention_head_size) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward( |
| self, |
| hidden_states: Tensor, |
| attention_mask: Optional[Tensor] = None, |
| head_mask: Optional[Tensor] = None, |
| encoder_hidden_states: Optional[Tensor] = None, |
| encoder_attention_mask: Optional[Tensor] = None, |
| past_key_value: Optional[Tuple[Tensor, Tensor]] = None, |
| output_attentions: bool = False, |
| ) -> Tuple[Tensor, ...]: |
| """Perform a forward pass through the BERT self-attention layer.""" |
|
|
| mixed_query_layer = self.query(hidden_states) |
|
|
| |
| |
| |
| is_cross_attention = encoder_hidden_states is not None |
|
|
| if is_cross_attention and past_key_value is not None: |
| |
| key_layer = past_key_value[0] |
| value_layer = past_key_value[1] |
| attention_mask = encoder_attention_mask |
| elif is_cross_attention: |
| key_layer = self.transpose_for_scores( |
| self.key(encoder_hidden_states)) |
| value_layer = self.transpose_for_scores( |
| self.value(encoder_hidden_states)) |
| attention_mask = encoder_attention_mask |
| elif past_key_value is not None: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| else: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
| if self.is_decoder: |
| past_key_value = (key_layer, value_layer) |
|
|
| |
| |
| attention_scores = torch.matmul(query_layer, |
| key_layer.transpose(-1, -2)) |
|
|
| if self.position_embedding_type == 'relative_key' or \ |
| self.position_embedding_type == 'relative_key_query': |
| seq_length = hidden_states.size()[1] |
| position_ids_l = torch.arange( |
| seq_length, dtype=torch.long, |
| device=hidden_states.device).view(-1, 1) |
| position_ids_r = torch.arange( |
| seq_length, dtype=torch.long, |
| device=hidden_states.device).view(1, -1) |
| distance = position_ids_l - position_ids_r |
| positional_embedding = self.distance_embedding( |
| distance + self.max_position_embeddings - 1) |
| positional_embedding = positional_embedding.to( |
| dtype=query_layer.dtype) |
|
|
| if self.position_embedding_type == 'relative_key': |
| relative_position_scores = torch.einsum( |
| 'bhld,lrd->bhlr', query_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores |
| elif self.position_embedding_type == 'relative_key_query': |
| relative_position_scores_query = torch.einsum( |
| 'bhld,lrd->bhlr', query_layer, positional_embedding) |
| relative_position_scores_key = torch.einsum( |
| 'bhrd,lrd->bhlr', key_layer, positional_embedding) |
| attention_scores = attention_scores + \ |
| relative_position_scores_query + \ |
| relative_position_scores_key |
|
|
| attention_scores = attention_scores / math.sqrt( |
| self.attention_head_size) |
|
|
| if self.clamp_min_for_underflow: |
| attention_scores = torch.clamp( |
| attention_scores, min=-MAX_CLAMP_VALUE |
| ) |
| if self.clamp_max_for_overflow: |
| attention_scores = torch.clamp( |
| attention_scores, max=MAX_CLAMP_VALUE |
| ) |
|
|
| if attention_mask is not None: |
| |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
| |
| |
| attention_probs = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| attention_probs = attention_probs * head_mask |
|
|
| context_layer = torch.matmul(attention_probs, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + ( |
| self.all_head_size, ) |
| context_layer = context_layer.view(*new_context_layer_shape) |
|
|
| outputs = (context_layer, |
| attention_probs) if output_attentions else (context_layer, ) |
|
|
| if self.is_decoder: |
| outputs = outputs + (past_key_value, ) |
| return outputs |
|
|
|
|
| class BertAttention(HFBertAttention): |
| """BertAttention is made up of self-attention and intermediate+output. |
| |
| Compared to the BertAttention of Huggingface, only add the clamp. |
| |
| Args: |
| config (:class:`~transformers.BertConfig`): |
| The configuration object that |
| contains various parameters for the model. |
| clamp_min_for_underflow (bool, optional): |
| Whether to clamp the minimum value of the hidden states |
| to prevent underflow. Defaults to `False`. |
| clamp_max_for_overflow (bool, optional): |
| Whether to clamp the maximum value of the hidden states |
| to prevent overflow. Defaults to `False`. |
| """ |
|
|
| def __init__(self, |
| config: BertConfig, |
| clamp_min_for_underflow: bool = False, |
| clamp_max_for_overflow: bool = False): |
| super().__init__(config) |
| self.self = BertSelfAttention(config, clamp_min_for_underflow, |
| clamp_max_for_overflow) |
|
|
|
|
| class BertIntermediate(HFBertIntermediate): |
| """Modified from transformers.models.bert.modeling_bert.BertIntermediate. |
| |
| Compared to the BertIntermediate of Huggingface, only add the clamp. |
| """ |
|
|
| def forward(self, hidden_states: Tensor) -> Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = clamp_values(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| hidden_states = clamp_values(hidden_states) |
| return hidden_states |
|
|
|
|
| class BertOutput(HFBertOutput): |
| """Modified from transformers.models.bert.modeling_bert.BertOutput. |
| |
| Compared to the BertOutput of Huggingface, only add the clamp. |
| """ |
|
|
| def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = clamp_values(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) |
| hidden_states = clamp_values(hidden_states) |
| return hidden_states |
|
|