uploading preprocessing script, model code, and training script
Browse files- model.py +365 -0
- preprocessing.py +113 -0
- train.py +260 -0
model.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#.\experiments\experiment1\model.py
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logging.basicConfig(level=logging.DEBUG)
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from prereqs.nanoGPT.model import GPTConfig, GPT, MLP
|
13 |
+
|
14 |
+
# set up logger
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
logger.setLevel(logging.DEBUG)
|
17 |
+
|
18 |
+
def new_rielu(x):
|
19 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class RotationallyInvariantGPTConfig:
|
23 |
+
block_size: int = 512
|
24 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
25 |
+
n_layer: int = 6
|
26 |
+
n_head: int = 8
|
27 |
+
n_embd: int = 768
|
28 |
+
dropout: float = 0.0
|
29 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
30 |
+
rotational_invariance: bool = True # Set to True to enable the rotationally invariant gate layers
|
31 |
+
|
32 |
+
# Models
|
33 |
+
class RotationInvariantLayerNorm(nn.Module):
|
34 |
+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
35 |
+
def __init__(self, ndim, bias):
|
36 |
+
super().__init__()
|
37 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
38 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
39 |
+
self.rotation_gate = nn.Linear(ndim, ndim, bias=False) # no bias needed for rotation
|
40 |
+
self.rotation_gate.weight.data = torch.eye(ndim)
|
41 |
+
|
42 |
+
def forward(self, input, rotation_matrix=None):
|
43 |
+
# apply rotation
|
44 |
+
if rotation_matrix is not None:
|
45 |
+
input = torch.matmul(input, self.rotation_gate(rotation_matrix))
|
46 |
+
|
47 |
+
# normalize
|
48 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
49 |
+
|
50 |
+
class RotationallyInvariantAttention(nn.Module):
|
51 |
+
def __init__(self, config):
|
52 |
+
super().__init__()
|
53 |
+
assert config.n_embd % config.n_head == 0
|
54 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
55 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
56 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
57 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
58 |
+
self.n_head = config.n_head
|
59 |
+
self.n_embd = config.n_embd
|
60 |
+
self.dropout = config.dropout
|
61 |
+
self.gate_q = nn.Linear(config.n_embd // config.n_head, 1, bias=config.bias)
|
62 |
+
self.gate_k = nn.Linear(config.n_embd // config.n_head, 1, bias=config.bias)
|
63 |
+
|
64 |
+
def forward(self, x, rotation_matrix=None):
|
65 |
+
logging.debug(f'x.size(): {x.size()}')
|
66 |
+
|
67 |
+
B, T, C = x.size()
|
68 |
+
|
69 |
+
logging.debug(f'B: {B}, T: {T}, C: {C}')
|
70 |
+
|
71 |
+
q, k, v = self.c_attn(x).chunk(3, dim=-1)
|
72 |
+
|
73 |
+
logging.debug('Pre-Reshape Q, K, and V')
|
74 |
+
logging.debug(f'q.size(): {q.size()}, k.size(): {k.size()}, v.size(): {v.size()}')
|
75 |
+
logging.debug(f'q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}')
|
76 |
+
|
77 |
+
|
78 |
+
# Reshape q and k to match the shape of att_dotproduct and att_rotation
|
79 |
+
q = q.view(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
|
80 |
+
k = k.view(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
|
81 |
+
v = v.view(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
|
82 |
+
|
83 |
+
logging.debug('Post-Reshape Q, K, and V')
|
84 |
+
logging.debug(f'q.size(): {q.size()}, k.size(): {k.size()}, v.size(): {v.size()}')
|
85 |
+
logging.debug(f'q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}')
|
86 |
+
|
87 |
+
# Compute gate_q and gate_k such that they have the same shape as q and k
|
88 |
+
gate_q = torch.sigmoid(self.gate_q(q.view(B, self.n_head, T, -1)))
|
89 |
+
gate_k = torch.sigmoid(self.gate_k(k.view(B, self.n_head, T, -1)))
|
90 |
+
|
91 |
+
# Traditional dot-product attention
|
92 |
+
qk_dot = q @ k.transpose(-2, -1)
|
93 |
+
att_dotproduct = qk_dot / math.sqrt(self.n_embd)
|
94 |
+
|
95 |
+
# Rotation invariant attention
|
96 |
+
q_norm = torch.sum(q * q, dim=-1, keepdim=True)
|
97 |
+
k_norm = torch.sum(k * k, dim=-1, keepdim=True)
|
98 |
+
distances = q_norm + k_norm.transpose(-2, -1) - 2 * qk_dot
|
99 |
+
att_rotation = -torch.sqrt(distances)
|
100 |
+
att_rotation = att_rotation / math.sqrt(self.n_embd)
|
101 |
+
|
102 |
+
# Apply gating to attention scores
|
103 |
+
mixed_att = att_dotproduct * gate_q + att_rotation * (torch.ones_like(gate_q) - gate_q)
|
104 |
+
att_scores = mixed_att / gate_k
|
105 |
+
|
106 |
+
if rotation_matrix is not None:
|
107 |
+
att_scores = att_scores + rotation_matrix
|
108 |
+
|
109 |
+
att_weights = F.softmax(att_scores, dim=-1)
|
110 |
+
y = att_weights @ v
|
111 |
+
y = y.permute(0, 2, 1, 3).contiguous().view(B, T, C)
|
112 |
+
|
113 |
+
y = self.resid_dropout(self.c_proj(y))
|
114 |
+
return y
|
115 |
+
|
116 |
+
class RotationallyInvariantMLP(nn.Module):
|
117 |
+
|
118 |
+
def __init__(self, config):
|
119 |
+
super().__init__()
|
120 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
121 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
122 |
+
self.dropout = nn.Dropout(config.dropout)
|
123 |
+
self.rotation_gate = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # Added rotational gate layer
|
124 |
+
self.rotation_gate.weight.data = torch.eye(config.n_embd) # Assuming initial rotation matrix as an identity matrix
|
125 |
+
|
126 |
+
def forward(self, x, rotation_matrix=None):
|
127 |
+
x = self.c_fc(x)
|
128 |
+
x = F.gelu(x)
|
129 |
+
x = self.c_proj(x)
|
130 |
+
x = self.dropout(x)
|
131 |
+
|
132 |
+
# Rotational Invariance Part
|
133 |
+
if rotation_matrix is not None:
|
134 |
+
x = torch.matmul(x, self.rotation_gate(rotation_matrix))
|
135 |
+
|
136 |
+
return x
|
137 |
+
|
138 |
+
class RotationallyInvariantBlock(nn.Module):
|
139 |
+
|
140 |
+
def __init__(self, config):
|
141 |
+
super().__init__()
|
142 |
+
self.ln_1 = RotationInvariantLayerNorm(config.n_embd, bias=config.bias)
|
143 |
+
self.attn = RotationallyInvariantAttention(config)
|
144 |
+
self.ln_2 = RotationInvariantLayerNorm(config.n_embd, bias=config.bias)
|
145 |
+
self.mlp = RotationallyInvariantMLP(config)
|
146 |
+
|
147 |
+
def forward(self, x, rotation_matrix=None):
|
148 |
+
x = x + self.attn(self.ln_1(x), rotation_matrix)
|
149 |
+
x = x + self.mlp(self.ln_2(x), rotation_matrix)
|
150 |
+
return x
|
151 |
+
|
152 |
+
class RotationallyInvariantGPT(nn.Module):
|
153 |
+
|
154 |
+
def __init__(self, config):
|
155 |
+
super().__init__()
|
156 |
+
assert config.vocab_size is not None
|
157 |
+
assert config.block_size is not None
|
158 |
+
self.config = config
|
159 |
+
|
160 |
+
self.transformer = nn.ModuleDict(dict(
|
161 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
162 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
163 |
+
drop = nn.Dropout(config.dropout),
|
164 |
+
h = nn.ModuleList([RotationallyInvariantBlock(config) for _ in range(config.n_layer)]),
|
165 |
+
ln_f = RotationInvariantLayerNorm(config.n_embd, bias=config.bias),
|
166 |
+
))
|
167 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
168 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
169 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
170 |
+
# This behavior is deprecated and will be an error in future versions"
|
171 |
+
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
172 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
173 |
+
|
174 |
+
# init all weights
|
175 |
+
self.apply(self._init_weights)
|
176 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
177 |
+
for pn, p in self.named_parameters():
|
178 |
+
if pn.endswith('c_proj.weight'):
|
179 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
180 |
+
|
181 |
+
# report number of parameters
|
182 |
+
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
183 |
+
|
184 |
+
def get_num_params(self, non_embedding=True):
|
185 |
+
"""
|
186 |
+
Return the number of parameters in the model.
|
187 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
188 |
+
The token embeddings would too, except due to the parameter sharing these
|
189 |
+
params are actually used as weights in the final layer, so we include them.
|
190 |
+
"""
|
191 |
+
n_params = sum(p.numel() for p in self.parameters())
|
192 |
+
if non_embedding:
|
193 |
+
n_params -= self.transformer.wpe.weight.numel()
|
194 |
+
return n_params
|
195 |
+
|
196 |
+
def _init_weights(self, module):
|
197 |
+
if isinstance(module, nn.Linear):
|
198 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
199 |
+
if module.bias is not None:
|
200 |
+
torch.nn.init.zeros_(module.bias)
|
201 |
+
elif isinstance(module, nn.Embedding):
|
202 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
203 |
+
|
204 |
+
def forward(self, idx, targets=None):
|
205 |
+
device = idx.device
|
206 |
+
b, t = idx.size()
|
207 |
+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
208 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
209 |
+
|
210 |
+
# forward the GPT model itself
|
211 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
212 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
213 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
214 |
+
for block in self.transformer.h:
|
215 |
+
x = block(x)
|
216 |
+
x = self.transformer.ln_f(x)
|
217 |
+
|
218 |
+
if targets is not None:
|
219 |
+
# if we are given some desired targets also calculate the loss
|
220 |
+
logits = self.lm_head(x)
|
221 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
222 |
+
else:
|
223 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
224 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
225 |
+
loss = None
|
226 |
+
|
227 |
+
return logits, loss
|
228 |
+
|
229 |
+
def crop_block_size(self, block_size):
|
230 |
+
# model surgery to decrease the block size if necessary
|
231 |
+
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
232 |
+
# but want to use a smaller block size for some smaller, simpler model
|
233 |
+
assert block_size <= self.config.block_size
|
234 |
+
self.config.block_size = block_size
|
235 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
236 |
+
for block in self.transformer.h:
|
237 |
+
if hasattr(block.attn, 'bias'):
|
238 |
+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def from_pretrained(cls, model_type, override_args=None):
|
242 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
243 |
+
override_args = override_args or {} # default to empty dict
|
244 |
+
# only dropout can be overridden see more notes below
|
245 |
+
assert all(k == 'dropout' for k in override_args)
|
246 |
+
from transformers import GPT2LMHeadModel
|
247 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
248 |
+
|
249 |
+
# n_layer, n_head and n_embd are determined from model_type
|
250 |
+
config_args = {
|
251 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
252 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
253 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
254 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
255 |
+
}[model_type]
|
256 |
+
print("forcing vocab_size=50257, block_size=1024, bias=True")
|
257 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
258 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
259 |
+
config_args['bias'] = True # always True for GPT model checkpoints
|
260 |
+
# we can override the dropout rate, if desired
|
261 |
+
if 'dropout' in override_args:
|
262 |
+
print(f"overriding dropout rate to {override_args['dropout']}")
|
263 |
+
config_args['dropout'] = override_args['dropout']
|
264 |
+
# create a from-scratch initialized minGPT model
|
265 |
+
config = GPTConfig(**config_args)
|
266 |
+
model = GPT(config)
|
267 |
+
sd = model.state_dict()
|
268 |
+
sd_keys = sd.keys()
|
269 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
270 |
+
|
271 |
+
# init a huggingface/transformers model
|
272 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
273 |
+
sd_hf = model_hf.state_dict()
|
274 |
+
|
275 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
276 |
+
sd_keys_hf = sd_hf.keys()
|
277 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
278 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
279 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
280 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
281 |
+
# this means that we have to transpose these weights when we import them
|
282 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
283 |
+
for k in sd_keys_hf:
|
284 |
+
if any(k.endswith(w) for w in transposed):
|
285 |
+
# special treatment for the Conv1D weights we need to transpose
|
286 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
287 |
+
with torch.no_grad():
|
288 |
+
sd[k].copy_(sd_hf[k].t())
|
289 |
+
else:
|
290 |
+
# vanilla copy over the other parameters
|
291 |
+
assert sd_hf[k].shape == sd[k].shape
|
292 |
+
with torch.no_grad():
|
293 |
+
sd[k].copy_(sd_hf[k])
|
294 |
+
|
295 |
+
return model
|
296 |
+
|
297 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
298 |
+
# start with all of the candidate parameters
|
299 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
300 |
+
# filter out those that do not require grad
|
301 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
302 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
303 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
304 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
305 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
306 |
+
optim_groups = [
|
307 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
308 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
309 |
+
]
|
310 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
311 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
312 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
313 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
314 |
+
# Create AdamW optimizer and use the fused version if it is available
|
315 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
316 |
+
use_fused = fused_available and device_type == 'cuda'
|
317 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
318 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
319 |
+
print(f"using fused AdamW: {use_fused}")
|
320 |
+
|
321 |
+
return optimizer
|
322 |
+
|
323 |
+
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
324 |
+
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
|
325 |
+
# first estimate the number of flops we do per iteration.
|
326 |
+
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
327 |
+
N = self.get_num_params()
|
328 |
+
cfg = self.config
|
329 |
+
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
|
330 |
+
flops_per_token = 6*N + 12*L*H*Q*T
|
331 |
+
flops_per_fwdbwd = flops_per_token * T
|
332 |
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
333 |
+
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
334 |
+
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
335 |
+
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
336 |
+
mfu = flops_achieved / flops_promised
|
337 |
+
return mfu
|
338 |
+
|
339 |
+
@torch.no_grad()
|
340 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
341 |
+
"""
|
342 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
343 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
344 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
345 |
+
"""
|
346 |
+
for _ in range(max_new_tokens):
|
347 |
+
# if the sequence context is growing too long we must crop it at block_size
|
348 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
349 |
+
# forward the model to get the logits for the index in the sequence
|
350 |
+
logits, _ = self(idx_cond)
|
351 |
+
# pluck the logits at the final step and scale by desired temperature
|
352 |
+
logits = logits[:, -1, :] / temperature
|
353 |
+
# optionally crop the logits to only the top k options
|
354 |
+
if top_k is not None:
|
355 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
356 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
357 |
+
# apply softmax to convert logits to (normalized) probabilities
|
358 |
+
probs = F.softmax(logits, dim=-1)
|
359 |
+
# sample from the distribution
|
360 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
361 |
+
# append sampled index to the running sequence and continue
|
362 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
363 |
+
|
364 |
+
return idx
|
365 |
+
|
preprocessing.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#./experiments/experiment1/preprocessing.py
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sqlite3
|
5 |
+
from transformers import GPT2TokenizerFast
|
6 |
+
from datasets import load_dataset
|
7 |
+
|
8 |
+
class DatabaseInterface(object):
|
9 |
+
def __init__(self, db_file):
|
10 |
+
self.db_file = db_file
|
11 |
+
|
12 |
+
def create_table(self, table_name=None):
|
13 |
+
conn = sqlite3.connect(self.db_file)
|
14 |
+
c = conn.cursor()
|
15 |
+
c.execute(
|
16 |
+
'''
|
17 |
+
CREATE TABLE IF NOT EXISTS plain_text (
|
18 |
+
text TEXT,
|
19 |
+
split TEXT
|
20 |
+
)
|
21 |
+
'''
|
22 |
+
)
|
23 |
+
conn.commit()
|
24 |
+
conn.close()
|
25 |
+
|
26 |
+
def write_plain_text(self, example, split):
|
27 |
+
conn = sqlite3.connect(self.db_file)
|
28 |
+
c = conn.cursor()
|
29 |
+
c.execute("INSERT INTO plain_text (text, split) VALUES (?, ?)",
|
30 |
+
(example, split))
|
31 |
+
conn.commit()
|
32 |
+
conn.close()
|
33 |
+
|
34 |
+
|
35 |
+
def process_and_write(example, writer, split):
|
36 |
+
writer.write_plain_text(example, split)
|
37 |
+
|
38 |
+
|
39 |
+
def prepare_data(start_index, end_index, **kwargs):
|
40 |
+
data_writer = kwargs['data_writer']
|
41 |
+
train_dataset = kwargs['train_dataset']
|
42 |
+
val_dataset = kwargs['val_dataset']
|
43 |
+
|
44 |
+
for split, dataset in {'val': val_dataset, 'train': train_dataset}.items():
|
45 |
+
subset = dataset[start_index:end_index] # Select the subset based on start and end indices
|
46 |
+
|
47 |
+
if isinstance(subset, dict):
|
48 |
+
subset = subset["text"] # Extract the "text" part from the subset dictionary
|
49 |
+
|
50 |
+
for example in subset:
|
51 |
+
process_and_write(example, data_writer, split)
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
logging.basicConfig(
|
56 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
57 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
58 |
+
level=logging.INFO
|
59 |
+
)
|
60 |
+
|
61 |
+
# Configs
|
62 |
+
batch_size = 32
|
63 |
+
num_processes = 4 # number of jobs to run simultaneously
|
64 |
+
|
65 |
+
logging.info("Creating Database Interface")
|
66 |
+
db_file_path = os.path.join('data', 'experiment1.db')
|
67 |
+
|
68 |
+
_delete_db = True
|
69 |
+
|
70 |
+
# Check to see if the database file already exists
|
71 |
+
if os.path.exists(db_file_path):
|
72 |
+
if _delete_db:
|
73 |
+
logging.info(f"Database file {db_file_path} already exists. Deleting it.")
|
74 |
+
os.remove(db_file_path)
|
75 |
+
data_writer = DatabaseInterface(db_file_path)
|
76 |
+
data_writer.create_table()
|
77 |
+
logging.info("Database table `plain_text` created")
|
78 |
+
else:
|
79 |
+
logging.info(f"Database file {db_file_path} already exists. Connecting to it.")
|
80 |
+
data_writer = DatabaseInterface(db_file_path)
|
81 |
+
else:
|
82 |
+
data_writer = DatabaseInterface(db_file_path)
|
83 |
+
data_writer.create_table()
|
84 |
+
logging.info("Database table `plain_text` created")
|
85 |
+
|
86 |
+
#cache_dir=os.path.join(
|
87 |
+
# 'C:/Users/User/.cache/huggingface/datasets/openwebtext/plain_text',
|
88 |
+
# '1.0.0',
|
89 |
+
# '6f68e85c16ccc770c0dd489f4008852ea9633604995addd0cd76e293aed9e521'
|
90 |
+
#)
|
91 |
+
|
92 |
+
dataset = load_dataset(
|
93 |
+
"openwebtext",
|
94 |
+
cache_dir=cache_dir,
|
95 |
+
num_proc=num_processes,
|
96 |
+
save_infos = True,
|
97 |
+
writer_batch_size=batch_size
|
98 |
+
|
99 |
+
)
|
100 |
+
|
101 |
+
split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42, shuffle=False)
|
102 |
+
train_dataset = split_dataset["train"]
|
103 |
+
val_dataset = split_dataset["test"]
|
104 |
+
|
105 |
+
prepare_data(
|
106 |
+
start_index=0,
|
107 |
+
end_index=1000,
|
108 |
+
**{
|
109 |
+
'data_writer': data_writer,
|
110 |
+
'train_dataset': train_dataset,
|
111 |
+
'val_dataset': val_dataset,
|
112 |
+
}
|
113 |
+
)
|
train.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#./experiments/experiment1/train.py
|
2 |
+
import logging
|
3 |
+
import pickle
|
4 |
+
import sqlite3
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import torch.optim as optim
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import transformers
|
11 |
+
|
12 |
+
from model import RotationallyInvariantGPT, RotationallyInvariantGPTConfig
|
13 |
+
from prereqs.nanoGPT.model import GPTConfig, GPT, MLP
|
14 |
+
from datasets import load_from_disk
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
|
17 |
+
|
18 |
+
from transformers import GPT2TokenizerFast
|
19 |
+
|
20 |
+
from torch.nn.utils.rnn import pad_sequence
|
21 |
+
|
22 |
+
def pad_collate(batch):
|
23 |
+
# Separating inputs and labels
|
24 |
+
inputs = [d['input_ids'] for d in batch]
|
25 |
+
labels = [d['labels'] for d in batch]
|
26 |
+
|
27 |
+
# Padding the input sequences
|
28 |
+
input_tensor = pad_sequence(inputs, batch_first=True)
|
29 |
+
|
30 |
+
# Padding the labels sequences
|
31 |
+
label_tensor = pad_sequence(labels, batch_first=True)
|
32 |
+
|
33 |
+
return {'input_ids': input_tensor, 'labels': label_tensor}
|
34 |
+
|
35 |
+
class DatabaseInterface(object):
|
36 |
+
def __init__(self, db_file):
|
37 |
+
self.db_file = db_file
|
38 |
+
|
39 |
+
def read(self, split):
|
40 |
+
conn = sqlite3.connect(self.db_file)
|
41 |
+
c = conn.cursor()
|
42 |
+
c.execute(f"SELECT * FROM plain_text WHERE split='{split}'")
|
43 |
+
col_names = [desc[0] for desc in c.description] # get column names
|
44 |
+
results = [dict(zip(col_names, row)) for row in c.fetchall()] # convert tuples to dictionaries
|
45 |
+
conn.close()
|
46 |
+
return results
|
47 |
+
|
48 |
+
|
49 |
+
class PlainTextDataset(torch.utils.data.Dataset):
|
50 |
+
def __init__(self, plain_text_dataset, tokenizer, device):
|
51 |
+
self.plain_text_dataset = plain_text_dataset
|
52 |
+
self.tokenizer = tokenizer
|
53 |
+
self.device = device
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.plain_text_dataset)
|
57 |
+
|
58 |
+
def __getitem__(self, idx):
|
59 |
+
item = self.plain_text_dataset[idx]
|
60 |
+
tokens = self.tokenizer.encode_plus(item["text"], truncation=True, max_length=512, padding="max_length")
|
61 |
+
input_ids = tokens["input_ids"]
|
62 |
+
attention_mask = tokens["attention_mask"]
|
63 |
+
return {
|
64 |
+
'input_ids': torch.as_tensor(input_ids[:-1], dtype=torch.long).to(self.device),
|
65 |
+
'attention_mask': torch.as_tensor(attention_mask[:-1], dtype=torch.long).to(self.device),
|
66 |
+
'labels': torch.as_tensor(input_ids[1:], dtype=torch.long).to(self.device)
|
67 |
+
}
|
68 |
+
|
69 |
+
def train(model: nn.Module, optimizer: optim.Optimizer, train_loader: DataLoader) -> float:
|
70 |
+
model.train()
|
71 |
+
running_loss = 0
|
72 |
+
for i, batch in enumerate(train_loader):
|
73 |
+
inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
|
74 |
+
optimizer.zero_grad()
|
75 |
+
outputs, loss = model(inputs, targets)
|
76 |
+
loss.backward()
|
77 |
+
optimizer.step()
|
78 |
+
running_loss += loss.item()
|
79 |
+
if i % 100 == 0:
|
80 |
+
logging.info(f"Batch {i}: Loss={loss.item()}")
|
81 |
+
return running_loss / len(train_loader)
|
82 |
+
|
83 |
+
|
84 |
+
def evaluate(model, valid_loader) -> float:
|
85 |
+
model.eval()
|
86 |
+
running_loss = 0
|
87 |
+
with torch.no_grad():
|
88 |
+
for i, batch in enumerate(valid_loader):
|
89 |
+
inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
|
90 |
+
outputs = model(inputs, targets)
|
91 |
+
loss = outputs.loss
|
92 |
+
running_loss += loss.item()
|
93 |
+
if i % 100 == 0:
|
94 |
+
logging.info(f"Batch {i}: Validation Loss={loss.item()}")
|
95 |
+
return running_loss / len(valid_loader)
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
logging.basicConfig(
|
99 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
100 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
101 |
+
level=logging.INFO
|
102 |
+
)
|
103 |
+
logging.info(f"PyTorch version: {torch.__version__}")
|
104 |
+
logging.info(f"Torchvision version: {torchvision.__version__}")
|
105 |
+
logging.info(f"Transformers version: {transformers.__version__}")
|
106 |
+
logging.info(f"CUDA version: {torch.version.cuda}")
|
107 |
+
logging.info(f"cuDNN version: {torch.backends.cudnn.version()}")
|
108 |
+
|
109 |
+
logging.info("Clearing cuda cache...")
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
logging.info("Setting num_threads to 1...")
|
113 |
+
torch.set_num_threads(1)
|
114 |
+
|
115 |
+
# Configs
|
116 |
+
d_model = 512
|
117 |
+
num_heads = 4
|
118 |
+
num_layers = 1
|
119 |
+
block_size = 512
|
120 |
+
dropout = 0.2
|
121 |
+
bias = True
|
122 |
+
rotational = True
|
123 |
+
batch_size = 32
|
124 |
+
eval_batch_size = 64
|
125 |
+
epochs = 10
|
126 |
+
lr = 0.001
|
127 |
+
|
128 |
+
vocab_size = 50304 # GPT-2 tokenizer vocab size
|
129 |
+
logging.info(f"Vocab size: {vocab_size}")
|
130 |
+
|
131 |
+
logging.info(f'''
|
132 |
+
Config:
|
133 |
+
d_model={d_model},
|
134 |
+
num_heads={num_heads},
|
135 |
+
num_layers={num_layers},
|
136 |
+
block_size={block_size},
|
137 |
+
dropout={dropout}, bias={bias}
|
138 |
+
'''
|
139 |
+
)
|
140 |
+
logging.info(
|
141 |
+
f"Training for {epochs} epochs with a learning rate of {lr}..."
|
142 |
+
)
|
143 |
+
|
144 |
+
logging.info(f"Batch size: {batch_size}")
|
145 |
+
logging.info(f"Eval batch size: {eval_batch_size}")
|
146 |
+
|
147 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
148 |
+
# device = torch.device("cpu")
|
149 |
+
logging.info(f"Device: {device}")
|
150 |
+
|
151 |
+
logging.info("Loading tokenizer")
|
152 |
+
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
|
153 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
154 |
+
|
155 |
+
# Query the database for the tokenized data
|
156 |
+
logging.info("Querying plain text data...")
|
157 |
+
|
158 |
+
db_file_path = "data/experiment1.db"
|
159 |
+
|
160 |
+
plain_text_train = DatabaseInterface(db_file_path).read("train")
|
161 |
+
#logging.debug(f"Plain text train: {plain_text_train[:10]}")
|
162 |
+
|
163 |
+
plain_text_val = DatabaseInterface(db_file_path).read("val")
|
164 |
+
#logging.debug(f"Plain text val: {plain_text_val[:10]}")
|
165 |
+
|
166 |
+
# Create train/val dataset objects
|
167 |
+
train_dataset = PlainTextDataset(plain_text_train, tokenizer, device)
|
168 |
+
valid_dataset = PlainTextDataset(plain_text_val, tokenizer, device)
|
169 |
+
|
170 |
+
|
171 |
+
# DEBUG
|
172 |
+
#for idx, item in enumerate(train_dataset):
|
173 |
+
# input_ids = item["input_ids"]
|
174 |
+
# attention_mask = item["attention_mask"]
|
175 |
+
# if input_ids.size(0) == 0:
|
176 |
+
# print(f"Sample index with 0 length: {idx}")
|
177 |
+
# print(f"Input_ids: {input_ids}")
|
178 |
+
# print(f"Attention_mask: {attention_mask}")
|
179 |
+
|
180 |
+
# Calculate the number of batches
|
181 |
+
num_train_batches = len(train_dataset) // batch_size
|
182 |
+
num_eval_batches = len(valid_dataset) // eval_batch_size
|
183 |
+
|
184 |
+
|
185 |
+
logging.info(f"Number of train batches: {num_train_batches}")
|
186 |
+
logging.info(f"Number of eval batches: {num_eval_batches}")
|
187 |
+
|
188 |
+
train_loader = DataLoader(
|
189 |
+
train_dataset,
|
190 |
+
batch_size=batch_size,
|
191 |
+
shuffle=False,
|
192 |
+
collate_fn=pad_collate
|
193 |
+
)
|
194 |
+
|
195 |
+
valid_loader = DataLoader(
|
196 |
+
valid_dataset,
|
197 |
+
batch_size=eval_batch_size,
|
198 |
+
shuffle=False,
|
199 |
+
collate_fn=pad_collate
|
200 |
+
)
|
201 |
+
|
202 |
+
# gpt_config = GPTConfig(
|
203 |
+
# vocab_size=vocab_size,
|
204 |
+
# n_embd=d_model,
|
205 |
+
# n_head=num_heads,
|
206 |
+
# n_layer=num_layers,
|
207 |
+
# block_size=block_size,
|
208 |
+
# dropout=dropout,
|
209 |
+
# bias=bias
|
210 |
+
#)
|
211 |
+
|
212 |
+
rigpt_config = RotationallyInvariantGPTConfig(
|
213 |
+
vocab_size=vocab_size,
|
214 |
+
n_embd=d_model,
|
215 |
+
n_head=num_heads,
|
216 |
+
n_layer=num_layers,
|
217 |
+
block_size=block_size,
|
218 |
+
dropout=dropout,
|
219 |
+
bias=bias,
|
220 |
+
rotational_invariance=rotational
|
221 |
+
)
|
222 |
+
|
223 |
+
logging.info("Creating models...")
|
224 |
+
# gpt = GPT(gpt_config).to(device)
|
225 |
+
rigpt = RotationallyInvariantGPT(rigpt_config).to(device)
|
226 |
+
|
227 |
+
logging.info("Creating optimizers...")
|
228 |
+
# optimizer_gpt = optim.Adam(gpt.parameters(), lr=lr)
|
229 |
+
optimizer_rigpt = optim.Adam(rigpt.parameters(), lr=lr)
|
230 |
+
|
231 |
+
logging.info("Training...")
|
232 |
+
for model, optimizer, model_name in [
|
233 |
+
# (
|
234 |
+
# gpt,
|
235 |
+
# optimizer_gpt,
|
236 |
+
# 'GPT'
|
237 |
+
# ),
|
238 |
+
(
|
239 |
+
rigpt,
|
240 |
+
optimizer_rigpt,
|
241 |
+
'RotationallyInvariantGPT'
|
242 |
+
)
|
243 |
+
]:
|
244 |
+
print(f"Training {model_name}")
|
245 |
+
for epoch in range(1, epochs + 1):
|
246 |
+
print(f"Training epoch {epoch}")
|
247 |
+
train_loss = train(model, optimizer, train_loader)
|
248 |
+
print(f"Validating epoch {epoch}")
|
249 |
+
valid_loss = evaluate(model, num_eval_batches)
|
250 |
+
print(
|
251 |
+
f'''
|
252 |
+
{model_name} -
|
253 |
+
Epoch: {epoch},
|
254 |
+
Train loss: {train_loss:.3f},
|
255 |
+
Validation loss: {valid_loss:.3f}'
|
256 |
+
'''
|
257 |
+
)
|
258 |
+
|
259 |
+
# torch.save(gpt.state_dict(), "gpt.pt")
|
260 |
+
torch.save(rigpt.state_dict(), "rigpt.pt")
|