Spaces:
Sleeping
Sleeping
sjz
commited on
Commit
•
9cefce7
1
Parent(s):
aae2a37
update code
Browse files- Gomoku_MCTS/mcts_Gumbel_Alphazero.py +390 -0
- Gomoku_MCTS/mcts_alphaZero.py +36 -21
- pages/Player_VS_AI.py +70 -21
Gomoku_MCTS/mcts_Gumbel_Alphazero.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FileName: main_worker.py
|
3 |
+
Author: Jiaxin Li
|
4 |
+
Create Date: 2023/11/21
|
5 |
+
Description: The implement of Gumbel MCST
|
6 |
+
Edit History:
|
7 |
+
Debug: the dim of output: probs
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import copy
|
12 |
+
import time
|
13 |
+
|
14 |
+
from config.options import *
|
15 |
+
import sys
|
16 |
+
from config.utils import *
|
17 |
+
|
18 |
+
|
19 |
+
def softmax(x):
|
20 |
+
probs = np.exp(x - np.max(x))
|
21 |
+
probs /= np.sum(probs)
|
22 |
+
return probs
|
23 |
+
|
24 |
+
|
25 |
+
def _sigma_mano(y ,Nb):
|
26 |
+
return (50 + Nb) * 1.0 * y
|
27 |
+
|
28 |
+
|
29 |
+
class TreeNode(object):
|
30 |
+
"""A node in the MCTS tree.
|
31 |
+
|
32 |
+
Each node keeps track of its own value Q, prior probability P, and
|
33 |
+
its visit-count-adjusted prior score u.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, parent, prior_p):
|
37 |
+
self._parent = parent
|
38 |
+
self._children = {} # a map from action to TreeNode
|
39 |
+
self._n_visits = 0
|
40 |
+
self._Q = 0
|
41 |
+
self._u = 0
|
42 |
+
self._v = 0
|
43 |
+
self._p = prior_p
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def expand(self, action_priors):
|
48 |
+
"""Expand tree by creating new children.
|
49 |
+
action_priors: a list of tuples of actions and their prior probability
|
50 |
+
according to the policy function.
|
51 |
+
"""
|
52 |
+
for action, prob in action_priors:
|
53 |
+
if action not in self._children:
|
54 |
+
self._children[action] = TreeNode(self, prob)
|
55 |
+
|
56 |
+
|
57 |
+
def select(self, v_pi):
|
58 |
+
"""Select action among children that gives maximum
|
59 |
+
(pi'(a) - N(a) \ (1 + \sum_b N(b)))
|
60 |
+
Return: A tuple of (action, next_node)
|
61 |
+
"""
|
62 |
+
# if opts.split == "train":
|
63 |
+
# v_pi = v_pi.detach().numpy()
|
64 |
+
# print(v_pi)
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._children.items()]))
|
70 |
+
|
71 |
+
if opts.split == "train":
|
72 |
+
pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1)
|
73 |
+
else:
|
74 |
+
pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1)
|
75 |
+
# print(pi_.shape)
|
76 |
+
|
77 |
+
|
78 |
+
N_a = np.array( [ act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items() ]).reshape(pi_.shape[0],-1)
|
79 |
+
# print(N_a.shape)
|
80 |
+
|
81 |
+
max_index= np.argmax(pi_ - N_a)
|
82 |
+
# print((pi_ - N_a).shape)
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
return list(self._children.items())[max_index]
|
87 |
+
|
88 |
+
|
89 |
+
def update(self, leaf_value):
|
90 |
+
"""Update node values from leaf evaluation.
|
91 |
+
leaf_value: the value of subtree evaluation from the current player's
|
92 |
+
perspective.
|
93 |
+
"""
|
94 |
+
# Count visit.
|
95 |
+
self._n_visits += 1
|
96 |
+
# Update Q, a running average of values for all visits.
|
97 |
+
if opts.split == "train":
|
98 |
+
self._Q = self._Q + (1.0*(leaf_value - self._Q ) / self._n_visits)
|
99 |
+
|
100 |
+
|
101 |
+
else:
|
102 |
+
self._Q += (1.0*(leaf_value - self._Q) / self._n_visits)
|
103 |
+
|
104 |
+
def update_recursive(self, leaf_value):
|
105 |
+
"""Like a call to update(), but applied recursively for all ancestors.
|
106 |
+
"""
|
107 |
+
# If it is not root, this node's parent should be updated first.
|
108 |
+
if self._parent:
|
109 |
+
self._parent.update_recursive(-leaf_value)
|
110 |
+
self.update(leaf_value)
|
111 |
+
|
112 |
+
def get_pi(self,v_pi,max_N_b):
|
113 |
+
if self._n_visits == 0:
|
114 |
+
Q_completed = v_pi
|
115 |
+
else:
|
116 |
+
Q_completed = self._Q
|
117 |
+
|
118 |
+
return self._p + _sigma_mano(Q_completed,max_N_b)
|
119 |
+
|
120 |
+
|
121 |
+
def get_value(self, c_puct):
|
122 |
+
"""Calculate and return the value for this node.
|
123 |
+
It is a combination of leaf evaluations Q, and this node's prior
|
124 |
+
adjusted for its visit count, u.
|
125 |
+
c_puct: a number in (0, inf) controlling the relative impact of
|
126 |
+
value Q, and prior probability P, on this node's score.
|
127 |
+
"""
|
128 |
+
self._u = (c_puct * self._P *
|
129 |
+
np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
|
130 |
+
return self._Q + self._u
|
131 |
+
|
132 |
+
def is_leaf(self):
|
133 |
+
"""Check if leaf node (i.e. no nodes below this have been expanded)."""
|
134 |
+
return self._children == {}
|
135 |
+
|
136 |
+
def is_root(self):
|
137 |
+
return self._parent is None
|
138 |
+
|
139 |
+
|
140 |
+
class Gumbel_MCTS(object):
|
141 |
+
"""An implementation of Monte Carlo Tree Search."""
|
142 |
+
|
143 |
+
def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
|
144 |
+
"""
|
145 |
+
policy_value_fn: a function that takes in a board state and outputs
|
146 |
+
a list of (action, probability) tuples and also a score in [-1, 1]
|
147 |
+
(i.e. the expected value of the end game score from the current
|
148 |
+
player's perspective) for the current player.
|
149 |
+
c_puct: a number in (0, inf) that controls how quickly exploration
|
150 |
+
converges to the maximum-value policy. A higher value means
|
151 |
+
relying on the prior more.
|
152 |
+
"""
|
153 |
+
self._root = TreeNode(None, 1.0)
|
154 |
+
self._policy = policy_value_fn
|
155 |
+
self._c_puct = c_puct
|
156 |
+
self._n_playout = n_playout
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
def Gumbel_playout(self, child_node, child_state):
|
162 |
+
"""Run a single playout from the child of the root to the leaf, getting a value at
|
163 |
+
the leaf and propagating it back through its parents.
|
164 |
+
State is modified in-place, so a copy must be provided.
|
165 |
+
This mothod of select is a non-root selet.
|
166 |
+
"""
|
167 |
+
node = child_node
|
168 |
+
state = child_state
|
169 |
+
|
170 |
+
while(1):
|
171 |
+
if node.is_leaf():
|
172 |
+
break
|
173 |
+
# Greedily select next move.
|
174 |
+
|
175 |
+
action, node = node.select(node._v)
|
176 |
+
|
177 |
+
state.do_move(action)
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
# Evaluate the leaf using a network which outputs a list of
|
182 |
+
# (action, probability) tuples p and also a score v in [-1, 1]
|
183 |
+
# for the current player.
|
184 |
+
action_probs, leaf_value = self._policy(state)
|
185 |
+
|
186 |
+
leaf_value = leaf_value.detach().numpy()[0][0]
|
187 |
+
|
188 |
+
node._v = leaf_value
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
# Check for end of game.
|
193 |
+
end, winner = state.game_end()
|
194 |
+
if not end:
|
195 |
+
node.expand(action_probs)
|
196 |
+
else:
|
197 |
+
# for end state,return the "true" leaf_value
|
198 |
+
if winner == -1: # tie
|
199 |
+
leaf_value = 0.0
|
200 |
+
else:
|
201 |
+
leaf_value = (
|
202 |
+
1.0 if winner == state.get_current_player() else -1.0
|
203 |
+
)
|
204 |
+
|
205 |
+
# Update value and visit count of nodes in this traversal.
|
206 |
+
node.update_recursive(-leaf_value)
|
207 |
+
|
208 |
+
|
209 |
+
def top_k(self,x, k):
|
210 |
+
# print("x",x.shape)
|
211 |
+
# print("k ", k)
|
212 |
+
|
213 |
+
return np.argpartition(x, k)[..., -k:]
|
214 |
+
|
215 |
+
def sample_k(self,logits, k):
|
216 |
+
u = np.random.uniform(size=np.shape(logits))
|
217 |
+
z = -np.log(-np.log(u))
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
return self.top_k(logits + z, k),z
|
222 |
+
|
223 |
+
|
224 |
+
def get_move_probs(self, state, temp=1e-3,m_action = 16):
|
225 |
+
"""Run all playouts sequentially and return the available actions and
|
226 |
+
their corresponding probabilities.
|
227 |
+
state: the current game state
|
228 |
+
temp: temperature parameter in (0, 1] controls the level of exploration
|
229 |
+
"""
|
230 |
+
# 这里需要修改:1
|
231 |
+
# logits 暂定为 p
|
232 |
+
|
233 |
+
start_time = time.time()
|
234 |
+
|
235 |
+
|
236 |
+
# 对根节点进行拓展
|
237 |
+
act_probs, leaf_value = self._policy(state)
|
238 |
+
act_probs = list(act_probs)
|
239 |
+
|
240 |
+
leaf_value = leaf_value.detach().numpy()[0][0]
|
241 |
+
|
242 |
+
# print(list(act_probs))
|
243 |
+
porbs = [prob for act,prob in (act_probs)]
|
244 |
+
self._root.expand(act_probs)
|
245 |
+
|
246 |
+
|
247 |
+
n = self._n_playout
|
248 |
+
m = min( m_action,int(len( porbs) / 2))
|
249 |
+
|
250 |
+
|
251 |
+
# 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
|
252 |
+
A_topm ,g = self.sample_k(porbs , m)
|
253 |
+
|
254 |
+
# 获得state选取每个action后对应的状态,保存到一个列表中
|
255 |
+
root_childs = list(self._root._children.items())
|
256 |
+
|
257 |
+
|
258 |
+
child_state_m = []
|
259 |
+
for i in range(m):
|
260 |
+
state_copy = copy.deepcopy(state)
|
261 |
+
action,node = root_childs[A_topm[i]]
|
262 |
+
state_copy.do_move(action)
|
263 |
+
child_state_m.append(state_copy)
|
264 |
+
|
265 |
+
|
266 |
+
# 每轮对选择的动作进行的仿真次数
|
267 |
+
N = int( n /( np.log(m) * m ))
|
268 |
+
|
269 |
+
# 进行sequential halving with Gumbel
|
270 |
+
while m >= 1:
|
271 |
+
|
272 |
+
# 对每个选择的动作进行仿真
|
273 |
+
for i in range(m):
|
274 |
+
action_state = child_state_m[i]
|
275 |
+
|
276 |
+
action,node = root_childs[A_topm[i]]
|
277 |
+
|
278 |
+
for j in range(N):
|
279 |
+
action_state_copy = copy.deepcopy(action_state)
|
280 |
+
|
281 |
+
# 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
|
282 |
+
self.Gumbel_playout(node, action_state_copy)
|
283 |
+
|
284 |
+
# 每轮不重复采样的动作个数减半
|
285 |
+
m = m //2
|
286 |
+
|
287 |
+
# 不是最后一轮,单轮仿真次数加倍
|
288 |
+
if(m != 1):
|
289 |
+
n = n - N
|
290 |
+
N *= 2
|
291 |
+
# 当最后一轮时,只有一个动作,把所有仿真次数用完
|
292 |
+
else:
|
293 |
+
N = n
|
294 |
+
|
295 |
+
# 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
|
296 |
+
# print([action_node[1]._Q for action_node in self._root._children.items() ])
|
297 |
+
|
298 |
+
|
299 |
+
q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items() ])
|
300 |
+
|
301 |
+
|
302 |
+
assert(np.sum(q_hat[A_topm] == 0) == 0 )
|
303 |
+
|
304 |
+
A_index = self.top_k( np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm] , m)
|
305 |
+
A_topm = np.array(A_topm)[A_index]
|
306 |
+
child_state_m = np.array(child_state_m)[A_index]
|
307 |
+
|
308 |
+
|
309 |
+
# 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
|
310 |
+
|
311 |
+
max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._root._children.items()] ))
|
312 |
+
|
313 |
+
final_act_probs= softmax( np.array( [ act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items() ]))
|
314 |
+
action = ( np.array( [ act_node[0] for act_node in self._root._children.items() ]))
|
315 |
+
|
316 |
+
need_time = time.time() - start_time
|
317 |
+
print(f" Gumbel Alphazero sum_time: {need_time }, total_simulation: {self._n_playout}")
|
318 |
+
|
319 |
+
return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs , need_time
|
320 |
+
|
321 |
+
def update_with_move(self, last_move):
|
322 |
+
"""Step forward in the tree, keeping everything we already know
|
323 |
+
about the subtree.
|
324 |
+
"""
|
325 |
+
if last_move in self._root._children:
|
326 |
+
self._root = self._root._children[last_move]
|
327 |
+
self._root._parent = None
|
328 |
+
else:
|
329 |
+
self._root = TreeNode(None, 1.0)
|
330 |
+
|
331 |
+
def __str__(self):
|
332 |
+
return "MCTS"
|
333 |
+
|
334 |
+
|
335 |
+
class Gumbel_MCTSPlayer(object):
|
336 |
+
"""AI player based on MCTS"""
|
337 |
+
|
338 |
+
def __init__(self, policy_value_function,
|
339 |
+
c_puct=5, n_playout=2000, is_selfplay=0,m_action = 16):
|
340 |
+
self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
|
341 |
+
self._is_selfplay = is_selfplay
|
342 |
+
self.m_action = m_action
|
343 |
+
|
344 |
+
|
345 |
+
def set_player_ind(self, p):
|
346 |
+
self.player = p
|
347 |
+
|
348 |
+
def reset_player(self):
|
349 |
+
self.mcts.update_with_move(-1)
|
350 |
+
|
351 |
+
|
352 |
+
def get_action(self, board, temp=1e-3, return_prob=0,return_time = False):
|
353 |
+
sensible_moves = board.availables
|
354 |
+
# the pi vector returned by MCTS as in the alphaGo Zero paper
|
355 |
+
move_probs = np.zeros(board.width*board.height)
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
if len(sensible_moves) > 0:
|
360 |
+
|
361 |
+
# 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
|
362 |
+
move, acts, probs,simul_mean_time = self.mcts.get_move_probs(board, temp,self.m_action)
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
# 重置搜索树
|
367 |
+
self.mcts.update_with_move(-1)
|
368 |
+
|
369 |
+
move_probs[list(acts)] = probs
|
370 |
+
|
371 |
+
|
372 |
+
if return_time:
|
373 |
+
|
374 |
+
if return_prob:
|
375 |
+
|
376 |
+
return move, move_probs,simul_mean_time
|
377 |
+
else:
|
378 |
+
return move,simul_mean_time
|
379 |
+
else:
|
380 |
+
|
381 |
+
if return_prob:
|
382 |
+
|
383 |
+
return move, move_probs
|
384 |
+
else:
|
385 |
+
return move
|
386 |
+
else:
|
387 |
+
print("WARNING: the board is full")
|
388 |
+
|
389 |
+
def __str__(self):
|
390 |
+
return "MCTS {}".format(self.player)
|
Gomoku_MCTS/mcts_alphaZero.py
CHANGED
@@ -156,35 +156,39 @@ class MCTS(object):
|
|
156 |
start_time_averge = 0
|
157 |
|
158 |
### test multi-thread
|
159 |
-
lock = threading.Lock()
|
160 |
-
with ThreadPoolExecutor(max_workers=4) as executor:
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
### end test multi-thread
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
176 |
# print('!!time!!:', time.time() - t)
|
177 |
|
178 |
-
|
179 |
|
180 |
|
181 |
# calc the move probabilities based on visit counts at the root node
|
182 |
act_visits = [(act, node._n_visits)
|
183 |
for act, node in self._root._children.items()]
|
|
|
184 |
acts, visits = zip(*act_visits)
|
|
|
185 |
act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
|
186 |
|
187 |
-
|
|
|
188 |
|
189 |
def update_with_move(self, last_move):
|
190 |
"""Step forward in the tree, keeping everything we already know
|
@@ -214,12 +218,12 @@ class MCTSPlayer(object):
|
|
214 |
def reset_player(self):
|
215 |
self.mcts.update_with_move(-1)
|
216 |
|
217 |
-
def get_action(self, board, temp=1e-3, return_prob=0):
|
218 |
sensible_moves = board.availables
|
219 |
# the pi vector returned by MCTS as in the alphaGo Zero paper
|
220 |
move_probs = np.zeros(board.width*board.height)
|
221 |
if len(sensible_moves) > 0:
|
222 |
-
acts, probs = self.mcts.get_move_probs(board, temp)
|
223 |
move_probs[list(acts)] = probs
|
224 |
if self._is_selfplay:
|
225 |
# add Dirichlet Noise for exploration (needed for
|
@@ -238,11 +242,22 @@ class MCTSPlayer(object):
|
|
238 |
self.mcts.update_with_move(-1)
|
239 |
# location = board.move_to_location(move)
|
240 |
# print("AI move: %d,%d\n" % (location[0], location[1]))
|
|
|
|
|
|
|
241 |
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
244 |
else:
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
246 |
else:
|
247 |
print("WARNING: the board is full")
|
248 |
|
|
|
156 |
start_time_averge = 0
|
157 |
|
158 |
### test multi-thread
|
159 |
+
# lock = threading.Lock()
|
160 |
+
# with ThreadPoolExecutor(max_workers=4) as executor:
|
161 |
+
# for n in range(self._n_playout):
|
162 |
+
# start_time = time.time()
|
163 |
+
|
164 |
+
# state_copy = copy.deepcopy(state)
|
165 |
+
# executor.submit(self._playout, state_copy, lock)
|
166 |
+
# start_time_averge += (time.time() - start_time)
|
167 |
### end test multi-thread
|
168 |
|
169 |
+
t = time.time()
|
170 |
+
for n in range(self._n_playout):
|
171 |
+
start_time = time.time()
|
172 |
|
173 |
+
state_copy = copy.deepcopy(state)
|
174 |
+
self._playout(state_copy)
|
175 |
+
start_time_averge += (time.time() - start_time)
|
176 |
+
total_time = time.time() - t
|
177 |
# print('!!time!!:', time.time() - t)
|
178 |
|
179 |
+
print(f" My MCTS sum_time: {total_time }, total_simulation: {self._n_playout}")
|
180 |
|
181 |
|
182 |
# calc the move probabilities based on visit counts at the root node
|
183 |
act_visits = [(act, node._n_visits)
|
184 |
for act, node in self._root._children.items()]
|
185 |
+
|
186 |
acts, visits = zip(*act_visits)
|
187 |
+
|
188 |
act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
|
189 |
|
190 |
+
|
191 |
+
return 0, acts, act_probs, total_time
|
192 |
|
193 |
def update_with_move(self, last_move):
|
194 |
"""Step forward in the tree, keeping everything we already know
|
|
|
218 |
def reset_player(self):
|
219 |
self.mcts.update_with_move(-1)
|
220 |
|
221 |
+
def get_action(self, board, temp=1e-3, return_prob=0,return_time = False):
|
222 |
sensible_moves = board.availables
|
223 |
# the pi vector returned by MCTS as in the alphaGo Zero paper
|
224 |
move_probs = np.zeros(board.width*board.height)
|
225 |
if len(sensible_moves) > 0:
|
226 |
+
_, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp)
|
227 |
move_probs[list(acts)] = probs
|
228 |
if self._is_selfplay:
|
229 |
# add Dirichlet Noise for exploration (needed for
|
|
|
242 |
self.mcts.update_with_move(-1)
|
243 |
# location = board.move_to_location(move)
|
244 |
# print("AI move: %d,%d\n" % (location[0], location[1]))
|
245 |
+
|
246 |
+
|
247 |
+
if return_time:
|
248 |
|
249 |
+
if return_prob:
|
250 |
+
|
251 |
+
return move, move_probs,simul_mean_time
|
252 |
+
else:
|
253 |
+
return move,simul_mean_time
|
254 |
else:
|
255 |
+
|
256 |
+
if return_prob:
|
257 |
+
|
258 |
+
return move, move_probs
|
259 |
+
else:
|
260 |
+
return move
|
261 |
else:
|
262 |
print("WARNING: the board is full")
|
263 |
|
pages/Player_VS_AI.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""
|
2 |
FileName: app.py
|
3 |
Author: Benhao Huang
|
4 |
-
Create Date: 2023/11/
|
5 |
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
|
6 |
"""
|
7 |
|
@@ -46,6 +46,7 @@ class Room:
|
|
46 |
self.WINNER = _BLANK
|
47 |
self.TIME = time.time()
|
48 |
self.MCTS = MCTSpure(c_puct=5, n_playout=10)
|
|
|
49 |
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
50 |
self.current_move = -1
|
51 |
self.simula_time_list = []
|
@@ -60,6 +61,8 @@ if "ROOM" not in session_state:
|
|
60 |
session_state.ROOM = Room("local")
|
61 |
if "OWNER" not in session_state:
|
62 |
session_state.OWNER = False
|
|
|
|
|
63 |
|
64 |
# Check server health
|
65 |
if "ROOMS" not in server_state:
|
@@ -88,6 +91,7 @@ MULTIPLAYER_TAG = st.sidebar.empty()
|
|
88 |
with st.sidebar.container():
|
89 |
ANOTHER_ROUND = st.empty()
|
90 |
RESTART = st.empty()
|
|
|
91 |
EXIT = st.empty()
|
92 |
GAME_INFO = st.sidebar.container()
|
93 |
message = st.empty()
|
@@ -213,6 +217,14 @@ def gomoku():
|
|
213 |
winner = _BLANK
|
214 |
return winner
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# Triggers the board response on click
|
217 |
def handle_click(x, y):
|
218 |
"""
|
@@ -257,7 +269,11 @@ def gomoku():
|
|
257 |
# Draw board
|
258 |
def draw_board(response: bool):
|
259 |
"""construct each buttons for all cells of the board"""
|
260 |
-
|
|
|
|
|
|
|
|
|
261 |
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
262 |
print("Your turn")
|
263 |
# construction of clickable buttons
|
@@ -276,13 +292,23 @@ def gomoku():
|
|
276 |
on_click=forbid_click
|
277 |
)
|
278 |
else:
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
|
288 |
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
|
@@ -292,7 +318,7 @@ def gomoku():
|
|
292 |
print("AI's turn")
|
293 |
print("Below are current board under AI's view")
|
294 |
print(session_state.ROOM.BOARD.board_map)
|
295 |
-
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD)
|
296 |
session_state.ROOM.simula_time_list.append(simul_time)
|
297 |
print("AI takes move: ", move)
|
298 |
session_state.ROOM.current_move = move
|
@@ -321,13 +347,24 @@ def gomoku():
|
|
321 |
on_click=forbid_click
|
322 |
)
|
323 |
else:
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
message.markdown(
|
333 |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
@@ -355,10 +392,17 @@ def gomoku():
|
|
355 |
print("Game over")
|
356 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
357 |
for j, cell in enumerate(row):
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
|
363 |
# Game process control
|
364 |
def game_control():
|
@@ -401,6 +445,11 @@ def gomoku():
|
|
401 |
st.line_chart(chart_data)
|
402 |
|
403 |
# The main game loop
|
|
|
|
|
|
|
|
|
|
|
404 |
game_control()
|
405 |
update_info()
|
406 |
|
|
|
1 |
"""
|
2 |
FileName: app.py
|
3 |
Author: Benhao Huang
|
4 |
+
Create Date: 2023/11/19
|
5 |
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
|
6 |
"""
|
7 |
|
|
|
46 |
self.WINNER = _BLANK
|
47 |
self.TIME = time.time()
|
48 |
self.MCTS = MCTSpure(c_puct=5, n_playout=10)
|
49 |
+
self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE).policy_value_fn, c_puct=5, n_playout=10)
|
50 |
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
51 |
self.current_move = -1
|
52 |
self.simula_time_list = []
|
|
|
61 |
session_state.ROOM = Room("local")
|
62 |
if "OWNER" not in session_state:
|
63 |
session_state.OWNER = False
|
64 |
+
if "USE_AIAID" not in session_state:
|
65 |
+
session_state.USE_AIAID = False
|
66 |
|
67 |
# Check server health
|
68 |
if "ROOMS" not in server_state:
|
|
|
91 |
with st.sidebar.container():
|
92 |
ANOTHER_ROUND = st.empty()
|
93 |
RESTART = st.empty()
|
94 |
+
AIAID = st.empty()
|
95 |
EXIT = st.empty()
|
96 |
GAME_INFO = st.sidebar.container()
|
97 |
message = st.empty()
|
|
|
217 |
winner = _BLANK
|
218 |
return winner
|
219 |
|
220 |
+
def ai_aid() -> None:
|
221 |
+
"""
|
222 |
+
Use AI Aid.
|
223 |
+
"""
|
224 |
+
session_state.USE_AIAID = not session_state.USE_AIAID
|
225 |
+
print('Use AI Aid: ', session_state.USE_AIAID)
|
226 |
+
draw_board(False)
|
227 |
+
|
228 |
# Triggers the board response on click
|
229 |
def handle_click(x, y):
|
230 |
"""
|
|
|
269 |
# Draw board
|
270 |
def draw_board(response: bool):
|
271 |
"""construct each buttons for all cells of the board"""
|
272 |
+
if session_state.USE_AIAID:
|
273 |
+
_, acts, probs, simul_mean_time = session_state.ROOM.MCTS.mcts.get_move_probs(session_state.ROOM.BOARD)
|
274 |
+
sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
|
275 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
276 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
277 |
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
278 |
print("Your turn")
|
279 |
# construction of clickable buttons
|
|
|
292 |
on_click=forbid_click
|
293 |
)
|
294 |
else:
|
295 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
296 |
+
# enable click for other cells available for human choices
|
297 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
298 |
+
BOARD_PLATE[i][j].button(
|
299 |
+
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
|
300 |
+
key=f"{i}:{j}",
|
301 |
+
on_click=handle_click,
|
302 |
+
args=(i, j),
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
# enable click for other cells available for human choices
|
306 |
+
BOARD_PLATE[i][j].button(
|
307 |
+
_PLAYER_SYMBOL[cell],
|
308 |
+
key=f"{i}:{j}",
|
309 |
+
on_click=handle_click,
|
310 |
+
args=(i, j),
|
311 |
+
)
|
312 |
|
313 |
|
314 |
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
|
|
|
318 |
print("AI's turn")
|
319 |
print("Below are current board under AI's view")
|
320 |
print(session_state.ROOM.BOARD.board_map)
|
321 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
322 |
session_state.ROOM.simula_time_list.append(simul_time)
|
323 |
print("AI takes move: ", move)
|
324 |
session_state.ROOM.current_move = move
|
|
|
347 |
on_click=forbid_click
|
348 |
)
|
349 |
else:
|
350 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
351 |
+
# enable click for other cells available for human choices
|
352 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
353 |
+
BOARD_PLATE[i][j].button(
|
354 |
+
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
|
355 |
+
key=f"{i}:{j}",
|
356 |
+
on_click=handle_click,
|
357 |
+
args=(i, j),
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
# enable click for other cells available for human choices
|
361 |
+
BOARD_PLATE[i][j].button(
|
362 |
+
_PLAYER_SYMBOL[cell],
|
363 |
+
key=f"{i}:{j}",
|
364 |
+
on_click=handle_click,
|
365 |
+
args=(i, j),
|
366 |
+
)
|
367 |
+
|
368 |
|
369 |
message.markdown(
|
370 |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
|
|
392 |
print("Game over")
|
393 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
394 |
for j, cell in enumerate(row):
|
395 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
396 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
397 |
+
BOARD_PLATE[i][j].write(
|
398 |
+
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
|
399 |
+
key=f"{i}:{j}",
|
400 |
+
)
|
401 |
+
else:
|
402 |
+
BOARD_PLATE[i][j].write(
|
403 |
+
_PLAYER_SYMBOL[cell],
|
404 |
+
key=f"{i}:{j}",
|
405 |
+
)
|
406 |
|
407 |
# Game process control
|
408 |
def game_control():
|
|
|
445 |
st.line_chart(chart_data)
|
446 |
|
447 |
# The main game loop
|
448 |
+
AIAID.button(
|
449 |
+
"Use AI Aid",
|
450 |
+
on_click=ai_aid,
|
451 |
+
help="Use AI Aid to help you make moves",
|
452 |
+
)
|
453 |
game_control()
|
454 |
update_info()
|
455 |
|