Spaces:
Sleeping
Sleeping
finish model selections
Browse files- Gomoku_MCTS/mcts_pure.py +48 -1
- const.py +2 -0
- pages/Player_VS_AI.py +61 -29
Gomoku_MCTS/mcts_pure.py
CHANGED
@@ -6,6 +6,11 @@ from operator import itemgetter
|
|
6 |
import time
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
def rollout_policy_fn(board):
|
10 |
"""a coarse, fast version of policy_fn used in the rollout phase."""
|
11 |
# rollout randomly
|
@@ -184,6 +189,48 @@ class MCTS(object):
|
|
184 |
else:
|
185 |
self._root = TreeNode(None, 1.0)
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
def __str__(self):
|
188 |
return "MCTS"
|
189 |
|
@@ -200,7 +247,7 @@ class MCTSPlayer(object):
|
|
200 |
def reset_player(self):
|
201 |
self.mcts.update_with_move(-1)
|
202 |
|
203 |
-
def get_action(self, board):
|
204 |
sensible_moves = board.availables
|
205 |
if len(sensible_moves) > 0:
|
206 |
move, simul_mean_time = self.mcts.get_move(board)
|
|
|
6 |
import time
|
7 |
|
8 |
|
9 |
+
def softmax(x):
|
10 |
+
probs = np.exp(x - np.max(x))
|
11 |
+
probs /= np.sum(probs)
|
12 |
+
return probs
|
13 |
+
|
14 |
def rollout_policy_fn(board):
|
15 |
"""a coarse, fast version of policy_fn used in the rollout phase."""
|
16 |
# rollout randomly
|
|
|
189 |
else:
|
190 |
self._root = TreeNode(None, 1.0)
|
191 |
|
192 |
+
def get_move_probs(self, state, temp=1e-3):
|
193 |
+
"""Run all playouts sequentially and return the available actions and
|
194 |
+
their corresponding probabilities.
|
195 |
+
state: the current game state
|
196 |
+
temp: temperature parameter in (0, 1] controls the level of exploration
|
197 |
+
"""
|
198 |
+
|
199 |
+
start_time_averge = 0
|
200 |
+
|
201 |
+
### test multi-thread
|
202 |
+
# lock = threading.Lock()
|
203 |
+
# with ThreadPoolExecutor(max_workers=4) as executor:
|
204 |
+
# for n in range(self._n_playout):
|
205 |
+
# start_time = time.time()
|
206 |
+
|
207 |
+
# state_copy = copy.deepcopy(state)
|
208 |
+
# executor.submit(self._playout, state_copy, lock)
|
209 |
+
# start_time_averge += (time.time() - start_time)
|
210 |
+
### end test multi-thread
|
211 |
+
|
212 |
+
t = time.time()
|
213 |
+
for n in range(self._n_playout):
|
214 |
+
start_time = time.time()
|
215 |
+
|
216 |
+
state_copy = copy.deepcopy(state)
|
217 |
+
self._playout(state_copy)
|
218 |
+
start_time_averge += (time.time() - start_time)
|
219 |
+
total_time = time.time() - t
|
220 |
+
# print('!!time!!:', time.time() - t)
|
221 |
+
|
222 |
+
print(f" My MCTS sum_time: {total_time}, total_simulation: {self._n_playout}")
|
223 |
+
|
224 |
+
# calc the move probabilities based on visit counts at the root node
|
225 |
+
act_visits = [(act, node._n_visits)
|
226 |
+
for act, node in self._root._children.items()]
|
227 |
+
|
228 |
+
acts, visits = zip(*act_visits)
|
229 |
+
|
230 |
+
act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
|
231 |
+
|
232 |
+
return 0, acts, act_probs, total_time
|
233 |
+
|
234 |
def __str__(self):
|
235 |
return "MCTS"
|
236 |
|
|
|
247 |
def reset_player(self):
|
248 |
self.mcts.update_with_move(-1)
|
249 |
|
250 |
+
def get_action(self, board, return_time=False):
|
251 |
sensible_moves = board.availables
|
252 |
if len(sensible_moves) > 0:
|
253 |
move, simul_mean_time = self.mcts.get_move(board)
|
const.py
CHANGED
@@ -7,6 +7,8 @@ Description: Some const value for Demo
|
|
7 |
|
8 |
import numpy as np
|
9 |
|
|
|
|
|
10 |
_BOARD_SIZE = 8
|
11 |
_BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
12 |
_BLANK = 0
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
|
10 |
+
_AI_AID_INFO = ["Use AI Aid", "Close AI Aid"]
|
11 |
+
|
12 |
_BOARD_SIZE = 8
|
13 |
_BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
14 |
_BLANK = 0
|
pages/Player_VS_AI.py
CHANGED
@@ -30,10 +30,13 @@ from const import (
|
|
30 |
_DIAGONAL_UP_LEFT,
|
31 |
_DIAGONAL_UP_RIGHT,
|
32 |
_BOARD_SIZE,
|
33 |
-
_BOARD_SIZE_1D
|
|
|
34 |
)
|
35 |
|
36 |
|
|
|
|
|
37 |
# Utils
|
38 |
class Room:
|
39 |
def __init__(self, room_id) -> None:
|
@@ -45,8 +48,10 @@ class Room:
|
|
45 |
self.HISTORY = (0, 0)
|
46 |
self.WINNER = _BLANK
|
47 |
self.TIME = time.time()
|
48 |
-
self.
|
49 |
-
|
|
|
|
|
50 |
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
51 |
self.current_move = -1
|
52 |
self.simula_time_list = []
|
@@ -69,10 +74,41 @@ if "ROOMS" not in server_state:
|
|
69 |
with server_state_lock["ROOMS"]:
|
70 |
server_state.ROOMS = {}
|
71 |
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
TITLE = st.empty()
|
|
|
|
|
75 |
TITLE.header("๐ค AI 3603 Gomoku")
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
ROUND_INFO = st.empty()
|
77 |
st.markdown("<br>", unsafe_allow_html=True)
|
78 |
BOARD_PLATE = [
|
@@ -93,6 +129,11 @@ with st.sidebar.container():
|
|
93 |
RESTART = st.empty()
|
94 |
AIAID = st.empty()
|
95 |
EXIT = st.empty()
|
|
|
|
|
|
|
|
|
|
|
96 |
GAME_INFO = st.sidebar.container()
|
97 |
message = st.empty()
|
98 |
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
|
@@ -102,6 +143,7 @@ GAME_INFO.markdown(
|
|
102 |
# <span style="color:black;">Freestyle Gomoku game. ๐ฒ</span>
|
103 |
- no restrictions ๐ซ
|
104 |
- no regrets ๐
|
|
|
105 |
- swap players after one round is over ๐
|
106 |
Powered by an AlphaZero approach with our own improvements! ๐ For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
|
107 |
##### Adapted and improved by us! ๐ <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
|
@@ -110,6 +152,7 @@ GAME_INFO.markdown(
|
|
110 |
)
|
111 |
|
112 |
|
|
|
113 |
def restart() -> None:
|
114 |
"""
|
115 |
Restart the game.
|
@@ -217,14 +260,6 @@ def gomoku():
|
|
217 |
winner = _BLANK
|
218 |
return winner
|
219 |
|
220 |
-
def ai_aid() -> None:
|
221 |
-
"""
|
222 |
-
Use AI Aid.
|
223 |
-
"""
|
224 |
-
session_state.USE_AIAID = not session_state.USE_AIAID
|
225 |
-
print('Use AI Aid: ', session_state.USE_AIAID)
|
226 |
-
draw_board(False)
|
227 |
-
|
228 |
# Triggers the board response on click
|
229 |
def handle_click(x, y):
|
230 |
"""
|
@@ -270,11 +305,12 @@ def gomoku():
|
|
270 |
def draw_board(response: bool):
|
271 |
"""construct each buttons for all cells of the board"""
|
272 |
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
278 |
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
279 |
print("Your turn")
|
280 |
# construction of clickable buttons
|
@@ -333,11 +369,12 @@ def gomoku():
|
|
333 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
334 |
|
335 |
if not session_state.ROOM.BOARD.game_end()[0]:
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
341 |
else:
|
342 |
top_five_acts = []
|
343 |
top_five_probs = []
|
@@ -449,12 +486,7 @@ def gomoku():
|
|
449 |
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
|
450 |
st.line_chart(chart_data)
|
451 |
|
452 |
-
|
453 |
-
AIAID.button(
|
454 |
-
"Use AI Aid",
|
455 |
-
on_click=ai_aid,
|
456 |
-
help="Use AI Aid to help you make moves",
|
457 |
-
)
|
458 |
game_control()
|
459 |
update_info()
|
460 |
|
|
|
30 |
_DIAGONAL_UP_LEFT,
|
31 |
_DIAGONAL_UP_RIGHT,
|
32 |
_BOARD_SIZE,
|
33 |
+
_BOARD_SIZE_1D,
|
34 |
+
_AI_AID_INFO
|
35 |
)
|
36 |
|
37 |
|
38 |
+
|
39 |
+
|
40 |
# Utils
|
41 |
class Room:
|
42 |
def __init__(self, room_id) -> None:
|
|
|
48 |
self.HISTORY = (0, 0)
|
49 |
self.WINNER = _BLANK
|
50 |
self.TIME = time.time()
|
51 |
+
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=10),
|
52 |
+
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100)}
|
53 |
+
self.MCTS = self.MCTS_dict['AlphaZero']
|
54 |
+
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
55 |
self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
|
56 |
self.current_move = -1
|
57 |
self.simula_time_list = []
|
|
|
74 |
with server_state_lock["ROOMS"]:
|
75 |
server_state.ROOMS = {}
|
76 |
|
77 |
+
def handle_oppo_model_selection():
|
78 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root
|
79 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
80 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
81 |
+
session_state.ROOM.MCTS = new_mct
|
82 |
+
return
|
83 |
+
|
84 |
+
def handle_aid_model_selection():
|
85 |
+
if st.session_state['selected_aid_model'] == 'None':
|
86 |
+
session_state.USE_AIAID = False
|
87 |
+
return
|
88 |
+
session_state.USE_AIAID = True
|
89 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
|
90 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
91 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
92 |
+
session_state.ROOM.AID_MCTS = new_mct
|
93 |
+
return
|
94 |
+
|
95 |
+
if 'selected_oppo_model' not in st.session_state:
|
96 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero' # ้ป่ฎคๅผ
|
97 |
+
|
98 |
+
if 'selected_aid_model' not in st.session_state:
|
99 |
+
st.session_state['selected_aid_model'] = 'AlphaZero' # ้ป่ฎคๅผ
|
100 |
+
|
101 |
+
# Layout
|
102 |
TITLE = st.empty()
|
103 |
+
Model_Switch = st.empty()
|
104 |
+
|
105 |
TITLE.header("๐ค AI 3603 Gomoku")
|
106 |
+
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero'], index=1, key='oppo_model')
|
107 |
+
|
108 |
+
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
109 |
+
st.session_state['selected_oppo_model'] = selected_oppo_option
|
110 |
+
handle_oppo_model_selection()
|
111 |
+
|
112 |
ROUND_INFO = st.empty()
|
113 |
st.markdown("<br>", unsafe_allow_html=True)
|
114 |
BOARD_PLATE = [
|
|
|
129 |
RESTART = st.empty()
|
130 |
AIAID = st.empty()
|
131 |
EXIT = st.empty()
|
132 |
+
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, key='aid_model')
|
133 |
+
if st.session_state['selected_aid_model'] != selected_aid_option:
|
134 |
+
st.session_state['selected_aid_model'] = selected_aid_option
|
135 |
+
handle_aid_model_selection()
|
136 |
+
|
137 |
GAME_INFO = st.sidebar.container()
|
138 |
message = st.empty()
|
139 |
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
|
|
|
143 |
# <span style="color:black;">Freestyle Gomoku game. ๐ฒ</span>
|
144 |
- no restrictions ๐ซ
|
145 |
- no regrets ๐
|
146 |
+
- no regrets ๐
|
147 |
- swap players after one round is over ๐
|
148 |
Powered by an AlphaZero approach with our own improvements! ๐ For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
|
149 |
##### Adapted and improved by us! ๐ <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
|
|
|
152 |
)
|
153 |
|
154 |
|
155 |
+
|
156 |
def restart() -> None:
|
157 |
"""
|
158 |
Restart the game.
|
|
|
260 |
winner = _BLANK
|
261 |
return winner
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
# Triggers the board response on click
|
264 |
def handle_click(x, y):
|
265 |
"""
|
|
|
305 |
def draw_board(response: bool):
|
306 |
"""construct each buttons for all cells of the board"""
|
307 |
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
|
308 |
+
if session_state.USE_AIAID:
|
309 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
310 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
311 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
312 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
313 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
314 |
if response and session_state.ROOM.TURN == _BLACK: # human turn
|
315 |
print("Your turn")
|
316 |
# construction of clickable buttons
|
|
|
369 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
370 |
|
371 |
if not session_state.ROOM.BOARD.game_end()[0]:
|
372 |
+
if session_state.USE_AIAID:
|
373 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
374 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
375 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
376 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
377 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
378 |
else:
|
379 |
top_five_acts = []
|
380 |
top_five_probs = []
|
|
|
486 |
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
|
487 |
st.line_chart(chart_data)
|
488 |
|
489 |
+
|
|
|
|
|
|
|
|
|
|
|
490 |
game_control()
|
491 |
update_info()
|
492 |
|