shng2025 commited on
Commit
10f91fd
1 Parent(s): e7ccce4

step 70000

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ log/debug_0.log filter=lfs diff=lfs merge=lfs -text
gptesla_checkpoint_training.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is constructed so that one can easily continue pre training their checkpointed model on HF repo.
3
+ So that even in the event of a model crash, one can easily continue training based on the current state! Very convenient!
4
+
5
+ How to use:
6
+ 1. git clone the repo
7
+ 2. git checkout to current branch
8
+ 3. accelerate config, then accelerate run!
9
+ """
10
+
11
+ import os
12
+
13
+ import datasets, transformers
14
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
15
+ from transformers.optimization import get_scheduler
16
+ from datasets import load_dataset, DownloadConfig
17
+
18
+ import torch
19
+ from torch.utils.data import IterableDataset
20
+ from torch.utils.data.dataloader import DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from torch.optim import AdamW
23
+
24
+ import logging
25
+ import wandb
26
+ from huggingface_hub import Repository, create_branch
27
+ from accelerate import Accelerator
28
+ from argparse import Namespace
29
+
30
+
31
+ # Set the API token as an environment variable
32
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
33
+
34
+
35
+ def save_checkpoint_state():
36
+
37
+ dir_name = "./torch_checkpoint"
38
+ os.makedirs(dir_name, exist_ok=True)
39
+
40
+ checkpoint = {
41
+ "lr_scheduler": lr_scheduler.state_dict(),
42
+ "completed_steps": completed_steps,
43
+ "run_name": run_name,
44
+ "optimizer": optimizer.state_dict(),
45
+ "run_id": wandb_id
46
+ }
47
+ torch.save(checkpoint, f"torch_checkpoint/latest_checkpoint.pth")
48
+
49
+
50
+ def load_checkpoint_torch(lr_scheduler, completed_steps, run_name, optimizer, wandb_id):
51
+
52
+ checkpoint = torch.load(f"torch_checkpoint/latest_checkpoint.pth")
53
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
54
+ completed_steps = checkpoint["completed_steps"]
55
+ run_name = checkpoint["run_name"]
56
+ optimizer.load_state_dict(checkpoint["optimizer"])
57
+ wandb_id = checkpoint["run_id"]
58
+
59
+ return lr_scheduler, completed_steps, run_name, optimizer, wandb_id
60
+
61
+
62
+ class ConstantLengthDataset(IterableDataset):
63
+
64
+ def __init__(
65
+ self,
66
+ tokenizer,
67
+ dataset,
68
+ seq_length=1024,
69
+ num_of_sequences=1024,
70
+ chars_per_token=3.6,
71
+ ):
72
+ self.tokenizer = tokenizer
73
+ self.concat_token_id = tokenizer.eos_token_id
74
+ self.dataset = dataset
75
+ self.seq_length = seq_length
76
+ self.input_characters = seq_length * chars_per_token * num_of_sequences
77
+
78
+ def __iter__(self):
79
+ iterator = iter(self.dataset)
80
+ more_examples = True
81
+ while more_examples:
82
+ buffer, buffer_len = [], 0
83
+ while True:
84
+ if buffer_len >= self.input_characters:
85
+ m = f"Buffer full: {buffer_len}>={self.input_characters:.0f}"
86
+ # print(m)
87
+ break
88
+ try:
89
+ m = f"Fill buffer: {buffer_len}<{self.input_characters:.0f}"
90
+ # print(m)
91
+ buffer.append(next(iterator)["content"])
92
+ buffer_len += len(buffer[-1])
93
+ except StopIteration:
94
+ # iterator = iter(self.dataset)
95
+ more_examples = False
96
+ break
97
+
98
+ all_token_ids = []
99
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)
100
+ for tokenized_input in tokenized_inputs["input_ids"]:
101
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
102
+
103
+ for i in range(0, len(all_token_ids), self.seq_length):
104
+ input_ids = all_token_ids[i : i + self.seq_length]
105
+ if len(input_ids) == self.seq_length:
106
+ yield torch.tensor(input_ids)
107
+
108
+
109
+ def continue_logging(project_name, run_id):
110
+ logger = logging.getLogger(__name__)
111
+
112
+ dir_name = "./log"
113
+ if not os.path.exists(dir_name):
114
+ os.makedirs(dir_name)
115
+ print(f"Directory '{dir_name}' was created.")
116
+ else:
117
+ print(f"Directory '{dir_name}' already exists.")
118
+
119
+ # setting up log directory
120
+ logging.basicConfig(
121
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
122
+ datefmt="%m/%d/%Y %H:%M:%S",
123
+ level=logging.INFO,
124
+ handlers=[
125
+ logging.FileHandler(f"log/debug_{accelerator.process_index}.log"),
126
+ logging.StreamHandler(),
127
+ ],
128
+ )
129
+
130
+ if accelerator.is_main_process: # We only want to set up logging once
131
+ #wandb.init(project=project_name, config=args, dir="./../")
132
+ wandb.init(project=project_name, id=run_id, resume="must", config=args, dir='./../')
133
+ run_name = wandb.run.name
134
+ tb_writer = SummaryWriter()
135
+ tb_writer.add_hparams(vars(args), {"0": 0})
136
+ logger.setLevel(logging.INFO)
137
+ datasets.utils.logging.set_verbosity_debug()
138
+ transformers.utils.logging.set_verbosity_info()
139
+ else:
140
+ tb_writer = None
141
+ run_name = ""
142
+ logger.setLevel(logging.ERROR)
143
+ datasets.utils.logging.set_verbosity_error()
144
+ transformers.utils.logging.set_verbosity_error()
145
+
146
+ return logger, tb_writer, run_name
147
+
148
+ def create_dataloaders(dataset_name):
149
+ train_data = load_dataset(dataset_name + "-train", split="train", streaming=True)
150
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
151
+ valid_data = load_dataset(dataset_name + "-valid", split="validation", streaming=True)
152
+
153
+ train_dataset = ConstantLengthDataset(tokenizer, train_data, seq_length=args.seq_length)
154
+ valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length)
155
+
156
+ train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=96)
157
+ eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, num_workers=1)
158
+ return train_dataloader, eval_dataloader
159
+
160
+
161
+ def log_metrics(step, metrics):
162
+ logger.info(f"Step {step}: {metrics}")
163
+ if accelerator.is_main_process:
164
+ wandb.log(metrics)
165
+ [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
166
+
167
+ def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
168
+ params_with_wd, params_without_wd = [], []
169
+ for n, p in model.named_parameters():
170
+ if any(nd in n for nd in no_decay):
171
+ params_without_wd.append(p)
172
+ else:
173
+ params_with_wd.append(p)
174
+ return [
175
+ {"params": params_with_wd, "weight_decay": args.weight_decay},
176
+ {"params": params_without_wd, "weight_decay": 0.0},
177
+ ]
178
+
179
+ def evaluate():
180
+ model.eval()
181
+ losses = []
182
+ for step, batch in enumerate(eval_dataloader):
183
+ with torch.no_grad():
184
+ outputs = model(batch, labels=batch)
185
+ loss = outputs.loss.repeat(args.valid_batch_size)
186
+ losses.append(accelerator.gather(loss))
187
+ if args.max_eval_steps > 0 and step >= args.max_eval_steps:
188
+ break
189
+ loss = torch.mean(torch.cat(losses))
190
+
191
+ try:
192
+ perplexity = torch.exp(loss)
193
+ except OverflowError:
194
+ perplexity = torch.tensor(float("inf"))
195
+
196
+ return loss.item(), perplexity.item()
197
+
198
+
199
+ # Accelerator
200
+ accelerator = Accelerator(dispatch_batches=True)
201
+ acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
202
+
203
+ project_name = "shng2025/gptesla-small"
204
+ dataset_name = "shng2025/gptesla"
205
+
206
+ # GPTesla - 111M param setup in comment. Modification to make lighter training requirement needed
207
+ config = {
208
+ "train_batch_size": 12, # 12
209
+ "valid_batch_size": 12, # 12
210
+ "weight_decay": 0.1,
211
+ "shuffle_buffer": 1000,
212
+ "learning_rate": 5e-4, # 5e-4
213
+ "lr_scheduler_type": "cosine",
214
+ "num_warmup_steps": 700, # 2000
215
+ "gradient_accumulation_steps": 1, # 1
216
+ "max_train_steps": 150000, # 150000
217
+ "max_eval_steps": 10,
218
+ "seq_length": 1024,
219
+ "seed": 1,
220
+ "save_checkpoint_steps": 10000,
221
+ } # 15000
222
+
223
+ args = Namespace(**config, **acc_state)
224
+ samples_per_step = accelerator.state.num_processes * args.train_batch_size
225
+ set_seed(args.seed)
226
+
227
+
228
+ model = AutoModelForCausalLM.from_pretrained("./") # , gradient_checkpointing=True)
229
+ tokenizer = AutoTokenizer.from_pretrained("./")
230
+
231
+ # Load dataset and dataloader
232
+ train_dataloader, eval_dataloader = create_dataloaders(dataset_name)
233
+
234
+
235
+ # Loading torch checkpoint
236
+ optimizer = AdamW(get_grouped_params(model), lr=args.learning_rate)
237
+ lr_scheduler = get_scheduler(
238
+ name=args.lr_scheduler_type,
239
+ optimizer=optimizer,
240
+ num_warmup_steps=args.num_warmup_steps,
241
+ num_training_steps=args.max_train_steps,
242
+ )
243
+ completed_steps = 0
244
+ run_name = ""
245
+ wandb_id = ""
246
+ lr_scheduler, completed_steps, run_name, optimizer, wandb_id = load_checkpoint_torch(lr_scheduler, completed_steps, run_name, optimizer, wandb_id)
247
+
248
+ logger, tb_writer, run_name = continue_logging(project_name.split("/")[1], wandb_id)
249
+
250
+
251
+ # Load model and tokenizer
252
+ if accelerator.is_main_process:
253
+ hf_repo = Repository("./", clone_from=project_name, revision=run_name)
254
+
255
+ def get_lr():
256
+ return optimizer.param_groups[0]["lr"]
257
+
258
+
259
+ # advancing dataloader to correct position
260
+ for i, _ in enumerate(train_dataloader):
261
+ if i >= completed_steps:
262
+ break
263
+ for i, _ in enumerate(eval_dataloader):
264
+ if i >= (completed_steps // args.save_checkpoint_steps) * args.max_eval_steps:
265
+ break
266
+
267
+ # Prepare everything with our `accelerator` (order of args is not important)
268
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
269
+ model, optimizer, train_dataloader, eval_dataloader
270
+ )
271
+
272
+ # Train model
273
+ model.train()
274
+ for step, batch in enumerate(train_dataloader, start=completed_steps + 1):
275
+ loss = model(batch, labels=batch).loss
276
+ log_metrics(
277
+ step,
278
+ {
279
+ "lr": get_lr(),
280
+ "samples": step * samples_per_step,
281
+ "steps": completed_steps,
282
+ "loss/train": loss.item(),
283
+ },
284
+ )
285
+ loss = loss / args.gradient_accumulation_steps
286
+ accelerator.backward(loss)
287
+ if step % args.gradient_accumulation_steps == 0:
288
+ optimizer.step()
289
+ lr_scheduler.step()
290
+ optimizer.zero_grad()
291
+ completed_steps += 1
292
+ if step % args.save_checkpoint_steps == 0:
293
+ logger.info("Evaluating and saving model checkpoint")
294
+ eval_loss, perplexity = evaluate()
295
+ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
296
+ accelerator.wait_for_everyone()
297
+ unwrapped_model = accelerator.unwrap_model(model)
298
+ if accelerator.is_main_process:
299
+ save_checkpoint_state()
300
+ unwrapped_model.save_pretrained("./")
301
+ hf_repo.push_to_hub(commit_message=f"step {step}")
302
+ model.train()
303
+ if completed_steps >= args.max_train_steps:
304
+ break
305
+
306
+
307
+ # Evaluate and save the last checkpoint
308
+ logger.info("Evaluating and saving model after training")
309
+ eval_loss, perplexity = evaluate()
310
+ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
311
+ accelerator.wait_for_everyone()
312
+ unwrapped_model = accelerator.unwrap_model(model)
313
+ if accelerator.is_main_process:
314
+ unwrapped_model.save_pretrained("./")
315
+ hf_repo.push_to_hub(commit_message="final model")
log/debug_0.log CHANGED
The diff for this file is too large to render. See raw diff
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e320c0d6e200e8b4fce80cb62d51ba129f03a127dfcaf25bb7fffd31cb5ed45
3
  size 444048000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11de17d222f5b1546c373114e7e6d47ed779e41a49951d89c57616611704c0cc
3
  size 444048000
runs/Jul26_05-42-20_lab/1721972540.4948585/events.out.tfevents.1721972540.lab.4339.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f4c4f60cc7e9fa817fcc42e2cf7e93cf80d5e4dd2009603d9dacfe13d7e1a1d
3
+ size 1702
runs/Jul26_05-42-20_lab/events.out.tfevents.1721972540.lab.4339.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7f68642110473ee1af15c6854a625af03908dc925bc39787912dee9aed1bf1a
3
+ size 1840187
torch_checkpoint/latest_checkpoint.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bbf6589a836ecca8da580bd32cc349eba2b67e6d2bb6f38b1f90bab1da619353
3
- size 888193914
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:959808934631d8d815e6f7e939714b002883d6bcc4297148768f99716e762d57
3
+ size 888195962