HaileyStorm commited on
Commit
f8519d1
1 Parent(s): 624e9a2

Upload 37 files

Browse files
Files changed (37) hide show
  1. chess-gpt-eval/gpt_query.py +265 -0
  2. chess-gpt-eval/llama_module.py +71 -0
  3. chess-gpt-eval/main.py +565 -0
  4. chess-gpt-eval/mamba.py +368 -0
  5. chess-gpt-eval/mamba/out/meta.pkl +3 -0
  6. chess-gpt-eval/mamba_lm.py +168 -0
  7. chess-gpt-eval/mamba_module.py +144 -0
  8. chess-gpt-eval/nanogpt/__pycache__/model.cpython-310.pyc +0 -0
  9. chess-gpt-eval/nanogpt/__pycache__/nanogpt_module.cpython-310.pyc +0 -0
  10. chess-gpt-eval/nanogpt/__pycache__/xformer.cpython-310.pyc +0 -0
  11. chess-gpt-eval/nanogpt/configurator.py +47 -0
  12. chess-gpt-eval/nanogpt/model.py +330 -0
  13. chess-gpt-eval/nanogpt/nanogpt_module.py +148 -0
  14. chess-gpt-eval/nanogpt/out/meta.pkl +3 -0
  15. chess-gpt-eval/nanogpt/out/view_ckpt.ipynb +61 -0
  16. chess-gpt-eval/openings.csv +0 -0
  17. chess-gpt-eval/pscan.py +226 -0
  18. chess-gpt-eval/requirements.txt +6 -0
  19. chess-gpt-eval/xformer.py +330 -0
  20. chess-mamba-vs-xformer/config/Mamba/11M.py +70 -0
  21. chess-mamba-vs-xformer/config/Mamba/250M.py +70 -0
  22. chess-mamba-vs-xformer/config/Mamba/29M.py +70 -0
  23. chess-mamba-vs-xformer/config/Mamba/50M.py +70 -0
  24. chess-mamba-vs-xformer/config/Mamba/6.6M.py +70 -0
  25. chess-mamba-vs-xformer/config/Xformer/11M.py +70 -0
  26. chess-mamba-vs-xformer/config/Xformer/250M.py +70 -0
  27. chess-mamba-vs-xformer/config/Xformer/29M.py +70 -0
  28. chess-mamba-vs-xformer/config/Xformer/50M.py +70 -0
  29. chess-mamba-vs-xformer/config/Xformer/6.6M.py +70 -0
  30. chess-mamba-vs-xformer/configurator.py +47 -0
  31. chess-mamba-vs-xformer/data/anneal/anneal.zip +3 -0
  32. chess-mamba-vs-xformer/mamba.py +368 -0
  33. chess-mamba-vs-xformer/mamba_lm.py +168 -0
  34. chess-mamba-vs-xformer/openings.csv +0 -0
  35. chess-mamba-vs-xformer/pscan.py +226 -0
  36. chess-mamba-vs-xformer/train_bygame.py +541 -0
  37. chess-mamba-vs-xformer/xformer.py +330 -0
chess-gpt-eval/gpt_query.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import tiktoken
3
+ import json
4
+ import os
5
+
6
+ # import replicate
7
+
8
+ # for hugging face inference endpoints for codellama
9
+ import requests
10
+
11
+ from typing import Optional
12
+
13
+ from tenacity import (
14
+ retry,
15
+ stop_after_attempt,
16
+ wait_random_exponential,
17
+ ) # for exponential backoff
18
+
19
+ # system message is used in openai_request()
20
+ system_message = """Provide the next move in the chess game. Only provide the move, no move numbers."""
21
+
22
+ # dollars per 1k tokens, per openai.com/pricing
23
+ pricing_dict = {
24
+ "gpt-4": 0.03,
25
+ "gpt-4-0301": 0.03,
26
+ "gpt-4-0613": 0.03,
27
+ "gpt-3.5-turbo": 0.0015,
28
+ "gpt-3.5-turbo-0301": 0.0015,
29
+ "gpt-3.5-turbo-0613": 0.0015,
30
+ "gpt-3.5-turbo-16k": 0.003,
31
+ "babbage": 0.0005,
32
+ "gpt-3.5-turbo-instruct": 0.0015,
33
+ }
34
+
35
+ MAX_TOKENS = 10
36
+
37
+ completion_models = [
38
+ "gpt-3.5-turbo-instruct",
39
+ "babbage",
40
+ "davinci",
41
+ ]
42
+
43
+
44
+ # tenacity is to handle anytime a request fails
45
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
46
+ def get_gpt_response(
47
+ prompt: str, model: str = "gpt-4", temperature: float = 0.0
48
+ ) -> Optional[str]:
49
+ try:
50
+ messages = []
51
+ # system message is used in openai_request()
52
+ # system_message_dict = {
53
+ # "role": "system",
54
+ # "content": system_message,
55
+ # }
56
+ initial_message = {"role": "user", "content": prompt}
57
+ messages.append(initial_message)
58
+
59
+ record_messages(messages, model)
60
+
61
+ # num_tokens = count_all_tokens(model, messages)
62
+ # prompt_cost = get_prompt_cost(model, num_tokens)
63
+ # print("prompt cost in $:", prompt_cost)
64
+
65
+ if model in completion_models:
66
+ response = get_completions_response(model, messages, temperature)
67
+ elif model.startswith("gpt"):
68
+ response = openai_chat_completion_request(model, messages, temperature)
69
+ elif model.startswith("openrouter"):
70
+ response = openrouter_request(model, messages, temperature)
71
+ elif model.startswith("huggingface"):
72
+ response = hugging_face_request(model, messages, temperature)
73
+ elif model.startswith("replicate"):
74
+ response = replicate_request(model, messages, temperature)
75
+ else:
76
+ raise Exception("Invalid model name")
77
+
78
+ # response_cost = get_response_cost(model, count_tokens(model, response))
79
+ # print("response cost in $:", response_cost)
80
+
81
+ messages.append({"role": "assistant", "content": response})
82
+ record_messages(messages, model)
83
+
84
+ return response
85
+ except Exception as e:
86
+ print(f"Error while getting GPT response: {e}")
87
+ return None
88
+
89
+
90
+ def openai_chat_completion_request(
91
+ model: str, messages: list[dict], temperature: float
92
+ ) -> str:
93
+ system_message_dict = {
94
+ "role": "system",
95
+ "content": system_message,
96
+ }
97
+ messages.append(system_message_dict)
98
+ completion = openai.ChatCompletion.create(
99
+ model=model,
100
+ temperature=temperature,
101
+ messages=messages,
102
+ )
103
+ response = completion.choices[0].message.content
104
+ return response
105
+
106
+
107
+ def openrouter_request(model: str, messages: list[dict], temperature: float) -> str:
108
+ if temperature == 0:
109
+ temperature = 0.001
110
+
111
+ with open("gpt_inputs/openrouter_api_key.txt", "r") as f:
112
+ openai.api_key = f.read().strip()
113
+
114
+ openai.api_base = "https://openrouter.ai/api/v1"
115
+ OPENROUTER_REFERRER = "https://github.com/adamkarvonen/nanoGPT"
116
+
117
+ model = model.replace("openrouter/", "")
118
+
119
+ completion = openai.ChatCompletion.create(
120
+ model=model,
121
+ headers={"HTTP-Referer": OPENROUTER_REFERRER},
122
+ messages=messages,
123
+ temperature=temperature,
124
+ max_tokens=MAX_TOKENS,
125
+ )
126
+ response = completion.choices[0].message.content
127
+ return response
128
+
129
+
130
+ def replicate_request(model: str, messages: list[dict], temperature: float) -> str:
131
+ if temperature == 0:
132
+ temperature = 0.001
133
+
134
+ with open("gpt_inputs/replicate_api_key.txt", "r") as f:
135
+ api_key = f.read().strip()
136
+ os.environ["REPLICATE_API_TOKEN"] = api_key
137
+
138
+ model = model.replace("replicate/", "")
139
+
140
+ messages = translate_to_string_input(messages)
141
+
142
+ output = replicate.run(
143
+ model,
144
+ input={
145
+ "prompt": messages,
146
+ "max_new_tokens": MAX_TOKENS,
147
+ "temperature": temperature,
148
+ },
149
+ )
150
+
151
+ # The meta/llama-2-7b model can stream output as it's running.
152
+ response = ""
153
+ # The predict method returns an iterator, and you can iterate over that output.
154
+ for item in output:
155
+ # https://replicate.com/meta/llama-2-7b/versions/527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef/api#output-schema
156
+ response += item
157
+
158
+ return response
159
+
160
+
161
+ def hugging_face_request(model: str, messages: list[dict], temperature: float) -> str:
162
+ def query(payload):
163
+ response = requests.post(API_URL, headers=headers, json=payload)
164
+ return response.json()
165
+
166
+ messages = translate_to_string_input(messages)
167
+ API_URL = "https://xxxxxxxx.us-east-1.aws.endpoints.huggingface.cloud"
168
+ headers = {
169
+ "Authorization": "Bearer xxxxx",
170
+ "Content-Type": "application/json",
171
+ }
172
+
173
+ if temperature == 0:
174
+ temperature = 0.001
175
+
176
+ output = query(
177
+ {
178
+ "inputs": messages,
179
+ "parameters": {"temperature": temperature, "max_new_tokens": MAX_TOKENS},
180
+ }
181
+ )
182
+
183
+ return output[0]["generated_text"]
184
+
185
+
186
+ def translate_to_string_input(
187
+ openai_messages: list[dict], roles_included: bool = False
188
+ ):
189
+ # Translate from OpenAI's dict to a single string input
190
+ messages = []
191
+ for message in openai_messages:
192
+ if roles_included:
193
+ messages.append(message["role"] + ": ")
194
+ messages.append(message["content"])
195
+ if roles_included:
196
+ messages.append("assistant: ")
197
+ return "\n".join(messages)
198
+
199
+
200
+ # for gpt-3 models and instruct models
201
+ def get_completions_response(
202
+ model: str,
203
+ messages: list[dict] | str,
204
+ temperature: float,
205
+ max_tokens: int = MAX_TOKENS,
206
+ ) -> str:
207
+ if not isinstance(messages, str):
208
+ prompt = translate_to_string_input(messages, roles_included=False)
209
+ else:
210
+ prompt = messages
211
+
212
+ completion = openai.Completion.create(
213
+ model=model, temperature=temperature, prompt=prompt, max_tokens=max_tokens
214
+ )
215
+
216
+ response = completion.choices[0].text
217
+ return response
218
+
219
+
220
+ def count_all_tokens(model: str, messages: list[dict[str, str]]) -> int:
221
+ total_tokens = 0
222
+ for message in messages:
223
+ total_tokens += count_tokens(model, message["content"])
224
+ return total_tokens
225
+
226
+
227
+ def count_tokens(model: str, prompt: str) -> int:
228
+ if "gpt" not in model:
229
+ model = "gpt-4"
230
+
231
+ encoding = tiktoken.encoding_for_model(model)
232
+ num_tokens = len(encoding.encode(prompt))
233
+ return num_tokens
234
+
235
+
236
+ def get_prompt_cost(model: str, num_tokens: int) -> float:
237
+ # good enough for quick evals
238
+ if model not in pricing_dict:
239
+ return num_tokens * 0.001 * pricing_dict["gpt-4"]
240
+ return num_tokens * 0.001 * pricing_dict[model]
241
+
242
+
243
+ def get_response_cost(model: str, num_tokens: int) -> float:
244
+ # good enough for quick evals
245
+ if model not in pricing_dict:
246
+ return num_tokens * 0.001 * pricing_dict["gpt-4"]
247
+
248
+ cost = num_tokens * 0.001 * pricing_dict[model]
249
+
250
+ if model == "gpt-4":
251
+ cost *= 2
252
+
253
+ return cost
254
+
255
+
256
+ def record_messages(messages: list[dict], model: str):
257
+ # create the conversation in a human-readable format
258
+ conversation_text = ""
259
+ for message in messages:
260
+ conversation_text += message["content"]
261
+
262
+ # write the conversation to the next available text file
263
+ with open(f"gpt_outputs/transcript.txt", "w") as f:
264
+ f.write(model + "\n\n")
265
+ f.write(conversation_text)
chess-gpt-eval/llama_module.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from peft import PeftModel
3
+ import torch
4
+
5
+ from typing import Optional
6
+
7
+
8
+ # There are a couple non optimal parts of this code:
9
+ # 1. It doesn't inherit the Player class in main.py, which throws type checking errors
10
+ # 2. get_move_from_response() is duplicated from main.py
11
+ # However, I didn't want to add clutter and major dependencies like torch, peft, and transformers
12
+ # to those not using this class. So, this was my compromise.
13
+ class BaseLlamaPlayer:
14
+ def __init__(
15
+ self, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, model_name: str
16
+ ):
17
+ self.tokenizer = tokenizer
18
+ self.model = model
19
+ self.model_name = model_name
20
+
21
+ def get_llama_response(self, game_state: str, temperature: float) -> Optional[str]:
22
+ prompt = game_state
23
+ tokenized_input = self.tokenizer(prompt, return_tensors="pt").to("cuda")
24
+ result = self.model.generate(
25
+ **tokenized_input, max_new_tokens=10, temperature=temperature
26
+ ).to("cpu")
27
+ input_ids_tensor = tokenized_input["input_ids"]
28
+ # transformers generate() returns <s> + prompt + output. This grabs only the output
29
+ res_sliced = result[:, input_ids_tensor.shape[1] :]
30
+ return self.tokenizer.batch_decode(res_sliced)[0]
31
+
32
+ def get_move_from_response(self, response: Optional[str]) -> Optional[str]:
33
+ if response is None:
34
+ return None
35
+
36
+ # Parse the response to get only the first move
37
+ moves = response.split()
38
+ first_move = moves[0] if moves else None
39
+
40
+ return first_move
41
+
42
+ def get_move(
43
+ self, board: str, game_state: str, temperature: float
44
+ ) -> Optional[str]:
45
+ completion = self.get_llama_response(game_state, temperature)
46
+ return self.get_move_from_response(completion)
47
+
48
+ def get_config(self) -> dict:
49
+ return {"model": self.model_name}
50
+
51
+
52
+ class LocalLlamaPlayer(BaseLlamaPlayer):
53
+ def __init__(self, model_name: str):
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_name, torch_dtype=torch.bfloat16, device_map=0
57
+ ).to("cuda")
58
+ super().__init__(tokenizer, model, model_name)
59
+
60
+
61
+ class LocalLoraLlamaPlayer(BaseLlamaPlayer):
62
+ def __init__(self, base_model_id: str, adapter_model_path: str):
63
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
64
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
65
+ model = (
66
+ PeftModel.from_pretrained(base_model, adapter_model_path)
67
+ .merge_and_unload()
68
+ .to("cuda")
69
+ )
70
+
71
+ super().__init__(tokenizer, model, adapter_model_path)
chess-gpt-eval/main.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import chess
3
+ import chess.engine
4
+ import os
5
+ import csv
6
+ import random
7
+ import time
8
+ import platform
9
+
10
+ # NOTE: LLAMA AND NANOGPT ARE EXPERIMENTAL PLAYERS, if not using them, comment them out
11
+ # from llama_module import BaseLlamaPlayer, LocalLlamaPlayer, LocalLoraLlamaPlayer
12
+ from nanogpt.nanogpt_module import NanoGptPlayer
13
+ from mamba_module import MambaPlayer
14
+ import gpt_query
15
+ from lczero.backends import Weights, Backend, GameState
16
+ import numpy as np
17
+
18
+ from typing import Optional, Tuple
19
+ from dataclasses import dataclass
20
+
21
+
22
+ @dataclass
23
+ class LegalMoveResponse:
24
+ move_san: Optional[str] = None
25
+ move_uci: Optional[chess.Move] = None
26
+ attempts: int = 0
27
+ is_resignation: bool = False
28
+ is_illegal_move: bool = False
29
+
30
+
31
+ # Define base Player class
32
+ class Player:
33
+ def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
34
+ raise NotImplementedError
35
+
36
+ def get_config(self) -> dict:
37
+ raise NotImplementedError
38
+
39
+
40
+ class GPTPlayer(Player):
41
+ def __init__(self, model: str):
42
+ with open("gpt_inputs/api_key.txt", "r") as f:
43
+ openai.api_key = f.read().strip()
44
+ self.model = model
45
+
46
+ def get_move(
47
+ self, board: chess.Board, game_state: str, temperature: float
48
+ ) -> Optional[str]:
49
+ response = get_gpt_response(game_state, self.model, temperature)
50
+ return get_move_from_gpt_response(response)
51
+
52
+ def get_config(self) -> dict:
53
+ return {"model": self.model}
54
+
55
+
56
+ class LC0PLayer(Player):
57
+ # "11258-32x4-se.pb.gz" = stockfish level 0- = skill 0
58
+ # "11258-48x5-se.pb.gz" = stockfish level 0+ = skill 1
59
+ # "11258-80x7-se.pb.gz" = stockfish level 1 = skill 2
60
+ # "11258-104x9-se.pb.gz" = stockfish level 2 = skill 3
61
+ # "TK-6430 aka 128x10-BPR-64M-6430000.pb.gz" = stockfish level 3 = skill 4
62
+ # "00af53b081e80147172e6f281c01daf5ca19ada173321438914c730370aa4267" = stockfish level 4 = skill 5
63
+ # "b2ec465d0fb5b5eb39d2e1e3f74041a5d2fc92d413b71aa7ea0b6fb082ccba9c" = stockfish level 5+ = skill 6
64
+ def __init__(self, skill):
65
+ self.skill = skill
66
+ network_paths = ["./lc0/build/release/11258-32x4-se.pb.gz", "./lc0/build/release/11258-48x5-se.pb.gz", "./lc0/build/release/11258-80x7-se.pb.gz", "./lc0/build/release/11258-104x9-se.pb.gz", "./lc0/build/release/TK-6430 aka 128x10-BPR-64M-6430000.pb.gz", "./lc0/build/release/00af53b081e80147172e6f281c01daf5ca19ada173321438914c730370aa4267", "./lc0/build/release/b2ec465d0fb5b5eb39d2e1e3f74041a5d2fc92d413b71aa7ea0b6fb082ccba9c"]
67
+ print(f"\n\nLoading lc0 network: {network_paths[skill]}\n\n")
68
+ self.weights = Weights(network_paths[skill])
69
+ self.backend = Backend(weights=self.weights)
70
+ self.gamestate = GameState()
71
+
72
+ def get_move(self, board: chess.Board, game_state: str, temperature: float):
73
+ self.gamestate = GameState(fen=board.fen())
74
+ input_planes = self.gamestate.as_input(self.backend)
75
+ result = self.backend.evaluate(input_planes)[0]
76
+ moves = self.gamestate.moves()
77
+ policy_indices = self.gamestate.policy_indices()
78
+ move_probs = np.array(result.p_softmax(*policy_indices))
79
+ best_move_idx = move_probs.argmax()
80
+ best_move = moves[best_move_idx]
81
+ return board.san(chess.Move.from_uci(best_move))
82
+
83
+ def get_config(self) -> dict:
84
+ return {"network": self.weights, "skill_level": self.skill, "play_time": 0}
85
+
86
+
87
+ class StockfishPlayer(Player):
88
+
89
+ @staticmethod
90
+ def get_stockfish_path() -> str:
91
+ """
92
+ Determines the operating system and returns the appropriate path for Stockfish.
93
+
94
+ Returns:
95
+ str: Path to the Stockfish executable based on the operating system.
96
+ """
97
+ if platform.system() == 'Linux':
98
+ return "/usr/games/stockfish"
99
+ elif platform.system() == 'Darwin': # Darwin is the system name for macOS
100
+ return "stockfish"
101
+ elif platform.system() == 'Windows':
102
+ return r"C:\Users\Haile\Downloads\stockfish\stockfish-windows-x86-64-avx2.exe"
103
+ else:
104
+ raise OSError("Unsupported operating system")
105
+
106
+ def __init__(self, skill_level: int, play_time: float):
107
+ self._skill_level = skill_level
108
+ self._play_time = play_time
109
+ # If getting started, you need to run brew install stockfish
110
+ stockfish_path = StockfishPlayer.get_stockfish_path()
111
+ self._engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
112
+
113
+ def get_move(
114
+ self, board: chess.Board, game_state: str, temperature: float
115
+ ) -> Optional[str]:
116
+ if self._skill_level == -2:
117
+ legal_moves = list(board.legal_moves)
118
+ random_move = random.choice(legal_moves)
119
+ return board.san(random_move)
120
+ elif self._skill_level < 0:
121
+ self._engine.configure({"Skill Level": 0})
122
+ result = self._engine.play(
123
+ board, chess.engine.Limit(time=1e-8, depth=1, nodes=1)
124
+ )
125
+
126
+ else:
127
+ self._engine.configure({"Skill Level": self._skill_level})
128
+ result = self._engine.play(board, chess.engine.Limit(time=self._play_time))
129
+ if result.move is None:
130
+ return None
131
+ return board.san(result.move)
132
+
133
+ def get_config(self) -> dict:
134
+ return {"skill_level": self._skill_level, "play_time": self._play_time}
135
+
136
+ def close(self):
137
+ self._engine.quit()
138
+
139
+
140
+ class HumanPlayer(Player):
141
+ def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
142
+ # Print board for human player
143
+ print(board)
144
+ while True:
145
+ move = input("Enter your move (SAN format): ")
146
+ try:
147
+ move_uci = board.parse_san(move)
148
+ if move_uci in board.legal_moves:
149
+ return move
150
+ except:
151
+ print("Illegal move, try again.")
152
+
153
+ def get_config(self) -> dict:
154
+ return {"player": "human"}
155
+
156
+
157
+ def get_gpt_response(game_state: str, model: str, temperature: float) -> Optional[str]:
158
+ # trying to prevent what I believe to be rate limit issues
159
+ if model == "gpt-4":
160
+ time.sleep(0.4)
161
+ response = gpt_query.get_gpt_response(game_state, model, temperature)
162
+ return response
163
+
164
+
165
+ def get_move_from_gpt_response(response: Optional[str]) -> Optional[str]:
166
+ if response is None:
167
+ return None
168
+
169
+ # Parse the response to get only the first move
170
+ moves = response.split()
171
+ first_move = moves[0] if moves else None
172
+
173
+ return first_move
174
+
175
+
176
+ def record_results(
177
+ board: chess.Board,
178
+ player_one: Player,
179
+ player_two: Player,
180
+ game_state: str,
181
+ player_one_illegal_moves: int,
182
+ player_two_illegal_moves: int,
183
+ player_one_legal_moves: int,
184
+ player_two_legal_moves: int,
185
+ total_time: float,
186
+ player_one_resignation: bool,
187
+ player_two_resignation: bool,
188
+ player_one_failed_to_find_legal_move: bool,
189
+ player_two_failed_to_find_legal_move: bool,
190
+ total_moves: int,
191
+ illegal_moves: int,
192
+ ):
193
+ unique_game_id = generate_unique_game_id()
194
+
195
+ (
196
+ player_one_title,
197
+ player_two_title,
198
+ player_one_time,
199
+ player_two_time,
200
+ ) = get_player_titles_and_time(player_one, player_two)
201
+
202
+ if player_one_resignation or player_one_failed_to_find_legal_move:
203
+ result = "0-1"
204
+ player_one_score = 0
205
+ player_two_score = 1
206
+ elif player_two_resignation or player_two_failed_to_find_legal_move:
207
+ result = "1-0"
208
+ player_one_score = 1
209
+ player_two_score = 0
210
+ else:
211
+ result = board.result()
212
+ # Hmmm.... debating this one. Annoying if I leave it running and it fails here for some reason, probably involving some
213
+ # resignation / failed move situation I didn't think of
214
+ # -1e10 at least ensures it doesn't fail silently
215
+ if "-" in result:
216
+ player_one_score = result.split("-")[0]
217
+ player_two_score = result.split("-")[1]
218
+ elif result == "*": # Draw due to hitting max moves
219
+ player_one_score = 0#1/2
220
+ player_two_score = 1#1/2
221
+ else:
222
+ player_one_score = -1e10
223
+ player_two_score = -1e10
224
+
225
+ info_dict = {
226
+ "game_id": unique_game_id,
227
+ "transcript": game_state,
228
+ "result": result,
229
+ "player_one": player_one_title,
230
+ "player_two": player_two_title,
231
+ "player_one_time": player_one_time,
232
+ "player_two_time": player_two_time,
233
+ "player_one_score": player_one_score,
234
+ "player_two_score": player_two_score,
235
+ "player_one_illegal_moves": player_one_illegal_moves,
236
+ "player_two_illegal_moves": player_two_illegal_moves,
237
+ "player_one_legal_moves": player_one_legal_moves,
238
+ "player_two_legal_moves": player_two_legal_moves,
239
+ "player_one_resignation": player_one_resignation,
240
+ "player_two_resignation": player_two_resignation,
241
+ "player_one_failed_to_find_legal_move": player_one_failed_to_find_legal_move,
242
+ "player_two_failed_to_find_legal_move": player_two_failed_to_find_legal_move,
243
+ "game_title": f"{player_one_title} vs. {player_two_title}",
244
+ "number_of_moves": board.fullmove_number,
245
+ "time_taken": total_time,
246
+ "total_moves": total_moves,
247
+ "illegal_moves": illegal_moves,
248
+ }
249
+
250
+ if RUN_FOR_ANALYSIS:
251
+ csv_file_path = f"logs/{player_one_recording_name}_vs_{player_two_recording_name}"
252
+ csv_file_path = csv_file_path.replace(".", "_") # Because I'm using ckpt filenames for nanogpt models
253
+ csv_file_path += ".csv"
254
+ else:
255
+ csv_file_path = recording_file
256
+
257
+
258
+
259
+ # Determine if we need to write headers (in case the file doesn't exist yet)
260
+ write_headers = not os.path.exists(csv_file_path)
261
+
262
+ # Append the results to the CSV file
263
+ with open(csv_file_path, "a", newline="") as csv_file: # THIS WAS APPEND
264
+ writer = csv.DictWriter(csv_file, fieldnames=info_dict.keys())
265
+ if write_headers:
266
+ writer.writeheader()
267
+ writer.writerow(info_dict)
268
+
269
+ with open("game.txt", "w") as f:
270
+ f.write(game_state)
271
+
272
+
273
+ def generate_unique_game_id() -> str:
274
+ timestamp = int(time.time())
275
+ random_num = random.randint(1000, 9999) # 4-digit random number
276
+ return f"{timestamp}-{random_num}"
277
+
278
+
279
+ def get_player_titles_and_time(
280
+ player_one: Player, player_two: Player
281
+ ) -> Tuple[str, str, Optional[float], Optional[float]]:
282
+ player_one_config = player_one.get_config()
283
+ player_two_config = player_two.get_config()
284
+
285
+ # For player one
286
+ if "model" in player_one_config:
287
+ player_one_title = player_one_config["model"]
288
+ player_one_time = None
289
+ else:
290
+ player_one_title = f"Stockfish {player_one_config['skill_level']}"
291
+ player_one_time = player_one_config["play_time"]
292
+
293
+ # For player two
294
+ if "model" in player_two_config:
295
+ player_two_title = player_two_config["model"]
296
+ player_two_time = None
297
+ else:
298
+ player_two_title = f"Stockfish {player_two_config['skill_level']}"
299
+ player_two_time = player_two_config["play_time"]
300
+
301
+ return (player_one_title, player_two_title, player_one_time, player_two_time)
302
+
303
+
304
+ used_openings = []
305
+ def initialize_game_with_opening(
306
+ game_state: str, board: chess.Board
307
+ ) -> Tuple[str, chess.Board]:
308
+ global used_openings
309
+ with open("openings.csv", "r") as file:
310
+ lines = file.readlines()[1:] # Skip header
311
+ moves_string = random.choice(lines)
312
+ while moves_string in used_openings:
313
+ moves_string = random.choice(lines)
314
+ used_openings.append(moves_string)
315
+ if move_num_in_gamestate:
316
+ game_state = moves_string.rstrip() + " "
317
+ else:
318
+ game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in moves_string.split()])
319
+ game_state = game_state.rstrip() + " "
320
+ # Splitting the moves string on spaces
321
+ tokens = moves_string.split()
322
+
323
+ for token in tokens:
324
+ # If the token contains a period, it's a move number + move combination
325
+ if "." in token:
326
+ move = token.split(".")[-1] # Take the move part after the period
327
+ else:
328
+ move = token
329
+
330
+ board.push_san(move)
331
+ return game_state.rstrip(), board
332
+
333
+
334
+ # Return is (move_san, move_uci, attempts, is_resignation, is_illegal_move)
335
+ def get_legal_move(
336
+ player: Player,
337
+ board: chess.Board,
338
+ game_state: str,
339
+ player_one: bool,
340
+ max_attempts: int = 5,
341
+ ) -> LegalMoveResponse:
342
+ """Request a move from the player and ensure it's legal."""
343
+ move_san = None
344
+ move_uci = None
345
+
346
+ for attempt in range(max_attempts):
347
+ #print(f"get_legal_move: |{game_state}|")
348
+ move_san = player.get_move(
349
+ board, game_state, min(((attempt / max_attempts) * 1) + 0.001, 0.75)
350
+ )
351
+
352
+ # Sometimes when GPT thinks it's the end of the game, it will just output the result
353
+ # Like "1-0". If so, this really isn't an illegal move, so we'll add a check for that.
354
+ if move_san is not None:
355
+ if move_san == "1-0" or move_san == "0-1" or move_san == "1/2-1/2":
356
+ print(f"{move_san}, player has resigned")
357
+ return LegalMoveResponse(
358
+ move_san=None,
359
+ move_uci=None,
360
+ attempts=attempt,
361
+ is_resignation=True,
362
+ )
363
+
364
+ try:
365
+ move_uci = board.parse_san(move_san)
366
+ except Exception as e:
367
+ print(f"Error parsing move {move_san}: {e}")
368
+ # check if player is gpt-3.5-turbo-instruct
369
+ # only recording errors for gpt-3.5-turbo-instruct because it's errors are so rare
370
+ if player.get_config()["model"] == "gpt-3.5-turbo-instruct":
371
+ with open("gpt-3.5-turbo-instruct-illegal-moves.txt", "a") as f:
372
+ f.write(f"{game_state}\n{move_san}\n")
373
+ continue
374
+
375
+ if move_uci in board.legal_moves:
376
+ if player_one == False:
377
+ if not move_san.startswith(" "):
378
+ move_san = " " + move_san
379
+ else:
380
+ if move_san.startswith(" "):
381
+ move_san = move_san[1:]
382
+ return LegalMoveResponse(move_san, move_uci, attempt)
383
+ print(f"Illegal move: {move_san}")
384
+
385
+ # If we reach here, the player has made illegal moves for all attempts.
386
+ print(f"{player} provided illegal moves for {max_attempts} attempts.")
387
+ return LegalMoveResponse(
388
+ move_san=None, move_uci=None, attempts=max_attempts, is_illegal_move=True
389
+ )
390
+
391
+
392
+ def play_turn(
393
+ player: Player, board: chess.Board, game_state: str, player_one: bool
394
+ ) -> Tuple[str, bool, bool, int]:
395
+ result = get_legal_move(player, board, game_state, player_one, 5)
396
+ illegal_moves = result.attempts
397
+ move_san = result.move_san
398
+ move_uci = result.move_uci
399
+ resignation = result.is_resignation
400
+ failed_to_find_legal_move = result.is_illegal_move
401
+
402
+ if resignation:
403
+ print(f"{player} resigned with result: {board.result()}")
404
+ elif failed_to_find_legal_move:
405
+ print(f"Game over: 5 consecutive illegal moves from {player}")
406
+ elif move_san is None or move_uci is None:
407
+ print(f"Game over: {player} failed to find a legal move")
408
+ else:
409
+ board.push(move_uci)
410
+ game_state += move_san
411
+ print(move_san, end=" ")
412
+
413
+ return game_state, resignation, failed_to_find_legal_move, illegal_moves
414
+
415
+
416
+ def play_game(
417
+ player_one: Player,
418
+ player_two: Player,
419
+ max_games: int = 10,
420
+ random_opening_seed: bool = False,
421
+ ):
422
+ for z in range(max_games):
423
+ print(f"\nGame {z} of {max_games}\n")
424
+
425
+ with open("gpt_inputs/prompt.txt", "r") as f:
426
+ game_state = f.read()
427
+ board = chess.Board()
428
+
429
+ if random_opening_seed:
430
+ game_state, board = initialize_game_with_opening(game_state, board)
431
+ #print(f"play_gamea after init: |{game_state}|")
432
+ player_one_illegal_moves = 0
433
+ player_two_illegal_moves = 0
434
+ player_one_legal_moves = 0
435
+ player_two_legal_moves = 0
436
+ player_one_resignation = False
437
+ player_two_resignation = False
438
+ player_one_failed_to_find_legal_move = False
439
+ player_two_failed_to_find_legal_move = False
440
+ start_time = time.time()
441
+
442
+ total_moves = 0
443
+ illegal_moves = 0
444
+ print_for_human = isinstance(player_one, HumanPlayer) or isinstance(player_two, HumanPlayer)
445
+
446
+ while not board.is_game_over():
447
+ if print_for_human:
448
+ print(board)
449
+
450
+ with open("game.txt", "w") as f:
451
+ f.write(game_state)
452
+ current_move_num = f"{board.fullmove_number if move_num_in_gamestate else ''}."
453
+ total_moves += 1
454
+ # I increment legal moves here so player_two isn't penalized for the game ending before its turn
455
+ player_one_legal_moves += 1
456
+ player_two_legal_moves += 1
457
+
458
+ # this if statement may be overkill, just trying to get format to exactly match PGN notation
459
+ if board.fullmove_number != 1:
460
+ game_state += " "
461
+ game_state += current_move_num
462
+ #print(f"|{game_state}|")
463
+ #print(f"{current_move_num}", end=" ")
464
+
465
+ (
466
+ game_state,
467
+ player_one_resignation,
468
+ player_one_failed_to_find_legal_move,
469
+ illegal_moves_one,
470
+ ) = play_turn(player_one, board, game_state, player_one=True)
471
+ player_one_illegal_moves += illegal_moves_one
472
+ if illegal_moves_one != 0:
473
+ player_one_legal_moves -= 1
474
+ if (
475
+ board.is_game_over()
476
+ or player_one_resignation
477
+ or player_one_failed_to_find_legal_move
478
+ ):
479
+ break
480
+
481
+ (
482
+ game_state,
483
+ player_two_resignation,
484
+ player_two_failed_to_find_legal_move,
485
+ illegal_moves_two,
486
+ ) = play_turn(player_two, board, game_state, player_one=False)
487
+ player_two_illegal_moves += illegal_moves_two
488
+ if illegal_moves_two != 0:
489
+ player_two_legal_moves -= 1
490
+ if (
491
+ board.is_game_over()
492
+ or player_two_resignation
493
+ or player_two_failed_to_find_legal_move
494
+ ):
495
+ break
496
+
497
+ print("\n", end="")
498
+
499
+ if total_moves > MAX_MOVES:
500
+ break
501
+
502
+ end_time = time.time()
503
+ total_time = end_time - start_time
504
+ print(f"\nGame over. Total time: {total_time} seconds")
505
+ print(f"Result: {board.result()}")
506
+ print(board)
507
+ print()
508
+ record_results(
509
+ board,
510
+ player_one,
511
+ player_two,
512
+ game_state,
513
+ player_one_illegal_moves,
514
+ player_two_illegal_moves,
515
+ player_one_legal_moves,
516
+ player_two_legal_moves,
517
+ total_time,
518
+ player_one_resignation,
519
+ player_two_resignation,
520
+ player_one_failed_to_find_legal_move,
521
+ player_two_failed_to_find_legal_move,
522
+ total_moves,
523
+ illegal_moves,
524
+ )
525
+ if isinstance(player_one, StockfishPlayer):
526
+ player_one.close()
527
+ if isinstance(player_two, StockfishPlayer):
528
+ player_two.close()
529
+
530
+ # print(game_state)
531
+
532
+
533
+ RUN_FOR_ANALYSIS = True
534
+ MAX_MOVES = 999 # Due to nanogpt max input length of 1024
535
+ recording_file = "logs/determine.csv" # default recording file. Because we are using list [player_ones], recording_file is overwritten
536
+ # player_one_recording_name = "ckpt_8.pt"
537
+ #player_ones = ["ckpt_iter_20000.pt","ckpt_iter_40000.pt","ckpt_iter_60000.pt","ckpt_iter_80000.pt"] #["ckpt.pt"]
538
+ player_ones = ["Xformer/6.6M/ckpt.pt"]
539
+ player_two_recording_name = "lc0_sweep" #"stockfish_sweep"
540
+ move_num_in_gamestate = False
541
+ if __name__ == "__main__":
542
+ for nanogpt_player in player_ones:
543
+ player_one_recording_name = nanogpt_player
544
+ for i in range(2): #range(11):
545
+ num_games = 265 #265 instead of 250 for duplicates (for lc0, stockfish doesn't need it)
546
+ # player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
547
+ # player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
548
+ # player_one = LocalLoraLlamaPlayer("meta-llama/Llama-2-7b-hf", "/workspace/axolotl/lora2-out")
549
+ # player_one = GPTPlayer(model="gpt-4")
550
+ # player_one = StockfishPlayer(skill_level=-1, play_time=0.1)
551
+
552
+ player_one = NanoGptPlayer(model_name=player_one_recording_name, move_num_in_gamestate=False)
553
+ # player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=False)
554
+ #player_two = StockfishPlayer(skill_level=i, play_time=0.1)
555
+ player_two = LC0PLayer(skill=i)
556
+
557
+ # player_two = GPTPlayer(model="gpt-4")
558
+ # player_two = GPTPlayer(model="gpt-3.5-turbo-instruct")
559
+
560
+ print(f"\n\nSTARTING GAMES AGAINST STOCKFISH LEVEL {i}\n\n")
561
+ #print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n")
562
+
563
+ play_game(player_one, player_two, num_games, random_opening_seed=True)
564
+
565
+ print("\n\n\n********\nDONE!\n********\n\n\n")
chess-gpt-eval/mamba.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from pscan import pscan
10
+
11
+ """
12
+
13
+ This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006.
14
+ The major differences are :
15
+ -the convolution is done with torch.nn.Conv1d
16
+ -the selective scan is done in PyTorch
17
+
18
+ A sequential version of the selective scan is also available for comparison.
19
+
20
+ - A Mamba model is composed of several layers, which are ResidualBlock.
21
+ - A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x
22
+ - This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim).
23
+ First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED).
24
+ Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM.
25
+ We then multiply it by silu(z).
26
+ See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock.
27
+
28
+ """
29
+
30
+ @dataclass
31
+ class MambaConfig:
32
+ d_model: int # D
33
+ n_layers: int
34
+ dt_rank: Union[int, str] = 'auto'
35
+ d_state: int = 16 # N in paper/comments
36
+ expand_factor: int = 2 # E in paper/comments
37
+ d_conv: int = 4
38
+
39
+ dt_min: float = 0.001
40
+ dt_max: float = 0.1
41
+ dt_init: str = "random" # "random" or "constant"
42
+ dt_scale: float = 1.0
43
+ dt_init_floor = 1e-4
44
+
45
+ bias: bool = False
46
+ conv_bias: bool = True
47
+
48
+ pscan: bool = True # use parallel scan mode or sequential mode when training
49
+
50
+ def __post_init__(self):
51
+ self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
52
+
53
+ if self.dt_rank == 'auto':
54
+ self.dt_rank = math.ceil(self.d_model / 16)
55
+
56
+ class Mamba(nn.Module):
57
+ def __init__(self, config: MambaConfig):
58
+ super().__init__()
59
+
60
+ self.config = config
61
+
62
+ self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
63
+ #self.norm_f = RMSNorm(config.d_model)
64
+
65
+ def forward(self, x):
66
+ # x : (B, L, D)
67
+
68
+ # y : (B, L, D)
69
+
70
+ for layer in self.layers:
71
+ x = layer(x)
72
+
73
+ #x = self.norm_f(x)
74
+
75
+ return x
76
+
77
+ def step(self, x, caches):
78
+ # x : (B, L, D)
79
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
80
+
81
+ # y : (B, L, D)
82
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
83
+
84
+ for i, layer in enumerate(self.layers):
85
+ x, caches[i] = layer.step(x, caches[i])
86
+
87
+ return x, caches
88
+
89
+ class ResidualBlock(nn.Module):
90
+ def __init__(self, config: MambaConfig):
91
+ super().__init__()
92
+
93
+ self.mixer = MambaBlock(config)
94
+ self.norm = RMSNorm(config.d_model)
95
+
96
+ def forward(self, x):
97
+ # x : (B, L, D)
98
+
99
+ # output : (B, L, D)
100
+
101
+ output = self.mixer(self.norm(x)) + x
102
+ return output
103
+
104
+ def step(self, x, cache):
105
+ # x : (B, D)
106
+ # cache : (h, inputs)
107
+ # h : (B, ED, N)
108
+ # inputs: (B, ED, d_conv-1)
109
+
110
+ # output : (B, D)
111
+ # cache : (h, inputs)
112
+
113
+ output, cache = self.mixer.step(self.norm(x), cache)
114
+ output = output + x
115
+ return output, cache
116
+
117
+ class MambaBlock(nn.Module):
118
+ def __init__(self, config: MambaConfig):
119
+ super().__init__()
120
+
121
+ self.config = config
122
+
123
+ # projects block input from D to 2*ED (two branches)
124
+ self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
125
+
126
+ self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
127
+ kernel_size=config.d_conv, bias=config.conv_bias,
128
+ groups=config.d_inner,
129
+ padding=config.d_conv - 1)
130
+
131
+ nn.init.kaiming_normal_(self.conv1d.weight, mode='fan_out', nonlinearity='leaky_relu')
132
+
133
+ # projects x to input-dependent Δ, B, C
134
+ self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
135
+
136
+ # projects Δ from dt_rank to d_inner
137
+ self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
138
+
139
+ # dt initialization
140
+ # dt weights
141
+ dt_init_std = config.dt_rank**-0.5 * config.dt_scale
142
+ if config.dt_init == "constant":
143
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
144
+ elif config.dt_init == "random":
145
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
146
+ else:
147
+ raise NotImplementedError
148
+
149
+ # dt bias
150
+ dt = torch.exp(
151
+ torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
152
+ ).clamp(min=config.dt_init_floor)
153
+ inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
154
+ with torch.no_grad():
155
+ self.dt_proj.bias.copy_(inv_dt)
156
+ #self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
157
+ # todo : explain why removed
158
+
159
+ # S4D real initialization
160
+ A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
161
+ self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
162
+ self.D = nn.Parameter(torch.ones(config.d_inner))
163
+
164
+ # projects block output from ED back to D
165
+ self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
166
+
167
+ def forward(self, x):
168
+ # x : (B, L, D)
169
+
170
+ # y : (B, L, D)
171
+
172
+ _, L, _ = x.shape
173
+
174
+ xz = self.in_proj(x) # (B, L, 2*ED)
175
+ x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
176
+
177
+ # x branch
178
+ x = x.transpose(1, 2) # (B, ED, L)
179
+ x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
180
+ x = x.transpose(1, 2) # (B, L, ED)
181
+
182
+ x = F.silu(x)
183
+ y = self.ssm(x)
184
+
185
+ # z branch
186
+ z = F.silu(z)
187
+
188
+ output = y * z
189
+ output = self.out_proj(output) # (B, L, D)
190
+
191
+ return output
192
+
193
+ def ssm(self, x):
194
+ # x : (B, L, ED)
195
+
196
+ # y : (B, L, ED)
197
+
198
+ A = -torch.exp(self.A_log.float()) # (ED, N)
199
+ D = self.D.float()
200
+ # TODO remove .float()
201
+
202
+ deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
203
+
204
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
205
+ delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
206
+
207
+ if self.config.pscan:
208
+ y = self.selective_scan(x, delta, A, B, C, D)
209
+ else:
210
+ y = self.selective_scan_seq(x, delta, A, B, C, D)
211
+
212
+ return y
213
+
214
+ def selective_scan(self, x, delta, A, B, C, D):
215
+ # x : (B, L, ED)
216
+ # Δ : (B, L, ED)
217
+ # A : (ED, N)
218
+ # B : (B, L, N)
219
+ # C : (B, L, N)
220
+ # D : (ED)
221
+
222
+ # y : (B, L, ED)
223
+
224
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
225
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
226
+
227
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
228
+
229
+ hs = pscan(deltaA, BX)
230
+
231
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
232
+
233
+ y = y + D * x
234
+
235
+ return y
236
+
237
+ def selective_scan_seq(self, x, delta, A, B, C, D):
238
+ # x : (B, L, ED)
239
+ # Δ : (B, L, ED)
240
+ # A : (ED, N)
241
+ # B : (B, L, N)
242
+ # C : (B, L, N)
243
+ # D : (ED)
244
+
245
+ # y : (B, L, ED)
246
+
247
+ _, L, _ = x.shape
248
+
249
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
250
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
251
+
252
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
253
+
254
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
255
+ hs = []
256
+
257
+ for t in range(0, L):
258
+ h = deltaA[:, t] * h + BX[:, t]
259
+ hs.append(h)
260
+
261
+ hs = torch.stack(hs, dim=1) # (B, L, ED, N)
262
+
263
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
264
+
265
+ y = y + D * x
266
+
267
+ return y
268
+
269
+ # -------------------------- inference -------------------------- #
270
+ """
271
+ Concerning auto-regressive inference
272
+
273
+ The cool part of using Mamba : inference is constant wrt to sequence length
274
+ We just have to keep in cache, for each layer, two things :
275
+ - the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
276
+ - the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
277
+ (d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
278
+ (and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
279
+
280
+ Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
281
+ h is (B, ED, N), and inputs is (B, ED, d_conv-1)
282
+ The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
283
+
284
+ The cache object is initialized as follows : (None, torch.zeros()).
285
+ When h is None, the selective scan function detects it and start with h=0.
286
+ The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
287
+
288
+ As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
289
+ """
290
+
291
+ def step(self, x, cache):
292
+ # x : (B, D)
293
+ # cache : (h, inputs)
294
+ # h : (B, ED, N)
295
+ # inputs : (B, ED, d_conv-1)
296
+
297
+ # y : (B, D)
298
+ # cache : (h, inputs)
299
+
300
+ h, inputs = cache
301
+
302
+ xz = self.in_proj(x) # (B, 2*ED)
303
+ x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
304
+
305
+ # x branch
306
+ x_cache = x.unsqueeze(2)
307
+ x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
308
+
309
+ x = F.silu(x)
310
+ y, h = self.ssm_step(x, h)
311
+
312
+ # z branch
313
+ z = F.silu(z)
314
+
315
+ output = y * z
316
+ output = self.out_proj(output) # (B, D)
317
+
318
+ # prepare cache for next call
319
+ inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
320
+ cache = (h, inputs)
321
+
322
+ return output, cache
323
+
324
+ def ssm_step(self, x, h):
325
+ # x : (B, ED)
326
+ # h : (B, ED, N)
327
+
328
+ # y : (B, ED)
329
+ # h : (B, ED, N)
330
+
331
+ A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
332
+ D = self.D.float()
333
+ # TODO remove .float()
334
+
335
+ deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
336
+
337
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
338
+ delta = F.softplus(self.dt_proj(delta)) # (B, ED)
339
+
340
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
341
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
342
+
343
+ BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
344
+
345
+ if h is None:
346
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
347
+
348
+ h = deltaA * h + BX # (B, ED, N)
349
+
350
+ y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
351
+
352
+ y = y + D * x
353
+
354
+ # todo : pq h.squeeze(1) ??
355
+ return y, h.squeeze(1)
356
+
357
+ # taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
358
+ class RMSNorm(nn.Module):
359
+ def __init__(self, d_model: int, eps: float = 1e-5):
360
+ super().__init__()
361
+
362
+ self.eps = eps
363
+ self.weight = nn.Parameter(torch.ones(d_model))
364
+
365
+ def forward(self, x):
366
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
367
+
368
+ return output
chess-gpt-eval/mamba/out/meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1121191e401988851de5744fe27fe463c3a086fc8c9a5538ef7fc12162bfb09
3
+ size 373
chess-gpt-eval/mamba_lm.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields, asdict
2
+ import json
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from mamba import Mamba, MambaConfig, RMSNorm
9
+
10
+ """
11
+
12
+ Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
13
+
14
+ """
15
+
16
+ # TODO generate function : batch size != 1 ? (for now B=1)
17
+ # TODO generate function : top-p sampling
18
+
19
+ @dataclass
20
+ class MambaLMConfig(MambaConfig):
21
+ vocab_size: int = 32000
22
+ pad_vocab_size_multiple: int = 8
23
+
24
+ def __post_init__(self):
25
+ super().__post_init__()
26
+
27
+ #if self.vocab_size % self.pad_vocab_size_multiple != 0:
28
+ # self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
29
+
30
+ def to_mamba_config(self) -> MambaConfig:
31
+ mamba_config_fields = {field.name for field in fields(MambaConfig)}
32
+ filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
33
+ return MambaConfig(**filtered_dict)
34
+
35
+ # adapted from https://github.com/johnma2006/mamba-minimal
36
+ def from_pretrained(name: str):
37
+ """
38
+ Returns a model loaded with pretrained weights pulled from HuggingFace.
39
+
40
+ Args:
41
+ name: As of now, supports
42
+ * 'state-spaces/mamba-2.8b-slimpj'
43
+ * 'state-spaces/mamba-2.8b'
44
+ * 'state-spaces/mamba-1.4b'
45
+ * 'state-spaces/mamba-790m'
46
+ * 'state-spaces/mamba-370m'
47
+ * 'state-spaces/mamba-130m'
48
+
49
+ Returns:
50
+ model: a Mamba model configured with the proper parameters and initialized with the proper weights
51
+ """
52
+
53
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
54
+ from transformers.utils.hub import cached_file
55
+
56
+ def load_config_hf(model_name):
57
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
58
+ return json.load(open(resolved_archive_file))
59
+
60
+ def load_state_dict_hf(model_name):
61
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
62
+ return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
63
+
64
+ # copy config data
65
+ config_data = load_config_hf(name)
66
+ config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
67
+
68
+ model = MambaLM(config)
69
+
70
+ # copy weights
71
+ state_dict = load_state_dict_hf(name)
72
+
73
+ new_state_dict = {}
74
+ for key in state_dict:
75
+ if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
76
+ new_key = key.replace('backbone.', '')
77
+ else:
78
+ new_key = key.replace('backbone', 'mamba')
79
+
80
+ new_state_dict[new_key] = state_dict[key]
81
+
82
+ model.load_state_dict(new_state_dict)
83
+
84
+ return model
85
+
86
+ class MambaLM(nn.Module):
87
+ def __init__(self, lm_config: MambaLMConfig):
88
+ super().__init__()
89
+ self.lm_config = lm_config
90
+ self.config = lm_config.to_mamba_config()
91
+
92
+ self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
93
+ self.mamba = Mamba(self.config)
94
+ self.norm_f = RMSNorm(self.config.d_model)
95
+
96
+ self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
97
+ self.lm_head.weight = self.embedding.weight
98
+
99
+ def forward(self, tokens):
100
+ # tokens : (B, L)
101
+
102
+ # logits : (B, L, vocab_size)
103
+
104
+ x = self.embedding(tokens)
105
+
106
+ x = self.mamba(x)
107
+ x = self.norm_f(x)
108
+
109
+ logits = self.lm_head(x)
110
+
111
+ return logits
112
+
113
+ def step(self, token, caches):
114
+ # token : (B)
115
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
116
+
117
+ # logits : (B, vocab_size)
118
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
119
+
120
+ x = self.embedding(token)
121
+
122
+ x, caches = self.mamba.step(x, caches)
123
+ x = self.norm_f(x)
124
+
125
+ logits = self.lm_head(x)
126
+
127
+ return logits, caches
128
+
129
+ # TODO temperature
130
+ # TODO process prompt in parallel, and pass in sequential mode when prompt is finished ?
131
+ def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40):
132
+ self.eval()
133
+
134
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
135
+
136
+ # caches is a list of cache, one per layer
137
+ # cache is composed of : the hidden state, and the last d_conv-1 inputs
138
+ # the hidden state because the update is like an RNN
139
+ # the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
140
+ caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)]
141
+
142
+ for i in range(input_ids.size(1) + num_tokens - 1):
143
+ with torch.no_grad():
144
+ # forward the new output, get new cache
145
+ next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
146
+
147
+ # sample (no sampling when the prompt is being processed)
148
+ if i+1 >= input_ids.size(1):
149
+ probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size)
150
+
151
+ if top_k is not None:
152
+ values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest
153
+ probs[probs < values[:, -1, None]] = 0
154
+ probs = probs / probs.sum(axis=1, keepdims=True)
155
+
156
+ if sample:
157
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1)
158
+ else:
159
+ next_token = torch.argmax(probs, dim=-1) # (1)
160
+
161
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
162
+
163
+ output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
164
+
165
+ self.train()
166
+
167
+ return output
168
+
chess-gpt-eval/mamba_module.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ from mamba_lm import MambaLM, MambaLMConfig, from_pretrained
5
+ from contextlib import nullcontext
6
+
7
+ BASE_DIR = "mamba/"
8
+
9
+ class MambaPlayer:
10
+ def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
11
+ self.model_name = model_name
12
+ self.move_num_in_gamestate = move_num_in_gamestate
13
+ # -----------------------------------------------------------------------------
14
+
15
+ init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
16
+ out_dir = "out" # ignored if init_from is not 'resume'
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ #device = "cpu"
19
+ dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
20
+ seed = 1337
21
+ compile = False # set to True if using PyTorch 2.0 and Mamba supports it
22
+ # -----------------------------------------------------------------------------
23
+
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+
27
+ device_type = (
28
+ "cuda" if "cuda" in device else "cpu"
29
+ ) # for later use in torch.autocast
30
+ ptdtype = {
31
+ "float32": torch.float32,
32
+ "bfloat16": torch.bfloat16,
33
+ "float16": torch.float16,
34
+ }[dtype]
35
+ ctx = (
36
+ nullcontext()
37
+ if device_type == "cpu"
38
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
39
+ )
40
+
41
+ # Model initialization
42
+ if init_from == "resume":
43
+ #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
44
+ ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
45
+ checkpoint = torch.load(ckpt_path, map_location=device)
46
+ model_config = checkpoint["model_args"]
47
+ model = MambaLM(model_config)
48
+ model.load_state_dict(checkpoint['model'])
49
+ elif init_from.startswith('state-spaces'):
50
+ model = from_pretrained(init_from).to(device)
51
+ else:
52
+ raise ValueError("Invalid init_from value")
53
+
54
+ model.eval()
55
+ model.to(device)
56
+
57
+ if compile and hasattr(torch, 'compile'):
58
+ model = torch.compile(model)
59
+
60
+ # look for the meta pickle in case it is available in the dataset folder
61
+ meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
62
+ load_meta = os.path.exists(meta_path)
63
+ if move_num_in_gamestate and load_meta:
64
+ with open(meta_path, "rb") as f:
65
+ meta = pickle.load(f)
66
+ stoi, itos = meta["stoi"], meta["itos"]
67
+ vocab_size = meta['vocab_size']
68
+ encode = lambda s: [stoi[c] for c in s]
69
+ decode = lambda l: "".join([itos[i] for i in l])
70
+ else:
71
+ stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
72
+ itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
73
+ for s in stoi:
74
+ assert itos[stoi[s]] == s
75
+ vocab_size = len(stoi)
76
+ print(f"Vocab size {vocab_size}")
77
+ encode = lambda s: [stoi[c] for c in s.replace('-', '')]
78
+ decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
79
+
80
+ self.vocab_size = vocab_size
81
+ self.encode = encode
82
+ self.decode = decode
83
+ self.model = model
84
+ self.ctx = ctx
85
+ self.device = device
86
+
87
+ def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
88
+ game_state = game_state.split("\n\n")[-1].strip()
89
+ #game_state = ";" + game_state
90
+
91
+ # Tokenize the game state
92
+ encoded_prompt = self.encode(game_state)
93
+ input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
94
+
95
+ self.model.eval() # Set the model to evaluation mode
96
+ with torch.no_grad():
97
+ have_non_space = False
98
+ for _ in range(max_new_tokens):
99
+ logits = self.model(input_ids)[0, -1, :] # Get logits for the last token
100
+
101
+ # Apply temperature scaling and optionally sample from top k tokens
102
+ logits = logits / temperature
103
+ if top_k > 0:
104
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
105
+ logits[indices_to_remove] = -float('Inf')
106
+
107
+ probs = torch.nn.functional.softmax(logits, dim=-1)
108
+ next_token_id = torch.multinomial(probs, num_samples=1)
109
+ if have_non_space and (next_token_id == 0 or next_token_id==4):
110
+ break
111
+ else:
112
+ have_non_space = True
113
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
114
+
115
+ model_response = self.decode(input_ids[0].tolist())
116
+ model_response = model_response[len(game_state):].split(";")[0]
117
+ return model_response
118
+
119
+ #def encode(self, text: str):
120
+ # Implement the appropriate tokenization for MambaLM
121
+ # This could be a simple mapping or a more complex tokenizer
122
+ # return [stoi[char] for char in text] # Example
123
+
124
+ #def decode(self, token_ids: list):
125
+ # Implement the appropriate decoding for MambaLM
126
+ # return ''.join([itos[id] for id in token_ids]) # Example
127
+
128
+ def get_move_from_response(self, response: str) -> str:
129
+ if not response:
130
+ return None
131
+ # Parse the response to get only the first move
132
+ moves = response.split()
133
+ first_move = moves[0]
134
+ first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
135
+
136
+ return first_move
137
+
138
+ def get_move(self, board: str, game_state: str, temperature: float) -> str:
139
+ completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
140
+ return self.get_move_from_response(completion)
141
+
142
+ def get_config(self) -> dict:
143
+ return {"model": self.model_name}
144
+
chess-gpt-eval/nanogpt/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
chess-gpt-eval/nanogpt/__pycache__/nanogpt_module.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
chess-gpt-eval/nanogpt/__pycache__/xformer.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
chess-gpt-eval/nanogpt/configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
chess-gpt-eval/nanogpt/model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head = config.n_head
42
+ self.n_embd = config.n_embd
43
+ self.dropout = config.dropout
44
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46
+ if not self.flash:
47
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48
+ # causal mask to ensure that attention is only applied to the left in the input sequence
49
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50
+ .view(1, 1, config.block_size, config.block_size))
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ if self.flash:
63
+ # efficient attention using Flash Attention CUDA kernels
64
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65
+ else:
66
+ # manual implementation of attention
67
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69
+ att = F.softmax(att, dim=-1)
70
+ att = self.attn_dropout(att)
71
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73
+
74
+ # output projection
75
+ y = self.resid_dropout(self.c_proj(y))
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83
+ self.gelu = nn.GELU()
84
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85
+ self.dropout = nn.Dropout(config.dropout)
86
+
87
+ def forward(self, x):
88
+ x = self.c_fc(x)
89
+ x = self.gelu(x)
90
+ x = self.c_proj(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlp(self.ln_2(x))
106
+ return x
107
+
108
+ @dataclass
109
+ class GPTConfig:
110
+ block_size: int = 1024
111
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112
+ n_layer: int = 12
113
+ n_head: int = 12
114
+ n_embd: int = 768
115
+ dropout: float = 0.0
116
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+
118
+ class GPT(nn.Module):
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ assert config.vocab_size is not None
123
+ assert config.block_size is not None
124
+ self.config = config
125
+
126
+ self.transformer = nn.ModuleDict(dict(
127
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
128
+ wpe = nn.Embedding(config.block_size, config.n_embd),
129
+ drop = nn.Dropout(config.dropout),
130
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
132
+ ))
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ # with weight tying when using torch.compile() some warnings get generated:
135
+ # "UserWarning: functional_call was passed multiple values for tied weights.
136
+ # This behavior is deprecated and will be an error in future versions"
137
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
138
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139
+
140
+ # init all weights
141
+ self.apply(self._init_weights)
142
+ # apply special scaled init to the residual projections, per GPT-2 paper
143
+ for pn, p in self.named_parameters():
144
+ if pn.endswith('c_proj.weight'):
145
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
+
147
+ # report number of parameters
148
+ #print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
+
150
+ def get_num_params(self, non_embedding=True):
151
+ """
152
+ Return the number of parameters in the model.
153
+ For non-embedding count (default), the position embeddings get subtracted.
154
+ The token embeddings would too, except due to the parameter sharing these
155
+ params are actually used as weights in the final layer, so we include them.
156
+ """
157
+ n_params = sum(p.numel() for p in self.parameters())
158
+ if non_embedding:
159
+ n_params -= self.transformer.wpe.weight.numel()
160
+ return n_params
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Linear):
164
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165
+ if module.bias is not None:
166
+ torch.nn.init.zeros_(module.bias)
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169
+
170
+ def forward(self, idx, targets=None):
171
+ device = idx.device
172
+ b, t = idx.size()
173
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175
+
176
+ # forward the GPT model itself
177
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
+ x = self.transformer.drop(tok_emb + pos_emb)
180
+ for block in self.transformer.h:
181
+ x = block(x)
182
+ x = self.transformer.ln_f(x)
183
+
184
+ if targets is not None:
185
+ # if we are given some desired targets also calculate the loss
186
+ logits = self.lm_head(x)
187
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188
+ else:
189
+ # inference-time mini-optimization: only forward the lm_head on the very last position
190
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191
+ loss = None
192
+
193
+ return logits, loss
194
+
195
+ def crop_block_size(self, block_size):
196
+ # model surgery to decrease the block size if necessary
197
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198
+ # but want to use a smaller block size for some smaller, simpler model
199
+ assert block_size <= self.config.block_size
200
+ self.config.block_size = block_size
201
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202
+ for block in self.transformer.h:
203
+ if hasattr(block.attn, 'bias'):
204
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205
+
206
+ @classmethod
207
+ def from_pretrained(cls, model_type, override_args=None):
208
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209
+ override_args = override_args or {} # default to empty dict
210
+ # only dropout can be overridden see more notes below
211
+ assert all(k == 'dropout' for k in override_args)
212
+ from transformers import GPT2LMHeadModel
213
+ print("loading weights from pretrained gpt: %s" % model_type)
214
+
215
+ # n_layer, n_head and n_embd are determined from model_type
216
+ config_args = {
217
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221
+ }[model_type]
222
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
223
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225
+ config_args['bias'] = True # always True for GPT model checkpoints
226
+ # we can override the dropout rate, if desired
227
+ if 'dropout' in override_args:
228
+ print(f"overriding dropout rate to {override_args['dropout']}")
229
+ config_args['dropout'] = override_args['dropout']
230
+ # create a from-scratch initialized minGPT model
231
+ config = GPTConfig(**config_args)
232
+ model = GPT(config)
233
+ sd = model.state_dict()
234
+ sd_keys = sd.keys()
235
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236
+
237
+ # init a huggingface/transformers model
238
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239
+ sd_hf = model_hf.state_dict()
240
+
241
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
242
+ sd_keys_hf = sd_hf.keys()
243
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247
+ # this means that we have to transpose these weights when we import them
248
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249
+ for k in sd_keys_hf:
250
+ if any(k.endswith(w) for w in transposed):
251
+ # special treatment for the Conv1D weights we need to transpose
252
+ assert sd_hf[k].shape[::-1] == sd[k].shape
253
+ with torch.no_grad():
254
+ sd[k].copy_(sd_hf[k].t())
255
+ else:
256
+ # vanilla copy over the other parameters
257
+ assert sd_hf[k].shape == sd[k].shape
258
+ with torch.no_grad():
259
+ sd[k].copy_(sd_hf[k])
260
+
261
+ return model
262
+
263
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264
+ # start with all of the candidate parameters
265
+ param_dict = {pn: p for pn, p in self.named_parameters()}
266
+ # filter out those that do not require grad
267
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272
+ optim_groups = [
273
+ {'params': decay_params, 'weight_decay': weight_decay},
274
+ {'params': nodecay_params, 'weight_decay': 0.0}
275
+ ]
276
+ num_decay_params = sum(p.numel() for p in decay_params)
277
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
278
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
+ # Create AdamW optimizer and use the fused version if it is available
281
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cuda'
283
+ extra_args = dict(fused=True) if use_fused else dict()
284
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
+ print(f"using fused AdamW: {use_fused}")
286
+
287
+ return optimizer
288
+
289
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
290
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291
+ # first estimate the number of flops we do per iteration.
292
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293
+ N = self.get_num_params()
294
+ cfg = self.config
295
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296
+ flops_per_token = 6*N + 12*L*H*Q*T
297
+ flops_per_fwdbwd = flops_per_token * T
298
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
300
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
301
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302
+ mfu = flops_achieved / flops_promised
303
+ return mfu
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
310
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311
+ """
312
+ for _ in range(max_new_tokens):
313
+ # if the sequence context is growing too long we must crop it at block_size
314
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315
+ # forward the model to get the logits for the index in the sequence
316
+ logits, _ = self(idx_cond)
317
+ # pluck the logits at the final step and scale by desired temperature
318
+ logits = logits[:, -1, :] / temperature
319
+ # optionally crop the logits to only the top k options
320
+ if top_k is not None:
321
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322
+ logits[logits < v[:, [-1]]] = -float('Inf')
323
+ # apply softmax to convert logits to (normalized) probabilities
324
+ probs = F.softmax(logits, dim=-1)
325
+ # sample from the distribution
326
+ idx_next = torch.multinomial(probs, num_samples=1)
327
+ # append sampled index to the running sequence and continue
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx
chess-gpt-eval/nanogpt/nanogpt_module.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from a trained model
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ import tiktoken
9
+ from nanogpt.model import GPTConfig, GPT
10
+
11
+ BASE_DIR = "nanogpt/"
12
+
13
+
14
+ class NanoGptPlayer:
15
+ def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
16
+ self.model_name = model_name
17
+ # -----------------------------------------------------------------------------
18
+
19
+ init_from = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
20
+ out_dir = "out" # ignored if init_from is not 'resume'
21
+ input_dir = "addition"
22
+ test_name = "test.txt"
23
+ start = "12+44=" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
24
+ num_samples = 1 # number of samples to draw
25
+ max_new_tokens = 6 # number of tokens generated in each sample
26
+ temperature = 0.01 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
27
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
28
+ seed = 1337
29
+ device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
30
+ #device = "cpu"
31
+ dtype = "float16" # 'float32' or 'bfloat16' or 'float16'
32
+ compile = False # use PyTorch 2.0 to compile the model to be faster
33
+ exec(
34
+ open(f"{BASE_DIR}configurator.py").read()
35
+ ) # overrides from command line or config file
36
+ # -----------------------------------------------------------------------------
37
+
38
+ torch.manual_seed(seed)
39
+ torch.cuda.manual_seed(seed)
40
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
41
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
42
+ device_type = (
43
+ "cuda" if "cuda" in device else "cpu"
44
+ ) # for later use in torch.autocast
45
+ ptdtype = {
46
+ "float32": torch.float32,
47
+ "bfloat16": torch.bfloat16,
48
+ "float16": torch.float16,
49
+ }[dtype]
50
+ ctx = (
51
+ nullcontext()
52
+ if device_type == "cpu"
53
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
54
+ )
55
+
56
+ # model
57
+ if init_from == "resume":
58
+ # init from a model saved in a specific directory
59
+ #ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
60
+ ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
61
+ checkpoint = torch.load(ckpt_path, map_location=device)
62
+ #gptconf = GPTConfig(**checkpoint["model_args"])
63
+ #model = GPT(gptconf)
64
+ model = GPT(checkpoint["model_args"])
65
+ state_dict = checkpoint["model"]
66
+ unwanted_prefix = "_orig_mod."
67
+ for k, v in list(state_dict.items()):
68
+ if k.startswith(unwanted_prefix):
69
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
70
+ model.load_state_dict(state_dict)
71
+ elif init_from.startswith("gpt2"):
72
+ # init from a given GPT-2 model
73
+ model = GPT.from_pretrained(init_from, dict(dropout=0.0))
74
+
75
+ model.eval()
76
+ model.to(device)
77
+ if compile:
78
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
79
+
80
+ # look for the meta pickle in case it is available in the dataset folder
81
+ meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
82
+ load_meta = os.path.exists(meta_path)
83
+ if move_num_in_gamestate and load_meta:
84
+ with open(meta_path, "rb") as f:
85
+ meta = pickle.load(f)
86
+ stoi, itos = meta["stoi"], meta["itos"]
87
+ vocab_size = meta['vocab_size']
88
+ encode = lambda s: [stoi[c] for c in s]
89
+ decode = lambda l: "".join([itos[i] for i in l])
90
+ else:
91
+ stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
92
+ itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
93
+ for s in stoi:
94
+ assert itos[stoi[s]] == s
95
+ vocab_size = len(stoi)
96
+ print(f"Vocab size {vocab_size}")
97
+ encode = lambda s: [stoi[c] for c in s.replace('-', '')]
98
+ decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
99
+
100
+ self.encode = encode
101
+ self.decode = decode
102
+ self.model = model
103
+ self.ctx = ctx
104
+ self.device = device
105
+
106
+ def get_nanogpt_response(self, game_state: str, temperature: float) -> str:
107
+ num_samples = 1 # number of samples to draw
108
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
109
+ max_new_tokens = 8
110
+
111
+ # Remove ["stockfish elo xxx"]\n["stockfish elo xxx"]\n\n from game_state
112
+ # nanogpt was trained only on pgn transcripts
113
+ game_state = game_state.split("\n\n")[-1].strip()
114
+
115
+ # print("game_state", game_state)
116
+
117
+ #game_state = ";" + game_state
118
+
119
+ start_ids = self.encode(game_state)
120
+
121
+ x = torch.tensor(start_ids, dtype=torch.long, device=self.device)[None, ...]
122
+ with torch.no_grad():
123
+ with self.ctx:
124
+ for k in range(num_samples):
125
+ y = self.model.generate(
126
+ x, max_new_tokens, temperature=temperature, top_k=top_k
127
+ )
128
+
129
+ model_response = self.decode(y[0].tolist())
130
+
131
+ # print("model_response", model_response)
132
+ # model_response includes the input string
133
+ model_response = model_response[len(game_state):].split(";")[0]
134
+ return model_response
135
+
136
+ def get_move_from_response(self, response: str) -> str:
137
+ # Parse the response to get only the first move
138
+ moves = response.split()
139
+ first_move = moves[0]
140
+
141
+ return first_move
142
+
143
+ def get_move(self, board: str, game_state: str, temperature: float) -> str:
144
+ completion = self.get_nanogpt_response(game_state, temperature)
145
+ return self.get_move_from_response(completion)
146
+
147
+ def get_config(self) -> dict:
148
+ return {"model": self.model_name}
chess-gpt-eval/nanogpt/out/meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1121191e401988851de5744fe27fe463c3a086fc8c9a5538ef7fc12162bfb09
3
+ size 373
chess-gpt-eval/nanogpt/out/view_ckpt.ipynb ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "\n",
11
+ "def load_checkpoint(filepath: str) -> dict:\n",
12
+ " \"\"\"\n",
13
+ " Load a checkpoint file.\n",
14
+ "\n",
15
+ " Args:\n",
16
+ " filepath (str): Path to the .ckpt file.\n",
17
+ "\n",
18
+ " Returns:\n",
19
+ " dict: Contents of the checkpoint file.\n",
20
+ " \"\"\"\n",
21
+ " checkpoint = torch.load(filepath, map_location=torch.device('cpu'))\n",
22
+ " return checkpoint\n",
23
+ "\n",
24
+ "checkpoint_path = 'ckpt.pt'\n",
25
+ "checkpoint_data = load_checkpoint(checkpoint_path)\n",
26
+ "\n",
27
+ "# Print the keys to understand what's inside\n",
28
+ "print(checkpoint_data.keys())\n",
29
+ "\n",
30
+ "# If you want to view specific information, access it using the keys\n",
31
+ "# For example, to view the model's state_dict\n",
32
+ "model_state = checkpoint_data.get('state_dict', None)\n",
33
+ "if model_state:\n",
34
+ " print(\"Model's state dict:\", model_state)\n",
35
+ "\n",
36
+ "# To view training information like current learning rate, iterations, etc.\n",
37
+ "training_info = checkpoint_data.get('training_info', None)\n",
38
+ "if training_info:\n",
39
+ " print(\"Training Info:\", training_info)\n",
40
+ "\n",
41
+ "# To view config, if it's stored in the checkpoint\n",
42
+ "config = checkpoint_data.get('config', None)\n",
43
+ "if config:\n",
44
+ " print(\"Configurations:\", config)\n"
45
+ ]
46
+ }
47
+ ],
48
+ "metadata": {
49
+ "kernelspec": {
50
+ "display_name": "openai",
51
+ "language": "python",
52
+ "name": "python3"
53
+ },
54
+ "language_info": {
55
+ "name": "python",
56
+ "version": "3.10.13"
57
+ }
58
+ },
59
+ "nbformat": 4,
60
+ "nbformat_minor": 2
61
+ }
chess-gpt-eval/openings.csv ADDED
The diff for this file is too large to render. See raw diff
 
chess-gpt-eval/pscan.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ """
7
+
8
+ An implementation of the parallel scan operation in PyTorch (Blelloch version).
9
+ Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10
+
11
+ """
12
+
13
+ def npo2(len):
14
+ """
15
+ Returns the next power of 2 above len
16
+ """
17
+
18
+ return 2 ** math.ceil(math.log2(len))
19
+
20
+ def pad_npo2(X):
21
+ """
22
+ Pads input length dim to the next power of 2
23
+
24
+ Args:
25
+ X : (B, L, D, N)
26
+
27
+ Returns:
28
+ Y : (B, npo2(L), D, N)
29
+ """
30
+
31
+ len_npo2 = npo2(X.size(1))
32
+ pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
33
+ return F.pad(X, pad_tuple, "constant", 0)
34
+
35
+ class PScan(torch.autograd.Function):
36
+ @staticmethod
37
+ def pscan(A, X):
38
+ # A : (B, D, L, N)
39
+ # X : (B, D, L, N)
40
+
41
+ # modifies X in place by doing a parallel scan.
42
+ # more formally, X will be populated by these values :
43
+ # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
44
+ # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
45
+
46
+ # only supports L that is a power of two (mainly for a clearer code)
47
+
48
+ B, D, L, _ = A.size()
49
+ num_steps = int(math.log2(L))
50
+
51
+ # up sweep (last 2 steps unfolded)
52
+ Aa = A
53
+ Xa = X
54
+ for _ in range(num_steps-2):
55
+ T = Xa.size(2)
56
+ Aa = Aa.view(B, D, T//2, 2, -1)
57
+ Xa = Xa.view(B, D, T//2, 2, -1)
58
+
59
+ Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
60
+ Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
61
+
62
+ Aa = Aa[:, :, :, 1]
63
+ Xa = Xa[:, :, :, 1]
64
+
65
+ # we have only 4, 2 or 1 nodes left
66
+ if Xa.size(2) == 4:
67
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
68
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
69
+
70
+ Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
71
+ elif Xa.size(2) == 2:
72
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
73
+ return
74
+ else:
75
+ return
76
+
77
+ # down sweep (first 2 steps unfolded)
78
+ Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
79
+ Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
80
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
81
+ Aa[:, :, 2].mul_(Aa[:, :, 1])
82
+
83
+ for k in range(num_steps-3, -1, -1):
84
+ Aa = A[:, :, 2**k-1:L:2**k]
85
+ Xa = X[:, :, 2**k-1:L:2**k]
86
+
87
+ T = Xa.size(2)
88
+ Aa = Aa.view(B, D, T//2, 2, -1)
89
+ Xa = Xa.view(B, D, T//2, 2, -1)
90
+
91
+ Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
92
+ Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
93
+
94
+ @staticmethod
95
+ def pscan_rev(A, X):
96
+ # A : (B, D, L, N)
97
+ # X : (B, D, L, N)
98
+
99
+ # the same function as above, but in reverse
100
+ # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
101
+ # it is used in the backward pass
102
+
103
+ # only supports L that is a power of two (mainly for a clearer code)
104
+
105
+ B, D, L, _ = A.size()
106
+ num_steps = int(math.log2(L))
107
+
108
+ # up sweep (last 2 steps unfolded)
109
+ Aa = A
110
+ Xa = X
111
+ for _ in range(num_steps-2):
112
+ T = Xa.size(2)
113
+ Aa = Aa.view(B, D, T//2, 2, -1)
114
+ Xa = Xa.view(B, D, T//2, 2, -1)
115
+
116
+ Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
117
+ Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
118
+
119
+ Aa = Aa[:, :, :, 0]
120
+ Xa = Xa[:, :, :, 0]
121
+
122
+ # we have only 4, 2 or 1 nodes left
123
+ if Xa.size(2) == 4:
124
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
125
+ Aa[:, :, 2].mul_(Aa[:, :, 3])
126
+
127
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
128
+ elif Xa.size(2) == 2:
129
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
130
+ return
131
+ else:
132
+ return
133
+
134
+ # down sweep (first 2 steps unfolded)
135
+ Aa = A[:, :, 0:L:2**(num_steps-2)]
136
+ Xa = X[:, :, 0:L:2**(num_steps-2)]
137
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
138
+ Aa[:, :, 1].mul_(Aa[:, :, 2])
139
+
140
+ for k in range(num_steps-3, -1, -1):
141
+ Aa = A[:, :, 0:L:2**k]
142
+ Xa = X[:, :, 0:L:2**k]
143
+
144
+ T = Xa.size(2)
145
+ Aa = Aa.view(B, D, T//2, 2, -1)
146
+ Xa = Xa.view(B, D, T//2, 2, -1)
147
+
148
+ Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
149
+ Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
150
+
151
+ @staticmethod
152
+ def forward(ctx, A_in, X_in):
153
+ """
154
+ Applies the parallel scan operation, as defined above. Returns a new tensor.
155
+ If you can, privilege sequence lengths that are powers of two.
156
+
157
+ Args:
158
+ A_in : (B, L, D, N)
159
+ X_in : (B, L, D, N)
160
+
161
+ Returns:
162
+ H : (B, L, D, N)
163
+ """
164
+
165
+ L = X_in.size(1)
166
+
167
+ # cloning is requiered because of the in-place ops
168
+ if L == npo2(L):
169
+ A = A_in.clone()
170
+ X = X_in.clone()
171
+ else:
172
+ # pad tensors (and clone btw)
173
+ A = pad_npo2(A_in) # (B, npo2(L), D, N)
174
+ X = pad_npo2(X_in) # (B, npo2(L), D, N)
175
+
176
+ # prepare tensors
177
+ A = A.transpose(2, 1) # (B, D, npo2(L), N)
178
+ X = X.transpose(2, 1) # (B, D, npo2(L), N)
179
+
180
+ # parallel scan (modifies X in-place)
181
+ PScan.pscan(A, X)
182
+
183
+ ctx.save_for_backward(A_in, X)
184
+
185
+ # slice [:, :L] (cut if there was padding)
186
+ return X.transpose(2, 1)[:, :L]
187
+
188
+ @staticmethod
189
+ def backward(ctx, grad_output_in):
190
+ """
191
+ Flows the gradient from the output to the input. Returns two new tensors.
192
+
193
+ Args:
194
+ ctx : A_in : (B, L, D, N), X : (B, D, L, N)
195
+ grad_output_in : (B, L, D, N)
196
+
197
+ Returns:
198
+ gradA : (B, L, D, N), gradX : (B, L, D, N)
199
+ """
200
+
201
+ A_in, X = ctx.saved_tensors
202
+
203
+ L = grad_output_in.size(1)
204
+
205
+ # cloning is requiered because of the in-place ops
206
+ if L == npo2(L):
207
+ grad_output = grad_output_in.clone()
208
+ # the next padding will clone A_in
209
+ else:
210
+ grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
211
+ A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
212
+
213
+ # prepare tensors
214
+ grad_output = grad_output.transpose(2, 1)
215
+ A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
216
+ A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
217
+
218
+ # reverse parallel scan (modifies grad_output in-place)
219
+ PScan.pscan_rev(A, grad_output)
220
+
221
+ Q = torch.zeros_like(X)
222
+ Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
223
+
224
+ return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
225
+
226
+ pscan = PScan.apply
chess-gpt-eval/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openai==0.28.0
2
+ tiktoken==0.4.0
3
+ tenacity==8.2.3
4
+ python-chess==1.999
5
+ matplotlib==3.8.0
6
+ pandas==2.1.1
chess-gpt-eval/xformer.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head = config.n_head
42
+ self.n_embd = config.n_embd
43
+ self.dropout = config.dropout
44
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46
+ if not self.flash:
47
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48
+ # causal mask to ensure that attention is only applied to the left in the input sequence
49
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50
+ .view(1, 1, config.block_size, config.block_size))
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ if self.flash:
63
+ # efficient attention using Flash Attention CUDA kernels
64
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65
+ else:
66
+ # manual implementation of attention
67
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69
+ att = F.softmax(att, dim=-1)
70
+ att = self.attn_dropout(att)
71
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73
+
74
+ # output projection
75
+ y = self.resid_dropout(self.c_proj(y))
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83
+ self.gelu = nn.GELU()
84
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85
+ self.dropout = nn.Dropout(config.dropout)
86
+
87
+ def forward(self, x):
88
+ x = self.c_fc(x)
89
+ x = self.gelu(x)
90
+ x = self.c_proj(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlp(self.ln_2(x))
106
+ return x
107
+
108
+ @dataclass
109
+ class GPTConfig:
110
+ block_size: int = 1024
111
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112
+ n_layer: int = 12
113
+ n_head: int = 12
114
+ n_embd: int = 768
115
+ dropout: float = 0.0
116
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+
118
+ class GPT(nn.Module):
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ assert config.vocab_size is not None
123
+ assert config.block_size is not None
124
+ self.config = config
125
+
126
+ self.transformer = nn.ModuleDict(dict(
127
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
128
+ wpe = nn.Embedding(config.block_size, config.n_embd),
129
+ drop = nn.Dropout(config.dropout),
130
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
132
+ ))
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ # with weight tying when using torch.compile() some warnings get generated:
135
+ # "UserWarning: functional_call was passed multiple values for tied weights.
136
+ # This behavior is deprecated and will be an error in future versions"
137
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
138
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139
+
140
+ # init all weights
141
+ self.apply(self._init_weights)
142
+ # apply special scaled init to the residual projections, per GPT-2 paper
143
+ for pn, p in self.named_parameters():
144
+ if pn.endswith('c_proj.weight'):
145
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
+
147
+ # report number of parameters
148
+ #print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
+
150
+ def get_num_params(self, non_embedding=True):
151
+ """
152
+ Return the number of parameters in the model.
153
+ For non-embedding count (default), the position embeddings get subtracted.
154
+ The token embeddings would too, except due to the parameter sharing these
155
+ params are actually used as weights in the final layer, so we include them.
156
+ """
157
+ n_params = sum(p.numel() for p in self.parameters())
158
+ if non_embedding:
159
+ n_params -= self.transformer.wpe.weight.numel()
160
+ return n_params
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Linear):
164
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165
+ if module.bias is not None:
166
+ torch.nn.init.zeros_(module.bias)
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169
+
170
+ def forward(self, idx, targets=None):
171
+ device = idx.device
172
+ b, t = idx.size()
173
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175
+
176
+ # forward the GPT model itself
177
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
+ x = self.transformer.drop(tok_emb + pos_emb)
180
+ for block in self.transformer.h:
181
+ x = block(x)
182
+ x = self.transformer.ln_f(x)
183
+
184
+ if targets is not None:
185
+ # if we are given some desired targets also calculate the loss
186
+ logits = self.lm_head(x)
187
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188
+ else:
189
+ # inference-time mini-optimization: only forward the lm_head on the very last position
190
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191
+ loss = None
192
+
193
+ return logits, loss
194
+
195
+ def crop_block_size(self, block_size):
196
+ # model surgery to decrease the block size if necessary
197
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198
+ # but want to use a smaller block size for some smaller, simpler model
199
+ assert block_size <= self.config.block_size
200
+ self.config.block_size = block_size
201
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202
+ for block in self.transformer.h:
203
+ if hasattr(block.attn, 'bias'):
204
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205
+
206
+ @classmethod
207
+ def from_pretrained(cls, model_type, override_args=None):
208
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209
+ override_args = override_args or {} # default to empty dict
210
+ # only dropout can be overridden see more notes below
211
+ assert all(k == 'dropout' for k in override_args)
212
+ from transformers import GPT2LMHeadModel
213
+ print("loading weights from pretrained gpt: %s" % model_type)
214
+
215
+ # n_layer, n_head and n_embd are determined from model_type
216
+ config_args = {
217
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221
+ }[model_type]
222
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
223
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225
+ config_args['bias'] = True # always True for GPT model checkpoints
226
+ # we can override the dropout rate, if desired
227
+ if 'dropout' in override_args:
228
+ print(f"overriding dropout rate to {override_args['dropout']}")
229
+ config_args['dropout'] = override_args['dropout']
230
+ # create a from-scratch initialized minGPT model
231
+ config = GPTConfig(**config_args)
232
+ model = GPT(config)
233
+ sd = model.state_dict()
234
+ sd_keys = sd.keys()
235
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236
+
237
+ # init a huggingface/transformers model
238
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239
+ sd_hf = model_hf.state_dict()
240
+
241
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
242
+ sd_keys_hf = sd_hf.keys()
243
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247
+ # this means that we have to transpose these weights when we import them
248
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249
+ for k in sd_keys_hf:
250
+ if any(k.endswith(w) for w in transposed):
251
+ # special treatment for the Conv1D weights we need to transpose
252
+ assert sd_hf[k].shape[::-1] == sd[k].shape
253
+ with torch.no_grad():
254
+ sd[k].copy_(sd_hf[k].t())
255
+ else:
256
+ # vanilla copy over the other parameters
257
+ assert sd_hf[k].shape == sd[k].shape
258
+ with torch.no_grad():
259
+ sd[k].copy_(sd_hf[k])
260
+
261
+ return model
262
+
263
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264
+ # start with all of the candidate parameters
265
+ param_dict = {pn: p for pn, p in self.named_parameters()}
266
+ # filter out those that do not require grad
267
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272
+ optim_groups = [
273
+ {'params': decay_params, 'weight_decay': weight_decay},
274
+ {'params': nodecay_params, 'weight_decay': 0.0}
275
+ ]
276
+ num_decay_params = sum(p.numel() for p in decay_params)
277
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
278
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
+ # Create AdamW optimizer and use the fused version if it is available
281
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cuda'
283
+ extra_args = dict(fused=True) if use_fused else dict()
284
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
+ print(f"using fused AdamW: {use_fused}")
286
+
287
+ return optimizer
288
+
289
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
290
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291
+ # first estimate the number of flops we do per iteration.
292
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293
+ N = self.get_num_params()
294
+ cfg = self.config
295
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296
+ flops_per_token = 6*N + 12*L*H*Q*T
297
+ flops_per_fwdbwd = flops_per_token * T
298
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
300
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
301
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302
+ mfu = flops_achieved / flops_promised
303
+ return mfu
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
310
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311
+ """
312
+ for _ in range(max_new_tokens):
313
+ # if the sequence context is growing too long we must crop it at block_size
314
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315
+ # forward the model to get the logits for the index in the sequence
316
+ logits, _ = self(idx_cond)
317
+ # pluck the logits at the final step and scale by desired temperature
318
+ logits = logits[:, -1, :] / temperature
319
+ # optionally crop the logits to only the top k options
320
+ if top_k is not None:
321
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322
+ logits[logits < v[:, [-1]]] = -float('Inf')
323
+ # apply softmax to convert logits to (normalized) probabilities
324
+ probs = F.softmax(logits, dim=-1)
325
+ # sample from the distribution
326
+ idx_next = torch.multinomial(probs, num_samples=1)
327
+ # append sampled index to the running sequence and continue
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx
chess-mamba-vs-xformer/config/Mamba/11M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 256
18
+
19
+ batch_size = 100
20
+ gradient_accumulation_steps = 1
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 200
25
+ eval_iters = 33
26
+ log_interval = 33
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 500 # not super necessary potentially
30
+ learning_rate = 1e-3
31
+ min_lr = 6.6667e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 400000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Mamba/11M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True # override via command line if you like
54
+ wandb_project = 'chess-mamba-v2'
55
+ wandb_run_name = 'Mamba-11M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 11M param
60
+ model_type = 'mamba'
61
+ n_layer = 20
62
+ d_model = 288
63
+ d_state = 16
64
+ dt_rank = 'auto' #ceil(d_model/16) ... 18 here
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Mamba/250M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 256
18
+
19
+ batch_size = 10
20
+ gradient_accumulation_steps = 10
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 300
25
+ eval_iters = 33
26
+ log_interval = 75
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 500 # not super necessary potentially
30
+ learning_rate = 2.0e-3 # tested 1.5e-3 from 112k-156k, before that 3.5e-3 #8e-3
31
+ min_lr = 1.3333e-4
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 400000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Mamba/250M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True
54
+ wandb_project = 'chess-mamba-v2'
55
+ wandb_run_name = 'Mamba-250M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 251M param
60
+ model_type = 'mamba'
61
+ n_layer = 96
62
+ d_model = 578
63
+ d_state = 56
64
+ dt_rank = 176
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Mamba/29M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 256
18
+
19
+ batch_size = 100
20
+ gradient_accumulation_steps = 1
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 250
25
+ eval_iters = 33
26
+ log_interval = 50
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 500 # not super necessary potentially
30
+ learning_rate = 1.25e-3
31
+ min_lr = 8.3333e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 400000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Mamba/29M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True # override via command line if you like
54
+ wandb_project = 'chess-mamba-v2'
55
+ wandb_run_name = 'Mamba-29M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 29.3M param
60
+ model_type = 'mamba'
61
+ n_layer = 33
62
+ d_model = 360
63
+ d_state = 24
64
+ dt_rank = 36
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Mamba/50M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 256
18
+
19
+ batch_size = 50
20
+ gradient_accumulation_steps = 2
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 250
25
+ eval_iters = 33
26
+ log_interval = 50
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 500 # not super necessary potentially
30
+ learning_rate = 1.5e-3 # tested 1.5e-3 from 112k-156k, before that 3.5e-3 #8e-3
31
+ min_lr = 1.0e-4 # was planning 8.5e-5 w/ /6.75 anneal #... before 2e-4 # 4.75e-4
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 400000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Mamba/50M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True
54
+ wandb_project = 'chess-mamba-v2'
55
+ wandb_run_name = 'Mamba-50M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 50.4M param
60
+ model_type = 'mamba'
61
+ n_layer = 48
62
+ d_model = 384
63
+ d_state = 32
64
+ dt_rank = 56
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Mamba/6.6M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 256
18
+
19
+ batch_size = 100
20
+ gradient_accumulation_steps = 1
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 200
25
+ eval_iters = 33
26
+ log_interval = 33
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 500 # not super necessary potentially
30
+ learning_rate = 8.16667e-4
31
+ min_lr = 5.4444e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 400000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Mamba/6.6M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True # override via command line if you like
54
+ wandb_project = 'chess-mamba-v2'
55
+ wandb_run_name = 'Mamba-6.6M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 6.6M param
60
+ model_type = 'mamba'
61
+ n_layer = 15
62
+ d_model = 256
63
+ d_state = 16
64
+ dt_rank = 'auto' #ceil(d_model/16) ... 16 here
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Xformer/11M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 100
18
+
19
+ batch_size = 100
20
+ gradient_accumulation_steps = 1
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 600
25
+ eval_iters = 100
26
+ log_interval = 100
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 1280 # not super necessary potentially
30
+ learning_rate = 2e-4
31
+ min_lr = 1.33333e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 1024000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Xformer/11M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True # override via command line if you like
54
+ wandb_project = 'chess-xformer'
55
+ wandb_run_name = 'Xformer-11M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 11.2M param
60
+ model_type = 'xformer'
61
+ n_layer = 6
62
+ n_head = 6
63
+ n_embd = 384
64
+ dropout = 0.0
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Xformer/250M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 100
18
+
19
+ batch_size = 10
20
+ gradient_accumulation_steps = 10
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 900
25
+ eval_iters = 100
26
+ log_interval = 225
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 1280 # not super necessary potentially
30
+ learning_rate = 4e-4
31
+ min_lr = 2.6667e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 1024000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Xformer/250M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True
54
+ wandb_project = 'chess-xformer'
55
+ wandb_run_name = 'Xformer-250M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 251.2M param
60
+ model_type = 'xformer'
61
+ n_layer = 51
62
+ n_head = 16
63
+ n_embd = 640
64
+ dropout = 0.0
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Xformer/29M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 100
18
+
19
+ batch_size = 100
20
+ gradient_accumulation_steps = 1
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 750
25
+ eval_iters = 100
26
+ log_interval = 150
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 1280 # not super necessary potentially
30
+ learning_rate = 2.5e-4
31
+ min_lr = 1.6667e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 1024000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Xformer/29M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True # override via command line if you like
54
+ wandb_project = 'chess-xformer'
55
+ wandb_run_name = 'Xformer-29M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 29.1M param
60
+ model_type = 'xformer'
61
+ n_layer = 9
62
+ n_head = 8
63
+ n_embd = 512
64
+ dropout = 0.0
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Xformer/50M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 100
18
+
19
+ batch_size = 50
20
+ gradient_accumulation_steps = 2
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 750
25
+ eval_iters = 100
26
+ log_interval = 150
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 1280 # not super necessary potentially
30
+ learning_rate = 3e-4 # Mamba is 9.375e-4 (adjusting for different base_batch_size)
31
+ min_lr = 2e-5 # Same ratio min/max as w/ Mamba. It's lower than 1/10 because doing slightly long anneal.
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 1024000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Xformer/50M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True
54
+ wandb_project = 'chess-xformer'
55
+ wandb_run_name = 'Xformer-50M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 50.8M param
60
+ model_type = 'xformer'
61
+ n_layer = 16
62
+ n_head = 8
63
+ n_embd = 512
64
+ dropout = 0.0
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'scratch'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/config/Xformer/6.6M.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ beta1 = 0.9
5
+ beta2 = 0.95
6
+ weight_decay = 4.5e-3
7
+ grad_clip = 0.5
8
+ auto_clip = True
9
+ auto_clip_max = 0.5
10
+ auto_clip_min = 3.333e-3
11
+ grad_clip_start_size = 100
12
+ grad_clip_max_size = 400
13
+ grad_clip_percentile = 10 #7.5 (try it at 10, tested @7.75)
14
+ max_seq_len = 1536
15
+
16
+ # batch size below values are based on this. When actual batch size adjusted, the below are adjusted automatically
17
+ base_batch_size = 100
18
+
19
+ batch_size = 100
20
+ gradient_accumulation_steps = 1
21
+ effective_batch_size = batch_size * gradient_accumulation_steps
22
+
23
+ always_save_checkpoint = True
24
+ eval_interval = 600
25
+ eval_iters = 100
26
+ log_interval = 100
27
+ train_file_update_interval = 10 # 23 was original ... 7 definitely crashes (maybe try 10 on Lambda)
28
+
29
+ warmup_iters = 1280 # not super necessary potentially
30
+ learning_rate = 1.633333e-4
31
+ min_lr = 1.08889e-5
32
+ # max_iters is for auto-stopping end of stable phase. Reported %complete progress is wrt this (that is, % complete doesn't include anneal).
33
+ max_iters = 1024000 #~=102M games
34
+
35
+ # # # # #
36
+
37
+ warmup_iters = int(warmup_iters * (base_batch_size / effective_batch_size))
38
+ learning_rate = learning_rate * np.sqrt(effective_batch_size / base_batch_size) # with baby networks can afford to go a bit higher
39
+ max_iters = int(max_iters * (base_batch_size / effective_batch_size))
40
+ min_lr = min_lr * np.sqrt(effective_batch_size / base_batch_size) # learning_rate / 10 usually
41
+
42
+ out_dir = 'out/Xformer/6.6M'
43
+ eval_interval = int(eval_interval * (base_batch_size / effective_batch_size)) # keep frequent because we'll overfit
44
+ eval_iters = int(eval_iters * (base_batch_size / batch_size)) # intentionally scaled by batch_size instead of effective_batch_size
45
+ log_interval = int(math.ceil(log_interval * (base_batch_size / effective_batch_size))) # don't print too too often
46
+
47
+ print(f'warmup iters: {warmup_iters}')
48
+ print(f'Max iters: {max_iters} ({max_iters * effective_batch_size} games)')
49
+ print(f'Eval iters: {eval_iters}')
50
+ print(f'Eval interval: {eval_interval}')
51
+ print(f'Log interval: {log_interval}')
52
+
53
+ wandb_log = True # override via command line if you like
54
+ wandb_project = 'chess-xformer'
55
+ wandb_run_name = 'Xformer-6.6M'
56
+
57
+ dataset = 'stable'
58
+
59
+ # 6.6M param
60
+ model_type = 'xformer'
61
+ n_layer = 5
62
+ n_head = 5
63
+ n_embd = 320
64
+ dropout = 0.0
65
+ move_num_in_gamestate = False
66
+
67
+ init_from = 'resume'
68
+
69
+ device = 'cuda' # run on cpu only
70
+ compile = False # do not torch compile the model
chess-mamba-vs-xformer/configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
chess-mamba-vs-xformer/data/anneal/anneal.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37569f7bafca0eb2a7361e4ae29ef1b9fed64dbeb061d2653215c348586f7a7e
3
+ size 679959998
chess-mamba-vs-xformer/mamba.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from pscan import pscan
10
+
11
+ """
12
+
13
+ This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006.
14
+ The major differences are :
15
+ -the convolution is done with torch.nn.Conv1d
16
+ -the selective scan is done in PyTorch
17
+
18
+ A sequential version of the selective scan is also available for comparison.
19
+
20
+ - A Mamba model is composed of several layers, which are ResidualBlock.
21
+ - A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x
22
+ - This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim).
23
+ First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED).
24
+ Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM.
25
+ We then multiply it by silu(z).
26
+ See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock.
27
+
28
+ """
29
+
30
+ @dataclass
31
+ class MambaConfig:
32
+ d_model: int # D
33
+ n_layers: int
34
+ dt_rank: Union[int, str] = 'auto'
35
+ d_state: int = 16 # N in paper/comments
36
+ expand_factor: int = 2 # E in paper/comments
37
+ d_conv: int = 4
38
+
39
+ dt_min: float = 0.001
40
+ dt_max: float = 0.1
41
+ dt_init: str = "random" # "random" or "constant"
42
+ dt_scale: float = 1.0
43
+ dt_init_floor = 1e-4
44
+
45
+ bias: bool = False
46
+ conv_bias: bool = True
47
+
48
+ pscan: bool = True # use parallel scan mode or sequential mode when training
49
+
50
+ def __post_init__(self):
51
+ self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
52
+
53
+ if self.dt_rank == 'auto':
54
+ self.dt_rank = math.ceil(self.d_model / 16)
55
+
56
+ class Mamba(nn.Module):
57
+ def __init__(self, config: MambaConfig):
58
+ super().__init__()
59
+
60
+ self.config = config
61
+
62
+ self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
63
+ #self.norm_f = RMSNorm(config.d_model)
64
+
65
+ def forward(self, x):
66
+ # x : (B, L, D)
67
+
68
+ # y : (B, L, D)
69
+
70
+ for layer in self.layers:
71
+ x = layer(x)
72
+
73
+ #x = self.norm_f(x)
74
+
75
+ return x
76
+
77
+ def step(self, x, caches):
78
+ # x : (B, L, D)
79
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
80
+
81
+ # y : (B, L, D)
82
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
83
+
84
+ for i, layer in enumerate(self.layers):
85
+ x, caches[i] = layer.step(x, caches[i])
86
+
87
+ return x, caches
88
+
89
+ class ResidualBlock(nn.Module):
90
+ def __init__(self, config: MambaConfig):
91
+ super().__init__()
92
+
93
+ self.mixer = MambaBlock(config)
94
+ self.norm = RMSNorm(config.d_model)
95
+
96
+ def forward(self, x):
97
+ # x : (B, L, D)
98
+
99
+ # output : (B, L, D)
100
+
101
+ output = self.mixer(self.norm(x)) + x
102
+ return output
103
+
104
+ def step(self, x, cache):
105
+ # x : (B, D)
106
+ # cache : (h, inputs)
107
+ # h : (B, ED, N)
108
+ # inputs: (B, ED, d_conv-1)
109
+
110
+ # output : (B, D)
111
+ # cache : (h, inputs)
112
+
113
+ output, cache = self.mixer.step(self.norm(x), cache)
114
+ output = output + x
115
+ return output, cache
116
+
117
+ class MambaBlock(nn.Module):
118
+ def __init__(self, config: MambaConfig):
119
+ super().__init__()
120
+
121
+ self.config = config
122
+
123
+ # projects block input from D to 2*ED (two branches)
124
+ self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
125
+
126
+ self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
127
+ kernel_size=config.d_conv, bias=config.conv_bias,
128
+ groups=config.d_inner,
129
+ padding=config.d_conv - 1)
130
+
131
+ nn.init.kaiming_normal_(self.conv1d.weight, mode='fan_out', nonlinearity='leaky_relu')
132
+
133
+ # projects x to input-dependent Δ, B, C
134
+ self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
135
+
136
+ # projects Δ from dt_rank to d_inner
137
+ self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
138
+
139
+ # dt initialization
140
+ # dt weights
141
+ dt_init_std = config.dt_rank**-0.5 * config.dt_scale
142
+ if config.dt_init == "constant":
143
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
144
+ elif config.dt_init == "random":
145
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
146
+ else:
147
+ raise NotImplementedError
148
+
149
+ # dt bias
150
+ dt = torch.exp(
151
+ torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
152
+ ).clamp(min=config.dt_init_floor)
153
+ inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
154
+ with torch.no_grad():
155
+ self.dt_proj.bias.copy_(inv_dt)
156
+ #self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
157
+ # todo : explain why removed
158
+
159
+ # S4D real initialization
160
+ A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
161
+ self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
162
+ self.D = nn.Parameter(torch.ones(config.d_inner))
163
+
164
+ # projects block output from ED back to D
165
+ self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
166
+
167
+ def forward(self, x):
168
+ # x : (B, L, D)
169
+
170
+ # y : (B, L, D)
171
+
172
+ _, L, _ = x.shape
173
+
174
+ xz = self.in_proj(x) # (B, L, 2*ED)
175
+ x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
176
+
177
+ # x branch
178
+ x = x.transpose(1, 2) # (B, ED, L)
179
+ x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
180
+ x = x.transpose(1, 2) # (B, L, ED)
181
+
182
+ x = F.silu(x)
183
+ y = self.ssm(x)
184
+
185
+ # z branch
186
+ z = F.silu(z)
187
+
188
+ output = y * z
189
+ output = self.out_proj(output) # (B, L, D)
190
+
191
+ return output
192
+
193
+ def ssm(self, x):
194
+ # x : (B, L, ED)
195
+
196
+ # y : (B, L, ED)
197
+
198
+ A = -torch.exp(self.A_log.float()) # (ED, N)
199
+ D = self.D.float()
200
+ # TODO remove .float()
201
+
202
+ deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
203
+
204
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
205
+ delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
206
+
207
+ if self.config.pscan:
208
+ y = self.selective_scan(x, delta, A, B, C, D)
209
+ else:
210
+ y = self.selective_scan_seq(x, delta, A, B, C, D)
211
+
212
+ return y
213
+
214
+ def selective_scan(self, x, delta, A, B, C, D):
215
+ # x : (B, L, ED)
216
+ # Δ : (B, L, ED)
217
+ # A : (ED, N)
218
+ # B : (B, L, N)
219
+ # C : (B, L, N)
220
+ # D : (ED)
221
+
222
+ # y : (B, L, ED)
223
+
224
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
225
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
226
+
227
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
228
+
229
+ hs = pscan(deltaA, BX)
230
+
231
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
232
+
233
+ y = y + D * x
234
+
235
+ return y
236
+
237
+ def selective_scan_seq(self, x, delta, A, B, C, D):
238
+ # x : (B, L, ED)
239
+ # Δ : (B, L, ED)
240
+ # A : (ED, N)
241
+ # B : (B, L, N)
242
+ # C : (B, L, N)
243
+ # D : (ED)
244
+
245
+ # y : (B, L, ED)
246
+
247
+ _, L, _ = x.shape
248
+
249
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
250
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
251
+
252
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
253
+
254
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
255
+ hs = []
256
+
257
+ for t in range(0, L):
258
+ h = deltaA[:, t] * h + BX[:, t]
259
+ hs.append(h)
260
+
261
+ hs = torch.stack(hs, dim=1) # (B, L, ED, N)
262
+
263
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
264
+
265
+ y = y + D * x
266
+
267
+ return y
268
+
269
+ # -------------------------- inference -------------------------- #
270
+ """
271
+ Concerning auto-regressive inference
272
+
273
+ The cool part of using Mamba : inference is constant wrt to sequence length
274
+ We just have to keep in cache, for each layer, two things :
275
+ - the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
276
+ - the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
277
+ (d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
278
+ (and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
279
+
280
+ Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
281
+ h is (B, ED, N), and inputs is (B, ED, d_conv-1)
282
+ The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
283
+
284
+ The cache object is initialized as follows : (None, torch.zeros()).
285
+ When h is None, the selective scan function detects it and start with h=0.
286
+ The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
287
+
288
+ As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
289
+ """
290
+
291
+ def step(self, x, cache):
292
+ # x : (B, D)
293
+ # cache : (h, inputs)
294
+ # h : (B, ED, N)
295
+ # inputs : (B, ED, d_conv-1)
296
+
297
+ # y : (B, D)
298
+ # cache : (h, inputs)
299
+
300
+ h, inputs = cache
301
+
302
+ xz = self.in_proj(x) # (B, 2*ED)
303
+ x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
304
+
305
+ # x branch
306
+ x_cache = x.unsqueeze(2)
307
+ x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
308
+
309
+ x = F.silu(x)
310
+ y, h = self.ssm_step(x, h)
311
+
312
+ # z branch
313
+ z = F.silu(z)
314
+
315
+ output = y * z
316
+ output = self.out_proj(output) # (B, D)
317
+
318
+ # prepare cache for next call
319
+ inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
320
+ cache = (h, inputs)
321
+
322
+ return output, cache
323
+
324
+ def ssm_step(self, x, h):
325
+ # x : (B, ED)
326
+ # h : (B, ED, N)
327
+
328
+ # y : (B, ED)
329
+ # h : (B, ED, N)
330
+
331
+ A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
332
+ D = self.D.float()
333
+ # TODO remove .float()
334
+
335
+ deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
336
+
337
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
338
+ delta = F.softplus(self.dt_proj(delta)) # (B, ED)
339
+
340
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
341
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
342
+
343
+ BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
344
+
345
+ if h is None:
346
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
347
+
348
+ h = deltaA * h + BX # (B, ED, N)
349
+
350
+ y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
351
+
352
+ y = y + D * x
353
+
354
+ # todo : pq h.squeeze(1) ??
355
+ return y, h.squeeze(1)
356
+
357
+ # taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
358
+ class RMSNorm(nn.Module):
359
+ def __init__(self, d_model: int, eps: float = 1e-5):
360
+ super().__init__()
361
+
362
+ self.eps = eps
363
+ self.weight = nn.Parameter(torch.ones(d_model))
364
+
365
+ def forward(self, x):
366
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
367
+
368
+ return output
chess-mamba-vs-xformer/mamba_lm.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields, asdict
2
+ import json
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from mamba import Mamba, MambaConfig, RMSNorm
9
+
10
+ """
11
+
12
+ Encapsulates a Mamba model as language model. It has an embedding layer, and a LM head which maps the model output to logits.
13
+
14
+ """
15
+
16
+ # TODO generate function : batch size != 1 ? (for now B=1)
17
+ # TODO generate function : top-p sampling
18
+
19
+ @dataclass
20
+ class MambaLMConfig(MambaConfig):
21
+ vocab_size: int = 32000
22
+ pad_vocab_size_multiple: int = 8
23
+
24
+ def __post_init__(self):
25
+ super().__post_init__()
26
+
27
+ #if self.vocab_size % self.pad_vocab_size_multiple != 0:
28
+ # self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
29
+
30
+ def to_mamba_config(self) -> MambaConfig:
31
+ mamba_config_fields = {field.name for field in fields(MambaConfig)}
32
+ filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
33
+ return MambaConfig(**filtered_dict)
34
+
35
+ # adapted from https://github.com/johnma2006/mamba-minimal
36
+ def from_pretrained(name: str):
37
+ """
38
+ Returns a model loaded with pretrained weights pulled from HuggingFace.
39
+
40
+ Args:
41
+ name: As of now, supports
42
+ * 'state-spaces/mamba-2.8b-slimpj'
43
+ * 'state-spaces/mamba-2.8b'
44
+ * 'state-spaces/mamba-1.4b'
45
+ * 'state-spaces/mamba-790m'
46
+ * 'state-spaces/mamba-370m'
47
+ * 'state-spaces/mamba-130m'
48
+
49
+ Returns:
50
+ model: a Mamba model configured with the proper parameters and initialized with the proper weights
51
+ """
52
+
53
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
54
+ from transformers.utils.hub import cached_file
55
+
56
+ def load_config_hf(model_name):
57
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
58
+ return json.load(open(resolved_archive_file))
59
+
60
+ def load_state_dict_hf(model_name):
61
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
62
+ return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
63
+
64
+ # copy config data
65
+ config_data = load_config_hf(name)
66
+ config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
67
+
68
+ model = MambaLM(config)
69
+
70
+ # copy weights
71
+ state_dict = load_state_dict_hf(name)
72
+
73
+ new_state_dict = {}
74
+ for key in state_dict:
75
+ if key == 'backbone.embedding.weight' or key == 'backbone.norm_f.weight':
76
+ new_key = key.replace('backbone.', '')
77
+ else:
78
+ new_key = key.replace('backbone', 'mamba')
79
+
80
+ new_state_dict[new_key] = state_dict[key]
81
+
82
+ model.load_state_dict(new_state_dict)
83
+
84
+ return model
85
+
86
+ class MambaLM(nn.Module):
87
+ def __init__(self, lm_config: MambaLMConfig):
88
+ super().__init__()
89
+ self.lm_config = lm_config
90
+ self.config = lm_config.to_mamba_config()
91
+
92
+ self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
93
+ self.mamba = Mamba(self.config)
94
+ self.norm_f = RMSNorm(self.config.d_model)
95
+
96
+ self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
97
+ self.lm_head.weight = self.embedding.weight
98
+
99
+ def forward(self, tokens):
100
+ # tokens : (B, L)
101
+
102
+ # logits : (B, L, vocab_size)
103
+
104
+ x = self.embedding(tokens)
105
+
106
+ x = self.mamba(x)
107
+ x = self.norm_f(x)
108
+
109
+ logits = self.lm_head(x)
110
+
111
+ return logits
112
+
113
+ def step(self, token, caches):
114
+ # token : (B)
115
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
116
+
117
+ # logits : (B, vocab_size)
118
+ # caches : [cache(layer) for all layers], cache : (h, inputs)
119
+
120
+ x = self.embedding(token)
121
+
122
+ x, caches = self.mamba.step(x, caches)
123
+ x = self.norm_f(x)
124
+
125
+ logits = self.lm_head(x)
126
+
127
+ return logits, caches
128
+
129
+ # TODO temperature
130
+ # TODO process prompt in parallel, and pass in sequential mode when prompt is finished ?
131
+ def generate(self, tokenizer, prompt: str, num_tokens: int = 50, sample: bool = True, top_k: int = 40):
132
+ self.eval()
133
+
134
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(next(self.parameters()).device) # (1, num_tokens)
135
+
136
+ # caches is a list of cache, one per layer
137
+ # cache is composed of : the hidden state, and the last d_conv-1 inputs
138
+ # the hidden state because the update is like an RNN
139
+ # the last d_conv-1 inputs because they are used in a 1d convolution (usually d_conv=4 so this is not large)
140
+ caches = [(None, torch.zeros(1, self.config.d_inner, self.config.d_conv-1, device=input_ids.device)) for _ in range(self.config.n_layers)]
141
+
142
+ for i in range(input_ids.size(1) + num_tokens - 1):
143
+ with torch.no_grad():
144
+ # forward the new output, get new cache
145
+ next_token_logits, caches = self.step(input_ids[:, i], caches) # (1, vocab_size), caches
146
+
147
+ # sample (no sampling when the prompt is being processed)
148
+ if i+1 >= input_ids.size(1):
149
+ probs = F.softmax(next_token_logits, dim=-1) # (1, vocab_size)
150
+
151
+ if top_k is not None:
152
+ values, _ = torch.topk(probs, k=top_k) # (1, k) ordered from lowest to biggest
153
+ probs[probs < values[:, -1, None]] = 0
154
+ probs = probs / probs.sum(axis=1, keepdims=True)
155
+
156
+ if sample:
157
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1) # (1)
158
+ else:
159
+ next_token = torch.argmax(probs, dim=-1) # (1)
160
+
161
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)
162
+
163
+ output = [tokenizer.decode(output.tolist()) for output in input_ids][0]
164
+
165
+ self.train()
166
+
167
+ return output
168
+
chess-mamba-vs-xformer/openings.csv ADDED
The diff for this file is too large to render. See raw diff
 
chess-mamba-vs-xformer/pscan.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ """
7
+
8
+ An implementation of the parallel scan operation in PyTorch (Blelloch version).
9
+ Please see docs/pscan.ipynb for a detailed explanation of what happens here.
10
+
11
+ """
12
+
13
+ def npo2(len):
14
+ """
15
+ Returns the next power of 2 above len
16
+ """
17
+
18
+ return 2 ** math.ceil(math.log2(len))
19
+
20
+ def pad_npo2(X):
21
+ """
22
+ Pads input length dim to the next power of 2
23
+
24
+ Args:
25
+ X : (B, L, D, N)
26
+
27
+ Returns:
28
+ Y : (B, npo2(L), D, N)
29
+ """
30
+
31
+ len_npo2 = npo2(X.size(1))
32
+ pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
33
+ return F.pad(X, pad_tuple, "constant", 0)
34
+
35
+ class PScan(torch.autograd.Function):
36
+ @staticmethod
37
+ def pscan(A, X):
38
+ # A : (B, D, L, N)
39
+ # X : (B, D, L, N)
40
+
41
+ # modifies X in place by doing a parallel scan.
42
+ # more formally, X will be populated by these values :
43
+ # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
44
+ # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
45
+
46
+ # only supports L that is a power of two (mainly for a clearer code)
47
+
48
+ B, D, L, _ = A.size()
49
+ num_steps = int(math.log2(L))
50
+
51
+ # up sweep (last 2 steps unfolded)
52
+ Aa = A
53
+ Xa = X
54
+ for _ in range(num_steps-2):
55
+ T = Xa.size(2)
56
+ Aa = Aa.view(B, D, T//2, 2, -1)
57
+ Xa = Xa.view(B, D, T//2, 2, -1)
58
+
59
+ Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
60
+ Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
61
+
62
+ Aa = Aa[:, :, :, 1]
63
+ Xa = Xa[:, :, :, 1]
64
+
65
+ # we have only 4, 2 or 1 nodes left
66
+ if Xa.size(2) == 4:
67
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
68
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
69
+
70
+ Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
71
+ elif Xa.size(2) == 2:
72
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
73
+ return
74
+ else:
75
+ return
76
+
77
+ # down sweep (first 2 steps unfolded)
78
+ Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
79
+ Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
80
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
81
+ Aa[:, :, 2].mul_(Aa[:, :, 1])
82
+
83
+ for k in range(num_steps-3, -1, -1):
84
+ Aa = A[:, :, 2**k-1:L:2**k]
85
+ Xa = X[:, :, 2**k-1:L:2**k]
86
+
87
+ T = Xa.size(2)
88
+ Aa = Aa.view(B, D, T//2, 2, -1)
89
+ Xa = Xa.view(B, D, T//2, 2, -1)
90
+
91
+ Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
92
+ Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
93
+
94
+ @staticmethod
95
+ def pscan_rev(A, X):
96
+ # A : (B, D, L, N)
97
+ # X : (B, D, L, N)
98
+
99
+ # the same function as above, but in reverse
100
+ # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
101
+ # it is used in the backward pass
102
+
103
+ # only supports L that is a power of two (mainly for a clearer code)
104
+
105
+ B, D, L, _ = A.size()
106
+ num_steps = int(math.log2(L))
107
+
108
+ # up sweep (last 2 steps unfolded)
109
+ Aa = A
110
+ Xa = X
111
+ for _ in range(num_steps-2):
112
+ T = Xa.size(2)
113
+ Aa = Aa.view(B, D, T//2, 2, -1)
114
+ Xa = Xa.view(B, D, T//2, 2, -1)
115
+
116
+ Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
117
+ Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
118
+
119
+ Aa = Aa[:, :, :, 0]
120
+ Xa = Xa[:, :, :, 0]
121
+
122
+ # we have only 4, 2 or 1 nodes left
123
+ if Xa.size(2) == 4:
124
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
125
+ Aa[:, :, 2].mul_(Aa[:, :, 3])
126
+
127
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
128
+ elif Xa.size(2) == 2:
129
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
130
+ return
131
+ else:
132
+ return
133
+
134
+ # down sweep (first 2 steps unfolded)
135
+ Aa = A[:, :, 0:L:2**(num_steps-2)]
136
+ Xa = X[:, :, 0:L:2**(num_steps-2)]
137
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
138
+ Aa[:, :, 1].mul_(Aa[:, :, 2])
139
+
140
+ for k in range(num_steps-3, -1, -1):
141
+ Aa = A[:, :, 0:L:2**k]
142
+ Xa = X[:, :, 0:L:2**k]
143
+
144
+ T = Xa.size(2)
145
+ Aa = Aa.view(B, D, T//2, 2, -1)
146
+ Xa = Xa.view(B, D, T//2, 2, -1)
147
+
148
+ Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
149
+ Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
150
+
151
+ @staticmethod
152
+ def forward(ctx, A_in, X_in):
153
+ """
154
+ Applies the parallel scan operation, as defined above. Returns a new tensor.
155
+ If you can, privilege sequence lengths that are powers of two.
156
+
157
+ Args:
158
+ A_in : (B, L, D, N)
159
+ X_in : (B, L, D, N)
160
+
161
+ Returns:
162
+ H : (B, L, D, N)
163
+ """
164
+
165
+ L = X_in.size(1)
166
+
167
+ # cloning is requiered because of the in-place ops
168
+ if L == npo2(L):
169
+ A = A_in.clone()
170
+ X = X_in.clone()
171
+ else:
172
+ # pad tensors (and clone btw)
173
+ A = pad_npo2(A_in) # (B, npo2(L), D, N)
174
+ X = pad_npo2(X_in) # (B, npo2(L), D, N)
175
+
176
+ # prepare tensors
177
+ A = A.transpose(2, 1) # (B, D, npo2(L), N)
178
+ X = X.transpose(2, 1) # (B, D, npo2(L), N)
179
+
180
+ # parallel scan (modifies X in-place)
181
+ PScan.pscan(A, X)
182
+
183
+ ctx.save_for_backward(A_in, X)
184
+
185
+ # slice [:, :L] (cut if there was padding)
186
+ return X.transpose(2, 1)[:, :L]
187
+
188
+ @staticmethod
189
+ def backward(ctx, grad_output_in):
190
+ """
191
+ Flows the gradient from the output to the input. Returns two new tensors.
192
+
193
+ Args:
194
+ ctx : A_in : (B, L, D, N), X : (B, D, L, N)
195
+ grad_output_in : (B, L, D, N)
196
+
197
+ Returns:
198
+ gradA : (B, L, D, N), gradX : (B, L, D, N)
199
+ """
200
+
201
+ A_in, X = ctx.saved_tensors
202
+
203
+ L = grad_output_in.size(1)
204
+
205
+ # cloning is requiered because of the in-place ops
206
+ if L == npo2(L):
207
+ grad_output = grad_output_in.clone()
208
+ # the next padding will clone A_in
209
+ else:
210
+ grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
211
+ A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
212
+
213
+ # prepare tensors
214
+ grad_output = grad_output.transpose(2, 1)
215
+ A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
216
+ A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
217
+
218
+ # reverse parallel scan (modifies grad_output in-place)
219
+ PScan.pscan_rev(A, grad_output)
220
+
221
+ Q = torch.zeros_like(X)
222
+ Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
223
+
224
+ return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
225
+
226
+ pscan = PScan.apply
chess-mamba-vs-xformer/train_bygame.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import pickle
5
+ from contextlib import nullcontext
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+ import pyarrow.parquet as pq
12
+ import random
13
+ from torch.utils.data import Dataset, DataLoader
14
+ import glob
15
+
16
+ # -----------------------------------------------------------------------------
17
+ # default config values designed for Mamba model training
18
+ # I/O
19
+ out_dir = 'out'
20
+ eval_interval = 2000
21
+ log_interval = 1
22
+ eval_iters = 5
23
+ eval_only = False
24
+ always_save_checkpoint = True
25
+ init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name
26
+ # wandb logging
27
+ wandb_log = False
28
+ wandb_project = 'mamba'
29
+ wandb_run_name = 'mamba_run' # modify as needed
30
+ # data
31
+ dataset = 'chess' # specify your dataset
32
+ gradient_accumulation_steps = 5 * 8
33
+ batch_size = 12
34
+ base_batch_size = batch_size
35
+ effective_batch_size = batch_size
36
+ max_seq_len = 1024 # For xformer, this is the block size
37
+ train_file_update_interval = 7
38
+
39
+ # model
40
+ model_type = 'mamba'
41
+ # TODO: add 'xformer' type / model paramers. move model imports to after exec() (when these values finalized)
42
+ n_layer = 12
43
+ d_model = 768
44
+ dt_rank = 'auto'
45
+ d_state = 16
46
+ expand_factor = 2
47
+ bias = False
48
+ conv_bias = True
49
+ pscan = True
50
+ vocab_size = 32
51
+ move_num_in_gamestate = True
52
+ # xformer-specific params. Note that n_layer, vocab_size, move_num_in_gamestate, and bias are shared by both model types
53
+ n_head = 12
54
+ n_embd = 768
55
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
+
57
+ # optimizer settings
58
+ learning_rate = 6e-4
59
+ max_iters = 600000 # max_iters is for auto-stopping end of stable phase
60
+ weight_decay = 1e-1
61
+ beta1 = 0.9
62
+ beta2 = 0.95
63
+ grad_clip = 0.5
64
+ auto_clip = False
65
+ auto_clip_max = 0.5
66
+ auto_clip_min = 3.333e-3
67
+ grad_clip_start_size = 100
68
+ grad_clip_max_size = 500
69
+ grad_clip_percentile = 10
70
+ # learning rate decay settings
71
+ decay_lr = True
72
+ warmup_iters = 2000
73
+ min_lr = 6e-5
74
+ # DDP settings
75
+ backend = 'nccl'
76
+ # system
77
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
78
+ dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
79
+ compile = False # set to True if using PyTorch 2.0
80
+ # -----------------------------------------------------------------------------
81
+
82
+ config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
83
+ exec(open('configurator.py').read()) # overrides from command line or config file
84
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
85
+ # -----------------------------------------------------------------------------
86
+
87
+
88
+ anneal_checkpoint = 'anneal/ckpt.pt'
89
+ anneal_dir = os.path.join(out_dir, 'anneal/')
90
+ anneal_start_iters = None # Set at init
91
+ anneal_decay_iters = None # Set at init
92
+
93
+ if model_type == 'mamba':
94
+ from mamba_lm import MambaLM, MambaLMConfig
95
+ model_config = MambaLMConfig(
96
+ d_model=d_model,
97
+ n_layers=n_layer,
98
+ dt_rank=dt_rank,
99
+ d_state=d_state,
100
+ expand_factor=expand_factor,
101
+ bias=bias,
102
+ conv_bias=conv_bias,
103
+ pscan=pscan,
104
+ vocab_size=vocab_size
105
+ )
106
+ elif model_type == 'xformer':
107
+ from xformer import GPTConfig, GPT
108
+ model_config = GPTConfig(
109
+ n_layer=n_layer,
110
+ n_head=n_head,
111
+ n_embd=n_embd,
112
+ block_size=max_seq_len,
113
+ bias=bias,
114
+ vocab_size=vocab_size,
115
+ dropout=dropout)
116
+ else:
117
+ print(f"Unknown model_type {model_type}.")
118
+ exit()
119
+
120
+ # DDP and other initializations
121
+ ddp = int(os.environ.get('RANK', -1)) != -1
122
+ if ddp:
123
+ init_process_group(backend=backend)
124
+ ddp_rank = int(os.environ['RANK'])
125
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
126
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
127
+ device = f'cuda:{ddp_local_rank}'
128
+ torch.cuda.set_device(device)
129
+ master_process = ddp_rank == 0
130
+ seed_offset = ddp_rank
131
+ assert gradient_accumulation_steps % ddp_world_size == 0
132
+ gradient_accumulation_steps //= ddp_world_size
133
+ else:
134
+ master_process = True
135
+ seed_offset = 0
136
+ ddp_world_size = 1
137
+
138
+ if master_process:
139
+ os.makedirs(out_dir, exist_ok=True)
140
+ os.makedirs(anneal_dir, exist_ok=True)
141
+ torch.manual_seed(1337 + seed_offset)
142
+ torch.backends.cuda.matmul.allow_tf32 = True
143
+ torch.backends.cudnn.allow_tf32 = True
144
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
145
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype]
146
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
147
+
148
+ # poor man's data loader
149
+ data_dir = os.path.join('data', dataset)
150
+ current_train_file_index = 0
151
+ train_files = glob.glob(os.path.join(data_dir, 'train*.parquet'))
152
+ train_datasets = []
153
+ for f in train_files:
154
+ dataset = pq.read_table(f).to_pandas()
155
+ dataset = dataset[dataset['tokenized'].apply(len) >= 8]
156
+ train_datasets.append(dataset)
157
+ #val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas()
158
+ #val_data = val_data[val_data['tokenized'].apply(len) >= 8]
159
+ truncated_games_count = 0
160
+ total_games_count = 0
161
+ games_seen = 0
162
+ tokens_seen = 0
163
+ tokens_seen_padded = 0
164
+ def get_batch(split):
165
+ global truncated_games_count, total_games_count, current_train_file_index, tokens_seen, tokens_seen_padded
166
+
167
+ # Randomly select batch_size games
168
+ dataset = train_datasets[current_train_file_index] if split == 'train' else None # else val_data # Use the correct DataFrame based on the split
169
+ sample_df = dataset.sample(batch_size)
170
+ games = sample_df['tokenized'].tolist()
171
+
172
+ # Prepare sequences tensor for the batch
173
+ max_length_in_batch = min(max(len(game) for game in games), max_seq_len)
174
+ pad_to = max_length_in_batch #if model_type == 'mamba' else max_seq_len
175
+ sequences = torch.zeros((batch_size, pad_to), dtype=torch.int64)
176
+
177
+ for i, game in enumerate(games):
178
+ total_games_count += 1
179
+ game_len = min(len(game), pad_to)
180
+ tokens_seen += game_len
181
+ tokens_seen_padded += pad_to
182
+ sequences[i, :game_len] = torch.tensor(game[:game_len], dtype=torch.int64)
183
+
184
+ if (total_games_count // batch_size) % train_file_update_interval == 0:
185
+ current_train_file_index = random.randint(0, len(train_files) - 1)
186
+ # print(f"Switched to file: {train_files[current_train_file_index]}")
187
+
188
+ if device_type == 'cuda':
189
+ sequences = sequences.pin_memory().to(device, non_blocking=True)
190
+ else:
191
+ sequences = sequences.to(device)
192
+
193
+ return sequences, max_length_in_batch
194
+
195
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
196
+ iter_num = 0
197
+ best_val_loss = 1e9
198
+
199
+ # attempt to derive vocab_size from the dataset
200
+ meta_path = os.path.join(data_dir, 'meta.pkl')
201
+ meta_vocab_size = None
202
+ if not move_num_in_gamestate:
203
+ meta_vocab_size = 28
204
+ elif os.path.exists(meta_path):
205
+ with open(meta_path, 'rb') as f:
206
+ meta = pickle.load(f)
207
+ meta_vocab_size = meta['vocab_size']
208
+ print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
209
+
210
+ # Model initialization
211
+ if init_from == 'scratch':
212
+ print(f"Initializing a new {model_type} model from scratch")
213
+ if meta_vocab_size is None:
214
+ print(f"defaulting to vocab_size of {vocab_size}")
215
+ else:
216
+ model_config.vocab_size = meta_vocab_size
217
+ if model_type == 'mamba':
218
+ model = MambaLM(model_config)
219
+ else:
220
+ model = GPT(model_config)
221
+ if auto_clip:
222
+ grad_clip = 0
223
+ config['grad_clip'] = 0
224
+ grad_norm_history = []
225
+ elif init_from == 'resume' or init_from == 'anneal':
226
+ print(f"Resuming training from {out_dir}")
227
+ if init_from == 'anneal':
228
+ ckpt_path = os.path.join(out_dir, anneal_checkpoint)
229
+ else:
230
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
231
+ checkpoint = torch.load(ckpt_path, map_location=device)
232
+ model_config = checkpoint['model_args']
233
+ if model_type == 'mamba':
234
+ model = MambaLM(model_config)
235
+ else:
236
+ model = GPT(model_config)
237
+ state_dict = checkpoint['model']
238
+ # fix the keys of the state dictionary :(
239
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
240
+ unwanted_prefix = '_orig_mod.'
241
+ for k,v in list(state_dict.items()):
242
+ if k.startswith(unwanted_prefix):
243
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
244
+ model.load_state_dict(state_dict)
245
+ if 'effective_batch_size' not in checkpoint['config']:
246
+ print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.")
247
+ checkpoint['config']['effective_batch_size'] = effective_batch_size
248
+ iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size)))
249
+ if 'games_seen' in checkpoint:
250
+ games_seen = checkpoint['games_seen']
251
+ else:
252
+ games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num']
253
+ checkpoint['games_seen'] = games_seen
254
+ print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}")
255
+ tokens_seen = checkpoint.get('tokens_seen', 0)
256
+ tokens_seen_padded = checkpoint.get('tokens_seen_padded', 0)
257
+ best_val_loss = checkpoint['best_val_loss']
258
+ print(f"Best val loss: {best_val_loss}")
259
+ if auto_clip:
260
+ grad_clip = checkpoint['config']['grad_clip']
261
+ config['grad_clip'] = grad_clip
262
+ #grad_norm_history = [t.item() if torch.is_tensor(t) else t for t in checkpoint.get('grad_norm_history', [])]
263
+ grad_norm_history = checkpoint.get('grad_norm_history', [])
264
+ if init_from == 'anneal':
265
+ print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n")
266
+ anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters']
267
+ anneal_decay_iters = iter_num / 8 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] # / 9 is og, but going deeper on lr too (can always take earlier ckpt during anneal if it doesn't keep improving)... have used 6.75
268
+ print(anneal_start_iters)
269
+ print(anneal_decay_iters)
270
+ if 'anneal_start_iters' not in checkpoint:
271
+ grad_clip = 0
272
+ config['grad_clip'] = 0
273
+ grad_norm_history = []
274
+ print(f"Starting anneal. Resumed from {anneal_checkpoint}, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.")
275
+ out_dir = anneal_dir
276
+ weight_decay = weight_decay / 12.5 # / 17.0
277
+ beta2 = np.sqrt(beta2) * beta2
278
+ auto_clip = True
279
+ grad_clip_percentile = 6.75
280
+ elif init_from.startswith('state-spaces'):
281
+ print(f"Initializing from Mamba pre-trained weights: {init_from}")
282
+ model = from_pretrained(init_from)
283
+ model_config = model.config
284
+ else:
285
+ raise ValueError("Invalid init_from value")
286
+
287
+ model.to(device)
288
+
289
+ print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.')
290
+
291
+ # Optimizer and GradScaler
292
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
293
+ scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16')
294
+ if init_from == 'resume':
295
+ optimizer.load_state_dict(checkpoint['optimizer'])
296
+ checkpoint = None
297
+
298
+ # Compile the model if using PyTorch 2.0
299
+ if compile:
300
+ print("compiling the model... (takes a ~minute)")
301
+ model = torch.compile(model)
302
+
303
+ # Wrap model in DDP container if necessary
304
+ if ddp:
305
+ model = DDP(model, device_ids=[ddp_local_rank])
306
+
307
+
308
+ def batch_to_loss(sequences, max_length_in_batch):
309
+ if model_type == 'mamba':
310
+ logits = model(sequences[:, :-1]) # Forward pass, exclude last token for input
311
+ # Compute loss (assuming next token prediction task)
312
+ targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction
313
+ return F.cross_entropy(logits.view(-1, logits.size(-1)), targets)
314
+ else:
315
+ inputs = sequences[:, :-1]
316
+ targets = sequences[:, 1:].reshape(-1)
317
+ _, loss = model(inputs, targets)
318
+ return loss
319
+
320
+
321
+ @torch.no_grad()
322
+ def estimate_loss():
323
+ global tokens_seen, tokens_seen_padded
324
+ out = {}
325
+ model.eval()
326
+ tokens_seen_b4 = tokens_seen
327
+ tokens_seen_padded_b4 = tokens_seen_padded
328
+ for split in ['train']: #['train', 'val']:
329
+ losses = torch.zeros(eval_iters)
330
+ for k in range(eval_iters):
331
+ loss = batch_to_loss(*get_batch(split))
332
+ losses[k] = loss.item()
333
+
334
+ split = 'val' # Temporary hack
335
+ out[split] = losses.mean()
336
+ tokens_seen = tokens_seen_b4
337
+ tokens_seen_padded = tokens_seen_padded_b4
338
+ model.train()
339
+ return out
340
+
341
+
342
+ # WSD scheduler
343
+ def get_lr(it):
344
+ if init_from == 'anneal':
345
+ # Linear decay from max LR to min LR over (anneal_start_iters / 9) iters
346
+ decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters
347
+ return learning_rate - decay_ratio * (learning_rate - min_lr)
348
+
349
+ if it < warmup_iters:
350
+ # Warmup
351
+ return learning_rate * it / warmup_iters
352
+
353
+ # Stable max LR
354
+ return learning_rate
355
+
356
+ # Logging setup
357
+ if wandb_log and master_process:
358
+ import wandb
359
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
360
+
361
+ # Training loop
362
+ local_iter_num = 0 # Number of iterations in the lifetime of this process
363
+ last_crossed_multiple = 0
364
+ save_every_n_games = 150000
365
+ raw_model = model.module if ddp else model # Unwrap DDP container if needed
366
+
367
+ # initial save
368
+ if init_from == 'scratch':
369
+ checkpoint = {
370
+ 'model': raw_model.state_dict(),
371
+ 'optimizer': optimizer.state_dict(),
372
+ 'model_args': model_config,
373
+ 'iter_num': 0,
374
+ "games_seen": 0,
375
+ "tokens_seen": 0,
376
+ "tokens_seen_padded": 0,
377
+ 'best_val_loss': best_val_loss,
378
+ 'config': config,
379
+ }
380
+ checkpoint['grad_norm_history'] = grad_norm_history
381
+ print(f"saving checkpoint to {out_dir}\n")
382
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
383
+
384
+ t0 = time.time()
385
+ while True:
386
+ # Determine and set the learning rate for this iteration
387
+ lr = get_lr(iter_num) if decay_lr else learning_rate
388
+ for param_group in optimizer.param_groups:
389
+ param_group['lr'] = lr
390
+
391
+ # Evaluate the loss on train/val sets and write checkpoints
392
+ if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
393
+ torch.cuda.empty_cache()
394
+ losses = estimate_loss()
395
+ print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}")
396
+ if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
397
+ grad_clip_prev = grad_clip
398
+ grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
399
+ grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min)
400
+ # Transition between grad_clips smoothly, weighed to new value
401
+ grad_clip = (grad_clip*9.0 + grad_clip_prev*4.0) / 13.0
402
+ grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min) # should never actually clip here
403
+ config['grad_clip'] = grad_clip
404
+ print(f"Auto adjusted grad_clip to {grad_clip}")
405
+ torch.cuda.empty_cache()
406
+ if wandb_log:
407
+ wandb.log({
408
+ "etc/iter": iter_num,
409
+ "etc/games": games_seen,
410
+ "etc/tokens_seen": tokens_seen,
411
+ "etc/tokens_seen_padded": tokens_seen_padded,
412
+ "etc/grad_clip": grad_clip,
413
+ "etc/lr": lr,
414
+ "val/loss": losses['val'],
415
+
416
+ })
417
+ if losses['val'] < best_val_loss or always_save_checkpoint:
418
+ if iter_num > 0:
419
+ checkpoint = {
420
+ 'model': raw_model.state_dict(),
421
+ 'optimizer': optimizer.state_dict(),
422
+ 'model_args': model_config,
423
+ 'iter_num': iter_num,
424
+ "games_seen": games_seen,
425
+ "tokens_seen": tokens_seen,
426
+ "tokens_seen_padded": tokens_seen_padded,
427
+ 'best_val_loss': min(best_val_loss, losses['val']),
428
+ 'config': config,
429
+ }
430
+ checkpoint['grad_norm_history'] = grad_norm_history
431
+ if init_from == 'anneal':
432
+ checkpoint['anneal_start_iters'] = anneal_start_iters
433
+ checkpoint['anneal_decay_iters'] = anneal_decay_iters
434
+ print(f"saving checkpoint to {out_dir}\n")
435
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
436
+ current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games
437
+ if losses['val'] < best_val_loss: # Temporary / only good after it's settled
438
+ best_val_loss = losses['val']
439
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
440
+ elif current_nearest_multiple != last_crossed_multiple: # elif so we don't double up
441
+ last_crossed_multiple = current_nearest_multiple
442
+ torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt'))
443
+
444
+ if iter_num == 0 and eval_only:
445
+ break
446
+
447
+ # Forward and backward pass
448
+ for micro_step in range(gradient_accumulation_steps):
449
+ if ddp:
450
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
451
+
452
+ sequences, max_length_in_batch = get_batch('train') # Fetch the training data
453
+ with ctx:
454
+ loss = batch_to_loss(sequences, max_length_in_batch)
455
+ loss = loss / gradient_accumulation_steps
456
+
457
+ scaler.scale(loss).backward()
458
+ #print('.', end='')
459
+
460
+ # clip the gradient
461
+ if grad_clip != 0.0 or auto_clip:
462
+ scaler.unscale_(optimizer)
463
+ total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) # The 0 check is for auto_clip enabled but not enough history
464
+ grad_norm_history.append(total_norm.item())
465
+ grad_norm_history = grad_norm_history[-grad_clip_max_size:]
466
+
467
+ # step the optimizer and scaler if training in fp16
468
+ scaler.step(optimizer)
469
+ scaler.update()
470
+ # flush the gradients as soon as we can, no need for this memory anymore
471
+ optimizer.zero_grad(set_to_none=True)
472
+ torch.cuda.empty_cache()
473
+
474
+ # timing and logging
475
+ t1 = time.time()
476
+ dt = t1 - t0
477
+ t0 = t1
478
+ if iter_num % log_interval == 0 and master_process:
479
+ # get loss as float. note: this is a CPU-GPU sync point
480
+ # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
481
+ lossf = loss.item() * gradient_accumulation_steps
482
+ print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms")
483
+ if wandb_log:
484
+ wandb.log({
485
+ "etc/iter": iter_num,
486
+ "etc/games": games_seen,
487
+ "etc/tokens_seen": tokens_seen,
488
+ "etc/tokens_seen_padded": tokens_seen_padded,
489
+ "etc/grad_norm": grad_norm_history[-1] if grad_norm_history else 0,
490
+ "etc/lr": lr,
491
+ "train/loss": lossf,
492
+ })
493
+ iter_num += 1
494
+ local_iter_num += 1
495
+ games_seen += effective_batch_size
496
+
497
+ # termination conditions
498
+ if iter_num > max_iters and not init_from == 'anneal': # max iters is for auto-stopping end of stable phase
499
+ checkpoint = {
500
+ 'model': raw_model.state_dict(),
501
+ 'optimizer': optimizer.state_dict(),
502
+ 'model_args': model_config,
503
+ 'iter_num': iter_num,
504
+ "games_seen": games_seen,
505
+ "tokens_seen": tokens_seen,
506
+ "tokens_seen_padded": tokens_seen_padded,
507
+ 'best_val_loss': best_val_loss,
508
+ 'config': config,
509
+ }
510
+ checkpoint['grad_norm_history'] = grad_norm_history
511
+ if init_from == 'anneal':
512
+ checkpoint['anneal_start_iters'] = anneal_start_iters
513
+ checkpoint['anneal_decay_iters'] = anneal_decay_iters
514
+ print(f"Max_iters reached. Saving pre-anneal checkpoint to {anneal_checkpoint}")
515
+ torch.save(checkpoint, os.path.join(out_dir, anneal_checkpoint))
516
+ break
517
+ if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters:
518
+ checkpoint = {
519
+ 'model': raw_model.state_dict(),
520
+ 'optimizer': optimizer.state_dict(),
521
+ 'model_args': model_config,
522
+ 'iter_num': iter_num,
523
+ "games_seen": games_seen,
524
+ "tokens_seen": tokens_seen,
525
+ "tokens_seen_padded": tokens_seen_padded,
526
+ 'best_val_loss': best_val_loss,
527
+ 'config': config,
528
+ }
529
+ checkpoint['grad_norm_history'] = grad_norm_history
530
+ if init_from == 'anneal':
531
+ checkpoint['anneal_start_iters'] = anneal_start_iters
532
+ checkpoint['anneal_decay_iters'] = anneal_decay_iters
533
+ print(f"Anneal complete. Saving checkpoint to {out_dir}")
534
+ torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt'))
535
+ break
536
+
537
+
538
+
539
+ if ddp:
540
+ destroy_process_group()
541
+
chess-mamba-vs-xformer/xformer.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head = config.n_head
42
+ self.n_embd = config.n_embd
43
+ self.dropout = config.dropout
44
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46
+ if not self.flash:
47
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48
+ # causal mask to ensure that attention is only applied to the left in the input sequence
49
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50
+ .view(1, 1, config.block_size, config.block_size))
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ if self.flash:
63
+ # efficient attention using Flash Attention CUDA kernels
64
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65
+ else:
66
+ # manual implementation of attention
67
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69
+ att = F.softmax(att, dim=-1)
70
+ att = self.attn_dropout(att)
71
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73
+
74
+ # output projection
75
+ y = self.resid_dropout(self.c_proj(y))
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83
+ self.gelu = nn.GELU()
84
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85
+ self.dropout = nn.Dropout(config.dropout)
86
+
87
+ def forward(self, x):
88
+ x = self.c_fc(x)
89
+ x = self.gelu(x)
90
+ x = self.c_proj(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlp(self.ln_2(x))
106
+ return x
107
+
108
+ @dataclass
109
+ class GPTConfig:
110
+ block_size: int = 1024
111
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112
+ n_layer: int = 12
113
+ n_head: int = 12
114
+ n_embd: int = 768
115
+ dropout: float = 0.0
116
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+
118
+ class GPT(nn.Module):
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ assert config.vocab_size is not None
123
+ assert config.block_size is not None
124
+ self.config = config
125
+
126
+ self.transformer = nn.ModuleDict(dict(
127
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
128
+ wpe = nn.Embedding(config.block_size, config.n_embd),
129
+ drop = nn.Dropout(config.dropout),
130
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
132
+ ))
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ # with weight tying when using torch.compile() some warnings get generated:
135
+ # "UserWarning: functional_call was passed multiple values for tied weights.
136
+ # This behavior is deprecated and will be an error in future versions"
137
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
138
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139
+
140
+ # init all weights
141
+ self.apply(self._init_weights)
142
+ # apply special scaled init to the residual projections, per GPT-2 paper
143
+ for pn, p in self.named_parameters():
144
+ if pn.endswith('c_proj.weight'):
145
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
+
147
+ # report number of parameters
148
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
+
150
+ def get_num_params(self, non_embedding=True):
151
+ """
152
+ Return the number of parameters in the model.
153
+ For non-embedding count (default), the position embeddings get subtracted.
154
+ The token embeddings would too, except due to the parameter sharing these
155
+ params are actually used as weights in the final layer, so we include them.
156
+ """
157
+ n_params = sum(p.numel() for p in self.parameters())
158
+ if non_embedding:
159
+ n_params -= self.transformer.wpe.weight.numel()
160
+ return n_params
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Linear):
164
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165
+ if module.bias is not None:
166
+ torch.nn.init.zeros_(module.bias)
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169
+
170
+ def forward(self, idx, targets=None):
171
+ device = idx.device
172
+ b, t = idx.size()
173
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175
+
176
+ # forward the GPT model itself
177
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
+ x = self.transformer.drop(tok_emb + pos_emb)
180
+ for block in self.transformer.h:
181
+ x = block(x)
182
+ x = self.transformer.ln_f(x)
183
+
184
+ if targets is not None:
185
+ # if we are given some desired targets also calculate the loss
186
+ logits = self.lm_head(x)
187
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188
+ else:
189
+ # inference-time mini-optimization: only forward the lm_head on the very last position
190
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191
+ loss = None
192
+
193
+ return logits, loss
194
+
195
+ def crop_block_size(self, block_size):
196
+ # model surgery to decrease the block size if necessary
197
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198
+ # but want to use a smaller block size for some smaller, simpler model
199
+ assert block_size <= self.config.block_size
200
+ self.config.block_size = block_size
201
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202
+ for block in self.transformer.h:
203
+ if hasattr(block.attn, 'bias'):
204
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205
+
206
+ @classmethod
207
+ def from_pretrained(cls, model_type, override_args=None):
208
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209
+ override_args = override_args or {} # default to empty dict
210
+ # only dropout can be overridden see more notes below
211
+ assert all(k == 'dropout' for k in override_args)
212
+ from transformers import GPT2LMHeadModel
213
+ print("loading weights from pretrained gpt: %s" % model_type)
214
+
215
+ # n_layer, n_head and n_embd are determined from model_type
216
+ config_args = {
217
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221
+ }[model_type]
222
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
223
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225
+ config_args['bias'] = True # always True for GPT model checkpoints
226
+ # we can override the dropout rate, if desired
227
+ if 'dropout' in override_args:
228
+ print(f"overriding dropout rate to {override_args['dropout']}")
229
+ config_args['dropout'] = override_args['dropout']
230
+ # create a from-scratch initialized minGPT model
231
+ config = GPTConfig(**config_args)
232
+ model = GPT(config)
233
+ sd = model.state_dict()
234
+ sd_keys = sd.keys()
235
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236
+
237
+ # init a huggingface/transformers model
238
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239
+ sd_hf = model_hf.state_dict()
240
+
241
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
242
+ sd_keys_hf = sd_hf.keys()
243
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247
+ # this means that we have to transpose these weights when we import them
248
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249
+ for k in sd_keys_hf:
250
+ if any(k.endswith(w) for w in transposed):
251
+ # special treatment for the Conv1D weights we need to transpose
252
+ assert sd_hf[k].shape[::-1] == sd[k].shape
253
+ with torch.no_grad():
254
+ sd[k].copy_(sd_hf[k].t())
255
+ else:
256
+ # vanilla copy over the other parameters
257
+ assert sd_hf[k].shape == sd[k].shape
258
+ with torch.no_grad():
259
+ sd[k].copy_(sd_hf[k])
260
+
261
+ return model
262
+
263
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264
+ # start with all of the candidate parameters
265
+ param_dict = {pn: p for pn, p in self.named_parameters()}
266
+ # filter out those that do not require grad
267
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272
+ optim_groups = [
273
+ {'params': decay_params, 'weight_decay': weight_decay},
274
+ {'params': nodecay_params, 'weight_decay': 0.0}
275
+ ]
276
+ num_decay_params = sum(p.numel() for p in decay_params)
277
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
278
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
+ # Create AdamW optimizer and use the fused version if it is available
281
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cuda'
283
+ extra_args = dict(fused=True) if use_fused else dict()
284
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
+ print(f"using fused AdamW: {use_fused}")
286
+
287
+ return optimizer
288
+
289
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
290
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291
+ # first estimate the number of flops we do per iteration.
292
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293
+ N = self.get_num_params()
294
+ cfg = self.config
295
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296
+ flops_per_token = 6*N + 12*L*H*Q*T
297
+ flops_per_fwdbwd = flops_per_token * T
298
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
300
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
301
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302
+ mfu = flops_achieved / flops_promised
303
+ return mfu
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
310
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311
+ """
312
+ for _ in range(max_new_tokens):
313
+ # if the sequence context is growing too long we must crop it at block_size
314
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315
+ # forward the model to get the logits for the index in the sequence
316
+ logits, _ = self(idx_cond)
317
+ # pluck the logits at the final step and scale by desired temperature
318
+ logits = logits[:, -1, :] / temperature
319
+ # optionally crop the logits to only the top k options
320
+ if top_k is not None:
321
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322
+ logits[logits < v[:, [-1]]] = -float('Inf')
323
+ # apply softmax to convert logits to (normalized) probabilities
324
+ probs = F.softmax(logits, dim=-1)
325
+ # sample from the distribution
326
+ idx_next = torch.multinomial(probs, num_samples=1)
327
+ # append sampled index to the running sequence and continue
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx