HuskyDoge commited on
Commit
172a1e4
1 Parent(s): fc602f9
Files changed (33) hide show
  1. .DS_Store +0 -0
  2. Gomoku_MCTS/.DS_Store +0 -0
  3. Gomoku_MCTS/__init__.py +142 -0
  4. Gomoku_MCTS/__pycache__/__init__.cpython-310.pyc +0 -0
  5. Gomoku_MCTS/__pycache__/dueling_net.cpython-310.pyc +0 -0
  6. Gomoku_MCTS/__pycache__/game.cpython-310.pyc +0 -0
  7. Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-310.pyc +0 -0
  8. Gomoku_MCTS/__pycache__/mcts_pure.cpython-310.pyc +0 -0
  9. Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth +3 -0
  10. Gomoku_MCTS/config/config.yaml +10 -0
  11. Gomoku_MCTS/config/options.py +74 -0
  12. Gomoku_MCTS/config/utils.py +54 -0
  13. Gomoku_MCTS/dueling_net.py +155 -0
  14. Gomoku_MCTS/game.py +281 -0
  15. Gomoku_MCTS/main_worker.py +334 -0
  16. Gomoku_MCTS/mcts_alphaZero.py +250 -0
  17. Gomoku_MCTS/mcts_pure.py +246 -0
  18. Gomoku_MCTS/policy_value_net_pytorch.py +159 -0
  19. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183498.LAPTOP-5AN2UHOO +3 -0
  20. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183516.LAPTOP-5AN2UHOO +3 -0
  21. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183568.LAPTOP-5AN2UHOO +3 -0
  22. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183629.LAPTOP-5AN2UHOO +3 -0
  23. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183640.LAPTOP-5AN2UHOO +3 -0
  24. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183667.LAPTOP-5AN2UHOO +3 -0
  25. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183756.LAPTOP-5AN2UHOO +3 -0
  26. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183820.LAPTOP-5AN2UHOO +3 -0
  27. Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700184097.LAPTOP-5AN2UHOO +3 -0
  28. README.md +3 -3
  29. app.py +56 -0
  30. assets/favicon_circle.png +0 -0
  31. const.py +58 -0
  32. pages/Player_VS_AI.py +409 -0
  33. requirements.txt +7 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Gomoku_MCTS/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Gomoku_MCTS/__init__.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mcts_pure import MCTSPlayer as MCTSpure
2
+ from .mcts_alphaZero import MCTSPlayer as alphazero
3
+ from .dueling_net import PolicyValueNet
4
+ import numpy as np
5
+
6
+
7
+ class Board(object):
8
+ """board for the game"""
9
+
10
+ def __init__(self, **kwargs):
11
+ self.last_move = None
12
+ self.availables = None
13
+ self.current_player = None
14
+ self.width = int(kwargs.get('width', 8)) # if no width, default 8
15
+ self.height = int(kwargs.get('height', 8))
16
+ self.board_map = np.zeros(shape=(self.width, self.height), dtype=int)
17
+ # board states stored as a dict,
18
+ # key: move as location on the board,
19
+ # value: player as pieces type
20
+ self.states = {}
21
+ # need how many pieces in a row to win
22
+ self.n_in_row = int(kwargs.get('n_in_row', 5))
23
+ self.players = kwargs.get('players', [1, 2]) # player1 and player2
24
+ self.init_board(0)
25
+
26
+ def init_board(self, start_player=0):
27
+ if self.width < self.n_in_row or self.height < self.n_in_row:
28
+ raise Exception('board width and height can not be '
29
+ 'less than {}'.format(self.n_in_row))
30
+ self.current_player = self.players[start_player] # start player
31
+ # keep available moves in a list
32
+ self.availables = list(range(self.width * self.height))
33
+ self.states = {}
34
+ self.last_move = -1
35
+
36
+ def move_to_location(self, move: int):
37
+ """
38
+ 3*3 board's moves like:
39
+ 6 7 8
40
+ 3 4 5
41
+ 0 1 2
42
+ and move 5's location is (1,2)
43
+ """
44
+ h = move // self.width
45
+ w = move % self.width
46
+ return [h, w]
47
+
48
+ def location_to_move(self, location):
49
+ if len(location) != 2:
50
+ return -1
51
+ h = location[0]
52
+ w = location[1]
53
+ move = h * self.width + w
54
+ if move not in range(self.width * self.height):
55
+ return -1
56
+ return move
57
+
58
+ def current_state(self):
59
+ """
60
+ return the board state from the perspective of the current player.
61
+ state shape: 4*width*height
62
+ 这个状态数组具有四个通道:
63
+ 第一个通道表示当前玩家的棋子位置,第二个通道表示对手的棋子位置,第三个通道表示最后一步移动的位置。
64
+ 第四个通道是一个指示符,用于表示当前轮到哪个玩家(如果棋盘上的总移动次数是偶数,那么这个通道的所有元素都为1,表示是第一个玩家的回合;否则,所有元素都为0,表示是第二个玩家的回合)。
65
+ 每个通道都是一个 width x height 的二维数组,代表着棋盘的布局。对于第一个和第二个通道,如果一个位置上有当前玩家或对手的棋子,那么该位置的值为 1,否则为0。
66
+ 对于第三个通道,只有最后一步移动的位置是1,其余位置都为0。对于第四个通道,如果是第一个玩家的回合,那么所有的位置都是1,否则都是0。
67
+ 最后,状态数组在垂直方向上翻转,以匹配棋盘的实际布局。
68
+ """
69
+
70
+ square_state = np.zeros((4, self.width, self.height))
71
+ if self.states:
72
+ moves, players = np.array(list(zip(*self.states.items())))
73
+ move_curr = moves[players == self.current_player]
74
+ move_oppo = moves[players != self.current_player]
75
+ square_state[0][move_curr // self.width,
76
+ move_curr % self.height] = 1.0
77
+ square_state[1][move_oppo // self.width,
78
+ move_oppo % self.height] = 1.0
79
+ # indicate the last move location
80
+ square_state[2][self.last_move // self.width,
81
+ self.last_move % self.height] = 1.0
82
+ if len(self.states) % 2 == 0:
83
+ square_state[3][:, :] = 1.0 # indicate the colour to play
84
+ return square_state[:, ::-1, :]
85
+
86
+ def do_move(self, move):
87
+ self.states[move] = self.current_player
88
+ # get (x,y) of this move
89
+ x, y = self.move_to_location(move)
90
+ self.board_map[x][y] = self.current_player
91
+
92
+ self.availables.remove(move)
93
+ self.current_player = (
94
+ self.players[0] if self.current_player == self.players[1]
95
+ else self.players[1]
96
+ )
97
+ self.last_move = move
98
+
99
+ def has_a_winner(self):
100
+ width = self.width
101
+ height = self.height
102
+ states = self.states
103
+ n = self.n_in_row
104
+
105
+ moved = list(set(range(width * height)) - set(self.availables))
106
+ if len(moved) < self.n_in_row * 2 - 1:
107
+ return False, -1
108
+
109
+ for m in moved:
110
+ h = m // width
111
+ w = m % width
112
+ player = states[m]
113
+
114
+ if (w in range(width - n + 1) and
115
+ len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
116
+ return True, player
117
+
118
+ if (h in range(height - n + 1) and
119
+ len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
120
+ return True, player
121
+
122
+ if (w in range(width - n + 1) and h in range(height - n + 1) and
123
+ len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
124
+ return True, player
125
+
126
+ if (w in range(n - 1, width) and h in range(height - n + 1) and
127
+ len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
128
+ return True, player
129
+
130
+ return False, -1
131
+
132
+ def game_end(self):
133
+ """Check whether the game is ended or not"""
134
+ win, winner = self.has_a_winner()
135
+ if win:
136
+ return True, winner
137
+ elif not len(self.availables):
138
+ return True, -1
139
+ return False, -1
140
+
141
+ def get_current_player(self):
142
+ return self.current_player
Gomoku_MCTS/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.41 kB). View file
 
Gomoku_MCTS/__pycache__/dueling_net.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
Gomoku_MCTS/__pycache__/game.cpython-310.pyc ADDED
Binary file (8.97 kB). View file
 
Gomoku_MCTS/__pycache__/mcts_alphaZero.cpython-310.pyc ADDED
Binary file (8.05 kB). View file
 
Gomoku_MCTS/__pycache__/mcts_pure.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:878aace7c41962e0817fe8298a1f260b3b83e71c24d7d8c3558ccd6c4996d4f8
3
+ size 481383
Gomoku_MCTS/config/config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ckpt/logger options(dynamic)
2
+ checkpoint_base: checkpoint
3
+ visual_base: visualization
4
+ log_base: log
5
+
6
+ # dataset
7
+ data_base: dataset
8
+
9
+
10
+
Gomoku_MCTS/config/options.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import argparse
4
+ import yaml
5
+
6
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7
+
8
+ # basic settings
9
+ parser.add_argument('--seed', default=1234, type=int)
10
+ parser.add_argument('--savepath', type=str, default="blip_uni_cross_mu", help='')
11
+
12
+
13
+ # board settings
14
+ parser.add_argument("--board_width", type=int,default=9)
15
+ parser.add_argument("--board_height", type=int,default=9)
16
+ parser.add_argument("--n_in_row", type=int,default=5,help="the condition of winning")
17
+
18
+
19
+ # device settings
20
+ parser.add_argument('--config', type=str, default='config/config.yaml', help='Path to the config file.')
21
+ parser.add_argument('--gpu_num', type=int, default=1)
22
+ parser.add_argument('--gpu_id', type=str, default='5')
23
+
24
+
25
+ # save options
26
+ parser.add_argument('--clear_visualizer', dest='clear_visualizer', action='store_true')
27
+ parser.add_argument('--std_log', dest='std_log', action='store_true')
28
+
29
+
30
+ # mode settings
31
+ parser.add_argument("--split",type=str,default="train",help="the mode of woker")
32
+
33
+
34
+ # train settings
35
+ parser.add_argument("--expri",type=str, default="",help="the name of experiment")
36
+ parser.add_argument("--learn_rate", type=float,default=2e-3)
37
+ parser.add_argument("--l2_const",type=float,default=1e-4)
38
+ # ???
39
+ parser.add_argument("--lr_multiplier", type=float,default= 1.0 ,help="adaptively adjust the learning rate based on KL")
40
+ parser.add_argument("--buffer_size",type=int,default=10000,help="The size of collection of game data ")
41
+ parser.add_argument("--batch_size",type=int,default=512)
42
+ parser.add_argument("--play_batch_size",type=int, default=1,help="The time of selfplaying when collect the data")
43
+ parser.add_argument("--epochs",type=int,default=5,help="num of train_steps for each update")
44
+ parser.add_argument("--kl_targ",type=float,default=0.02,help="the target kl distance between the old decision function and the new decision function ")
45
+ parser.add_argument("--check_freq",type=int,default=50,help='the frequence of the checking the win ratio when training')
46
+ parser.add_argument("--game_batch_num",type=int,default=1500,help = "the total training times")
47
+
48
+
49
+ # parser.add_argument("--l2_const",type=float,default=1e-4,help=" coef of l2 penalty")
50
+ parser.add_argument("--distributed",type=bool,default=False)
51
+
52
+ # preload_model setting
53
+ parser.add_argument("--preload_model",type=str, default="")
54
+
55
+
56
+ # Alphazero agent setting
57
+ parser.add_argument("--temp", type=float,default= 1.0 ,help="the temperature parameter when calculate the decision function getting the next action")
58
+ parser.add_argument("--n_playout",type=int, default=200, help="num of simulations for each move ")
59
+ parser.add_argument("--c_puct",type=int, default=5, help= "the balance parameter between exploration and exploitative ")
60
+
61
+ # prue_mcts agent setting
62
+ parser.add_argument("--pure_mcts_playout_num",type=int, default=200)
63
+
64
+ # test settings
65
+ parser.add_argument('--test_ckpt', type=str, default=None, help='ckpt absolute path')
66
+
67
+
68
+ opts = parser.parse_args()
69
+
70
+ # additional parameters
71
+ current_path = os.path.abspath(__file__)
72
+ grandfather_path = os.path.abspath(os.path.dirname(os.path.dirname(current_path)) + os.path.sep + ".")
73
+ with open(os.path.join(grandfather_path, opts.config), 'r') as stream:
74
+ config = yaml.full_load(stream)
Gomoku_MCTS/config/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil
2
+ import torch
3
+ from tensorboardX import SummaryWriter
4
+ from config.options import *
5
+ import torch.distributed as dist
6
+ import time
7
+
8
+ """ ==================== Save ======================== """
9
+
10
+ def make_path():
11
+ return "{}_{}_bs{}_lr{}".format(opts.expri,opts.savepath,opts.batch_size,opts.learn_rate)
12
+
13
+
14
+
15
+
16
+ def save_model(model,name):
17
+ save_path = make_path()
18
+ if not os.path.isdir(os.path.join(config['checkpoint_base'], save_path)):
19
+ os.makedirs(os.path.join(config['checkpoint_base'], save_path), exist_ok=True)
20
+ model_name = os.path.join(config['checkpoint_base'], save_path, name)
21
+ torch.save(model.state_dict(), model_name)
22
+
23
+
24
+
25
+
26
+ """ ==================== Tools ======================== """
27
+ def is_dist_avail_and_initialized():
28
+ if not dist.is_available():
29
+ return False
30
+ if not dist.is_initialized():
31
+ return False
32
+ return True
33
+
34
+ def get_rank():
35
+ if not is_dist_avail_and_initialized():
36
+ return 0
37
+ return dist.get_rank()
38
+
39
+
40
+ def makedir(path):
41
+ if not os.path.exists(path):
42
+ os.makedirs(path, 0o777)
43
+
44
+
45
+ def visualizer():
46
+ if get_rank() == 0:
47
+ # filewriter_path = config['visual_base']+opts.savepath+'/'
48
+ save_path = make_path()
49
+ filewriter_path = os.path.join(config['visual_base'], save_path)
50
+ if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合
51
+ shutil.rmtree(filewriter_path)
52
+ makedir(filewriter_path)
53
+ writer = SummaryWriter(filewriter_path, comment='visualizer')
54
+ return writer
Gomoku_MCTS/dueling_net.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ import numpy as np
7
+
8
+ def set_learning_rate(optimizer, lr):
9
+ """Sets the learning rate to the given value"""
10
+ for param_group in optimizer.param_groups:
11
+ param_group['lr'] = lr
12
+
13
+ class DuelingDQNNet(nn.Module):
14
+ """Dueling DQN network module"""
15
+ def __init__(self, board_width, board_height):
16
+ super(DuelingDQNNet, self).__init__()
17
+
18
+ self.board_width = board_width
19
+ self.board_height = board_height
20
+ # common layers
21
+ self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
22
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
23
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
24
+ # advantage layers
25
+ self.adv_conv1 = nn.Conv2d(128, 4, kernel_size=1)
26
+ self.adv_fc1 = nn.Linear(4*board_width*board_height,
27
+ board_width*board_height)
28
+ # value layers
29
+ self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
30
+ self.val_fc1 = nn.Linear(2*board_width*board_height, 64)
31
+ self.val_fc2 = nn.Linear(64, 1)
32
+
33
+ def forward(self, state_input):
34
+ # common layers
35
+ x = F.relu(self.conv1(state_input))
36
+ x = F.relu(self.conv2(x))
37
+ x = F.relu(self.conv3(x))
38
+
39
+ # advantage stream
40
+ adv = F.relu(self.adv_conv1(x))
41
+ adv = adv.view(-1, 4*self.board_width*self.board_height)
42
+ adv = self.adv_fc1(adv)
43
+
44
+ # value stream
45
+ val = F.relu(self.val_conv1(x))
46
+ val = val.view(-1, 2*self.board_width*self.board_height)
47
+ val = F.relu(self.val_fc1(val))
48
+ val = self.val_fc2(val)
49
+
50
+ q_values = val + adv - adv.mean(dim=1, keepdim=True)
51
+
52
+ return F.log_softmax(q_values, dim=1), val
53
+
54
+ class PolicyValueNet():
55
+ """policy-value network """
56
+ def __init__(self, board_width, board_height,
57
+ model_file=None, use_gpu=False):
58
+ self.use_gpu = use_gpu
59
+ self.board_width = board_width
60
+ self.board_height = board_height
61
+ self.l2_const = 1e-4 # coef of l2 penalty
62
+ # the policy value net module
63
+ if self.use_gpu:
64
+ self.policy_value_net = DuelingDQNNet(board_width, board_height).cuda()
65
+ else:
66
+ self.policy_value_net = DuelingDQNNet(board_width, board_height)
67
+ self.optimizer = optim.Adam(self.policy_value_net.parameters(),
68
+ weight_decay=self.l2_const)
69
+
70
+ if model_file:
71
+ net_params = torch.load(model_file)
72
+ self.policy_value_net.load_state_dict(net_params, strict=False)
73
+
74
+ def policy_value(self, state_batch):
75
+ """
76
+ input: a batch of states
77
+ output: a batch of action probabilities and state values
78
+ """
79
+ if self.use_gpu:
80
+ state_batch = Variable(torch.FloatTensor(state_batch).cuda())
81
+ log_act_probs, value = self.policy_value_net(state_batch)
82
+ act_probs = np.exp(log_act_probs.data.cpu().numpy())
83
+ return act_probs, value.data.cpu().numpy()
84
+ else:
85
+ state_batch = Variable(torch.FloatTensor(state_batch))
86
+ log_act_probs, value = self.policy_value_net(state_batch)
87
+ act_probs = np.exp(log_act_probs.data.numpy())
88
+ return act_probs, value.data.numpy()
89
+
90
+ def policy_value_fn(self, board):
91
+ """
92
+ input: board
93
+ output: a list of (action, probability) tuples for each available
94
+ action and the score of the board state
95
+ """
96
+ legal_positions = board.availables
97
+ current_state = np.ascontiguousarray(board.current_state().reshape(
98
+ -1, 4, self.board_width, self.board_height))
99
+ if self.use_gpu:
100
+ log_act_probs, value = self.policy_value_net(
101
+ Variable(torch.from_numpy(current_state)).cuda().float())
102
+ act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
103
+ else:
104
+ log_act_probs, value = self.policy_value_net(
105
+ Variable(torch.from_numpy(current_state)).float())
106
+ act_probs = np.exp(log_act_probs.data.numpy().flatten())
107
+ act_probs = zip(legal_positions, act_probs[legal_positions])
108
+ value = value.data[0][0]
109
+ return act_probs, value
110
+
111
+ def train_step(self, state_batch, mcts_probs, winner_batch, lr):
112
+ """perform a training step"""
113
+
114
+ # self.use_gpu = True
115
+ # wrap in Variable
116
+ if self.use_gpu:
117
+ state_batch = Variable(torch.FloatTensor(state_batch).cuda())
118
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
119
+ winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
120
+ else:
121
+ state_batch = Variable(torch.FloatTensor(state_batch))
122
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs))
123
+ winner_batch = Variable(torch.FloatTensor(winner_batch))
124
+
125
+ # zero the parameter gradients
126
+ self.optimizer.zero_grad()
127
+ # set learning rate
128
+ set_learning_rate(self.optimizer, lr)
129
+
130
+ # forward
131
+ log_act_probs, value = self.policy_value_net(state_batch)
132
+ # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
133
+ # Note: the L2 penalty is incorporated in optimizer
134
+ value_loss = F.mse_loss(value.view(-1), winner_batch)
135
+ policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
136
+ loss = value_loss + policy_loss
137
+ # backward and optimize
138
+ loss.backward()
139
+ self.optimizer.step()
140
+ # calc policy entropy, for monitoring only
141
+ entropy = -torch.mean(
142
+ torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
143
+ )
144
+ # return loss.data[0], entropy.data[0]
145
+ #for pytorch version >= 0.5 please use the following line instead.
146
+ return loss.item(), entropy.item()
147
+
148
+ def get_policy_param(self):
149
+ net_params = self.policy_value_net.state_dict()
150
+ return net_params
151
+
152
+ def save_model(self, model_file):
153
+ """ save model params to file """
154
+ net_params = self.get_policy_param() # get model params
155
+ torch.save(net_params, model_file)
Gomoku_MCTS/game.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FileName: game.py
3
+ Author: Jiaxin Li
4
+ Create Date: yyyy/mm/dd
5
+ Description: to be completed
6
+ Edit History:
7
+ - 2023/11/18, Sat, Edited by Hbh (hbh001098hbh@sjtu.edu.cn)
8
+ - added some comments and optimize import and some structures
9
+ """
10
+
11
+ import numpy as np
12
+ from mcts_pure import MCTSPlayer as MCTS_Pure
13
+ from mcts_pure import Human_Player
14
+ from collections import defaultdict
15
+ from typing import Optional
16
+
17
+
18
+ class Board(object):
19
+ """board for the game"""
20
+
21
+ def __init__(self, **kwargs):
22
+ self.last_move = None
23
+ self.availables = None
24
+ self.current_player = None
25
+ self.width = int(kwargs.get('width', 8)) # if no width, default 8
26
+ self.height = int(kwargs.get('height', 8))
27
+ # board states stored as a dict,
28
+ # key: move as location on the board,
29
+ # value: player as pieces type
30
+ self.states = {}
31
+ # need how many pieces in a row to win
32
+ self.n_in_row = int(kwargs.get('n_in_row', 5))
33
+ self.players = [1, 2] # player1 and player2
34
+
35
+ def init_board(self, start_player=0):
36
+ if self.width < self.n_in_row or self.height < self.n_in_row:
37
+ raise Exception('board width and height can not be '
38
+ 'less than {}'.format(self.n_in_row))
39
+ self.current_player = self.players[start_player] # start player
40
+ # keep available moves in a list
41
+ self.availables = list(range(self.width * self.height))
42
+ self.states = {}
43
+ self.last_move = -1
44
+
45
+ def move_to_location(self, move: int):
46
+ """
47
+ 3*3 board's moves like:
48
+ 6 7 8
49
+ 3 4 5
50
+ 0 1 2
51
+ and move 5's location is (1,2)
52
+ """
53
+ h = move // self.width
54
+ w = move % self.width
55
+ return [h, w]
56
+
57
+ def location_to_move(self, location):
58
+ if len(location) != 2:
59
+ return -1
60
+ h = location[0]
61
+ w = location[1]
62
+ move = h * self.width + w
63
+ if move not in range(self.width * self.height):
64
+ return -1
65
+ return move
66
+
67
+ def current_state(self):
68
+ """
69
+ return the board state from the perspective of the current player.
70
+ state shape: 4*width*height
71
+ 这个状态数组具有四个通道:
72
+ 第一个通道表示当前玩家的棋子位置,第二个通道表示对手的棋子位置,第三个通道表示最后一步移动的位置。
73
+ 第四个通道是一个指示符,用于表示当前轮到哪个玩家(如果棋盘上的总移动次数是偶数,那么这个通道的所有元素都为1,表示是第一个玩家的回合;否则,所有元素都为0,表示是第二个玩家的回合)。
74
+ 每个通道都是一个 width x height 的二维数组,代表着棋盘的布局。对于第一个和第二个通道,如果一个位置上有当前玩家或对手的棋子,那么该位置的值为 1,否则为0。
75
+ 对于第三个通道,只有最后一步移动的位置是1,其余位置都为0。对于第四个通道,如果是第一个玩家的回合,那么所有的位置都是1,否则都是0。
76
+ 最后,状态数组在垂直方向上翻转,以匹配棋盘的实际布局。
77
+ """
78
+
79
+ square_state = np.zeros((4, self.width, self.height))
80
+ if self.states:
81
+ moves, players = np.array(list(zip(*self.states.items())))
82
+ move_curr = moves[players == self.current_player]
83
+ move_oppo = moves[players != self.current_player]
84
+ square_state[0][move_curr // self.width,
85
+ move_curr % self.height] = 1.0
86
+ square_state[1][move_oppo // self.width,
87
+ move_oppo % self.height] = 1.0
88
+ # indicate the last move location
89
+ square_state[2][self.last_move // self.width,
90
+ self.last_move % self.height] = 1.0
91
+ if len(self.states) % 2 == 0:
92
+ square_state[3][:, :] = 1.0 # indicate the colour to play
93
+ return square_state[:, ::-1, :]
94
+
95
+ def do_move(self, move):
96
+ self.states[move] = self.current_player
97
+ self.availables.remove(move)
98
+ self.current_player = (
99
+ self.players[0] if self.current_player == self.players[1]
100
+ else self.players[1]
101
+ )
102
+ self.last_move = move
103
+
104
+ def has_a_winner(self):
105
+ width = self.width
106
+ height = self.height
107
+ states = self.states
108
+ n = self.n_in_row
109
+
110
+ moved = list(set(range(width * height)) - set(self.availables))
111
+ if len(moved) < self.n_in_row * 2 - 1:
112
+ return False, -1
113
+
114
+ for m in moved:
115
+ h = m // width
116
+ w = m % width
117
+ player = states[m]
118
+
119
+ if (w in range(width - n + 1) and
120
+ len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
121
+ return True, player
122
+
123
+ if (h in range(height - n + 1) and
124
+ len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
125
+ return True, player
126
+
127
+ if (w in range(width - n + 1) and h in range(height - n + 1) and
128
+ len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
129
+ return True, player
130
+
131
+ if (w in range(n - 1, width) and h in range(height - n + 1) and
132
+ len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
133
+ return True, player
134
+
135
+ return False, -1
136
+
137
+ def game_end(self):
138
+ """Check whether the game is ended or not"""
139
+ win, winner = self.has_a_winner()
140
+ if win:
141
+ return True, winner
142
+ elif not len(self.availables):
143
+ return True, -1
144
+ return False, -1
145
+
146
+ def get_current_player(self):
147
+ return self.current_player
148
+
149
+
150
+ class Game(object):
151
+ """game server"""
152
+
153
+ def __init__(self, board, **kwargs):
154
+ self.board = board
155
+ self.pure_mcts_playout_num = 100 # simulation time
156
+
157
+ def graphic(self, board, player1, player2):
158
+ """Draw the board and show game info"""
159
+ width = board.width
160
+ height = board.height
161
+
162
+ print("Player", player1, "with X".rjust(3))
163
+ print("Player", player2, "with O".rjust(3))
164
+ print()
165
+ for x in range(width):
166
+ print("{0:8}".format(x), end='')
167
+ print('\r\n')
168
+ for i in range(height - 1, -1, -1):
169
+ print("{0:4d}".format(i), end='')
170
+ for j in range(width):
171
+ loc = i * width + j
172
+ p = board.states.get(loc, -1)
173
+ if p == player1:
174
+ print('X'.center(8), end='')
175
+ elif p == player2:
176
+ print('O'.center(8), end='')
177
+ else:
178
+ print('_'.center(8), end='')
179
+ print('\r\n\r\n')
180
+
181
+ def start_play(self, player1, player2, start_player=0, is_shown=1):
182
+ """start a game between two players"""
183
+ if start_player not in (0, 1):
184
+ raise Exception('start_player should be either 0 (player1 first) '
185
+ 'or 1 (player2 f1irst)')
186
+ self.board.init_board(start_player)
187
+ p1, p2 = self.board.players
188
+ player1.set_player_ind(p1)
189
+ player2.set_player_ind(p2)
190
+ players = {p1: player1, p2: player2}
191
+ if is_shown:
192
+ self.graphic(self.board, player1.player, player2.player)
193
+ while True:
194
+ current_player = self.board.get_current_player()
195
+ player_in_turn = players[current_player]
196
+ move = player_in_turn.get_action(self.board)
197
+ self.board.do_move(move)
198
+ if is_shown:
199
+ self.graphic(self.board, player1.player, player2.player)
200
+ end, winner = self.board.game_end()
201
+ if end:
202
+ if is_shown:
203
+ if winner != -1:
204
+ print("Game end. Winner is", players[winner])
205
+ else:
206
+ print("Game end. Tie")
207
+ return winner
208
+
209
+ def start_self_play(self, player, is_shown=0, temp=1e-3):
210
+ """
211
+ start a self-play game using a MCTS player, reuse the search tree,
212
+ and store the self-play data: (state, mcts_probs, z) for training
213
+ """
214
+ self.board.init_board()
215
+ p1, p2 = self.board.players
216
+ states, mcts_probs, current_players = [], [], []
217
+ while True:
218
+ move, move_probs = player.get_action(self.board,
219
+ temp=temp,
220
+ return_prob=1)
221
+ # store the data
222
+ states.append(self.board.current_state())
223
+ mcts_probs.append(move_probs)
224
+ current_players.append(self.board.current_player)
225
+ # perform a move
226
+ self.board.do_move(move)
227
+ if is_shown:
228
+ self.graphic(self.board, p1, p2)
229
+ end, winner = self.board.game_end()
230
+ if end:
231
+ # winner from the perspective of the current player of each state
232
+ winners_z = np.zeros(len(current_players))
233
+ if winner != -1:
234
+ winners_z[np.array(current_players) == winner] = 1.0
235
+ winners_z[np.array(current_players) != winner] = -1.0
236
+ # reset MCTS root node
237
+ player.reset_player()
238
+ if is_shown:
239
+ if winner != -1:
240
+ print("Game end. Winner is player:", winner)
241
+ else:
242
+ print("Game end. Tie")
243
+ return winner, zip(states, mcts_probs, winners_z)
244
+
245
+ # 多了下面这一串测试代码
246
+
247
+ def policy_evaluate(self, n_games=10):
248
+ """
249
+ Evaluate the trained policy by playing against the pure MCTS player
250
+ Note: this is only for monitoring the progress of training
251
+ """
252
+ current_mcts_player = MCTS_Pure(c_puct=5,
253
+ n_playout=self.pure_mcts_playout_num)
254
+
255
+ # pure_mcts_player = MCTS_Pure(c_puct=5,
256
+ # n_playout=self.pure_mcts_playout_num)
257
+
258
+ pure_mcts_player = Human_Player()
259
+ win_cnt = defaultdict(int)
260
+ for i in range(n_games):
261
+ winner = self.start_play(current_mcts_player,
262
+ pure_mcts_player,
263
+ start_player=i % 2,
264
+ is_shown=1)
265
+ win_cnt[winner] += 1
266
+ win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
267
+ print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
268
+ self.pure_mcts_playout_num,
269
+ win_cnt[1], win_cnt[2], win_cnt[-1]))
270
+ return win_ratio
271
+
272
+
273
+ if __name__ == '__main__':
274
+ board_width = 8
275
+ board_height = 8
276
+ n_in_row = 5
277
+ board = Board(width=board_width,
278
+ height=board_height,
279
+ n_in_row=n_in_row)
280
+ task = Game(board)
281
+ task.policy_evaluate(n_games=10)
Gomoku_MCTS/main_worker.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import random
3
+ import numpy as np
4
+ from collections import defaultdict, deque
5
+ from game import Board, Game
6
+ from mcts_pure import MCTSPlayer as MCTS_Pure
7
+ from mcts_alphaZero import MCTSPlayer
8
+ import torch.optim as optim
9
+ # from policy_value_net import PolicyValueNet # Theano and Lasagne
10
+ # from policy_value_net_pytorch import PolicyValueNet # Pytorch
11
+ from dueling_net import PolicyValueNet
12
+ # from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
13
+ # from policy_value_net_keras import PolicyValueNet # Keras
14
+ # import joblib
15
+ from torch.autograd import Variable
16
+ import torch.nn.functional as F
17
+
18
+
19
+ from config.options import *
20
+ import sys
21
+ from config.utils import *
22
+ from torch.backends import cudnn
23
+
24
+ import torch
25
+
26
+ from tqdm import *
27
+ from torch.utils.tensorboard import SummaryWriter
28
+
29
+ from multiprocessing import Pool
30
+
31
+ def set_learning_rate(optimizer, lr):
32
+ """Sets the learning rate to the given value"""
33
+ for param_group in optimizer.param_groups:
34
+ param_group['lr'] = lr
35
+
36
+ def std_log():
37
+ if get_rank() == 0:
38
+ save_path = make_path()
39
+ makedir(config['log_base'])
40
+ sys.stdout = open(os.path.join(config['log_base'], "{}.txt".format(save_path)), "w")
41
+
42
+
43
+ def init_seeds(seed, cuda_deterministic=True):
44
+ torch.manual_seed(seed)
45
+ if cuda_deterministic: # slower, more reproducible
46
+ cudnn.deterministic = True
47
+ cudnn.benchmark = False
48
+ else: # faster, less reproducible
49
+ cudnn.deterministic = False
50
+ cudnn.benchmark = True
51
+
52
+
53
+
54
+
55
+ class MainWorker():
56
+ def __init__(self,device):
57
+
58
+ #--- init the set of pipeline -------
59
+ self.board_width = opts.board_width
60
+ self.board_height = opts.board_height
61
+ self.n_in_row = opts.n_in_row
62
+ self.learn_rate = opts.learn_rate
63
+ self.lr_multiplier = opts.lr_multiplier
64
+ self.temp = opts.temp
65
+ self.n_playout = opts.n_playout
66
+ self.c_puct = opts.c_puct
67
+ self.buffer_size = opts.buffer_size
68
+ self.batch_size = opts.batch_size
69
+ self.play_batch_size = opts.play_batch_size
70
+ self.epochs = opts.epochs
71
+ self.kl_targ = opts.kl_targ
72
+ self.check_freq = opts.check_freq
73
+ self.game_batch_num = opts.game_batch_num
74
+ self.pure_mcts_playout_num = opts.pure_mcts_playout_num
75
+
76
+ self.device = device
77
+ self.use_gpu = torch.device("cuda") == self.device
78
+
79
+ self.board = Board(width=self.board_width,
80
+ height=self.board_height,
81
+ n_in_row=self.n_in_row)
82
+ self.game = Game(self.board)
83
+
84
+ # The data collection of the history of games
85
+ self.data_buffer = deque(maxlen=self.buffer_size)
86
+
87
+
88
+ # The best win ratio of the training agent
89
+ self.best_win_ratio = 0.0
90
+
91
+
92
+ if opts.preload_model:
93
+ # start training from an initial policy-value net
94
+ self.policy_value_net = PolicyValueNet(self.board_width,
95
+ self.board_height,
96
+ model_file=opts.preload_model,
97
+ use_gpu=(self.device == "cuda"))
98
+
99
+ else:
100
+ # start training from a new policy-value net
101
+ self.policy_value_net = PolicyValueNet(self.board_width,
102
+ self.board_height,
103
+ use_gpu=(self.device == "cuda"))
104
+ self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
105
+ c_puct=self.c_puct,
106
+ n_playout=self.n_playout,
107
+ is_selfplay=1)
108
+
109
+ # The set of optimizer
110
+ self.optimizer = optim.Adam(self.policy_value_net.policy_value_net.parameters(),
111
+ weight_decay=opts.l2_const)
112
+ # set learning rate
113
+ set_learning_rate(self.optimizer, self.learn_rate*self.lr_multiplier)
114
+
115
+
116
+
117
+
118
+ def get_equi_data(self, play_data):
119
+ """augment the data set by rotation and flipping
120
+ play_data: [(state, mcts_prob, winner_z), ..., ...]
121
+ """
122
+ extend_data = []
123
+ for state, mcts_porb, winner in play_data:
124
+ for i in [1, 2, 3, 4]:
125
+ # rotate counterclockwise
126
+ equi_state = np.array([np.rot90(s, i) for s in state])
127
+ equi_mcts_prob = np.rot90(np.flipud(
128
+ mcts_porb.reshape(self.board_height, self.board_width)), i)
129
+ extend_data.append((equi_state,
130
+ np.flipud(equi_mcts_prob).flatten(),
131
+ winner))
132
+ # flip horizontally
133
+ equi_state = np.array([np.fliplr(s) for s in equi_state])
134
+ equi_mcts_prob = np.fliplr(equi_mcts_prob)
135
+ extend_data.append((equi_state,
136
+ np.flipud(equi_mcts_prob).flatten(),
137
+ winner))
138
+ return extend_data
139
+
140
+ def job(self, i):
141
+ game = self.game
142
+ player = self.mcts_player
143
+ winner, play_data = game.start_self_play(player,
144
+ temp=self.temp)
145
+ play_data = list(play_data)[:]
146
+ play_data = self.get_equi_data(play_data)
147
+
148
+ return play_data
149
+
150
+ def collect_selfplay_data(self, n_games=1):
151
+ """collect self-play data for training"""
152
+ # print("[STAGE] Collecting self-play data for training")
153
+
154
+ # collection_bar = tqdm( range(n_games))
155
+ collection_bar = range(n_games)
156
+ with Pool(4) as p:
157
+ play_data = p.map(self.job, collection_bar, chunksize=1)
158
+ self.data_buffer.extend(play_data)
159
+ # print('\n', 'data buffer size:', len(self.data_buffer))
160
+
161
+ def policy_update(self):
162
+ """update the policy-value net"""
163
+ mini_batch = random.sample(self.data_buffer, self.batch_size)
164
+ state_batch = [data[0] for data in mini_batch]
165
+ mcts_probs_batch = [data[1] for data in mini_batch]
166
+ winner_batch = [data[2] for data in mini_batch]
167
+ old_probs, old_v = self.policy_value_net.policy_value(state_batch)
168
+
169
+ epoch_bar = tqdm(range(self.epochs))
170
+
171
+ for i in epoch_bar:
172
+ """perform a training step"""
173
+ # wrap in Variable
174
+ if self.use_gpu:
175
+ state_batch = Variable(torch.FloatTensor(state_batch).cuda())
176
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs_batch).cuda())
177
+ winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
178
+ else:
179
+ state_batch = Variable(torch.FloatTensor(state_batch))
180
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs_batch))
181
+ winner_batch = Variable(torch.FloatTensor(winner_batch))
182
+
183
+ # zero the parameter gradients
184
+ self.optimizer.zero_grad()
185
+
186
+ # forward
187
+ log_act_probs, value = self.policy_value_net.policy_value_net(state_batch)
188
+ # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
189
+ # Note: the L2 penalty is incorporated in optimizer
190
+ value_loss = F.mse_loss(value.view(-1), winner_batch)
191
+ policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
192
+ loss = value_loss + policy_loss
193
+ # backward and optimize
194
+ loss.backward()
195
+ self.optimizer.step()
196
+ # calc policy entropy, for monitoring only
197
+ entropy = -torch.mean(
198
+ torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
199
+ )
200
+ loss = loss.item()
201
+ entropy = entropy.item()
202
+
203
+ new_probs, new_v = self.policy_value_net.policy_value(state_batch)
204
+ kl = np.mean(np.sum(old_probs * (
205
+ np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
206
+ axis=1)
207
+ )
208
+ if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly
209
+ break
210
+
211
+ epoch_bar.set_description(f"training epoch {i}")
212
+ epoch_bar.set_postfix( new_v =new_v, kl = kl)
213
+
214
+ # adaptively adjust the learning rate
215
+ if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
216
+ self.lr_multiplier /= 1.5
217
+ elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
218
+ self.lr_multiplier *= 1.5
219
+
220
+
221
+
222
+ explained_var_old = (1 -
223
+ np.var(np.array(winner_batch) - old_v.flatten()) /
224
+ np.var(np.array(winner_batch)))
225
+ explained_var_new = (1 -
226
+ np.var(np.array(winner_batch) - new_v.flatten()) /
227
+ np.var(np.array(winner_batch)))
228
+
229
+
230
+
231
+
232
+ return kl, loss, entropy,explained_var_old, explained_var_new
233
+
234
+ def policy_evaluate(self, n_games=10):
235
+ """
236
+ Evaluate the trained policy by playing against the pure MCTS player
237
+ Note: this is only for monitoring the progress of training
238
+ """
239
+ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
240
+ c_puct=self.c_puct,
241
+ n_playout=self.n_playout)
242
+ pure_mcts_player = MCTS_Pure(c_puct=5,
243
+ n_playout=self.pure_mcts_playout_num)
244
+ win_cnt = defaultdict(int)
245
+ for i in range(n_games):
246
+
247
+ winner = self.game.start_play(
248
+ pure_mcts_player,current_mcts_player,
249
+ start_player=i % 2,
250
+ is_shown=0)
251
+ win_cnt[winner] += 1
252
+ print(f" {i}_th winner:" , winner)
253
+ win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
254
+ print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
255
+ self.pure_mcts_playout_num,
256
+ win_cnt[1], win_cnt[2], win_cnt[-1]))
257
+ return win_ratio
258
+
259
+ def run(self):
260
+ """run the training pipeline"""
261
+ try:
262
+
263
+ batch_bar = tqdm(range(self.game_batch_num))
264
+ for i in batch_bar:
265
+ self.collect_selfplay_data(self.play_batch_size)
266
+
267
+ if len(self.data_buffer) > self.batch_size:
268
+ kl, loss, entropy,explained_var_old, explained_var_new = self.policy_update()
269
+
270
+ writer.add_scalar("policy_update/kl", kl ,i )
271
+ writer.add_scalar("policy_update/loss", loss ,i)
272
+ writer.add_scalar("policy_update/entropy", entropy ,i)
273
+ writer.add_scalar("policy_update/explained_var_old", explained_var_old,i)
274
+ writer.add_scalar("policy_update/explained_var_new ", explained_var_new ,i)
275
+
276
+
277
+ batch_bar.set_description(f"game batch num {i}")
278
+
279
+ # check the performance of the current model,
280
+ # and save the model params
281
+ if (i+1) % self.check_freq == 0:
282
+ win_ratio = self.policy_evaluate()
283
+
284
+ batch_bar.set_description(f"game batch num {i+1}")
285
+ writer.add_scalar("evaluate/explained_var_new ", win_ratio ,i)
286
+ batch_bar.set_postfix(loss= loss, entropy= entropy,win_ratio =win_ratio)
287
+
288
+ save_model(self.policy_value_net,"current_policy.model")
289
+ if win_ratio > self.best_win_ratio:
290
+ print("New best policy!!!!!!!!")
291
+ self.best_win_ratio = win_ratio
292
+ # update the best_policy
293
+ save_model(self.policy_value_net,"best_policy.model")
294
+ if (self.best_win_ratio == 1.0 and
295
+ self.pure_mcts_playout_num < 5000):
296
+ self.pure_mcts_playout_num += 1000
297
+ self.best_win_ratio = 0.0
298
+ except KeyboardInterrupt:
299
+ print('\n\rquit')
300
+
301
+
302
+ if __name__ == "__main__":
303
+ print("START train....")
304
+
305
+ # ------init set-----------
306
+
307
+ if opts.std_log:
308
+ std_log()
309
+ writer = visualizer()
310
+
311
+
312
+ if opts.distributed:
313
+ torch.distributed.init_process_group(backend="nccl")
314
+ local_rank = torch.distributed.get_rank()
315
+ torch.cuda.set_device(local_rank)
316
+ device = torch.device("cuda", local_rank)
317
+ init_seeds(opts.seed + local_rank)
318
+
319
+ else:
320
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
321
+ init_seeds(opts.seed)
322
+
323
+ print("seed: ",opts.seed )
324
+ print("device:" , device)
325
+
326
+
327
+ if opts.split == "train":
328
+ training_pipeline = MainWorker(device)
329
+ training_pipeline.run()
330
+
331
+ if get_rank() == 0 and opts.split == "test":
332
+ training_pipeline = MainWorker(device)
333
+ training_pipeline.policy_value_net()
334
+
Gomoku_MCTS/mcts_alphaZero.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value
4
+ network to guide the tree search and evaluate the leaf nodes
5
+
6
+ @author: Junxiao Song
7
+ """
8
+
9
+ import numpy as np
10
+ import copy
11
+ import time
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ import threading
14
+
15
+
16
+ def softmax(x):
17
+ probs = np.exp(x - np.max(x))
18
+ probs /= np.sum(probs)
19
+ return probs
20
+
21
+
22
+ class TreeNode(object):
23
+ """A node in the MCTS tree.
24
+
25
+ Each node keeps track of its own value Q, prior probability P, and
26
+ its visit-count-adjusted prior score u.
27
+ """
28
+
29
+ def __init__(self, parent, prior_p):
30
+ self._parent = parent
31
+ self._children = {} # a map from action to TreeNode
32
+ self._n_visits = 0
33
+ self._Q = 0
34
+ self._u = 0
35
+ self._P = prior_p
36
+
37
+ def expand(self, action_priors):
38
+ """Expand tree by creating new children.
39
+ action_priors: a list of tuples of actions and their prior probability
40
+ according to the policy function.
41
+ """
42
+ for action, prob in action_priors:
43
+ if action not in self._children:
44
+ self._children[action] = TreeNode(self, prob)
45
+
46
+ def select(self, c_puct):
47
+ """Select action among children that gives maximum action value Q
48
+ plus bonus u(P).
49
+ Return: A tuple of (action, next_node)
50
+ """
51
+ return max(self._children.items(),
52
+ key=lambda act_node: act_node[1].get_value(c_puct))
53
+
54
+ def update(self, leaf_value):
55
+ """Update node values from leaf evaluation.
56
+ leaf_value: the value of subtree evaluation from the current player's
57
+ perspective.
58
+ """
59
+ # Count visit.
60
+ self._n_visits += 1
61
+ # Update Q, a running average of values for all visits.
62
+ self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
63
+
64
+ def update_recursive(self, leaf_value):
65
+ """Like a call to update(), but applied recursively for all ancestors.
66
+ """
67
+ # If it is not root, this node's parent should be updated first.
68
+ if self._parent:
69
+ self._parent.update_recursive(-leaf_value)
70
+ self.update(leaf_value)
71
+
72
+ def get_value(self, c_puct):
73
+ """Calculate and return the value for this node.
74
+ It is a combination of leaf evaluations Q, and this node's prior
75
+ adjusted for its visit count, u.
76
+ c_puct: a number in (0, inf) controlling the relative impact of
77
+ value Q, and prior probability P, on this node's score.
78
+ """
79
+ self._u = (c_puct * self._P *
80
+ np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
81
+ return self._Q + self._u
82
+
83
+ def is_leaf(self):
84
+ """Check if leaf node (i.e. no nodes below this have been expanded)."""
85
+ return self._children == {}
86
+
87
+ def is_root(self):
88
+ return self._parent is None
89
+
90
+
91
+ class MCTS(object):
92
+ """An implementation of Monte Carlo Tree Search."""
93
+
94
+ def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
95
+ """
96
+ policy_value_fn: a function that takes in a board state and outputs
97
+ a list of (action, probability) tuples and also a score in [-1, 1]
98
+ (i.e. the expected value of the end game score from the current
99
+ player's perspective) for the current player.
100
+ c_puct: a number in (0, inf) that controls how quickly exploration
101
+ converges to the maximum-value policy. A higher value means
102
+ relying on the prior more.
103
+ """
104
+ self._root = TreeNode(None, 1.0)
105
+ self._policy = policy_value_fn
106
+ self._c_puct = c_puct
107
+ self._n_playout = n_playout
108
+
109
+ def _playout(self, state, lock=None):
110
+ """Run a single playout from the root to the leaf, getting a value at
111
+ the leaf and propagating it back through its parents.
112
+ State is modified in-place, so a copy must be provided.
113
+ """
114
+ node = self._root
115
+ if lock is not None:
116
+ lock.acquire()
117
+ while(1):
118
+ if node.is_leaf():
119
+ break
120
+ # Greedily select next move.
121
+ action, node = node.select(self._c_puct)
122
+ state.do_move(action)
123
+ if lock is not None:
124
+ lock.release()
125
+ # Evaluate the leaf using a network which outputs a list of
126
+ # (action, probability) tuples p and also a score v in [-1, 1]
127
+ # for the current player.
128
+ action_probs, leaf_value = self._policy(state)
129
+ # Check for end of game.
130
+ end, winner = state.game_end()
131
+ if lock is not None:
132
+ lock.acquire()
133
+ if not end:
134
+ node.expand(action_probs)
135
+ else:
136
+ # for end state,return the "true" leaf_value
137
+ if winner == -1: # tie
138
+ leaf_value = 0.0
139
+ else:
140
+ leaf_value = (
141
+ 1.0 if winner == state.get_current_player() else -1.0
142
+ )
143
+
144
+ # Update value and visit count of nodes in this traversal.
145
+ node.update_recursive(-leaf_value)
146
+ if lock is not None:
147
+ lock.release()
148
+
149
+ def get_move_probs(self, state, temp=1e-3):
150
+ """Run all playouts sequentially and return the available actions and
151
+ their corresponding probabilities.
152
+ state: the current game state
153
+ temp: temperature parameter in (0, 1] controls the level of exploration
154
+ """
155
+
156
+ start_time_averge = 0
157
+
158
+ ### test multi-thread
159
+ lock = threading.Lock()
160
+ with ThreadPoolExecutor(max_workers=4) as executor:
161
+ for n in range(self._n_playout):
162
+ start_time = time.time()
163
+
164
+ state_copy = copy.deepcopy(state)
165
+ executor.submit(self._playout, state_copy, lock)
166
+ start_time_averge += (time.time() - start_time)
167
+ ### end test multi-thread
168
+
169
+ # t = time.time()
170
+ # for n in range(self._n_playout):
171
+ # start_time = time.time()
172
+
173
+ # state_copy = copy.deepcopy(state)
174
+ # self._playout(state_copy)
175
+ # start_time_averge += (time.time() - start_time)
176
+ # print('!!time!!:', time.time() - t)
177
+
178
+ # print(f" My MCTS sum_time: {start_time_averge }, total_simulation: {self._n_playout}")
179
+
180
+
181
+ # calc the move probabilities based on visit counts at the root node
182
+ act_visits = [(act, node._n_visits)
183
+ for act, node in self._root._children.items()]
184
+ acts, visits = zip(*act_visits)
185
+ act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
186
+
187
+ return acts, act_probs
188
+
189
+ def update_with_move(self, last_move):
190
+ """Step forward in the tree, keeping everything we already know
191
+ about the subtree.
192
+ """
193
+ if last_move in self._root._children:
194
+ self._root = self._root._children[last_move]
195
+ self._root._parent = None
196
+ else:
197
+ self._root = TreeNode(None, 1.0)
198
+
199
+ def __str__(self):
200
+ return "MCTS"
201
+
202
+
203
+ class MCTSPlayer(object):
204
+ """AI player based on MCTS"""
205
+
206
+ def __init__(self, policy_value_function,
207
+ c_puct=5, n_playout=2000, is_selfplay=0):
208
+ self.mcts = MCTS(policy_value_function, c_puct, n_playout)
209
+ self._is_selfplay = is_selfplay
210
+
211
+ def set_player_ind(self, p):
212
+ self.player = p
213
+
214
+ def reset_player(self):
215
+ self.mcts.update_with_move(-1)
216
+
217
+ def get_action(self, board, temp=1e-3, return_prob=0):
218
+ sensible_moves = board.availables
219
+ # the pi vector returned by MCTS as in the alphaGo Zero paper
220
+ move_probs = np.zeros(board.width*board.height)
221
+ if len(sensible_moves) > 0:
222
+ acts, probs = self.mcts.get_move_probs(board, temp)
223
+ move_probs[list(acts)] = probs
224
+ if self._is_selfplay:
225
+ # add Dirichlet Noise for exploration (needed for
226
+ # self-play training)
227
+ move = np.random.choice(
228
+ acts,
229
+ p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))
230
+ )
231
+ # update the root node and reuse the search tree
232
+ self.mcts.update_with_move(move)
233
+ else:
234
+ # with the default temp=1e-3, it is almost equivalent
235
+ # to choosing the move with the highest prob
236
+ move = np.random.choice(acts, p=probs)
237
+ # reset the root node
238
+ self.mcts.update_with_move(-1)
239
+ # location = board.move_to_location(move)
240
+ # print("AI move: %d,%d\n" % (location[0], location[1]))
241
+
242
+ if return_prob:
243
+ return move, move_probs
244
+ else:
245
+ return move
246
+ else:
247
+ print("WARNING: the board is full")
248
+
249
+ def __str__(self):
250
+ return "MCTS {}".format(self.player)
Gomoku_MCTS/mcts_pure.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ import copy
5
+ from operator import itemgetter
6
+ import time
7
+
8
+
9
+ def rollout_policy_fn(board):
10
+ """a coarse, fast version of policy_fn used in the rollout phase."""
11
+ # rollout randomly
12
+ action_probs = np.random.rand(len(board.availables))
13
+ return zip(board.availables, action_probs)
14
+
15
+ # 决策价值函数
16
+ def policy_value_fn(board):
17
+ """a function that takes in a state and outputs a list of (action, probability)
18
+ tuples and a score for the state"""
19
+ # return uniform probabilities and 0 score for pure MCTS
20
+ action_probs = np.ones(len(board.availables))/len(board.availables)
21
+ return zip(board.availables, action_probs), 0
22
+
23
+
24
+ class TreeNode(object):
25
+ """A node in the MCTS tree. Each node keeps track of its own value Q,
26
+ prior probability P, and its visit-count-adjusted prior score u.
27
+ """
28
+
29
+ def __init__(self, parent, prior_p):
30
+ self._parent = parent
31
+ self._children = {} # a map from action to TreeNode
32
+ self._n_visits = 0
33
+ self._Q = 0
34
+ self._u = 0
35
+ self._P = prior_p
36
+
37
+ def expand(self, action_priors):
38
+ """Expand tree by creating new children.
39
+ action_priors: a list of tuples of actions and their prior probability
40
+ according to the policy function.
41
+ """
42
+ for action, prob in action_priors:
43
+ if action not in self._children:
44
+ self._children[action] = TreeNode(self, prob)
45
+
46
+ def select(self, c_puct):
47
+ """Select action among children that gives maximum action value Q
48
+ plus bonus u(P).
49
+ Return: A tuple of (action, next_node)
50
+ """
51
+ return max(self._children.items(),
52
+ key=lambda act_node: act_node[1].get_value(c_puct))
53
+
54
+ def update(self, leaf_value):
55
+ """Update node values from leaf evaluation.
56
+ leaf_value: the value of subtree evaluation from the current player's
57
+ perspective.
58
+ """
59
+ # Count visit.
60
+ self._n_visits += 1
61
+ # Update Q, a running average of values for all visits.
62
+ # print("=====================================")
63
+ # print("Before, Q: {}, visits: {}, leaf_value: {}".format(self._Q, self._n_visits,leaf_value))
64
+ self._Q += 1.0*(leaf_value - self._Q) / self._n_visits
65
+ # print("After, Q: {}, visits: {}, leaf_value: {}".format(self._Q, self._n_visits,leaf_value))
66
+
67
+
68
+ def update_recursive(self, leaf_value):
69
+ """Like a call to update(), but applied recursively for all ancestors.
70
+ """
71
+ # If it is not root, this node's parent should be updated first.
72
+ if self._parent:
73
+ self._parent.update_recursive(-leaf_value)
74
+ self.update(leaf_value)
75
+
76
+ def get_value(self, c_puct):
77
+ """Calculate and return the value for this node.
78
+ It is a combination of leaf evaluations Q, and this node's prior
79
+ adjusted for its visit count, u.
80
+ c_puct: a number in (0, inf) controlling the relative impact of
81
+ value Q, and prior probability P, on this node's score.
82
+ """
83
+ self._u = (c_puct * self._P *
84
+ np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
85
+ return self._Q + self._u
86
+
87
+ def is_leaf(self):
88
+ """Check if leaf node (i.e. no nodes below this have been expanded).
89
+ """
90
+ return self._children == {}
91
+
92
+ def is_root(self):
93
+ return self._parent is None
94
+
95
+
96
+ class MCTS(object):
97
+ """A simple implementation of Monte Carlo Tree Search."""
98
+
99
+ def __init__(self, policy_value_fn, c_puct=5, n_playout=2000):
100
+ """
101
+ policy_value_fn: a function that takes in a board state and outputs
102
+ a list of (action, probability) tuples and also a score in [-1, 1]
103
+ (i.e. the expected value of the end game score from the current
104
+ player's perspective) for the current player.
105
+ c_puct: a number in (0, inf) that controls how quickly exploration
106
+ converges to the maximum-value policy. A higher value means
107
+ relying on the prior more. ???
108
+ """
109
+ self._root = TreeNode(None, 1.0)
110
+ self._policy = policy_value_fn
111
+ self._c_puct = c_puct
112
+ self._n_playout = n_playout
113
+
114
+ def _playout(self, state):
115
+ """Run a single playout from the root to the leaf, getting a value at
116
+ the leaf and propagating it back through its parents.
117
+ State is modified in-place, so a copy must be provided.
118
+ """
119
+ node = self._root
120
+ while(1):
121
+ if node.is_leaf():
122
+
123
+ break
124
+ # Greedily select next move.
125
+ action, node = node.select(self._c_puct)
126
+ state.do_move(action)
127
+
128
+ action_probs, _ = self._policy(state)
129
+ # Check for end of game
130
+ end, winner = state.game_end()
131
+ if not end:
132
+ node.expand(action_probs)
133
+ # Evaluate the leaf node by random rollout
134
+ leaf_value = self._evaluate_rollout(state)
135
+ # Update value and visit count of nodes in this traversal.
136
+ node.update_recursive(-leaf_value)
137
+
138
+ def _evaluate_rollout(self, state, limit=1000):
139
+ """Use the rollout policy to play until the end of the game,
140
+ returning +1 if the current player wins, -1 if the opponent wins,
141
+ and 0 if it is a tie.
142
+ """
143
+ player = state.get_current_player()
144
+ for i in range(limit):
145
+ end, winner = state.game_end()
146
+ if end:
147
+ break
148
+ action_probs = rollout_policy_fn(state)
149
+ max_action = max(action_probs, key=itemgetter(1))[0]
150
+ state.do_move(max_action)
151
+ else:
152
+ # If no break from the loop, issue a warning.
153
+ print("WARNING: rollout reached move limit")
154
+ if winner == -1: # tie
155
+ return 0
156
+ else:
157
+ return 1 if winner == player else -1
158
+
159
+ def get_move(self, state):
160
+ """Runs all playouts sequentially and returns the most visited action.
161
+ state: the current game state
162
+
163
+ Return: the selected action
164
+ """
165
+ start_time = time.time()
166
+ # n_playout 探索的次数
167
+ for n in range(self._n_playout):
168
+ state_copy = copy.deepcopy(state)
169
+ self._playout(state_copy)
170
+
171
+ need_time = time.time() - start_time
172
+
173
+ print(f" PureMCTS sum_time: {need_time / self._n_playout }, total_simulation: {self._n_playout}")
174
+
175
+ return max(self._root._children.items(),key=lambda act_node: act_node[1]._n_visits)[0], need_time / self._n_playout
176
+
177
+ def update_with_move(self, last_move):
178
+ """Step forward in the tree, keeping everything we already know
179
+ about the subtree.
180
+ """
181
+ if last_move in self._root._children:
182
+ self._root = self._root._children[last_move]
183
+ self._root._parent = None
184
+ else:
185
+ self._root = TreeNode(None, 1.0)
186
+
187
+ def __str__(self):
188
+ return "MCTS"
189
+
190
+
191
+
192
+ class MCTSPlayer(object):
193
+ """AI player based on MCTS"""
194
+ def __init__(self, c_puct=5, n_playout=2000):
195
+ self.mcts = MCTS(policy_value_fn, c_puct, n_playout)
196
+
197
+ def set_player_ind(self, p):
198
+ self.player = p
199
+
200
+ def reset_player(self):
201
+ self.mcts.update_with_move(-1)
202
+
203
+ def get_action(self, board):
204
+ sensible_moves = board.availables
205
+ if len(sensible_moves) > 0:
206
+ move, simul_mean_time = self.mcts.get_move(board)
207
+ self.mcts.update_with_move(-1)
208
+ print("MCTS move:", move)
209
+ return move, simul_mean_time
210
+ else:
211
+ print("WARNING: the board is full")
212
+
213
+
214
+ def __str__(self):
215
+ return "MCTS {}".format(self.player)
216
+
217
+
218
+ # 多了下面这一串代码
219
+
220
+ class Human_Player(object):
221
+ def __init__(self):
222
+ pass
223
+
224
+
225
+ def set_player_ind(self, p):
226
+ self.player = p
227
+
228
+
229
+ def get_action(self, board):
230
+
231
+
232
+ sensible_moves = board.availables
233
+ if len(sensible_moves) > 0:
234
+ # print(sensible_moves)
235
+
236
+ move = int(input("Input the move:"))
237
+ while (move not in sensible_moves ):
238
+ print(sensible_moves)
239
+ move = int(input("Input the move again:"))
240
+ return move
241
+ else:
242
+ print("WARNING: the board is full")
243
+
244
+ def __str__(self):
245
+ return "Human {}".format(self.player)
246
+
Gomoku_MCTS/policy_value_net_pytorch.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ An implementation of the policyValueNet in PyTorch
4
+ Tested in PyTorch 0.2.0 and 0.3.0
5
+
6
+ @author: Junxiao Song
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from torch.autograd import Variable
14
+ import numpy as np
15
+
16
+
17
+
18
+
19
+
20
+ class Net(nn.Module):
21
+ """policy-value network module"""
22
+ def __init__(self, board_width, board_height):
23
+ super(Net, self).__init__()
24
+
25
+ self.board_width = board_width
26
+ self.board_height = board_height
27
+ # common layers
28
+ self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
29
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
30
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
31
+ # action policy layers
32
+ self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
33
+ self.act_fc1 = nn.Linear(4*board_width*board_height,
34
+ board_width*board_height)
35
+ # state value layers
36
+ self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1)
37
+ self.val_fc1 = nn.Linear(2*board_width*board_height, 64)
38
+ self.val_fc2 = nn.Linear(64, 1)
39
+
40
+ def forward(self, state_input):
41
+ # common layers
42
+ x = F.relu(self.conv1(state_input))
43
+ x = F.relu(self.conv2(x))
44
+ x = F.relu(self.conv3(x))
45
+ # action policy layers
46
+ x_act = F.relu(self.act_conv1(x))
47
+ x_act = x_act.view(-1, 4*self.board_width*self.board_height)
48
+ x_act = F.log_softmax(self.act_fc1(x_act))
49
+ # state value layers
50
+ x_val = F.relu(self.val_conv1(x))
51
+ x_val = x_val.view(-1, 2*self.board_width*self.board_height)
52
+ x_val = F.relu(self.val_fc1(x_val))
53
+ x_val = F.tanh(self.val_fc2(x_val))
54
+ return x_act, x_val
55
+
56
+
57
+ class PolicyValueNet():
58
+ """policy-value network """
59
+ def __init__(self, board_width, board_height,
60
+ model_file=None, use_gpu=False):
61
+ self.use_gpu = use_gpu
62
+ self.board_width = board_width
63
+ self.board_height = board_height
64
+
65
+ # the policy value net module
66
+ if self.use_gpu:
67
+ self.policy_value_net = Net(board_width, board_height).cuda()
68
+ else:
69
+ self.policy_value_net = Net(board_width, board_height)
70
+
71
+ if model_file:
72
+ net_params = torch.load(model_file)
73
+ self.policy_value_net.load_state_dict(net_params)
74
+
75
+ def policy_value(self, state_batch):
76
+ """
77
+ input: a batch of states
78
+ output: a batch of action probabilities and state values
79
+ """
80
+ if self.use_gpu:
81
+ state_batch = Variable(torch.FloatTensor(state_batch).cuda())
82
+ log_act_probs, value = self.policy_value_net(state_batch)
83
+ act_probs = np.exp(log_act_probs.data.cpu().numpy())
84
+ return act_probs, value.data.cpu().numpy()
85
+ else:
86
+ state_batch = Variable(torch.FloatTensor(state_batch))
87
+ log_act_probs, value = self.policy_value_net(state_batch)
88
+ act_probs = np.exp(log_act_probs.data.numpy())
89
+ return act_probs, value.data.numpy()
90
+
91
+ def policy_value_fn(self, board):
92
+ """
93
+ input: board
94
+ output: a list of (action, probability) tuples for each available
95
+ action and the score of the board state
96
+ """
97
+ legal_positions = board.availables
98
+ current_state = np.ascontiguousarray(board.current_state().reshape(
99
+ -1, 4, self.board_width, self.board_height))
100
+ if self.use_gpu:
101
+ log_act_probs, value = self.policy_value_net(
102
+ Variable(torch.from_numpy(current_state)).cuda().float())
103
+ act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
104
+ else:
105
+ log_act_probs, value = self.policy_value_net(
106
+ Variable(torch.from_numpy(current_state)).float())
107
+ act_probs = np.exp(log_act_probs.data.numpy().flatten())
108
+ act_probs = zip(legal_positions, act_probs[legal_positions])
109
+ value = value.data[0][0]
110
+ return act_probs, value
111
+
112
+
113
+ # 搬到main_worker
114
+
115
+ def train_step(self, state_batch, mcts_probs, winner_batch, lr):
116
+ """perform a training step"""
117
+
118
+ # self.use_gpu = True
119
+ # wrap in Variable
120
+ if self.use_gpu:
121
+ state_batch = Variable(torch.FloatTensor(state_batch).cuda())
122
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
123
+ winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
124
+ else:
125
+ state_batch = Variable(torch.FloatTensor(state_batch))
126
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs))
127
+ winner_batch = Variable(torch.FloatTensor(winner_batch))
128
+
129
+ # zero the parameter gradients
130
+ self.optimizer.zero_grad()
131
+ # set learning rate
132
+ set_learning_rate(self.optimizer, lr)
133
+
134
+ # forward
135
+ log_act_probs, value = self.policy_value_net(state_batch)
136
+ # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
137
+ # Note: the L2 penalty is incorporated in optimizer
138
+ value_loss = F.mse_loss(value.view(-1), winner_batch)
139
+ policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
140
+ loss = value_loss + policy_loss
141
+ # backward and optimize
142
+ loss.backward()
143
+ self.optimizer.step()
144
+ # calc policy entropy, for monitoring only
145
+ entropy = -torch.mean(
146
+ torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
147
+ )
148
+ # return loss.data[0], entropy.data[0]
149
+ #for pytorch version >= 0.5 please use the following line instead.
150
+ return loss.item(), entropy.item()
151
+
152
+ # def get_policy_param(self):
153
+ # net_params = self.policy_value_net.state_dict()
154
+ # return net_params
155
+
156
+ # def save_model(self, model_file):
157
+ # """ save model params to file """
158
+ # net_params = self.get_policy_param() # get model params
159
+ # torch.save(net_params, model_file)
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183498.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f64431c679947fac92ef87e7f3d3b6a75c0cdf82e6fd0383451a98d778b7b21e
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183516.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaa7025d5d1daa88dce58231e0fba4d7a04391612c696e4c2e23292ad4169d80
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183568.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:440fac7c819d3368da1126c35e1a146b4ec3a3e614cb3c6e7e10063f9f0ced3c
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183629.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e45fb4fd64ac1d0a3ec9f5376d2122b48aa9c0a56e01ccfdc0a4ea0ed22188ed
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183640.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59b89d395abdaa5c2b8cb0922f4a465a9f06c59a429697c4d138e58033e6e1a0
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183667.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2ffadd87d44bba87f7bcd80fb424959536ff24e7a4e52a67238200c691befac
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183756.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a91a4478ed17986eaf70acf7a0fb3fe0db11cbcbe8eedf7655bfad9e6a4a9650
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700183820.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2641a889839f3eaf83a9ee90b4bdf0073488a9416044507d321a7bfc8bbad83f
3
+ size 40
Gomoku_MCTS/visualization/_blip_uni_cross_mu_bs512_lr0.002/events.out.tfevents.1700184097.LAPTOP-5AN2UHOO ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c33c3e165e65aac3f1f75ddf5a1a4a3fc6e494e5be728a3a455ff453c7a40100
3
+ size 3726
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Demo
3
- emoji: 🐢
4
  colorFrom: green
5
- colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.28.2
8
  app_file: app.py
 
1
  ---
2
+ title: Gomoku Zero
3
+ emoji: 📉
4
  colorFrom: green
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.28.2
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # 设置页面配置
3
+ st.set_page_config(
4
+ page_title="AI 3603 Gomoku Project",
5
+ page_icon="👋",
6
+ layout="wide",
7
+ initial_sidebar_state="collapsed"
8
+ )
9
+ # 大标题
10
+ st.write('<h1 style="text-align: center; color: black; font-weight: bold;">AI 3603 Gomoku Project 👋</h1>', unsafe_allow_html=True)
11
+ # 项目参与者
12
+ st.write('<p style="text-align: center; font-size: 20px;"><a href="https://github.com" style="color: blue; font-weight: normal; margin-right: 20px; text-decoration: none;">Jiaxin Li</a> \
13
+ <a href="https://github.com" style="color: blue; font-weight: normal; margin-right: 20px; text-decoration: none;">Junzhe Shen</a> \
14
+ <a href="https://github.com" style="color: blue; font-weight: normal; text-decoration: none;">Benhao Huang</a></p>', unsafe_allow_html=True)
15
+ # 标签
16
+ st.markdown("""
17
+ <div style="text-align: center;">
18
+ <a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">📄 Report</a>
19
+ <a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">💻 Code</a>
20
+ <a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">🌐 Space</a>
21
+ <a href="#" style="background-color: #343a40; color: white; font-size: 15px; padding: 10px 15px; margin: 5px; border-radius: 15px; text-decoration: none;">📊 PPT</a>
22
+ </div>
23
+ </br>
24
+ </br>
25
+ """, unsafe_allow_html=True)
26
+ # 项目介绍
27
+ st.markdown("""
28
+ <div style='color: black; font-size:18px'>Gomoku is an abstract strategy board game. Also called <span style='color:red;'>Gobang</span> or <span style='color:red;'>Five in a Row</span>,
29
+ it is traditionally played with Go pieces (black and white stones)
30
+ on a Go board. It is straightforward and fun, but also full of strategy and challenge.
31
+ Our project is aiming to apply Machine Learning techniques to build a powerful Gomoku AI.</div>
32
+ """,
33
+ unsafe_allow_html=True)
34
+ # 创新点和图片展示
35
+ st.write("<h2 style='text-align: center; color: black; font-weight: bold;'>Innovations We Made 👍</h2>", unsafe_allow_html=True)
36
+ col1, col2, col3 = st.columns(3)
37
+ with col1:
38
+ st.image("assets/favicon_circle.png", width=50) # 替换为你的图片 URL
39
+ st.caption("Innovation 1")
40
+ with col2:
41
+ st.image("assets/favicon_circle.png", width=50) # 替换为你的图片 URL
42
+ st.caption("Innovation 2")
43
+ with col3:
44
+ st.image("assets/favicon_circle.png", width=50) # 替换为你的图片 URL
45
+ st.caption("Innovation 3")
46
+ # 代码框架阐述和代码组件
47
+ st.write("<h2 style='text-align: center; color: black; font-weight: bold;'>Code Structure 🛠️</h2>", unsafe_allow_html=True)
48
+ st.code("""
49
+ import os
50
+ import streamlit as st
51
+ def main():
52
+ # your code here
53
+ if __name__ == "__main__":
54
+ main()
55
+ """, language="python")
56
+ st.markdown("---")
assets/favicon_circle.png ADDED
const.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ _BOARD_SIZE = 8
4
+ _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
5
+ _BLANK = 0
6
+ _BLACK = 1
7
+ _WHITE = 2
8
+ _PLAYER_SYMBOL = {
9
+ _WHITE: "⚪",
10
+ _BLANK: "➕",
11
+ _BLACK: "⚫",
12
+ }
13
+ _PLAYER_COLOR = {
14
+ _WHITE: "AI",
15
+ _BLANK: "Blank",
16
+ _BLACK: "YOU HUMAN",
17
+ }
18
+ _HORIZONTAL = np.array(
19
+ [
20
+ [0, 0, 0, 0, 0],
21
+ [0, 0, 0, 0, 0],
22
+ [1, 1, 1, 1, 1],
23
+ [0, 0, 0, 0, 0],
24
+ [0, 0, 0, 0, 0],
25
+ ]
26
+ )
27
+ _VERTICAL = np.array(
28
+ [
29
+ [0, 0, 1, 0, 0],
30
+ [0, 0, 1, 0, 0],
31
+ [0, 0, 1, 0, 0],
32
+ [0, 0, 1, 0, 0],
33
+ [0, 0, 1, 0, 0],
34
+ ]
35
+ )
36
+ _DIAGONAL_UP_LEFT = np.array(
37
+ [
38
+ [1, 0, 0, 0, 0],
39
+ [0, 1, 0, 0, 0],
40
+ [0, 0, 1, 0, 0],
41
+ [0, 0, 0, 1, 0],
42
+ [0, 0, 0, 0, 1],
43
+ ]
44
+ )
45
+ _DIAGONAL_UP_RIGHT = np.array(
46
+ [
47
+ [0, 0, 0, 0, 1],
48
+ [0, 0, 0, 1, 0],
49
+ [0, 0, 1, 0, 0],
50
+ [0, 1, 0, 0, 0],
51
+ [1, 0, 0, 0, 0],
52
+ ]
53
+ )
54
+
55
+ _ROOM_COLOR = {
56
+ True: _BLACK,
57
+ False: _WHITE,
58
+ }
pages/Player_VS_AI.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FileName: app.py
3
+ Author: Benhao Huang
4
+ Create Date: 2023/11/18
5
+ Description: this file is used to display our project and add visualization elements to the game, using Streamlit
6
+ """
7
+
8
+ import time
9
+ import pandas as pd
10
+ from copy import deepcopy
11
+
12
+ # import torch
13
+ import numpy as np
14
+ import streamlit as st
15
+ from scipy.signal import convolve # this is used to check if any player wins
16
+ from streamlit import session_state
17
+ from streamlit_server_state import server_state, server_state_lock
18
+ from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
19
+ import matplotlib.pyplot as plt
20
+
21
+ from const import (
22
+ _BLACK, # 1, for human
23
+ _WHITE, # 2 , for AI
24
+ _BLANK,
25
+ _PLAYER_COLOR,
26
+ _PLAYER_SYMBOL,
27
+ _ROOM_COLOR,
28
+ _VERTICAL,
29
+ _HORIZONTAL,
30
+ _DIAGONAL_UP_LEFT,
31
+ _DIAGONAL_UP_RIGHT,
32
+ _BOARD_SIZE,
33
+ _BOARD_SIZE_1D
34
+ )
35
+
36
+
37
+ # Utils
38
+ class Room:
39
+ def __init__(self, room_id) -> None:
40
+ self.ROOM_ID = room_id
41
+ # self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
42
+ self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=[_BLACK, _WHITE])
43
+ self.PLAYER = _BLACK
44
+ self.TURN = self.PLAYER
45
+ self.HISTORY = (0, 0)
46
+ self.WINNER = _BLANK
47
+ self.TIME = time.time()
48
+ self.MCTS = MCTSpure(c_puct=5, n_playout=10)
49
+ self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
50
+ self.current_move = -1
51
+ self.simula_time_list = []
52
+
53
+
54
+ def change_turn(cur):
55
+ return cur % 2 + 1
56
+
57
+
58
+ # Initialize the game
59
+ if "ROOM" not in session_state:
60
+ session_state.ROOM = Room("local")
61
+ if "OWNER" not in session_state:
62
+ session_state.OWNER = False
63
+
64
+ # Check server health
65
+ if "ROOMS" not in server_state:
66
+ with server_state_lock["ROOMS"]:
67
+ server_state.ROOMS = {}
68
+
69
+ # # Layout
70
+ # Main
71
+ TITLE = st.empty()
72
+ TITLE.header("🤖 AI 3603 Gomoku")
73
+ ROUND_INFO = st.empty()
74
+ st.markdown("<br>", unsafe_allow_html=True)
75
+ BOARD_PLATE = [
76
+ [cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
77
+ ]
78
+ LOG = st.empty()
79
+
80
+ # Sidebar
81
+ SCORE_TAG = st.sidebar.empty()
82
+ SCORE_PLATE = st.sidebar.columns(2)
83
+ # History scores
84
+ SCORE_TAG.subheader("Scores")
85
+
86
+ PLAY_MODE_INFO = st.sidebar.container()
87
+ MULTIPLAYER_TAG = st.sidebar.empty()
88
+ with st.sidebar.container():
89
+ ANOTHER_ROUND = st.empty()
90
+ RESTART = st.empty()
91
+ EXIT = st.empty()
92
+ GAME_INFO = st.sidebar.container()
93
+ message = st.empty()
94
+ PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
95
+ GAME_INFO.markdown(
96
+ """
97
+ ---
98
+ # <span style="color:black;">Freestyle Gomoku game. 🎲</span>
99
+ - no restrictions 🚫
100
+ - no regrets 😎
101
+ - swap players after one round is over 🔁
102
+ Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
103
+ ##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
104
+ """,
105
+ unsafe_allow_html=True,
106
+ )
107
+
108
+
109
+ def restart() -> None:
110
+ """
111
+ Restart the game.
112
+ """
113
+ session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
114
+
115
+
116
+ RESTART.button(
117
+ "Reset",
118
+ on_click=restart,
119
+ help="Clear the board as well as the scores",
120
+ )
121
+
122
+
123
+ # Draw the board
124
+ def gomoku():
125
+ """
126
+ Draw the board.
127
+ Handle the main logic.
128
+ """
129
+
130
+ # Restart the game
131
+
132
+ # Continue new round
133
+ def another_round() -> None:
134
+ """
135
+ Continue new round.
136
+ """
137
+ session_state.ROOM = deepcopy(session_state.ROOM)
138
+ session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
139
+ session_state.ROOM.PLAYER = session_state.ROOM.PLAYER % 2 + 1
140
+ session_state.ROOM.TURN = session_state.ROOM.PLAYER
141
+ session_state.ROOM.WINNER = _BLANK # 0
142
+ session_state.ROOM.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
143
+
144
+ # Room status sync
145
+ def sync_room() -> bool:
146
+ room_id = session_state.ROOM.ROOM_ID
147
+ if room_id not in server_state.ROOMS.keys():
148
+ session_state.ROOM = Room("local")
149
+ return False
150
+ elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
151
+ return False
152
+ elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
153
+ # Only acquire the lock when writing to the server state
154
+ with server_state_lock["ROOMS"]:
155
+ server_rooms = server_state.ROOMS
156
+ server_rooms[room_id] = session_state.ROOM
157
+ server_state.ROOMS = server_rooms
158
+ return True
159
+ else:
160
+ session_state.ROOM = server_state.ROOMS[room_id]
161
+ return True
162
+
163
+ # Check if winner emerge from move
164
+ def check_win() -> int:
165
+ """
166
+ Use convolution to check if any player wins.
167
+ """
168
+ vertical = convolve(
169
+ session_state.ROOM.BOARD.board_map,
170
+ _VERTICAL,
171
+ mode="same",
172
+ )
173
+ horizontal = convolve(
174
+ session_state.ROOM.BOARD.board_map,
175
+ _HORIZONTAL,
176
+ mode="same",
177
+ )
178
+ diagonal_up_left = convolve(
179
+ session_state.ROOM.BOARD.board_map,
180
+ _DIAGONAL_UP_LEFT,
181
+ mode="same",
182
+ )
183
+ diagonal_up_right = convolve(
184
+ session_state.ROOM.BOARD.board_map,
185
+ _DIAGONAL_UP_RIGHT,
186
+ mode="same",
187
+ )
188
+ if (
189
+ np.max(
190
+ [
191
+ np.max(vertical),
192
+ np.max(horizontal),
193
+ np.max(diagonal_up_left),
194
+ np.max(diagonal_up_right),
195
+ ]
196
+ )
197
+ == 5 * _BLACK
198
+ ):
199
+ winner = _BLACK
200
+ elif (
201
+ np.min(
202
+ [
203
+ np.min(vertical),
204
+ np.min(horizontal),
205
+ np.min(diagonal_up_left),
206
+ np.min(diagonal_up_right),
207
+ ]
208
+ )
209
+ == 5 * _WHITE
210
+ ):
211
+ winner = _WHITE
212
+ else:
213
+ winner = _BLANK
214
+ return winner
215
+
216
+ # Triggers the board response on click
217
+ def handle_click(x, y):
218
+ """
219
+ Controls whether to pass on / continue current board / may start new round
220
+ """
221
+ if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
222
+ pass
223
+ elif (
224
+ session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
225
+ and _ROOM_COLOR[session_state.OWNER]
226
+ != server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
227
+ ):
228
+ sync_room()
229
+
230
+ # normal play situation
231
+ elif session_state.ROOM.WINNER == _BLANK:
232
+ # session_state.ROOM = deepcopy(session_state.ROOM)
233
+ print("View of human player: ", session_state.ROOM.BOARD.board_map)
234
+ move = session_state.ROOM.BOARD.location_to_move((x, y))
235
+ session_state.ROOM.current_move = move
236
+ session_state.ROOM.BOARD.do_move(move)
237
+ session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
238
+ session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
239
+
240
+ session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
241
+ win, winner = session_state.ROOM.BOARD.game_end()
242
+ if win:
243
+ session_state.ROOM.WINNER = winner
244
+ session_state.ROOM.HISTORY = (
245
+ session_state.ROOM.HISTORY[0]
246
+ + int(session_state.ROOM.WINNER == _WHITE),
247
+ session_state.ROOM.HISTORY[1]
248
+ + int(session_state.ROOM.WINNER == _BLACK),
249
+ )
250
+ session_state.ROOM.TIME = time.time()
251
+
252
+ def forbid_click(x, y):
253
+ # st.warning('This posistion has been occupied!!!!', icon="⚠️")
254
+ st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
255
+ print("asdas")
256
+
257
+ # Draw board
258
+ def draw_board(response: bool):
259
+ """construct each buttons for all cells of the board"""
260
+
261
+ if response and session_state.ROOM.TURN == _BLACK: # human turn
262
+ print("Your turn")
263
+ # construction of clickable buttons
264
+ for i, row in enumerate(session_state.ROOM.BOARD.board_map):
265
+ # print("row:", row)
266
+ for j, cell in enumerate(row):
267
+ if (
268
+ i * _BOARD_SIZE + j
269
+ in (session_state.ROOM.COORDINATE_1D)
270
+ ):
271
+ # disable click for GPT choices
272
+ BOARD_PLATE[i][j].button(
273
+ _PLAYER_SYMBOL[cell],
274
+ key=f"{i}:{j}",
275
+ args=(i, j),
276
+ on_click=forbid_click
277
+ )
278
+ else:
279
+ # enable click for other cells available for human choices
280
+ BOARD_PLATE[i][j].button(
281
+ _PLAYER_SYMBOL[cell],
282
+ key=f"{i}:{j}",
283
+ on_click=handle_click,
284
+ args=(i, j),
285
+ )
286
+
287
+
288
+ elif response and session_state.ROOM.TURN == _WHITE: # AI turn
289
+ message.empty()
290
+ with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
291
+ time.sleep(0.1)
292
+ print("AI's turn")
293
+ print("Below are current board under AI's view")
294
+ print(session_state.ROOM.BOARD.board_map)
295
+ move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD)
296
+ session_state.ROOM.simula_time_list.append(simul_time)
297
+ print("AI takes move: ", move)
298
+ session_state.ROOM.current_move = move
299
+ gpt_response = move
300
+ gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
301
+ print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
302
+ move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
303
+ print("Location to move: ", move)
304
+ session_state.ROOM.BOARD.do_move(move)
305
+ # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
306
+ session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
307
+
308
+ # construction of clickable buttons
309
+ for i, row in enumerate(session_state.ROOM.BOARD.board_map):
310
+ # print("row:", row)
311
+ for j, cell in enumerate(row):
312
+ if (
313
+ i * _BOARD_SIZE + j
314
+ in (session_state.ROOM.COORDINATE_1D)
315
+ ):
316
+ # disable click for GPT choices
317
+ BOARD_PLATE[i][j].button(
318
+ _PLAYER_SYMBOL[cell],
319
+ key=f"{i}:{j}",
320
+ args=(i, j),
321
+ on_click=forbid_click
322
+ )
323
+ else:
324
+ # enable click for other cells available for human choices
325
+ BOARD_PLATE[i][j].button(
326
+ _PLAYER_SYMBOL[cell],
327
+ key=f"{i}:{j}",
328
+ on_click=handle_click,
329
+ args=(i, j),
330
+ )
331
+
332
+ message.markdown(
333
+ 'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
334
+ simul_time),
335
+ unsafe_allow_html=True
336
+ )
337
+ LOG.subheader("Logs")
338
+ # change turn
339
+ session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
340
+ # session_state.ROOM.WINNER = check_win()
341
+
342
+ win, winner = session_state.ROOM.BOARD.game_end()
343
+ if win:
344
+ session_state.ROOM.WINNER = winner
345
+
346
+ session_state.ROOM.HISTORY = (
347
+ session_state.ROOM.HISTORY[0]
348
+ + int(session_state.ROOM.WINNER == _WHITE),
349
+ session_state.ROOM.HISTORY[1]
350
+ + int(session_state.ROOM.WINNER == _BLACK),
351
+ )
352
+ session_state.ROOM.TIME = time.time()
353
+
354
+ if not response or session_state.ROOM.WINNER != _BLANK:
355
+ print("Game over")
356
+ for i, row in enumerate(session_state.ROOM.BOARD.board_map):
357
+ for j, cell in enumerate(row):
358
+ BOARD_PLATE[i][j].write(
359
+ _PLAYER_SYMBOL[cell],
360
+ key=f"{i}:{j}",
361
+ )
362
+
363
+ # Game process control
364
+ def game_control():
365
+ if session_state.ROOM.WINNER != _BLANK:
366
+ draw_board(False)
367
+ else:
368
+ draw_board(True)
369
+ if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
370
+ ANOTHER_ROUND.button(
371
+ "Play Next round!",
372
+ on_click=another_round,
373
+ help="Clear board and swap first player",
374
+ )
375
+
376
+ # Infos
377
+ def update_info() -> None:
378
+ # Additional information
379
+ SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
380
+ SCORE_PLATE[1].metric("Black", session_state.ROOM.HISTORY[1])
381
+ if session_state.ROOM.WINNER != _BLANK:
382
+ st.balloons()
383
+ ROUND_INFO.write(
384
+ f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
385
+ )
386
+
387
+ # elif 0 not in session_state.ROOM.BOARD.board_map:
388
+ # ROUND_INFO.write("#### **Tie**")
389
+ # else:
390
+ # ROUND_INFO.write(
391
+ # f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
392
+ # )
393
+
394
+ # draw the plot for simulation time
395
+ # 创建一个 DataFrame
396
+
397
+ print(session_state.ROOM.simula_time_list)
398
+ st.markdown("<br>", unsafe_allow_html=True)
399
+ st.markdown("<br>", unsafe_allow_html=True)
400
+ chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
401
+ st.line_chart(chart_data)
402
+
403
+ # The main game loop
404
+ game_control()
405
+ update_info()
406
+
407
+
408
+ if __name__ == "__main__":
409
+ gomoku()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pandas~=2.1.3
2
+ numpy~=1.26.2
3
+ streamlit~=1.28.2
4
+ matplotlib~=3.8.2
5
+ scipy~=1.11.3
6
+ torch~=2.1.1
7
+ streamlit-server-state==0.17.1