| | from collections import deque
|
| | import random
|
| | import torch
|
| | import torch
|
| | from engine import GameState
|
| | from move_finder import find_best_move_shallow
|
| | from infer_nnue import gs_to_nnue_features
|
| | from nnue_model import NNUE
|
| | from tqdm import tqdm
|
| | from infer_nnue import NNUEInfer
|
| | NNUE_FEATURES = 32
|
| | def pad_features(feats):
|
| | if len(feats) < NNUE_FEATURES:
|
| | return feats + [0] * (NNUE_FEATURES - len(feats))
|
| | return feats[:NNUE_FEATURES]
|
| |
|
| | import pickle
|
| |
|
| | def load_pgn_dataset(path):
|
| | trajectories = []
|
| | current_traj = []
|
| |
|
| | with open(path, "rb") as f:
|
| | while True:
|
| | try:
|
| | chunk = pickle.load(f)
|
| | for item in chunk:
|
| | current_traj.append(item)
|
| |
|
| |
|
| | if len(current_traj) > 1 and \
|
| | current_traj[-1]["stm"] != current_traj[-2]["stm"]:
|
| | trajectories.append(current_traj)
|
| | current_traj = []
|
| |
|
| | except EOFError:
|
| | break
|
| |
|
| | if current_traj:
|
| | trajectories.append(current_traj)
|
| |
|
| | return trajectories
|
| |
|
| |
|
| | @torch.no_grad()
|
| | @torch.no_grad()
|
| | def td_targets_from_traj(model, traj, gamma=0.99):
|
| | if len(traj) == 1:
|
| | return [0.0]
|
| |
|
| | feats = [pad_features(x["features"]) for x in traj]
|
| | stm = [x["stm"] for x in traj]
|
| |
|
| | feats = torch.tensor(feats, dtype=torch.long, device="cuda")
|
| | stm = torch.tensor(stm, dtype=torch.long, device="cuda")
|
| |
|
| | values = model(feats, stm).view(-1)
|
| |
|
| | targets = torch.empty_like(values)
|
| |
|
| |
|
| | targets[:-1] = gamma * (-values[1:])
|
| | targets[-1] = values[-1].detach()
|
| |
|
| |
|
| | targets = torch.clamp(targets, -1.0, 1.0)
|
| |
|
| | return targets.cpu().tolist()
|
| |
|
| |
|
| |
|
| | from collections import deque
|
| | import random
|
| |
|
| | class ReplayBuffer:
|
| | def __init__(self, capacity=300_000):
|
| | self.buf = deque(maxlen=capacity)
|
| |
|
| | def add(self, f, stm, t):
|
| | self.buf.append((f, stm, t))
|
| |
|
| | def sample(self, n):
|
| | return random.sample(self.buf, n)
|
| |
|
| | def __len__(self):
|
| | return len(self.buf)
|
| |
|
| |
|
| | def train_from_replay(model, optimizer, replay, batch_size):
|
| | if len(replay) < batch_size:
|
| | return
|
| |
|
| | batch = replay.sample(batch_size)
|
| | feats, stm, targets = zip(*batch)
|
| |
|
| | feats = torch.tensor(feats, dtype=torch.long, device="cuda")
|
| | stm = torch.tensor(stm, dtype=torch.long, device="cuda")
|
| | targ = torch.tensor(targets, dtype=torch.float, device="cuda")
|
| |
|
| | preds = model(feats, stm).view(-1)
|
| |
|
| | loss = torch.nn.functional.smooth_l1_loss(preds, targ)
|
| |
|
| | optimizer.zero_grad(set_to_none=True)
|
| | loss.backward()
|
| | optimizer.step()
|
| |
|
| |
|
| | from tqdm import tqdm
|
| | device = "cuda"
|
| | model = NNUE().to(device)
|
| | model.load_state_dict(torch.load("nnue_model.pt", weights_only=True))
|
| | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
| |
|
| | replay = ReplayBuffer()
|
| | trajectories = load_pgn_dataset("nnue_dataset.pkl")
|
| |
|
| | for epoch in range(3):
|
| | print(f"Epoch {epoch}")
|
| |
|
| | for traj in tqdm(trajectories):
|
| | if len(traj) < 2:
|
| | continue
|
| |
|
| | targets = td_targets_from_traj(model, traj)
|
| |
|
| | for x, t in zip(traj, targets):
|
| | replay.add(
|
| | pad_features(x["features"]),
|
| | x["stm"],
|
| | t
|
| | )
|
| |
|
| | for _ in range(3):
|
| | train_from_replay(model, optimizer, replay, batch_size=512)
|
| |
|
| | torch.save(model.state_dict(), "nnue_model_td.pt")
|
| |
|