Play-with-o2 / src /mcts.py
FlameF0X's picture
Upload 4 files
bcc23fe verified
import chess
import torch
import numpy as np
from o2_model import board_to_tensor
class MCTSNode:
def __init__(self, board, parent=None, move=None):
self.board = board.copy()
self.parent = parent
self.move = move
self.children = {}
self.N = 0 # Visit count
self.W = 0 # Total value
self.Q = 0 # Mean value
self.P = 0 # Prior probability
class MCTS:
def __init__(self, model, simulations=100, c_puct=1.5):
self.model = model
self.simulations = simulations
self.c_puct = c_puct
def run(self, board, temperature=0.0):
root = MCTSNode(board)
self._expand(root)
for _ in range(self.simulations):
node = root
search_path = [node]
# Selection
while node.children:
max_ucb = -float('inf')
best_move = None
for move, child in node.children.items():
ucb = child.Q + self.c_puct * child.P * np.sqrt(node.N) / (1 + child.N)
if ucb > max_ucb:
max_ucb = ucb
best_move = move
node = node.children[best_move]
search_path.append(node)
# Expansion
value = self._expand(node)
# Backpropagation
for n in reversed(search_path):
n.N += 1
n.W += value
n.Q = n.W / n.N if n.N > 0 else 0.0
value = -value # Switch perspective
# Temperature-based sampling for opening diversity
if temperature and temperature > 0:
moves = list(root.children.keys())
visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
probs = visits ** (1.0 / temperature)
probs = probs / np.sum(probs)
move = np.random.choice(moves, p=probs)
return move
# Choose move with highest visit count
best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
return best_move
def _expand(self, node):
if node.board.is_game_over():
result = node.board.result()
if result == '1-0':
return 1
elif result == '0-1':
return -1
else:
return 0
tensor = torch.tensor(board_to_tensor(node.board)).unsqueeze(0)
with torch.no_grad():
policy, value = self.model(tensor)
policy = torch.softmax(policy, dim=1).numpy()[0]
assert len(policy) == 4672, f"Policy size mismatch: expected 4672, got {len(policy)}"
legal_moves = list(node.board.legal_moves)
total_p = 1e-8 # Small epsilon to prevent division by zero
for move in legal_moves:
try:
idx = self.move_to_index(move)
if 0 <= idx < 4672: # Ensure index is within bounds
p = policy[idx]
total_p += p
except Exception:
continue # Skip moves that can't be indexed properly
if total_p < 1e-8: # If all probabilities are extremely small
total_p = 1.0 # Fall back to uniform distribution
# Use uniform distribution only for legal moves
for move in legal_moves:
idx = self.move_to_index(move)
if 0 <= idx < 4672:
policy[idx] = 1.0 / len(legal_moves)
# Create child nodes only for valid moves
for move in legal_moves:
try:
idx = self.move_to_index(move)
if 0 <= idx < 4672:
p = policy[idx] / total_p
child_board = node.board.copy()
child_board.push(move)
child = MCTSNode(child_board, parent=node, move=move)
child.P = p
node.children[move] = child
except Exception:
continue # Skip problematic moves
return value.item()
def move_to_index(self, move):
from_square = move.from_square
to_square = move.to_square
promotion = move.promotion if move.promotion else 0
# Base index for normal moves
idx = from_square * 64 + to_square
# Handle promotions (knight=1, bishop=2, rook=3, queen=4)
if promotion:
# Map to indices after normal moves (4096 onwards)
idx = 4096 + ((promotion - 1) * 64 * 64 // 4) + (from_square * 8 + to_square // 8)
# Ensure index is within bounds (4672 = 64*64 + 64*8)
return min(idx, 4671)