ffzeroHua commited on
Commit
201fbbb
·
verified ·
1 Parent(s): 496825e

Upload 7 files

Browse files
Files changed (8) hide show
  1. .gitattributes +2 -0
  2. Dockerfile +29 -0
  3. app.py +330 -0
  4. libriichi3p.so +3 -0
  5. libriichiSanma.so +3 -0
  6. model3pLOCAL.py +452 -0
  7. model3pNEW.py +445 -0
  8. requirements.txt +7 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ libriichi3p.so filter=lfs diff=lfs merge=lfs -text
37
+ libriichiSanma.so filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. 使用轻量级的 Python 3.10 基础镜像
2
+ FROM python:3.12-slim
3
+
4
+ # 2. 设置环境变量,防止 python 缓冲 stdout 导致日志延迟
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ # 3. [针对 Hugging Face 空间的特殊设置]
8
+ # 创建一个非 root 用户 user,UID 设置为 1000
9
+ RUN useradd -m -u 1000 user
10
+ USER user
11
+ ENV PATH="/home/user/.local/bin:$PATH"
12
+
13
+ # 4. 设置工作目录
14
+ WORKDIR /app
15
+
16
+ # 5. 复制 requirements.txt 并安装依赖
17
+ # (先复制这个文件可以利用 Docker 的缓存机制,加快后续构建速度)
18
+ COPY --chown=user:user requirements.txt /app/
19
+ # 强烈建议安装 CPU 版本的 PyTorch 以大幅缩减镜像体积
20
+ RUN pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
21
+
22
+ # 6. 复制所有项目文件到工作目录下
23
+ COPY --chown=user:user . /app/
24
+
25
+ # 7. 暴露 Gradio 默认端口
26
+ EXPOSE 7860
27
+
28
+ # 8. 启动应用
29
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import orjson
3
+ import concurrent.futures
4
+ import random
5
+ import torch
6
+ import threading
7
+ import time
8
+ import uuid
9
+ import glob
10
+ import gradio as gr
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ from huggingface_hub import snapshot_download, hf_hub_download, HfApi
14
+
15
+ from riichienv import RiichiEnv, GameRule
16
+
17
+ # 分别导入两个不同架构的加载函数,防止命名冲突
18
+ from model3pLOCAL import load_model as load_model_local
19
+ from model3pNEW import load_model as load_model_new
20
+
21
+ # ==========================================
22
+ # 0. 核心对抗配置开关 (在这里切换模式)
23
+ # ==========================================
24
+ # True: 1个 NEW架构(TEST_MODEL) VS 2个 LOCAL架构(EXAMINER_MODEL)
25
+ # False: 1个 LOCAL架构(TEST_MODEL) VS 2个 NEW架构(EXAMINER_MODEL)
26
+ ONE_NEW_VS_TWO_LOCAL = True
27
+
28
+ # ==========================================
29
+ # 0. 分布式多开与云端持久化配置
30
+ # ==========================================
31
+ DATA_REPO_ID = "ffzeroHua/mj-eval-results" # 📊 战绩数据集仓库
32
+ MODEL_REPO_ID = "ffzeroHua/Riichi-Model-Repo" # 🧠 模型权重仓库
33
+ HF_TOKEN = os.getenv("HF_TOKEN")
34
+
35
+ # 为当前节点生成唯一的 ID
36
+ WORKER_ID = os.getenv("WORKER_ID", str(uuid.uuid4())[:6])
37
+
38
+ # 根据开关状态自动调整保存的文件前缀
39
+ BASE_REPORT_PREFIX = 'Step40800P42998_vs_9070_eval_report'
40
+ if ONE_NEW_VS_TWO_LOCAL:
41
+ REPORT_FILE_PREFIX = BASE_REPORT_PREFIX
42
+ else:
43
+ REPORT_FILE_PREFIX = f"inverse_{BASE_REPORT_PREFIX}"
44
+
45
+ REPORT_FILE = f"{REPORT_FILE_PREFIX}_{WORKER_ID}.txt"
46
+
47
+ api = HfApi()
48
+ EVAL_RUNNING = True
49
+
50
+ # 🚀 设定要从云端拉取并进行对抗的两个模型
51
+ TEST_MODEL = "Elite3P_Step40800_P42998.pth"
52
+ EXAMINER_MODEL = "Elite4z9070.pth"
53
+
54
+ def sync_models_from_hub():
55
+ """启动时从指定的模型仓库拉取对战双方的权重文件"""
56
+ if HF_TOKEN and "你的用户名" not in MODEL_REPO_ID:
57
+ print(f"☁️ 正在从模型仓库 [{MODEL_REPO_ID}] 拉取评估模型...")
58
+ try:
59
+ hf_hub_download(repo_id=MODEL_REPO_ID, filename=TEST_MODEL, repo_type="model", local_dir=".", token=HF_TOKEN)
60
+ print(f"✅ 成功拉取测试模型: {TEST_MODEL}")
61
+
62
+ hf_hub_download(repo_id=MODEL_REPO_ID, filename=EXAMINER_MODEL, repo_type="model", local_dir=".", token=HF_TOKEN)
63
+ print(f"✅ 成功拉取考官模型: {EXAMINER_MODEL}")
64
+
65
+ print("🎉 模型环境准备完毕!")
66
+ except Exception as e:
67
+ print(f"❌ 拉取模型失败,请检查文件名或仓库权限: {e}")
68
+ else:
69
+ print("⚠️ 未配置有效 HF_TOKEN 或未修改 MODEL_REPO_ID,将尝试使用本地已存在的模型文件。")
70
+
71
+ def sync_data_from_hub():
72
+ """启动时从数据集下载所有节点的战绩分片文件"""
73
+ if HF_TOKEN and "你的用户名" not in DATA_REPO_ID:
74
+ try:
75
+ print(f"🔄 正在从 Hub 拉取全局历史战绩数据 (前缀匹配: {REPORT_FILE_PREFIX})...")
76
+ snapshot_download(
77
+ repo_id=DATA_REPO_ID,
78
+ repo_type="dataset",
79
+ local_dir=".",
80
+ allow_patterns=REPORT_FILE_PREFIX + "_*.txt",
81
+ token=HF_TOKEN
82
+ )
83
+ print("✅ 历史数据拉取完成。")
84
+ except Exception as e:
85
+ print(f"⚠️ 拉取历史战绩失败: {e}")
86
+
87
+ def sync_data_to_hub():
88
+ """将当前节点的战绩文件备份到数据集"""
89
+ if HF_TOKEN and "你的用户名" not in DATA_REPO_ID:
90
+ try:
91
+ api.upload_file(
92
+ path_or_fileobj=REPORT_FILE,
93
+ path_in_repo=REPORT_FILE,
94
+ repo_id=DATA_REPO_ID,
95
+ repo_type="dataset",
96
+ token=HF_TOKEN
97
+ )
98
+ print(f"☁️ 节点 {WORKER_ID} 战绩已同步至 Hub: {time.strftime('%H:%M:%S')}")
99
+ except Exception as e:
100
+ print(f"❌ 同步失败: {e}")
101
+
102
+ # ==========================================
103
+ # 1. 高频及模型加载逻辑
104
+ # ==========================================
105
+ def patch_event_fast(event_str):
106
+ if '"kita"' in event_str:
107
+ event_str = event_str.replace('"kita"', '"nukidora"')
108
+
109
+ if '"start_kyoku"' in event_str or '"deltas"' in event_str:
110
+ event = orjson.loads(event_str)
111
+ if event.get('type') == 'start_kyoku':
112
+ scores = event.setdefault('scores', [])
113
+ while len(scores) < 4: scores.append(0)
114
+ tehais = event.setdefault('tehais', [])
115
+ while len(tehais) < 4: tehais.append(["?" for _ in range(13)])
116
+ if 'deltas' in event:
117
+ deltas = event['deltas']
118
+ while len(deltas) < 4: deltas.append(0)
119
+ return orjson.dumps(event).decode('utf-8')
120
+ return event_str
121
+
122
+ def patch_resp_fast(resp_str):
123
+ if not resp_str: return resp_str
124
+ return resp_str.replace('"nukidora"', '"kita"')
125
+
126
+ _MODEL_CACHE = {}
127
+
128
+ def get_cached_model(player_id: int, model_file: str, arch_type: str):
129
+ """根据指定的架构类型 (new 或 local) 加载模型"""
130
+ key = (player_id, model_file, arch_type)
131
+ if key not in _MODEL_CACHE:
132
+ torch.set_num_threads(1)
133
+ if arch_type == 'new':
134
+ _MODEL_CACHE[key] = load_model_new(player_id, model_file)
135
+ else:
136
+ _MODEL_CACHE[key] = load_model_local(player_id, model_file)
137
+ return _MODEL_CACHE[key]
138
+
139
+ class MortalAgent:
140
+ def __init__(self, player_id: int, model_file: str, arch_type: str):
141
+ self.player_id = player_id
142
+ self.arch_type = arch_type
143
+ self.model = get_cached_model(player_id, model_file, arch_type)
144
+
145
+ def act(self, obs):
146
+ resp = None
147
+ for event in obs.new_events():
148
+ event_patched = patch_event_fast(event)
149
+ resp = patch_resp_fast(self.model.react(event_patched))
150
+ action = obs.select_action_from_mjai(resp)
151
+ assert action is not None, "Mortal must return a legal action"
152
+ return action
153
+
154
+ # ==========================================
155
+ # 2. 核心对局任务
156
+ # ==========================================
157
+ def play_one_game(game_index):
158
+ env = RiichiEnv(game_mode="3p-red-half", rule=GameRule.default_tenhou())
159
+ new_seat = random.randrange(3)
160
+
161
+ agents = {}
162
+ for i in range(3):
163
+ if i == new_seat:
164
+ # 🚀 挑战者位
165
+ model_file = TEST_MODEL
166
+ arch = 'new' if ONE_NEW_VS_TWO_LOCAL else 'local'
167
+ else:
168
+ # 🚀 考官位
169
+ model_file = EXAMINER_MODEL
170
+ arch = 'local' if ONE_NEW_VS_TWO_LOCAL else 'new'
171
+
172
+ agents[i] = MortalAgent(i, model_file, arch)
173
+
174
+ obs_dict = env.reset()
175
+ while not env.done():
176
+ actions = {pid: agents[pid].act(obs) for pid, obs in obs_dict.items()}
177
+ obs_dict = env.step(actions)
178
+
179
+ scores = env.scores()
180
+ ranks = env.ranks()
181
+ return ranks[new_seat], scores[new_seat]
182
+
183
+ # ==========================================
184
+ # 3. 后台独立评估线程
185
+ # ==========================================
186
+ def background_eval_loop():
187
+ sync_models_from_hub() # 🚀 启动时从 Riichi-Model-Repo 拉取对战模型
188
+ sync_data_from_hub() # 🚀 启动时从战绩仓库拉取历史战绩
189
+
190
+ NUM_WORKERS = 1
191
+
192
+ mode_str = "1只 NEW 挑战 2只 LOCAL" if ONE_NEW_VS_TWO_LOCAL else "1只 LOCAL 挑战 2只 NEW"
193
+ print(f"🚀 节点 [{WORKER_ID}] 后台对战线程已启动: 模式为 [{mode_str}]")
194
+
195
+ if not os.path.exists(REPORT_FILE):
196
+ open(REPORT_FILE, 'w').close()
197
+
198
+ games_since_last_sync = 0
199
+
200
+ with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
201
+ futures = {executor.submit(play_one_game, i) for i in range(NUM_WORKERS * 2)}
202
+ games_completed = 0
203
+
204
+ while EVAL_RUNNING and futures:
205
+ done, futures = concurrent.futures.wait(
206
+ futures, return_when=concurrent.futures.FIRST_COMPLETED
207
+ )
208
+
209
+ with open(REPORT_FILE, "a") as f:
210
+ for future in done:
211
+ try:
212
+ rank, score = future.result()
213
+ f.write(f"{rank} {score}\n")
214
+ f.flush()
215
+ games_completed += 1
216
+ games_since_last_sync += 1
217
+ print(f"[节点 {WORKER_ID}] 完成 {games_completed} 局: 顺位 {rank}, 得点 {score}")
218
+ except Exception as e:
219
+ print(f"对局异常: {e}")
220
+
221
+ if EVAL_RUNNING:
222
+ futures.add(executor.submit(play_one_game, games_completed))
223
+
224
+ if games_since_last_sync >= 50:
225
+ sync_data_to_hub()
226
+ sync_data_from_hub()
227
+ games_since_last_sync = 0
228
+
229
+ # ==========================================
230
+ # 4. 前端 Gradio 实时展示面板 (全局汇总)
231
+ # ==========================================
232
+ def read_and_analyze():
233
+ all_files = glob.glob(f"{REPORT_FILE_PREFIX}_*.txt")
234
+
235
+ main_arch = "NEW架构" if ONE_NEW_VS_TWO_LOCAL else "LOCAL架构"
236
+ opp_arch = "LOCAL架构" if ONE_NEW_VS_TWO_LOCAL else "NEW架构"
237
+
238
+ if not all_files:
239
+ return f"⏳ 正在拉取模型并等待 [{main_arch}] `{TEST_MODEL}` VS [{opp_arch}] `{EXAMINER_MODEL}` 第一局完成...", None
240
+
241
+ ranks, scores = [], []
242
+ try:
243
+ for file in all_files:
244
+ with open(file, "r") as f:
245
+ lines = f.readlines()
246
+ for line in lines:
247
+ parts = line.strip().split()
248
+ if len(parts) == 2:
249
+ ranks.append(int(float(parts[0])))
250
+ scores.append(float(parts[1]))
251
+ total = len(ranks)
252
+ if total == 0:
253
+ return f"⏳ 模型已就绪,正在进行第一局对抗...", None
254
+
255
+ avg_rank = sum(ranks) / total
256
+ avg_score = sum(scores) / total
257
+ rank1_rate = ranks.count(1) / total * 100
258
+ rank2_rate = ranks.count(2) / total * 100
259
+ rank3_rate = ranks.count(3) / total * 100
260
+
261
+ last_update = time.strftime('%Y-%m-%d %H:%M:%S')
262
+
263
+ md_text = f"""
264
+ ### 📊 对战简报
265
+ - ⚔️ **对抗阵容:** 1只 `{TEST_MODEL}` ({main_arch}) **VS** 2只 `{EXAMINER_MODEL}` ({opp_arch})
266
+ - 🧮 **总对局数:** {total} 局 (跨节点全局汇集)
267
+ - 🏆 **平均顺位:** {avg_rank:.3f}
268
+ - 💰 **平均得点:** {avg_score:.0f}
269
+ ---
270
+ - 🥇 **一位率:** {rank1_rate:.1f}%
271
+ - 🥈 **二位率:** {rank2_rate:.1f}%
272
+ - 🥉 **三位率:** {rank3_rate:.1f}%
273
+ ---
274
+ - 🌐 **当前节点 ID:** `{WORKER_ID}`
275
+ - 🕒 **刷新时间:** {last_update}
276
+ """
277
+
278
+ fig = plt.figure(figsize=(10, 4))
279
+
280
+ ax1 = fig.add_subplot(121)
281
+ ax1.bar(['1st', '2nd', '3rd'], [rank1_rate, rank2_rate, rank3_rate], color=['#FFD700', '#C0C0C0', '#CD7F32'])
282
+ ax1.set_title(f'Rank Distribution for {TEST_MODEL}')
283
+ ax1.set_ylim(0, max(100, max([rank1_rate, rank2_rate, rank3_rate] + [0]) + 10))
284
+ for i, v in enumerate([rank1_rate, rank2_rate, rank3_rate]):
285
+ ax1.text(i, v + 2, f"{v:.1f}%", ha='center')
286
+
287
+ ax2 = fig.add_subplot(122)
288
+ df = pd.DataFrame({'score': scores})
289
+ df['ma'] = df['score'].rolling(window=min(10, max(1, len(df))), min_periods=1).mean()
290
+ ax2.plot(df['score'], alpha=0.3, color='gray', label='Raw Score')
291
+ ax2.plot(df['ma'], color='crimson', linewidth=2, label='Moving Avg (10)')
292
+ ax2.set_title('Score Trend')
293
+ ax2.legend()
294
+
295
+ plt.tight_layout()
296
+ return md_text, fig
297
+
298
+ except Exception as e:
299
+ return f"❌ 数据解析出错: {e}", None
300
+
301
+ # ==========================================
302
+ # 5. 启动 Gradio 应用
303
+ # ==========================================
304
+ with gr.Blocks() as demo:
305
+ gr.Markdown("# 🀄 Mahjong AI 基准评估舱")
306
+
307
+ header_main = "NEW架构" if ONE_NEW_VS_TWO_LOCAL else "LOCAL架构"
308
+ header_opp = "LOCAL架构" if ONE_NEW_VS_TWO_LOCAL else "NEW架构"
309
+
310
+ gr.Markdown(f"当前正在评估: 1名 **{TEST_MODEL} ({header_main})** 单挑 2名 **{EXAMINER_MODEL} ({header_opp})**。启动时会自动拉取权重。")
311
+
312
+ with gr.Row():
313
+ with gr.Column(scale=1):
314
+ stats_output = gr.Markdown("🚀 正在初始化基准环境并连接模型仓库...")
315
+ refresh_btn = gr.Button("🔄 手动刷新全局战绩")
316
+ with gr.Column(scale=2):
317
+ plot_output = gr.Plot()
318
+
319
+ demo.load(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
320
+
321
+ timer = gr.Timer(15)
322
+ timer.tick(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
323
+
324
+ refresh_btn.click(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
325
+
326
+ if __name__ == "__main__":
327
+ t = threading.Thread(target=background_eval_loop, daemon=True)
328
+ t.start()
329
+
330
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())
libriichi3p.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03900834051021f662fec35c6e9608f4d4c5aa61b4c4ce37b49fa2e861bf619b
3
+ size 1873424
libriichiSanma.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c35adaace110bde0dc896f742b1e1b3ad50213cf7dbafb858a774e46f5b5cf32
3
+ size 3631184
model3pLOCAL.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gzip
3
+ import torch
4
+ import pathlib
5
+ import requests
6
+ import traceback
7
+ import numpy as np
8
+
9
+ from torch import nn, Tensor
10
+ from torch.nn import functional as F
11
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
12
+ from torch.distributions import Normal, Categorical
13
+ from typing import *
14
+ from functools import partial
15
+ from itertools import permutations
16
+ try:
17
+ from libriichi3p.mjai import Bot
18
+ from libriichi3p.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE
19
+ except:
20
+ import importlib.util
21
+ import sys
22
+ import os
23
+
24
+ # ⚠️ 这里必须填入你在 Colab 中的绝对路径!
25
+ # 假设你的文件在云盘的 MahjongTest 文件夹下,名字叫 libriichi3p.so
26
+ # 如果你的文件叫别的名字,或者在别的文件夹,请务必修改这行路径
27
+ SO_FILE_PATH = "/content/drive/MyDrive/MahjongTest/libriichi3p.so"
28
+
29
+ # 1. 检查文件到底存不存在
30
+ if not os.path.exists(SO_FILE_PATH):
31
+ print(f"❌ 致命错误:在路径 {SO_FILE_PATH} 下根本找不到文件!请检查路径拼写。")
32
+ else:
33
+ print(f"✅ 找到文件: {SO_FILE_PATH},正在尝试强行加载...")
34
+
35
+ try:
36
+ # 2. 根据绝对路径创建模块加载规范 (spec)
37
+ # 第一个参数是你想给它起的名字(供 Python 内部识别),第二个参数是文件路径
38
+ spec = importlib.util.spec_from_file_location("libriichi3p", SO_FILE_PATH)
39
+
40
+ # 3. 实例化模块
41
+ libriichi3p_module = importlib.util.module_from_spec(spec)
42
+
43
+ # 4. 注册到系统的模块字典里 (非常重要!这样后续其他文件 import libriichi3p 就能直接用)
44
+ sys.modules["libriichi3p"] = libriichi3p_module
45
+
46
+ # 5. 执行底层代码加载
47
+ spec.loader.exec_module(libriichi3p_module)
48
+
49
+ print("🎉 强行导入成功!现在可以在代码里正常使用了。")
50
+
51
+ except Exception as e:
52
+ print(f"❌ 导入失败,暴露出真实报错: {e}")
53
+ # ========== Online Server =========== #
54
+ OT_REQUEST_TIMEOUT = 2
55
+ ot_settings = {
56
+ "server": "http://example.com",
57
+ "online": False,
58
+ "api_key": "example_api_key",
59
+ }
60
+ is_online = False
61
+
62
+ def online_settings_init():
63
+ global ot_settings
64
+ # Check if the file exists
65
+ if (pathlib.Path(__file__).parent / 'ot_settings.json').exists():
66
+ with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f:
67
+ ot_settings = json.load(f)
68
+
69
+ online_settings_init()
70
+ # ==================================== #
71
+
72
+ class ChannelAttention(nn.Module):
73
+ def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
74
+ super().__init__()
75
+ self.shared_mlp = nn.Sequential(
76
+ nn.Linear(channels, channels // ratio, bias=bias),
77
+ actv_builder(),
78
+ nn.Linear(channels // ratio, channels, bias=bias),
79
+ )
80
+ if bias:
81
+ for mod in self.modules():
82
+ if isinstance(mod, nn.Linear):
83
+ nn.init.constant_(mod.bias, 0)
84
+
85
+ def forward(self, x: Tensor):
86
+ avg_out = self.shared_mlp(x.mean(-1))
87
+ max_out = self.shared_mlp(x.amax(-1))
88
+ weight = (avg_out + max_out).sigmoid()
89
+ x = weight.unsqueeze(-1) * x
90
+ return x
91
+
92
+ class ResBlock(nn.Module):
93
+ def __init__(
94
+ self,
95
+ channels,
96
+ *,
97
+ norm_builder = nn.Identity,
98
+ actv_builder = nn.ReLU,
99
+ pre_actv = False,
100
+ ):
101
+ super().__init__()
102
+ self.pre_actv = pre_actv
103
+
104
+ if pre_actv:
105
+ self.res_unit = nn.Sequential(
106
+ norm_builder(),
107
+ actv_builder(),
108
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
109
+ norm_builder(),
110
+ actv_builder(),
111
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
112
+ )
113
+ else:
114
+ self.res_unit = nn.Sequential(
115
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
116
+ norm_builder(),
117
+ actv_builder(),
118
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
119
+ norm_builder(),
120
+ )
121
+ self.actv = actv_builder()
122
+ self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True)
123
+
124
+ def forward(self, x):
125
+ out = self.res_unit(x)
126
+ out = self.ca(out)
127
+ out = out + x
128
+ if not self.pre_actv:
129
+ out = self.actv(out)
130
+ return out
131
+
132
+ class ResNet(nn.Module):
133
+ def __init__(
134
+ self,
135
+ in_channels,
136
+ conv_channels,
137
+ num_blocks,
138
+ *,
139
+ norm_builder = nn.Identity,
140
+ actv_builder = nn.ReLU,
141
+ pre_actv = False,
142
+ ):
143
+ super().__init__()
144
+
145
+ blocks = []
146
+ for _ in range(num_blocks):
147
+ blocks.append(ResBlock(
148
+ conv_channels,
149
+ norm_builder = norm_builder,
150
+ actv_builder = actv_builder,
151
+ pre_actv = pre_actv,
152
+ ))
153
+
154
+ layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)]
155
+ if pre_actv:
156
+ layers += [*blocks, norm_builder(), actv_builder()]
157
+ else:
158
+ layers += [norm_builder(), actv_builder(), *blocks]
159
+ layers += [
160
+ nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1),
161
+ actv_builder(),
162
+ nn.Flatten(),
163
+ nn.Linear(32 * 34, 1024),
164
+ ]
165
+ self.net = nn.Sequential(*layers)
166
+
167
+ def forward(self, x):
168
+ return self.net(x)
169
+
170
+ class Brain(nn.Module):
171
+ def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
172
+ super().__init__()
173
+ self.is_oracle = is_oracle
174
+ self.version = version
175
+
176
+ in_channels = obs_shape(version)[0]
177
+ if is_oracle:
178
+ in_channels += oracle_obs_shape(version)[0]
179
+
180
+ norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01)
181
+ actv_builder = partial(nn.Mish, inplace=True)
182
+ pre_actv = True
183
+
184
+ match version:
185
+ case 1:
186
+ actv_builder = partial(nn.ReLU, inplace=True)
187
+ pre_actv = False
188
+ self.latent_net = nn.Sequential(
189
+ nn.Linear(1024, 512),
190
+ nn.ReLU(inplace=True),
191
+ )
192
+ self.mu_head = nn.Linear(512, 512)
193
+ self.logsig_head = nn.Linear(512, 512)
194
+ case 2:
195
+ pass
196
+ case 3 | 4:
197
+ norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3)
198
+ case _:
199
+ raise ValueError(f'Unexpected version {self.version}')
200
+
201
+ self.encoder = ResNet(
202
+ in_channels = in_channels,
203
+ conv_channels = conv_channels,
204
+ num_blocks = num_blocks,
205
+ norm_builder = norm_builder,
206
+ actv_builder = actv_builder,
207
+ pre_actv = pre_actv,
208
+ )
209
+ self.actv = actv_builder()
210
+
211
+ # always use EMA or CMA when True
212
+ self._freeze_bn = False
213
+
214
+ def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
215
+ if self.is_oracle:
216
+ assert invisible_obs is not None
217
+ obs = torch.cat((obs, invisible_obs), dim=1)
218
+ phi = self.encoder(obs)
219
+ phi = F.dropout(phi, p=0.1, training=self.training)
220
+ match self.version:
221
+ case 1:
222
+ latent_out = self.latent_net(phi)
223
+ mu = self.mu_head(latent_out)
224
+ logsig = self.logsig_head(latent_out)
225
+ return mu, logsig
226
+ case 2 | 3 | 4:
227
+ return self.actv(phi)
228
+ case _:
229
+ raise ValueError(f'Unexpected version {self.version}')
230
+
231
+ def train(self, mode=True):
232
+ super().train(mode)
233
+ if self._freeze_bn:
234
+ for mod in self.modules():
235
+ if isinstance(mod, nn.BatchNorm1d):
236
+ mod.eval()
237
+ # I don't think this benefits
238
+ # module.requires_grad_(False)
239
+ return self
240
+
241
+ def reset_running_stats(self):
242
+ for mod in self.modules():
243
+ if isinstance(mod, nn.BatchNorm1d):
244
+ mod.reset_running_stats()
245
+
246
+ def freeze_bn(self, value: bool):
247
+ self._freeze_bn = value
248
+ return self.train(self.training)
249
+
250
+ class AuxNet(nn.Module):
251
+ def __init__(self, dims=None):
252
+ super().__init__()
253
+ self.dims = dims
254
+ self.net = nn.Linear(1024, sum(dims), bias=False)
255
+
256
+ def forward(self, x):
257
+ return self.net(x).split(self.dims, dim=-1)
258
+
259
+ class DQN(nn.Module):
260
+ def __init__(self, *, version=1):
261
+ super().__init__()
262
+ self.version = version
263
+ match version:
264
+ case 1:
265
+ self.v_head = nn.Linear(512, 1)
266
+ self.a_head = nn.Linear(512, ACTION_SPACE)
267
+ case 2 | 3:
268
+ hidden_size = 512 if version == 2 else 256
269
+ self.v_head = nn.Sequential(
270
+ nn.Linear(1024, hidden_size),
271
+ nn.Mish(inplace=True),
272
+ nn.Linear(hidden_size, 1),
273
+ )
274
+ self.a_head = nn.Sequential(
275
+ nn.Linear(1024, hidden_size),
276
+ nn.Mish(inplace=True),
277
+ nn.Linear(hidden_size, ACTION_SPACE),
278
+ )
279
+ case 4:
280
+ self.net = nn.Linear(1024, 1 + ACTION_SPACE)
281
+ nn.init.constant_(self.net.bias, 0)
282
+
283
+ def forward(self, phi, mask):
284
+ if self.version == 4:
285
+ v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1)
286
+ else:
287
+ v = self.v_head(phi)
288
+ a = self.a_head(phi)
289
+ a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True)
290
+ mask_sum = mask.sum(-1, keepdim=True)
291
+ a_mean = a_sum / mask_sum
292
+ q = (v + a - a_mean).masked_fill(~mask, -1e9)
293
+ return q
294
+
295
+
296
+ class MortalEngine:
297
+ def __init__(
298
+ self,
299
+ brain,
300
+ dqn,
301
+ is_oracle,
302
+ version,
303
+ device = None,
304
+ stochastic_latent = False,
305
+ enable_amp = False,
306
+ enable_quick_eval = True,
307
+ enable_rule_based_agari_guard = False,
308
+ name = 'NoName',
309
+ boltzmann_epsilon = 0,
310
+ boltzmann_temp = 1,
311
+ top_p = 1,
312
+ ):
313
+ self.engine_type = 'mortal'
314
+ self.device = device or torch.device('cpu')
315
+ assert isinstance(self.device, torch.device)
316
+ self.brain = brain.to(self.device).eval()
317
+ self.dqn = dqn.to(self.device).eval()
318
+ self.is_oracle = is_oracle
319
+ self.version = version
320
+ self.stochastic_latent = stochastic_latent
321
+
322
+ self.enable_amp = enable_amp
323
+ self.enable_quick_eval = enable_quick_eval
324
+ self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
325
+ self.name = name
326
+
327
+ self.boltzmann_epsilon = boltzmann_epsilon
328
+ self.boltzmann_temp = boltzmann_temp
329
+ self.top_p = top_p
330
+
331
+ def react_batch(self, obs, masks, invisible_obs):
332
+ # ========== Online Server =========== #
333
+ global ot_settings, is_online
334
+ # print('Reacting Batch')
335
+ if ot_settings['online']:
336
+ try:
337
+ list_obs = [o.tolist() for o in obs]
338
+ list_masks = [m.tolist() for m in masks]
339
+ post_data = {
340
+ 'obs': list_obs,
341
+ 'masks': list_masks,
342
+ }
343
+ data = json.dumps(post_data, separators=(',', ':'))
344
+ compressed_data = gzip.compress(data.encode('utf-8'))
345
+ headers = {
346
+ 'Authorization': ot_settings['api_key'],
347
+ 'Content-Encoding': 'gzip',
348
+ }
349
+ r = requests.post(
350
+ f'{ot_settings["server"]}/react_batch_3p',
351
+ headers=headers,
352
+ data=compressed_data,
353
+ timeout=OT_REQUEST_TIMEOUT
354
+ )
355
+ assert r.status_code == 200
356
+ is_online = True
357
+ r_json = r.json()
358
+ return r_json['actions'], r_json['q_out'], r_json['masks'], r_json['is_greedy']
359
+ except:
360
+ is_online = False
361
+ pass
362
+ # ==================================== #
363
+ try:
364
+ with (
365
+ torch.autocast(self.device.type, enabled=self.enable_amp),
366
+ torch.inference_mode(),
367
+ ):
368
+ return self._react_batch(obs, masks, invisible_obs)
369
+ except Exception as ex:
370
+ raise Exception(f'{ex}\n{traceback.format_exc()}')
371
+
372
+ def _react_batch(self, obs, masks, invisible_obs):
373
+ obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
374
+ masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
375
+ invisible_obs = None
376
+ if self.is_oracle:
377
+ invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
378
+ batch_size = obs.shape[0]
379
+
380
+ match self.version:
381
+ case 1:
382
+ mu, logsig = self.brain(obs, invisible_obs)
383
+ if self.stochastic_latent:
384
+ latent = Normal(mu, logsig.exp() + 1e-6).sample()
385
+ else:
386
+ latent = mu
387
+ q_out = self.dqn(latent, masks)
388
+ case 2 | 3 | 4:
389
+ phi = self.brain(obs)
390
+ q_out = self.dqn(phi, masks)
391
+
392
+ if self.boltzmann_epsilon > 0:
393
+ is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
394
+ logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
395
+ sampled = sample_top_p(logits, self.top_p)
396
+ actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
397
+ else:
398
+ is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
399
+ actions = q_out.argmax(-1)
400
+ return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
401
+
402
+ def sample_top_p(logits, p):
403
+ if p >= 1:
404
+ return Categorical(logits=logits).sample()
405
+ if p <= 0:
406
+ return logits.argmax(-1)
407
+ probs = logits.softmax(-1)
408
+ probs_sort, probs_idx = probs.sort(-1, descending=True)
409
+ probs_sum = probs_sort.cumsum(-1)
410
+ mask = probs_sum - probs_sort > p
411
+ probs_sort[mask] = 0.
412
+ sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
413
+ return sampled
414
+
415
+ def load_model(seat: int, model: str) -> Bot:
416
+ # check if GPU is available
417
+ if torch.cuda.is_available():
418
+ device = torch.device('cuda')
419
+ else:
420
+ device = torch.device('cpu')
421
+
422
+ # latest binary model
423
+ if model == None:
424
+ model = 'Elite4zWeightedBest5.pth'
425
+ model = str(model).split('?')[0]
426
+ control_state_file = model
427
+ print(control_state_file, 'loaded')
428
+
429
+ # Get the path of control_state_file = current directory / control_state_file
430
+ control_state_file = pathlib.Path(__file__).parent / control_state_file
431
+ state = torch.load(control_state_file, map_location=device)
432
+
433
+ mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval()
434
+ dqn = DQN(version=state['config']['control']['version']).eval()
435
+ mortal.load_state_dict(state['mortal'])
436
+ dqn.load_state_dict(state['current_dqn'])
437
+
438
+ engine = MortalEngine(
439
+ mortal,
440
+ dqn,
441
+ is_oracle = False,
442
+ version = state['config']['control']['version'],
443
+ device = device,
444
+ enable_amp = False,
445
+ enable_quick_eval = False,
446
+ enable_rule_based_agari_guard = True,
447
+ name = 'mortal',
448
+ top_p = 1,
449
+ )
450
+
451
+ bot = Bot(engine, seat)
452
+ return bot
model3pNEW.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gzip
3
+ import torch
4
+ import pathlib
5
+ import requests
6
+ import traceback
7
+ import numpy as np
8
+
9
+ from torch import nn, Tensor
10
+ from torch.nn import functional as F
11
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
12
+ from torch.distributions import Normal, Categorical
13
+ from typing import *
14
+ from functools import partial
15
+ from itertools import permutations
16
+ try:
17
+ from libriichi.mjai import Bot
18
+ from libriichi.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE
19
+ except:
20
+ import importlib.util
21
+ import sys
22
+ import os
23
+ SO_FILE_PATH = "/content/drive/MyDrive/MahjongTest/libriichi.so"
24
+
25
+ # 1. 检查文件到底存不存在
26
+ if not os.path.exists(SO_FILE_PATH):
27
+ print(f"❌ 致命错误:在路径 {SO_FILE_PATH} 下根本找不到文件!请检查路径拼写。")
28
+ else:
29
+ print(f"✅ 找到文件: {SO_FILE_PATH},正在尝试强行加载...")
30
+
31
+ try:
32
+ # 2. 根据绝对路径创建模块加载规范 (spec)
33
+ # 第一个参数是你想给它起的名字(供 Python 内部识别),第二个参数是文件路径
34
+ spec = importlib.util.spec_from_file_location("libriichi", SO_FILE_PATH)
35
+
36
+ # 3. 实例化模块
37
+ libriichi_module = importlib.util.module_from_spec(spec)
38
+
39
+ # 4. 注册到系统的模块字典里 (非常重要!这样后续其他文件 import libriichi3p 就能直接用)
40
+ sys.modules["libriichi"] = libriichi_module
41
+
42
+ # 5. 执行底层代码加载
43
+ spec.loader.exec_module(libriichi_module)
44
+
45
+ print("🎉 强行导入成功!现在可以在代码里正常使用了。")
46
+
47
+ except Exception as e:
48
+ print(f"❌ 导入失败,暴露出真实报错: {e}")
49
+ # ========== Online Server =========== #
50
+ OT_REQUEST_TIMEOUT = 2
51
+ ot_settings = {
52
+ "server": "http://example.com",
53
+ "online": False,
54
+ "api_key": "example_api_key",
55
+ }
56
+ is_online = False
57
+
58
+ def online_settings_init():
59
+ global ot_settings
60
+ # Check if the file exists
61
+ if (pathlib.Path(__file__).parent / 'ot_settings.json').exists():
62
+ with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f:
63
+ ot_settings = json.load(f)
64
+
65
+ online_settings_init()
66
+ # ==================================== #
67
+
68
+ class ChannelAttention(nn.Module):
69
+ def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
70
+ super().__init__()
71
+ self.shared_mlp = nn.Sequential(
72
+ nn.Linear(channels, channels // ratio, bias=bias),
73
+ actv_builder(),
74
+ nn.Linear(channels // ratio, channels, bias=bias),
75
+ )
76
+ if bias:
77
+ for mod in self.modules():
78
+ if isinstance(mod, nn.Linear):
79
+ nn.init.constant_(mod.bias, 0)
80
+
81
+ def forward(self, x: Tensor):
82
+ avg_out = self.shared_mlp(x.mean(-1))
83
+ max_out = self.shared_mlp(x.amax(-1))
84
+ weight = (avg_out + max_out).sigmoid()
85
+ x = weight.unsqueeze(-1) * x
86
+ return x
87
+
88
+ class ResBlock(nn.Module):
89
+ def __init__(
90
+ self,
91
+ channels,
92
+ *,
93
+ norm_builder = nn.Identity,
94
+ actv_builder = nn.ReLU,
95
+ pre_actv = False,
96
+ ):
97
+ super().__init__()
98
+ self.pre_actv = pre_actv
99
+
100
+ if pre_actv:
101
+ self.res_unit = nn.Sequential(
102
+ norm_builder(),
103
+ actv_builder(),
104
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
105
+ norm_builder(),
106
+ actv_builder(),
107
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
108
+ )
109
+ else:
110
+ self.res_unit = nn.Sequential(
111
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
112
+ norm_builder(),
113
+ actv_builder(),
114
+ nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
115
+ norm_builder(),
116
+ )
117
+ self.actv = actv_builder()
118
+ self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True)
119
+
120
+ def forward(self, x):
121
+ out = self.res_unit(x)
122
+ out = self.ca(out)
123
+ out = out + x
124
+ if not self.pre_actv:
125
+ out = self.actv(out)
126
+ return out
127
+
128
+ class ResNet(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_channels,
132
+ conv_channels,
133
+ num_blocks,
134
+ *,
135
+ norm_builder = nn.Identity,
136
+ actv_builder = nn.ReLU,
137
+ pre_actv = False,
138
+ ):
139
+ super().__init__()
140
+
141
+ blocks = []
142
+ for _ in range(num_blocks):
143
+ blocks.append(ResBlock(
144
+ conv_channels,
145
+ norm_builder = norm_builder,
146
+ actv_builder = actv_builder,
147
+ pre_actv = pre_actv,
148
+ ))
149
+
150
+ layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)]
151
+ if pre_actv:
152
+ layers += [*blocks, norm_builder(), actv_builder()]
153
+ else:
154
+ layers += [norm_builder(), actv_builder(), *blocks]
155
+ layers += [
156
+ nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1),
157
+ actv_builder(),
158
+ nn.Flatten(),
159
+ nn.Linear(32 * 34, 1024),
160
+ ]
161
+ self.net = nn.Sequential(*layers)
162
+
163
+ def forward(self, x):
164
+ return self.net(x)
165
+
166
+ class Brain(nn.Module):
167
+ def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
168
+ super().__init__()
169
+ self.is_oracle = is_oracle
170
+ self.version = version
171
+
172
+ in_channels = obs_shape(version)[0]
173
+ if is_oracle:
174
+ in_channels += oracle_obs_shape(version)[0]
175
+
176
+ norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01)
177
+ actv_builder = partial(nn.Mish, inplace=True)
178
+ pre_actv = True
179
+
180
+ match version:
181
+ case 1:
182
+ actv_builder = partial(nn.ReLU, inplace=True)
183
+ pre_actv = False
184
+ self.latent_net = nn.Sequential(
185
+ nn.Linear(1024, 512),
186
+ nn.ReLU(inplace=True),
187
+ )
188
+ self.mu_head = nn.Linear(512, 512)
189
+ self.logsig_head = nn.Linear(512, 512)
190
+ case 2:
191
+ pass
192
+ case 3 | 4:
193
+ norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3)
194
+ case _:
195
+ raise ValueError(f'Unexpected version {self.version}')
196
+
197
+ self.encoder = ResNet(
198
+ in_channels = in_channels,
199
+ conv_channels = conv_channels,
200
+ num_blocks = num_blocks,
201
+ norm_builder = norm_builder,
202
+ actv_builder = actv_builder,
203
+ pre_actv = pre_actv,
204
+ )
205
+ self.actv = actv_builder()
206
+
207
+ # always use EMA or CMA when True
208
+ self._freeze_bn = False
209
+
210
+ def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
211
+ if self.is_oracle:
212
+ assert invisible_obs is not None
213
+ obs = torch.cat((obs, invisible_obs), dim=1)
214
+ phi = self.encoder(obs)
215
+
216
+ match self.version:
217
+ case 1:
218
+ latent_out = self.latent_net(phi)
219
+ mu = self.mu_head(latent_out)
220
+ logsig = self.logsig_head(latent_out)
221
+ return mu, logsig
222
+ case 2 | 3 | 4:
223
+ return self.actv(phi)
224
+ case _:
225
+ raise ValueError(f'Unexpected version {self.version}')
226
+
227
+ def train(self, mode=True):
228
+ super().train(mode)
229
+ if self._freeze_bn:
230
+ for mod in self.modules():
231
+ if isinstance(mod, nn.BatchNorm1d):
232
+ mod.eval()
233
+ # I don't think this benefits
234
+ # module.requires_grad_(False)
235
+ return self
236
+
237
+ def reset_running_stats(self):
238
+ for mod in self.modules():
239
+ if isinstance(mod, nn.BatchNorm1d):
240
+ mod.reset_running_stats()
241
+
242
+ def freeze_bn(self, value: bool):
243
+ self._freeze_bn = value
244
+ return self.train(self.training)
245
+
246
+ class AuxNet(nn.Module):
247
+ def __init__(self, dims=None):
248
+ super().__init__()
249
+ self.dims = dims
250
+ self.net = nn.Linear(1024, sum(dims), bias=False)
251
+
252
+ def forward(self, x):
253
+ return self.net(x).split(self.dims, dim=-1)
254
+
255
+ class DQN(nn.Module):
256
+ def __init__(self, *, version=1):
257
+ super().__init__()
258
+ self.version = version
259
+ match version:
260
+ case 1:
261
+ self.v_head = nn.Linear(512, 1)
262
+ self.a_head = nn.Linear(512, ACTION_SPACE)
263
+ case 2 | 3:
264
+ hidden_size = 512 if version == 2 else 256
265
+ self.v_head = nn.Sequential(
266
+ nn.Linear(1024, hidden_size),
267
+ nn.Mish(inplace=True),
268
+ nn.Linear(hidden_size, 1),
269
+ )
270
+ self.a_head = nn.Sequential(
271
+ nn.Linear(1024, hidden_size),
272
+ nn.Mish(inplace=True),
273
+ nn.Linear(hidden_size, ACTION_SPACE),
274
+ )
275
+ case 4:
276
+ self.net = nn.Linear(1024, 1 + ACTION_SPACE)
277
+ nn.init.constant_(self.net.bias, 0)
278
+
279
+ def forward(self, phi, mask):
280
+ if self.version == 4:
281
+ v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1)
282
+ else:
283
+ v = self.v_head(phi)
284
+ a = self.a_head(phi)
285
+ a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True)
286
+ mask_sum = mask.sum(-1, keepdim=True)
287
+ a_mean = a_sum / mask_sum
288
+ q = (v + a - a_mean).masked_fill(~mask, -torch.inf)
289
+ return q
290
+
291
+
292
+ class MortalEngine:
293
+ def __init__(
294
+ self,
295
+ brain,
296
+ dqn,
297
+ is_oracle,
298
+ version,
299
+ device = None,
300
+ stochastic_latent = False,
301
+ enable_amp = False,
302
+ enable_quick_eval = True,
303
+ enable_rule_based_agari_guard = False,
304
+ name = 'NoName',
305
+ boltzmann_epsilon = 0,
306
+ boltzmann_temp = 1,
307
+ top_p = 1,
308
+ ):
309
+ self.engine_type = 'mortal'
310
+ self.device = device or torch.device('cpu')
311
+ assert isinstance(self.device, torch.device)
312
+ self.brain = brain.to(self.device).eval()
313
+ self.dqn = dqn.to(self.device).eval()
314
+ self.is_oracle = is_oracle
315
+ self.version = version
316
+ self.stochastic_latent = stochastic_latent
317
+
318
+ self.enable_amp = enable_amp
319
+ self.enable_quick_eval = enable_quick_eval
320
+ self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
321
+ self.name = name
322
+
323
+ self.boltzmann_epsilon = boltzmann_epsilon
324
+ self.boltzmann_temp = boltzmann_temp
325
+ self.top_p = top_p
326
+
327
+ def react_batch(self, obs, masks, invisible_obs):
328
+ # ========== Online Server =========== #
329
+ global ot_settings, is_online
330
+ if ot_settings['online']:
331
+ try:
332
+ list_obs = [o.tolist() for o in obs]
333
+ list_masks = [m.tolist() for m in masks]
334
+ post_data = {
335
+ 'obs': list_obs,
336
+ 'masks': list_masks,
337
+ }
338
+ data = json.dumps(post_data, separators=(',', ':'))
339
+ compressed_data = gzip.compress(data.encode('utf-8'))
340
+ headers = {
341
+ 'Authorization': ot_settings['api_key'],
342
+ 'Content-Encoding': 'gzip',
343
+ }
344
+ r = requests.post(
345
+ f'{ot_settings["server"]}/react_batch',
346
+ headers=headers,
347
+ data=compressed_data,
348
+ timeout=OT_REQUEST_TIMEOUT
349
+ )
350
+ assert r.status_code == 200
351
+ is_online = True
352
+ r_json = r.json()
353
+ return r_json['actions'], r_json['q_out'], r_json['masks'], r_json['is_greedy']
354
+ except:
355
+ is_online = False
356
+ pass
357
+ # ==================================== #
358
+ try:
359
+ with (
360
+ torch.autocast(self.device.type, enabled=self.enable_amp),
361
+ torch.inference_mode(),
362
+ ):
363
+ return self._react_batch(obs, masks, invisible_obs)
364
+ except Exception as ex:
365
+ raise Exception(f'{ex}\n{traceback.format_exc()}')
366
+
367
+ def _react_batch(self, obs, masks, invisible_obs):
368
+ obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
369
+ masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
370
+ invisible_obs = None
371
+ if self.is_oracle:
372
+ invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
373
+ batch_size = obs.shape[0]
374
+
375
+ match self.version:
376
+ case 1:
377
+ mu, logsig = self.brain(obs, invisible_obs)
378
+ if self.stochastic_latent:
379
+ latent = Normal(mu, logsig.exp() + 1e-6).sample()
380
+ else:
381
+ latent = mu
382
+ q_out = self.dqn(latent, masks)
383
+ case 2 | 3 | 4:
384
+ phi = self.brain(obs)
385
+ q_out = self.dqn(phi, masks)
386
+
387
+ if self.boltzmann_epsilon > 0:
388
+ is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
389
+ logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
390
+ sampled = sample_top_p(logits, self.top_p)
391
+ actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
392
+ else:
393
+ is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
394
+ actions = q_out.argmax(-1)
395
+
396
+ return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
397
+
398
+ def sample_top_p(logits, p):
399
+ if p >= 1:
400
+ return Categorical(logits=logits).sample()
401
+ if p <= 0:
402
+ return logits.argmax(-1)
403
+ probs = logits.softmax(-1)
404
+ probs_sort, probs_idx = probs.sort(-1, descending=True)
405
+ probs_sum = probs_sort.cumsum(-1)
406
+ mask = probs_sum - probs_sort > p
407
+ probs_sort[mask] = 0.
408
+ sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
409
+ return sampled
410
+
411
+ def load_model(seat: int, model_type) -> Bot:
412
+ # check if GPU is available
413
+ # device = torch.device('cpu')
414
+ if torch.cuda.is_available():
415
+ device = torch.device('cuda')
416
+ else:
417
+ device = torch.device('cpu')
418
+
419
+ # latest binary model
420
+ control_state_file = "./Elite4z-Mowang_epoch_10.pth"
421
+ print('model.py loading', control_state_file)
422
+
423
+ # Get the path of control_state_file = current directory / control_state_file
424
+ control_state_file = pathlib.Path(__file__).parent / control_state_file
425
+ state = torch.load(control_state_file, map_location=device)
426
+
427
+ mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval()
428
+ dqn = DQN(version=state['config']['control']['version']).eval()
429
+ mortal.load_state_dict(state['mortal'])
430
+ dqn.load_state_dict(state['current_dqn'])
431
+
432
+ engine = MortalEngine(
433
+ mortal,
434
+ dqn,
435
+ is_oracle = False,
436
+ version = state['config']['control']['version'],
437
+ device = device,
438
+ enable_amp = False,
439
+ enable_quick_eval = False,
440
+ enable_rule_based_agari_guard = True,
441
+ name = 'mortal',
442
+ )
443
+
444
+ bot = Bot(engine, seat)
445
+ return bot
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ orjson
3
+ gradio
4
+ matplotlib
5
+ pandas
6
+ riichienv
7
+ requests