Spaces:
Running
Running
import chess | |
import torch | |
import numpy as np | |
from o2_model import board_to_tensor | |
class MCTSNode: | |
def __init__(self, board, parent=None, move=None): | |
self.board = board.copy() | |
self.parent = parent | |
self.move = move | |
self.children = {} | |
self.N = 0 # Visit count | |
self.W = 0 # Total value | |
self.Q = 0 # Mean value | |
self.P = 0 # Prior probability | |
class MCTS: | |
def __init__(self, model, simulations=100, c_puct=1.5): | |
self.model = model | |
self.simulations = simulations | |
self.c_puct = c_puct | |
def run(self, board, temperature=0.0): | |
root = MCTSNode(board) | |
self._expand(root) | |
for _ in range(self.simulations): | |
node = root | |
search_path = [node] | |
# Selection | |
while node.children: | |
max_ucb = -float('inf') | |
best_move = None | |
for move, child in node.children.items(): | |
ucb = child.Q + self.c_puct * child.P * np.sqrt(node.N) / (1 + child.N) | |
if ucb > max_ucb: | |
max_ucb = ucb | |
best_move = move | |
node = node.children[best_move] | |
search_path.append(node) | |
# Expansion | |
value = self._expand(node) | |
# Backpropagation | |
for n in reversed(search_path): | |
n.N += 1 | |
n.W += value | |
n.Q = n.W / n.N if n.N > 0 else 0.0 | |
value = -value # Switch perspective | |
# Temperature-based sampling for opening diversity | |
if temperature and temperature > 0: | |
moves = list(root.children.keys()) | |
visits = np.array([root.children[m].N for m in moves], dtype=np.float32) | |
probs = visits ** (1.0 / temperature) | |
probs = probs / np.sum(probs) | |
move = np.random.choice(moves, p=probs) | |
return move | |
# Choose move with highest visit count | |
best_move = max(root.children.items(), key=lambda item: item[1].N)[0] | |
return best_move | |
def _expand(self, node): | |
if node.board.is_game_over(): | |
result = node.board.result() | |
if result == '1-0': | |
return 1 | |
elif result == '0-1': | |
return -1 | |
else: | |
return 0 | |
tensor = torch.tensor(board_to_tensor(node.board)).unsqueeze(0) | |
with torch.no_grad(): | |
policy, value = self.model(tensor) | |
policy = torch.softmax(policy, dim=1).numpy()[0] | |
assert len(policy) == 4672, f"Policy size mismatch: expected 4672, got {len(policy)}" | |
legal_moves = list(node.board.legal_moves) | |
total_p = 1e-8 # Small epsilon to prevent division by zero | |
for move in legal_moves: | |
try: | |
idx = self.move_to_index(move) | |
if 0 <= idx < 4672: # Ensure index is within bounds | |
p = policy[idx] | |
total_p += p | |
except Exception: | |
continue # Skip moves that can't be indexed properly | |
if total_p < 1e-8: # If all probabilities are extremely small | |
total_p = 1.0 # Fall back to uniform distribution | |
# Use uniform distribution only for legal moves | |
for move in legal_moves: | |
idx = self.move_to_index(move) | |
if 0 <= idx < 4672: | |
policy[idx] = 1.0 / len(legal_moves) | |
# Create child nodes only for valid moves | |
for move in legal_moves: | |
try: | |
idx = self.move_to_index(move) | |
if 0 <= idx < 4672: | |
p = policy[idx] / total_p | |
child_board = node.board.copy() | |
child_board.push(move) | |
child = MCTSNode(child_board, parent=node, move=move) | |
child.P = p | |
node.children[move] = child | |
except Exception: | |
continue # Skip problematic moves | |
return value.item() | |
def move_to_index(self, move): | |
from_square = move.from_square | |
to_square = move.to_square | |
promotion = move.promotion if move.promotion else 0 | |
# Base index for normal moves | |
idx = from_square * 64 + to_square | |
# Handle promotions (knight=1, bishop=2, rook=3, queen=4) | |
if promotion: | |
# Map to indices after normal moves (4096 onwards) | |
idx = 4096 + ((promotion - 1) * 64 * 64 // 4) + (from_square * 8 + to_square // 8) | |
# Ensure index is within bounds (4672 = 64*64 + 64*8) | |
return min(idx, 4671) | |