HaileyStorm commited on
Commit
816f85a
1 Parent(s): e81ef6e

Upload 8 files

Browse files
Mamba11M-move-dist.webp ADDED
Mamba11M-win-rate-detail.webp ADDED
Mamba11M-win-rate.webp ADDED
anneal_complete.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:974d6b8ee8bef4dc514705b9a673a2c26d7e3a571bc2aa486e0a5553645646d1
3
+ size 132459528
anneal_complete_pt_vs_stockfish_sweep.csv ADDED
The diff for this file is too large to render. See raw diff
 
mamba_module.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ from mamba_lm import MambaLM, MambaLMConfig, from_pretrained
5
+ from contextlib import nullcontext
6
+
7
+ BASE_DIR = "mamba/"
8
+
9
+ class MambaPlayer:
10
+ def __init__(self, model_name: str):
11
+ self.model_name = model_name
12
+ # -----------------------------------------------------------------------------
13
+
14
+ init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
15
+ move_num_in_gamestate = True
16
+ out_dir = "out" # ignored if init_from is not 'resume'
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ #device = "cpu"
19
+ dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
20
+ seed = 1337
21
+ compile = False # set to True if using PyTorch 2.0 and Mamba supports it
22
+ # -----------------------------------------------------------------------------
23
+
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+
27
+ device_type = (
28
+ "cuda" if "cuda" in device else "cpu"
29
+ ) # for later use in torch.autocast
30
+ ptdtype = {
31
+ "float32": torch.float32,
32
+ "bfloat16": torch.bfloat16,
33
+ "float16": torch.float16,
34
+ }[dtype]
35
+ ctx = (
36
+ nullcontext()
37
+ if device_type == "cpu"
38
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
39
+ )
40
+
41
+ # Model initialization
42
+ if init_from == "resume":
43
+ #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
44
+ ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
45
+ checkpoint = torch.load(ckpt_path, map_location=device)
46
+ model_config = checkpoint["model_args"]
47
+ model = MambaLM(model_config)
48
+ model.load_state_dict(checkpoint['model'])
49
+ elif init_from.startswith('state-spaces'):
50
+ model = from_pretrained(init_from).to(device)
51
+ else:
52
+ raise ValueError("Invalid init_from value")
53
+
54
+ model.eval()
55
+ model.to(device)
56
+
57
+ if compile and hasattr(torch, 'compile'):
58
+ model = torch.compile(model)
59
+
60
+ # look for the meta pickle in case it is available in the dataset folder
61
+ meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
62
+ load_meta = os.path.exists(meta_path)
63
+ if move_num_in_gamestate and load_meta:
64
+ with open(meta_path, "rb") as f:
65
+ meta = pickle.load(f)
66
+ stoi, itos = meta["stoi"], meta["itos"]
67
+ vocab_size = meta['vocab_size']
68
+ encode = lambda s: [stoi[c] for c in s]
69
+ decode = lambda l: "".join([itos[i] for i in l])
70
+ else:
71
+ stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
72
+ itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
73
+ for s in stoi:
74
+ assert itos[stoi[s]] == s
75
+ vocab_size = len(stoi)
76
+ print(f"Vocab size {vocab_size}")
77
+ encode = lambda s: [stoi[c] for c in s.replace('-', '')]
78
+ decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
79
+
80
+ self.encode = encode
81
+ self.decode = decode
82
+ self.model = model
83
+ self.ctx = ctx
84
+ self.device = device
85
+
86
+ def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
87
+ game_state = game_state.split("\n\n")[-1].strip()
88
+ #game_state = ";" + game_state
89
+
90
+ # Tokenize the game state
91
+ encoded_prompt = self.encode(game_state)
92
+ input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
93
+
94
+ self.model.eval() # Set the model to evaluation mode
95
+ with torch.no_grad():
96
+ have_non_space = False
97
+ for _ in range(max_new_tokens):
98
+ logits = self.model(input_ids)[0, -1, :] # Get logits for the last token
99
+
100
+ # Apply temperature scaling and optionally sample from top k tokens
101
+ logits = logits / temperature
102
+ if top_k > 0:
103
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
104
+ logits[indices_to_remove] = -float('Inf')
105
+
106
+ probs = torch.nn.functional.softmax(logits, dim=-1)
107
+ next_token_id = torch.multinomial(probs, num_samples=1)
108
+ if have_non_space and (next_token_id == 0 or next_token_id==4):
109
+ break
110
+ else:
111
+ have_non_space = True
112
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
113
+
114
+ model_response = self.decode(input_ids[0].tolist())
115
+ model_response = model_response[len(game_state):].split(";")[0]
116
+ return model_response
117
+
118
+ #def encode(self, text: str):
119
+ # Implement the appropriate tokenization for MambaLM
120
+ # This could be a simple mapping or a more complex tokenizer
121
+ # return [stoi[char] for char in text] # Example
122
+
123
+ #def decode(self, token_ids: list):
124
+ # Implement the appropriate decoding for MambaLM
125
+ # return ''.join([itos[id] for id in token_ids]) # Example
126
+
127
+ def get_move_from_response(self, response: str) -> str:
128
+ if not response:
129
+ return None
130
+ # Parse the response to get only the first move
131
+ moves = response.split()
132
+ first_move = moves[0]
133
+ first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
134
+
135
+ return first_move
136
+
137
+ def get_move(self, board: str, game_state: str, temperature: float) -> str:
138
+ completion = self.get_mamba_response(game_state, temperature, 8, 32)
139
+ return self.get_move_from_response(completion)
140
+
141
+ def get_config(self) -> dict:
142
+ return {"model": self.model_name}
143
+
train_bygame.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import pickle
5
+ from contextlib import nullcontext
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+ from mamba_lm import MambaLM, MambaLMConfig
12
+ import pyarrow.parquet as pq
13
+ import random
14
+ from torch.utils.data import Dataset, DataLoader
15
+ import glob
16
+
17
+ # -----------------------------------------------------------------------------
18
+ # default config values designed for Mamba model training
19
+ # I/O
20
+ out_dir = 'out'
21
+ eval_interval = 2000
22
+ log_interval = 1
23
+ eval_iters = 5
24
+ eval_only = False
25
+ always_save_checkpoint = True
26
+ init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
27
+ # wandb logging
28
+ wandb_log = False
29
+ wandb_project = 'mamba'
30
+ wandb_run_name = 'mamba_run' # modify as needed
31
+ # data
32
+ dataset = 'chess' # specify your dataset
33
+ gradient_accumulation_steps = 5 * 8
34
+ batch_size = 12
35
+ base_batch_size = batch_size
36
+ effective_batch_size = batch_size
37
+ max_seq_len = 1024 # A trianing-only parameter for controlling VRAM
38
+ train_file_update_interval = 7
39
+
40
+ # model
41
+ n_layer = 12
42
+ d_model = 768
43
+ dt_rank = 'auto'
44
+ d_state = 16
45
+ expand_factor = 2
46
+ bias = False
47
+ conv_bias = True
48
+ pscan = True
49
+ vocab_size = 32000
50
+ move_num_in_gamestate = True
51
+
52
+ # optimizer settings
53
+ learning_rate = 6e-4
54
+ max_iters = 600000
55
+ weight_decay = 1e-1
56
+ beta1 = 0.9
57
+ beta2 = 0.95
58
+ grad_clip = 1.0
59
+ auto_clip = False
60
+ grad_clip_start_size = 100
61
+ grad_clip_max_size = 500
62
+ grad_clip_percentile = 10
63
+ # learning rate decay settings
64
+ decay_lr = True
65
+ warmup_iters = 2000
66
+ lr_decay_iters = 600000
67
+ min_lr = 6e-5
68
+ # DDP settings
69
+ backend = 'nccl'
70
+ # system
71
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
72
+ dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
73
+ compile = False # set to True if using PyTorch 2.0
74
+ # -----------------------------------------------------------------------------
75
+
76
+ config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77
+ exec(open('configurator.py').read()) # overrides from command line or config file
78
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
79
+ # -----------------------------------------------------------------------------
80
+
81
+ anneal_checkpoint = 'anneal/ckpt.pt' #'anneal_me.pt'
82
+ anneal_dir = os.path.join(out_dir, 'anneal/')
83
+ anneal_start_iters = None # Set at init
84
+ anneal_decay_iters = None # Set at init
85
+
86
+ mamba_config = MambaLMConfig(
87
+ d_model=d_model, # adjust as needed
88
+ n_layers=n_layer, # adjust as needed
89
+ dt_rank=dt_rank,
90
+ d_state=d_state,
91
+ expand_factor=expand_factor,
92
+ bias=bias,
93
+ conv_bias=conv_bias,
94
+ pscan=pscan,
95
+ vocab_size=vocab_size # adjust based on your dataset
96
+ )
97
+
98
+ # DDP and other initializations
99
+ ddp = int(os.environ.get('RANK', -1)) != -1
100
+ if ddp:
101
+ init_process_group(backend=backend)
102
+ ddp_rank = int(os.environ['RANK'])
103
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
104
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
105
+ device = f'cuda:{ddp_local_rank}'
106
+ torch.cuda.set_device(device)
107
+ master_process = ddp_rank == 0
108
+ seed_offset = ddp_rank
109
+ assert gradient_accumulation_steps % ddp_world_size == 0
110
+ gradient_accumulation_steps //= ddp_world_size
111
+ else:
112
+ master_process = True
113
+ seed_offset = 0
114
+ ddp_world_size = 1
115
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
116
+
117
+ if master_process:
118
+ os.makedirs(out_dir, exist_ok=True)
119
+ os.makedirs(anneal_dir, exist_ok=True)
120
+ torch.manual_seed(1337 + seed_offset)
121
+ torch.backends.cuda.matmul.allow_tf32 = True
122
+ torch.backends.cudnn.allow_tf32 = True
123
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
124
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype]
125
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
126
+
127
+ # poor man's data loader
128
+ data_dir = os.path.join('data', dataset)
129
+ current_train_file_index = 0
130
+ train_files = glob.glob(os.path.join(data_dir, 'train*.parquet'))
131
+ train_datasets = []
132
+ for f in train_files:
133
+ dataset = pq.read_table(f).to_pandas()
134
+ dataset = dataset[dataset['tokenized'].apply(len) >= 8]
135
+ train_datasets.append(dataset)
136
+ #val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
137
+ #val_data = val_data[val_data['tokenized'].apply(len) >= 8]
138
+ truncated_games_count = 0
139
+ total_games_count = 0
140
+ games_seen = 0
141
+ def get_batch(split):
142
+ global truncated_games_count, total_games_count, current_train_file_index
143
+
144
+ # Randomly select batch_size games
145
+ dataset = train_datasets[current_train_file_index] if split == 'train' else None # else val_data # Use the correct DataFrame based on the split
146
+ sample_df = dataset.sample(batch_size)
147
+ games = sample_df['tokenized'].tolist()
148
+
149
+ # Prepare sequences tensor for the batch
150
+ max_length_in_batch = min(max(len(game) for game in games), max_seq_len)
151
+ sequences = torch.zeros((batch_size, max_length_in_batch), dtype=torch.int64)
152
+
153
+ for i, game in enumerate(games):
154
+ total_games_count += 1
155
+
156
+ if len(game) > max_seq_len:
157
+ truncated_games_count += 1
158
+ # Randomly decide truncation strategy
159
+ truncation_choice = random.choice(['beginning', 'end', 'end2', 'random'])
160
+ if truncation_choice == 'beginning':
161
+ # Truncatethe beginning (use from the end backward)
162
+ truncated_game = game[-max_seq_len:]
163
+ elif truncation_choice.startswith('end'):
164
+ # Truncatethe end (use from the beginning forward)
165
+ truncated_game = game[:max_seq_len]
166
+ else:
167
+ # Random start index (truncate beginning and end)
168
+ start_idx = random.randint(0, len(game) - max_seq_len)
169
+ truncated_game = game[start_idx:start_idx + max_seq_len]
170
+ sequences[i, :len(truncated_game)] = torch.tensor(truncated_game, dtype=torch.int64)
171
+ # Report the percentage of truncated games
172
+ if truncated_games_count > 0 and truncated_games_count % 50 == 0:
173
+ truncated_percentage = (truncated_games_count / total_games_count) * 100
174
+ print(f"Percentage of truncated games: {truncated_percentage:.2f}%\t\t({truncated_games_count}/{total_games_count})")
175
+ else:
176
+ sequences[i, :len(game)] = torch.tensor(game, dtype=torch.int64)
177
+
178
+ if (total_games_count // batch_size) % train_file_update_interval == 0:
179
+ current_train_file_index = random.randint(0, len(train_files) - 1)
180
+ # print(f"Switched to file: {train_files[current_train_file_index]}")
181
+
182
+ if device_type == 'cuda':
183
+ sequences = sequences.pin_memory().to(device, non_blocking=True)
184
+ else:
185
+ sequences = sequences.to(device)
186
+
187
+ return sequences
188
+
189
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
190
+ iter_num = 0
191
+ best_val_loss = 1e9
192
+
193
+ # attempt to derive vocab_size from the dataset
194
+ meta_path = os.path.join(data_dir, 'meta.pkl')
195
+ meta_vocab_size = None
196
+ if not move_num_in_gamestate:
197
+ meta_vocab_size = 28
198
+ elif os.path.exists(meta_path):
199
+ with open(meta_path, 'rb') as f:
200
+ meta = pickle.load(f)
201
+ meta_vocab_size = meta['vocab_size']
202
+ print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
203
+
204
+ # Model initialization
205
+ if init_from == 'scratch':
206
+ print("Initializing a new Mamba model from scratch")
207
+ if meta_vocab_size is None:
208
+ print(f"defaulting to vocab_size of {vocab_size}")
209
+ else:
210
+ mamba_config.vocab_size = meta_vocab_size
211
+ model = MambaLM(mamba_config)
212
+ if auto_clip:
213
+ grad_clip = 0
214
+ config['grad_clip'] = 0
215
+ grad_norm_history = []
216
+ elif init_from == 'resume' or init_from == 'anneal':
217
+ print(f"Resuming training from {out_dir}")
218
+ if init_from == 'anneal':
219
+ ckpt_path = os.path.join(out_dir, anneal_checkpoint)
220
+ else:
221
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
222
+ checkpoint = torch.load(ckpt_path, map_location=device)
223
+ mamba_config = checkpoint['model_args']
224
+ model = MambaLM(mamba_config)
225
+ state_dict = checkpoint['model']
226
+ # fix the keys of the state dictionary :(
227
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
228
+ unwanted_prefix = '_orig_mod.'
229
+ for k,v in list(state_dict.items()):
230
+ if k.startswith(unwanted_prefix):
231
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
232
+ model.load_state_dict(state_dict)
233
+ if 'effective_batch_size' not in checkpoint['config']:
234
+ print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.")
235
+ checkpoint['config']['effective_batch_size'] = effective_batch_size
236
+ iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size)))
237
+ if 'games_seen' in checkpoint:
238
+ games_seen = checkpoint['games_seen']
239
+ else:
240
+ games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num']
241
+ checkpoint['games_seen'] = games_seen
242
+ print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}")
243
+ best_val_loss = checkpoint['best_val_loss']
244
+ print(f"Best val loss: {best_val_loss}")
245
+ if auto_clip:
246
+ grad_clip = checkpoint['config']['grad_clip']
247
+ config['grad_clip'] = grad_clip
248
+ #grad_norm_history = [t.item() if torch.is_tensor(t) else t for t in checkpoint.get('grad_norm_history', [])]
249
+ grad_norm_history = checkpoint.get('grad_norm_history', [])
250
+ if init_from == 'anneal':
251
+ print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n")
252
+ anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters']
253
+ anneal_decay_iters = iter_num / 7.0 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] # / 9 is og
254
+ print(anneal_start_iters)
255
+ print(anneal_decay_iters)
256
+ if 'anneal_start_iters' not in checkpoint:
257
+ grad_clip = 0
258
+ config['grad_clip'] = 0
259
+ grad_norm_history = []
260
+ print(f"Starting anneal. Resumed from anneal_me.pt, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.")
261
+ out_dir = anneal_dir
262
+ weight_decay = weight_decay / 10.0 # / 17.0
263
+ beta2 = np.sqrt(beta2) * beta2
264
+ auto_clip = True
265
+ grad_clip_percentile = 6.3333 # 6.75
266
+ elif init_from.startswith('state-spaces'):
267
+ print(f"Initializing from Mamba pre-trained weights: {init_from}")
268
+ model = from_pretrained(init_from)
269
+ mamba_config = model.config
270
+ else:
271
+ raise ValueError("Invalid init_from value")
272
+
273
+ model.to(device)
274
+
275
+ print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.')
276
+
277
+ # Optimizer and GradScaler
278
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
279
+ scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
280
+ if init_from == 'resume':
281
+ optimizer.load_state_dict(checkpoint['optimizer'])
282
+ checkpoint = None
283
+
284
+ # Compile the model if using PyTorch 2.0
285
+ if compile:
286
+ print("compiling the model... (takes a ~minute)")
287
+ model = torch.compile(model)
288
+
289
+ # Wrap model in DDP container if necessary
290
+ if ddp:
291
+ model = DDP(model, device_ids=[ddp_local_rank])
292
+
293
+
294
+ @torch.no_grad()
295
+ def estimate_loss():
296
+ out = {}
297
+ model.eval()
298
+ for split in ['train']: #['train', 'val']:
299
+ losses = torch.zeros(eval_iters)
300
+ for k in range(eval_iters):
301
+ tokens = get_batch(split) # Fetch tokens in the correct format
302
+ logits = model(tokens[:, :-1]) # Predict next tokens (ignore last token)
303
+
304
+ # The targets are the tokens shifted by one position
305
+ targets = tokens[:, 1:].reshape(-1) # Flatten targets for cross-entropy
306
+
307
+ # Compute cross-entropy loss between logits and targets
308
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
309
+ losses[k] = loss.item()
310
+
311
+ split = 'val' # Temporary hack
312
+ out[split] = losses.mean()
313
+ model.train()
314
+ return out
315
+
316
+
317
+ # WSD scheduler
318
+ def get_lr(it):
319
+ if init_from == 'anneal':
320
+ # Linear decay from max LR to min LR over (anneal_start_iters / 9) iters
321
+ decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters
322
+ return learning_rate - decay_ratio * (learning_rate - min_lr)
323
+
324
+ if it < warmup_iters:
325
+ # Warmup
326
+ return learning_rate * it / warmup_iters
327
+
328
+ # Stable max LR
329
+ return learning_rate
330
+
331
+ # Logging setup
332
+ if wandb_log and master_process:
333
+ import wandb
334
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
335
+
336
+ # Training loop
337
+ local_iter_num = 0 # Number of iterations in the lifetime of this process
338
+ last_crossed_multiple = 0
339
+ save_every_n_games = 150000
340
+ raw_model = model.module if ddp else model # Unwrap DDP container if needed
341
+
342
+ t0 = time.time()
343
+ while True:
344
+ # Determine and set the learning rate for this iteration
345
+ lr = get_lr(iter_num) if decay_lr else learning_rate
346
+ for param_group in optimizer.param_groups:
347
+ param_group['lr'] = lr
348
+
349
+ # Evaluate the loss on train/val sets and write checkpoints
350
+ if iter_num % eval_interval == 0 and master_process:
351
+ losses = estimate_loss()
352
+ print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}") # Temporary hack
353
+ #print(f"game {games_seen} ({iter_num}): train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
354
+ if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
355
+ grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
356
+ config['grad_clip'] = grad_clip
357
+ print(f"Auto adjusted grad_clip to {grad_clip}")
358
+ if wandb_log:
359
+ wandb.log({
360
+ "iter": iter_num,
361
+ "games": games_seen,
362
+ #"train/loss": losses['train'], # Temporary hack
363
+ "grad_clip": grad_clip,
364
+ "val/loss": losses['val'],
365
+ "lr": lr,
366
+ })
367
+ if losses['val'] < best_val_loss or always_save_checkpoint:
368
+ if iter_num > 0:
369
+ checkpoint = {
370
+ 'model': raw_model.state_dict(),
371
+ 'optimizer': optimizer.state_dict(),
372
+ 'model_args': mamba_config,
373
+ 'iter_num': iter_num,
374
+ "games_seen": games_seen,
375
+ 'best_val_loss': min(best_val_loss, losses['val']),
376
+ 'config': config,
377
+ }
378
+ checkpoint['grad_norm_history'] = grad_norm_history
379
+ if init_from == 'anneal':
380
+ checkpoint['anneal_start_iters'] = anneal_start_iters
381
+ checkpoint['anneal_decay_iters'] = anneal_decay_iters
382
+ print(f"saving checkpoint to {out_dir}\n")
383
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
384
+ current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games
385
+ if losses['val'] < best_val_loss: # Temporary / only good after it's settled
386
+ best_val_loss = losses['val']
387
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
388
+ elif current_nearest_multiple != last_crossed_multiple: # elif so we don't double up
389
+ last_crossed_multiple = current_nearest_multiple
390
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt'))
391
+
392
+ if iter_num == 0 and eval_only:
393
+ break
394
+
395
+ # Forward and backward pass
396
+ for micro_step in range(gradient_accumulation_steps):
397
+ if ddp:
398
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
399
+
400
+ sequences = get_batch('train') # Fetch the training data
401
+ with ctx:
402
+ logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input
403
+ # Compute loss (assuming next token prediction task)
404
+ targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
405
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
406
+ loss = loss / gradient_accumulation_steps
407
+
408
+ scaler.scale(loss).backward()
409
+ #print('.', end='')
410
+
411
+ # clip the gradient
412
+ if grad_clip != 0.0 or auto_clip:
413
+ scaler.unscale_(optimizer)
414
+ total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) # The 0 check is for auto_clip enabled but not enough history
415
+ grad_norm_history.append(total_norm.item())
416
+ grad_norm_history = grad_norm_history[-grad_clip_max_size:]
417
+
418
+ # step the optimizer and scaler if training in fp16
419
+ scaler.step(optimizer)
420
+ scaler.update()
421
+ # flush the gradients as soon as we can, no need for this memory anymore
422
+ optimizer.zero_grad(set_to_none=True)
423
+
424
+ # timing and logging
425
+ t1 = time.time()
426
+ dt = t1 - t0
427
+ t0 = t1
428
+ if iter_num % log_interval == 0 and master_process:
429
+ # get loss as float. note: this is a CPU-GPU sync point
430
+ # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
431
+ lossf = loss.item() * gradient_accumulation_steps
432
+ print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms")
433
+ if wandb_log:
434
+ wandb.log({
435
+ "iter": iter_num,
436
+ "games": games_seen,
437
+ "grad_norm": grad_norm_history[-1] if grad_norm_history else 0,
438
+ "train/loss": lossf,
439
+ "lr": lr,
440
+ })
441
+ iter_num += 1
442
+ local_iter_num += 1
443
+ games_seen += effective_batch_size
444
+
445
+ # termination conditions
446
+ if iter_num > max_iters:
447
+ checkpoint = {
448
+ 'model': raw_model.state_dict(),
449
+ 'optimizer': optimizer.state_dict(),
450
+ 'model_args': mamba_config,
451
+ 'iter_num': iter_num,
452
+ "games_seen": games_seen,
453
+ 'best_val_loss': best_val_loss,
454
+ 'config': config,
455
+ }
456
+ checkpoint['grad_norm_history'] = grad_norm_history
457
+ if init_from == 'anneal':
458
+ checkpoint['anneal_start_iters'] = anneal_start_iters
459
+ checkpoint['anneal_decay_iters'] = anneal_decay_iters
460
+ print(f"Max_iters reached. Saving checkpoint to {out_dir}")
461
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt_final.pt'))
462
+ break
463
+
464
+ if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters:
465
+ checkpoint = {
466
+ 'model': raw_model.state_dict(),
467
+ 'optimizer': optimizer.state_dict(),
468
+ 'model_args': mamba_config,
469
+ 'iter_num': iter_num,
470
+ "games_seen": games_seen,
471
+ 'best_val_loss': best_val_loss,
472
+ 'config': config,
473
+ }
474
+ checkpoint['grad_norm_history'] = grad_norm_history
475
+ if init_from == 'anneal':
476
+ checkpoint['anneal_start_iters'] = anneal_start_iters
477
+ checkpoint['anneal_decay_iters'] = anneal_decay_iters
478
+ print(f"Anneal complete. Saving checkpoint to {out_dir}")
479
+ torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt'))
480
+ break
481
+
482
+
483
+
484
+ if ddp:
485
+ destroy_process_group()
486
+
train_rl.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import time
4
+ import pickle
5
+ from contextlib import nullcontext
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+ from mamba_lm import MambaLM, MambaLMConfig
12
+ import random
13
+ import chess
14
+ from lczero.backends import Weights, Backend, GameState
15
+
16
+ # Default config values
17
+ out_dir = 'out/play'
18
+ save_interval = 50
19
+ wandb_project = 'chess-training'
20
+ wandb_run_name = 'lc0-training'
21
+ init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
22
+
23
+ # Model parameters
24
+ n_layer = 15
25
+ d_model = 256
26
+ dt_rank = 'auto'
27
+ d_state = 16
28
+ vocab_size = 28
29
+ move_num_in_gamestate = False
30
+
31
+
32
+ # wandb logging
33
+ wandb_log = True
34
+ wandb_project = 'mamba-rl'
35
+ wandb_run_name = 'mamba_run'
36
+
37
+ # Load openings file
38
+ with open("openings.csv", "r") as file:
39
+ lines = file.readlines()[1:] # Skip header
40
+ opening_lines = lines
41
+
42
+ # Optimizer settings
43
+ learning_rate = 1e-7 #7.25e-7
44
+ min_lr = 1e-8 # 1.75e-8
45
+ warmup_iters = 600
46
+ lr_decay_iters = len(opening_lines)
47
+ weight_decay = 1e-2 #5e-3
48
+ beta1 = 0.905 #0.915
49
+ beta2 = 0.965 #0.95
50
+ grad_clip = 0.5 #0.25
51
+ min_grad_clip = 1e-3 #1e-3
52
+ max_grad_clip = 0.45 #0.45
53
+ auto_clip = True
54
+ grad_clip_start_size = 150
55
+ grad_clip_max_size = 600
56
+ grad_clip_percentile = 9
57
+
58
+ # Game play / loss calculation settings
59
+ top_k = 2 # 2
60
+ top_k_adj_moves = 40 #999 #35
61
+ max_illegal_moves = 8 #2
62
+ max_moves = 87
63
+ update_freq = 3 #1 # How often to do a backward pass
64
+ flush_every = 1
65
+ move_reward_scale_factor = 4.0 # 2.125 # scales down the move reward so it's not so dramatic / so that illegal moves (reward -1) are more dramatic by comparison to bad moves
66
+ decrease_factor = 0.75 # Bonus for winning (1/x is penalty for losing)
67
+ window_size = 300
68
+
69
+
70
+ # DDP settings
71
+ backend = 'nccl'
72
+
73
+ # System
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+ dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
76
+ compile = False # Set to True if using PyTorch 2.0
77
+
78
+ config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
79
+ #exec(open('configurator.py').read()) # overrides from command line or config file
80
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
81
+
82
+ # Initialize lc0 engines
83
+ lc0_weights_opponent = Weights("./lc0/build/release/11258-32x4-se.pb.gz")
84
+ lc0_backend_opponent = Backend(weights=lc0_weights_opponent)
85
+
86
+ lc0_weights_evaluator = Weights("./lc0/build/release/11258-48x5-se.pb.gz")
87
+ lc0_backend_evaluator = lc0_backend_opponent #Backend(weights=lc0_weights_evaluator)
88
+
89
+ # Load tokenizer and decode function
90
+ if move_num_in_gamestate:
91
+ meta_path = os.path.join(os.path.join('data', 'chess'), 'meta.pkl')
92
+ with open(meta_path, "rb") as f:
93
+ meta = pickle.load(f)
94
+ stoi, itos = meta["stoi"], meta["itos"]
95
+ vocab_size = meta['vocab_size']
96
+ encode = lambda s: [stoi[c] for c in s]
97
+ decode = lambda l: "".join([itos[i] for i in l])
98
+ else:
99
+ stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
100
+ itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
101
+ for s in stoi:
102
+ assert itos[stoi[s]] == s
103
+ vocab_size = len(stoi)
104
+ print(f"Vocab size {vocab_size}")
105
+ encode = lambda s: [stoi[c] for c in s.replace('-', '')]
106
+ decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
107
+
108
+
109
+
110
+ # Initialize Mamba model
111
+ mamba_config = MambaLMConfig(
112
+ d_model=d_model,
113
+ n_layers=n_layer,
114
+ dt_rank=dt_rank,
115
+ d_state=d_state,
116
+ vocab_size=vocab_size # Adjust based on your dataset
117
+ )
118
+
119
+ model = MambaLM(mamba_config)
120
+ model.to(device)
121
+
122
+ # Optimizer and GradScaler
123
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
124
+ scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
125
+
126
+ # Compile the model if using PyTorch 2.0
127
+ if compile:
128
+ print("compiling the model... (takes a ~minute)")
129
+ model = torch.compile(model)
130
+
131
+ ddp = int(os.environ.get('RANK', -1)) != -1
132
+ # Wrap model in DDP container if necessary
133
+ if ddp:
134
+ model = DDP(model, device_ids=[ddp_local_rank])
135
+
136
+ win_rate_window = []
137
+ win_only_rate_window = []
138
+ # Load checkpoint if resuming training
139
+ if init_from == 'resume':
140
+ print(f"Resuming training from {out_dir}")
141
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
142
+ checkpoint = torch.load(ckpt_path, map_location=device)
143
+ mamba_config = checkpoint['model_args']
144
+ state_dict = checkpoint['model']
145
+ # fix the keys of the state dictionary :(
146
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
147
+ unwanted_prefix = '_orig_mod.'
148
+ for k, v in list(state_dict.items()):
149
+ if k.startswith(unwanted_prefix):
150
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
151
+ model.load_state_dict(state_dict)
152
+ optimizer.load_state_dict(checkpoint['optimizer'])
153
+ iter_num = checkpoint['iter_num']
154
+ games_played = checkpoint['games_seen']
155
+ opening_line_index = checkpoint.get('opening_line_index', 0)
156
+ win_rate_window = checkpoint.get('win_rate_window', [])
157
+ win_only_rate_window = checkpoint.get('win_only_rate_window', [])
158
+ best_wr = checkpoint.get('best_wr', 0.0)
159
+ best_wor = checkpoint.get('best_wor', 0.0)
160
+ if auto_clip:
161
+ grad_clip = checkpoint['config']['grad_clip']
162
+ config['grad_clip'] = grad_clip
163
+ grad_norm_history = checkpoint.get('grad_norm_history', [])
164
+ else:
165
+ grad_norm_history = []
166
+ else:
167
+ best_wr = 0.0
168
+ best_wor = 0.0
169
+ grad_norm_history = []
170
+ games_played = 0
171
+ iter_num = 0
172
+ opening_line_index = 0
173
+ if auto_clip:
174
+ grad_clip = 0
175
+ config['grad_clip'] = 0
176
+
177
+
178
+ def get_model_move(game_state, top_k):
179
+ model.train() # Ensure the model is in training mode
180
+ encoded_prompt = encode(game_state)
181
+ input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
182
+
183
+ have_non_space = False
184
+ logits_list = [] # Collect logits for analysis and potential loss calculation
185
+ for _ in range(8):
186
+ logits = model(input_ids)[0, -1, :] # Logits for the last predicted token
187
+
188
+ # We're using top-k more as a VRAM control, not a decision enhacing tool
189
+ if top_k is not None and top_k < logits.size(-1):
190
+ logits, indices = torch.topk(logits, top_k)
191
+ probs = torch.nn.functional.softmax(logits, dim=-1)
192
+ next_token_id = indices[torch.multinomial(probs, 1)]
193
+ else:
194
+ probs = torch.nn.functional.softmax(logits, dim=-1)
195
+ next_token_id = torch.multinomial(probs, num_samples=1)
196
+
197
+ if have_non_space and (next_token_id == 0 or next_token_id==4):
198
+ break
199
+ else:
200
+ have_non_space = True
201
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
202
+ logits_list.append(logits)
203
+ del logits, probs
204
+
205
+ # Decode the sequence to extract the move
206
+ model_response = decode(input_ids.squeeze(0).tolist())
207
+ try:
208
+ move = model_response[len(game_state):].split(";")[0].split()[0] # Extract the first move
209
+ except IndexError:
210
+ move = None
211
+
212
+ return move, torch.stack(logits_list) if len(logits_list) > 0 else None
213
+
214
+ def get_lc0_move(board, backend):
215
+ gamestate = GameState(fen=board.fen())
216
+ input_planes = gamestate.as_input(backend)
217
+ result = backend.evaluate(input_planes)[0]
218
+ moves = gamestate.moves()
219
+ policy_indices = gamestate.policy_indices()
220
+ move_probs = np.array(result.p_softmax(*policy_indices))
221
+ try:
222
+ best_move_idx = move_probs.argmax()
223
+ except:
224
+ return None
225
+ best_move = moves[best_move_idx]
226
+ return chess.Move.from_uci(best_move)
227
+
228
+ def evaluate_position(fen, backend):
229
+ gamestate = GameState(fen=fen)
230
+ result = backend.evaluate(gamestate.as_input(backend))[0]
231
+ return result.q()
232
+
233
+ def reward_from_eval(before_eval, after_eval):
234
+ diff = after_eval - before_eval
235
+ return diff / (move_reward_scale_factor + abs(diff))
236
+
237
+ def backward_pass(loss):
238
+ global grad_norm_history
239
+
240
+ # Backward pass
241
+ scaler.scale(loss).backward()
242
+
243
+ # clip the gradient
244
+ if grad_clip != 0.0 or auto_clip:
245
+ scaler.unscale_(optimizer)
246
+ total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 0.1) # The 0 check is for auto_clip enabled but not enough history
247
+ grad_norm_history.append(total_norm.item())
248
+ grad_norm_history = grad_norm_history[-grad_clip_max_size:]
249
+
250
+ scaler.step(optimizer)
251
+ scaler.update()
252
+ optimizer.zero_grad(set_to_none=True)
253
+
254
+ def play_game():
255
+ global top_k
256
+
257
+ optimizer.zero_grad(set_to_none=True)
258
+ torch.cuda.empty_cache()
259
+ board = chess.Board()
260
+ total_loss = 0
261
+ illegal_moves = 0
262
+ move_count = 0
263
+ moves_since_backward = 0
264
+ tot_reward = 0
265
+
266
+ # Load opening from openings.csv
267
+ tokens = [m.split(".")[-1] if "." in m else m for m in opening_line.split()]
268
+ [board.push_san(m) for m in tokens]
269
+ if move_num_in_gamestate:
270
+ game_state = opening_line.rstrip() + " "
271
+ else:
272
+ game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in opening_line.split()])
273
+ fail = False
274
+
275
+ while not board.is_game_over():
276
+ before_eval = evaluate_position(board.fen(), lc0_backend_evaluator)
277
+ game_state += f"{board.fullmove_number if move_num_in_gamestate else ''}."
278
+ model_move, logits = get_model_move(game_state, top_k)
279
+ move_reward = -1
280
+
281
+ if model_move is None or logits is None:
282
+ illegal_moves += 1
283
+ pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
284
+ if pinch_hit_move is None:
285
+ print("Failed game (lc0 couldn't find pinch-hit move)")
286
+ fail = True
287
+ tot_reward += move_reward
288
+ move_count += 1
289
+ break
290
+ game_state += f"{board.san(pinch_hit_move)} "
291
+ board.push(pinch_hit_move)
292
+ else:
293
+ try:
294
+ #print(model_move)
295
+ board.push(board.parse_san(model_move))
296
+ game_state += f"{model_move} "
297
+ except:
298
+ illegal_moves += 1
299
+ pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
300
+ if pinch_hit_move is None:
301
+ print("Failed game (lc0 couldn't find pinch-hit move)")
302
+ fail = True
303
+ tot_reward += move_reward
304
+ move_count += 1
305
+ break
306
+ game_state += f"{board.san(pinch_hit_move)} "
307
+ board.push(pinch_hit_move)
308
+ else:
309
+ if not board.is_valid():
310
+ board.pop()
311
+ illegal_moves += 1
312
+ pinch_hit_move = get_lc0_move(board, lc0_backend_opponent)
313
+ if pinch_hit_move is None:
314
+ print("Failed game (lc0 couldn't find pinch-hit move)")
315
+ fail = True
316
+ tot_reward += move_reward
317
+ move_count += 1
318
+ break
319
+ game_state += f"{board.san(pinch_hit_move)} "
320
+ board.push(pinch_hit_move)
321
+ else:
322
+ after_eval = -evaluate_position(board.fen(), lc0_backend_evaluator)
323
+ move_reward = reward_from_eval(before_eval, after_eval)
324
+
325
+ tot_reward += move_reward
326
+ if not board.is_game_over():
327
+ black_move = get_lc0_move(board, lc0_backend_opponent)
328
+ if black_move is None:
329
+ print("Failed game (lc0 couldn't find black move)")
330
+ fail = True
331
+ move_count += 1
332
+ break
333
+ game_state += f"{board.san(black_move)} "
334
+ board.push(black_move)
335
+
336
+ if logits is not None:
337
+ total_loss += torch.sum(torch.nn.functional.log_softmax(logits, dim=-1) * move_reward)
338
+ logits_none = logits is None
339
+ del logits
340
+ moves_since_backward += 1
341
+ if move_count % update_freq == 0 and not board.is_game_over() and not logits_none:
342
+ backward_pass(total_loss / moves_since_backward)
343
+ total_loss = 0.0 # Reset cumulative loss after update
344
+ moves_since_backward = 0
345
+ move_count += 1
346
+ if move_count == top_k_adj_moves:
347
+ top_k = top_k - 1
348
+ if move_count >= max_moves:
349
+ break
350
+ if move_count % flush_every == 0:
351
+ torch.cuda.empty_cache()
352
+
353
+ if move_count >= top_k_adj_moves:
354
+ top_k = top_k + 1
355
+ # Scale loss based on game result and illegal moves
356
+ avg_reward = tot_reward / move_count
357
+ #print(f'Avg reward {avg_reward} = {tot_reward} / {move_count}')
358
+ scale_factor = torch.tensor([1.0], device=device)
359
+ if move_count >= max_moves:
360
+ result = "1/2-1/2"
361
+ elif fail:
362
+ result = "*"
363
+ else:
364
+ result = board.result()
365
+ total_loss = total_loss / moves_since_backward
366
+ if result == "0-1": # Black wins
367
+ # Increase the loss for a loss, if the reward is negative (if the loss is positive)
368
+ scale_factor = torch.tensor([1.0 / decrease_factor], device=device) if avg_reward < 0 and illegal_moves <= max_illegal_moves else scale_factor
369
+ #print(f'Black win, scale factor adjusted to {scale_factor} (avg award<0 and illegal less max {avg_reward < 0 and illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}')
370
+ elif result == "1-0": # White wins
371
+ wdf = decrease_factor / 2.0 if avg_reward <= 0 else 1.0 / decrease_factor
372
+ #print(f'White win - adjusted decrease factor {wdf}')
373
+ # Don't update as much for (real) wins. Also change the result so our win_rate isn't inflated.
374
+ if illegal_moves == 0:
375
+ scale_factor = torch.tensor([wdf], device=device)
376
+ #print(f'White win, scale factor adjusted to {scale_factor} (0 illegal moves)')
377
+ elif illegal_moves <= max_illegal_moves:
378
+ scale_factor = torch.tensor([(1 + wdf) / 2], device=device)
379
+ #print(f'White win, scale factor adjusted to {scale_factor} ({0 < illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}')
380
+ result = "1/2-1/2"
381
+ else:
382
+ result = "0-1"
383
+ # No adjustment to scale_factor
384
+
385
+ if total_loss.numel():
386
+ try:
387
+ backward_pass(total_loss * scale_factor)
388
+ except:
389
+ print("Failed game (final backward pass, result not effected)")
390
+ total_loss = 0.0
391
+
392
+ #print(f'Scale factor {scale_factor.item()}')
393
+ return avg_reward / scale_factor.item(), result, illegal_moves, move_count
394
+
395
+
396
+ def get_lr(it):
397
+ # 1) linear warmup for warmup_iters steps
398
+ if it < warmup_iters:
399
+ return learning_rate * it / warmup_iters
400
+ # 2) if it > lr_decay_iters, return min learning rate
401
+ if it > lr_decay_iters:
402
+ return min_lr
403
+ # 3) in between, use cosine decay down to min learning rate
404
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
405
+ assert 0 <= decay_ratio <= 1
406
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
407
+ return min_lr + coeff * (learning_rate - min_lr)
408
+
409
+ # Training loop
410
+ if wandb_log:
411
+ import wandb
412
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
413
+
414
+ while True:
415
+ t0 = time.time()
416
+ lr = get_lr(iter_num)
417
+ for param_group in optimizer.param_groups:
418
+ param_group['lr'] = lr
419
+ opening_line = opening_lines[opening_line_index]
420
+
421
+ if iter_num > 0 and iter_num % save_interval == 0:
422
+ if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
423
+ grad_clip = max(min(np.percentile(grad_norm_history, grad_clip_percentile), max_grad_clip), min_grad_clip)
424
+ config['grad_clip'] = grad_clip
425
+ print(f"Auto adjusted grad_clip to {grad_clip}")
426
+
427
+ #print(f"Game {games_played}: Loss {game_reward:.4f}, Illegal moves {illegal_moves}, Win rate {win_rate:.3f}")
428
+ if wandb_log:
429
+ wandb.log({
430
+ "etc/iter": iter_num,
431
+ "etc/lr": lr,
432
+ "etc/grad_clip": grad_clip,
433
+ "etc/games_played": games_played,
434
+ })
435
+
436
+ # Save checkpoint
437
+ raw_model = model.module if ddp else model
438
+ checkpoint = {
439
+ 'model': raw_model.state_dict(),
440
+ 'optimizer': optimizer.state_dict(),
441
+ 'model_args': mamba_config,
442
+ 'iter_num': iter_num,
443
+ "games_seen": games_played,
444
+ 'config': config,
445
+ 'opening_line_index': opening_line_index,
446
+ 'grad_norm_history': grad_norm_history,
447
+ 'win_rate_window': win_rate_window,
448
+ 'win_only_rate_window': win_only_rate_window,
449
+ 'best_wr': best_wr,
450
+ 'best_wor': best_wor
451
+ }
452
+ print(f"saving checkpoint to {out_dir}\n")
453
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
454
+
455
+ # Play a game against lc0 engine
456
+ game_reward, result, illegal_moves, move_count = play_game()
457
+ games_played += 1
458
+
459
+ # Backward passes happen in play_game
460
+
461
+ # Log game result and update win rate window
462
+ t1 = time.time()
463
+ dt = t1 - t0
464
+ t0 = t1
465
+ score = 0.5
466
+ if result == "1-0":
467
+ score = 1
468
+ elif result == "0-1":
469
+ score = 0
470
+ if result != "*":
471
+ win_rate_window.append(score)
472
+ win_rate_window = win_rate_window[-window_size:]
473
+ win_rate = sum(win_rate_window) / len(win_rate_window)
474
+ win_only_rate_window.append(int(score)) #int to discard draws
475
+ win_only_rate_window = win_only_rate_window[-window_size:]
476
+ win_only_rate = float(sum(win_only_rate_window)) / len(win_only_rate_window)
477
+ if win_rate > best_wr:
478
+ best_wr = win_rate
479
+ raw_model = model.module if ddp else model
480
+ checkpoint = {
481
+ 'model': raw_model.state_dict(),
482
+ 'optimizer': optimizer.state_dict(),
483
+ 'model_args': mamba_config,
484
+ 'iter_num': iter_num,
485
+ "games_seen": games_played,
486
+ 'config': config,
487
+ 'opening_line_index': opening_line_index,
488
+ 'grad_norm_history': grad_norm_history,
489
+ 'win_rate_window': win_rate_window,
490
+ 'best_wr': best_wr,
491
+ 'best_wor': best_wor
492
+ }
493
+ print(f"saving checkpoint to {out_dir}\n")
494
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wr{best_wr}.pt'))
495
+ elif win_only_rate > best_wor:
496
+ best_wor = win_only_rate
497
+ raw_model = model.module if ddp else model
498
+ checkpoint = {
499
+ 'model': raw_model.state_dict(),
500
+ 'optimizer': optimizer.state_dict(),
501
+ 'model_args': mamba_config,
502
+ 'iter_num': iter_num,
503
+ "games_seen": games_played,
504
+ 'config': config,
505
+ 'opening_line_index': opening_line_index,
506
+ 'grad_norm_history': grad_norm_history,
507
+ 'win_rate_window': win_rate_window,
508
+ 'best_wr': best_wr,
509
+ 'best_wor': best_wor
510
+ }
511
+ print(f"saving checkpoint to {out_dir}\n")
512
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wor{best_wor}.pt'))
513
+ best_wor = max(best_wor, win_only_rate)
514
+ print(f"Game {games_played} ({iter_num}, {(iter_num / len(opening_lines)) * 100.0:.3f}%): Score {score}, Reward {game_reward:.4f}, Illegal moves {illegal_moves} ({illegal_moves / move_count:.3%}), Total moves {move_count}, Win rate {win_rate:.3f}, Win only rate {win_only_rate:.3f}, time {dt * 1000:.2f}ms")
515
+ if wandb_log:
516
+ wandb.log({
517
+ "etc/iter": iter_num,
518
+ "etc/lr": lr,
519
+ "etc/grad_norm_mean": np.mean(grad_norm_history) if grad_norm_history else -1,
520
+ "etc/grad_zero_pct": float(np.count_nonzero(grad_norm_history==0))/len(grad_norm_history) if grad_norm_history else -1,
521
+ "etc/games_played": games_played,
522
+ "eval/game_reward": game_reward,
523
+ "eval/illegal_move_pct": illegal_moves / move_count,
524
+ "eval/move_ct": move_count,
525
+ "eval/win_rate": win_rate,
526
+ "eval/win_only_rate": win_only_rate,
527
+ })
528
+
529
+ iter_num += 1
530
+ opening_line_index += 1
531
+
532
+ # Termination condition
533
+ if opening_line_index >= len(opening_lines):
534
+ break
535
+
536
+ if ddp:
537
+ destroy_process_group()