test_model / train_redpajama.py
khoicrtp's picture
init
12001a9
import os
import math
import glob
import time
from functools import partial
from pathlib import Path
from typing import Tuple, Optional
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
import torch
from torch.utils.data import DataLoader
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import numpy as np
from lit_llama.model import Block, LLaMA, LLaMAConfig
from lit_llama.packed_dataset import PackedDataset, CombinedDataset
from lit_llama.utils import save_model_checkpoint
out_dir = "out/training"
save_interval = 1000
eval_interval = 1000
eval_iters = 100
log_interval = 1
# compile = False
# Hyperparameters
learning_rate = 6e-4
batch_size = 125
micro_batch_size = 5
max_iters = 600000 # num_epochs * epoch_size // devices
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
warmup_iters = 2000
lr_decay_iters = max_iters
min_lr = 6e-5
# Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1
data_config = [
("arxiv", 2.5),
("book", 4.5),
("c4", 15.0),
("cc", 67.0),
("github", 4.5),
("stackexchange", 2.0),
("wikipedia", 4.5),
]
def main(
devices: int = 4,
train_data_dir: Path = "data/lit-redpajama",
val_data_dir: Optional[Path] = None,
) -> None:
auto_wrap_policy = partial(
transformer_auto_wrap_policy, transformer_layer_cls={Block}
)
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block
)
fabric = L.Fabric(
accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy
)
fabric.launch()
fabric.seed_everything(1337)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
config = LLaMAConfig.from_name("7B")
train_dataloader, val_dataloader = create_dataloaders(
batch_size=micro_batch_size,
block_size=config.block_size,
fabric=fabric,
train_data_dir=train_data_dir,
val_data_dir=val_data_dir,
seed=1338,
)
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
with fabric.device:
torch.set_default_dtype(torch.bfloat16)
model = LLaMA(config)
model.apply(model._init_weights)
torch.set_default_dtype(torch.float32)
# if compile:
# model = torch.compile(model)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(beta1, beta2),
)
model, optimizer = fabric.setup(model, optimizer)
process_batch_size = batch_size // devices
grad_accum_steps = process_batch_size // micro_batch_size
train(fabric, model, optimizer, train_dataloader, val_dataloader, grad_accum_steps, devices)
def train(
fabric: L.Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader],
grad_accum_steps: int,
devices: int,
) -> None:
"""The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0
step_time = 0.0
tokens = 0
tokens_sec = 0.0
prev_t1 = time.time()
for iter_num, train_data in enumerate(train_dataloader):
t0 = time.time()
# determine and set the learning rate for this iteration
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr
input_ids = train_data[:, 0 : model.config.block_size].contiguous()
targets = train_data[:, 1 : model.config.block_size + 1].contiguous()
is_accumulating = (iter_num + 1) % grad_accum_steps != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
)
fabric.backward(loss / grad_accum_steps)
t1 = time.time()
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
optimizer.step()
optimizer.zero_grad()
step_count += 1
t1 = time.time()
if val_dataloader is not None and step_count % eval_interval == 0:
val_loss = validate(fabric, model, val_dataloader)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.barrier()
fabric.log_dict(
{"iter": iter_num, "val_loss": val_loss, "step": step_count, "lr": lr}
)
if step_count % save_interval == 0:
fabric.print(f"Saving checkpoint to {out_dir}")
save_model_checkpoint(
fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")
)
dt = t1 - t0
tokens += micro_batch_size * model.config.block_size
step_time += t1 - prev_t1
prev_t1 = t1
if iter_num % log_interval == 0:
tokens_sec_str = f"{tokens / step_time:.0f}" if not is_accumulating else "-"
fabric.log_dict(
{"iter": iter_num, "train_loss": loss, "step": step_count, "lr": lr}
)
fabric.print(
f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms, speed: {tokens_sec_str} toks/s/device"
)
if not is_accumulating:
tokens = 0
step_time = 0.0
if iter_num > max_iters:
break
@torch.no_grad()
def validate(
fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
for k, val_data in enumerate(val_dataloader):
input_ids = val_data[:, 0 : model.config.block_size].contiguous()
targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
logits = model(input_ids)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out
def create_dataloader(
batch_size: int,
block_size: int,
data_dir: str,
fabric,
shuffle: bool = True,
seed: int = 12345,
) -> DataLoader:
datasets = []
for prefix, _ in data_config:
filenames = glob.glob(os.path.join(data_dir, prefix + "*"))
dataset = PackedDataset(
filenames, n_chunks=4, block_size=block_size, shuffle=shuffle, seed=seed,
num_processes=fabric.world_size, process_rank=fabric.global_rank,
)
datasets.append(dataset)
if not datasets:
raise RuntimeError(
f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
)
weights = [weight for _, weight in data_config]
sum_weights = sum(weights)
weights = [el / sum_weights for el in weights]
combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)
return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
def create_dataloaders(
batch_size: int,
block_size: int,
fabric,
train_data_dir: str = "data/lit-redpajama",
val_data_dir: Optional[str] = None,
seed: int = 12345,
) -> Tuple[DataLoader, DataLoader]:
# Increase by one because we need the next word as well
effective_block_size = block_size + 1
train_dataloader = create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=train_data_dir,
shuffle=True,
seed=seed,
)
val_dataloader = (
create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=val_data_dir,
shuffle=False,
seed=seed,
)
if val_data_dir
else None
)
return train_dataloader, val_dataloader
# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
from jsonargparse.cli import CLI
CLI(main)