MambaMate-Micro / train_rl.py
HaileyStorm's picture
Upload 8 files
816f85a verified
raw
history blame
No virus
21.3 kB
import math
import os
import time
import pickle
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from mamba_lm import MambaLM, MambaLMConfig
import random
import chess
from lczero.backends import Weights, Backend, GameState
# Default config values
out_dir = 'out/play'
save_interval = 50
wandb_project = 'chess-training'
wandb_run_name = 'lc0-training'
init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
# Model parameters
n_layer = 15
d_model = 256
dt_rank = 'auto'
d_state = 16
vocab_size = 28
move_num_in_gamestate = False
# wandb logging
wandb_log = True
wandb_project = 'mamba-rl'
wandb_run_name = 'mamba_run'
# Load openings file
with open("openings.csv", "r") as file:
lines = file.readlines()[1:] # Skip header
opening_lines = lines
# Optimizer settings
learning_rate = 1e-7 #7.25e-7
min_lr = 1e-8 # 1.75e-8
warmup_iters = 600
lr_decay_iters = len(opening_lines)
weight_decay = 1e-2 #5e-3
beta1 = 0.905 #0.915
beta2 = 0.965 #0.95
grad_clip = 0.5 #0.25
min_grad_clip = 1e-3 #1e-3
max_grad_clip = 0.45 #0.45
auto_clip = True
grad_clip_start_size = 150
grad_clip_max_size = 600
grad_clip_percentile = 9
# Game play / loss calculation settings
top_k = 2 # 2
top_k_adj_moves = 40 #999 #35
max_illegal_moves = 8 #2
max_moves = 87
update_freq = 3 #1 # How often to do a backward pass
flush_every = 1
move_reward_scale_factor = 4.0 # 2.125 # scales down the move reward so it's not so dramatic / so that illegal moves (reward -1) are more dramatic by comparison to bad moves
decrease_factor = 0.75 # Bonus for winning (1/x is penalty for losing)
window_size = 300
# DDP settings
backend = 'nccl'
# System
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
compile = False # Set to True if using PyTorch 2.0
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
#exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# Initialize lc0 engines
lc0_weights_opponent = Weights("./lc0/build/release/11258-32x4-se.pb.gz")
lc0_backend_opponent = Backend(weights=lc0_weights_opponent)
lc0_weights_evaluator = Weights("./lc0/build/release/11258-48x5-se.pb.gz")
lc0_backend_evaluator = lc0_backend_opponent #Backend(weights=lc0_weights_evaluator)
# Load tokenizer and decode function
if move_num_in_gamestate:
meta_path = os.path.join(os.path.join('data', 'chess'), 'meta.pkl')
with open(meta_path, "rb") as f:
meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta['vocab_size']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
else:
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
for s in stoi:
assert itos[stoi[s]] == s
vocab_size = len(stoi)
print(f"Vocab size {vocab_size}")
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
# Initialize Mamba model
mamba_config = MambaLMConfig(
d_model=d_model,
n_layers=n_layer,
dt_rank=dt_rank,
d_state=d_state,
vocab_size=vocab_size # Adjust based on your dataset
)
model = MambaLM(mamba_config)
model.to(device)
# Optimizer and GradScaler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
# Compile the model if using PyTorch 2.0
if compile:
print("compiling the model... (takes a ~minute)")
model = torch.compile(model)
ddp = int(os.environ.get('RANK', -1)) != -1
# Wrap model in DDP container if necessary
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
win_rate_window = []
win_only_rate_window = []
# Load checkpoint if resuming training
if init_from == 'resume':
print(f"Resuming training from {out_dir}")
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
mamba_config = checkpoint['model_args']
state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint['optimizer'])
iter_num = checkpoint['iter_num']
games_played = checkpoint['games_seen']
opening_line_index = checkpoint.get('opening_line_index', 0)
win_rate_window = checkpoint.get('win_rate_window', [])
win_only_rate_window = checkpoint.get('win_only_rate_window', [])
best_wr = checkpoint.get('best_wr', 0.0)
best_wor = checkpoint.get('best_wor', 0.0)
if auto_clip:
grad_clip = checkpoint['config']['grad_clip']
config['grad_clip'] = grad_clip
grad_norm_history = checkpoint.get('grad_norm_history', [])
else:
grad_norm_history = []
else:
best_wr = 0.0
best_wor = 0.0
grad_norm_history = []
games_played = 0
iter_num = 0
opening_line_index = 0
if auto_clip:
grad_clip = 0
config['grad_clip'] = 0
def get_model_move(game_state, top_k):
model.train() # Ensure the model is in training mode
encoded_prompt = encode(game_state)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
have_non_space = False
logits_list = [] # Collect logits for analysis and potential loss calculation
for _ in range(8):
logits = model(input_ids)[0, -1, :] # Logits for the last predicted token
# We're using top-k more as a VRAM control, not a decision enhacing tool
if top_k is not None and top_k < logits.size(-1):
logits, indices = torch.topk(logits, top_k)
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = indices[torch.multinomial(probs, 1)]
else:
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
if have_non_space and (next_token_id == 0 or next_token_id==4):
break
else:
have_non_space = True
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
logits_list.append(logits)
del logits, probs
# Decode the sequence to extract the move
model_response = decode(input_ids.squeeze(0).tolist())
try:
move = model_response[len(game_state):].split(";")[0].split()[0] # Extract the first move
except IndexError:
move = None
return move, torch.stack(logits_list) if len(logits_list) > 0 else None
def get_lc0_move(board, backend):
gamestate = GameState(fen=board.fen())
input_planes = gamestate.as_input(backend)
result = backend.evaluate(input_planes)[0]
moves = gamestate.moves()
policy_indices = gamestate.policy_indices()
move_probs = np.array(result.p_softmax(*policy_indices))
try:
best_move_idx = move_probs.argmax()
except:
return None
best_move = moves[best_move_idx]
return chess.Move.from_uci(best_move)
def evaluate_position(fen, backend):
gamestate = GameState(fen=fen)
result = backend.evaluate(gamestate.as_input(backend))[0]
return result.q()
def reward_from_eval(before_eval, after_eval):
diff = after_eval - before_eval
return diff / (move_reward_scale_factor + abs(diff))
def backward_pass(loss):
global grad_norm_history
# Backward pass
scaler.scale(loss).backward()
# clip the gradient
if grad_clip != 0.0 or auto_clip:
scaler.unscale_(optimizer)
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 0.1) # The 0 check is for auto_clip enabled but not enough history
grad_norm_history.append(total_norm.item())
grad_norm_history = grad_norm_history[-grad_clip_max_size:]
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
def play_game():
global top_k
optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
board = chess.Board()
total_loss = 0
illegal_moves = 0
move_count = 0
moves_since_backward = 0
tot_reward = 0
# Load opening from openings.csv
tokens = [m.split(".")[-1] if "." in m else m for m in opening_line.split()]
[board.push_san(m) for m in tokens]
if move_num_in_gamestate:
game_state = opening_line.rstrip() + " "
else:
game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in opening_line.split()])
fail = False
while not board.is_game_over():
before_eval = evaluate_position(board.fen(), lc0_backend_evaluator)
game_state += f"{board.fullmove_number if move_num_in_gamestate else ''}."
model_move, logits = get_model_move(game_state, top_k)
move_reward = -1
if model_move is None or logits is None:
illegal_moves += 1
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
if pinch_hit_move is None:
print("Failed game (lc0 couldn't find pinch-hit move)")
fail = True
tot_reward += move_reward
move_count += 1
break
game_state += f"{board.san(pinch_hit_move)} "
board.push(pinch_hit_move)
else:
try:
#print(model_move)
board.push(board.parse_san(model_move))
game_state += f"{model_move} "
except:
illegal_moves += 1
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
if pinch_hit_move is None:
print("Failed game (lc0 couldn't find pinch-hit move)")
fail = True
tot_reward += move_reward
move_count += 1
break
game_state += f"{board.san(pinch_hit_move)} "
board.push(pinch_hit_move)
else:
if not board.is_valid():
board.pop()
illegal_moves += 1
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
if pinch_hit_move is None:
print("Failed game (lc0 couldn't find pinch-hit move)")
fail = True
tot_reward += move_reward
move_count += 1
break
game_state += f"{board.san(pinch_hit_move)} "
board.push(pinch_hit_move)
else:
after_eval = -evaluate_position(board.fen(), lc0_backend_evaluator)
move_reward = reward_from_eval(before_eval, after_eval)
tot_reward += move_reward
if not board.is_game_over():
black_move = get_lc0_move(board, lc0_backend_opponent)
if black_move is None:
print("Failed game (lc0 couldn't find black move)")
fail = True
move_count += 1
break
game_state += f"{board.san(black_move)} "
board.push(black_move)
if logits is not None:
total_loss += torch.sum(torch.nn.functional.log_softmax(logits, dim=-1) * move_reward)
logits_none = logits is None
del logits
moves_since_backward += 1
if move_count % update_freq == 0 and not board.is_game_over() and not logits_none:
backward_pass(total_loss / moves_since_backward)
total_loss = 0.0 # Reset cumulative loss after update
moves_since_backward = 0
move_count += 1
if move_count == top_k_adj_moves:
top_k = top_k - 1
if move_count >= max_moves:
break
if move_count % flush_every == 0:
torch.cuda.empty_cache()
if move_count >= top_k_adj_moves:
top_k = top_k + 1
# Scale loss based on game result and illegal moves
avg_reward = tot_reward / move_count
#print(f'Avg reward {avg_reward} = {tot_reward} / {move_count}')
scale_factor = torch.tensor([1.0], device=device)
if move_count >= max_moves:
result = "1/2-1/2"
elif fail:
result = "*"
else:
result = board.result()
total_loss = total_loss / moves_since_backward
if result == "0-1": # Black wins
# Increase the loss for a loss, if the reward is negative (if the loss is positive)
scale_factor = torch.tensor([1.0 / decrease_factor], device=device) if avg_reward < 0 and illegal_moves <= max_illegal_moves else scale_factor
#print(f'Black win, scale factor adjusted to {scale_factor} (avg award<0 and illegal less max {avg_reward < 0 and illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}')
elif result == "1-0": # White wins
wdf = decrease_factor / 2.0 if avg_reward <= 0 else 1.0 / decrease_factor
#print(f'White win - adjusted decrease factor {wdf}')
# Don't update as much for (real) wins. Also change the result so our win_rate isn't inflated.
if illegal_moves == 0:
scale_factor = torch.tensor([wdf], device=device)
#print(f'White win, scale factor adjusted to {scale_factor} (0 illegal moves)')
elif illegal_moves <= max_illegal_moves:
scale_factor = torch.tensor([(1 + wdf) / 2], device=device)
#print(f'White win, scale factor adjusted to {scale_factor} ({0 < illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}')
result = "1/2-1/2"
else:
result = "0-1"
# No adjustment to scale_factor
if total_loss.numel():
try:
backward_pass(total_loss * scale_factor)
except:
print("Failed game (final backward pass, result not effected)")
total_loss = 0.0
#print(f'Scale factor {scale_factor.item()}')
return avg_reward / scale_factor.item(), result, illegal_moves, move_count
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
# Training loop
if wandb_log:
import wandb
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
while True:
t0 = time.time()
lr = get_lr(iter_num)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
opening_line = opening_lines[opening_line_index]
if iter_num > 0 and iter_num % save_interval == 0:
if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
grad_clip = max(min(np.percentile(grad_norm_history, grad_clip_percentile), max_grad_clip), min_grad_clip)
config['grad_clip'] = grad_clip
print(f"Auto adjusted grad_clip to {grad_clip}")
#print(f"Game {games_played}: Loss {game_reward:.4f}, Illegal moves {illegal_moves}, Win rate {win_rate:.3f}")
if wandb_log:
wandb.log({
"etc/iter": iter_num,
"etc/lr": lr,
"etc/grad_clip": grad_clip,
"etc/games_played": games_played,
})
# Save checkpoint
raw_model = model.module if ddp else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_played,
'config': config,
'opening_line_index': opening_line_index,
'grad_norm_history': grad_norm_history,
'win_rate_window': win_rate_window,
'win_only_rate_window': win_only_rate_window,
'best_wr': best_wr,
'best_wor': best_wor
}
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
# Play a game against lc0 engine
game_reward, result, illegal_moves, move_count = play_game()
games_played += 1
# Backward passes happen in play_game
# Log game result and update win rate window
t1 = time.time()
dt = t1 - t0
t0 = t1
score = 0.5
if result == "1-0":
score = 1
elif result == "0-1":
score = 0
if result != "*":
win_rate_window.append(score)
win_rate_window = win_rate_window[-window_size:]
win_rate = sum(win_rate_window) / len(win_rate_window)
win_only_rate_window.append(int(score)) #int to discard draws
win_only_rate_window = win_only_rate_window[-window_size:]
win_only_rate = float(sum(win_only_rate_window)) / len(win_only_rate_window)
if win_rate > best_wr:
best_wr = win_rate
raw_model = model.module if ddp else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_played,
'config': config,
'opening_line_index': opening_line_index,
'grad_norm_history': grad_norm_history,
'win_rate_window': win_rate_window,
'best_wr': best_wr,
'best_wor': best_wor
}
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wr{best_wr}.pt'))
elif win_only_rate > best_wor:
best_wor = win_only_rate
raw_model = model.module if ddp else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': mamba_config,
'iter_num': iter_num,
"games_seen": games_played,
'config': config,
'opening_line_index': opening_line_index,
'grad_norm_history': grad_norm_history,
'win_rate_window': win_rate_window,
'best_wr': best_wr,
'best_wor': best_wor
}
print(f"saving checkpoint to {out_dir}\n")
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wor{best_wor}.pt'))
best_wor = max(best_wor, win_only_rate)
print(f"Game {games_played} ({iter_num}, {(iter_num / len(opening_lines)) * 100.0:.3f}%): Score {score}, Reward {game_reward:.4f}, Illegal moves {illegal_moves} ({illegal_moves / move_count:.3%}), Total moves {move_count}, Win rate {win_rate:.3f}, Win only rate {win_only_rate:.3f}, time {dt * 1000:.2f}ms")
if wandb_log:
wandb.log({
"etc/iter": iter_num,
"etc/lr": lr,
"etc/grad_norm_mean": np.mean(grad_norm_history) if grad_norm_history else -1,
"etc/grad_zero_pct": float(np.count_nonzero(grad_norm_history==0))/len(grad_norm_history) if grad_norm_history else -1,
"etc/games_played": games_played,
"eval/game_reward": game_reward,
"eval/illegal_move_pct": illegal_moves / move_count,
"eval/move_ct": move_count,
"eval/win_rate": win_rate,
"eval/win_only_rate": win_only_rate,
})
iter_num += 1
opening_line_index += 1
# Termination condition
if opening_line_index >= len(opening_lines):
break
if ddp:
destroy_process_group()