|
import os |
|
import time |
|
import math |
|
import pickle |
|
from contextlib import nullcontext |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.distributed import init_process_group, destroy_process_group |
|
import pyarrow.parquet as pq |
|
import random |
|
from torch.utils.data import Dataset, DataLoader |
|
import glob |
|
|
|
|
|
|
|
|
|
out_dir = 'out' |
|
eval_interval = 2000 |
|
log_interval = 1 |
|
eval_iters = 5 |
|
eval_only = False |
|
always_save_checkpoint = True |
|
init_from = 'resume' |
|
|
|
wandb_log = False |
|
wandb_project = 'mamba' |
|
wandb_run_name = 'mamba_run' |
|
|
|
dataset = 'chess' |
|
gradient_accumulation_steps = 5 * 8 |
|
batch_size = 12 |
|
base_batch_size = batch_size |
|
effective_batch_size = batch_size |
|
max_seq_len = 1024 |
|
train_file_update_interval = 7 |
|
|
|
|
|
model_type = 'mamba' |
|
|
|
n_layer = 12 |
|
d_model = 768 |
|
dt_rank = 'auto' |
|
d_state = 16 |
|
expand_factor = 2 |
|
bias = False |
|
conv_bias = True |
|
pscan = True |
|
vocab_size = 32 |
|
move_num_in_gamestate = True |
|
|
|
n_head = 12 |
|
n_embd = 768 |
|
dropout = 0.0 |
|
|
|
|
|
learning_rate = 6e-4 |
|
max_iters = 600000 |
|
weight_decay = 1e-1 |
|
beta1 = 0.9 |
|
beta2 = 0.95 |
|
grad_clip = 0.5 |
|
auto_clip = False |
|
auto_clip_max = 0.5 |
|
auto_clip_min = 3.333e-3 |
|
grad_clip_start_size = 100 |
|
grad_clip_max_size = 500 |
|
grad_clip_percentile = 10 |
|
|
|
decay_lr = True |
|
warmup_iters = 2000 |
|
min_lr = 6e-5 |
|
|
|
backend = 'nccl' |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' |
|
compile = False |
|
|
|
|
|
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] |
|
exec(open('configurator.py').read()) |
|
config = {k: globals()[k] for k in config_keys} |
|
|
|
|
|
|
|
anneal_checkpoint = 'anneal/ckpt.pt' |
|
anneal_dir = os.path.join(out_dir, 'anneal/') |
|
anneal_start_iters = None |
|
anneal_decay_iters = None |
|
|
|
if model_type == 'mamba': |
|
from mamba_lm import MambaLM, MambaLMConfig |
|
from mamba_ssm import MambaLMHeadModel |
|
model_config = MambaLMConfig( |
|
d_model=d_model, |
|
|
|
n_layer=n_layer, |
|
ssm_cfg={ |
|
'dt_rank': dt_rank, |
|
'd_state': d_state, |
|
|
|
'bias': bias, |
|
'conv_bias':conv_bias, |
|
|
|
}, |
|
vocab_size=vocab_size, |
|
pad_vocab_size_multiple=1 |
|
).to_mamba_config() |
|
elif model_type == 'xformer': |
|
from xformer import GPTConfig, GPT |
|
model_config = GPTConfig( |
|
n_layer=n_layer, |
|
n_head=n_head, |
|
n_embd=n_embd, |
|
block_size=max_seq_len, |
|
bias=bias, |
|
vocab_size=vocab_size, |
|
dropout=dropout) |
|
else: |
|
print(f"Unknown model_type {model_type}.") |
|
exit() |
|
|
|
|
|
ddp = int(os.environ.get('RANK', -1)) != -1 |
|
if ddp: |
|
init_process_group(backend=backend) |
|
ddp_rank = int(os.environ['RANK']) |
|
ddp_local_rank = int(os.environ['LOCAL_RANK']) |
|
ddp_world_size = int(os.environ['WORLD_SIZE']) |
|
device = f'cuda:{ddp_local_rank}' |
|
torch.cuda.set_device(device) |
|
master_process = ddp_rank == 0 |
|
seed_offset = ddp_rank |
|
assert gradient_accumulation_steps % ddp_world_size == 0 |
|
gradient_accumulation_steps //= ddp_world_size |
|
else: |
|
master_process = True |
|
seed_offset = 0 |
|
ddp_world_size = 1 |
|
|
|
if master_process: |
|
os.makedirs(out_dir, exist_ok=True) |
|
os.makedirs(anneal_dir, exist_ok=True) |
|
torch.manual_seed(1337 + seed_offset) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' if 'cuda' in device else 'cpu' |
|
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] |
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
|
|
data_dir = os.path.join('data', dataset) |
|
current_train_file_index = 0 |
|
train_files = glob.glob(os.path.join(data_dir, 'train*.parquet')) + \ |
|
glob.glob(os.path.join(data_dir, 'stable*.parquet')) + \ |
|
glob.glob(os.path.join(data_dir, 'anneal*.parquet')) |
|
train_datasets = [] |
|
print("Loading dataset...") |
|
for f in train_files: |
|
dataset = pq.read_table(f).to_pandas() |
|
dataset = dataset[dataset['tokenized'].apply(len) >= 8] |
|
train_datasets.append(dataset) |
|
print('.',end='',flush=True) |
|
print("\nLoaded.") |
|
|
|
|
|
truncated_games_count = 0 |
|
total_games_count = 0 |
|
games_seen = 0 |
|
tokens_seen = 0 |
|
tokens_seen_padded = 0 |
|
def get_batch(split): |
|
global truncated_games_count, total_games_count, current_train_file_index, tokens_seen, tokens_seen_padded |
|
|
|
|
|
dataset = train_datasets[current_train_file_index] if split == 'train' else None |
|
sample_df = dataset.sample(batch_size) |
|
games = sample_df['tokenized'].tolist() |
|
|
|
|
|
max_length_in_batch = min(max(len(game) for game in games), max_seq_len) |
|
pad_to = max_length_in_batch |
|
sequences = torch.zeros((batch_size, pad_to), dtype=torch.int64) |
|
|
|
for i, game in enumerate(games): |
|
total_games_count += 1 |
|
game_len = min(len(game), pad_to) |
|
tokens_seen += game_len |
|
tokens_seen_padded += pad_to |
|
sequences[i, :game_len] = torch.tensor(game[:game_len], dtype=torch.int64) |
|
|
|
if (total_games_count // batch_size) % train_file_update_interval == 0: |
|
current_train_file_index = random.randint(0, len(train_files) - 1) |
|
|
|
|
|
if device_type == 'cuda': |
|
sequences = sequences.pin_memory().to(device, non_blocking=True) |
|
else: |
|
sequences = sequences.to(device) |
|
|
|
return sequences, max_length_in_batch |
|
|
|
|
|
iter_num = 0 |
|
best_val_loss = 1e9 |
|
|
|
|
|
meta_path = os.path.join(data_dir, 'meta.pkl') |
|
meta_vocab_size = None |
|
if not move_num_in_gamestate: |
|
meta_vocab_size = 28 |
|
elif os.path.exists(meta_path): |
|
with open(meta_path, 'rb') as f: |
|
meta = pickle.load(f) |
|
meta_vocab_size = meta['vocab_size'] |
|
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") |
|
|
|
|
|
if init_from == 'scratch': |
|
print(f"Initializing a new {model_type} model from scratch") |
|
if meta_vocab_size is None: |
|
print(f"defaulting to vocab_size of {vocab_size}") |
|
else: |
|
model_config.vocab_size = meta_vocab_size |
|
if model_type == 'mamba': |
|
|
|
model = MambaLMHeadModel(model_config) |
|
else: |
|
model = GPT(model_config) |
|
if auto_clip: |
|
grad_clip = 0 |
|
config['grad_clip'] = 0 |
|
grad_norm_history = [] |
|
elif init_from == 'resume' or init_from == 'anneal': |
|
print(f"Resuming training from {out_dir}") |
|
if init_from == 'anneal': |
|
ckpt_path = os.path.join(out_dir, anneal_checkpoint) |
|
else: |
|
ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
model_config = checkpoint['model_args'] |
|
if model_type == 'mamba': |
|
|
|
model = MambaLMHeadModel(model_config) |
|
else: |
|
model = GPT(model_config) |
|
state_dict = checkpoint['model'] |
|
|
|
|
|
unwanted_prefix = '_orig_mod.' |
|
for k,v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
if 'effective_batch_size' not in checkpoint['config']: |
|
print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.") |
|
checkpoint['config']['effective_batch_size'] = effective_batch_size |
|
iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size))) |
|
if 'games_seen' in checkpoint: |
|
games_seen = checkpoint['games_seen'] |
|
else: |
|
games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num'] |
|
checkpoint['games_seen'] = games_seen |
|
print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}") |
|
tokens_seen = checkpoint.get('tokens_seen', 0) |
|
tokens_seen_padded = checkpoint.get('tokens_seen_padded', 0) |
|
best_val_loss = checkpoint['best_val_loss'] |
|
print(f"Best val loss: {best_val_loss}") |
|
if auto_clip: |
|
grad_clip = checkpoint['config']['grad_clip'] |
|
config['grad_clip'] = grad_clip |
|
|
|
grad_norm_history = checkpoint.get('grad_norm_history', []) |
|
if init_from == 'anneal': |
|
print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n") |
|
anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters'] |
|
anneal_decay_iters = iter_num / 8 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] |
|
print(anneal_start_iters) |
|
print(anneal_decay_iters) |
|
if 'anneal_start_iters' not in checkpoint: |
|
grad_clip = 0 |
|
config['grad_clip'] = 0 |
|
grad_norm_history = [] |
|
print(f"Starting anneal. Resumed from {anneal_checkpoint}, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.") |
|
out_dir = anneal_dir |
|
weight_decay = weight_decay / 12.5 |
|
beta2 = np.sqrt(beta2) * beta2 |
|
auto_clip = True |
|
grad_clip_percentile = 6.75 |
|
elif init_from.startswith('state-spaces'): |
|
print(f"Initializing from Mamba pre-trained weights: {init_from}") |
|
model = from_pretrained(init_from) |
|
model_config = model.config |
|
else: |
|
raise ValueError("Invalid init_from value") |
|
|
|
model.to(device) |
|
|
|
print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.') |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) |
|
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') |
|
if init_from == 'resume': |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
checkpoint = None |
|
|
|
|
|
if compile: |
|
print("compiling the model... (takes a ~minute)") |
|
model = torch.compile(model) |
|
|
|
|
|
if ddp: |
|
model = DDP(model, device_ids=[ddp_local_rank]) |
|
|
|
|
|
def batch_to_loss(sequences, max_length_in_batch): |
|
if model_type == 'mamba': |
|
logits = model(sequences[:, :-1]).logits |
|
|
|
targets = sequences[:, 1:].reshape(-1) |
|
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets) |
|
|
|
else: |
|
inputs = sequences[:, :-1] |
|
targets = sequences[:, 1:].reshape(-1) |
|
_, loss = model(inputs, targets) |
|
return loss |
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(): |
|
global tokens_seen, tokens_seen_padded |
|
out = {} |
|
model.eval() |
|
tokens_seen_b4 = tokens_seen |
|
tokens_seen_padded_b4 = tokens_seen_padded |
|
for split in ['train']: |
|
losses = torch.zeros(eval_iters) |
|
for k in range(eval_iters): |
|
loss = batch_to_loss(*get_batch(split)) |
|
losses[k] = loss.item() |
|
|
|
split = 'val' |
|
out[split] = losses.mean() |
|
tokens_seen = tokens_seen_b4 |
|
tokens_seen_padded = tokens_seen_padded_b4 |
|
model.train() |
|
return out |
|
|
|
|
|
|
|
def get_lr(it): |
|
if init_from == 'anneal': |
|
|
|
decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters |
|
return learning_rate - decay_ratio * (learning_rate - min_lr) |
|
|
|
if it < warmup_iters: |
|
|
|
return learning_rate * it / warmup_iters |
|
|
|
|
|
return learning_rate |
|
|
|
|
|
if wandb_log and master_process: |
|
import wandb |
|
wandb.init(project=wandb_project, name=wandb_run_name, config=config) |
|
|
|
|
|
local_iter_num = 0 |
|
last_crossed_multiple = 0 |
|
save_every_n_games = 150000 |
|
raw_model = model.module if ddp else model |
|
|
|
|
|
if init_from == 'scratch': |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': model_config, |
|
'iter_num': 0, |
|
"games_seen": 0, |
|
"tokens_seen": 0, |
|
"tokens_seen_padded": 0, |
|
'best_val_loss': best_val_loss, |
|
'config': config, |
|
} |
|
checkpoint['grad_norm_history'] = grad_norm_history |
|
print(f"saving checkpoint to {out_dir}\n") |
|
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) |
|
|
|
t0 = time.time() |
|
while True: |
|
|
|
lr = get_lr(iter_num) if decay_lr else learning_rate |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
|
|
if iter_num % eval_interval == 0 and master_process and local_iter_num > 0: |
|
torch.cuda.empty_cache() |
|
losses = estimate_loss() |
|
if init_from == 'anneal': |
|
print(f"\ngame {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): 'val' loss {losses['val']:.4f}") |
|
else: |
|
print(f"\ngame {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): 'val' loss {losses['val']:.4f}") |
|
if auto_clip and len(grad_norm_history) >= grad_clip_start_size: |
|
grad_clip_prev = grad_clip |
|
grad_clip = np.percentile(grad_norm_history, grad_clip_percentile) |
|
grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min) |
|
|
|
grad_clip = (grad_clip*9.0 + grad_clip_prev*4.0) / 13.0 |
|
grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min) |
|
config['grad_clip'] = grad_clip |
|
print(f"Auto adjusted grad_clip to {grad_clip}") |
|
torch.cuda.empty_cache() |
|
if wandb_log: |
|
wandb.log({ |
|
"etc/iter": iter_num, |
|
"etc/games": games_seen, |
|
"etc/tokens_seen": tokens_seen, |
|
"etc/tokens_seen_padded": tokens_seen_padded, |
|
"etc/grad_clip": grad_clip, |
|
"etc/lr": lr, |
|
"val/loss": losses['val'], |
|
|
|
}) |
|
if losses['val'] < best_val_loss or always_save_checkpoint: |
|
if iter_num > 0: |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': model_config, |
|
'iter_num': iter_num, |
|
"games_seen": games_seen, |
|
"tokens_seen": tokens_seen, |
|
"tokens_seen_padded": tokens_seen_padded, |
|
'best_val_loss': min(best_val_loss, losses['val']), |
|
'config': config, |
|
} |
|
checkpoint['grad_norm_history'] = grad_norm_history |
|
if init_from == 'anneal': |
|
checkpoint['anneal_start_iters'] = anneal_start_iters |
|
checkpoint['anneal_decay_iters'] = anneal_decay_iters |
|
print(f"saving checkpoint to {out_dir}\n") |
|
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) |
|
current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games |
|
if losses['val'] < best_val_loss: |
|
best_val_loss = losses['val'] |
|
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt')) |
|
elif current_nearest_multiple != last_crossed_multiple: |
|
last_crossed_multiple = current_nearest_multiple |
|
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt')) |
|
|
|
if iter_num == 0 and eval_only: |
|
break |
|
|
|
|
|
for micro_step in range(gradient_accumulation_steps): |
|
if ddp: |
|
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) |
|
|
|
sequences, max_length_in_batch = get_batch('train') |
|
with ctx: |
|
loss = batch_to_loss(sequences, max_length_in_batch) |
|
loss = loss / gradient_accumulation_steps |
|
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
|
|
if grad_clip != 0.0 or auto_clip: |
|
scaler.unscale_(optimizer) |
|
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) |
|
grad_norm_history.append(total_norm.item()) |
|
grad_norm_history = grad_norm_history[-grad_clip_max_size:] |
|
|
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
t1 = time.time() |
|
dt = t1 - t0 |
|
t0 = t1 |
|
if iter_num % log_interval == 0 and master_process: |
|
|
|
|
|
lossf = loss.item() * gradient_accumulation_steps |
|
if init_from == 'anneal': |
|
print(f"game {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms") |
|
else: |
|
print(f"game {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms") |
|
if wandb_log: |
|
wandb.log({ |
|
"etc/iter": iter_num, |
|
"etc/games": games_seen, |
|
"etc/tokens_seen": tokens_seen, |
|
"etc/tokens_seen_padded": tokens_seen_padded, |
|
"etc/grad_norm": grad_norm_history[-1] if grad_norm_history else 0, |
|
"etc/lr": lr, |
|
"train/loss": lossf, |
|
}) |
|
iter_num += 1 |
|
local_iter_num += 1 |
|
games_seen += effective_batch_size |
|
|
|
|
|
if iter_num > max_iters and not init_from == 'anneal': |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': model_config, |
|
'iter_num': iter_num, |
|
"games_seen": games_seen, |
|
"tokens_seen": tokens_seen, |
|
"tokens_seen_padded": tokens_seen_padded, |
|
'best_val_loss': best_val_loss, |
|
'config': config, |
|
} |
|
checkpoint['grad_norm_history'] = grad_norm_history |
|
if init_from == 'anneal': |
|
checkpoint['anneal_start_iters'] = anneal_start_iters |
|
checkpoint['anneal_decay_iters'] = anneal_decay_iters |
|
print(f"Max_iters reached. Saving pre-anneal checkpoint to {anneal_checkpoint}") |
|
torch.save(checkpoint, os.path.join(out_dir, anneal_checkpoint)) |
|
break |
|
if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters: |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': model_config, |
|
'iter_num': iter_num, |
|
"games_seen": games_seen, |
|
"tokens_seen": tokens_seen, |
|
"tokens_seen_padded": tokens_seen_padded, |
|
'best_val_loss': best_val_loss, |
|
'config': config, |
|
} |
|
checkpoint['grad_norm_history'] = grad_norm_history |
|
if init_from == 'anneal': |
|
checkpoint['anneal_start_iters'] = anneal_start_iters |
|
checkpoint['anneal_decay_iters'] = anneal_decay_iters |
|
print(f"Anneal complete. Saving checkpoint to {out_dir}") |
|
torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt')) |
|
break |
|
|
|
|
|
|
|
if ddp: |
|
destroy_process_group() |
|
|
|
|