File size: 4,831 Bytes
c3d0293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from torch import nn
import torch.nn.functional as F
import copy
from torch.nn import MultiheadAttention
from motion.model.layer_norm_fp16 import LayerNorm, RMSNorm
import numpy as np
import math

class SwiGLU(nn.Module):
    '''
    follow the structure of llama
    '''
    def __init__(self, dim, hidden_dim, multiple_of = 256):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias= False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

def _get_activation_fn(activation: str):
    if activation.lower() == "relu":
        return F.relu
    elif activation.lower() == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class RefinedLayer(nn.Module):
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1,
                 activation = F.relu, layer_norm_eps = 1e-5, device=None, dtype=None, max_seq_len=196, position_type="static", word_tokens=False, norm_type="rmsnorm", attention_type="torch"):
        factory_kwargs = {'device': device, 'dtype': dtype, "bias":False}
        super().__init__()
        if norm_type.lower() == "rmsnorm":
            Norm = RMSNorm
        elif norm_type.lower() == "layer":
            Norm = LayerNorm

        self.attention_type = attention_type
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs)

        if word_tokens:
            self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False, **factory_kwargs)
            self.norm3 = Norm(d_model, layer_norm_eps)
            self.dropout3 = nn.Dropout(dropout)
        self.word_tokens = word_tokens
        # Implementation of Feedforward model

        self.norm1 = Norm(d_model, layer_norm_eps)
        self.norm2 = Norm(d_model, layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str) and activation.lower() != "swiglu":
            activation = _get_activation_fn(activation)
            self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
            self.dropout = nn.Dropout(dropout)
            self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)      
            self.ffn = self._ff_block 
        elif activation.lower() == "swiglu":
            self.ffn = SwiGLU(d_model, dim_feedforward)
        
        self.activation = activation

    def forward(
            self,
            src,
            word_tokens = None,
            src_mask = None,
            src_key_padding_mask = None):
        x = src
        x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)   
        if self.word_tokens:
            x = x + self._csa_block(self.norm3(x), word_tokens)   
        x = x + self.dropout2(self.ffn(self.norm2(x)))
        return x

    # encoder block
    def _sa_block(self, x, attn_mask, key_padding_mask):
        x = self.self_attn(x, x, x,
                        attn_mask=attn_mask,
                        key_padding_mask=key_padding_mask,
                        need_weights=False)[0]


        return self.dropout1(x)

    # multihead attention block
    def _csa_block(self, x, mem, attn_mask=None, key_padding_mask=None):
        x = self.cross_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                need_weights=False)[0]


        return self.dropout3(x)

    # feed forward block
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return x

class Refined_Transformer(nn.Module):
    def __init__(self, refined_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(refined_layer, num_layers)
        self.num_layers = num_layers

    def forward(
            self,
            src,
            word_tokens=None,
            src_mask=None,
            src_key_padding_mask = None):
        output = src
        src_key_padding_mask_for_layers = src_key_padding_mask
        for mod in self.layers:
            output = mod(output, word_tokens=word_tokens, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers)
        return output