""" To train a GPT from sratch """ import argparse import os import time import math import pickle from contextlib import nullcontext import numpy as np import torch from torch.nn.parallel import DistributedDataParallel from torch.distributed import init_process_group, destroy_process_group import pynvml from model import GPTConfig, GPT parser = argparse.ArgumentParser(description="Load configuration file") parser.add_argument('--config', type=str, required=True, help='Path to the configuration file') args = parser.parse_args() config_path = args.config exec(open(config_path).read()) # -----Load all global variables for logging-------------------------------------------------------- config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] # exec(open(config_path).read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # ----------------------------------------------------------------------------- def log_and_write(filename, message): with open(filename, 'a') as f: f.write(message + "\n") print(message) log_and_write(log_dir,f'gradient_accumulation_steps: {gradient_accumulation_steps}, batch_size: {batch_size}, \nblock_size: {block_size}, \nn_layer: {n_layer}, n_head: {n_head}, n_embd: {n_embd}, dropout: {dropout}, bias: {bias}, \nlearning_rate: {learning_rate}, max_iters: {max_iters}, \nweight_decay: {weight_decay}, beta1: {beta1}, beta2: {beta2}, grad_clip: {grad_clip}, decay_lr: {decay_lr}, \nwarmup_iters: {warmup_iters}, lr_decay_iters: {lr_decay_iters}, \nmin_lr: {min_lr}, backend: {backend}, device: {device},\n dtype: {dtype}, compile: {compile}') log_and_write(log_dir, f'meta_vocab_size: {meta_vocab_size}') log_and_write(log_dir, f'training data: {data_dir}') # ----------------------------------------------------------------------------- ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: init_process_group(backend=backend) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. seed_offset = ddp_rank assert gradient_accumulation_steps % ddp_world_size == 0 gradient_accumulation_steps //= ddp_world_size else: master_process = True seed_offset = 0 ddp_world_size = 1 tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size print('ddp_world_size:',ddp_world_size) print(f"tokens per iteration will be: {tokens_per_iter:,}") pynvml.nvmlInit() def print_gpu_memory_usage(): handle = pynvml.nvmlDeviceGetHandleByIndex(0) info = pynvml.nvmlDeviceGetMemoryInfo(handle) print(f"Used: {info.used / 1024**2:.2f}MB/{info.total / 1024**2:.2f}MB ({info.used / info.total * 100:.2f}%)") if master_process: os.makedirs(out_dir, exist_ok=True) torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # data loader train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') def get_batch(split): data = train_data if split == 'train' else val_data ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) if device_type == 'cuda': # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) else: x, y = x.to(device), y.to(device) return x, y iter_num = 0 best_val_loss = 1e9 # attempt to derive vocab_size from the dataset meta_path = os.path.join(data_dir, 'meta.pkl') if os.path.exists(meta_path): with open(meta_path, 'rb') as f: meta = pickle.load(f) meta_vocab_size = meta['vocab_size'] print(f"found vocab_size = {meta_vocab_size}") # model init model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line if init_from == 'scratch': # init a new model from scratch print("Initializing a new model from scratch") # determine the vocab size we'll use for from-scratch training if meta_vocab_size is None: print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 gptconf = GPTConfig(**model_args) model = GPT(gptconf) elif init_from == 'resume': print(f"Resuming training from {out_dir}") # resume training from a checkpoint. checkpoint = torch.load(ckpt_path, map_location=device) checkpoint_model_args = checkpoint['model_args'] # force these config attributes to be equal otherwise we can't even resume training # the rest of the attributes (e.g. dropout) can stay as desired from command line for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: model_args[k] = checkpoint_model_args[k] # create the model gptconf = GPTConfig(**model_args) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) iter_num = checkpoint['iter_num'] best_val_loss = checkpoint['best_val_loss'] # crop down the model block size if desired, using model surgery if block_size < model.config.block_size: model.crop_block_size(block_size) model_args['block_size'] = block_size # so that the checkpoint will have the right value model.to(device) # initialize a GradScaler. If enabled=False scaler is a no-op scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float32')) # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) if init_from == 'resume': optimizer.load_state_dict(checkpoint['optimizer']) checkpoint = None # free up memory # compile the model if compile: print("compiling the model... (takes a ~minute)") unoptimized_model = model model = torch.compile(model) # requires PyTorch 2.0 # wrap model into DDP container if ddp: model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) # helps estimate an arbitrarily accurate loss over either split using many batches @torch.no_grad() def estimate_loss(): out = {} perplexities = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(eval_iters) total_loss = 0 for k in range(eval_iters): X, Y = get_batch(split) with ctx: logits, loss = model(X, Y) losses[k] = loss.item() total_loss += loss.item() avg_loss = losses.mean() out[split] = avg_loss perplexities[split] = torch.exp(avg_loss) model.train() return out, perplexities # learning rate decay scheduler (cosine with warmup) def get_lr(it): # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) # training loop X, Y = get_batch('train') # fetch the very first batch t0 = time.time() local_iter_num = 0 # number of iterations in the lifetime of this process raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 while True: # determine and set the learning rate for this iteration lr = get_lr(iter_num) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group['lr'] = lr # evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process: losses, perplexities = estimate_loss() log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}") if iter_num % 200 == 0: print_gpu_memory_usage() if always_save_checkpoint: if losses['val'] < best_val_loss or always_save_checkpoint: best_val_loss = losses['val'] if iter_num > 0: checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_args, 'iter_num': iter_num, 'best_val_loss': best_val_loss, 'config': config, } log_and_write(log_dir, f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt')) if iter_num == 0 and eval_only: break # forward backward update, with optional gradient accumulation to simulate larger batch size # and using the GradScaler if data type is float16 for micro_step in range(gradient_accumulation_steps): if ddp: # in DDP training we only need to sync gradients at the last micro step. model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) with ctx: logits, loss = model(X, Y) loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = get_batch('train') # backward pass, with gradient scaling if training in fp16 scaler.scale(loss).backward() # clip the gradient if grad_clip != 0.0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) # step the optimizer and scaler if training in fp16 scaler.step(optimizer) scaler.update() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) # timing and logging t1 = time.time() dt = t1 - t0 t0 = t1 if iter_num % log_interval == 0 and master_process: # get loss as float. note: this is a CPU-GPU sync point # scale up to undo the division above, approximating the true total loss (exact would have been a sum) lossf = loss.item() * gradient_accumulation_steps if local_iter_num >= 5: # let the training loop settle a bit mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu log_and_write(log_dir, f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {lr}, mfu {running_mfu*100:.2f}%") iter_num += 1 local_iter_num += 1 # termination conditions if iter_num > max_iters: break if ddp: destroy_process_group() pynvml.nvmlShutdown()