AbLang_light / encoderblocks.py
qilowoq's picture
Upload AbLang
9fc352f
raw
history blame
5.38 kB
import math
from transformers import PreTrainedModel
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch
import torch.nn as nn
from fairseq.modules.multihead_attention import MultiheadAttention
from .extra_fns import ACT2FN
@dataclass
class AbRepOutput():
"""
Dataclass used to store AbRep output.
"""
last_hidden_state: torch.FloatTensor
all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class EncoderBlocks(PreTrainedModel):
"""
Wrapper for multiple EncoderBlocks (or a single).
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.Layers = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for num_block, a_EncoderBlock in enumerate(self.Layers):
hidden_states, attentions = a_EncoderBlock(hidden_states, attention_mask, output_attentions)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
if output_attentions:
all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
return AbRepOutput(last_hidden_state=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions)
class EncoderBlock(PreTrainedModel):
"""
Single EncoderBlock.
An EncoderBlock consists of a MultiHeadAttention and a IntermediateLayer.
"""
def __init__(self, config):
super().__init__(config)
self.MultiHeadAttention = ThirdMultiHeadAttention(config)
self.MHADropout = nn.Dropout(config.hidden_dropout_prob)
self.MHALayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.IntermediateLayer = IntermediateLayer(config)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
MHAoutput, attentions = self.MultiHeadAttention(hidden_states, attention_mask, output_attentions=output_attentions)
output = self.MHADropout(MHAoutput)
output = self.MHALayerNorm(output + hidden_states) # HIDDEN_STATES ARE ADDED FOR RESIDUAL BLOCK EFFECT
output = self.IntermediateLayer(output) # INTERMEDIATELAYER HAS RESIDUAL BLOCK EFFECT INTERNALLY
return output, attentions
class ThirdMultiHeadAttention(PreTrainedModel):
"""
New MultiHeadAttention which can return the weights of the individual heads.
"""
def __init__(self, config):
super().__init__(config)
self.Attention = MultiheadAttention(config.hidden_size, config.num_attention_heads, dropout=config.attention_probs_dropout_prob, self_attention=True)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
hidden_states = torch.transpose(hidden_states, 0, 1)
# static_kv is only True because there is currently a bug which doesn't return the head weights unaveraged unless its true
attn_output, attn_weights = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, static_kv=True,
need_weights=output_attentions, need_head_weights=output_attentions)
return torch.transpose(attn_output, 0, 1), attn_weights
class OldMultiHeadAttention(PreTrainedModel):
"""
MultiHeadAttention contains a Scaled Dot Product Attention and a Linear Layer.
"""
def __init__(self, config):
super().__init__(config)
self.Attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.attention_probs_dropout_prob)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
hidden_states = torch.transpose(hidden_states, 0, 1)
output, attentions = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, need_weights=output_attentions)
attention_output = torch.transpose(output, 0, 1)
return attention_output, attentions
class IntermediateLayer(PreTrainedModel):
"""
Contains an expanding layer, while also functioning as a residual block ending with a drop-norm layer
"""
def __init__(self, config):
super().__init__(config)
self.expand_dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act]
self.dense_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
output = self.expand_dense(hidden_states)
output = self.intermediate_act_fn(output)
output = self.dense_dense(output)
output = self.dropout(output)
output = self.LayerNorm(output + hidden_states)
return output