Redbuilder1433 commited on
Commit
20ca364
·
verified ·
1 Parent(s): 7ec2393

Upload transformerdecoder.py

Browse files
Files changed (1) hide show
  1. transformerdecoder.py +248 -0
transformerdecoder.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import numpy as np
5
+ import math
6
+ from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+
9
+ import torchvision
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from datasets import load_dataset_builder
12
+ from datasets import load_dataset
13
+ from transformers import DataCollatorForLanguageModeling
14
+ from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
15
+ from torch.optim import AdamW
16
+ from trl import SFTTrainer, SFTConfig
17
+ from transformers import TrainingArguments, Trainer
18
+
19
+ pretrain_data = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split="train") # ["text"] contains the data
20
+ tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
21
+ vocab_size = len(tokenizer)
22
+ if tokenizer.pad_token is None:
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+
26
+
27
+
28
+ class PositionalEncoding(nn.Module):
29
+ def __init__(self, d_model):
30
+ super().__init__()
31
+ self.pos_enc = nn.Sequential(
32
+ nn.Linear(1, d_model*4),
33
+ nn.Tanh(),
34
+ nn.Linear(d_model*4, d_model)
35
+ )
36
+ def forward(self, seq_len, device):
37
+ pos = torch.arange(seq_len, device=device, dtype=torch.float32).unsqueeze(-1) # (seq_len, 1)
38
+ pe = self.pos_enc(pos) # (seq_len, d_model)
39
+ return pe.unsqueeze(0) # (1, seq_len, d_model)
40
+
41
+ class AugmentedPositionGPTConfig(PretrainedConfig):
42
+ model_type = "AugmentedPositionGPT"
43
+
44
+ def __init__(
45
+ self,
46
+ vocab_size=vocab_size,
47
+ d_model=128,
48
+ num_heads=2,
49
+ num_layers=1,
50
+ max_position_embeddings=512,
51
+ **kwargs,
52
+ ):
53
+ super().__init__(**kwargs)
54
+ self.vocab_size = vocab_size
55
+ self.d_model = d_model
56
+ self.num_heads = num_heads
57
+ self.num_layers = num_layers
58
+ self.max_position_embeddings = max_position_embeddings
59
+
60
+
61
+
62
+
63
+
64
+
65
+ class AugmentedPositionGPTBlock(nn.Module):
66
+ def __init__(self, d_model, num_heads):
67
+ super().__init__()
68
+ self.d_model = d_model
69
+ #self.output_embedding = nn.Embedding(vocab_size, d_model)
70
+ self.multiheadattention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
71
+ self.norm1 = nn.LayerNorm(d_model)
72
+ self.norm2 = nn.LayerNorm(d_model)
73
+ self.ffn1 = nn.Linear(d_model, 4*d_model)
74
+ self.ffn2 = nn.Linear(d_model*4, d_model)
75
+ #self.linear = nn.Linear(d_model, vocab_size)
76
+ def forward(self, x, causal_mask=None):
77
+ residual = x
78
+ normx = self.norm1(x)
79
+ attn_out, _ = self.multiheadattention(normx, normx, normx, attn_mask=causal_mask) # Attention(Q, K, V) = softmax(Q @ K.T / sqrt(d_k) + mask) @ V
80
+ # output: (batch, seq_len, d_model)
81
+ x = residual + attn_out
82
+ residual2 = x
83
+ j = self.ffn1(self.norm2(x)) # takes in: (batch, seq_len, d_model)
84
+
85
+
86
+ # outputs: (batch, seq_len, d_model*4)
87
+ h = self.ffn2(F.relu(j)) # takes in: (batch, seq_len, d_model*4)
88
+ x = residual2 + h
89
+ # outputs: (batch, seq_len, d_model)
90
+
91
+ return x
92
+ class AugmentedPositionGPT(PreTrainedModel):
93
+ config_class = AugmentedPositionGPTConfig
94
+
95
+ def __init__(self, config):
96
+ super().__init__(config)
97
+ self.vocab_size = config.vocab_size
98
+ self.d_model = config.d_model
99
+ self.num_heads = config.num_heads
100
+ self.num_layers = config.num_layers
101
+ self.max_position_embeddings = config.max_position_embeddings
102
+ self.output_embedding = nn.Embedding(self.vocab_size, self.d_model)
103
+ self.blocks = nn.ModuleList(
104
+ [AugmentedPositionGPTBlock(self.d_model, self.num_heads) for _ in range(self.num_layers)]
105
+ )
106
+ self.ln_f = nn.LayerNorm(self.d_model)
107
+ self.register_buffer(
108
+ "position_ids",
109
+ torch.arange(self.max_position_embeddings).unsqueeze(0), # (1, seq_len)
110
+ persistent=False,
111
+ )
112
+ self.post_init()
113
+
114
+ def causal_mask(self, seq_len, device):
115
+ # (seq_len, seq_len)
116
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
117
+ return causal_mask
118
+
119
+ def positional_encoding(self, seq_len, device):
120
+ d_model = self.d_model
121
+ # EVEN: PE(pos, 2i) = sin(pos/10000^(2i/d_model))
122
+ # ODD: PE(pos, 2i+1) = cos(pos/10000^(2i/dmodel))
123
+ a = 10000
124
+
125
+
126
+ i = torch.arange(0, d_model, 2, device=device, dtype=torch.float32) # (d_model/2)
127
+ div_term = a ** (i / d_model) # (d_model/2)
128
+ position = torch.arange(seq_len, device=device, dtype=torch.float32).unsqueeze(1) # (seq_len, 1)
129
+ angles = position / div_term # (seq_len, d_model/2)
130
+ pe = torch.zeros(seq_len, d_model, device=device, dtype=torch.float32) # (seq_len, d_model)
131
+ pe[:, 0::2] = torch.sin(angles)
132
+ pe[:, 1::2] = torch.cos(angles)
133
+ pe = pe.unsqueeze(0)
134
+ # shape: (1, seq_len, d_model)
135
+ return pe
136
+
137
+ def forward(
138
+ self,
139
+ input_ids=None,
140
+ attention_mask=None,
141
+ input_embeds=None,
142
+ output_hidden_states=False,
143
+ return_dict=True
144
+ ):
145
+ if input_ids is not None and input_embeds is not None:
146
+ raise ValueError("you cant specify both input_ids and input_embeds")
147
+ if input_embeds is None:
148
+ #max_id = input_ids.max().item()
149
+ #min_id = input_ids.min().item()
150
+ #if max_id >= self.vocab_size or min_id < 0:
151
+ #raise RuntimeError(
152
+ #f"Bad token id: min={min_id}, max={max_id}, "
153
+ #f"embedding vocab_size={self.vocab_size}"
154
+ #)
155
+ input_embeds = self.output_embedding(input_ids) # (batch, seq_len, d_model)
156
+ batch, seq_len, _ = input_embeds.shape
157
+ device = input_embeds.device
158
+ # output embeddings and postional encoding
159
+ x = self.output_embedding(input_ids) # (batch, seq_len, d_model)
160
+ pe = self.positional_encoding(seq_len, device=device) # (1, seq_len, d_model)
161
+ x = x + pe # (batch, seq_len, d_model)
162
+ causal_mask = self.causal_mask(seq_len, device)
163
+ all_hidden_states = [] if output_hidden_states else None
164
+ for block in self.blocks:
165
+ if output_hidden_states:
166
+ all_hidden_states.append(x)
167
+ x = block(x, causal_mask=causal_mask)
168
+ x = self.ln_f(x)
169
+ if not return_dict:
170
+ return (x, all_hidden_states)
171
+
172
+ return {"last_hidden_state": x, "hidden_states": all_hidden_states}
173
+
174
+
175
+ class AugmentedPositionGPTForCausalLM(PreTrainedModel):
176
+ config_class = AugmentedPositionGPTConfig
177
+
178
+ def __init__(self, config):
179
+ super().__init__(config)
180
+ self.transformerdecoder = AugmentedPositionGPT(config)
181
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
182
+
183
+ self.lm_head.weight = self.transformerdecoder.output_embedding.weight
184
+ self._dynamic_tied_weights_keys = { # make sure to tell huggingface everything you do or else it will explode
185
+ "lm_head.weight": "transformerdecoder.output_embedding.weight"
186
+ }
187
+ self.post_init()
188
+
189
+
190
+
191
+ def forward(self, input_ids=None, attention_mask=None, input_embeds=None, labels=None, output_hidden_states=False, return_dict=True):
192
+ outputs= self.transformerdecoder(
193
+ input_ids=input_ids,
194
+ attention_mask=attention_mask,
195
+ input_embeds = input_embeds,
196
+ output_hidden_states=output_hidden_states,
197
+ return_dict=True
198
+ )
199
+ hidden_states = outputs["last_hidden_state"] # (batch, seq_len, d_model)
200
+ logits = self.lm_head(hidden_states) # (batch, seq_len, vocab_size)
201
+ loss = None
202
+ if labels is not None:
203
+ loss_fct = nn.CrossEntropyLoss()
204
+ loss = loss_fct(
205
+ logits.view(-1, logits.size(-1)),
206
+ labels.view(-1)
207
+ )
208
+ if not return_dict:
209
+ output = (logits,)
210
+ if output_hidden_states:
211
+ output += (outputs["hidden_states"],)
212
+ return ((loss,) + output) if loss is not None else output
213
+ return CausalLMOutputWithCrossAttentions(
214
+ loss=loss,
215
+ logits=logits,
216
+ hidden_states=outputs["hidden_states"],
217
+ attentions=None,
218
+ cross_attentions=None
219
+ )
220
+ config = AugmentedPositionGPTConfig(vocab_size=vocab_size)
221
+ collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
222
+ model = AugmentedPositionGPTForCausalLM(config)
223
+ def tokenize(examples):
224
+ return tokenizer(examples["text"], truncation=True, max_length=512)
225
+ pretrain_data_tok = pretrain_data.map(
226
+ tokenize,
227
+ batched=True,
228
+ remove_columns=["text"], # remove raw text so Trainer doesn't pass it
229
+ )
230
+
231
+ training_args = TrainingArguments(
232
+ output_dir = "AugmentedGPT/results",
233
+ num_train_epochs=1,
234
+ per_device_eval_batch_size=1,
235
+ remove_unused_columns=False,
236
+ gradient_accumulation_steps=8,
237
+ fp16=True,
238
+
239
+
240
+ )
241
+
242
+ trainer = Trainer(
243
+ model=model,
244
+ args=training_args,
245
+ train_dataset=pretrain_data_tok,
246
+ data_collator=collator,
247
+ )
248
+ trainer.train()