Alexandru Gherghescu commited on
Commit
f5fd4e7
1 Parent(s): a80c40f

Fix modeling and configuration scripts

Browse files

Replace modeling with an actual implementation, instead of using
PyTorch's built-in Transformer. This should be easier to configure and
hack.

Files changed (2) hide show
  1. configuration_gpt1.py +1 -1
  2. modeling_gpt1.py +188 -36
configuration_gpt1.py CHANGED
@@ -1,6 +1,6 @@
1
  """ GPT1 model configuration """
2
 
3
- from transformers import PretrainedConfig
4
 
5
 
6
  class GPT1Config(PretrainedConfig):
 
1
  """ GPT1 model configuration """
2
 
3
+ from transformers.configuration_utils import PretrainedConfig
4
 
5
 
6
  class GPT1Config(PretrainedConfig):
modeling_gpt1.py CHANGED
@@ -1,10 +1,137 @@
 
 
 
 
1
  import torch
2
  from torch import nn
3
  from transformers import PreTrainedModel
 
 
 
 
 
4
 
5
  from configuration_gpt1 import GPT1Config
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class GPT1PreTrainedModel(PreTrainedModel):
9
  config_class = GPT1Config
10
  supports_gradient_checkpointing = False
@@ -27,50 +154,55 @@ class GPT1Model(GPT1PreTrainedModel):
27
  super().__init__(config)
28
 
29
  # embeddings
30
- self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
31
- self.embeddings_dropout = nn.Dropout(config.embd_pdrop)
32
 
33
  # positional encoding (learned)
34
  self.pos_emb = nn.Embedding(config.max_position_embeddings,
35
  config.hidden_size)
36
 
37
- dec_layers = nn.TransformerEncoderLayer(d_model=config.hidden_size,
38
- nhead=config.num_attention_heads,
39
- dim_feedforward=config.intermediate_size,
40
- dropout=config.attention_dropout,
41
- activation=config.hidden_act,
42
- layer_norm_eps=config.layer_norm_eps,
43
- batch_first=True)
44
 
45
- self.layers = nn.TransformerEncoder(dec_layers, config.num_hidden_layers)
 
 
 
 
46
 
47
  self.post_init()
48
 
49
- def forward(
50
- self,
51
- input_ids,
52
- attention_mask
53
- ):
 
 
54
  position_ids = torch.arange(input_ids.size()[-1],
55
  dtype=torch.long,
56
  device=input_ids.device)
57
 
58
- input_embeds = self.embeddings(input_ids) # (bs, seq_len, dim)
59
  position_embeds = self.pos_emb(position_ids)
60
- hidden_state = input_embeds + position_embeds
61
-
62
- hidden_state = self.embeddings_dropout(hidden_state)
63
-
64
- _, seq_len, _ = hidden_state.shape # (bs, seq_len, dim)
65
 
66
- attention_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(hidden_state.device)
 
 
 
67
 
68
- output = self.layers(hidden_state, attention_mask, is_causal=True)
69
 
70
- return output
 
 
71
 
72
 
73
- class GPT1ModelForCausalLM(GPT1PreTrainedModel):
74
  _tied_weights_keys = ["lm_head.weight"]
75
 
76
  def __init__(self, config: GPT1Config):
@@ -80,15 +212,29 @@ class GPT1ModelForCausalLM(GPT1PreTrainedModel):
80
 
81
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
82
 
 
83
  self.post_init()
84
 
85
- def forward(
86
- self,
87
- input_ids,
88
- attention_mask,
89
- labels
90
- ):
91
- output = self.model(input_ids, attention_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  logits = self.lm_head(output).float()
94
 
@@ -97,9 +243,15 @@ class GPT1ModelForCausalLM(GPT1PreTrainedModel):
97
  shift_logits = logits[..., :-1, :].contiguous()
98
  shift_labels = labels[..., 1:].contiguous()
99
 
100
- loss_fct = torch.nn.CrossEntropyLoss()
101
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
102
  shift_labels = shift_labels.view(-1)
103
- loss = loss_fct(shift_logits, shift_labels)
104
- return { "loss": loss, "logits": logits }
105
- return { "logits": logits }
 
 
 
 
 
 
 
1
+ """ PyTorch GPT1 model."""
2
+
3
+ import math
4
+
5
  import torch
6
  from torch import nn
7
  from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ CausalLMOutput,
11
+ )
12
+ from transformers.activations import ACT2FN
13
 
14
  from configuration_gpt1 import GPT1Config
15
 
16
 
17
+ class GPT1RMSNorm(nn.Module):
18
+ def __init__(self, config: GPT1Config):
19
+ super().__init__()
20
+ self.config = config
21
+ self.weight = nn.Parameter(torch.ones(config.hidden_size))
22
+
23
+ def _norm(self, x):
24
+ std = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
25
+ return x / (std + self.config.layer_norm_eps)
26
+
27
+ def forward(self, hidden_state):
28
+ input_dtype = hidden_state.dtype
29
+ # compute in float32, not in fp16, since normalization needs to be accurate
30
+ hidden_state = hidden_state.float()
31
+ output = self._norm(hidden_state).type_as(input_dtype)
32
+ return output * self.weight
33
+
34
+
35
+ class GPT1MLP(nn.Module):
36
+ def __init__(self, config: GPT1Config):
37
+ super().__init__()
38
+ self.activation_fn = ACT2FN(config.hidden_act)
39
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
40
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
41
+
42
+ def forward(self, hidden_state):
43
+ hidden_state = self.fc1(hidden_state)
44
+ hidden_state = self.activation_fn(hidden_state)
45
+ hidden_state = self.fc2(hidden_state)
46
+ return hidden_state
47
+
48
+
49
+ class GPT1Attention(nn.Module):
50
+ def __init__(self, config: GPT1Config):
51
+ """
52
+ Multi-head attention layer.
53
+ """
54
+ super().__init__()
55
+
56
+ assert config.hidden_size % config.num_attention_heads == 0
57
+ self.hidden_size = config.hidden_size
58
+ self.num_heads = config.num_attention_heads
59
+ self.head_dim = self.hidden_size // self.num_heads
60
+ self.attn_dropout = nn.Dropout(p=config.attention_dropout)
61
+
62
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
63
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
64
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
65
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
66
+
67
+ def forward(self, hidden_state, attn_mask):
68
+ bs, seq_len, _ = hidden_state.size() # (batch_size, seq_len, dim)
69
+
70
+ # linearly project the inputs
71
+ Q = self.q_proj(hidden_state) # (batch_size, seq_len, n_heads * head_dim)
72
+ K = self.k_proj(hidden_state)
73
+ V = self.v_proj(hidden_state)
74
+
75
+ # split into n_heads to compute attention
76
+ queries = Q.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim)
77
+ keys = K.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
78
+ values = V.view(bs, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
79
+
80
+ # compute attention matmul
81
+ keys = keys.transpose(2, 3) # (batch_size, n_heads, head_dim, seq_len)
82
+ attn_scores = queries @ keys # (batch_size, n_heads, seq_len, seq_len)
83
+
84
+ # scale
85
+ attn_scores = attn_scores / math.sqrt(self.head_dim)
86
+
87
+ # mask
88
+ if attn_mask is not None:
89
+ attn_scores = attn_scores + attn_mask
90
+
91
+ # softmax (attention probabilities) + dropout
92
+ attn_probs = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
93
+ attn_probs = self.attn_dropout(attn_probs)
94
+
95
+ # matmul
96
+ attn_output = attn_probs @ values # (batch_size, n_heads, seq_len, head_dim)
97
+
98
+ attn_output = attn_output.transpose(1, 2).contiguous()
99
+ attn_output = attn_output.reshape(bs, seq_len, self.hidden_size) # (batch_size, seq_len, n_heads * head_dim)
100
+
101
+ # final linear
102
+ attn_output = self.o_proj(attn_output)
103
+ return attn_output
104
+
105
+
106
+ class GPT1DecoderLayer(nn.Module):
107
+ def __init__(self, config: GPT1Config):
108
+ super().__init__()
109
+ self.attention = GPT1Attention(config)
110
+ self.mlp = GPT1MLP(config)
111
+
112
+ self.attention_norm = GPT1RMSNorm(config)
113
+ self.mlp_norm = GPT1RMSNorm(config)
114
+
115
+ self.res_dropout = nn.Dropout(p=config.resid_pdrop)
116
+
117
+ def forward(self, hidden_state, attn_mask):
118
+ # attention
119
+ residual = hidden_state
120
+ hidden_state = self.attention_norm(hidden_state)
121
+ hidden_state = self.attention(hidden_state, attn_mask)
122
+ hidden_state = self.res_dropout(hidden_state)
123
+ hidden_state = residual + hidden_state
124
+
125
+ # feed forward fully connected
126
+ residual = hidden_state
127
+ hidden_state = self.mlp_norm(hidden_state)
128
+ hidden_state = self.mlp(hidden_state)
129
+ hidden_state = self.res_dropout(hidden_state)
130
+ hidden_state = residual + hidden_state
131
+
132
+ return hidden_state
133
+
134
+
135
  class GPT1PreTrainedModel(PreTrainedModel):
136
  config_class = GPT1Config
137
  supports_gradient_checkpointing = False
 
154
  super().__init__(config)
155
 
156
  # embeddings
157
+ self.embs = nn.Embedding(config.vocab_size, config.hidden_size)
158
+ self.embs_dropout = nn.Dropout(p=config.embd_pdrop)
159
 
160
  # positional encoding (learned)
161
  self.pos_emb = nn.Embedding(config.max_position_embeddings,
162
  config.hidden_size)
163
 
164
+ self.layers = nn.ModuleList(
165
+ [GPT1DecoderLayer(config) for _ in range(config.num_hidden_layers)]
166
+ )
167
+
168
+ self.norm = GPT1RMSNorm(config)
 
 
169
 
170
+ causal_mask = torch.full((1, config.max_position_embeddings, config.max_position_embeddings),
171
+ fill_value=float('-inf'))
172
+ self.register_buffer('causal_mask',
173
+ torch.triu(causal_mask, diagonal=1),
174
+ persistent=False)
175
 
176
  self.post_init()
177
 
178
+ def get_input_embeddings(self):
179
+ return self.embs
180
+
181
+ def set_input_embeddings(self, value):
182
+ self.embs = value
183
+
184
+ def forward(self, input_ids, *args, **kwargs):
185
  position_ids = torch.arange(input_ids.size()[-1],
186
  dtype=torch.long,
187
  device=input_ids.device)
188
 
189
+ input_embeds = self.embs(input_ids) # (bs, seq_len, dim)
190
  position_embeds = self.pos_emb(position_ids)
191
+ hidden_state = self.embs_dropout(input_embeds) + position_embeds
 
 
 
 
192
 
193
+ causal_mask = self.causal_mask.to(dtype=input_embeds.dtype,
194
+ device=input_embeds.device)
195
+ for layer in self.layers:
196
+ hidden_state = layer(hidden_state, attn_mask=causal_mask)
197
 
198
+ hidden_state = self.norm(hidden_state)
199
 
200
+ return BaseModelOutput(
201
+ last_hidden_state=hidden_state
202
+ )
203
 
204
 
205
+ class GPT1ForCausalLM(GPT1PreTrainedModel):
206
  _tied_weights_keys = ["lm_head.weight"]
207
 
208
  def __init__(self, config: GPT1Config):
 
212
 
213
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
214
 
215
+ # initialize weigths and apply final processing
216
  self.post_init()
217
 
218
+ def get_input_embeddings(self):
219
+ return self.model.embs
220
+
221
+ def set_input_embeddings(self, value):
222
+ self.model.embs = value
223
+
224
+ def get_output_embeddings(self):
225
+ return self.lm_head
226
+
227
+ def set_output_embeddings(self, new_embeddings):
228
+ self.lm_head = new_embeddings
229
+
230
+ def get_decoder(self):
231
+ return self.model
232
+
233
+ def set_decoder(self, decoder):
234
+ self.model = decoder
235
+
236
+ def forward(self, input_ids, labels = None, *args, **kwargs):
237
+ output = self.model(input_ids)
238
 
239
  logits = self.lm_head(output).float()
240
 
 
243
  shift_logits = logits[..., :-1, :].contiguous()
244
  shift_labels = labels[..., 1:].contiguous()
245
 
246
+ loss_fn = torch.nn.CrossEntropyLoss()
247
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
248
  shift_labels = shift_labels.view(-1)
249
+ loss = loss_fn(shift_logits, shift_labels)
250
+
251
+ return CausalLMOutput(
252
+ loss=loss,
253
+ logits=logits
254
+ )
255
+
256
+ def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
257
+ return { 'input_ids': input_ids }