sodastar commited on
Commit
0819f4e
·
1 Parent(s): b7dd8b4

Deploying Othello Flask backend with Docker

Browse files
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +23 -0
  3. LICENSE +21 -0
  4. alphazero.py +487 -0
  5. app.py +343 -0
  6. game.py +417 -0
  7. requirements.txt +12 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # othello-backend/Dockerfile
2
+
3
+ # 1. 基础镜像:使用官方 Python 镜像,包含所有必要的系统库
4
+ FROM python:3.9-slim
5
+
6
+ # 2. 设置工作目录:所有操作都在 /app 目录下进行
7
+ WORKDIR /app
8
+
9
+ # 3. 复制依赖文件并安装 Python 库
10
+ # 先复制并安装依赖,以便利用 Docker 缓存
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # 4. 复制所有应用代码和模型
15
+ # 注意:你需要确保 game.py, alphazero.py, app.py, checkpoint/ 都位于 othello-backend 目录
16
+ COPY . .
17
+
18
+ # 5. 暴露端口:Hugging Face 默认使用 7860 端口接收 HTTP 流量
19
+ EXPOSE 7860
20
+
21
+ # 6. 容器启动命令:运行你的 Flask 应用
22
+ # CMD 会在容器启动时执行
23
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 wangxuguang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
alphazero.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import time
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ from tqdm import tqdm
11
+ from collections import deque
12
+ from pickle import Pickler, Unpickler
13
+ from random import shuffle
14
+
15
+ from game import *
16
+ logging.basicConfig(level = logging.INFO)
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ class MCTS():
21
+ """
22
+ This class handles the MCTS tree.
23
+ """
24
+
25
+ def __init__(self, game, nnet, args):
26
+ self.game = game
27
+ self.nnet = nnet
28
+ self.args = args
29
+ self.Qsa = {} # stores Q values for s,a (as defined in the paper)
30
+ self.Nsa = {} # stores #times edge s,a was visited
31
+ self.Ns = {} # stores #times board s was visited
32
+ self.Ps = {} # stores initial policy (returned by neural net)
33
+
34
+ self.Es = {} # stores game.getGameEnded for board s
35
+ self.Vs = {} # stores game.getValidMoves for board s
36
+
37
+ def getActionProb(self, canonicalBoard, temp=1):
38
+ """
39
+ This function performs numMCTSSims simulations of MCTS starting from
40
+ canonicalBoard.
41
+
42
+ Returns:
43
+ probs: a policy vector where the probability of the ith action is
44
+ proportional to Nsa[(s,a)]**(1./temp)
45
+ """
46
+ for i in range(self.args.numMCTSSims):
47
+ self.search(canonicalBoard)
48
+
49
+ s = self.game.stringRepresentation(canonicalBoard)
50
+ counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
51
+
52
+ if temp == 0:
53
+ bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
54
+ bestA = np.random.choice(bestAs)
55
+ probs = [0] * len(counts)
56
+ probs[bestA] = 1
57
+ return probs
58
+
59
+ counts = [x ** (1. / temp) for x in counts]
60
+ counts_sum = float(sum(counts))
61
+ probs = [x / counts_sum for x in counts]
62
+ return probs
63
+
64
+ def search(self, canonicalBoard):
65
+ """
66
+ This function performs one iteration of MCTS. It is recursively called
67
+ till a leaf node is found. The action chosen at each node is one that
68
+ has the maximum upper confidence bound as in the paper.
69
+
70
+ Once a leaf node is found, the neural network is called to return an
71
+ initial policy P and a value v for the state. This value is propagated
72
+ up the search path. In case the leaf node is a terminal state, the
73
+ outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
74
+ updated.
75
+
76
+ NOTE: Since v is in [-1,1] and if v is the value of a
77
+ state for the current player, then its value is -v for the other player.
78
+
79
+ Returns:
80
+ v: the value of the current canonicalBoard
81
+ """
82
+
83
+ s = self.game.stringRepresentation(canonicalBoard)
84
+
85
+ if s not in self.Es:
86
+ self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
87
+ if self.Es[s] is not None:
88
+ # terminal node
89
+ return self.Es[s]
90
+
91
+ if s not in self.Ps:
92
+ # leaf node
93
+ self.Ps[s], v = self.nnet.predict(canonicalBoard)
94
+ valids = self.game.getValidMoves(canonicalBoard, 1)
95
+ self.Ps[s] = self.Ps[s] * valids # masking invalid moves
96
+ sum_Ps_s = np.sum(self.Ps[s])
97
+ if sum_Ps_s > 0:
98
+ self.Ps[s] /= sum_Ps_s # renormalize
99
+ else:
100
+ # if all valid moves were masked make all valid moves equally probable
101
+ log.error("All valid moves were masked, doing a workaround.")
102
+ self.Ps[s] = self.Ps[s] + valids
103
+ self.Ps[s] /= np.sum(self.Ps[s])
104
+
105
+ self.Vs[s] = valids
106
+ self.Ns[s] = 0
107
+ return v
108
+
109
+ valids = self.Vs[s]
110
+ cur_best = -float('inf')
111
+ best_act = -1
112
+
113
+ # pick the action with the highest upper confidence bound
114
+ for a in range(self.game.getActionSize()):
115
+ if valids[a]:
116
+ u = self.Qsa.get((s, a), 0) + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
117
+ 1 + self.Nsa.get((s, a), 0))
118
+
119
+ if u > cur_best:
120
+ cur_best = u
121
+ best_act = a
122
+
123
+ a = best_act
124
+ next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
125
+ next_s = self.game.getCanonicalForm(next_s, next_player)
126
+
127
+ v = -self.search(next_s)
128
+
129
+ if (s, a) in self.Qsa:
130
+ self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
131
+ self.Nsa[(s, a)] += 1
132
+
133
+ else:
134
+ self.Qsa[(s, a)] = v
135
+ self.Nsa[(s, a)] = 1
136
+
137
+ self.Ns[s] += 1
138
+ return v
139
+
140
+ class OthelloNNet(nn.Module):
141
+ def __init__(self, game, args):
142
+ # game params
143
+ self.board_x, self.board_y = game.getBoardSize()
144
+ self.action_size = game.getActionSize()
145
+ self.args = args
146
+
147
+ super(OthelloNNet, self).__init__()
148
+ self.conv1 = nn.Conv2d(1, args.num_channels, 3, stride=1, padding=1)
149
+ self.conv2 = nn.Conv2d(args.num_channels, args.num_channels, 3, stride=1, padding=1)
150
+ self.conv3 = nn.Conv2d(args.num_channels, args.num_channels, 3, stride=1)
151
+ self.conv4 = nn.Conv2d(args.num_channels, args.num_channels, 3, stride=1)
152
+
153
+ self.bn1 = nn.BatchNorm2d(args.num_channels)
154
+ self.bn2 = nn.BatchNorm2d(args.num_channels)
155
+ self.bn3 = nn.BatchNorm2d(args.num_channels)
156
+ self.bn4 = nn.BatchNorm2d(args.num_channels)
157
+
158
+ self.fc1 = nn.Linear(args.num_channels*(self.board_x-4)*(self.board_y-4), 1024)
159
+ self.fc_bn1 = nn.BatchNorm1d(1024)
160
+
161
+ self.fc2 = nn.Linear(1024, 512)
162
+ self.fc_bn2 = nn.BatchNorm1d(512)
163
+
164
+ self.fc3 = nn.Linear(512, self.action_size)
165
+ self.fc4 = nn.Linear(512, 1)
166
+
167
+ def forward(self, s):
168
+ # you can add residual to the network
169
+ # s: batch_size x board_x x board_y
170
+ s = s.view(-1, 1, self.board_x, self.board_y) # batch_size x 1 x board_x x board_y
171
+ s = F.relu(self.bn1(self.conv1(s))) # batch_size x num_channels x board_x x board_y
172
+ s = F.relu(self.bn2(self.conv2(s))) # batch_size x num_channels x board_x x board_y
173
+ s = F.relu(self.bn3(self.conv3(s))) # batch_size x num_channels x (board_x-2) x (board_y-2)
174
+ s = F.relu(self.bn4(self.conv4(s))) # batch_size x num_channels x (board_x-4) x (board_y-4)
175
+ s = s.view(-1, self.args.num_channels*(self.board_x-4)*(self.board_y-4))
176
+
177
+ s = F.dropout(F.relu(self.fc_bn1(self.fc1(s))), p=self.args.dropout, training=self.training) # batch_size x 1024
178
+ s = F.dropout(F.relu(self.fc_bn2(self.fc2(s))), p=self.args.dropout, training=self.training) # batch_size x 512
179
+
180
+ pi = self.fc3(s) # batch_size x action_size
181
+ v = self.fc4(s) # batch_size x 1
182
+
183
+ return F.log_softmax(pi, dim=1), torch.tanh(v)
184
+
185
+
186
+ class AverageMeter(object):
187
+ """From https://github.com/pytorch/examples/blob/master/imagenet/main.py"""
188
+
189
+ def __init__(self):
190
+ self.val = 0
191
+ self.avg = 0
192
+ self.sum = 0
193
+ self.count = 0
194
+
195
+ def __repr__(self):
196
+ return f'{self.avg:.2e}'
197
+
198
+ def update(self, val, n=1):
199
+ self.val = val
200
+ self.sum += val * n
201
+ self.count += n
202
+ self.avg = self.sum / self.count
203
+
204
+
205
+ class NNetWrapper():
206
+ def __init__(self, game, args):
207
+ self.nnet = OthelloNNet(game, args)
208
+ self.board_x, self.board_y = game.getBoardSize()
209
+ self.action_size = game.getActionSize()
210
+ self.args = args
211
+
212
+ if args.cuda:
213
+ self.nnet.cuda()
214
+
215
+ def train(self, examples):
216
+ """
217
+ examples: list of examples, each example is of form (board, pi, v)
218
+ """
219
+ optimizer = optim.Adam(self.nnet.parameters(), lr=self.args.lr)
220
+
221
+ for epoch in range(self.args.epochs):
222
+ print('EPOCH ::: ' + str(epoch + 1))
223
+ self.nnet.train()
224
+ pi_losses = AverageMeter()
225
+ v_losses = AverageMeter()
226
+
227
+ batch_count = int(len(examples) / self.args.batch_size)
228
+
229
+ t = tqdm(range(batch_count), desc='Training Net')
230
+ for _ in t:
231
+ sample_ids = np.random.randint(len(examples), size=self.args.batch_size)
232
+ boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
233
+ boards = torch.FloatTensor(np.array(boards).astype(np.float32)) #np.float32 or np.float64 have a difference?
234
+ target_pis = torch.FloatTensor(np.array(pis))
235
+ target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
236
+
237
+ if self.args.cuda:
238
+ # boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
239
+ boards, target_pis, target_vs = boards.cuda(), target_pis.cuda(), target_vs.cuda()
240
+
241
+ # compute output
242
+ out_pi, out_v = self.nnet(boards)
243
+ l_pi = self.loss_pi(target_pis, out_pi)
244
+ l_v = self.loss_v(target_vs, out_v)
245
+ total_loss = l_pi + l_v
246
+
247
+ # record loss
248
+ pi_losses.update(l_pi.item(), boards.size(0))
249
+ v_losses.update(l_v.item(), boards.size(0))
250
+ t.set_postfix(Loss_pi=pi_losses, Loss_v=v_losses)
251
+
252
+ # compute gradient and do SGD step
253
+ optimizer.zero_grad()
254
+ total_loss.backward()
255
+ optimizer.step()
256
+
257
+ def predict(self, board):
258
+ """
259
+ board: np array with board
260
+ """
261
+ # timing
262
+ # start = time.time()
263
+
264
+ # preparing input
265
+ board = torch.FloatTensor(board.astype(np.float32))
266
+ if self.args.cuda: board = board.cuda()
267
+ board = board.view(1, self.board_x, self.board_y)
268
+ self.nnet.eval()
269
+ with torch.no_grad():
270
+ pi, v = self.nnet(board)
271
+
272
+ # print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start))
273
+ return torch.exp(pi).data.cpu().numpy()[0], v.data.cpu().numpy()[0]
274
+
275
+ def loss_pi(self, targets, outputs):
276
+ return -torch.sum(targets * outputs) / targets.size()[0]
277
+
278
+ def loss_v(self, targets, outputs):
279
+ return torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0]
280
+
281
+ def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
282
+ filepath = os.path.join(folder, filename)
283
+ if not os.path.exists(folder):
284
+ print("Checkpoint Directory does not exist! Making directory {}".format(folder))
285
+ os.mkdir(folder)
286
+ else:
287
+ print("Checkpoint Directory exists! ")
288
+ torch.save({
289
+ 'state_dict': self.nnet.state_dict(),
290
+ }, filepath)
291
+
292
+ def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
293
+ filepath = os.path.join(folder, filename)
294
+ if not os.path.exists(filepath):
295
+ raise ValueError("No model in path {}".format(filepath))
296
+ map_location = None if self.args.cuda else 'cpu'
297
+ checkpoint = torch.load(filepath, map_location=map_location)
298
+ self.nnet.load_state_dict(checkpoint['state_dict'])
299
+
300
+
301
+ class SelfPlay():
302
+ """
303
+ This class executes the self-play + learning.
304
+ """
305
+
306
+ def __init__(self, game, nnet, args):
307
+ self.game = game
308
+ self.nnet = nnet
309
+ self.pnet = self.nnet.__class__(self.game, args) # the competitor network
310
+ self.args = args
311
+ self.mcts = MCTS(self.game, self.nnet, self.args)
312
+ self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
313
+
314
+ def executeEpisode(self):
315
+ """
316
+ This function executes one episode of self-play, starting with player 1.
317
+ As the game is played, each turn is added as a training example to
318
+ trainExamples. The game is played till the game ends. After the game
319
+ ends, the outcome of the game is used to assign values to each example
320
+ in trainExamples.
321
+
322
+ It uses a temp=1 if episodeStep < tempThreshold, and thereafter
323
+ uses temp=0.
324
+
325
+ Returns:
326
+ trainExamples: a list of examples of the form (canonicalBoard, pi, v)
327
+ """
328
+ trainExamples = []
329
+ board = self.game.getInitBoard()
330
+ self.curPlayer = 1
331
+ episodeStep = 0
332
+
333
+ while True:
334
+ episodeStep += 1
335
+ canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)
336
+ temp = int(episodeStep < self.args.tempThreshold)
337
+
338
+ pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
339
+ sym = self.game.getSymmetries(canonicalBoard, pi)
340
+ for b, p in sym:
341
+ trainExamples.append([b, self.curPlayer, p, None])
342
+
343
+ action = np.random.choice(len(pi), p=pi)
344
+ board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
345
+
346
+ r = self.game.getGameEnded(board, self.curPlayer)
347
+
348
+ if r is not None:
349
+ # r * (1 if self.curPlayer == x[1] else -1) means 1 for winner, -1 for loser, 0 for draw.
350
+ return [(x[0], x[2], r * (1 if self.curPlayer == x[1] else -1)) for x in trainExamples]
351
+
352
+ def learn(self):
353
+ """
354
+ Performs numIters iterations with numEps episodes of self-play in each
355
+ iteration. After every iteration, it retrains neural network with
356
+ examples in trainExamples (which has a maximum length of maxlenofQueue).
357
+ It then pits the new neural network against the old one and accepts it
358
+ only if it wins >= updateThreshold fraction of games.
359
+ """
360
+
361
+ for i in range(1, self.args.numIters + 1):
362
+ # bookkeeping
363
+ log.info(f'Starting Iter #{i} ...')
364
+ # examples of the iteration
365
+ iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)
366
+
367
+ for _ in tqdm(range(self.args.numEps), desc="Self Play"):
368
+ self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
369
+ iterationTrainExamples += self.executeEpisode()
370
+
371
+ # save the iteration examples to the history
372
+ self.trainExamplesHistory.append(iterationTrainExamples)
373
+
374
+ if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
375
+ log.warning(
376
+ f"Removing the oldest entry in trainExamples. len(trainExamplesHistory) = {len(self.trainExamplesHistory)}")
377
+ self.trainExamplesHistory.pop(0)
378
+
379
+ # shuffle examples before training
380
+ trainExamples = []
381
+ for e in self.trainExamplesHistory:
382
+ trainExamples.extend(e)
383
+ shuffle(trainExamples)
384
+
385
+ # training new network, keeping a copy of the old one
386
+ self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
387
+ self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
388
+ pmcts = MCTS(self.game, self.pnet, self.args)
389
+
390
+ self.nnet.train(trainExamples)
391
+ nmcts = MCTS(self.game, self.nnet, self.args)
392
+
393
+ log.info('PITTING AGAINST PREVIOUS VERSION')
394
+ arena = Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)),
395
+ lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game)
396
+ pwins, nwins, draws = arena.playGames(self.args.arenaCompare)
397
+
398
+ log.info('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws))
399
+ if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold:
400
+ log.info('REJECTING NEW MODEL')
401
+ self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
402
+ else:
403
+ log.info('ACCEPTING NEW MODEL')
404
+ self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar')
405
+
406
+
407
+ class dotdict(dict):
408
+ def __getattr__(self, name):
409
+ return self[name]
410
+
411
+
412
+ args = dotdict({
413
+ 'lr': 0.001,
414
+ 'dropout': 0.1,
415
+ 'epochs': 10,
416
+ 'batch_size': 64,
417
+ 'cuda': torch.cuda.is_available(),
418
+ 'num_channels': 512,
419
+
420
+ 'numIters': 200,
421
+ 'numEps': 100, # Number of complete self-play games to simulate during a new iteration.
422
+ 'tempThreshold': 15, #
423
+ 'updateThreshold': 0.6, # During arena playoff, new neural net will be accepted if threshold ratio or more of games are won.
424
+ 'maxlenOfQueue': 200000, # Number of game examples to train the neural networks.
425
+ 'numItersForTrainExamplesHistory': 20,
426
+ 'numMCTSSims': 25, # Number of games moves for MCTS to simulate.
427
+ 'arenaCompare': 40, # Number of games to play during arena play to determine if new net will be accepted.
428
+ 'cpuct': 1,
429
+
430
+ 'checkpoint': './temp/',
431
+ 'load_model': False,
432
+ 'load_folder_file': ('./temp/','best.pth.tar'),
433
+ })
434
+
435
+ def main():
436
+ import argparse
437
+ parser = argparse.ArgumentParser()
438
+ parser.add_argument('--train', action="store_true")
439
+ parser.add_argument('--board_size', type=int, default=6)
440
+ # play arguments
441
+ parser.add_argument('--play', action="store_true")
442
+ parser.add_argument('--verbose', action="store_true")
443
+ parser.add_argument('--round', type=int, default=2)
444
+ parser.add_argument('--player1', type=str, default='human', choices=['human', 'random', 'greedy', 'alphazero'])
445
+ parser.add_argument('--player2', type=str, default='alphazero', choices=['human', 'random', 'greedy', 'alphazero'])
446
+ parser.add_argument('--ckpt_file', type=str, default='best.pth.tar')
447
+ args_input = vars(parser.parse_args())
448
+ for k,v in args_input.items():
449
+ args[k] = v
450
+
451
+ g = OthelloGame(args.board_size)
452
+
453
+ if args.train:
454
+ nnet = NNetWrapper(g, args)
455
+ if args.load_model:
456
+ log.info('Loading checkpoint "%s/%s"...', args.load_folder_file[0], args.load_folder_file[1])
457
+ nnet.load_checkpoint(args.load_folder_file[0], args.load_folder_file[1])
458
+
459
+ log.info('Loading the SelfCoach...')
460
+ s = SelfPlay(g, nnet, args)
461
+
462
+ log.info('Starting the learning process 🎉')
463
+ s.learn()
464
+
465
+ if args.play:
466
+ def getPlayFunc(name):
467
+ if name == 'human':
468
+ return HumanOthelloPlayer(g).play
469
+ elif name == 'random':
470
+ return RandomPlayer(g).play
471
+ elif name == 'greedy':
472
+ return GreedyOthelloPlayer(g).play
473
+ elif name == 'alphazero':
474
+ nnet = NNetWrapper(g, args)
475
+ nnet.load_checkpoint(args.checkpoint, args.ckpt_file)
476
+ mcts = MCTS(g, nnet, dotdict({'numMCTSSims': 200, 'cpuct':1.0}))
477
+ return lambda x: np.argmax(mcts.getActionProb(x, temp=0))
478
+ else:
479
+ raise ValueError('not support player name {}'.format(name))
480
+ player1 = getPlayFunc(args.player1)
481
+ player2 = getPlayFunc(args.player2)
482
+ arena = Arena(player1, player2, g, display=OthelloGame.display)
483
+ results = arena.playGames(args.round, verbose=args.verbose)
484
+ print("Final results: Player1 wins {}, Player2 wins {}, Draws {}".format(*results))
485
+
486
+ if __name__ == '__main__':
487
+ main()
app.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import json
4
+ import os
5
+ import logging
6
+ from flask import Flask, request, jsonify
7
+ from flask_cors import CORS
8
+ import time
9
+
10
+ # 导入你的游戏和 AI 模块
11
+ from game import OthelloGame #
12
+ from alphazero import NNetWrapper, MCTS, dotdict #
13
+
14
+ # 配置日志
15
+ logging.basicConfig(level=logging.INFO)
16
+ log = logging.getLogger(__name__)
17
+
18
+ app = Flask(__name__)
19
+ # 启用 CORS,允许前端(通常在不同端口)访问
20
+ CORS(app)
21
+
22
+ # --- 全局状态和 AI 初始化 ---
23
+
24
+ # 默认参数 (与 alphazero.py 中的 args 保持一致)
25
+ args = dotdict({
26
+ 'lr': 0.001,
27
+ 'dropout': 0.1,
28
+ 'epochs': 10,
29
+ 'batch_size': 64,
30
+ 'cuda': torch.cuda.is_available(),
31
+ 'num_channels': 512,
32
+ 'numIters': 200,
33
+ 'numEps': 100,
34
+ 'tempThreshold': 15,
35
+ 'updateThreshold': 0.6,
36
+ 'maxlenOfQueue': 200000,
37
+ 'numItersForTrainExamplesHistory': 20,
38
+ 'numMCTSSims': 25, # 训练时的 MCTS 模拟次数
39
+ 'arenaCompare': 40,
40
+ 'cpuct': 1,
41
+ 'checkpoint': './temp/',
42
+ 'load_model': True,
43
+ 'load_folder_file': ('./temp/','best.pth.tar'),
44
+ 'board_size': 8 # 默认 8x8
45
+ })
46
+
47
+ # 游戏和 AI 实例
48
+ game = None
49
+ nnet = None
50
+ mcts = None
51
+
52
+ # 游戏状态
53
+ current_board = None
54
+ current_player = 1 # 1: Human (White), -1: AI (Black)
55
+ last_move_coords = None
56
+ board_size = 8
57
+
58
+ def init_game_and_ai(n):
59
+ """根据板子大小初始化游戏和 AI 模块"""
60
+ global game, nnet, mcts, board_size
61
+ board_size = n
62
+ log.info(f"Initializing game and AI for {n}x{n} board.")
63
+ game = OthelloGame(n) #
64
+
65
+ # 注意:AlphaZero 模型训练通常针对固定尺寸。
66
+ # 如果你的模型只支持 8x8,这里需要进行处理或重新训练。
67
+ # 这里我们假设模型支持当前尺寸 n。
68
+
69
+ # 重新配置 MCTS 参数用于 Play 模式
70
+ play_args = dotdict({
71
+ 'numMCTSSims': 200, # 对战时使用更多的模拟次数
72
+ 'cpuct': 1.0,
73
+ 'cuda': args.cuda # 继承 CUDA 设置
74
+ })
75
+
76
+ nnet = NNetWrapper(game, args) #
77
+ # 假设你的模型文件已保存到 './checkpoint/best.pth.tar'
78
+ try:
79
+ load_folder = args.load_folder_file[0]
80
+ load_file = args.load_folder_file[1]
81
+ nnet.load_checkpoint(folder=load_folder, filename=load_file)
82
+ log.info(f"Successfully loaded model from {load_folder}{load_file}")
83
+ except ValueError as e:
84
+ log.error(f"Failed to load model: {e}. AI will likely perform poorly.")
85
+
86
+ mcts = MCTS(game, nnet, play_args) #
87
+
88
+
89
+ def get_api_moves(board, player):
90
+ """将 getValidMoves 结果从向量转换为 {x, y} 列表"""
91
+ if game is None: return []
92
+
93
+ valids = game.getValidMoves(board, player) #
94
+ moves_list = []
95
+ # 排除最后一个动作(Pass动作)
96
+ for i in range(len(valids) - 1): #
97
+ if valids[i] == 1:
98
+ x = i // game.n
99
+ y = i % game.n
100
+ moves_list.append({'x': int(x), 'y': int(y)})
101
+ return moves_list
102
+
103
+ def check_game_end(board, player):
104
+ """检查游戏是否结束,并返回状态信息,基于绝对的棋子数量差异。"""
105
+
106
+ # 获取游戏结束的相对结果 (1: player 赢, -1: player 输, 0: 平局)
107
+ # 注意:这个结果是相对于传入的 player 而言的
108
+ result = game.getGameEnded(board, player) #
109
+
110
+ status = 'Ongoing'
111
+ score_diff = 0
112
+
113
+ if result is not None:
114
+ # 获取白棋 (1) 和黑棋 (-1) 的绝对分数。
115
+ # 这里的 score_diff 是:(白棋数量 - 黑棋数量)
116
+ white_count = np.sum(board == 1)
117
+ black_count = np.sum(board == -1)
118
+ score_diff = int(white_count - black_count)
119
+
120
+ if result == 0:
121
+ status = f"Game Over: Draw. Score: {white_count} vs {black_count}"
122
+ elif score_diff > 0:
123
+ # 白棋 (Human) 数量多,人赢
124
+ status = f"Game Over: Human (O) Wins! Score: {white_count} vs {black_count}"
125
+ elif score_diff < 0:
126
+ # 黑棋 (AI) 数量多,AI 赢
127
+ status = f"Game Over: AI (X) Wins! Score: {white_count} vs {black_count}"
128
+ else:
129
+ # 理论上 result != 0 时分数不会为 0,但以防万一
130
+ status = f"Game Over: Draw. Score: {white_count} vs {black_count}"
131
+
132
+ return status
133
+
134
+ @app.route('/api/game/new', methods=['POST'])
135
+ def new_game():
136
+ global current_board, current_player, last_move_coords, board_size
137
+ data = request.json
138
+ size = data.get('size', 8)
139
+
140
+ # 【新增代码】接收 first_player 参数,默认为 1 (Human)
141
+ first_player = data.get('first_player', 1)
142
+
143
+ if game is None or size != board_size:
144
+ init_game_and_ai(size)
145
+
146
+ current_board = game.getInitBoard() #
147
+ current_player = first_player # 【修改】使用接收到的 first_player 设置当前玩家
148
+ last_move_coords = None
149
+
150
+ status = check_game_end(current_board, current_player)
151
+
152
+ # 【新增逻辑】如果 AI 先手,立即触发 AI 移动
153
+ if current_player == -1 and status == 'Ongoing':
154
+ return ai_move_logic() # 直接调用 AI 逻辑并返回结果
155
+ # 对current_board进行flip,以符合前端显示习惯
156
+ current_board = np.flip(current_board, 0)
157
+
158
+ return jsonify({
159
+ 'board': current_board.tolist(),
160
+ 'legal_moves': get_api_moves(current_board, current_player),
161
+ 'current_player': current_player,
162
+ 'last_move': last_move_coords,
163
+ 'status': status,
164
+ })
165
+
166
+
167
+ # @app.route('/api/game/human_move', methods=['POST'])
168
+ # def human_move():
169
+ # """处理人类玩家移动"""
170
+ # global current_board, current_player, last_move_coords
171
+
172
+ # if current_player != 1 or check_game_end(current_board, current_player) != 'Ongoing':
173
+ # return jsonify({'error': 'Not your turn or game is over'}), 400
174
+
175
+ # data = request.json
176
+ # x = data.get('x')
177
+ # y = data.get('y')
178
+
179
+ # if x is None or y is None:
180
+ # # 检查是否是 Pass 动作
181
+ # if data.get('action') == 'pass':
182
+ # action = game.n * game.n # Pass action is the last index
183
+ # else:
184
+ # return jsonify({'error': 'Invalid move coordinates'}), 400
185
+ # else:
186
+ # action = game.n * x + y
187
+
188
+ # valids = game.getValidMoves(current_board, 1) #
189
+ # if valids[action] == 0:
190
+ # return jsonify({'error': 'Illegal move'}), 400
191
+
192
+ # current_board, current_player = game.getNextState(current_board, 1, action) #
193
+
194
+ # if action != game.n * game.n:
195
+ # last_move_coords = {'x': x, 'y': y}
196
+
197
+ # status = check_game_end(current_board, current_player)
198
+
199
+ # # 如果游戏未结束且轮到 AI (-1)
200
+ # if status == 'Ongoing' and current_player == -1:
201
+ # # 在这里触发 AI 移动
202
+ # return ai_move_logic()
203
+
204
+ # return jsonify({
205
+ # 'board': current_board.tolist(),
206
+ # 'legal_moves': get_api_moves(current_board, current_player),
207
+ # 'current_player': current_player,
208
+ # 'last_move': last_move_coords,
209
+ # 'status': status,
210
+ # })
211
+
212
+ def ai_move_logic():
213
+ """AI 移动的逻辑封装,在 human_move 中调用"""
214
+ global current_board, current_player, last_move_coords
215
+
216
+ canonical_board = game.getCanonicalForm(current_board, -1) #
217
+
218
+ # 获取 AI 的最佳动作 (temp=0)
219
+ ai_action = np.argmax(mcts.getActionProb(canonical_board, temp=0)) #
220
+
221
+ # 更新游戏状态
222
+ current_board, next_player = game.getNextState(current_board, -1, ai_action) #
223
+ current_player = next_player
224
+
225
+ # 记录 AI 的移动坐标
226
+ if ai_action != game.n * game.n: # 如果不是 Pass 动作
227
+ ai_x = ai_action // game.n
228
+ ai_y = ai_action % game.n
229
+ last_move_coords = {'x': int(ai_x), 'y': int(ai_y)}
230
+
231
+ status = check_game_end(current_board, current_player)
232
+
233
+ # 对current_board进行flip,以符合前端显示习惯
234
+ current_board = np.flip(current_board, 0)
235
+
236
+ return jsonify({
237
+ 'board': current_board.tolist(),
238
+ 'legal_moves': get_api_moves(current_board, current_player),
239
+ 'current_player': current_player,
240
+ 'last_move': last_move_coords,
241
+ 'status': status,
242
+ })
243
+
244
+ # app.py (在 @app.route('/api/game/human_move', methods=['POST']) 路由下)
245
+ # 替换原有的 handleHumanMove/human_move 函数
246
+
247
+ @app.route('/api/game/human_move', methods=['POST'])
248
+ def human_move():
249
+ """处理人类玩家移动,并返回给 AI 的中间状态"""
250
+ global current_board, current_player, last_move_coords
251
+
252
+ if current_player != 1 or check_game_end(current_board, current_player) != 'Ongoing':
253
+ return jsonify({'error': 'Not your turn or game is over'}), 400
254
+
255
+ data = request.json
256
+ x = data.get('x')
257
+ y = data.get('y')
258
+
259
+ if x is None or y is None:
260
+ # 检查是否是 Pass 动作
261
+ if data.get('action') == 'pass':
262
+ action = game.n * game.n # Pass action is the last index
263
+ else:
264
+ return jsonify({'error': 'Invalid move coordinates'}), 400
265
+ else:
266
+ action = game.n * x + y
267
+
268
+ valids = game.getValidMoves(current_board, 1)
269
+ if valids[action] == 0:
270
+ return jsonify({'error': 'Illegal move'}), 400
271
+
272
+ # 执行人类移动
273
+ current_board, current_player = game.getNextState(current_board, 1, action)
274
+
275
+ if action != game.n * game.n:
276
+ last_move_coords = {'x': x, 'y': y}
277
+
278
+ status = check_game_end(current_board, current_player)
279
+
280
+ # 对current_board进行flip,以符合前端显示习惯
281
+ # current_board = np.flip(current_board, 0)
282
+
283
+ # 注意:这里不再包含 AI 移动逻辑,直接返回
284
+ return jsonify({
285
+ 'board': current_board.tolist(),
286
+ 'legal_moves': get_api_moves(current_board, current_player),
287
+ 'current_player': current_player,
288
+ 'last_move': last_move_coords,
289
+ 'status': status,
290
+ })
291
+
292
+
293
+ # B. 新增 `ai_move` 路由
294
+
295
+ @app.route('/api/game/ai_move', methods=['POST'])
296
+ def ai_move():
297
+ start_time = time.time()
298
+ """触发 AI 移动,并返回最终状态"""
299
+ global current_board, current_player, last_move_coords
300
+
301
+ if current_player != -1:
302
+ return jsonify({'error': 'Not AI turn'}), 400
303
+
304
+
305
+ canonical_board = game.getCanonicalForm(current_board, -1)
306
+ ai_action = np.argmax(mcts.getActionProb(canonical_board, temp=0))
307
+
308
+ # 执行 AI 移动
309
+ current_board, next_player = game.getNextState(current_board, -1, ai_action)
310
+ current_player = next_player
311
+
312
+ # 记录 AI 的移动坐标
313
+ if ai_action != game.n * game.n:
314
+ ai_x = ai_action // game.n
315
+ ai_y = ai_action % game.n
316
+ last_move_coords = {'x': int(ai_x), 'y': int(ai_y)}
317
+ else:
318
+ last_move_coords = None # AI Pass
319
+
320
+ status = check_game_end(current_board, current_player)
321
+
322
+ # 控制 AI 最少思考时间为 0.5 秒
323
+ end_time = time.time()
324
+ used_time = end_time - start_time
325
+ if used_time < 0.5:
326
+ time.sleep(0.5 - used_time) # 确保至少等待0.5秒
327
+
328
+ return jsonify({
329
+ 'board': current_board.tolist(),
330
+ 'legal_moves': get_api_moves(current_board, current_player),
331
+ 'current_player': current_player,
332
+ 'last_move': last_move_coords,
333
+ 'status': status,
334
+ })
335
+
336
+ if __name__ == '__main__':
337
+ # 初始化一个默认的 8x8 游戏实例
338
+ init_game_and_ai(8)
339
+ log.info("Starting Flask server on port 7860...")
340
+
341
+ port = int(os.environ.get('PORT', 7860))
342
+ # ... (日志) ...
343
+ app.run(host='0.0.0.0', port=port)
game.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import logging
3
+ from tqdm import tqdm
4
+ log = logging.getLogger(__name__)
5
+
6
+ class Board():
7
+ '''
8
+ Author: Eric P. Nichols
9
+ Date: Feb 8, 2008.
10
+ Board class.
11
+ Board data:
12
+ 1=white, -1=black, 0=empty
13
+ '''
14
+
15
+ # list of all 8 directions on the board, as (x,y) offsets
16
+ __directions = [(1,1),(1,0),(1,-1),(0,-1),(-1,-1),(-1,0),(-1,1),(0,1)]
17
+
18
+ def __init__(self, n):
19
+ "Set up initial board configuration."
20
+
21
+ self.n = n
22
+ # Create the empty board array.
23
+ self.pieces = [None]*self.n
24
+ for i in range(self.n):
25
+ self.pieces[i] = [0]*self.n
26
+
27
+ # Set up the initial 4 pieces.
28
+ self.pieces[int(self.n/2)-1][int(self.n/2)] = 1
29
+ self.pieces[int(self.n/2)][int(self.n/2)-1] = 1
30
+ self.pieces[int(self.n/2)-1][int(self.n/2)-1] = -1;
31
+ self.pieces[int(self.n/2)][int(self.n/2)] = -1;
32
+
33
+ # add [][] indexer syntax to the Board
34
+ def __getitem__(self, index):
35
+ return self.pieces[index]
36
+
37
+ def countDiff(self, color):
38
+ """Counts the # pieces of the given color
39
+ (1 for white, -1 for black, 0 for empty spaces)"""
40
+ count = 0
41
+ for y in range(self.n):
42
+ for x in range(self.n):
43
+ if self[x][y]==color:
44
+ count += 1
45
+ if self[x][y]==-color:
46
+ count -= 1
47
+ return count
48
+
49
+ def get_legal_moves(self, color):
50
+ """Returns all the legal moves for the given color.
51
+ (1 for white, -1 for black)
52
+ """
53
+ moves = set() # stores the legal moves.
54
+
55
+ # Get all the squares with pieces of the given color.
56
+ for y in range(self.n):
57
+ for x in range(self.n):
58
+ if self[x][y]==color:
59
+ newmoves = self.get_moves_for_square((x,y))
60
+ moves.update(newmoves)
61
+ return list(moves)
62
+
63
+ def has_legal_moves(self, color):
64
+ for y in range(self.n):
65
+ for x in range(self.n):
66
+ if self[x][y]==color:
67
+ newmoves = self.get_moves_for_square((x,y))
68
+ if len(newmoves)>0:
69
+ return True
70
+ return False
71
+
72
+ def get_moves_for_square(self, square):
73
+ """Returns all the legal moves that use the given square as a base.
74
+ That is, if the given square is (3,4) and it contains a black piece,
75
+ and (3,5) and (3,6) contain white pieces, and (3,7) is empty, one
76
+ of the returned moves is (3,7) because everything from there to (3,4)
77
+ is flipped.
78
+ """
79
+ (x,y) = square
80
+
81
+ # determine the color of the piece.
82
+ color = self[x][y]
83
+
84
+ # skip empty source squares.
85
+ if color==0:
86
+ return None
87
+
88
+ # search all possible directions.
89
+ moves = []
90
+ for direction in self.__directions:
91
+ move = self._discover_move(square, direction)
92
+ if move:
93
+ moves.append(move)
94
+
95
+ # return the generated move list
96
+ return moves
97
+
98
+ def execute_move(self, move, color):
99
+ """Perform the given move on the board; flips pieces as necessary.
100
+ color gives the color of the piece to play (1=white,-1=black)
101
+ """
102
+
103
+ #Much like move generation, start at the new piece's square and
104
+ #follow it on all 8 directions to look for a piece allowing flipping.
105
+
106
+ flips = [flip for direction in self.__directions
107
+ for flip in self._get_flips(move, direction, color)]
108
+ assert len(list(flips))>0
109
+ for x, y in flips:
110
+ self[x][y] = color
111
+
112
+ def _discover_move(self, origin, direction):
113
+ """ Returns the endpoint for a legal move, starting at the given origin,
114
+ moving by the given increment."""
115
+ x, y = origin
116
+ color = self[x][y]
117
+ flips = []
118
+
119
+ for x, y in Board._increment_move(origin, direction, self.n):
120
+ if self[x][y] == 0:
121
+ if flips:
122
+ return (x, y)
123
+ else:
124
+ return None
125
+ elif self[x][y] == color:
126
+ return None
127
+ elif self[x][y] == -color:
128
+ flips.append((x, y))
129
+
130
+ def _get_flips(self, origin, direction, color):
131
+ """ Gets the list of flips for a vertex and direction to use with the
132
+ execute_move function """
133
+ #initialize variables
134
+ flips = [origin]
135
+
136
+ for x, y in Board._increment_move(origin, direction, self.n):
137
+ if self[x][y] == 0:
138
+ return []
139
+ if self[x][y] == -color:
140
+ flips.append((x, y))
141
+ elif self[x][y] == color and len(flips) > 0:
142
+ return flips
143
+
144
+ return []
145
+
146
+ @staticmethod
147
+ def _increment_move(move, direction, n):
148
+ """ Generator expression for incrementing moves """
149
+ move = list(map(sum, zip(move, direction)))
150
+ #move = (move[0]+direction[0], move[1]+direction[1])
151
+ while all(map(lambda x: 0 <= x < n, move)):
152
+ #while 0<=move[0] and move[0]<n and 0<=move[1] and move[1]<n:
153
+ yield move
154
+ move=list(map(sum,zip(move,direction)))
155
+ #move = (move[0]+direction[0],move[1]+direction[1])
156
+
157
+
158
+ class OthelloGame():
159
+ square_content = {
160
+ -1: "X",
161
+ +0: "-",
162
+ +1: "O"
163
+ }
164
+
165
+ @staticmethod
166
+ def getSquarePiece(piece):
167
+ return OthelloGame.square_content[piece]
168
+
169
+ def __init__(self, n):
170
+ self.n = n
171
+
172
+ def getInitBoard(self):
173
+ # return initial board (numpy board)
174
+ b = Board(self.n)
175
+ return np.array(b.pieces)
176
+
177
+ def getBoardSize(self):
178
+ # (a,b) tuple
179
+ return (self.n, self.n)
180
+
181
+ def getActionSize(self):
182
+ # return number of actions
183
+ return self.n*self.n + 1
184
+
185
+ def getNextState(self, board, player, action):
186
+ # if player takes action on board, return next (board,player)
187
+ # action must be a valid move
188
+ if action == self.n*self.n:
189
+ return (board, -player)
190
+ b = Board(self.n)
191
+ b.pieces = np.copy(board)
192
+ move = (int(action/self.n), action%self.n)
193
+ b.execute_move(move, player)
194
+ return (b.pieces, -player)
195
+
196
+ def getValidMoves(self, board, player):
197
+ # return a fixed size binary vector
198
+ valids = [0]*self.getActionSize()
199
+ b = Board(self.n)
200
+ b.pieces = np.copy(board)
201
+ legalMoves = b.get_legal_moves(player)
202
+ if len(legalMoves)==0:
203
+ valids[-1]=1
204
+ return np.array(valids)
205
+ for x, y in legalMoves:
206
+ valids[self.n*x+y]=1
207
+ return np.array(valids)
208
+
209
+ def getGameEnded(self, board, player):
210
+ # return None if not ended, 1 if player won, -1 if player lost, 0 if draw.
211
+ b = Board(self.n)
212
+ b.pieces = np.copy(board)
213
+ if b.has_legal_moves(player):
214
+ return None
215
+ if b.has_legal_moves(-player):
216
+ return None
217
+ if b.countDiff(player) > 0:
218
+ return 1
219
+ elif b.countDiff(player) < 0:
220
+ return -1
221
+ else:
222
+ return 0
223
+
224
+ def getCanonicalForm(self, board, player):
225
+ # return state if player==1, else return -state if player==-1
226
+ return player*board
227
+
228
+ def getSymmetries(self, board, pi):
229
+ # mirror, rotational
230
+ assert(len(pi) == self.n**2+1) # 1 for pass
231
+ pi_board = np.reshape(pi[:-1], (self.n, self.n))
232
+ l = []
233
+
234
+ for i in range(1, 5):
235
+ for j in [True, False]:
236
+ newB = np.rot90(board, i)
237
+ newPi = np.rot90(pi_board, i)
238
+ if j:
239
+ newB = np.fliplr(newB)
240
+ newPi = np.fliplr(newPi)
241
+ l += [(newB, list(newPi.ravel()) + [pi[-1]])]
242
+ return l
243
+
244
+ def stringRepresentation(self, board):
245
+ return board.tostring()
246
+
247
+ def stringRepresentationReadable(self, board):
248
+ board_s = "".join(self.square_content[square] for row in board for square in row)
249
+ return board_s
250
+
251
+ def getScore(self, board, player):
252
+ b = Board(self.n)
253
+ b.pieces = np.copy(board)
254
+ return b.countDiff(player)
255
+
256
+ @staticmethod
257
+ def display(board):
258
+ n = board.shape[0]
259
+ print(" ", end="")
260
+ for y in range(n):
261
+ print(y, end=" ")
262
+ print("")
263
+ print("-----------------------")
264
+ for y in range(n):
265
+ print(y, "|", end="")
266
+ for x in range(n):
267
+ piece = board[y][x]
268
+ print(OthelloGame.square_content[piece], end=" ")
269
+ print("|")
270
+
271
+ print("-----------------------")
272
+
273
+
274
+ class RandomPlayer():
275
+ def __init__(self, game):
276
+ self.game = game
277
+
278
+ def play(self, board):
279
+ a = np.random.randint(self.game.getActionSize())
280
+ valids = self.game.getValidMoves(board, 1)
281
+ while valids[a]!=1:
282
+ a = np.random.randint(self.game.getActionSize())
283
+ return a
284
+
285
+ class GreedyOthelloPlayer():
286
+ def __init__(self, game):
287
+ self.game = game
288
+
289
+ def play(self, board):
290
+ valids = self.game.getValidMoves(board, 1)
291
+ candidates = []
292
+ for a in range(self.game.getActionSize()):
293
+ if valids[a]==0:
294
+ continue
295
+ nextBoard, _ = self.game.getNextState(board, 1, a)
296
+ score = self.game.getScore(nextBoard, 1)
297
+ candidates += [(-score, a)]
298
+ candidates.sort()
299
+ return candidates[0][1]
300
+
301
+
302
+ class HumanOthelloPlayer():
303
+ def __init__(self, game):
304
+ self.game = game
305
+
306
+ def play(self, board):
307
+ # display(board)
308
+ valid = self.game.getValidMoves(board, 1)
309
+ for i in range(len(valid)):
310
+ if valid[i]:
311
+ print("[", int(i/self.game.n), int(i%self.game.n), end="] ")
312
+ while True:
313
+ input_move = input()
314
+ input_a = input_move.split(" ")
315
+ if len(input_a) == 2:
316
+ try:
317
+ x,y = [int(i) for i in input_a]
318
+ if ((0 <= x) and (x < self.game.n) and (0 <= y) and (y < self.game.n)) or \
319
+ ((x == self.game.n) and (y == 0)):
320
+ a = self.game.n * x + y
321
+ if valid[a]:
322
+ break
323
+ except ValueError:
324
+ 'Invalid integer'
325
+ print('Invalid move')
326
+ return a
327
+
328
+
329
+ class Arena():
330
+ """
331
+ An Arena class where any 2 agents can be pit against each other.
332
+ """
333
+
334
+ def __init__(self, player1, player2, game, display=None):
335
+ """
336
+ Input:
337
+ player 1,2: two functions that takes board as input, return action
338
+ game: Game object
339
+ display: a function that takes board as input and prints it. Is necessary for verbose
340
+ mode.
341
+ """
342
+ self.player1 = player1
343
+ self.player2 = player2
344
+ self.game = game
345
+ self.display = display
346
+
347
+ def playGame(self, verbose=False):
348
+ """
349
+ Executes one episode of a game.
350
+
351
+ Returns:
352
+ either
353
+ winner: player who won the game (1 if player1, -1 if player2, 0 if draw)
354
+ """
355
+ players = [self.player2, None, self.player1]
356
+ curPlayer = 1 # player1 go first
357
+ board = self.game.getInitBoard()
358
+ it = 0
359
+ while self.game.getGameEnded(board, curPlayer) is None:
360
+ it += 1
361
+ if verbose:
362
+ assert self.display
363
+ print("Turn ", str(it), "Player ", str(curPlayer))
364
+ self.display(board)
365
+ action = players[curPlayer + 1](self.game.getCanonicalForm(board, curPlayer))
366
+
367
+ valids = self.game.getValidMoves(self.game.getCanonicalForm(board, curPlayer), 1)
368
+
369
+ if valids[action] == 0:
370
+ log.error(f'Action {action} is not valid!')
371
+ log.debug(f'valids = {valids}')
372
+ assert valids[action] > 0
373
+ board, curPlayer = self.game.getNextState(board, curPlayer, action)
374
+ result = curPlayer * self.game.getGameEnded(board, curPlayer)
375
+ if verbose:
376
+ assert self.display
377
+ print("Game over: Turn ", str(it), "Result ", str(result))
378
+ self.display(board)
379
+ return result
380
+
381
+ def playGames(self, num, verbose=False):
382
+ """
383
+ Plays num games in which player1 starts num/2 games and player2 starts
384
+ num/2 games.
385
+
386
+ Returns:
387
+ oneWon: games won by player1
388
+ twoWon: games won by player2
389
+ draws: games won by nobody
390
+ """
391
+
392
+ num = int(num / 2)
393
+ oneWon = 0
394
+ twoWon = 0
395
+ draws = 0
396
+ for _ in tqdm(range(num), desc="Arena.playGames (player1 go first)"):
397
+ gameResult = self.playGame(verbose=verbose)
398
+ if gameResult == 1:
399
+ oneWon += 1
400
+ elif gameResult == -1:
401
+ twoWon += 1
402
+ else:
403
+ draws += 1
404
+
405
+ self.player1, self.player2 = self.player2, self.player1
406
+
407
+ for _ in tqdm(range(num), desc="Arena.playGames (player2 go first)"):
408
+ gameResult = self.playGame(verbose=verbose)
409
+ if gameResult == -1:
410
+ oneWon += 1
411
+ elif gameResult == 1:
412
+ twoWon += 1
413
+ else:
414
+ draws += 1
415
+
416
+ return oneWon, twoWon, draws
417
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # othello-backend/requirements.txt
2
+ # 核心 Web 框架
3
+ flask
4
+ Flask-CORS
5
+
6
+ # AI/模型依赖
7
+ numpy
8
+ torch
9
+
10
+ # 其他辅助库
11
+ tqdm
12
+ logging