File size: 4,863 Bytes
cc24eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5342d4d
cc24eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc23fe
cc24eeb
5342d4d
 
 
 
 
 
 
 
cc24eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc23fe
 
cc24eeb
bcc23fe
cc24eeb
bcc23fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc24eeb
bcc23fe
 
 
 
 
 
 
 
 
 
 
 
cc24eeb
 
 
 
 
bcc23fe
 
 
cc24eeb
bcc23fe
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import chess
import torch
import numpy as np
from o2_model import board_to_tensor

class MCTSNode:
    def __init__(self, board, parent=None, move=None):
        self.board = board.copy()
        self.parent = parent
        self.move = move
        self.children = {}
        self.N = 0  # Visit count
        self.W = 0  # Total value
        self.Q = 0  # Mean value
        self.P = 0  # Prior probability

class MCTS:
    def __init__(self, model, simulations=100, c_puct=1.5):
        self.model = model
        self.simulations = simulations
        self.c_puct = c_puct

    def run(self, board, temperature=0.0):
        root = MCTSNode(board)
        self._expand(root)
        for _ in range(self.simulations):
            node = root
            search_path = [node]
            # Selection
            while node.children:
                max_ucb = -float('inf')
                best_move = None
                for move, child in node.children.items():
                    ucb = child.Q + self.c_puct * child.P * np.sqrt(node.N) / (1 + child.N)
                    if ucb > max_ucb:
                        max_ucb = ucb
                        best_move = move
                node = node.children[best_move]
                search_path.append(node)
            # Expansion
            value = self._expand(node)
            # Backpropagation
            for n in reversed(search_path):
                n.N += 1
                n.W += value
                n.Q = n.W / n.N if n.N > 0 else 0.0
                value = -value  # Switch perspective
        # Temperature-based sampling for opening diversity
        if temperature and temperature > 0:
            moves = list(root.children.keys())
            visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
            probs = visits ** (1.0 / temperature)
            probs = probs / np.sum(probs)
            move = np.random.choice(moves, p=probs)
            return move
        # Choose move with highest visit count
        best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
        return best_move

    def _expand(self, node):
        if node.board.is_game_over():
            result = node.board.result()
            if result == '1-0':
                return 1
            elif result == '0-1':
                return -1
            else:
                return 0
        tensor = torch.tensor(board_to_tensor(node.board)).unsqueeze(0)
        with torch.no_grad():
            policy, value = self.model(tensor)
        policy = torch.softmax(policy, dim=1).numpy()[0]
        assert len(policy) == 4672, f"Policy size mismatch: expected 4672, got {len(policy)}"
        
        legal_moves = list(node.board.legal_moves)
        total_p = 1e-8  # Small epsilon to prevent division by zero
        for move in legal_moves:
            try:
                idx = self.move_to_index(move)
                if 0 <= idx < 4672:  # Ensure index is within bounds
                    p = policy[idx]
                    total_p += p
            except Exception:
                continue  # Skip moves that can't be indexed properly
                
        if total_p < 1e-8:  # If all probabilities are extremely small
            total_p = 1.0  # Fall back to uniform distribution
            # Use uniform distribution only for legal moves
            for move in legal_moves:
                idx = self.move_to_index(move)
                if 0 <= idx < 4672:
                    policy[idx] = 1.0 / len(legal_moves)
                
        # Create child nodes only for valid moves
        for move in legal_moves:
            try:
                idx = self.move_to_index(move)
                if 0 <= idx < 4672:
                    p = policy[idx] / total_p
                    child_board = node.board.copy()
                    child_board.push(move)
                    child = MCTSNode(child_board, parent=node, move=move)
                    child.P = p
                    node.children[move] = child
            except Exception:
                continue  # Skip problematic moves
                
        return value.item()
    def move_to_index(self, move):
        from_square = move.from_square
        to_square = move.to_square
        promotion = move.promotion if move.promotion else 0
        # Base index for normal moves
        idx = from_square * 64 + to_square
        # Handle promotions (knight=1, bishop=2, rook=3, queen=4)
        if promotion:
            # Map to indices after normal moves (4096 onwards)
            idx = 4096 + ((promotion - 1) * 64 * 64 // 4) + (from_square * 8 + to_square // 8)
        # Ensure index is within bounds (4672 = 64*64 + 64*8)
        return min(idx, 4671)