MambaMate-Micro / train_rl.py
HaileyStorm's picture
Upload 8 files
816f85a verified
import math
import os
import time
import pickle
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from mamba_lm import MambaLM, MambaLMConfig
import random
import chess
from lczero.backends import Weights, Backend, GameState
# Default config values
out_dir = 'out/play'
save_interval = 50
wandb_project = 'chess-training'
wandb_run_name = 'lc0-training'
init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
# Model parameters
n_layer = 15
d_model = 256
dt_rank = 'auto'
d_state = 16
vocab_size = 28
move_num_in_gamestate = False
# wandb logging
wandb_log = True
wandb_project = 'mamba-rl'
wandb_run_name = 'mamba_run'
# Load openings file
with open("openings.csv", "r") as file:
lines = file.readlines()[1:] # Skip header
opening_lines = lines
# Optimizer settings
learning_rate = 1e-7 #7.25e-7
min_lr = 1e-8 # 1.75e-8
warmup_iters = 600
lr_decay_iters = len(opening_lines)
weight_decay = 1e-2 #5e-3
beta1 = 0.905 #0.915
beta2 = 0.965 #0.95
grad_clip = 0.5 #0.25
min_grad_clip = 1e-3 #1e-3
max_grad_clip = 0.45 #0.45
auto_clip = True
grad_clip_start_size = 150
grad_clip_max_size = 600
grad_clip_percentile = 9
# Game play / loss calculation settings
top_k = 2 # 2
top_k_adj_moves = 40 #999 #35
max_illegal_moves = 8 #2
max_moves = 87
update_freq = 3 #1 # How often to do a backward pass
flush_every = 1
move_reward_scale_factor = 4.0 # 2.125 # scales down the move reward so it's not so dramatic / so that illegal moves (reward -1) are more dramatic by comparison to bad moves
decrease_factor = 0.75 # Bonus for winning (1/x is penalty for losing)
window_size = 300
# DDP settings
backend = 'nccl'
# System
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
compile = False # Set to True if using PyTorch 2.0
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
#exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# Initialize lc0 engines
lc0_weights_opponent = Weights("./lc0/build/release/11258-32x4-se.pb.gz")
lc0_backend_opponent = Backend(weights=lc0_weights_opponent)
lc0_weights_evaluator = Weights("./lc0/build/release/11258-48x5-se.pb.gz")
lc0_backend_evaluator = lc0_backend_opponent #Backend(weights=lc0_weights_evaluator)
# Load tokenizer and decode function
if move_num_in_gamestate:
meta_path = os.path.join(os.path.join('data', 'chess'), 'meta.pkl')
with open(meta_path, "rb") as f:
meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta['vocab_size']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
else:
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
for s in stoi:
assert itos[stoi[s]] == s
vocab_size = len(stoi)
print(f"Vocab size {vocab_size}")
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
# Initialize Mamba model
mamba_config = MambaLMConfig(
d_model=d_model,
n_layers=n_layer,
dt_rank=dt_rank,
d_state=d_state,
vocab_size=vocab_size # Adjust based on your dataset
)
model = MambaLM(mamba_config)
model.to(device)
# Optimizer and GradScaler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
# Compile the model if using PyTorch 2.0
if compile:
print("compiling the model... (takes a ~minute)")
model = torch.compile(model)
ddp = int(os.environ.get('RANK', -1)) != -1
# Wrap model in DDP container if necessary
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
win_rate_window = []
win_only_rate_window = []
# Load checkpoint if resuming training
if init_from == 'resume':
print(f"Resuming training from {out_dir}")
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
mamba_config = checkpoint['model_args']
state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint['optimizer'])
iter_num = checkpoint['iter_num']
games_played = checkpoint['games_seen']
opening_line_index = checkpoint.get('opening_line_index', 0)
win_rate_window = checkpoint.get('win_rate_window', [])
win_only_rate_window = checkpoint.get('win_only_rate_window', [])
best_wr = checkpoint.get('best_wr', 0.0)
best_wor = checkpoint.get('best_wor', 0.0)
if auto_clip:
grad_clip = checkpoint['config']['grad_clip']
config['grad_clip'] = grad_clip
grad_norm_history = checkpoint.get('grad_norm_history', [])
else:
grad_norm_history = []
else:
best_wr = 0.0
best_wor = 0.0
grad_norm_history = []
games_played = 0
iter_num = 0
opening_line_index = 0
if auto_clip:
grad_clip = 0
config['grad_clip'] = 0
def get_model_move(game_state, top_k):
model.train() # Ensure the model is in training mode
encoded_prompt = encode(game_state)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
have_non_space = False
logits_list = [] # Collect logits for analysis and potential loss calculation
for _ in range(8):
logits = model(input_ids)[0, -1, :] # Logits for the last predicted token
# We're using top-k more as a VRAM control, not a decision enhacing tool
if top_k is not None and top_k < logits.size(-1):
logits, indices = torch.topk(logits, top_k)
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = indices[torch.multinomial(probs, 1)]
else:
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
if have_non_space and (next_token_id == 0 or next_token_id==4):
break
else:
have_non_space = True
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
logits_list.append(logits)
del logits, probs
# Decode the sequence to extract the move
model_response = decode(input_ids.squeeze(0).tolist())
try:
move = model_response[len(game_state):].split(";")[0].split()[0] # Extract the first move
except IndexError:
move = None
return move, torch.stack(logits_list) if len(logits_list) > 0 else None
def get_lc0_move(board, backend):
gamestate = GameState(fen=board.fen())
input_planes = gamestate.as_input(backend)
result = backend.evaluate(input_planes)[0]
moves = gamestate.moves()
policy_indices = gamestate.policy_indices()
move_probs = np.array(result.p_softmax(*policy_indices))
try:
best_move_idx = move_probs.argmax()
except:
return None
best_move = moves[best_move_idx]
return chess.Move.from_uci(best_move)
def evaluate_position(fen, backend):
gamestate = GameState(fen=fen)
result = backend.evaluate(gamestate.as_input(backend))[0]
return result.q()
def reward_from_eval(before_eval, after_eval):
diff = after_eval - before_eval
return diff / (move_reward_scale_factor + abs(diff))
def backward_pass(loss):
global grad_norm_history
# Backward pass
scaler.scale(loss).backward()
# clip the gradient
if grad_clip != 0.0 or auto_clip:
scaler.unscale_(optimizer)
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 0.1) # The 0 check is for auto_clip enabled but not enough history
grad_norm_history.append(total_norm.item())
grad_norm_history = grad_norm_history[-grad_clip_max_size:]
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
def play_game():
global top_k
optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
board = chess.Board()
total_loss = 0
illegal_moves = 0
move_count = 0
moves_since_backward = 0
tot_reward = 0
# Load opening from openings.csv
tokens = [m.split(".")[-1] if "." in m else m for m in opening_line.split()]
[board.push_san(m) for m in tokens]
if move_num_in_gamestate:
game_state = opening_line.rstrip() + " "
else:
game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in opening_line.split()])
fail = False
while not board.is_game_over():
before_eval = evaluate_position(board.fen(), lc0_backend_evaluator)
game_state += f"{board.fullmove_number if move_num_in_gamestate else ''}."
model_move, logits = get_model_move(game_state, top_k)
move_reward = -1
if model_move is None or logits is None:
illegal_moves += 1
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
if pinch_hit_move is None:
print("Failed game (lc0 couldn't find pinch-hit move)")
fail = True
tot_reward += move_reward
move_count += 1
break
game_state += f"{board.san(pinch_hit_move)} "
board.push(pinch_hit_move)
else:
try:
#print(model_move)
board.push(board.parse_san(model_move))
game_state += f"{model_move} "
except:
illegal_moves += 1
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
if pinch_hit_move is None:
print("Failed game (lc0 couldn't find pinch-hit move)")
fail = True
tot_reward += move_reward
move_count += 1
break
game_state += f"{board.san(pinch_hit_move)} "
board.push(pinch_hit_move)
else:
if not board.is_valid():
board.pop()
illegal_moves += 1
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
if pinch_hit_move is None:
print("Failed game (lc0 couldn't find pinch-hit move)")
fail = True
tot_reward += move_reward
move_count += 1
break
game_state += f"{board.san(pinch_hit_move)} "
board.push(pinch_hit_move)
else:
after_eval = -evaluate_position(board.fen(), lc0_backend_evaluator)
move_reward = reward_from_eval(before_eval, after_eval)
tot_reward += move_reward
if not board.is_game_over():
black_move = get_lc0_move(board, lc0_backend_opponent)
if black_move is None:
print("Failed game (lc0 couldn't find black move)")
fail = True
move_count += 1
break
game_state += f"{board.san(black_move)} "
board.push(black_move)
if logits is not None:
total_loss += torch.sum(torch.nn.functional.log_softmax(logits, dim=-1) * move_reward)
logits_none = logits is None
del logits
moves_since_backward += 1
if move_count % update_freq == 0 and not board.is_game_over() and not logits_none:
backward_pass(total_loss / moves_since_backward)
total_loss = 0.0 # Reset cumulative loss after update
moves_since_backward = 0
move_count += 1
if move_count == top_k_adj_moves:
top_k = top_k - 1
if move_count >= max_moves:
break
if move_count % flush_every == 0:
torch.cuda.empty_cache()
if move_count >= top_k_adj_moves:
top_k = top_k + 1
# Scale loss based on game result and illegal moves
avg_reward = tot_reward / move_count
#print(f'Avg reward {avg_reward} = {tot_reward} / {move_count}')
scale_factor = torch.tensor([1.0], device=device)
if move_count >= max_moves:
result = "1/2-1/2"
elif fail:
result = "*"
else:
result = board.result()
total_loss = total_loss / moves_since_backward
if result == "0-1": # Black wins
# Increase the loss for a loss, if the reward is negative (if the loss is positive)
scale_factor = torch.tensor([1.0 / decrease_factor], device=device) if avg_reward < 0 and illegal_moves <= max_illegal_moves else scale_factor
#print(f'Black win, scale factor adjusted to {scale_factor} (avg award<0 and illegal less max {avg_reward < 0 and illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}')
elif result == "1-0": # White wins
wdf = decrease_factor / 2.0 if avg_reward <= 0 else 1.0 / decrease_factor
#print(f'White win - adjusted decrease factor {wdf}')
# Don't update as much for (real) wins. Also change the result so our win_rate isn't inflated.
if illegal_moves == 0:
scale_factor = torch.tensor([wdf], device=device)
#print(f'White win, scale factor adjusted to {scale_factor} (0 illegal moves)')
elif illegal_moves <= max_illegal_moves:
scale_factor = torch.tensor([(1 + wdf) / 2], device=device)
#print(f'White win, scale factor adjusted to {scale_factor} ({0 < illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}')
result = "1/2-1/2"
else:
result = "0-1"
# No adjustment to scale_factor
if total_loss.numel():
try:
backward_pass(total_loss * scale_factor)
except:
print("Failed game (final backward pass, result not effected)")
total_loss = 0.0
#print(f'Scale factor {scale_factor.item()}')
return avg_reward / scale_factor.item(), result, illegal_moves, move_count
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
# Training loop
if wandb_log:
import wandb
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
while True:
t0 = time.time()
lr = get_lr(iter_num)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
opening_line = opening_lines[opening_line_index]
if iter_num > 0 and iter_num % save_interval == 0:
if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
grad_clip = max(min(np.percentile(grad_norm_history, grad_clip_percentile), max_grad_clip), min_grad_clip)
config['grad_clip'] = grad_clip
print(f"Auto adjusted grad_clip to {grad_clip}")
#print(f"Game {games_played}: Loss {game_reward:.4f}, Illegal moves {illegal_moves}, Win rate {win_rate:.3f}")
if wandb_log:
wandb.log({
"etc/iter": iter_num,
"etc/lr": lr,
"etc/grad_clip": grad_clip,
"etc/games_played": games_played,
})
# Save checkpoint
raw_model = model.module if ddp else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_played,
'config': config,
'opening_line_index': opening_line_index,
'grad_norm_history': grad_norm_history,
'win_rate_window': win_rate_window,
'win_only_rate_window': win_only_rate_window,
'best_wr': best_wr,
'best_wor': best_wor
}
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
# Play a game against lc0 engine
game_reward, result, illegal_moves, move_count = play_game()
games_played += 1
# Backward passes happen in play_game
# Log game result and update win rate window
t1 = time.time()
dt = t1 - t0
t0 = t1
score = 0.5
if result == "1-0":
score = 1
elif result == "0-1":
score = 0
if result != "*":
win_rate_window.append(score)
win_rate_window = win_rate_window[-window_size:]
win_rate = sum(win_rate_window) / len(win_rate_window)
win_only_rate_window.append(int(score)) #int to discard draws
win_only_rate_window = win_only_rate_window[-window_size:]
win_only_rate = float(sum(win_only_rate_window)) / len(win_only_rate_window)
if win_rate > best_wr:
best_wr = win_rate
raw_model = model.module if ddp else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_played,
'config': config,
'opening_line_index': opening_line_index,
'grad_norm_history': grad_norm_history,
'win_rate_window': win_rate_window,
'best_wr': best_wr,
'best_wor': best_wor
}
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wr{best_wr}.pt'))
elif win_only_rate > best_wor:
best_wor = win_only_rate
raw_model = model.module if ddp else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_played,
'config': config,
'opening_line_index': opening_line_index,
'grad_norm_history': grad_norm_history,
'win_rate_window': win_rate_window,
'best_wr': best_wr,
'best_wor': best_wor
}
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wor{best_wor}.pt'))
best_wor = max(best_wor, win_only_rate)
print(f"Game {games_played} ({iter_num}, {(iter_num / len(opening_lines)) * 100.0:.3f}%): Score {score}, Reward {game_reward:.4f}, Illegal moves {illegal_moves} ({illegal_moves / move_count:.3%}), Total moves {move_count}, Win rate {win_rate:.3f}, Win only rate {win_only_rate:.3f}, time {dt * 1000:.2f}ms")
if wandb_log:
wandb.log({
"etc/iter": iter_num,
"etc/lr": lr,
"etc/grad_norm_mean": np.mean(grad_norm_history) if grad_norm_history else -1,
"etc/grad_zero_pct": float(np.count_nonzero(grad_norm_history==0))/len(grad_norm_history) if grad_norm_history else -1,
"etc/games_played": games_played,
"eval/game_reward": game_reward,
"eval/illegal_move_pct": illegal_moves / move_count,
"eval/move_ct": move_count,
"eval/win_rate": win_rate,
"eval/win_only_rate": win_only_rate,
})
iter_num += 1
opening_line_index += 1
# Termination condition
if opening_line_index >= len(opening_lines):
break
if ddp:
destroy_process_group()