File size: 8,884 Bytes
f5fd4e7
 
 
 
612c8db
 
 
f5fd4e7
 
 
 
 
612c8db
 
 
 
f5fd4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612c8db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5fd4e7
 
612c8db
 
 
 
 
f5fd4e7
 
 
 
 
612c8db
f5fd4e7
 
 
 
 
612c8db
 
 
f5fd4e7
 
 
 
 
 
 
612c8db
 
 
 
f5fd4e7
612c8db
f5fd4e7
612c8db
f5fd4e7
 
 
 
612c8db
f5fd4e7
612c8db
f5fd4e7
 
 
612c8db
 
f5fd4e7
612c8db
 
 
 
 
 
 
 
 
f5fd4e7
612c8db
 
f5fd4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612c8db
 
 
 
 
 
 
 
f5fd4e7
612c8db
 
f5fd4e7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
""" PyTorch GPT1 model."""

import math

import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
    BaseModelOutput,
    CausalLMOutput,
)
from transformers.activations import ACT2FN

from configuration_gpt1 import GPT1Config


class GPT1RMSNorm(nn.Module):
    def __init__(self, config: GPT1Config):
        super().__init__()
        self.config = config
        self.weight = nn.Parameter(torch.ones(config.hidden_size))

    def _norm(self, x):
        std = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
        return x / (std + self.config.layer_norm_eps)

    def forward(self, hidden_state):
        input_dtype = hidden_state.dtype
        # compute in float32, not in fp16, since normalization needs to be accurate
        hidden_state = hidden_state.float()
        output = self._norm(hidden_state).type_as(input_dtype)
        return output * self.weight


class GPT1MLP(nn.Module):
    def __init__(self, config: GPT1Config):
        super().__init__()
        self.activation_fn = ACT2FN(config.hidden_act)
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_state):
        hidden_state = self.fc1(hidden_state)
        hidden_state = self.activation_fn(hidden_state)
        hidden_state = self.fc2(hidden_state)
        return hidden_state


class GPT1Attention(nn.Module):
    def __init__(self, config: GPT1Config):
        """
        Multi-head attention layer.
        """
        super().__init__()

        assert config.hidden_size % config.num_attention_heads == 0
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.attn_dropout = nn.Dropout(p=config.attention_dropout)

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)

    def forward(self, hidden_state, attn_mask):
        bs, seq_len, _ = hidden_state.size() # (batch_size, seq_len, dim)

        # linearly project the inputs
        Q = self.q_proj(hidden_state) # (batch_size, seq_len, n_heads * head_dim)
        K = self.k_proj(hidden_state)
        V = self.v_proj(hidden_state)

        # split into n_heads to compute attention
        queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim)
        keys = K.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = V.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # compute attention matmul
        keys = keys.transpose(2, 3) # (batch_size, n_heads, head_dim, seq_len)
        attn_scores = queries @ keys # (batch_size, n_heads, seq_len, seq_len)

        # scale
        attn_scores = attn_scores / math.sqrt(self.head_dim)

        # mask
        if attn_mask is not None:
            attn_scores = attn_scores + attn_mask

        # softmax (attention probabilities) + dropout
        attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
        attn_probs = self.attn_dropout(attn_probs)

        # matmul
        attn_output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bs, seq_len, self.hidden_size) # (batch_size, seq_len, n_heads * head_dim)

        # final linear
        attn_output = self.o_proj(attn_output)
        return attn_output


class GPT1DecoderLayer(nn.Module):
    def __init__(self, config: GPT1Config):
        super().__init__()
        self.attention = GPT1Attention(config)
        self.mlp = GPT1MLP(config)

        self.attention_norm = GPT1RMSNorm(config)
        self.mlp_norm = GPT1RMSNorm(config)

        self.res_dropout = nn.Dropout(p=config.resid_pdrop)

    def forward(self, hidden_state, attn_mask):
        # attention
        residual = hidden_state
        hidden_state = self.attention_norm(hidden_state)
        hidden_state = self.attention(hidden_state, attn_mask)
        hidden_state = self.res_dropout(hidden_state)
        hidden_state = residual + hidden_state

        # feed forward fully connected
        residual = hidden_state
        hidden_state = self.mlp_norm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state = self.res_dropout(hidden_state)
        hidden_state = residual + hidden_state

        return hidden_state


class GPT1PreTrainedModel(PreTrainedModel):
    config_class = GPT1Config
    supports_gradient_checkpointing = False

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class GPT1Model(GPT1PreTrainedModel):

    def __init__(self, config: GPT1Config):
        super().__init__(config)

        # embeddings
        self.embs = nn.Embedding(config.vocab_size, config.hidden_size)
        self.embs_dropout = nn.Dropout(p=config.embd_pdrop)

        # positional encoding (learned)
        self.pos_emb = nn.Embedding(config.max_position_embeddings,
                                    config.hidden_size)

        self.layers = nn.ModuleList(
            [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
        )

        self.norm = GPT1RMSNorm(config)

        causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
                                 fill_value=float('-inf'))
        self.register_buffer('causal_mask',
                             torch.triu(causal_mask, diagonal=1),
                             persistent=False)

        self.post_init()

    def get_input_embeddings(self):
        return self.embs

    def set_input_embeddings(self, value):
        self.embs = value

    def forward(self, input_ids, *args, **kwargs):
        position_ids = torch.arange(input_ids.size()[-1],
                                    dtype=torch.long,
                                    device=input_ids.device)

        input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
        position_embeds = self.pos_emb(position_ids)
        hidden_state = self.embs_dropout(input_embeds) + position_embeds

        causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
                                          device=input_embeds.device)
        for layer in self.layers:
            hidden_state = layer(hidden_state, attn_mask=causal_mask)

        hidden_state = self.norm(hidden_state)

        return BaseModelOutput(
            last_hidden_state=hidden_state
        )


class GPT1ForCausalLM(GPT1PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: GPT1Config):
        super().__init__(config)
        self.model = GPT1Model(config)
        self.vocab_size = config.vocab_size

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

        # initialize weigths and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embs

    def set_input_embeddings(self, value):
        self.model.embs = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_decoder(self):
        return self.model

    def set_decoder(self, decoder):
        self.model = decoder

    def forward(self, input_ids, labels = None, *args, **kwargs):
        output = self.model(input_ids)

        logits = self.lm_head(output).float()

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fn = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            loss = loss_fn(shift_logits, shift_labels)

        return CausalLMOutput(
            loss=loss,
            logits=logits
        )

    def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
        return { 'input_ids': input_ids }