Spaces:
Running
Running
import chess | |
import json | |
import numpy as np | |
import os | |
import random | |
import torch | |
import torch.nn as nn | |
from flask import Flask, render_template, request, jsonify, send_from_directory | |
app = Flask(__name__) | |
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0 | |
PIECE_TO_PLANE = { | |
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, | |
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 | |
} | |
CASTLING_INDICES = {'K': 12, 'Q': 13, 'k': 14, 'q': 15} | |
def fen_to_tensor_perspective(fen: str) -> torch.Tensor: | |
board = chess.Board(fen) | |
tensor = np.zeros((17, 8, 8), dtype=np.float32) | |
if board.turn == chess.WHITE: | |
piece_to_plane = PIECE_TO_PLANE | |
flip = False | |
else: | |
piece_to_plane = { | |
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11, | |
'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5 | |
} | |
flip = True | |
parts = fen.split() | |
rows = parts[0].split('/') | |
for r, row in enumerate(rows): | |
f = 0 | |
for c in row: | |
if c.isdigit(): | |
f += int(c) | |
else: | |
plane = piece_to_plane[c] | |
tensor[plane, r, f] = 1.0 | |
f += 1 | |
if 'K' in parts[2]: tensor[12, :, :] = 1.0 | |
if 'Q' in parts[2]: tensor[13, :, :] = 1.0 | |
if 'k' in parts[2]: tensor[14, :, :] = 1.0 | |
if 'q' in parts[2]: tensor[15, :, :] = 1.0 | |
if board.ep_square is not None: | |
ep_rank = 7 - (board.ep_square // 8) | |
ep_file = board.ep_square % 8 | |
tensor[16, ep_rank, ep_file] = 1.0 | |
if flip: | |
tensor = np.flip(tensor, axis=(1, 2)).copy() | |
return torch.tensor(tensor, dtype=torch.float32) | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(channels) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(channels) | |
def forward(self, x): | |
identity = x | |
out = self.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += identity | |
return self.relu(out) | |
class ChessPolicyValueNet(nn.Module): | |
def __init__(self, input_planes, policy_size, num_blocks=12, channels=128): | |
super().__init__() | |
self.conv_in = nn.Conv2d(input_planes, channels, kernel_size=3, padding=1, bias=False) | |
self.bn_in = nn.BatchNorm2d(channels) | |
self.res_blocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_blocks)]) | |
self.policy_conv = nn.Conv2d(channels, 2, kernel_size=1) | |
self.policy_bn = nn.BatchNorm2d(2) | |
self.policy_fc = nn.Linear(2 * 8 * 8, policy_size) | |
self.value_conv = nn.Conv2d(channels, 1, kernel_size=1) | |
self.value_bn = nn.BatchNorm2d(1) | |
self.value_fc1 = nn.Linear(8 * 8, 256) | |
self.value_fc2 = nn.Linear(256, 3) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
x = self.relu(self.bn_in(self.conv_in(x))) | |
x = self.res_blocks(x) | |
p = self.relu(self.policy_bn(self.policy_conv(x))).view(x.size(0), -1) | |
p = self.policy_fc(p) | |
v = self.relu(self.value_bn(self.value_conv(x))).view(x.size(0), -1) | |
v = self.relu(self.value_fc1(v)) | |
v = self.value_fc2(v) | |
return p, v | |
with open("move_to_idx.json", "r") as f: | |
move_to_idx = json.load(f) | |
idx_to_move = {int(i): m for m, i in move_to_idx.items()} | |
book = {} | |
for filename in ["book1.json", "book2.json", "book3.json", "book4.json"]: | |
with open(filename, "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
book[entry["fen"]] = entry["moves"] | |
def get_positional_fen(fen: str) -> str: | |
return " ".join(fen.split(" ")[:4]) | |
def get_book_move(fen: str): | |
pos_fen = get_positional_fen(fen) | |
moves = book.get(pos_fen) | |
if not moves: | |
return None | |
uci_moves = [m[0] for m in moves] | |
weights = [m[1] for m in moves] | |
total = sum(weights) | |
if total == 0: | |
# Should never happen methinks | |
return None | |
probabilities = [w / total for w in weights] | |
return random.choices(uci_moves, weights=probabilities, k=1)[0] | |
model = ChessPolicyValueNet(input_planes=17, policy_size=len(move_to_idx)) | |
model.load_state_dict(torch.load("model.pt", map_location="cpu")['model_state_dict']) | |
model.eval() | |
class MCTSNode: | |
def __init__(self, board, parent=None, move=None, prior=0.0): | |
self.board = board | |
self.parent = parent | |
self.move = move | |
self.children = {} | |
self.visits = 0 | |
self.value_sum = 0.0 | |
self.prior = prior | |
def value(self): | |
return self.value_sum / self.visits if self.visits > 0 else 0.0 | |
def expand_node(node, top_k=10): | |
legal_moves = list(node.board.legal_moves) | |
tensor = fen_to_tensor_perspective(node.board.fen()).unsqueeze(0) | |
with torch.no_grad(): | |
logits, value = model(tensor) | |
policy = torch.softmax(logits[0], dim=0) | |
move_priors = [] | |
for move in legal_moves: | |
uci = move.uci() | |
if uci in move_to_idx: | |
prior = policy[move_to_idx[uci]].item() | |
move_priors.append((move, prior)) | |
# Keep only the top K moves by prior | |
move_priors = sorted(move_priors, key=lambda x: x[1], reverse=True)[:top_k] | |
for move, prior in move_priors: | |
board_copy = node.board.copy() | |
board_copy.push(move) | |
node.children[move.uci()] = MCTSNode(board_copy, parent=node, move=move, prior=prior) | |
return value[0] | |
def select_child(node, c_puct=1.0): | |
total_visits = sum(child.visits for child in node.children.values()) + 1 | |
best_score = -float('inf') | |
best_child = None | |
for child in node.children.values(): | |
u = child.value() + c_puct * child.prior * (total_visits ** 0.5 / (1 + child.visits)) | |
if u > best_score: | |
best_score = u | |
best_child = child | |
return best_child | |
def backpropagate(node, value): | |
current = node | |
win_prob = torch.softmax(value, dim=0)[0].item() | |
while current: | |
current.visits += 1 | |
current.value_sum += win_prob | |
current = current.parent | |
def run_mcts_or_play_book_move(root_board, simulations=1000, top_k=10, fen=None): | |
best = get_book_move(fen) | |
if best is None: | |
root = MCTSNode(root_board) | |
expand_node(root, top_k=top_k) | |
for _ in range(simulations): | |
node = root | |
while node.children: | |
node = select_child(node) | |
if node.board.is_game_over(): | |
outcome = node.board.result() | |
if outcome == "1-0": | |
value = torch.tensor([1.0, 0.0, 0.0]) | |
elif outcome == "0-1": | |
value = torch.tensor([0.0, 0.0, 1.0]) | |
else: | |
value = torch.tensor([0.0, 1.0, 0.0]) | |
else: | |
value = expand_node(node, top_k=top_k) | |
backpropagate(node, value) | |
best = max(root.children.values(), key=lambda c: c.visits) | |
return best.move.uci() | |
return best | |
def index(): | |
return render_template("index.html") | |
def start_game(): | |
board = chess.Board() | |
computer_color = random.choice(['white', 'black']) | |
move = None | |
if computer_color == 'white': | |
move = run_mcts_or_play_book_move(board, simulations=100, top_k=10, fen=board.fen()) | |
board.push(chess.Move.from_uci(move)) | |
return jsonify({ | |
"fen": board.fen(), | |
"computer_color": computer_color, | |
"move": move | |
}) | |
def model_move(): | |
data = request.get_json() | |
fen = data["fen"] | |
board = chess.Board(fen) | |
if board.is_game_over(): | |
return jsonify({"error": "Game is over."}) | |
best_move = run_mcts_or_play_book_move(board, simulations=100, top_k=10, fen=board.fen()) | |
board.push(chess.Move.from_uci(best_move)) | |
return jsonify({ | |
"fen": board.fen(), | |
"move": best_move | |
}) | |
def favicon(): | |
return send_from_directory(os.path.join(app.root_path, 'static'), 'favicon.ico', mimetype='image/vnd.microsoft.icon') | |
if __name__ == "__main__": | |
app.run(debug=True, host="0.0.0.0", port=7860) | |