step 70000
Browse files- .gitattributes +1 -0
- gptesla_checkpoint_training.py +315 -0
- log/debug_0.log +0 -0
- model.safetensors +1 -1
- runs/Jul26_05-42-20_lab/1721972540.4948585/events.out.tfevents.1721972540.lab.4339.1 +3 -0
- runs/Jul26_05-42-20_lab/events.out.tfevents.1721972540.lab.4339.0 +3 -0
- torch_checkpoint/latest_checkpoint.pth +2 -2
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
log/debug_0.log filter=lfs diff=lfs merge=lfs -text
|
gptesla_checkpoint_training.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is constructed so that one can easily continue pre training their checkpointed model on HF repo.
|
3 |
+
So that even in the event of a model crash, one can easily continue training based on the current state! Very convenient!
|
4 |
+
|
5 |
+
How to use:
|
6 |
+
1. git clone the repo
|
7 |
+
2. git checkout to current branch
|
8 |
+
3. accelerate config, then accelerate run!
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
import datasets, transformers
|
14 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
|
15 |
+
from transformers.optimization import get_scheduler
|
16 |
+
from datasets import load_dataset, DownloadConfig
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch.utils.data import IterableDataset
|
20 |
+
from torch.utils.data.dataloader import DataLoader
|
21 |
+
from torch.utils.tensorboard import SummaryWriter
|
22 |
+
from torch.optim import AdamW
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import wandb
|
26 |
+
from huggingface_hub import Repository, create_branch
|
27 |
+
from accelerate import Accelerator
|
28 |
+
from argparse import Namespace
|
29 |
+
|
30 |
+
|
31 |
+
# Set the API token as an environment variable
|
32 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
33 |
+
|
34 |
+
|
35 |
+
def save_checkpoint_state():
|
36 |
+
|
37 |
+
dir_name = "./torch_checkpoint"
|
38 |
+
os.makedirs(dir_name, exist_ok=True)
|
39 |
+
|
40 |
+
checkpoint = {
|
41 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
42 |
+
"completed_steps": completed_steps,
|
43 |
+
"run_name": run_name,
|
44 |
+
"optimizer": optimizer.state_dict(),
|
45 |
+
"run_id": wandb_id
|
46 |
+
}
|
47 |
+
torch.save(checkpoint, f"torch_checkpoint/latest_checkpoint.pth")
|
48 |
+
|
49 |
+
|
50 |
+
def load_checkpoint_torch(lr_scheduler, completed_steps, run_name, optimizer, wandb_id):
|
51 |
+
|
52 |
+
checkpoint = torch.load(f"torch_checkpoint/latest_checkpoint.pth")
|
53 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
54 |
+
completed_steps = checkpoint["completed_steps"]
|
55 |
+
run_name = checkpoint["run_name"]
|
56 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
57 |
+
wandb_id = checkpoint["run_id"]
|
58 |
+
|
59 |
+
return lr_scheduler, completed_steps, run_name, optimizer, wandb_id
|
60 |
+
|
61 |
+
|
62 |
+
class ConstantLengthDataset(IterableDataset):
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
tokenizer,
|
67 |
+
dataset,
|
68 |
+
seq_length=1024,
|
69 |
+
num_of_sequences=1024,
|
70 |
+
chars_per_token=3.6,
|
71 |
+
):
|
72 |
+
self.tokenizer = tokenizer
|
73 |
+
self.concat_token_id = tokenizer.eos_token_id
|
74 |
+
self.dataset = dataset
|
75 |
+
self.seq_length = seq_length
|
76 |
+
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
77 |
+
|
78 |
+
def __iter__(self):
|
79 |
+
iterator = iter(self.dataset)
|
80 |
+
more_examples = True
|
81 |
+
while more_examples:
|
82 |
+
buffer, buffer_len = [], 0
|
83 |
+
while True:
|
84 |
+
if buffer_len >= self.input_characters:
|
85 |
+
m = f"Buffer full: {buffer_len}>={self.input_characters:.0f}"
|
86 |
+
# print(m)
|
87 |
+
break
|
88 |
+
try:
|
89 |
+
m = f"Fill buffer: {buffer_len}<{self.input_characters:.0f}"
|
90 |
+
# print(m)
|
91 |
+
buffer.append(next(iterator)["content"])
|
92 |
+
buffer_len += len(buffer[-1])
|
93 |
+
except StopIteration:
|
94 |
+
# iterator = iter(self.dataset)
|
95 |
+
more_examples = False
|
96 |
+
break
|
97 |
+
|
98 |
+
all_token_ids = []
|
99 |
+
tokenized_inputs = self.tokenizer(buffer, truncation=False)
|
100 |
+
for tokenized_input in tokenized_inputs["input_ids"]:
|
101 |
+
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
102 |
+
|
103 |
+
for i in range(0, len(all_token_ids), self.seq_length):
|
104 |
+
input_ids = all_token_ids[i : i + self.seq_length]
|
105 |
+
if len(input_ids) == self.seq_length:
|
106 |
+
yield torch.tensor(input_ids)
|
107 |
+
|
108 |
+
|
109 |
+
def continue_logging(project_name, run_id):
|
110 |
+
logger = logging.getLogger(__name__)
|
111 |
+
|
112 |
+
dir_name = "./log"
|
113 |
+
if not os.path.exists(dir_name):
|
114 |
+
os.makedirs(dir_name)
|
115 |
+
print(f"Directory '{dir_name}' was created.")
|
116 |
+
else:
|
117 |
+
print(f"Directory '{dir_name}' already exists.")
|
118 |
+
|
119 |
+
# setting up log directory
|
120 |
+
logging.basicConfig(
|
121 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
122 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
123 |
+
level=logging.INFO,
|
124 |
+
handlers=[
|
125 |
+
logging.FileHandler(f"log/debug_{accelerator.process_index}.log"),
|
126 |
+
logging.StreamHandler(),
|
127 |
+
],
|
128 |
+
)
|
129 |
+
|
130 |
+
if accelerator.is_main_process: # We only want to set up logging once
|
131 |
+
#wandb.init(project=project_name, config=args, dir="./../")
|
132 |
+
wandb.init(project=project_name, id=run_id, resume="must", config=args, dir='./../')
|
133 |
+
run_name = wandb.run.name
|
134 |
+
tb_writer = SummaryWriter()
|
135 |
+
tb_writer.add_hparams(vars(args), {"0": 0})
|
136 |
+
logger.setLevel(logging.INFO)
|
137 |
+
datasets.utils.logging.set_verbosity_debug()
|
138 |
+
transformers.utils.logging.set_verbosity_info()
|
139 |
+
else:
|
140 |
+
tb_writer = None
|
141 |
+
run_name = ""
|
142 |
+
logger.setLevel(logging.ERROR)
|
143 |
+
datasets.utils.logging.set_verbosity_error()
|
144 |
+
transformers.utils.logging.set_verbosity_error()
|
145 |
+
|
146 |
+
return logger, tb_writer, run_name
|
147 |
+
|
148 |
+
def create_dataloaders(dataset_name):
|
149 |
+
train_data = load_dataset(dataset_name + "-train", split="train", streaming=True)
|
150 |
+
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
151 |
+
valid_data = load_dataset(dataset_name + "-valid", split="validation", streaming=True)
|
152 |
+
|
153 |
+
train_dataset = ConstantLengthDataset(tokenizer, train_data, seq_length=args.seq_length)
|
154 |
+
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length)
|
155 |
+
|
156 |
+
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=96)
|
157 |
+
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, num_workers=1)
|
158 |
+
return train_dataloader, eval_dataloader
|
159 |
+
|
160 |
+
|
161 |
+
def log_metrics(step, metrics):
|
162 |
+
logger.info(f"Step {step}: {metrics}")
|
163 |
+
if accelerator.is_main_process:
|
164 |
+
wandb.log(metrics)
|
165 |
+
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
166 |
+
|
167 |
+
def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
|
168 |
+
params_with_wd, params_without_wd = [], []
|
169 |
+
for n, p in model.named_parameters():
|
170 |
+
if any(nd in n for nd in no_decay):
|
171 |
+
params_without_wd.append(p)
|
172 |
+
else:
|
173 |
+
params_with_wd.append(p)
|
174 |
+
return [
|
175 |
+
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
176 |
+
{"params": params_without_wd, "weight_decay": 0.0},
|
177 |
+
]
|
178 |
+
|
179 |
+
def evaluate():
|
180 |
+
model.eval()
|
181 |
+
losses = []
|
182 |
+
for step, batch in enumerate(eval_dataloader):
|
183 |
+
with torch.no_grad():
|
184 |
+
outputs = model(batch, labels=batch)
|
185 |
+
loss = outputs.loss.repeat(args.valid_batch_size)
|
186 |
+
losses.append(accelerator.gather(loss))
|
187 |
+
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
188 |
+
break
|
189 |
+
loss = torch.mean(torch.cat(losses))
|
190 |
+
|
191 |
+
try:
|
192 |
+
perplexity = torch.exp(loss)
|
193 |
+
except OverflowError:
|
194 |
+
perplexity = torch.tensor(float("inf"))
|
195 |
+
|
196 |
+
return loss.item(), perplexity.item()
|
197 |
+
|
198 |
+
|
199 |
+
# Accelerator
|
200 |
+
accelerator = Accelerator(dispatch_batches=True)
|
201 |
+
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
202 |
+
|
203 |
+
project_name = "shng2025/gptesla-small"
|
204 |
+
dataset_name = "shng2025/gptesla"
|
205 |
+
|
206 |
+
# GPTesla - 111M param setup in comment. Modification to make lighter training requirement needed
|
207 |
+
config = {
|
208 |
+
"train_batch_size": 12, # 12
|
209 |
+
"valid_batch_size": 12, # 12
|
210 |
+
"weight_decay": 0.1,
|
211 |
+
"shuffle_buffer": 1000,
|
212 |
+
"learning_rate": 5e-4, # 5e-4
|
213 |
+
"lr_scheduler_type": "cosine",
|
214 |
+
"num_warmup_steps": 700, # 2000
|
215 |
+
"gradient_accumulation_steps": 1, # 1
|
216 |
+
"max_train_steps": 150000, # 150000
|
217 |
+
"max_eval_steps": 10,
|
218 |
+
"seq_length": 1024,
|
219 |
+
"seed": 1,
|
220 |
+
"save_checkpoint_steps": 10000,
|
221 |
+
} # 15000
|
222 |
+
|
223 |
+
args = Namespace(**config, **acc_state)
|
224 |
+
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
225 |
+
set_seed(args.seed)
|
226 |
+
|
227 |
+
|
228 |
+
model = AutoModelForCausalLM.from_pretrained("./") # , gradient_checkpointing=True)
|
229 |
+
tokenizer = AutoTokenizer.from_pretrained("./")
|
230 |
+
|
231 |
+
# Load dataset and dataloader
|
232 |
+
train_dataloader, eval_dataloader = create_dataloaders(dataset_name)
|
233 |
+
|
234 |
+
|
235 |
+
# Loading torch checkpoint
|
236 |
+
optimizer = AdamW(get_grouped_params(model), lr=args.learning_rate)
|
237 |
+
lr_scheduler = get_scheduler(
|
238 |
+
name=args.lr_scheduler_type,
|
239 |
+
optimizer=optimizer,
|
240 |
+
num_warmup_steps=args.num_warmup_steps,
|
241 |
+
num_training_steps=args.max_train_steps,
|
242 |
+
)
|
243 |
+
completed_steps = 0
|
244 |
+
run_name = ""
|
245 |
+
wandb_id = ""
|
246 |
+
lr_scheduler, completed_steps, run_name, optimizer, wandb_id = load_checkpoint_torch(lr_scheduler, completed_steps, run_name, optimizer, wandb_id)
|
247 |
+
|
248 |
+
logger, tb_writer, run_name = continue_logging(project_name.split("/")[1], wandb_id)
|
249 |
+
|
250 |
+
|
251 |
+
# Load model and tokenizer
|
252 |
+
if accelerator.is_main_process:
|
253 |
+
hf_repo = Repository("./", clone_from=project_name, revision=run_name)
|
254 |
+
|
255 |
+
def get_lr():
|
256 |
+
return optimizer.param_groups[0]["lr"]
|
257 |
+
|
258 |
+
|
259 |
+
# advancing dataloader to correct position
|
260 |
+
for i, _ in enumerate(train_dataloader):
|
261 |
+
if i >= completed_steps:
|
262 |
+
break
|
263 |
+
for i, _ in enumerate(eval_dataloader):
|
264 |
+
if i >= (completed_steps // args.save_checkpoint_steps) * args.max_eval_steps:
|
265 |
+
break
|
266 |
+
|
267 |
+
# Prepare everything with our `accelerator` (order of args is not important)
|
268 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
269 |
+
model, optimizer, train_dataloader, eval_dataloader
|
270 |
+
)
|
271 |
+
|
272 |
+
# Train model
|
273 |
+
model.train()
|
274 |
+
for step, batch in enumerate(train_dataloader, start=completed_steps + 1):
|
275 |
+
loss = model(batch, labels=batch).loss
|
276 |
+
log_metrics(
|
277 |
+
step,
|
278 |
+
{
|
279 |
+
"lr": get_lr(),
|
280 |
+
"samples": step * samples_per_step,
|
281 |
+
"steps": completed_steps,
|
282 |
+
"loss/train": loss.item(),
|
283 |
+
},
|
284 |
+
)
|
285 |
+
loss = loss / args.gradient_accumulation_steps
|
286 |
+
accelerator.backward(loss)
|
287 |
+
if step % args.gradient_accumulation_steps == 0:
|
288 |
+
optimizer.step()
|
289 |
+
lr_scheduler.step()
|
290 |
+
optimizer.zero_grad()
|
291 |
+
completed_steps += 1
|
292 |
+
if step % args.save_checkpoint_steps == 0:
|
293 |
+
logger.info("Evaluating and saving model checkpoint")
|
294 |
+
eval_loss, perplexity = evaluate()
|
295 |
+
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
296 |
+
accelerator.wait_for_everyone()
|
297 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
298 |
+
if accelerator.is_main_process:
|
299 |
+
save_checkpoint_state()
|
300 |
+
unwrapped_model.save_pretrained("./")
|
301 |
+
hf_repo.push_to_hub(commit_message=f"step {step}")
|
302 |
+
model.train()
|
303 |
+
if completed_steps >= args.max_train_steps:
|
304 |
+
break
|
305 |
+
|
306 |
+
|
307 |
+
# Evaluate and save the last checkpoint
|
308 |
+
logger.info("Evaluating and saving model after training")
|
309 |
+
eval_loss, perplexity = evaluate()
|
310 |
+
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
311 |
+
accelerator.wait_for_everyone()
|
312 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
313 |
+
if accelerator.is_main_process:
|
314 |
+
unwrapped_model.save_pretrained("./")
|
315 |
+
hf_repo.push_to_hub(commit_message="final model")
|
log/debug_0.log
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 444048000
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11de17d222f5b1546c373114e7e6d47ed779e41a49951d89c57616611704c0cc
|
3 |
size 444048000
|
runs/Jul26_05-42-20_lab/1721972540.4948585/events.out.tfevents.1721972540.lab.4339.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f4c4f60cc7e9fa817fcc42e2cf7e93cf80d5e4dd2009603d9dacfe13d7e1a1d
|
3 |
+
size 1702
|
runs/Jul26_05-42-20_lab/events.out.tfevents.1721972540.lab.4339.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7f68642110473ee1af15c6854a625af03908dc925bc39787912dee9aed1bf1a
|
3 |
+
size 1840187
|
torch_checkpoint/latest_checkpoint.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:959808934631d8d815e6f7e939714b002883d6bcc4297148768f99716e762d57
|
3 |
+
size 888195962
|