Spaces:
Sleeping
Sleeping
Deploying Othello Flask backend with Docker
Browse files- .DS_Store +0 -0
- Dockerfile +23 -0
- LICENSE +21 -0
- alphazero.py +487 -0
- app.py +343 -0
- game.py +417 -0
- requirements.txt +12 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# othello-backend/Dockerfile
|
| 2 |
+
|
| 3 |
+
# 1. 基础镜像:使用官方 Python 镜像,包含所有必要的系统库
|
| 4 |
+
FROM python:3.9-slim
|
| 5 |
+
|
| 6 |
+
# 2. 设置工作目录:所有操作都在 /app 目录下进行
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# 3. 复制依赖文件并安装 Python 库
|
| 10 |
+
# 先复制并安装依赖,以便利用 Docker 缓存
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# 4. 复制所有应用代码和模型
|
| 15 |
+
# 注意:你需要确保 game.py, alphazero.py, app.py, checkpoint/ 都位于 othello-backend 目录
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# 5. 暴露端口:Hugging Face 默认使用 7860 端口接收 HTTP 流量
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
# 6. 容器启动命令:运行你的 Flask 应用
|
| 22 |
+
# CMD 会在容器启动时执行
|
| 23 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 wangxuguang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
alphazero.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import deque
|
| 12 |
+
from pickle import Pickler, Unpickler
|
| 13 |
+
from random import shuffle
|
| 14 |
+
|
| 15 |
+
from game import *
|
| 16 |
+
logging.basicConfig(level = logging.INFO)
|
| 17 |
+
log = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MCTS():
|
| 21 |
+
"""
|
| 22 |
+
This class handles the MCTS tree.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, game, nnet, args):
|
| 26 |
+
self.game = game
|
| 27 |
+
self.nnet = nnet
|
| 28 |
+
self.args = args
|
| 29 |
+
self.Qsa = {} # stores Q values for s,a (as defined in the paper)
|
| 30 |
+
self.Nsa = {} # stores #times edge s,a was visited
|
| 31 |
+
self.Ns = {} # stores #times board s was visited
|
| 32 |
+
self.Ps = {} # stores initial policy (returned by neural net)
|
| 33 |
+
|
| 34 |
+
self.Es = {} # stores game.getGameEnded for board s
|
| 35 |
+
self.Vs = {} # stores game.getValidMoves for board s
|
| 36 |
+
|
| 37 |
+
def getActionProb(self, canonicalBoard, temp=1):
|
| 38 |
+
"""
|
| 39 |
+
This function performs numMCTSSims simulations of MCTS starting from
|
| 40 |
+
canonicalBoard.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
probs: a policy vector where the probability of the ith action is
|
| 44 |
+
proportional to Nsa[(s,a)]**(1./temp)
|
| 45 |
+
"""
|
| 46 |
+
for i in range(self.args.numMCTSSims):
|
| 47 |
+
self.search(canonicalBoard)
|
| 48 |
+
|
| 49 |
+
s = self.game.stringRepresentation(canonicalBoard)
|
| 50 |
+
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
|
| 51 |
+
|
| 52 |
+
if temp == 0:
|
| 53 |
+
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
|
| 54 |
+
bestA = np.random.choice(bestAs)
|
| 55 |
+
probs = [0] * len(counts)
|
| 56 |
+
probs[bestA] = 1
|
| 57 |
+
return probs
|
| 58 |
+
|
| 59 |
+
counts = [x ** (1. / temp) for x in counts]
|
| 60 |
+
counts_sum = float(sum(counts))
|
| 61 |
+
probs = [x / counts_sum for x in counts]
|
| 62 |
+
return probs
|
| 63 |
+
|
| 64 |
+
def search(self, canonicalBoard):
|
| 65 |
+
"""
|
| 66 |
+
This function performs one iteration of MCTS. It is recursively called
|
| 67 |
+
till a leaf node is found. The action chosen at each node is one that
|
| 68 |
+
has the maximum upper confidence bound as in the paper.
|
| 69 |
+
|
| 70 |
+
Once a leaf node is found, the neural network is called to return an
|
| 71 |
+
initial policy P and a value v for the state. This value is propagated
|
| 72 |
+
up the search path. In case the leaf node is a terminal state, the
|
| 73 |
+
outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
|
| 74 |
+
updated.
|
| 75 |
+
|
| 76 |
+
NOTE: Since v is in [-1,1] and if v is the value of a
|
| 77 |
+
state for the current player, then its value is -v for the other player.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
v: the value of the current canonicalBoard
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
s = self.game.stringRepresentation(canonicalBoard)
|
| 84 |
+
|
| 85 |
+
if s not in self.Es:
|
| 86 |
+
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
|
| 87 |
+
if self.Es[s] is not None:
|
| 88 |
+
# terminal node
|
| 89 |
+
return self.Es[s]
|
| 90 |
+
|
| 91 |
+
if s not in self.Ps:
|
| 92 |
+
# leaf node
|
| 93 |
+
self.Ps[s], v = self.nnet.predict(canonicalBoard)
|
| 94 |
+
valids = self.game.getValidMoves(canonicalBoard, 1)
|
| 95 |
+
self.Ps[s] = self.Ps[s] * valids # masking invalid moves
|
| 96 |
+
sum_Ps_s = np.sum(self.Ps[s])
|
| 97 |
+
if sum_Ps_s > 0:
|
| 98 |
+
self.Ps[s] /= sum_Ps_s # renormalize
|
| 99 |
+
else:
|
| 100 |
+
# if all valid moves were masked make all valid moves equally probable
|
| 101 |
+
log.error("All valid moves were masked, doing a workaround.")
|
| 102 |
+
self.Ps[s] = self.Ps[s] + valids
|
| 103 |
+
self.Ps[s] /= np.sum(self.Ps[s])
|
| 104 |
+
|
| 105 |
+
self.Vs[s] = valids
|
| 106 |
+
self.Ns[s] = 0
|
| 107 |
+
return v
|
| 108 |
+
|
| 109 |
+
valids = self.Vs[s]
|
| 110 |
+
cur_best = -float('inf')
|
| 111 |
+
best_act = -1
|
| 112 |
+
|
| 113 |
+
# pick the action with the highest upper confidence bound
|
| 114 |
+
for a in range(self.game.getActionSize()):
|
| 115 |
+
if valids[a]:
|
| 116 |
+
u = self.Qsa.get((s, a), 0) + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
|
| 117 |
+
1 + self.Nsa.get((s, a), 0))
|
| 118 |
+
|
| 119 |
+
if u > cur_best:
|
| 120 |
+
cur_best = u
|
| 121 |
+
best_act = a
|
| 122 |
+
|
| 123 |
+
a = best_act
|
| 124 |
+
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
|
| 125 |
+
next_s = self.game.getCanonicalForm(next_s, next_player)
|
| 126 |
+
|
| 127 |
+
v = -self.search(next_s)
|
| 128 |
+
|
| 129 |
+
if (s, a) in self.Qsa:
|
| 130 |
+
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
|
| 131 |
+
self.Nsa[(s, a)] += 1
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
self.Qsa[(s, a)] = v
|
| 135 |
+
self.Nsa[(s, a)] = 1
|
| 136 |
+
|
| 137 |
+
self.Ns[s] += 1
|
| 138 |
+
return v
|
| 139 |
+
|
| 140 |
+
class OthelloNNet(nn.Module):
|
| 141 |
+
def __init__(self, game, args):
|
| 142 |
+
# game params
|
| 143 |
+
self.board_x, self.board_y = game.getBoardSize()
|
| 144 |
+
self.action_size = game.getActionSize()
|
| 145 |
+
self.args = args
|
| 146 |
+
|
| 147 |
+
super(OthelloNNet, self).__init__()
|
| 148 |
+
self.conv1 = nn.Conv2d(1, args.num_channels, 3, stride=1, padding=1)
|
| 149 |
+
self.conv2 = nn.Conv2d(args.num_channels, args.num_channels, 3, stride=1, padding=1)
|
| 150 |
+
self.conv3 = nn.Conv2d(args.num_channels, args.num_channels, 3, stride=1)
|
| 151 |
+
self.conv4 = nn.Conv2d(args.num_channels, args.num_channels, 3, stride=1)
|
| 152 |
+
|
| 153 |
+
self.bn1 = nn.BatchNorm2d(args.num_channels)
|
| 154 |
+
self.bn2 = nn.BatchNorm2d(args.num_channels)
|
| 155 |
+
self.bn3 = nn.BatchNorm2d(args.num_channels)
|
| 156 |
+
self.bn4 = nn.BatchNorm2d(args.num_channels)
|
| 157 |
+
|
| 158 |
+
self.fc1 = nn.Linear(args.num_channels*(self.board_x-4)*(self.board_y-4), 1024)
|
| 159 |
+
self.fc_bn1 = nn.BatchNorm1d(1024)
|
| 160 |
+
|
| 161 |
+
self.fc2 = nn.Linear(1024, 512)
|
| 162 |
+
self.fc_bn2 = nn.BatchNorm1d(512)
|
| 163 |
+
|
| 164 |
+
self.fc3 = nn.Linear(512, self.action_size)
|
| 165 |
+
self.fc4 = nn.Linear(512, 1)
|
| 166 |
+
|
| 167 |
+
def forward(self, s):
|
| 168 |
+
# you can add residual to the network
|
| 169 |
+
# s: batch_size x board_x x board_y
|
| 170 |
+
s = s.view(-1, 1, self.board_x, self.board_y) # batch_size x 1 x board_x x board_y
|
| 171 |
+
s = F.relu(self.bn1(self.conv1(s))) # batch_size x num_channels x board_x x board_y
|
| 172 |
+
s = F.relu(self.bn2(self.conv2(s))) # batch_size x num_channels x board_x x board_y
|
| 173 |
+
s = F.relu(self.bn3(self.conv3(s))) # batch_size x num_channels x (board_x-2) x (board_y-2)
|
| 174 |
+
s = F.relu(self.bn4(self.conv4(s))) # batch_size x num_channels x (board_x-4) x (board_y-4)
|
| 175 |
+
s = s.view(-1, self.args.num_channels*(self.board_x-4)*(self.board_y-4))
|
| 176 |
+
|
| 177 |
+
s = F.dropout(F.relu(self.fc_bn1(self.fc1(s))), p=self.args.dropout, training=self.training) # batch_size x 1024
|
| 178 |
+
s = F.dropout(F.relu(self.fc_bn2(self.fc2(s))), p=self.args.dropout, training=self.training) # batch_size x 512
|
| 179 |
+
|
| 180 |
+
pi = self.fc3(s) # batch_size x action_size
|
| 181 |
+
v = self.fc4(s) # batch_size x 1
|
| 182 |
+
|
| 183 |
+
return F.log_softmax(pi, dim=1), torch.tanh(v)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class AverageMeter(object):
|
| 187 |
+
"""From https://github.com/pytorch/examples/blob/master/imagenet/main.py"""
|
| 188 |
+
|
| 189 |
+
def __init__(self):
|
| 190 |
+
self.val = 0
|
| 191 |
+
self.avg = 0
|
| 192 |
+
self.sum = 0
|
| 193 |
+
self.count = 0
|
| 194 |
+
|
| 195 |
+
def __repr__(self):
|
| 196 |
+
return f'{self.avg:.2e}'
|
| 197 |
+
|
| 198 |
+
def update(self, val, n=1):
|
| 199 |
+
self.val = val
|
| 200 |
+
self.sum += val * n
|
| 201 |
+
self.count += n
|
| 202 |
+
self.avg = self.sum / self.count
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class NNetWrapper():
|
| 206 |
+
def __init__(self, game, args):
|
| 207 |
+
self.nnet = OthelloNNet(game, args)
|
| 208 |
+
self.board_x, self.board_y = game.getBoardSize()
|
| 209 |
+
self.action_size = game.getActionSize()
|
| 210 |
+
self.args = args
|
| 211 |
+
|
| 212 |
+
if args.cuda:
|
| 213 |
+
self.nnet.cuda()
|
| 214 |
+
|
| 215 |
+
def train(self, examples):
|
| 216 |
+
"""
|
| 217 |
+
examples: list of examples, each example is of form (board, pi, v)
|
| 218 |
+
"""
|
| 219 |
+
optimizer = optim.Adam(self.nnet.parameters(), lr=self.args.lr)
|
| 220 |
+
|
| 221 |
+
for epoch in range(self.args.epochs):
|
| 222 |
+
print('EPOCH ::: ' + str(epoch + 1))
|
| 223 |
+
self.nnet.train()
|
| 224 |
+
pi_losses = AverageMeter()
|
| 225 |
+
v_losses = AverageMeter()
|
| 226 |
+
|
| 227 |
+
batch_count = int(len(examples) / self.args.batch_size)
|
| 228 |
+
|
| 229 |
+
t = tqdm(range(batch_count), desc='Training Net')
|
| 230 |
+
for _ in t:
|
| 231 |
+
sample_ids = np.random.randint(len(examples), size=self.args.batch_size)
|
| 232 |
+
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
|
| 233 |
+
boards = torch.FloatTensor(np.array(boards).astype(np.float32)) #np.float32 or np.float64 have a difference?
|
| 234 |
+
target_pis = torch.FloatTensor(np.array(pis))
|
| 235 |
+
target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
|
| 236 |
+
|
| 237 |
+
if self.args.cuda:
|
| 238 |
+
# boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
|
| 239 |
+
boards, target_pis, target_vs = boards.cuda(), target_pis.cuda(), target_vs.cuda()
|
| 240 |
+
|
| 241 |
+
# compute output
|
| 242 |
+
out_pi, out_v = self.nnet(boards)
|
| 243 |
+
l_pi = self.loss_pi(target_pis, out_pi)
|
| 244 |
+
l_v = self.loss_v(target_vs, out_v)
|
| 245 |
+
total_loss = l_pi + l_v
|
| 246 |
+
|
| 247 |
+
# record loss
|
| 248 |
+
pi_losses.update(l_pi.item(), boards.size(0))
|
| 249 |
+
v_losses.update(l_v.item(), boards.size(0))
|
| 250 |
+
t.set_postfix(Loss_pi=pi_losses, Loss_v=v_losses)
|
| 251 |
+
|
| 252 |
+
# compute gradient and do SGD step
|
| 253 |
+
optimizer.zero_grad()
|
| 254 |
+
total_loss.backward()
|
| 255 |
+
optimizer.step()
|
| 256 |
+
|
| 257 |
+
def predict(self, board):
|
| 258 |
+
"""
|
| 259 |
+
board: np array with board
|
| 260 |
+
"""
|
| 261 |
+
# timing
|
| 262 |
+
# start = time.time()
|
| 263 |
+
|
| 264 |
+
# preparing input
|
| 265 |
+
board = torch.FloatTensor(board.astype(np.float32))
|
| 266 |
+
if self.args.cuda: board = board.cuda()
|
| 267 |
+
board = board.view(1, self.board_x, self.board_y)
|
| 268 |
+
self.nnet.eval()
|
| 269 |
+
with torch.no_grad():
|
| 270 |
+
pi, v = self.nnet(board)
|
| 271 |
+
|
| 272 |
+
# print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start))
|
| 273 |
+
return torch.exp(pi).data.cpu().numpy()[0], v.data.cpu().numpy()[0]
|
| 274 |
+
|
| 275 |
+
def loss_pi(self, targets, outputs):
|
| 276 |
+
return -torch.sum(targets * outputs) / targets.size()[0]
|
| 277 |
+
|
| 278 |
+
def loss_v(self, targets, outputs):
|
| 279 |
+
return torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0]
|
| 280 |
+
|
| 281 |
+
def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
|
| 282 |
+
filepath = os.path.join(folder, filename)
|
| 283 |
+
if not os.path.exists(folder):
|
| 284 |
+
print("Checkpoint Directory does not exist! Making directory {}".format(folder))
|
| 285 |
+
os.mkdir(folder)
|
| 286 |
+
else:
|
| 287 |
+
print("Checkpoint Directory exists! ")
|
| 288 |
+
torch.save({
|
| 289 |
+
'state_dict': self.nnet.state_dict(),
|
| 290 |
+
}, filepath)
|
| 291 |
+
|
| 292 |
+
def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
|
| 293 |
+
filepath = os.path.join(folder, filename)
|
| 294 |
+
if not os.path.exists(filepath):
|
| 295 |
+
raise ValueError("No model in path {}".format(filepath))
|
| 296 |
+
map_location = None if self.args.cuda else 'cpu'
|
| 297 |
+
checkpoint = torch.load(filepath, map_location=map_location)
|
| 298 |
+
self.nnet.load_state_dict(checkpoint['state_dict'])
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class SelfPlay():
|
| 302 |
+
"""
|
| 303 |
+
This class executes the self-play + learning.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(self, game, nnet, args):
|
| 307 |
+
self.game = game
|
| 308 |
+
self.nnet = nnet
|
| 309 |
+
self.pnet = self.nnet.__class__(self.game, args) # the competitor network
|
| 310 |
+
self.args = args
|
| 311 |
+
self.mcts = MCTS(self.game, self.nnet, self.args)
|
| 312 |
+
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
|
| 313 |
+
|
| 314 |
+
def executeEpisode(self):
|
| 315 |
+
"""
|
| 316 |
+
This function executes one episode of self-play, starting with player 1.
|
| 317 |
+
As the game is played, each turn is added as a training example to
|
| 318 |
+
trainExamples. The game is played till the game ends. After the game
|
| 319 |
+
ends, the outcome of the game is used to assign values to each example
|
| 320 |
+
in trainExamples.
|
| 321 |
+
|
| 322 |
+
It uses a temp=1 if episodeStep < tempThreshold, and thereafter
|
| 323 |
+
uses temp=0.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
trainExamples: a list of examples of the form (canonicalBoard, pi, v)
|
| 327 |
+
"""
|
| 328 |
+
trainExamples = []
|
| 329 |
+
board = self.game.getInitBoard()
|
| 330 |
+
self.curPlayer = 1
|
| 331 |
+
episodeStep = 0
|
| 332 |
+
|
| 333 |
+
while True:
|
| 334 |
+
episodeStep += 1
|
| 335 |
+
canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)
|
| 336 |
+
temp = int(episodeStep < self.args.tempThreshold)
|
| 337 |
+
|
| 338 |
+
pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
|
| 339 |
+
sym = self.game.getSymmetries(canonicalBoard, pi)
|
| 340 |
+
for b, p in sym:
|
| 341 |
+
trainExamples.append([b, self.curPlayer, p, None])
|
| 342 |
+
|
| 343 |
+
action = np.random.choice(len(pi), p=pi)
|
| 344 |
+
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action)
|
| 345 |
+
|
| 346 |
+
r = self.game.getGameEnded(board, self.curPlayer)
|
| 347 |
+
|
| 348 |
+
if r is not None:
|
| 349 |
+
# r * (1 if self.curPlayer == x[1] else -1) means 1 for winner, -1 for loser, 0 for draw.
|
| 350 |
+
return [(x[0], x[2], r * (1 if self.curPlayer == x[1] else -1)) for x in trainExamples]
|
| 351 |
+
|
| 352 |
+
def learn(self):
|
| 353 |
+
"""
|
| 354 |
+
Performs numIters iterations with numEps episodes of self-play in each
|
| 355 |
+
iteration. After every iteration, it retrains neural network with
|
| 356 |
+
examples in trainExamples (which has a maximum length of maxlenofQueue).
|
| 357 |
+
It then pits the new neural network against the old one and accepts it
|
| 358 |
+
only if it wins >= updateThreshold fraction of games.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
for i in range(1, self.args.numIters + 1):
|
| 362 |
+
# bookkeeping
|
| 363 |
+
log.info(f'Starting Iter #{i} ...')
|
| 364 |
+
# examples of the iteration
|
| 365 |
+
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)
|
| 366 |
+
|
| 367 |
+
for _ in tqdm(range(self.args.numEps), desc="Self Play"):
|
| 368 |
+
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
|
| 369 |
+
iterationTrainExamples += self.executeEpisode()
|
| 370 |
+
|
| 371 |
+
# save the iteration examples to the history
|
| 372 |
+
self.trainExamplesHistory.append(iterationTrainExamples)
|
| 373 |
+
|
| 374 |
+
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
|
| 375 |
+
log.warning(
|
| 376 |
+
f"Removing the oldest entry in trainExamples. len(trainExamplesHistory) = {len(self.trainExamplesHistory)}")
|
| 377 |
+
self.trainExamplesHistory.pop(0)
|
| 378 |
+
|
| 379 |
+
# shuffle examples before training
|
| 380 |
+
trainExamples = []
|
| 381 |
+
for e in self.trainExamplesHistory:
|
| 382 |
+
trainExamples.extend(e)
|
| 383 |
+
shuffle(trainExamples)
|
| 384 |
+
|
| 385 |
+
# training new network, keeping a copy of the old one
|
| 386 |
+
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
| 387 |
+
self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
| 388 |
+
pmcts = MCTS(self.game, self.pnet, self.args)
|
| 389 |
+
|
| 390 |
+
self.nnet.train(trainExamples)
|
| 391 |
+
nmcts = MCTS(self.game, self.nnet, self.args)
|
| 392 |
+
|
| 393 |
+
log.info('PITTING AGAINST PREVIOUS VERSION')
|
| 394 |
+
arena = Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)),
|
| 395 |
+
lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game)
|
| 396 |
+
pwins, nwins, draws = arena.playGames(self.args.arenaCompare)
|
| 397 |
+
|
| 398 |
+
log.info('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws))
|
| 399 |
+
if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold:
|
| 400 |
+
log.info('REJECTING NEW MODEL')
|
| 401 |
+
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
|
| 402 |
+
else:
|
| 403 |
+
log.info('ACCEPTING NEW MODEL')
|
| 404 |
+
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar')
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class dotdict(dict):
|
| 408 |
+
def __getattr__(self, name):
|
| 409 |
+
return self[name]
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
args = dotdict({
|
| 413 |
+
'lr': 0.001,
|
| 414 |
+
'dropout': 0.1,
|
| 415 |
+
'epochs': 10,
|
| 416 |
+
'batch_size': 64,
|
| 417 |
+
'cuda': torch.cuda.is_available(),
|
| 418 |
+
'num_channels': 512,
|
| 419 |
+
|
| 420 |
+
'numIters': 200,
|
| 421 |
+
'numEps': 100, # Number of complete self-play games to simulate during a new iteration.
|
| 422 |
+
'tempThreshold': 15, #
|
| 423 |
+
'updateThreshold': 0.6, # During arena playoff, new neural net will be accepted if threshold ratio or more of games are won.
|
| 424 |
+
'maxlenOfQueue': 200000, # Number of game examples to train the neural networks.
|
| 425 |
+
'numItersForTrainExamplesHistory': 20,
|
| 426 |
+
'numMCTSSims': 25, # Number of games moves for MCTS to simulate.
|
| 427 |
+
'arenaCompare': 40, # Number of games to play during arena play to determine if new net will be accepted.
|
| 428 |
+
'cpuct': 1,
|
| 429 |
+
|
| 430 |
+
'checkpoint': './temp/',
|
| 431 |
+
'load_model': False,
|
| 432 |
+
'load_folder_file': ('./temp/','best.pth.tar'),
|
| 433 |
+
})
|
| 434 |
+
|
| 435 |
+
def main():
|
| 436 |
+
import argparse
|
| 437 |
+
parser = argparse.ArgumentParser()
|
| 438 |
+
parser.add_argument('--train', action="store_true")
|
| 439 |
+
parser.add_argument('--board_size', type=int, default=6)
|
| 440 |
+
# play arguments
|
| 441 |
+
parser.add_argument('--play', action="store_true")
|
| 442 |
+
parser.add_argument('--verbose', action="store_true")
|
| 443 |
+
parser.add_argument('--round', type=int, default=2)
|
| 444 |
+
parser.add_argument('--player1', type=str, default='human', choices=['human', 'random', 'greedy', 'alphazero'])
|
| 445 |
+
parser.add_argument('--player2', type=str, default='alphazero', choices=['human', 'random', 'greedy', 'alphazero'])
|
| 446 |
+
parser.add_argument('--ckpt_file', type=str, default='best.pth.tar')
|
| 447 |
+
args_input = vars(parser.parse_args())
|
| 448 |
+
for k,v in args_input.items():
|
| 449 |
+
args[k] = v
|
| 450 |
+
|
| 451 |
+
g = OthelloGame(args.board_size)
|
| 452 |
+
|
| 453 |
+
if args.train:
|
| 454 |
+
nnet = NNetWrapper(g, args)
|
| 455 |
+
if args.load_model:
|
| 456 |
+
log.info('Loading checkpoint "%s/%s"...', args.load_folder_file[0], args.load_folder_file[1])
|
| 457 |
+
nnet.load_checkpoint(args.load_folder_file[0], args.load_folder_file[1])
|
| 458 |
+
|
| 459 |
+
log.info('Loading the SelfCoach...')
|
| 460 |
+
s = SelfPlay(g, nnet, args)
|
| 461 |
+
|
| 462 |
+
log.info('Starting the learning process 🎉')
|
| 463 |
+
s.learn()
|
| 464 |
+
|
| 465 |
+
if args.play:
|
| 466 |
+
def getPlayFunc(name):
|
| 467 |
+
if name == 'human':
|
| 468 |
+
return HumanOthelloPlayer(g).play
|
| 469 |
+
elif name == 'random':
|
| 470 |
+
return RandomPlayer(g).play
|
| 471 |
+
elif name == 'greedy':
|
| 472 |
+
return GreedyOthelloPlayer(g).play
|
| 473 |
+
elif name == 'alphazero':
|
| 474 |
+
nnet = NNetWrapper(g, args)
|
| 475 |
+
nnet.load_checkpoint(args.checkpoint, args.ckpt_file)
|
| 476 |
+
mcts = MCTS(g, nnet, dotdict({'numMCTSSims': 200, 'cpuct':1.0}))
|
| 477 |
+
return lambda x: np.argmax(mcts.getActionProb(x, temp=0))
|
| 478 |
+
else:
|
| 479 |
+
raise ValueError('not support player name {}'.format(name))
|
| 480 |
+
player1 = getPlayFunc(args.player1)
|
| 481 |
+
player2 = getPlayFunc(args.player2)
|
| 482 |
+
arena = Arena(player1, player2, g, display=OthelloGame.display)
|
| 483 |
+
results = arena.playGames(args.round, verbose=args.verbose)
|
| 484 |
+
print("Final results: Player1 wins {}, Player2 wins {}, Draws {}".format(*results))
|
| 485 |
+
|
| 486 |
+
if __name__ == '__main__':
|
| 487 |
+
main()
|
app.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from flask import Flask, request, jsonify
|
| 7 |
+
from flask_cors import CORS
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
# 导入你的游戏和 AI 模块
|
| 11 |
+
from game import OthelloGame #
|
| 12 |
+
from alphazero import NNetWrapper, MCTS, dotdict #
|
| 13 |
+
|
| 14 |
+
# 配置日志
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
log = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
app = Flask(__name__)
|
| 19 |
+
# 启用 CORS,允许前端(通常在不同端口)访问
|
| 20 |
+
CORS(app)
|
| 21 |
+
|
| 22 |
+
# --- 全局状态和 AI 初始化 ---
|
| 23 |
+
|
| 24 |
+
# 默认参数 (与 alphazero.py 中的 args 保持一致)
|
| 25 |
+
args = dotdict({
|
| 26 |
+
'lr': 0.001,
|
| 27 |
+
'dropout': 0.1,
|
| 28 |
+
'epochs': 10,
|
| 29 |
+
'batch_size': 64,
|
| 30 |
+
'cuda': torch.cuda.is_available(),
|
| 31 |
+
'num_channels': 512,
|
| 32 |
+
'numIters': 200,
|
| 33 |
+
'numEps': 100,
|
| 34 |
+
'tempThreshold': 15,
|
| 35 |
+
'updateThreshold': 0.6,
|
| 36 |
+
'maxlenOfQueue': 200000,
|
| 37 |
+
'numItersForTrainExamplesHistory': 20,
|
| 38 |
+
'numMCTSSims': 25, # 训练时的 MCTS 模拟次数
|
| 39 |
+
'arenaCompare': 40,
|
| 40 |
+
'cpuct': 1,
|
| 41 |
+
'checkpoint': './temp/',
|
| 42 |
+
'load_model': True,
|
| 43 |
+
'load_folder_file': ('./temp/','best.pth.tar'),
|
| 44 |
+
'board_size': 8 # 默认 8x8
|
| 45 |
+
})
|
| 46 |
+
|
| 47 |
+
# 游戏和 AI 实例
|
| 48 |
+
game = None
|
| 49 |
+
nnet = None
|
| 50 |
+
mcts = None
|
| 51 |
+
|
| 52 |
+
# 游戏状态
|
| 53 |
+
current_board = None
|
| 54 |
+
current_player = 1 # 1: Human (White), -1: AI (Black)
|
| 55 |
+
last_move_coords = None
|
| 56 |
+
board_size = 8
|
| 57 |
+
|
| 58 |
+
def init_game_and_ai(n):
|
| 59 |
+
"""根据板子大小初始化游戏和 AI 模块"""
|
| 60 |
+
global game, nnet, mcts, board_size
|
| 61 |
+
board_size = n
|
| 62 |
+
log.info(f"Initializing game and AI for {n}x{n} board.")
|
| 63 |
+
game = OthelloGame(n) #
|
| 64 |
+
|
| 65 |
+
# 注意:AlphaZero 模型训练通常针对固定尺寸。
|
| 66 |
+
# 如果你的模型只支持 8x8,这里需要进行处理或重新训练。
|
| 67 |
+
# 这里我们假设模型支持当前尺寸 n。
|
| 68 |
+
|
| 69 |
+
# 重新配置 MCTS 参数用于 Play 模式
|
| 70 |
+
play_args = dotdict({
|
| 71 |
+
'numMCTSSims': 200, # 对战时使用更多的模拟次数
|
| 72 |
+
'cpuct': 1.0,
|
| 73 |
+
'cuda': args.cuda # 继承 CUDA 设置
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
nnet = NNetWrapper(game, args) #
|
| 77 |
+
# 假设你的模型文件已保存到 './checkpoint/best.pth.tar'
|
| 78 |
+
try:
|
| 79 |
+
load_folder = args.load_folder_file[0]
|
| 80 |
+
load_file = args.load_folder_file[1]
|
| 81 |
+
nnet.load_checkpoint(folder=load_folder, filename=load_file)
|
| 82 |
+
log.info(f"Successfully loaded model from {load_folder}{load_file}")
|
| 83 |
+
except ValueError as e:
|
| 84 |
+
log.error(f"Failed to load model: {e}. AI will likely perform poorly.")
|
| 85 |
+
|
| 86 |
+
mcts = MCTS(game, nnet, play_args) #
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_api_moves(board, player):
|
| 90 |
+
"""将 getValidMoves 结果从向量转换为 {x, y} 列表"""
|
| 91 |
+
if game is None: return []
|
| 92 |
+
|
| 93 |
+
valids = game.getValidMoves(board, player) #
|
| 94 |
+
moves_list = []
|
| 95 |
+
# 排除最后一个动作(Pass动作)
|
| 96 |
+
for i in range(len(valids) - 1): #
|
| 97 |
+
if valids[i] == 1:
|
| 98 |
+
x = i // game.n
|
| 99 |
+
y = i % game.n
|
| 100 |
+
moves_list.append({'x': int(x), 'y': int(y)})
|
| 101 |
+
return moves_list
|
| 102 |
+
|
| 103 |
+
def check_game_end(board, player):
|
| 104 |
+
"""检查游戏是否结束,并返回状态信息,基于绝对的棋子数量差异。"""
|
| 105 |
+
|
| 106 |
+
# 获取游戏结束的相对结果 (1: player 赢, -1: player 输, 0: 平局)
|
| 107 |
+
# 注意:这个结果是相对于传入的 player 而言的
|
| 108 |
+
result = game.getGameEnded(board, player) #
|
| 109 |
+
|
| 110 |
+
status = 'Ongoing'
|
| 111 |
+
score_diff = 0
|
| 112 |
+
|
| 113 |
+
if result is not None:
|
| 114 |
+
# 获取白棋 (1) 和黑棋 (-1) 的绝对分数。
|
| 115 |
+
# 这里的 score_diff 是:(白棋数量 - 黑棋数量)
|
| 116 |
+
white_count = np.sum(board == 1)
|
| 117 |
+
black_count = np.sum(board == -1)
|
| 118 |
+
score_diff = int(white_count - black_count)
|
| 119 |
+
|
| 120 |
+
if result == 0:
|
| 121 |
+
status = f"Game Over: Draw. Score: {white_count} vs {black_count}"
|
| 122 |
+
elif score_diff > 0:
|
| 123 |
+
# 白棋 (Human) 数量多,人赢
|
| 124 |
+
status = f"Game Over: Human (O) Wins! Score: {white_count} vs {black_count}"
|
| 125 |
+
elif score_diff < 0:
|
| 126 |
+
# 黑棋 (AI) 数量多,AI 赢
|
| 127 |
+
status = f"Game Over: AI (X) Wins! Score: {white_count} vs {black_count}"
|
| 128 |
+
else:
|
| 129 |
+
# 理论上 result != 0 时分数不会为 0,但以防万一
|
| 130 |
+
status = f"Game Over: Draw. Score: {white_count} vs {black_count}"
|
| 131 |
+
|
| 132 |
+
return status
|
| 133 |
+
|
| 134 |
+
@app.route('/api/game/new', methods=['POST'])
|
| 135 |
+
def new_game():
|
| 136 |
+
global current_board, current_player, last_move_coords, board_size
|
| 137 |
+
data = request.json
|
| 138 |
+
size = data.get('size', 8)
|
| 139 |
+
|
| 140 |
+
# 【新增代码】接收 first_player 参数,默认为 1 (Human)
|
| 141 |
+
first_player = data.get('first_player', 1)
|
| 142 |
+
|
| 143 |
+
if game is None or size != board_size:
|
| 144 |
+
init_game_and_ai(size)
|
| 145 |
+
|
| 146 |
+
current_board = game.getInitBoard() #
|
| 147 |
+
current_player = first_player # 【修改】使用接收到的 first_player 设置当前玩家
|
| 148 |
+
last_move_coords = None
|
| 149 |
+
|
| 150 |
+
status = check_game_end(current_board, current_player)
|
| 151 |
+
|
| 152 |
+
# 【新增逻辑】如果 AI 先手,立即触发 AI 移动
|
| 153 |
+
if current_player == -1 and status == 'Ongoing':
|
| 154 |
+
return ai_move_logic() # 直接调用 AI 逻辑并返回结果
|
| 155 |
+
# 对current_board进行flip,以符合前端显示习惯
|
| 156 |
+
current_board = np.flip(current_board, 0)
|
| 157 |
+
|
| 158 |
+
return jsonify({
|
| 159 |
+
'board': current_board.tolist(),
|
| 160 |
+
'legal_moves': get_api_moves(current_board, current_player),
|
| 161 |
+
'current_player': current_player,
|
| 162 |
+
'last_move': last_move_coords,
|
| 163 |
+
'status': status,
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# @app.route('/api/game/human_move', methods=['POST'])
|
| 168 |
+
# def human_move():
|
| 169 |
+
# """处理人类玩家移动"""
|
| 170 |
+
# global current_board, current_player, last_move_coords
|
| 171 |
+
|
| 172 |
+
# if current_player != 1 or check_game_end(current_board, current_player) != 'Ongoing':
|
| 173 |
+
# return jsonify({'error': 'Not your turn or game is over'}), 400
|
| 174 |
+
|
| 175 |
+
# data = request.json
|
| 176 |
+
# x = data.get('x')
|
| 177 |
+
# y = data.get('y')
|
| 178 |
+
|
| 179 |
+
# if x is None or y is None:
|
| 180 |
+
# # 检查是否是 Pass 动作
|
| 181 |
+
# if data.get('action') == 'pass':
|
| 182 |
+
# action = game.n * game.n # Pass action is the last index
|
| 183 |
+
# else:
|
| 184 |
+
# return jsonify({'error': 'Invalid move coordinates'}), 400
|
| 185 |
+
# else:
|
| 186 |
+
# action = game.n * x + y
|
| 187 |
+
|
| 188 |
+
# valids = game.getValidMoves(current_board, 1) #
|
| 189 |
+
# if valids[action] == 0:
|
| 190 |
+
# return jsonify({'error': 'Illegal move'}), 400
|
| 191 |
+
|
| 192 |
+
# current_board, current_player = game.getNextState(current_board, 1, action) #
|
| 193 |
+
|
| 194 |
+
# if action != game.n * game.n:
|
| 195 |
+
# last_move_coords = {'x': x, 'y': y}
|
| 196 |
+
|
| 197 |
+
# status = check_game_end(current_board, current_player)
|
| 198 |
+
|
| 199 |
+
# # 如果游戏未结束且轮到 AI (-1)
|
| 200 |
+
# if status == 'Ongoing' and current_player == -1:
|
| 201 |
+
# # 在这里触发 AI 移动
|
| 202 |
+
# return ai_move_logic()
|
| 203 |
+
|
| 204 |
+
# return jsonify({
|
| 205 |
+
# 'board': current_board.tolist(),
|
| 206 |
+
# 'legal_moves': get_api_moves(current_board, current_player),
|
| 207 |
+
# 'current_player': current_player,
|
| 208 |
+
# 'last_move': last_move_coords,
|
| 209 |
+
# 'status': status,
|
| 210 |
+
# })
|
| 211 |
+
|
| 212 |
+
def ai_move_logic():
|
| 213 |
+
"""AI 移动的逻辑封装,在 human_move 中调用"""
|
| 214 |
+
global current_board, current_player, last_move_coords
|
| 215 |
+
|
| 216 |
+
canonical_board = game.getCanonicalForm(current_board, -1) #
|
| 217 |
+
|
| 218 |
+
# 获取 AI 的最佳动作 (temp=0)
|
| 219 |
+
ai_action = np.argmax(mcts.getActionProb(canonical_board, temp=0)) #
|
| 220 |
+
|
| 221 |
+
# 更新游戏状态
|
| 222 |
+
current_board, next_player = game.getNextState(current_board, -1, ai_action) #
|
| 223 |
+
current_player = next_player
|
| 224 |
+
|
| 225 |
+
# 记录 AI 的移动坐标
|
| 226 |
+
if ai_action != game.n * game.n: # 如果不是 Pass 动作
|
| 227 |
+
ai_x = ai_action // game.n
|
| 228 |
+
ai_y = ai_action % game.n
|
| 229 |
+
last_move_coords = {'x': int(ai_x), 'y': int(ai_y)}
|
| 230 |
+
|
| 231 |
+
status = check_game_end(current_board, current_player)
|
| 232 |
+
|
| 233 |
+
# 对current_board进行flip,以符合前端显示习惯
|
| 234 |
+
current_board = np.flip(current_board, 0)
|
| 235 |
+
|
| 236 |
+
return jsonify({
|
| 237 |
+
'board': current_board.tolist(),
|
| 238 |
+
'legal_moves': get_api_moves(current_board, current_player),
|
| 239 |
+
'current_player': current_player,
|
| 240 |
+
'last_move': last_move_coords,
|
| 241 |
+
'status': status,
|
| 242 |
+
})
|
| 243 |
+
|
| 244 |
+
# app.py (在 @app.route('/api/game/human_move', methods=['POST']) 路由下)
|
| 245 |
+
# 替换原有的 handleHumanMove/human_move 函数
|
| 246 |
+
|
| 247 |
+
@app.route('/api/game/human_move', methods=['POST'])
|
| 248 |
+
def human_move():
|
| 249 |
+
"""处理人类玩家移动,并返回给 AI 的中间状态"""
|
| 250 |
+
global current_board, current_player, last_move_coords
|
| 251 |
+
|
| 252 |
+
if current_player != 1 or check_game_end(current_board, current_player) != 'Ongoing':
|
| 253 |
+
return jsonify({'error': 'Not your turn or game is over'}), 400
|
| 254 |
+
|
| 255 |
+
data = request.json
|
| 256 |
+
x = data.get('x')
|
| 257 |
+
y = data.get('y')
|
| 258 |
+
|
| 259 |
+
if x is None or y is None:
|
| 260 |
+
# 检查是否是 Pass 动作
|
| 261 |
+
if data.get('action') == 'pass':
|
| 262 |
+
action = game.n * game.n # Pass action is the last index
|
| 263 |
+
else:
|
| 264 |
+
return jsonify({'error': 'Invalid move coordinates'}), 400
|
| 265 |
+
else:
|
| 266 |
+
action = game.n * x + y
|
| 267 |
+
|
| 268 |
+
valids = game.getValidMoves(current_board, 1)
|
| 269 |
+
if valids[action] == 0:
|
| 270 |
+
return jsonify({'error': 'Illegal move'}), 400
|
| 271 |
+
|
| 272 |
+
# 执行人类移动
|
| 273 |
+
current_board, current_player = game.getNextState(current_board, 1, action)
|
| 274 |
+
|
| 275 |
+
if action != game.n * game.n:
|
| 276 |
+
last_move_coords = {'x': x, 'y': y}
|
| 277 |
+
|
| 278 |
+
status = check_game_end(current_board, current_player)
|
| 279 |
+
|
| 280 |
+
# 对current_board进行flip,以符合前端显示习惯
|
| 281 |
+
# current_board = np.flip(current_board, 0)
|
| 282 |
+
|
| 283 |
+
# 注意:这里不再包含 AI 移动逻辑,直接返回
|
| 284 |
+
return jsonify({
|
| 285 |
+
'board': current_board.tolist(),
|
| 286 |
+
'legal_moves': get_api_moves(current_board, current_player),
|
| 287 |
+
'current_player': current_player,
|
| 288 |
+
'last_move': last_move_coords,
|
| 289 |
+
'status': status,
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# B. 新增 `ai_move` 路由
|
| 294 |
+
|
| 295 |
+
@app.route('/api/game/ai_move', methods=['POST'])
|
| 296 |
+
def ai_move():
|
| 297 |
+
start_time = time.time()
|
| 298 |
+
"""触发 AI 移动,并返回最终状态"""
|
| 299 |
+
global current_board, current_player, last_move_coords
|
| 300 |
+
|
| 301 |
+
if current_player != -1:
|
| 302 |
+
return jsonify({'error': 'Not AI turn'}), 400
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
canonical_board = game.getCanonicalForm(current_board, -1)
|
| 306 |
+
ai_action = np.argmax(mcts.getActionProb(canonical_board, temp=0))
|
| 307 |
+
|
| 308 |
+
# 执行 AI 移动
|
| 309 |
+
current_board, next_player = game.getNextState(current_board, -1, ai_action)
|
| 310 |
+
current_player = next_player
|
| 311 |
+
|
| 312 |
+
# 记录 AI 的移动坐标
|
| 313 |
+
if ai_action != game.n * game.n:
|
| 314 |
+
ai_x = ai_action // game.n
|
| 315 |
+
ai_y = ai_action % game.n
|
| 316 |
+
last_move_coords = {'x': int(ai_x), 'y': int(ai_y)}
|
| 317 |
+
else:
|
| 318 |
+
last_move_coords = None # AI Pass
|
| 319 |
+
|
| 320 |
+
status = check_game_end(current_board, current_player)
|
| 321 |
+
|
| 322 |
+
# 控制 AI 最少思考时间为 0.5 秒
|
| 323 |
+
end_time = time.time()
|
| 324 |
+
used_time = end_time - start_time
|
| 325 |
+
if used_time < 0.5:
|
| 326 |
+
time.sleep(0.5 - used_time) # 确保至少等待0.5秒
|
| 327 |
+
|
| 328 |
+
return jsonify({
|
| 329 |
+
'board': current_board.tolist(),
|
| 330 |
+
'legal_moves': get_api_moves(current_board, current_player),
|
| 331 |
+
'current_player': current_player,
|
| 332 |
+
'last_move': last_move_coords,
|
| 333 |
+
'status': status,
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
if __name__ == '__main__':
|
| 337 |
+
# 初始化一个默认的 8x8 游戏实例
|
| 338 |
+
init_game_and_ai(8)
|
| 339 |
+
log.info("Starting Flask server on port 7860...")
|
| 340 |
+
|
| 341 |
+
port = int(os.environ.get('PORT', 7860))
|
| 342 |
+
# ... (日志) ...
|
| 343 |
+
app.run(host='0.0.0.0', port=port)
|
game.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import logging
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
log = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
class Board():
|
| 7 |
+
'''
|
| 8 |
+
Author: Eric P. Nichols
|
| 9 |
+
Date: Feb 8, 2008.
|
| 10 |
+
Board class.
|
| 11 |
+
Board data:
|
| 12 |
+
1=white, -1=black, 0=empty
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
# list of all 8 directions on the board, as (x,y) offsets
|
| 16 |
+
__directions = [(1,1),(1,0),(1,-1),(0,-1),(-1,-1),(-1,0),(-1,1),(0,1)]
|
| 17 |
+
|
| 18 |
+
def __init__(self, n):
|
| 19 |
+
"Set up initial board configuration."
|
| 20 |
+
|
| 21 |
+
self.n = n
|
| 22 |
+
# Create the empty board array.
|
| 23 |
+
self.pieces = [None]*self.n
|
| 24 |
+
for i in range(self.n):
|
| 25 |
+
self.pieces[i] = [0]*self.n
|
| 26 |
+
|
| 27 |
+
# Set up the initial 4 pieces.
|
| 28 |
+
self.pieces[int(self.n/2)-1][int(self.n/2)] = 1
|
| 29 |
+
self.pieces[int(self.n/2)][int(self.n/2)-1] = 1
|
| 30 |
+
self.pieces[int(self.n/2)-1][int(self.n/2)-1] = -1;
|
| 31 |
+
self.pieces[int(self.n/2)][int(self.n/2)] = -1;
|
| 32 |
+
|
| 33 |
+
# add [][] indexer syntax to the Board
|
| 34 |
+
def __getitem__(self, index):
|
| 35 |
+
return self.pieces[index]
|
| 36 |
+
|
| 37 |
+
def countDiff(self, color):
|
| 38 |
+
"""Counts the # pieces of the given color
|
| 39 |
+
(1 for white, -1 for black, 0 for empty spaces)"""
|
| 40 |
+
count = 0
|
| 41 |
+
for y in range(self.n):
|
| 42 |
+
for x in range(self.n):
|
| 43 |
+
if self[x][y]==color:
|
| 44 |
+
count += 1
|
| 45 |
+
if self[x][y]==-color:
|
| 46 |
+
count -= 1
|
| 47 |
+
return count
|
| 48 |
+
|
| 49 |
+
def get_legal_moves(self, color):
|
| 50 |
+
"""Returns all the legal moves for the given color.
|
| 51 |
+
(1 for white, -1 for black)
|
| 52 |
+
"""
|
| 53 |
+
moves = set() # stores the legal moves.
|
| 54 |
+
|
| 55 |
+
# Get all the squares with pieces of the given color.
|
| 56 |
+
for y in range(self.n):
|
| 57 |
+
for x in range(self.n):
|
| 58 |
+
if self[x][y]==color:
|
| 59 |
+
newmoves = self.get_moves_for_square((x,y))
|
| 60 |
+
moves.update(newmoves)
|
| 61 |
+
return list(moves)
|
| 62 |
+
|
| 63 |
+
def has_legal_moves(self, color):
|
| 64 |
+
for y in range(self.n):
|
| 65 |
+
for x in range(self.n):
|
| 66 |
+
if self[x][y]==color:
|
| 67 |
+
newmoves = self.get_moves_for_square((x,y))
|
| 68 |
+
if len(newmoves)>0:
|
| 69 |
+
return True
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
def get_moves_for_square(self, square):
|
| 73 |
+
"""Returns all the legal moves that use the given square as a base.
|
| 74 |
+
That is, if the given square is (3,4) and it contains a black piece,
|
| 75 |
+
and (3,5) and (3,6) contain white pieces, and (3,7) is empty, one
|
| 76 |
+
of the returned moves is (3,7) because everything from there to (3,4)
|
| 77 |
+
is flipped.
|
| 78 |
+
"""
|
| 79 |
+
(x,y) = square
|
| 80 |
+
|
| 81 |
+
# determine the color of the piece.
|
| 82 |
+
color = self[x][y]
|
| 83 |
+
|
| 84 |
+
# skip empty source squares.
|
| 85 |
+
if color==0:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
# search all possible directions.
|
| 89 |
+
moves = []
|
| 90 |
+
for direction in self.__directions:
|
| 91 |
+
move = self._discover_move(square, direction)
|
| 92 |
+
if move:
|
| 93 |
+
moves.append(move)
|
| 94 |
+
|
| 95 |
+
# return the generated move list
|
| 96 |
+
return moves
|
| 97 |
+
|
| 98 |
+
def execute_move(self, move, color):
|
| 99 |
+
"""Perform the given move on the board; flips pieces as necessary.
|
| 100 |
+
color gives the color of the piece to play (1=white,-1=black)
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
#Much like move generation, start at the new piece's square and
|
| 104 |
+
#follow it on all 8 directions to look for a piece allowing flipping.
|
| 105 |
+
|
| 106 |
+
flips = [flip for direction in self.__directions
|
| 107 |
+
for flip in self._get_flips(move, direction, color)]
|
| 108 |
+
assert len(list(flips))>0
|
| 109 |
+
for x, y in flips:
|
| 110 |
+
self[x][y] = color
|
| 111 |
+
|
| 112 |
+
def _discover_move(self, origin, direction):
|
| 113 |
+
""" Returns the endpoint for a legal move, starting at the given origin,
|
| 114 |
+
moving by the given increment."""
|
| 115 |
+
x, y = origin
|
| 116 |
+
color = self[x][y]
|
| 117 |
+
flips = []
|
| 118 |
+
|
| 119 |
+
for x, y in Board._increment_move(origin, direction, self.n):
|
| 120 |
+
if self[x][y] == 0:
|
| 121 |
+
if flips:
|
| 122 |
+
return (x, y)
|
| 123 |
+
else:
|
| 124 |
+
return None
|
| 125 |
+
elif self[x][y] == color:
|
| 126 |
+
return None
|
| 127 |
+
elif self[x][y] == -color:
|
| 128 |
+
flips.append((x, y))
|
| 129 |
+
|
| 130 |
+
def _get_flips(self, origin, direction, color):
|
| 131 |
+
""" Gets the list of flips for a vertex and direction to use with the
|
| 132 |
+
execute_move function """
|
| 133 |
+
#initialize variables
|
| 134 |
+
flips = [origin]
|
| 135 |
+
|
| 136 |
+
for x, y in Board._increment_move(origin, direction, self.n):
|
| 137 |
+
if self[x][y] == 0:
|
| 138 |
+
return []
|
| 139 |
+
if self[x][y] == -color:
|
| 140 |
+
flips.append((x, y))
|
| 141 |
+
elif self[x][y] == color and len(flips) > 0:
|
| 142 |
+
return flips
|
| 143 |
+
|
| 144 |
+
return []
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def _increment_move(move, direction, n):
|
| 148 |
+
""" Generator expression for incrementing moves """
|
| 149 |
+
move = list(map(sum, zip(move, direction)))
|
| 150 |
+
#move = (move[0]+direction[0], move[1]+direction[1])
|
| 151 |
+
while all(map(lambda x: 0 <= x < n, move)):
|
| 152 |
+
#while 0<=move[0] and move[0]<n and 0<=move[1] and move[1]<n:
|
| 153 |
+
yield move
|
| 154 |
+
move=list(map(sum,zip(move,direction)))
|
| 155 |
+
#move = (move[0]+direction[0],move[1]+direction[1])
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class OthelloGame():
|
| 159 |
+
square_content = {
|
| 160 |
+
-1: "X",
|
| 161 |
+
+0: "-",
|
| 162 |
+
+1: "O"
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def getSquarePiece(piece):
|
| 167 |
+
return OthelloGame.square_content[piece]
|
| 168 |
+
|
| 169 |
+
def __init__(self, n):
|
| 170 |
+
self.n = n
|
| 171 |
+
|
| 172 |
+
def getInitBoard(self):
|
| 173 |
+
# return initial board (numpy board)
|
| 174 |
+
b = Board(self.n)
|
| 175 |
+
return np.array(b.pieces)
|
| 176 |
+
|
| 177 |
+
def getBoardSize(self):
|
| 178 |
+
# (a,b) tuple
|
| 179 |
+
return (self.n, self.n)
|
| 180 |
+
|
| 181 |
+
def getActionSize(self):
|
| 182 |
+
# return number of actions
|
| 183 |
+
return self.n*self.n + 1
|
| 184 |
+
|
| 185 |
+
def getNextState(self, board, player, action):
|
| 186 |
+
# if player takes action on board, return next (board,player)
|
| 187 |
+
# action must be a valid move
|
| 188 |
+
if action == self.n*self.n:
|
| 189 |
+
return (board, -player)
|
| 190 |
+
b = Board(self.n)
|
| 191 |
+
b.pieces = np.copy(board)
|
| 192 |
+
move = (int(action/self.n), action%self.n)
|
| 193 |
+
b.execute_move(move, player)
|
| 194 |
+
return (b.pieces, -player)
|
| 195 |
+
|
| 196 |
+
def getValidMoves(self, board, player):
|
| 197 |
+
# return a fixed size binary vector
|
| 198 |
+
valids = [0]*self.getActionSize()
|
| 199 |
+
b = Board(self.n)
|
| 200 |
+
b.pieces = np.copy(board)
|
| 201 |
+
legalMoves = b.get_legal_moves(player)
|
| 202 |
+
if len(legalMoves)==0:
|
| 203 |
+
valids[-1]=1
|
| 204 |
+
return np.array(valids)
|
| 205 |
+
for x, y in legalMoves:
|
| 206 |
+
valids[self.n*x+y]=1
|
| 207 |
+
return np.array(valids)
|
| 208 |
+
|
| 209 |
+
def getGameEnded(self, board, player):
|
| 210 |
+
# return None if not ended, 1 if player won, -1 if player lost, 0 if draw.
|
| 211 |
+
b = Board(self.n)
|
| 212 |
+
b.pieces = np.copy(board)
|
| 213 |
+
if b.has_legal_moves(player):
|
| 214 |
+
return None
|
| 215 |
+
if b.has_legal_moves(-player):
|
| 216 |
+
return None
|
| 217 |
+
if b.countDiff(player) > 0:
|
| 218 |
+
return 1
|
| 219 |
+
elif b.countDiff(player) < 0:
|
| 220 |
+
return -1
|
| 221 |
+
else:
|
| 222 |
+
return 0
|
| 223 |
+
|
| 224 |
+
def getCanonicalForm(self, board, player):
|
| 225 |
+
# return state if player==1, else return -state if player==-1
|
| 226 |
+
return player*board
|
| 227 |
+
|
| 228 |
+
def getSymmetries(self, board, pi):
|
| 229 |
+
# mirror, rotational
|
| 230 |
+
assert(len(pi) == self.n**2+1) # 1 for pass
|
| 231 |
+
pi_board = np.reshape(pi[:-1], (self.n, self.n))
|
| 232 |
+
l = []
|
| 233 |
+
|
| 234 |
+
for i in range(1, 5):
|
| 235 |
+
for j in [True, False]:
|
| 236 |
+
newB = np.rot90(board, i)
|
| 237 |
+
newPi = np.rot90(pi_board, i)
|
| 238 |
+
if j:
|
| 239 |
+
newB = np.fliplr(newB)
|
| 240 |
+
newPi = np.fliplr(newPi)
|
| 241 |
+
l += [(newB, list(newPi.ravel()) + [pi[-1]])]
|
| 242 |
+
return l
|
| 243 |
+
|
| 244 |
+
def stringRepresentation(self, board):
|
| 245 |
+
return board.tostring()
|
| 246 |
+
|
| 247 |
+
def stringRepresentationReadable(self, board):
|
| 248 |
+
board_s = "".join(self.square_content[square] for row in board for square in row)
|
| 249 |
+
return board_s
|
| 250 |
+
|
| 251 |
+
def getScore(self, board, player):
|
| 252 |
+
b = Board(self.n)
|
| 253 |
+
b.pieces = np.copy(board)
|
| 254 |
+
return b.countDiff(player)
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def display(board):
|
| 258 |
+
n = board.shape[0]
|
| 259 |
+
print(" ", end="")
|
| 260 |
+
for y in range(n):
|
| 261 |
+
print(y, end=" ")
|
| 262 |
+
print("")
|
| 263 |
+
print("-----------------------")
|
| 264 |
+
for y in range(n):
|
| 265 |
+
print(y, "|", end="")
|
| 266 |
+
for x in range(n):
|
| 267 |
+
piece = board[y][x]
|
| 268 |
+
print(OthelloGame.square_content[piece], end=" ")
|
| 269 |
+
print("|")
|
| 270 |
+
|
| 271 |
+
print("-----------------------")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class RandomPlayer():
|
| 275 |
+
def __init__(self, game):
|
| 276 |
+
self.game = game
|
| 277 |
+
|
| 278 |
+
def play(self, board):
|
| 279 |
+
a = np.random.randint(self.game.getActionSize())
|
| 280 |
+
valids = self.game.getValidMoves(board, 1)
|
| 281 |
+
while valids[a]!=1:
|
| 282 |
+
a = np.random.randint(self.game.getActionSize())
|
| 283 |
+
return a
|
| 284 |
+
|
| 285 |
+
class GreedyOthelloPlayer():
|
| 286 |
+
def __init__(self, game):
|
| 287 |
+
self.game = game
|
| 288 |
+
|
| 289 |
+
def play(self, board):
|
| 290 |
+
valids = self.game.getValidMoves(board, 1)
|
| 291 |
+
candidates = []
|
| 292 |
+
for a in range(self.game.getActionSize()):
|
| 293 |
+
if valids[a]==0:
|
| 294 |
+
continue
|
| 295 |
+
nextBoard, _ = self.game.getNextState(board, 1, a)
|
| 296 |
+
score = self.game.getScore(nextBoard, 1)
|
| 297 |
+
candidates += [(-score, a)]
|
| 298 |
+
candidates.sort()
|
| 299 |
+
return candidates[0][1]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class HumanOthelloPlayer():
|
| 303 |
+
def __init__(self, game):
|
| 304 |
+
self.game = game
|
| 305 |
+
|
| 306 |
+
def play(self, board):
|
| 307 |
+
# display(board)
|
| 308 |
+
valid = self.game.getValidMoves(board, 1)
|
| 309 |
+
for i in range(len(valid)):
|
| 310 |
+
if valid[i]:
|
| 311 |
+
print("[", int(i/self.game.n), int(i%self.game.n), end="] ")
|
| 312 |
+
while True:
|
| 313 |
+
input_move = input()
|
| 314 |
+
input_a = input_move.split(" ")
|
| 315 |
+
if len(input_a) == 2:
|
| 316 |
+
try:
|
| 317 |
+
x,y = [int(i) for i in input_a]
|
| 318 |
+
if ((0 <= x) and (x < self.game.n) and (0 <= y) and (y < self.game.n)) or \
|
| 319 |
+
((x == self.game.n) and (y == 0)):
|
| 320 |
+
a = self.game.n * x + y
|
| 321 |
+
if valid[a]:
|
| 322 |
+
break
|
| 323 |
+
except ValueError:
|
| 324 |
+
'Invalid integer'
|
| 325 |
+
print('Invalid move')
|
| 326 |
+
return a
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Arena():
|
| 330 |
+
"""
|
| 331 |
+
An Arena class where any 2 agents can be pit against each other.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def __init__(self, player1, player2, game, display=None):
|
| 335 |
+
"""
|
| 336 |
+
Input:
|
| 337 |
+
player 1,2: two functions that takes board as input, return action
|
| 338 |
+
game: Game object
|
| 339 |
+
display: a function that takes board as input and prints it. Is necessary for verbose
|
| 340 |
+
mode.
|
| 341 |
+
"""
|
| 342 |
+
self.player1 = player1
|
| 343 |
+
self.player2 = player2
|
| 344 |
+
self.game = game
|
| 345 |
+
self.display = display
|
| 346 |
+
|
| 347 |
+
def playGame(self, verbose=False):
|
| 348 |
+
"""
|
| 349 |
+
Executes one episode of a game.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
either
|
| 353 |
+
winner: player who won the game (1 if player1, -1 if player2, 0 if draw)
|
| 354 |
+
"""
|
| 355 |
+
players = [self.player2, None, self.player1]
|
| 356 |
+
curPlayer = 1 # player1 go first
|
| 357 |
+
board = self.game.getInitBoard()
|
| 358 |
+
it = 0
|
| 359 |
+
while self.game.getGameEnded(board, curPlayer) is None:
|
| 360 |
+
it += 1
|
| 361 |
+
if verbose:
|
| 362 |
+
assert self.display
|
| 363 |
+
print("Turn ", str(it), "Player ", str(curPlayer))
|
| 364 |
+
self.display(board)
|
| 365 |
+
action = players[curPlayer + 1](self.game.getCanonicalForm(board, curPlayer))
|
| 366 |
+
|
| 367 |
+
valids = self.game.getValidMoves(self.game.getCanonicalForm(board, curPlayer), 1)
|
| 368 |
+
|
| 369 |
+
if valids[action] == 0:
|
| 370 |
+
log.error(f'Action {action} is not valid!')
|
| 371 |
+
log.debug(f'valids = {valids}')
|
| 372 |
+
assert valids[action] > 0
|
| 373 |
+
board, curPlayer = self.game.getNextState(board, curPlayer, action)
|
| 374 |
+
result = curPlayer * self.game.getGameEnded(board, curPlayer)
|
| 375 |
+
if verbose:
|
| 376 |
+
assert self.display
|
| 377 |
+
print("Game over: Turn ", str(it), "Result ", str(result))
|
| 378 |
+
self.display(board)
|
| 379 |
+
return result
|
| 380 |
+
|
| 381 |
+
def playGames(self, num, verbose=False):
|
| 382 |
+
"""
|
| 383 |
+
Plays num games in which player1 starts num/2 games and player2 starts
|
| 384 |
+
num/2 games.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
oneWon: games won by player1
|
| 388 |
+
twoWon: games won by player2
|
| 389 |
+
draws: games won by nobody
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
num = int(num / 2)
|
| 393 |
+
oneWon = 0
|
| 394 |
+
twoWon = 0
|
| 395 |
+
draws = 0
|
| 396 |
+
for _ in tqdm(range(num), desc="Arena.playGames (player1 go first)"):
|
| 397 |
+
gameResult = self.playGame(verbose=verbose)
|
| 398 |
+
if gameResult == 1:
|
| 399 |
+
oneWon += 1
|
| 400 |
+
elif gameResult == -1:
|
| 401 |
+
twoWon += 1
|
| 402 |
+
else:
|
| 403 |
+
draws += 1
|
| 404 |
+
|
| 405 |
+
self.player1, self.player2 = self.player2, self.player1
|
| 406 |
+
|
| 407 |
+
for _ in tqdm(range(num), desc="Arena.playGames (player2 go first)"):
|
| 408 |
+
gameResult = self.playGame(verbose=verbose)
|
| 409 |
+
if gameResult == -1:
|
| 410 |
+
oneWon += 1
|
| 411 |
+
elif gameResult == 1:
|
| 412 |
+
twoWon += 1
|
| 413 |
+
else:
|
| 414 |
+
draws += 1
|
| 415 |
+
|
| 416 |
+
return oneWon, twoWon, draws
|
| 417 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# othello-backend/requirements.txt
|
| 2 |
+
# 核心 Web 框架
|
| 3 |
+
flask
|
| 4 |
+
Flask-CORS
|
| 5 |
+
|
| 6 |
+
# AI/模型依赖
|
| 7 |
+
numpy
|
| 8 |
+
torch
|
| 9 |
+
|
| 10 |
+
# 其他辅助库
|
| 11 |
+
tqdm
|
| 12 |
+
logging
|