""" This script is a placeholder for training LLaMA from scratch. Currently, it just trains on the Shakespeare dataset. """ import os import time from functools import partial from typing import Tuple import lightning as L from lightning.fabric.strategies import FSDPStrategy import torch from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy import numpy as np from lit_llama.model import Block, LLaMA, LLaMAConfig from lit_llama.utils import save_model_checkpoint out_dir = "out/training" eval_interval = 2000 eval_iters = 200 log_interval = 1 # compilation fails as it does not support torch.complex64 for RoPE # compile = False # Hyperparameters learning_rate = 6e-4 batch_size = 2 max_iters = 600000 weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 grad_clip = 1.0 # For shakespeare, choose smaller block size than vanilla LLaMA block_size = 1024 def main() -> None: auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block) fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy) fabric.launch() fabric.seed_everything(1337 + fabric.global_rank) if fabric.global_rank == 0: os.makedirs(out_dir, exist_ok=True) train_data, val_data = load_datasets() config = LLaMAConfig.from_name("7B") config.block_size = block_size config.vocab_size = 100 # from prepare_shakespeare.py with fabric.device: model = LLaMA(config) # if compile: # model = torch.compile(model) model = fabric.setup_module(model) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) optimizer = fabric.setup_optimizers(optimizer) train(fabric, model, optimizer, train_data, val_data) def train( fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_data: np.ndarray, val_data: np.ndarray, ) -> None: """The training loop. Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. """ iter_num = 0 while True: # TODO: add learning rate scheduling # evaluate the loss on train/val sets and write checkpoints if iter_num > 0 and iter_num % eval_interval == 0: val_loss = validate(fabric, model, val_data) fabric.print(f"step {iter_num}: val loss {val_loss:.4f}") fabric.print(f"Saving checkpoint to {out_dir}") save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth")) t0 = time.time() input_ids, targets = get_batch( fabric, train_data, block_size=model.config.block_size, # type: ignore[union-attr,arg-type] ) logits = model(input_ids) loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) fabric.backward(loss) # TODO: Gradient clipping # if grad_clip != 0.0: # fabric.clip_gradients(model, optimizer, max_norm=grad_clip) optimizer.step() optimizer.zero_grad() dt = time.time() - t0 if iter_num % log_interval == 0: fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms") iter_num += 1 if iter_num > max_iters: break @torch.no_grad() def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval_iters) for k in range(eval_iters): input_ids, targets = get_batch( fabric, val_data, block_size=model.config.block_size, # type: ignore[union-attr,arg-type] ) logits = model(input_ids) loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) losses[k] = loss.item() out = losses.mean() model.train() return out def get_batch(fabric: L.Fabric, data: np.ndarray, block_size: int) -> Tuple[torch.Tensor, torch.Tensor]: ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i : i + block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + block_size]).astype(np.int64)) for i in ix]) x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) return x, y def load_datasets(data_dir: str = "data/shakespeare") -> Tuple[np.ndarray, np.ndarray]: train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r") val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r") return train_data, val_data if __name__ == "__main__": torch.set_float32_matmul_precision("high") main()