Spaces:
Running
Running
File size: 3,329 Bytes
bcc23fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import numpy as np
import chess
import torch
from o2_model import O2Net, board_to_tensor
from mcts import MCTS
import random
# Optional: Endgame tablebase and opening book integration placeholders
# You can use python-chess's tablebase and opening book modules if desired
# Example for endgame tablebase:
# from chess import tablebase
# tb = tablebase.Tablebase()
# tb.add_tablebase('/path/to/syzygy')
# if tb.probe_wdl(board) is not None:
# # Use tablebase move
# Example for opening book:
# from chess.polyglot import open_reader
# with open_reader('book.bin') as reader:
# entry = reader.find(board)
# move = entry.move
class O2Agent:
def __init__(self, model_path=None):
self.model = O2Net()
if model_path:
self.model.load_state_dict(torch.load(model_path))
self.model.eval()
def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
if use_mcts:
mcts = MCTS(self.model, simulations=simulations)
return mcts.run(board, temperature=temperature)
tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
with torch.no_grad():
policy, _ = self.model(tensor)
legal_moves = list(board.legal_moves)
move_scores = []
for move in legal_moves:
move_idx = self.move_to_index(move)
move_scores.append(policy[0, move_idx].item())
if temperature and temperature > 0:
# Softmax sampling
scores = np.array(move_scores)
exp_scores = np.exp(scores / temperature)
probs = exp_scores / np.sum(exp_scores)
move = np.random.choice(legal_moves, p=probs)
return move
best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
return best_move
def move_to_index(self, move):
# Encode move as from_square * 64 + to_square + promotion_offset
from_square = move.from_square
to_square = move.to_square
promotion = move.promotion if move.promotion else 0
promotion_offset = 0
if promotion:
# Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
# Offset: 4096 + (promotion-1)*64*64//4
promotion_offset = 4096 + (promotion - 1) * 256
idx = from_square * 64 + to_square + promotion_offset
# Ensure index is within bounds
return idx if idx < 4672 else idx % 4672
def index_to_move(self, board, index):
# Decode index to move (reverse of move_to_index)
if index >= 4096:
promotion = (index - 4096) % 4 + 1
idx = index - 4096
from_square = idx // 64
to_square = idx % 64
move = chess.Move(from_square, to_square, promotion=promotion)
else:
from_square = index // 64
to_square = index % 64
move = chess.Move(from_square, to_square)
if move in board.legal_moves:
return move
# Fallback: pick a random legal move
return random.choice(list(board.legal_moves))
if __name__ == "__main__":
board = chess.Board()
agent = O2Agent()
move = agent.select_move(board)
print("O2 selects:", move)
|