Spaces:
Running
Running
import numpy as np | |
import chess | |
import torch | |
from o2_model import O2Net, board_to_tensor | |
from mcts import MCTS | |
import random | |
# Optional: Endgame tablebase and opening book integration placeholders | |
# You can use python-chess's tablebase and opening book modules if desired | |
# Example for endgame tablebase: | |
# from chess import tablebase | |
# tb = tablebase.Tablebase() | |
# tb.add_tablebase('/path/to/syzygy') | |
# if tb.probe_wdl(board) is not None: | |
# # Use tablebase move | |
# Example for opening book: | |
# from chess.polyglot import open_reader | |
# with open_reader('book.bin') as reader: | |
# entry = reader.find(board) | |
# move = entry.move | |
class O2Agent: | |
def __init__(self, model_path=None): | |
self.model = O2Net() | |
if model_path: | |
self.model.load_state_dict(torch.load(model_path)) | |
self.model.eval() | |
def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0): | |
if use_mcts: | |
mcts = MCTS(self.model, simulations=simulations) | |
return mcts.run(board, temperature=temperature) | |
tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0) | |
with torch.no_grad(): | |
policy, _ = self.model(tensor) | |
legal_moves = list(board.legal_moves) | |
move_scores = [] | |
for move in legal_moves: | |
move_idx = self.move_to_index(move) | |
move_scores.append(policy[0, move_idx].item()) | |
if temperature and temperature > 0: | |
# Softmax sampling | |
scores = np.array(move_scores) | |
exp_scores = np.exp(scores / temperature) | |
probs = exp_scores / np.sum(exp_scores) | |
move = np.random.choice(legal_moves, p=probs) | |
return move | |
best_move = legal_moves[int(torch.tensor(move_scores).argmax())] | |
return best_move | |
def move_to_index(self, move): | |
# Encode move as from_square * 64 + to_square + promotion_offset | |
from_square = move.from_square | |
to_square = move.to_square | |
promotion = move.promotion if move.promotion else 0 | |
promotion_offset = 0 | |
if promotion: | |
# Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess) | |
# Offset: 4096 + (promotion-1)*64*64//4 | |
promotion_offset = 4096 + (promotion - 1) * 256 | |
idx = from_square * 64 + to_square + promotion_offset | |
# Ensure index is within bounds | |
return idx if idx < 4672 else idx % 4672 | |
def index_to_move(self, board, index): | |
# Decode index to move (reverse of move_to_index) | |
if index >= 4096: | |
promotion = (index - 4096) % 4 + 1 | |
idx = index - 4096 | |
from_square = idx // 64 | |
to_square = idx % 64 | |
move = chess.Move(from_square, to_square, promotion=promotion) | |
else: | |
from_square = index // 64 | |
to_square = index % 64 | |
move = chess.Move(from_square, to_square) | |
if move in board.legal_moves: | |
return move | |
# Fallback: pick a random legal move | |
return random.choice(list(board.legal_moves)) | |
if __name__ == "__main__": | |
board = chess.Board() | |
agent = O2Agent() | |
move = agent.select_move(board) | |
print("O2 selects:", move) | |