| import torch
|
| import torch.nn.functional as F
|
| from model import OthelloNet
|
| from bitboard import get_bit, make_input_planes
|
| import numpy as np
|
|
|
| def load_dualist(model_path="dualist_model.pth", device="cpu"):
|
| """
|
| Loads the Dualist Othello model.
|
| """
|
| model = OthelloNet(num_res_blocks=10, num_channels=256)
|
| checkpoint = torch.load(model_path, map_location=device)
|
|
|
|
|
| if "model_state_dict" in checkpoint:
|
| model.load_state_dict(checkpoint["model_state_dict"])
|
| else:
|
| model.load_state_dict(checkpoint)
|
|
|
| model.to(device)
|
| model.eval()
|
| return model
|
|
|
| def get_best_move(model, player_bb, opponent_bb, legal_moves_bb, device="cpu"):
|
| """
|
| Given the current board state and legal moves, returns the best move (bitmask).
|
| """
|
|
|
| input_tensor = make_input_planes(player_bb, opponent_bb).to(device)
|
|
|
|
|
| with torch.no_grad():
|
| policy_logits, value = model(input_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
| probs = torch.exp(policy_logits).squeeze(0).cpu().numpy()
|
|
|
| best_move_idx = -1
|
| max_prob = -1.0
|
|
|
| for i in range(64):
|
|
|
| row, col = (63 - i) // 8, (63 - i) % 8
|
| mask = get_bit(row, col)
|
|
|
| if legal_moves_bb & mask:
|
| if probs[i] > max_prob:
|
| max_prob = probs[i]
|
| best_move_idx = i
|
|
|
| if best_move_idx == -1:
|
|
|
| return 0
|
|
|
| row, col = (63 - best_move_idx) // 8, (63 - best_move_idx) % 8
|
| return get_bit(row, col)
|
|
|
| if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
| print("Dualist Inference Test")
|
| try:
|
| model = load_dualist()
|
| print("Model loaded successfully!")
|
|
|
|
|
|
|
|
|
| black_bb = 0x0000000810000000
|
| white_bb = 0x0000001008000000
|
| legal_moves = 0x0000102004080000
|
|
|
| best = get_best_move(model, black_bb, white_bb, legal_moves)
|
| print(f"Best move found: {hex(best)}")
|
|
|
| except FileNotFoundError:
|
| print("Error: dualist_model.pth not found. Ensure it's in the same directory.")
|
| except Exception as e:
|
| print(f"An error occurred: {e}")
|
|
|