MambaMate-Micro / train_bygame.py
HaileyStorm's picture
Upload 8 files
816f85a verified
import os
import time
import math
import pickle
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from mamba_lm import MambaLM, MambaLMConfig
import pyarrow.parquet as pq
import random
from torch.utils.data import Dataset, DataLoader
import glob
# -----------------------------------------------------------------------------
# default config values designed for Mamba model training
# I/O
out_dir = 'out'
eval_interval = 2000
log_interval = 1
eval_iters = 5
eval_only = False
always_save_checkpoint = True
init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
# wandb logging
wandb_log = False
wandb_project = 'mamba'
wandb_run_name = 'mamba_run' # modify as needed
# data
dataset = 'chess' # specify your dataset
gradient_accumulation_steps = 5 * 8
batch_size = 12
base_batch_size = batch_size
effective_batch_size = batch_size
max_seq_len = 1024 # A trianing-only parameter for controlling VRAM
train_file_update_interval = 7
# model
n_layer = 12
d_model = 768
dt_rank = 'auto'
d_state = 16
expand_factor = 2
bias = False
conv_bias = True
pscan = True
vocab_size = 32000
move_num_in_gamestate = True
# optimizer settings
learning_rate = 6e-4
max_iters = 600000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
auto_clip = False
grad_clip_start_size = 100
grad_clip_max_size = 500
grad_clip_percentile = 10
# learning rate decay settings
decay_lr = True
warmup_iters = 2000
lr_decay_iters = 600000
min_lr = 6e-5
# DDP settings
backend = 'nccl'
# system
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
compile = False # set to True if using PyTorch 2.0
# -----------------------------------------------------------------------------
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
anneal_checkpoint = 'anneal/ckpt.pt' #'anneal_me.pt'
anneal_dir = os.path.join(out_dir, 'anneal/')
anneal_start_iters = None # Set at init
anneal_decay_iters = None # Set at init
mamba_config = MambaLMConfig(
d_model=d_model, # adjust as needed
n_layers=n_layer, # adjust as needed
dt_rank=dt_rank,
d_state=d_state,
expand_factor=expand_factor,
bias=bias,
conv_bias=conv_bias,
pscan=pscan,
vocab_size=vocab_size # adjust based on your dataset
)
# DDP and other initializations
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
init_process_group(backend=backend)
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0
seed_offset = ddp_rank
assert gradient_accumulation_steps % ddp_world_size == 0
gradient_accumulation_steps //= ddp_world_size
else:
master_process = True
seed_offset = 0
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
if master_process:
os.makedirs(out_dir, exist_ok=True)
os.makedirs(anneal_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# poor man's data loader
data_dir = os.path.join('data', dataset)
current_train_file_index = 0
train_files = glob.glob(os.path.join(data_dir, 'train*.parquet'))
train_datasets = []
for f in train_files:
dataset = pq.read_table(f).to_pandas()
dataset = dataset[dataset['tokenized'].apply(len) >= 8]
train_datasets.append(dataset)
#val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
#val_data = val_data[val_data['tokenized'].apply(len) >= 8]
truncated_games_count = 0
total_games_count = 0
games_seen = 0
def get_batch(split):
global truncated_games_count, total_games_count, current_train_file_index
# Randomly select batch_size games
dataset = train_datasets[current_train_file_index] if split == 'train' else None # else val_data # Use the correct DataFrame based on the split
sample_df = dataset.sample(batch_size)
games = sample_df['tokenized'].tolist()
# Prepare sequences tensor for the batch
max_length_in_batch = min(max(len(game) for game in games), max_seq_len)
sequences = torch.zeros((batch_size, max_length_in_batch), dtype=torch.int64)
for i, game in enumerate(games):
total_games_count += 1
if len(game) > max_seq_len:
truncated_games_count += 1
# Randomly decide truncation strategy
truncation_choice = random.choice(['beginning', 'end', 'end2', 'random'])
if truncation_choice == 'beginning':
# Truncatethe beginning (use from the end backward)
truncated_game = game[-max_seq_len:]
elif truncation_choice.startswith('end'):
# Truncatethe end (use from the beginning forward)
truncated_game = game[:max_seq_len]
else:
# Random start index (truncate beginning and end)
start_idx = random.randint(0, len(game) - max_seq_len)
truncated_game = game[start_idx:start_idx + max_seq_len]
sequences[i, :len(truncated_game)] = torch.tensor(truncated_game, dtype=torch.int64)
# Report the percentage of truncated games
if truncated_games_count > 0 and truncated_games_count % 50 == 0:
truncated_percentage = (truncated_games_count / total_games_count) * 100
print(f"Percentage of truncated games: {truncated_percentage:.2f}%\t\t({truncated_games_count}/{total_games_count})")
else:
sequences[i, :len(game)] = torch.tensor(game, dtype=torch.int64)
if (total_games_count // batch_size) % train_file_update_interval == 0:
current_train_file_index = random.randint(0, len(train_files) - 1)
# print(f"Switched to file: {train_files[current_train_file_index]}")
if device_type == 'cuda':
sequences = sequences.pin_memory().to(device, non_blocking=True)
else:
sequences = sequences.to(device)
return sequences
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9
# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if not move_num_in_gamestate:
meta_vocab_size = 28
elif os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
meta_vocab_size = meta['vocab_size']
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
# Model initialization
if init_from == 'scratch':
print("Initializing a new Mamba model from scratch")
if meta_vocab_size is None:
print(f"defaulting to vocab_size of {vocab_size}")
else:
mamba_config.vocab_size = meta_vocab_size
model = MambaLM(mamba_config)
if auto_clip:
grad_clip = 0
config['grad_clip'] = 0
grad_norm_history = []
elif init_from == 'resume' or init_from == 'anneal':
print(f"Resuming training from {out_dir}")
if init_from == 'anneal':
ckpt_path = os.path.join(out_dir, anneal_checkpoint)
else:
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
mamba_config = checkpoint['model_args']
model = MambaLM(mamba_config)
state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
if 'effective_batch_size' not in checkpoint['config']:
print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.")
checkpoint['config']['effective_batch_size'] = effective_batch_size
iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size)))
if 'games_seen' in checkpoint:
games_seen = checkpoint['games_seen']
else:
games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num']
checkpoint['games_seen'] = games_seen
print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}")
best_val_loss = checkpoint['best_val_loss']
print(f"Best val loss: {best_val_loss}")
if auto_clip:
grad_clip = checkpoint['config']['grad_clip']
config['grad_clip'] = grad_clip
#grad_norm_history = [t.item() if torch.is_tensor(t) else t for t in checkpoint.get('grad_norm_history', [])]
grad_norm_history = checkpoint.get('grad_norm_history', [])
if init_from == 'anneal':
print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n")
anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters']
anneal_decay_iters = iter_num / 7.0 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] # / 9 is og
print(anneal_start_iters)
print(anneal_decay_iters)
if 'anneal_start_iters' not in checkpoint:
grad_clip = 0
config['grad_clip'] = 0
grad_norm_history = []
print(f"Starting anneal. Resumed from anneal_me.pt, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.")
out_dir = anneal_dir
weight_decay = weight_decay / 10.0 # / 17.0
beta2 = np.sqrt(beta2) * beta2
auto_clip = True
grad_clip_percentile = 6.3333 # 6.75
elif init_from.startswith('state-spaces'):
print(f"Initializing from Mamba pre-trained weights: {init_from}")
model = from_pretrained(init_from)
mamba_config = model.config
else:
raise ValueError("Invalid init_from value")
model.to(device)
print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.')
# Optimizer and GradScaler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None
# Compile the model if using PyTorch 2.0
if compile:
print("compiling the model... (takes a ~minute)")
model = torch.compile(model)
# Wrap model in DDP container if necessary
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train']: #['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
tokens = get_batch(split) # Fetch tokens in the correct format
logits = model(tokens[:, :-1]) # Predict next tokens (ignore last token)
# The targets are the tokens shifted by one position
targets = tokens[:, 1:].reshape(-1) # Flatten targets for cross-entropy
# Compute cross-entropy loss between logits and targets
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
losses[k] = loss.item()
split = 'val' # Temporary hack
out[split] = losses.mean()
model.train()
return out
# WSD scheduler
def get_lr(it):
if init_from == 'anneal':
# Linear decay from max LR to min LR over (anneal_start_iters / 9) iters
decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters
return learning_rate - decay_ratio * (learning_rate - min_lr)
if it < warmup_iters:
# Warmup
return learning_rate * it / warmup_iters
# Stable max LR
return learning_rate
# Logging setup
if wandb_log and master_process:
import wandb
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
# Training loop
local_iter_num = 0 # Number of iterations in the lifetime of this process
last_crossed_multiple = 0
save_every_n_games = 150000
raw_model = model.module if ddp else model # Unwrap DDP container if needed
t0 = time.time()
while True:
# Determine and set the learning rate for this iteration
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Evaluate the loss on train/val sets and write checkpoints
if iter_num % eval_interval == 0 and master_process:
losses = estimate_loss()
print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}") # Temporary hack
#print(f"game {games_seen} ({iter_num}): train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
config['grad_clip'] = grad_clip
print(f"Auto adjusted grad_clip to {grad_clip}")
if wandb_log:
wandb.log({
"iter": iter_num,
"games": games_seen,
#"train/loss": losses['train'], # Temporary hack
"grad_clip": grad_clip,
"val/loss": losses['val'],
"lr": lr,
})
if losses['val'] < best_val_loss or always_save_checkpoint:
if iter_num > 0:
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_seen,
'best_val_loss': min(best_val_loss, losses['val']),
'config': config,
}
checkpoint['grad_norm_history'] = grad_norm_history
if init_from == 'anneal':
checkpoint['anneal_start_iters'] = anneal_start_iters
checkpoint['anneal_decay_iters'] = anneal_decay_iters
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games
if losses['val'] < best_val_loss: # Temporary / only good after it's settled
best_val_loss = losses['val']
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
elif current_nearest_multiple != last_crossed_multiple: # elif so we don't double up
last_crossed_multiple = current_nearest_multiple
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt'))
if iter_num == 0 and eval_only:
break
# Forward and backward pass
for micro_step in range(gradient_accumulation_steps):
if ddp:
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
sequences = get_batch('train') # Fetch the training data
with ctx:
logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input
# Compute loss (assuming next token prediction task)
targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
loss = loss / gradient_accumulation_steps
scaler.scale(loss).backward()
#print('.', end='')
# clip the gradient
if grad_clip != 0.0 or auto_clip:
scaler.unscale_(optimizer)
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) # The 0 check is for auto_clip enabled but not enough history
grad_norm_history.append(total_norm.item())
grad_norm_history = grad_norm_history[-grad_clip_max_size:]
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)
# timing and logging
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % log_interval == 0 and master_process:
# get loss as float. note: this is a CPU-GPU sync point
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
lossf = loss.item() * gradient_accumulation_steps
print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms")
if wandb_log:
wandb.log({
"iter": iter_num,
"games": games_seen,
"grad_norm": grad_norm_history[-1] if grad_norm_history else 0,
"train/loss": lossf,
"lr": lr,
})
iter_num += 1
local_iter_num += 1
games_seen += effective_batch_size
# termination conditions
if iter_num > max_iters:
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_seen,
'best_val_loss': best_val_loss,
'config': config,
}
checkpoint['grad_norm_history'] = grad_norm_history
if init_from == 'anneal':
checkpoint['anneal_start_iters'] = anneal_start_iters
checkpoint['anneal_decay_iters'] = anneal_decay_iters
print(f"Max_iters reached. Saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt_final.pt'))
break
if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters:
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_seen,
'best_val_loss': best_val_loss,
'config': config,
}
checkpoint['grad_norm_history'] = grad_norm_history
if init_from == 'anneal':
checkpoint['anneal_start_iters'] = anneal_start_iters
checkpoint['anneal_decay_iters'] = anneal_decay_iters
print(f"Anneal complete. Saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt'))
break
if ddp:
destroy_process_group()