File size: 4,630 Bytes
e3868b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import hashlib

import chess.pgn
import numpy as np
from anytree import LevelOrderGroupIter, LevelOrderIter, findall_by_attr
import torch
from loguru import logger

from src.engine.agents.base_agent import BaseAgent
from src.engine.agents.policies import beam_search, eval_board


class DLAgent(BaseAgent):
    def __init__(self,
                 model: torch.nn.Module,
                 is_white: bool,
                 beam_width: int = 5,
                 beam_depth: int = 11,
                 beam_player_strategy: str = "greedy",
                 beam_opponent_strategy: str = "greedy",
                 beam_player_top_k: int = 1,
                 beam_opponent_top_k: int = 1,
                 name: str = None
                 ) -> None:

        super().__init__(is_white)
        self.model = model
        self.model.eval()
        self.beam_width = beam_width
        self.beam_depth = beam_depth
        self.beam_player_strategy = beam_player_strategy
        self.beam_opponent_strategy = beam_opponent_strategy
        self.beam_player_top_k = beam_player_top_k
        self.beam_opponent_top_k = beam_opponent_top_k

        self.name = name

        self.set_min_max_score()

    def set_min_max_score(self):
        self.max_score = 100
        self.min_score = -100

    @logger.catch(level="DEBUG")
    def next_move(self, board: chess.Board) -> chess.Move:
        # beam = list(LevelOrderGroupIter(beam_search(self.model, board, self.beam_depth, self.beam_width, self.min_score, self.max_score)))
        beam = beam_search(model=self.model,
                           board=board,
                           depth=self.beam_depth,
                           beam_width=self.beam_width,
                           player_strategy=self.beam_player_strategy,
                           opponent_strategy=self.beam_opponent_strategy,
                           opponent_top_k=self.beam_opponent_top_k,
                           min_score=self.min_score,
                           max_score=self.max_score)

        beam_mean_score = np.mean([node.score for node in list(LevelOrderIter(beam))][1:])

        # first check if an immediate move leads to checkmate or tie
        for node in list(LevelOrderGroupIter(beam))[1]:
            if node.board.outcome():
                if self.is_white:
                    if node.board.outcome().result() == "1-0":
                        logger.info(f"Immediate checkmate for {self}")
                        return node.move
                    if node.board.outcome().result() == "1/2-1/2" and beam_mean_score < 0:
                        logger.info(f"Immediate tie for {self}")
                        return node.move
                else:
                    if node.board.outcome().result() == "0-1":
                        logger.info(f"Immediate checkmate for {self}")
                        return node.move
                    if node.board.outcome().result() == "1/2-1/2" and beam_mean_score > 0:
                        logger.info(f"Immediate tie for {self}")
                        return node.move


        # tag each leaf node with "finished" attribute if it is at max depth or checkmate for the player
        max_depth = max([node.depth for node in list(LevelOrderIter(beam))])
        for node in list(LevelOrderIter(beam)):
            if node.depth == max_depth:
                node.finished = True
            elif node.board.outcome():
                if node.board.outcome().result() == "1-0" and self.is_white:
                    node.finished = True
                elif node.board.outcome().result() == "0-1" and not self.is_white:
                    node.finished = True

        # identify moves with most finishing nodes
        nb_finished_nodes = [len(findall_by_attr(name="finished", value=True, node=node))
                             for node in list(LevelOrderGroupIter(beam))[1]]

        # TODO: implement a tie-breaker
        best_node = list(LevelOrderGroupIter(beam))[1][np.argmax(nb_finished_nodes)]

        return best_node.move

    def evaluate_board(self, board: chess.Board) -> float:
        return eval_board(self.model, board)

    def __str__(self) -> str:
        if self.name:
            return self.name
        # get list of self parameters
        hash = hashlib.md5(
            (str(self.model.model_hash()) +
             str(self.beam_width) +
             str(self.beam_depth) +
             str(self.beam_player_strategy) +
             str(self.beam_player_top_k) +
             str(self.beam_opponent_top_k)
             ).encode()
        ).hexdigest()
        return f"DLAgent-{hash}"