import os import time import math import pickle from contextlib import nullcontext import numpy as np import torch import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group from mamba_lm import MambaLM, MambaLMConfig import pyarrow.parquet as pq import random from torch.utils.data import Dataset, DataLoader import glob # ----------------------------------------------------------------------------- # default config values designed for Mamba model training # I/O out_dir = 'out' eval_interval = 2000 log_interval = 1 eval_iters = 5 eval_only = False always_save_checkpoint = True init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name # wandb logging wandb_log = False wandb_project = 'mamba' wandb_run_name = 'mamba_run' # modify as needed # data dataset = 'chess' # specify your dataset gradient_accumulation_steps = 5 * 8 batch_size = 12 base_batch_size = batch_size effective_batch_size = batch_size max_seq_len = 1024 # A trianing-only parameter for controlling VRAM train_file_update_interval = 7 # model n_layer = 12 d_model = 768 dt_rank = 'auto' d_state = 16 expand_factor = 2 bias = False conv_bias = True pscan = True vocab_size = 32000 move_num_in_gamestate = True # optimizer settings learning_rate = 6e-4 max_iters = 600000 weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 grad_clip = 1.0 auto_clip = False grad_clip_start_size = 100 grad_clip_max_size = 500 grad_clip_percentile = 10 # learning rate decay settings decay_lr = True warmup_iters = 2000 lr_decay_iters = 600000 min_lr = 6e-5 # DDP settings backend = 'nccl' # system device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' compile = False # set to True if using PyTorch 2.0 # ----------------------------------------------------------------------------- config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open('configurator.py').read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- anneal_checkpoint = 'anneal/ckpt.pt' #'anneal_me.pt' anneal_dir = os.path.join(out_dir, 'anneal/') anneal_start_iters = None # Set at init anneal_decay_iters = None # Set at init mamba_config = MambaLMConfig( d_model=d_model, # adjust as needed n_layers=n_layer, # adjust as needed dt_rank=dt_rank, d_state=d_state, expand_factor=expand_factor, bias=bias, conv_bias=conv_bias, pscan=pscan, vocab_size=vocab_size # adjust based on your dataset ) # DDP and other initializations ddp = int(os.environ.get('RANK', -1)) != -1 if ddp: init_process_group(backend=backend) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) master_process = ddp_rank == 0 seed_offset = ddp_rank assert gradient_accumulation_steps % ddp_world_size == 0 gradient_accumulation_steps //= ddp_world_size else: master_process = True seed_offset = 0 ddp_world_size = 1 tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len if master_process: os.makedirs(out_dir, exist_ok=True) os.makedirs(anneal_dir, exist_ok=True) torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device_type = 'cuda' if 'cuda' in device else 'cpu' ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # poor man's data loader data_dir = os.path.join('data', dataset) current_train_file_index = 0 train_files = glob.glob(os.path.join(data_dir, 'train*.parquet')) train_datasets = [] for f in train_files: dataset = pq.read_table(f).to_pandas() dataset = dataset[dataset['tokenized'].apply(len) >= 8] train_datasets.append(dataset) #val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas() #val_data = val_data[val_data['tokenized'].apply(len) >= 8] truncated_games_count = 0 total_games_count = 0 games_seen = 0 def get_batch(split): global truncated_games_count, total_games_count, current_train_file_index # Randomly select batch_size games dataset = train_datasets[current_train_file_index] if split == 'train' else None # else val_data # Use the correct DataFrame based on the split sample_df = dataset.sample(batch_size) games = sample_df['tokenized'].tolist() # Prepare sequences tensor for the batch max_length_in_batch = min(max(len(game) for game in games), max_seq_len) sequences = torch.zeros((batch_size, max_length_in_batch), dtype=torch.int64) for i, game in enumerate(games): total_games_count += 1 if len(game) > max_seq_len: truncated_games_count += 1 # Randomly decide truncation strategy truncation_choice = random.choice(['beginning', 'end', 'end2', 'random']) if truncation_choice == 'beginning': # Truncatethe beginning (use from the end backward) truncated_game = game[-max_seq_len:] elif truncation_choice.startswith('end'): # Truncatethe end (use from the beginning forward) truncated_game = game[:max_seq_len] else: # Random start index (truncate beginning and end) start_idx = random.randint(0, len(game) - max_seq_len) truncated_game = game[start_idx:start_idx + max_seq_len] sequences[i, :len(truncated_game)] = torch.tensor(truncated_game, dtype=torch.int64) # Report the percentage of truncated games if truncated_games_count > 0 and truncated_games_count % 50 == 0: truncated_percentage = (truncated_games_count / total_games_count) * 100 print(f"Percentage of truncated games: {truncated_percentage:.2f}%\t\t({truncated_games_count}/{total_games_count})") else: sequences[i, :len(game)] = torch.tensor(game, dtype=torch.int64) if (total_games_count // batch_size) % train_file_update_interval == 0: current_train_file_index = random.randint(0, len(train_files) - 1) # print(f"Switched to file: {train_files[current_train_file_index]}") if device_type == 'cuda': sequences = sequences.pin_memory().to(device, non_blocking=True) else: sequences = sequences.to(device) return sequences # init these up here, can override if init_from='resume' (i.e. from a checkpoint) iter_num = 0 best_val_loss = 1e9 # attempt to derive vocab_size from the dataset meta_path = os.path.join(data_dir, 'meta.pkl') meta_vocab_size = None if not move_num_in_gamestate: meta_vocab_size = 28 elif os.path.exists(meta_path): with open(meta_path, 'rb') as f: meta = pickle.load(f) meta_vocab_size = meta['vocab_size'] print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") # Model initialization if init_from == 'scratch': print("Initializing a new Mamba model from scratch") if meta_vocab_size is None: print(f"defaulting to vocab_size of {vocab_size}") else: mamba_config.vocab_size = meta_vocab_size model = MambaLM(mamba_config) if auto_clip: grad_clip = 0 config['grad_clip'] = 0 grad_norm_history = [] elif init_from == 'resume' or init_from == 'anneal': print(f"Resuming training from {out_dir}") if init_from == 'anneal': ckpt_path = os.path.join(out_dir, anneal_checkpoint) else: ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) mamba_config = checkpoint['model_args'] model = MambaLM(mamba_config) state_dict = checkpoint['model'] # fix the keys of the state dictionary :( # honestly no idea how checkpoints sometimes get this prefix, have to debug more unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) if 'effective_batch_size' not in checkpoint['config']: print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.") checkpoint['config']['effective_batch_size'] = effective_batch_size iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size))) if 'games_seen' in checkpoint: games_seen = checkpoint['games_seen'] else: games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num'] checkpoint['games_seen'] = games_seen print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}") best_val_loss = checkpoint['best_val_loss'] print(f"Best val loss: {best_val_loss}") if auto_clip: grad_clip = checkpoint['config']['grad_clip'] config['grad_clip'] = grad_clip #grad_norm_history = [t.item() if torch.is_tensor(t) else t for t in checkpoint.get('grad_norm_history', [])] grad_norm_history = checkpoint.get('grad_norm_history', []) if init_from == 'anneal': print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n") anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters'] anneal_decay_iters = iter_num / 7.0 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] # / 9 is og print(anneal_start_iters) print(anneal_decay_iters) if 'anneal_start_iters' not in checkpoint: grad_clip = 0 config['grad_clip'] = 0 grad_norm_history = [] print(f"Starting anneal. Resumed from anneal_me.pt, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.") out_dir = anneal_dir weight_decay = weight_decay / 10.0 # / 17.0 beta2 = np.sqrt(beta2) * beta2 auto_clip = True grad_clip_percentile = 6.3333 # 6.75 elif init_from.startswith('state-spaces'): print(f"Initializing from Mamba pre-trained weights: {init_from}") model = from_pretrained(init_from) mamba_config = model.config else: raise ValueError("Invalid init_from value") model.to(device) print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.') # Optimizer and GradScaler optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') if init_from == 'resume': optimizer.load_state_dict(checkpoint['optimizer']) checkpoint = None # Compile the model if using PyTorch 2.0 if compile: print("compiling the model... (takes a ~minute)") model = torch.compile(model) # Wrap model in DDP container if necessary if ddp: model = DDP(model, device_ids=[ddp_local_rank]) @torch.no_grad() def estimate_loss(): out = {} model.eval() for split in ['train']: #['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): tokens = get_batch(split) # Fetch tokens in the correct format logits = model(tokens[:, :-1]) # Predict next tokens (ignore last token) # The targets are the tokens shifted by one position targets = tokens[:, 1:].reshape(-1) # Flatten targets for cross-entropy # Compute cross-entropy loss between logits and targets loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets) losses[k] = loss.item() split = 'val' # Temporary hack out[split] = losses.mean() model.train() return out # WSD scheduler def get_lr(it): if init_from == 'anneal': # Linear decay from max LR to min LR over (anneal_start_iters / 9) iters decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters return learning_rate - decay_ratio * (learning_rate - min_lr) if it < warmup_iters: # Warmup return learning_rate * it / warmup_iters # Stable max LR return learning_rate # Logging setup if wandb_log and master_process: import wandb wandb.init(project=wandb_project, name=wandb_run_name, config=config) # Training loop local_iter_num = 0 # Number of iterations in the lifetime of this process last_crossed_multiple = 0 save_every_n_games = 150000 raw_model = model.module if ddp else model # Unwrap DDP container if needed t0 = time.time() while True: # Determine and set the learning rate for this iteration lr = get_lr(iter_num) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group['lr'] = lr # Evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process: losses = estimate_loss() print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}") # Temporary hack #print(f"game {games_seen} ({iter_num}): train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") if auto_clip and len(grad_norm_history) >= grad_clip_start_size: grad_clip = np.percentile(grad_norm_history, grad_clip_percentile) config['grad_clip'] = grad_clip print(f"Auto adjusted grad_clip to {grad_clip}") if wandb_log: wandb.log({ "iter": iter_num, "games": games_seen, #"train/loss": losses['train'], # Temporary hack "grad_clip": grad_clip, "val/loss": losses['val'], "lr": lr, }) if losses['val'] < best_val_loss or always_save_checkpoint: if iter_num > 0: checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': mamba_config, 'iter_num': iter_num, "games_seen": games_seen, 'best_val_loss': min(best_val_loss, losses['val']), 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history if init_from == 'anneal': checkpoint['anneal_start_iters'] = anneal_start_iters checkpoint['anneal_decay_iters'] = anneal_decay_iters print(f"saving checkpoint to {out_dir}\n") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games if losses['val'] < best_val_loss: # Temporary / only good after it's settled best_val_loss = losses['val'] torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt')) elif current_nearest_multiple != last_crossed_multiple: # elif so we don't double up last_crossed_multiple = current_nearest_multiple torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt')) if iter_num == 0 and eval_only: break # Forward and backward pass for micro_step in range(gradient_accumulation_steps): if ddp: model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) sequences = get_batch('train') # Fetch the training data with ctx: logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input # Compute loss (assuming next token prediction task) targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets) loss = loss / gradient_accumulation_steps scaler.scale(loss).backward() #print('.', end='') # clip the gradient if grad_clip != 0.0 or auto_clip: scaler.unscale_(optimizer) total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) # The 0 check is for auto_clip enabled but not enough history grad_norm_history.append(total_norm.item()) grad_norm_history = grad_norm_history[-grad_clip_max_size:] # step the optimizer and scaler if training in fp16 scaler.step(optimizer) scaler.update() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) # timing and logging t1 = time.time() dt = t1 - t0 t0 = t1 if iter_num % log_interval == 0 and master_process: # get loss as float. note: this is a CPU-GPU sync point # scale up to undo the division above, approximating the true total loss (exact would have been a sum) lossf = loss.item() * gradient_accumulation_steps print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms") if wandb_log: wandb.log({ "iter": iter_num, "games": games_seen, "grad_norm": grad_norm_history[-1] if grad_norm_history else 0, "train/loss": lossf, "lr": lr, }) iter_num += 1 local_iter_num += 1 games_seen += effective_batch_size # termination conditions if iter_num > max_iters: checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': mamba_config, 'iter_num': iter_num, "games_seen": games_seen, 'best_val_loss': best_val_loss, 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history if init_from == 'anneal': checkpoint['anneal_start_iters'] = anneal_start_iters checkpoint['anneal_decay_iters'] = anneal_decay_iters print(f"Max_iters reached. Saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, 'ckpt_final.pt')) break if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters: checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': mamba_config, 'iter_num': iter_num, "games_seen": games_seen, 'best_val_loss': best_val_loss, 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history if init_from == 'anneal': checkpoint['anneal_start_iters'] = anneal_start_iters checkpoint['anneal_decay_iters'] = anneal_decay_iters print(f"Anneal complete. Saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt')) break if ddp: destroy_process_group()