ffzeroHua commited on
Commit
ed0d2b6
·
verified ·
1 Parent(s): 6a169e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import pickle
5
+ import threading
6
+ import traceback
7
+ import requests
8
+ import numpy as np
9
+ from typing import *
10
+ from datetime import datetime
11
+
12
+ # Web 服务与 HF Hub 依赖
13
+ from fastapi import FastAPI
14
+ import uvicorn
15
+ from huggingface_hub import HfApi, hf_hub_download
16
+
17
+ # 底层特征引擎 (Teacher)
18
+ from libriichi3p.mjai import Bot as RiichiBot
19
+ from libriichi3p.consts import ACTION_SPACE
20
+
21
+ # 底层特征引擎 (Student)
22
+ try:
23
+ from libriichiSanma import state as sanma_state
24
+ except ImportError:
25
+ import libriichi as sanma_state
26
+
27
+ # ==========================================
28
+ # [配置与环境变量]
29
+ # ==========================================
30
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
31
+ DATASET_REPO = os.environ.get("DATASET_REPO", "ffzeroHua/tenhou-scc")
32
+ URL_LIST_FILE = os.environ.get("URL_LIST_FILE", "urls_better.txt")
33
+
34
+ MASK_3P = [
35
+ "1m", "2m", "3m", "4m", "5m", "6m", "7m", "8m", "9m",
36
+ "1p", "2p", "3p", "4p", "5p", "6p", "7p", "8p", "9p",
37
+ "1s", "2s", "3s", "4s", "5s", "6s", "7s", "8s", "9s",
38
+ "E", "S", "W", "N", "P", "F", "C",
39
+ '5mr', '5pr', '5sr',
40
+ 'reach', 'pon', 'kan', 'nukidora', 'hora', 'ryukyoku', 'none'
41
+ ]
42
+
43
+ NONE_CODE = MASK_3P.index('none')
44
+ KAN_CODE = MASK_3P.index('kan')
45
+ _thread_local = threading.local()
46
+
47
+ worker_status = {
48
+ "status": "Starting up...",
49
+ "urls_processed": 0,
50
+ "total_chunks_uploaded": 0,
51
+ "total_records_extracted": 0,
52
+ "current_target": "",
53
+ "errors": 0
54
+ }
55
+
56
+ # ==========================================
57
+ # [解析器] 保持不变
58
+ # ==========================================
59
+ class TenhouParser:
60
+ @staticmethod
61
+ def tile_name(x):
62
+ if x in (51, 52, 53): return ['5mr', '5pr', '5sr'][x - 51]
63
+ num, suit = x % 10, x // 10
64
+ if suit in (1, 2, 3): return str(num) + 'mps'[suit - 1]
65
+ if suit == 4: return 'ESWNPFC'[num - 1]
66
+ return '?'
67
+
68
+ @classmethod
69
+ def get_meld_tiles(cls, actor, s):
70
+ i, player = 0, 0
71
+ result = {'pai': [], 'consumed': [], 'actor': actor}
72
+ while i < len(s):
73
+ player += 1
74
+ tile_type = 'consumed'
75
+ if s[i] in 'cpmakf':
76
+ tile_type = 'pai'
77
+ result['type'] = ['chi', 'pon', 'daiminkan', 'ankan', 'kakan', 'nukidora']['cpmakf'.index(s[i])]
78
+ if s[i] in 'cpm':
79
+ result['target'] = (4 - player + actor) % 4
80
+ i += 1
81
+ result[tile_type].append(cls.tile_name(int(s[i:i+2])))
82
+ i += 2
83
+ result['pai'] = result['pai'][0]
84
+ if result.get('type') == 'ankan': result['consumed'].append(result['pai'])
85
+ return result
86
+
87
+ @classmethod
88
+ def parse_events(cls, actor, income, outcome):
89
+ incoming, outcoming = [], []
90
+ for i, event in enumerate(income):
91
+ if type(event) is str: incoming.append(cls.get_meld_tiles(actor, event))
92
+ else: incoming.append({'type': 'tsumo', 'pai': cls.tile_name(event), 'actor': actor})
93
+ for i, event in enumerate(outcome):
94
+ if type(event) is str and event[0] != 'r':
95
+ outcoming.append(cls.get_meld_tiles(actor, event))
96
+ else:
97
+ if event == 0:
98
+ outcoming.append({'type': 'empty'})
99
+ continue
100
+ reach = False
101
+ if type(event) is str and event[0] == 'r':
102
+ reach, event = True, int(event[1:])
103
+ outcoming.append({'type': 'reach', 'actor': actor})
104
+ outcoming.append({'type': 'dahai', 'pai': cls.tile_name(event if event != 60 else income[i]), 'actor': actor, 'tsumogiri': event == 60})
105
+ if reach: outcoming.append({'type': 'reach_accepted', 'actor': actor})
106
+ return incoming, outcoming
107
+
108
+ @classmethod
109
+ def merge_events(cls, oya, events, dora_markers):
110
+ current, result = oya, []
111
+ def finished(x): return all(len(i[0]) == 0 and len(i[1]) == 0 for i in x)
112
+ while not finished(events):
113
+ income, outcome = events[current]
114
+ nuki = False
115
+ if len(income):
116
+ result.append(income.pop(0))
117
+ if result[-1]['type'] == 'daiminkan':
118
+ result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))})
119
+ outcome.pop(0)
120
+ continue
121
+ if len(outcome):
122
+ result.append(outcome.pop(0))
123
+ pai, t = result[-1].get('pai'), result[-1]['type']
124
+ if t == 'reach':
125
+ result.append(outcome.pop(0))
126
+ pai = result[-1].get('pai')
127
+ result.append(outcome.pop(0))
128
+ nuki = False
129
+ for actor, x in enumerate(events):
130
+ if actor == current or len(x[1]) == 0: continue
131
+ if x[0][0]['type'] != 'tsumo' and x[0][0].get('pai') == pai and not (x[0][0]['type'] == 'chi' and not (x[0][0]['actor'] + 3) % 4 == actor):
132
+ nuki, current = True, actor
133
+ break
134
+ if t in ('ankan', 'kakan', 'nukidora'):
135
+ if t != 'nukidora' and len(dora_markers) > 0: result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))})
136
+ nuki = True
137
+ if not nuki: current = (current + 1) % 4
138
+ return result
139
+
140
+ @classmethod
141
+ def parse_single_round(cls, data):
142
+ round_info, scores, dora_markers, uradora, result_info = data[0], data[1], data[2], data[3], data[-1]
143
+ oya = round_info[0] % 4
144
+ patch = lambda arr: arr if len(arr) >= 13 else [0] * 13
145
+ events = [{
146
+ 'type': 'start_kyoku', 'bakaze': 'ESWN'[round_info[0] // 4], 'kyoku': oya + 1,
147
+ 'honba': round_info[1], 'kyotaku': round_info[2], 'oya': oya,
148
+ 'dora_marker': cls.tile_name(dora_markers.pop(0)), 'scores': scores,
149
+ 'tehais': [[cls.tile_name(i) for i in patch(data[k])] for k in [4, 7, 10, 13]]
150
+ }]
151
+ e_list = [cls.parse_events(i, data[5+i*3], data[6+i*3]) for i in range(4)]
152
+ events += cls.merge_events(oya, e_list, dora_markers)
153
+ last_type = events[-1]['type']
154
+ if last_type == 'tsumo' and result_info[0] == '和了': events.append({'type': 'hora', 'actor': events[-1]['actor'], 'target': events[-1]['actor']})
155
+ elif result_info[0] == '和了':
156
+ actor = next(i for i, x in enumerate(result_info[1]) if x > 0)
157
+ events.append({'type': 'hora', 'actor': actor, 'target': actor})
158
+ elif last_type == 'tsumo' or '九牌' in result_info[0]: events.append({'type': 'ryukyoku', 'actor': events[-1]['actor']})
159
+ return events
160
+
161
+ @classmethod
162
+ def parse_log(cls, log):
163
+ scores = log.get('sc', [])
164
+ weights = [1.0, 1.0, 1.0]
165
+ seat = log['name'].index('私') if '私' in log['name'] else -1
166
+ parsed_rounds = []
167
+ for i in log['log'][:]:
168
+ round_events = [{"type": "start_game", "id": seat, "weight": weights}] + cls.parse_single_round(i)
169
+ parsed_rounds.append(round_events)
170
+ return parsed_rounds
171
+
172
+ # ==========================================
173
+ # [特征拦截假引擎 (Teacher)]
174
+ # ==========================================
175
+ class DummyFeatureEngine:
176
+ def __init__(self):
177
+ self.engine_type = 'mortal'
178
+ self.name = 'DataMiner'
179
+ self.version = 4
180
+ self.is_oracle = False
181
+ self.enable_quick_eval = True
182
+ self.enable_rule_based_agari_guard = True
183
+
184
+ def react_batch(self, obs, masks, invisible_obs):
185
+ _thread_local.interception = (obs, masks, invisible_obs)
186
+ batch_size = len(obs)
187
+ actions, q_outs, pure_masks = [], [], []
188
+
189
+ for m in masks:
190
+ m_list = m.tolist() if hasattr(m, 'tolist') else list(m)
191
+ pure_masks.append(m_list)
192
+ try: valid_action = m_list.index(True)
193
+ except ValueError: valid_action = 0
194
+ actions.append(valid_action)
195
+ q_outs.append([0.0] * len(m_list))
196
+ return actions, q_outs, pure_masks, [True] * batch_size
197
+
198
+ # ==========================================
199
+ # [双重特征打包架构 (Distillation)]
200
+ # ==========================================
201
+ class FeatureEncoder:
202
+ def __init__(self, chunk_size=2048, pool_size=8):
203
+ self.chunk_size = chunk_size
204
+ self.pool_size = pool_size
205
+ self.inputs, self.outputs, self.weights = [], [], []
206
+ self.chunk_count = 0
207
+ self.hf_api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
208
+
209
+ self.local_pool_dir = "local_chunks_pool"
210
+ os.makedirs(self.local_pool_dir, exist_ok=True)
211
+
212
+ @staticmethod
213
+ def action_to_mask(who, action):
214
+ if action is None: return NONE_CODE
215
+ if type(action) is str: action = json.loads(action)
216
+ if action.get('actor') != who or action.get('type') == 'tsumo': return NONE_CODE
217
+ if action['type'] == 'dahai': return MASK_3P.index(action['pai'])
218
+ if action['type'] in ('daiminkan', 'ankan', 'kakan'): return KAN_CODE
219
+ if action['type'] in MASK_3P: return MASK_3P.index(action['type'])
220
+ raise Exception(f"Unknown action map: {action}")
221
+
222
+ def save_and_check_upload(self):
223
+ if not self.inputs: return
224
+
225
+ filename = f"chunk_distill_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{self.chunk_count}.pkl"
226
+ filepath = os.path.join(self.local_pool_dir, filename)
227
+
228
+ with open(filepath, 'wb') as f:
229
+ pickle.dump({'inputs': self.inputs, 'outputs': self.outputs, 'weights': self.weights}, f)
230
+
231
+ print(f"📦 已生成蒸馏缓存: {filename} ({len(self.inputs)} records).")
232
+
233
+ self.chunk_count += 1
234
+ self.inputs.clear()
235
+ self.outputs.clear()
236
+ self.weights.clear()
237
+
238
+ current_files = os.listdir(self.local_pool_dir)
239
+ if len(current_files) >= self.pool_size:
240
+ self.upload_pool()
241
+
242
+ def upload_pool(self):
243
+ current_files = os.listdir(self.local_pool_dir)
244
+ if not current_files or not self.hf_api or not DATASET_REPO: return
245
+
246
+ import time
247
+ print(f"🚀 本地池满,正在批量上传 {len(current_files)} 个文件...")
248
+
249
+ for attempt in range(6):
250
+ try:
251
+ self.hf_api.upload_folder(
252
+ folder_path=self.local_pool_dir,
253
+ path_in_repo="distill_chunks",
254
+ repo_id=DATASET_REPO,
255
+ repo_type="dataset"
256
+ )
257
+ print(f"✅ 上传成功 (Attempt {attempt + 1}).")
258
+ worker_status["total_chunks_uploaded"] += len(current_files)
259
+ for f in current_files: os.remove(os.path.join(self.local_pool_dir, f))
260
+ break
261
+ except Exception as e:
262
+ wait_time = 5 * (2 ** attempt)
263
+ print(f"⚠️ Upload failed: {e}. Waiting {wait_time}s...")
264
+ time.sleep(wait_time)
265
+
266
+ def process_game(self, events):
267
+ who = -1
268
+ current_weight = 1.0
269
+
270
+ ps_student = None
271
+ bot_teacher = None
272
+
273
+ for i, event in enumerate(events):
274
+ if event.get('type') == 'start_game':
275
+ who = event['id']
276
+ weights_list = event.get('weight', [1.0, 1.0, 1.0])
277
+ current_weight = weights_list[who]
278
+
279
+ # 初始化双模型状态机
280
+ ps_student = sanma_state.PlayerState(who)
281
+ bot_teacher = RiichiBot(DummyFeatureEngine(), who)
282
+
283
+ if ps_student is None or bot_teacher is None:
284
+ continue
285
+
286
+ if event.get('type') == 'end_game':
287
+ continue
288
+
289
+ next_event = None
290
+ for j in range(i + 1, len(events)):
291
+ if events[j].get('type') not in ('dora', 'reach_accepted'):
292
+ next_event = events[j]; break
293
+
294
+ event_str = json.dumps(event, separators=(",", ":"))
295
+
296
+ # --- 1. Teacher 更新与拦截 ---
297
+ _thread_local.interception = None
298
+ bot_teacher.react(event_str)
299
+ intercepted = getattr(_thread_local, 'interception', None)
300
+
301
+ # --- 2. Student 更新与特征生成 ---
302
+ cans = ps_student.update(event_str)
303
+
304
+ if intercepted is None or not cans.can_act:
305
+ continue
306
+
307
+ obs_t, masks_t, _ = intercepted
308
+ obs_s, mask_s = ps_student.encode_obs(4, False)
309
+
310
+ valid_actions_count = int(np.count_nonzero(masks_t[0]))
311
+ if valid_actions_count <= 1:
312
+ continue
313
+
314
+ try:
315
+ output_code = self.action_to_mask(who, next_event)
316
+
317
+ # 存入字典,解耦新老数据格式
318
+ self.inputs.append({
319
+ "obs_student": obs_s,
320
+ "mask_student": mask_s,
321
+ "obs_teacher": obs_t[0], # 去除 batch 维度
322
+ "mask_teacher": masks_t[0]
323
+ })
324
+ self.outputs.append(output_code)
325
+ self.weights.append(current_weight)
326
+
327
+ worker_status["total_records_extracted"] += 1
328
+ except Exception: pass
329
+
330
+ if len(self.inputs) >= self.chunk_size:
331
+ self.save_and_check_upload()
332
+
333
+ # ==========================================
334
+ # [数据挖掘总管线]
335
+ # ==========================================
336
+ def worker_pipeline():
337
+ if not HF_TOKEN or not DATASET_REPO:
338
+ worker_status["status"] = "Error: HF_TOKEN or DATASET_REPO missing!"
339
+ return
340
+
341
+ worker_status["status"] = "Fetching target URL list..."
342
+ try:
343
+ url_file_path = hf_hub_download(repo_id=DATASET_REPO, filename=URL_LIST_FILE, repo_type="dataset", token=HF_TOKEN)
344
+ with open(url_file_path, 'r') as f: target_urls = [line.strip() for line in f if line.strip()]
345
+ except Exception as e:
346
+ worker_status["status"] = f"Failed to fetch {URL_LIST_FILE}: {e}"
347
+ return
348
+
349
+ headers = {"User-Agent": "Mozilla/5.0"}
350
+ encoder = FeatureEncoder(chunk_size=2048, pool_size=8)
351
+ worker_status["status"] = "Mining..."
352
+
353
+ for url in target_urls:
354
+ worker_status["current_target"] = url
355
+ log_match = re.search(r'log=([^&]+)', url)
356
+ tw_match = re.search(r'tw=(\d+)', url)
357
+ if not log_match: continue
358
+
359
+ tw = int(tw_match.group(1)) if tw_match else -1
360
+ log_id = log_match.group(1)
361
+
362
+ try:
363
+ res = requests.get(f"https://tenhou.net/5/mjlog2json.cgi?{log_id}", headers=headers, timeout=30)
364
+ parsed_games = TenhouParser.parse_log(res.json())
365
+
366
+ for game in parsed_games:
367
+ for j in range(3):
368
+ if j == tw: continue
369
+ game[0]['id'] = j
370
+ encoder.process_game(game)
371
+
372
+ worker_status["urls_processed"] += 1
373
+ except Exception as e:
374
+ worker_status["errors"] += 1
375
+
376
+ encoder.save_and_check_upload()
377
+ encoder.upload_pool()
378
+ worker_status["status"] = "Finished! All URLs processed."
379
+ worker_status["current_target"] = "Idle"
380
+
381
+ app = FastAPI()
382
+ @app.get("/")
383
+ def read_status(): return worker_status
384
+
385
+ if __name__ == '__main__':
386
+ thread = threading.Thread(target=worker_pipeline, daemon=True)
387
+ thread.start()
388
+ uvicorn.run(app, host="0.0.0.0", port=7860)