File size: 3,329 Bytes
bcc23fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import chess
import torch
from o2_model import O2Net, board_to_tensor
from mcts import MCTS
import random

# Optional: Endgame tablebase and opening book integration placeholders
# You can use python-chess's tablebase and opening book modules if desired
# Example for endgame tablebase:
# from chess import tablebase
# tb = tablebase.Tablebase()
# tb.add_tablebase('/path/to/syzygy')
# if tb.probe_wdl(board) is not None:
#     # Use tablebase move
# Example for opening book:
# from chess.polyglot import open_reader
# with open_reader('book.bin') as reader:
#     entry = reader.find(board)
#     move = entry.move

class O2Agent:
    def __init__(self, model_path=None):
        self.model = O2Net()
        if model_path:
            self.model.load_state_dict(torch.load(model_path))
        self.model.eval()

    def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
        if use_mcts:
            mcts = MCTS(self.model, simulations=simulations)
            return mcts.run(board, temperature=temperature)
        tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
        with torch.no_grad():
            policy, _ = self.model(tensor)
        legal_moves = list(board.legal_moves)
        move_scores = []
        for move in legal_moves:
            move_idx = self.move_to_index(move)
            move_scores.append(policy[0, move_idx].item())
        if temperature and temperature > 0:
            # Softmax sampling
            scores = np.array(move_scores)
            exp_scores = np.exp(scores / temperature)
            probs = exp_scores / np.sum(exp_scores)
            move = np.random.choice(legal_moves, p=probs)
            return move
        best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
        return best_move

    def move_to_index(self, move):
        # Encode move as from_square * 64 + to_square + promotion_offset
        from_square = move.from_square
        to_square = move.to_square
        promotion = move.promotion if move.promotion else 0
        promotion_offset = 0
        if promotion:
            # Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
            # Offset: 4096 + (promotion-1)*64*64//4
            promotion_offset = 4096 + (promotion - 1) * 256
        idx = from_square * 64 + to_square + promotion_offset
        # Ensure index is within bounds
        return idx if idx < 4672 else idx % 4672

    def index_to_move(self, board, index):
        # Decode index to move (reverse of move_to_index)
        if index >= 4096:
            promotion = (index - 4096) % 4 + 1
            idx = index - 4096
            from_square = idx // 64
            to_square = idx % 64
            move = chess.Move(from_square, to_square, promotion=promotion)
        else:
            from_square = index // 64
            to_square = index % 64
            move = chess.Move(from_square, to_square)
        if move in board.legal_moves:
            return move
        # Fallback: pick a random legal move
        return random.choice(list(board.legal_moves))

if __name__ == "__main__":
    board = chess.Board()
    agent = O2Agent()
    move = agent.select_move(board)
    print("O2 selects:", move)