File size: 5,378 Bytes
9fc352f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import math
from transformers import PreTrainedModel
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch
import torch.nn as nn
from fairseq.modules.multihead_attention import MultiheadAttention
from .extra_fns import ACT2FN


@dataclass
class AbRepOutput():
    """
    Dataclass used to store AbRep output.
    """
    last_hidden_state: torch.FloatTensor
    all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

        
class EncoderBlocks(PreTrainedModel):
    """
    Wrapper for multiple EncoderBlocks (or a single).
    """
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.Layers = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        for num_block, a_EncoderBlock in enumerate(self.Layers):
            hidden_states, attentions = a_EncoderBlock(hidden_states, attention_mask, output_attentions)
            if output_hidden_states: 
                all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
            if output_attentions: 
                all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
        return AbRepOutput(last_hidden_state=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions)
    
    
class EncoderBlock(PreTrainedModel):
    """
    Single EncoderBlock.
    An EncoderBlock consists of a MultiHeadAttention and a IntermediateLayer.
    """
    def __init__(self, config):
        super().__init__(config)
        self.MultiHeadAttention = ThirdMultiHeadAttention(config)
        self.MHADropout = nn.Dropout(config.hidden_dropout_prob)
        self.MHALayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.IntermediateLayer = IntermediateLayer(config)
        
    def forward(self, hidden_states, attention_mask=None, output_attentions=False):
        MHAoutput, attentions = self.MultiHeadAttention(hidden_states, attention_mask, output_attentions=output_attentions)
        output = self.MHADropout(MHAoutput)
        output = self.MHALayerNorm(output + hidden_states) # HIDDEN_STATES ARE ADDED FOR RESIDUAL BLOCK EFFECT
        output = self.IntermediateLayer(output) # INTERMEDIATELAYER HAS RESIDUAL BLOCK EFFECT INTERNALLY
        return output, attentions
    
    
class ThirdMultiHeadAttention(PreTrainedModel):
    """
    New MultiHeadAttention which can return the weights of the individual heads.
    """
    def __init__(self, config):
        super().__init__(config)  
        self.Attention = MultiheadAttention(config.hidden_size, config.num_attention_heads, dropout=config.attention_probs_dropout_prob, self_attention=True)
        
    def forward(self, hidden_states, attention_mask=None, output_attentions=False): 
        hidden_states = torch.transpose(hidden_states, 0, 1)
        # static_kv is only True because there is currently a bug which doesn't return the head weights unaveraged unless its true
        attn_output, attn_weights = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, static_kv=True, 
                                                   need_weights=output_attentions, need_head_weights=output_attentions)
        return torch.transpose(attn_output, 0, 1), attn_weights


class OldMultiHeadAttention(PreTrainedModel):
    """
    MultiHeadAttention contains a Scaled Dot Product Attention and a Linear Layer.    
    """
    def __init__(self, config):
        super().__init__(config)
        self.Attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.attention_probs_dropout_prob)

    def forward(self, hidden_states, attention_mask=None, output_attentions=False):  
        hidden_states = torch.transpose(hidden_states, 0, 1)
        output, attentions = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, need_weights=output_attentions) 
        attention_output = torch.transpose(output, 0, 1)
        return attention_output, attentions
    

class IntermediateLayer(PreTrainedModel):
    """
    Contains an expanding layer, while also functioning as a residual block ending with a drop-norm layer
    """
    def __init__(self, config):
        super().__init__(config)
        self.expand_dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = ACT2FN[config.hidden_act]
        
        self.dense_dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        output = self.expand_dense(hidden_states)
        output = self.intermediate_act_fn(output)
        output = self.dense_dense(output)
        output = self.dropout(output)
        output = self.LayerNorm(output + hidden_states)
        return output