|
import logging |
|
import math |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
EPS = 1e-8 |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class MCTS(): |
|
""" |
|
This class handles the MCTS tree. |
|
""" |
|
|
|
def __init__(self, game, nnet, args): |
|
self.game = game |
|
self.nnet = nnet |
|
self.args = args |
|
self.Qsa = {} |
|
self.Nsa = {} |
|
self.Ns = {} |
|
self.Ps = {} |
|
|
|
self.Es = {} |
|
self.Vs = {} |
|
|
|
def getActionProb(self, canonicalBoard, temp=1): |
|
""" |
|
This function performs numMCTSSims simulations of MCTS starting from |
|
canonicalBoard. |
|
|
|
Returns: |
|
probs: a policy vector where the probability of the ith action is |
|
proportional to Nsa[(s,a)]**(1./temp) |
|
""" |
|
for i in range(self.args.numMCTSSims): |
|
|
|
self.game.reset_steps() |
|
self.search(canonicalBoard) |
|
|
|
s = self.game.stringRepresentation(canonicalBoard) |
|
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())] |
|
|
|
if temp == 0: |
|
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten() |
|
bestA = np.random.choice(bestAs) |
|
probs = [0] * len(counts) |
|
probs[bestA] = 1 |
|
return probs |
|
|
|
counts = [x ** (1. / temp) for x in counts] |
|
counts_sum = float(sum(counts)) |
|
if counts_sum == 0: |
|
print(len(counts)) |
|
probs = [x / counts_sum for x in counts] |
|
return probs |
|
|
|
def search_iterate(self, canonicalBoard): |
|
stack = [(0, (canonicalBoard,))] |
|
results = [] |
|
|
|
while stack: |
|
st, sv = stack.pop() |
|
if st == 0: |
|
result, ns = self.search_iterate_st0(sv[0]) |
|
if result is not None: |
|
results.append(result) |
|
if ns is not None: |
|
stack.append((1, (ns[1], ns[2]))) |
|
stack.append((0, (ns[0],))) |
|
elif st == 1: |
|
v = results.pop() |
|
v = self.search_iterate_update(v, sv[0], sv[1]) |
|
results.append(v) |
|
else: |
|
raise ValueError("Invalid state") |
|
return results.pop() |
|
|
|
def search_iterate_st0(self, canonicalBoard): |
|
s = self.game.stringRepresentation(canonicalBoard) |
|
|
|
if s not in self.Es: |
|
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) |
|
if self.Es[s] != 0: |
|
result = -self.Es[s] |
|
return result, None |
|
if s not in self.Ps: |
|
|
|
self.Ps[s], v = self.nnet.predict(canonicalBoard) |
|
valids = self.game.getValidMoves(canonicalBoard, 1) |
|
self.Ps[s] = self.Ps[s] * valids |
|
sum_Ps_s = np.sum(self.Ps[s]) |
|
if sum_Ps_s > 0: |
|
self.Ps[s] /= sum_Ps_s |
|
else: |
|
self.Ps[s] = self.Ps[s] + valids |
|
self.Ps[s] /= np.sum(self.Ps[s]) |
|
|
|
self.Vs[s] = valids |
|
self.Ns[s] = 0 |
|
|
|
return -v, None |
|
|
|
valids = self.Vs[s] |
|
cur_best = -float('inf') |
|
best_act = -1 |
|
|
|
for a in range(self.game.getActionSize()): |
|
if valids[a]: |
|
if (s, a) in self.Qsa: |
|
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)]) |
|
else: |
|
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) |
|
|
|
if u > cur_best: |
|
cur_best = u |
|
best_act = a |
|
|
|
next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_act) |
|
next_s = self.game.getCanonicalForm(next_s, next_player) |
|
|
|
return None, (next_s, s, best_act) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_iterate_update(self, v, s, a): |
|
if (s, a) in self.Qsa: |
|
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) |
|
self.Nsa[(s, a)] += 1 |
|
|
|
else: |
|
self.Qsa[(s, a)] = v |
|
self.Nsa[(s, a)] = 1 |
|
|
|
self.Ns[s] += 1 |
|
return -v |
|
|
|
def search(self, canonicalBoard, depth=0): |
|
""" |
|
This function performs one iteration of MCTS. It is recursively called |
|
till a leaf node is found. The action chosen at each node is one that |
|
has the maximum upper confidence bound as in the paper. |
|
|
|
Once a leaf node is found, the neural network is called to return an |
|
initial policy P and a value v for the state. This value is propagated |
|
up the search path. In case the leaf node is a terminal state, the |
|
outcome is propagated up the search path. The values of Ns, Nsa, Qsa are |
|
updated. |
|
|
|
NOTE: the return values are the negative of the value of the current |
|
state. This is done since v is in [-1,1] and if v is the value of a |
|
state for the current player, then its value is -v for the other player. |
|
|
|
Returns: |
|
v: the negative of the value of the current canonicalBoard |
|
""" |
|
|
|
s = self.game.stringRepresentation(canonicalBoard) |
|
|
|
if s not in self.Es: |
|
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) |
|
if self.Es[s] != 0: |
|
|
|
return -self.Es[s] |
|
|
|
if s not in self.Ps: |
|
|
|
self.Ps[s], v = self.nnet.predict(canonicalBoard) |
|
valids = self.game.getValidMoves(canonicalBoard, 1) |
|
self.Ps[s] = self.Ps[s] * valids |
|
sum_Ps_s = np.sum(self.Ps[s]) |
|
if sum_Ps_s > 0: |
|
self.Ps[s] /= sum_Ps_s |
|
else: |
|
|
|
|
|
|
|
|
|
log.error("All valid moves were masked, doing a workaround.") |
|
self.Ps[s] = self.Ps[s] + valids |
|
self.Ps[s] /= np.sum(self.Ps[s]) |
|
|
|
self.Vs[s] = valids |
|
self.Ns[s] = 0 |
|
return -v |
|
|
|
valids = self.Vs[s] |
|
cur_best = -float('inf') |
|
best_act = -1 |
|
|
|
|
|
for a in range(self.game.getActionSize()): |
|
if valids[a]: |
|
if (s, a) in self.Qsa: |
|
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / ( |
|
1 + self.Nsa[(s, a)]) |
|
else: |
|
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) |
|
|
|
if u > cur_best: |
|
cur_best = u |
|
best_act = a |
|
|
|
a = best_act |
|
|
|
if depth > 100: |
|
candidates = self.game.getValidMoves(canonicalBoard, 1) |
|
a = np.random.choice([i for i in range(len(candidates)) if candidates[i] == 1]) |
|
|
|
|
|
|
|
depth = 80 |
|
|
|
|
|
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) |
|
next_s = self.game.getCanonicalForm(next_s, next_player) |
|
|
|
|
|
|
|
v = self.search(next_s, depth=depth + 1) |
|
|
|
if (s, a) in self.Qsa: |
|
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) |
|
self.Nsa[(s, a)] += 1 |
|
|
|
else: |
|
self.Qsa[(s, a)] = v |
|
self.Nsa[(s, a)] = 1 |
|
|
|
self.Ns[s] += 1 |
|
return -v |
|
|