HuskyDoge commited on
Commit
e7a440c
โ€ข
1 Parent(s): 6d4d507

finish model selections

Browse files
Files changed (3) hide show
  1. Gomoku_MCTS/mcts_pure.py +48 -1
  2. const.py +2 -0
  3. pages/Player_VS_AI.py +61 -29
Gomoku_MCTS/mcts_pure.py CHANGED
@@ -6,6 +6,11 @@ from operator import itemgetter
6
  import time
7
 
8
 
 
 
 
 
 
9
  def rollout_policy_fn(board):
10
  """a coarse, fast version of policy_fn used in the rollout phase."""
11
  # rollout randomly
@@ -184,6 +189,48 @@ class MCTS(object):
184
  else:
185
  self._root = TreeNode(None, 1.0)
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def __str__(self):
188
  return "MCTS"
189
 
@@ -200,7 +247,7 @@ class MCTSPlayer(object):
200
  def reset_player(self):
201
  self.mcts.update_with_move(-1)
202
 
203
- def get_action(self, board):
204
  sensible_moves = board.availables
205
  if len(sensible_moves) > 0:
206
  move, simul_mean_time = self.mcts.get_move(board)
 
6
  import time
7
 
8
 
9
+ def softmax(x):
10
+ probs = np.exp(x - np.max(x))
11
+ probs /= np.sum(probs)
12
+ return probs
13
+
14
  def rollout_policy_fn(board):
15
  """a coarse, fast version of policy_fn used in the rollout phase."""
16
  # rollout randomly
 
189
  else:
190
  self._root = TreeNode(None, 1.0)
191
 
192
+ def get_move_probs(self, state, temp=1e-3):
193
+ """Run all playouts sequentially and return the available actions and
194
+ their corresponding probabilities.
195
+ state: the current game state
196
+ temp: temperature parameter in (0, 1] controls the level of exploration
197
+ """
198
+
199
+ start_time_averge = 0
200
+
201
+ ### test multi-thread
202
+ # lock = threading.Lock()
203
+ # with ThreadPoolExecutor(max_workers=4) as executor:
204
+ # for n in range(self._n_playout):
205
+ # start_time = time.time()
206
+
207
+ # state_copy = copy.deepcopy(state)
208
+ # executor.submit(self._playout, state_copy, lock)
209
+ # start_time_averge += (time.time() - start_time)
210
+ ### end test multi-thread
211
+
212
+ t = time.time()
213
+ for n in range(self._n_playout):
214
+ start_time = time.time()
215
+
216
+ state_copy = copy.deepcopy(state)
217
+ self._playout(state_copy)
218
+ start_time_averge += (time.time() - start_time)
219
+ total_time = time.time() - t
220
+ # print('!!time!!:', time.time() - t)
221
+
222
+ print(f" My MCTS sum_time: {total_time}, total_simulation: {self._n_playout}")
223
+
224
+ # calc the move probabilities based on visit counts at the root node
225
+ act_visits = [(act, node._n_visits)
226
+ for act, node in self._root._children.items()]
227
+
228
+ acts, visits = zip(*act_visits)
229
+
230
+ act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
231
+
232
+ return 0, acts, act_probs, total_time
233
+
234
  def __str__(self):
235
  return "MCTS"
236
 
 
247
  def reset_player(self):
248
  self.mcts.update_with_move(-1)
249
 
250
+ def get_action(self, board, return_time=False):
251
  sensible_moves = board.availables
252
  if len(sensible_moves) > 0:
253
  move, simul_mean_time = self.mcts.get_move(board)
const.py CHANGED
@@ -7,6 +7,8 @@ Description: Some const value for Demo
7
 
8
  import numpy as np
9
 
 
 
10
  _BOARD_SIZE = 8
11
  _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
12
  _BLANK = 0
 
7
 
8
  import numpy as np
9
 
10
+ _AI_AID_INFO = ["Use AI Aid", "Close AI Aid"]
11
+
12
  _BOARD_SIZE = 8
13
  _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
14
  _BLANK = 0
pages/Player_VS_AI.py CHANGED
@@ -30,10 +30,13 @@ from const import (
30
  _DIAGONAL_UP_LEFT,
31
  _DIAGONAL_UP_RIGHT,
32
  _BOARD_SIZE,
33
- _BOARD_SIZE_1D
 
34
  )
35
 
36
 
 
 
37
  # Utils
38
  class Room:
39
  def __init__(self, room_id) -> None:
@@ -45,8 +48,10 @@ class Room:
45
  self.HISTORY = (0, 0)
46
  self.WINNER = _BLANK
47
  self.TIME = time.time()
48
- self.MCTS = MCTSpure(c_puct=5, n_playout=10)
49
- self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100)
 
 
50
  self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
51
  self.current_move = -1
52
  self.simula_time_list = []
@@ -69,10 +74,41 @@ if "ROOMS" not in server_state:
69
  with server_state_lock["ROOMS"]:
70
  server_state.ROOMS = {}
71
 
72
- # # Layout
73
- # Main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  TITLE = st.empty()
 
 
75
  TITLE.header("๐Ÿค– AI 3603 Gomoku")
 
 
 
 
 
 
76
  ROUND_INFO = st.empty()
77
  st.markdown("<br>", unsafe_allow_html=True)
78
  BOARD_PLATE = [
@@ -93,6 +129,11 @@ with st.sidebar.container():
93
  RESTART = st.empty()
94
  AIAID = st.empty()
95
  EXIT = st.empty()
 
 
 
 
 
96
  GAME_INFO = st.sidebar.container()
97
  message = st.empty()
98
  PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
@@ -102,6 +143,7 @@ GAME_INFO.markdown(
102
  # <span style="color:black;">Freestyle Gomoku game. ๐ŸŽฒ</span>
103
  - no restrictions ๐Ÿšซ
104
  - no regrets ๐Ÿ˜Ž
 
105
  - swap players after one round is over ๐Ÿ”
106
  Powered by an AlphaZero approach with our own improvements! ๐Ÿš€ For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
107
  ##### Adapted and improved by us! ๐ŸŒŸ <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
@@ -110,6 +152,7 @@ GAME_INFO.markdown(
110
  )
111
 
112
 
 
113
  def restart() -> None:
114
  """
115
  Restart the game.
@@ -217,14 +260,6 @@ def gomoku():
217
  winner = _BLANK
218
  return winner
219
 
220
- def ai_aid() -> None:
221
- """
222
- Use AI Aid.
223
- """
224
- session_state.USE_AIAID = not session_state.USE_AIAID
225
- print('Use AI Aid: ', session_state.USE_AIAID)
226
- draw_board(False)
227
-
228
  # Triggers the board response on click
229
  def handle_click(x, y):
230
  """
@@ -270,11 +305,12 @@ def gomoku():
270
  def draw_board(response: bool):
271
  """construct each buttons for all cells of the board"""
272
  if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
273
- copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
274
- _, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
275
- sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
276
- top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
277
- top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
 
278
  if response and session_state.ROOM.TURN == _BLACK: # human turn
279
  print("Your turn")
280
  # construction of clickable buttons
@@ -333,11 +369,12 @@ def gomoku():
333
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
334
 
335
  if not session_state.ROOM.BOARD.game_end()[0]:
336
- copy_mcts = deepcopy(session_state.ROOM.MCTS.mcts)
337
- _, acts, probs, simul_mean_time = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
338
- sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
339
- top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
340
- top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
 
341
  else:
342
  top_five_acts = []
343
  top_five_probs = []
@@ -449,12 +486,7 @@ def gomoku():
449
  chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
450
  st.line_chart(chart_data)
451
 
452
- # The main game loop
453
- AIAID.button(
454
- "Use AI Aid",
455
- on_click=ai_aid,
456
- help="Use AI Aid to help you make moves",
457
- )
458
  game_control()
459
  update_info()
460
 
 
30
  _DIAGONAL_UP_LEFT,
31
  _DIAGONAL_UP_RIGHT,
32
  _BOARD_SIZE,
33
+ _BOARD_SIZE_1D,
34
+ _AI_AID_INFO
35
  )
36
 
37
 
38
+
39
+
40
  # Utils
41
  class Room:
42
  def __init__(self, room_id) -> None:
 
48
  self.HISTORY = (0, 0)
49
  self.WINNER = _BLANK
50
  self.TIME = time.time()
51
+ self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=10),
52
+ 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100)}
53
+ self.MCTS = self.MCTS_dict['AlphaZero']
54
+ self.AID_MCTS = self.MCTS_dict['AlphaZero']
55
  self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
56
  self.current_move = -1
57
  self.simula_time_list = []
 
74
  with server_state_lock["ROOMS"]:
75
  server_state.ROOMS = {}
76
 
77
+ def handle_oppo_model_selection():
78
+ TreeNode = session_state.ROOM.MCTS.mcts._root
79
+ new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
80
+ new_mct.mcts._root = deepcopy(TreeNode)
81
+ session_state.ROOM.MCTS = new_mct
82
+ return
83
+
84
+ def handle_aid_model_selection():
85
+ if st.session_state['selected_aid_model'] == 'None':
86
+ session_state.USE_AIAID = False
87
+ return
88
+ session_state.USE_AIAID = True
89
+ TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
90
+ new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
91
+ new_mct.mcts._root = deepcopy(TreeNode)
92
+ session_state.ROOM.AID_MCTS = new_mct
93
+ return
94
+
95
+ if 'selected_oppo_model' not in st.session_state:
96
+ st.session_state['selected_oppo_model'] = 'AlphaZero' # ้ป˜่ฎคๅ€ผ
97
+
98
+ if 'selected_aid_model' not in st.session_state:
99
+ st.session_state['selected_aid_model'] = 'AlphaZero' # ้ป˜่ฎคๅ€ผ
100
+
101
+ # Layout
102
  TITLE = st.empty()
103
+ Model_Switch = st.empty()
104
+
105
  TITLE.header("๐Ÿค– AI 3603 Gomoku")
106
+ selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero'], index=1, key='oppo_model')
107
+
108
+ if st.session_state['selected_oppo_model'] != selected_oppo_option:
109
+ st.session_state['selected_oppo_model'] = selected_oppo_option
110
+ handle_oppo_model_selection()
111
+
112
  ROUND_INFO = st.empty()
113
  st.markdown("<br>", unsafe_allow_html=True)
114
  BOARD_PLATE = [
 
129
  RESTART = st.empty()
130
  AIAID = st.empty()
131
  EXIT = st.empty()
132
+ selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, key='aid_model')
133
+ if st.session_state['selected_aid_model'] != selected_aid_option:
134
+ st.session_state['selected_aid_model'] = selected_aid_option
135
+ handle_aid_model_selection()
136
+
137
  GAME_INFO = st.sidebar.container()
138
  message = st.empty()
139
  PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
 
143
  # <span style="color:black;">Freestyle Gomoku game. ๐ŸŽฒ</span>
144
  - no restrictions ๐Ÿšซ
145
  - no regrets ๐Ÿ˜Ž
146
+ - no regrets ๐Ÿ˜Ž
147
  - swap players after one round is over ๐Ÿ”
148
  Powered by an AlphaZero approach with our own improvements! ๐Ÿš€ For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
149
  ##### Adapted and improved by us! ๐ŸŒŸ <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
 
152
  )
153
 
154
 
155
+
156
  def restart() -> None:
157
  """
158
  Restart the game.
 
260
  winner = _BLANK
261
  return winner
262
 
 
 
 
 
 
 
 
 
263
  # Triggers the board response on click
264
  def handle_click(x, y):
265
  """
 
305
  def draw_board(response: bool):
306
  """construct each buttons for all cells of the board"""
307
  if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
308
+ if session_state.USE_AIAID:
309
+ copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
310
+ _, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
311
+ sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
312
+ top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
313
+ top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
314
  if response and session_state.ROOM.TURN == _BLACK: # human turn
315
  print("Your turn")
316
  # construction of clickable buttons
 
369
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
370
 
371
  if not session_state.ROOM.BOARD.game_end()[0]:
372
+ if session_state.USE_AIAID:
373
+ copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
374
+ _, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
375
+ sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
376
+ top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
377
+ top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
378
  else:
379
  top_five_acts = []
380
  top_five_probs = []
 
486
  chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
487
  st.line_chart(chart_data)
488
 
489
+
 
 
 
 
 
490
  game_control()
491
  update_info()
492