Spaces:
Paused
Paused
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import pickle
|
| 5 |
+
import threading
|
| 6 |
+
import traceback
|
| 7 |
+
import requests
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import *
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
# Web 服务与 HF Hub 依赖
|
| 13 |
+
from fastapi import FastAPI
|
| 14 |
+
import uvicorn
|
| 15 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 16 |
+
|
| 17 |
+
# 底层特征引擎 (Teacher)
|
| 18 |
+
from libriichi3p.mjai import Bot as RiichiBot
|
| 19 |
+
from libriichi3p.consts import ACTION_SPACE
|
| 20 |
+
|
| 21 |
+
# 底层特征引擎 (Student)
|
| 22 |
+
try:
|
| 23 |
+
from libriichiSanma import state as sanma_state
|
| 24 |
+
except ImportError:
|
| 25 |
+
import libriichi as sanma_state
|
| 26 |
+
|
| 27 |
+
# ==========================================
|
| 28 |
+
# [配置与环境变量]
|
| 29 |
+
# ==========================================
|
| 30 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 31 |
+
DATASET_REPO = os.environ.get("DATASET_REPO", "ffzeroHua/tenhou-scc")
|
| 32 |
+
URL_LIST_FILE = os.environ.get("URL_LIST_FILE", "urls_better.txt")
|
| 33 |
+
|
| 34 |
+
MASK_3P = [
|
| 35 |
+
"1m", "2m", "3m", "4m", "5m", "6m", "7m", "8m", "9m",
|
| 36 |
+
"1p", "2p", "3p", "4p", "5p", "6p", "7p", "8p", "9p",
|
| 37 |
+
"1s", "2s", "3s", "4s", "5s", "6s", "7s", "8s", "9s",
|
| 38 |
+
"E", "S", "W", "N", "P", "F", "C",
|
| 39 |
+
'5mr', '5pr', '5sr',
|
| 40 |
+
'reach', 'pon', 'kan', 'nukidora', 'hora', 'ryukyoku', 'none'
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
NONE_CODE = MASK_3P.index('none')
|
| 44 |
+
KAN_CODE = MASK_3P.index('kan')
|
| 45 |
+
_thread_local = threading.local()
|
| 46 |
+
|
| 47 |
+
worker_status = {
|
| 48 |
+
"status": "Starting up...",
|
| 49 |
+
"urls_processed": 0,
|
| 50 |
+
"total_chunks_uploaded": 0,
|
| 51 |
+
"total_records_extracted": 0,
|
| 52 |
+
"current_target": "",
|
| 53 |
+
"errors": 0
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# ==========================================
|
| 57 |
+
# [解析器] 保持不变
|
| 58 |
+
# ==========================================
|
| 59 |
+
class TenhouParser:
|
| 60 |
+
@staticmethod
|
| 61 |
+
def tile_name(x):
|
| 62 |
+
if x in (51, 52, 53): return ['5mr', '5pr', '5sr'][x - 51]
|
| 63 |
+
num, suit = x % 10, x // 10
|
| 64 |
+
if suit in (1, 2, 3): return str(num) + 'mps'[suit - 1]
|
| 65 |
+
if suit == 4: return 'ESWNPFC'[num - 1]
|
| 66 |
+
return '?'
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def get_meld_tiles(cls, actor, s):
|
| 70 |
+
i, player = 0, 0
|
| 71 |
+
result = {'pai': [], 'consumed': [], 'actor': actor}
|
| 72 |
+
while i < len(s):
|
| 73 |
+
player += 1
|
| 74 |
+
tile_type = 'consumed'
|
| 75 |
+
if s[i] in 'cpmakf':
|
| 76 |
+
tile_type = 'pai'
|
| 77 |
+
result['type'] = ['chi', 'pon', 'daiminkan', 'ankan', 'kakan', 'nukidora']['cpmakf'.index(s[i])]
|
| 78 |
+
if s[i] in 'cpm':
|
| 79 |
+
result['target'] = (4 - player + actor) % 4
|
| 80 |
+
i += 1
|
| 81 |
+
result[tile_type].append(cls.tile_name(int(s[i:i+2])))
|
| 82 |
+
i += 2
|
| 83 |
+
result['pai'] = result['pai'][0]
|
| 84 |
+
if result.get('type') == 'ankan': result['consumed'].append(result['pai'])
|
| 85 |
+
return result
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def parse_events(cls, actor, income, outcome):
|
| 89 |
+
incoming, outcoming = [], []
|
| 90 |
+
for i, event in enumerate(income):
|
| 91 |
+
if type(event) is str: incoming.append(cls.get_meld_tiles(actor, event))
|
| 92 |
+
else: incoming.append({'type': 'tsumo', 'pai': cls.tile_name(event), 'actor': actor})
|
| 93 |
+
for i, event in enumerate(outcome):
|
| 94 |
+
if type(event) is str and event[0] != 'r':
|
| 95 |
+
outcoming.append(cls.get_meld_tiles(actor, event))
|
| 96 |
+
else:
|
| 97 |
+
if event == 0:
|
| 98 |
+
outcoming.append({'type': 'empty'})
|
| 99 |
+
continue
|
| 100 |
+
reach = False
|
| 101 |
+
if type(event) is str and event[0] == 'r':
|
| 102 |
+
reach, event = True, int(event[1:])
|
| 103 |
+
outcoming.append({'type': 'reach', 'actor': actor})
|
| 104 |
+
outcoming.append({'type': 'dahai', 'pai': cls.tile_name(event if event != 60 else income[i]), 'actor': actor, 'tsumogiri': event == 60})
|
| 105 |
+
if reach: outcoming.append({'type': 'reach_accepted', 'actor': actor})
|
| 106 |
+
return incoming, outcoming
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def merge_events(cls, oya, events, dora_markers):
|
| 110 |
+
current, result = oya, []
|
| 111 |
+
def finished(x): return all(len(i[0]) == 0 and len(i[1]) == 0 for i in x)
|
| 112 |
+
while not finished(events):
|
| 113 |
+
income, outcome = events[current]
|
| 114 |
+
nuki = False
|
| 115 |
+
if len(income):
|
| 116 |
+
result.append(income.pop(0))
|
| 117 |
+
if result[-1]['type'] == 'daiminkan':
|
| 118 |
+
result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))})
|
| 119 |
+
outcome.pop(0)
|
| 120 |
+
continue
|
| 121 |
+
if len(outcome):
|
| 122 |
+
result.append(outcome.pop(0))
|
| 123 |
+
pai, t = result[-1].get('pai'), result[-1]['type']
|
| 124 |
+
if t == 'reach':
|
| 125 |
+
result.append(outcome.pop(0))
|
| 126 |
+
pai = result[-1].get('pai')
|
| 127 |
+
result.append(outcome.pop(0))
|
| 128 |
+
nuki = False
|
| 129 |
+
for actor, x in enumerate(events):
|
| 130 |
+
if actor == current or len(x[1]) == 0: continue
|
| 131 |
+
if x[0][0]['type'] != 'tsumo' and x[0][0].get('pai') == pai and not (x[0][0]['type'] == 'chi' and not (x[0][0]['actor'] + 3) % 4 == actor):
|
| 132 |
+
nuki, current = True, actor
|
| 133 |
+
break
|
| 134 |
+
if t in ('ankan', 'kakan', 'nukidora'):
|
| 135 |
+
if t != 'nukidora' and len(dora_markers) > 0: result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))})
|
| 136 |
+
nuki = True
|
| 137 |
+
if not nuki: current = (current + 1) % 4
|
| 138 |
+
return result
|
| 139 |
+
|
| 140 |
+
@classmethod
|
| 141 |
+
def parse_single_round(cls, data):
|
| 142 |
+
round_info, scores, dora_markers, uradora, result_info = data[0], data[1], data[2], data[3], data[-1]
|
| 143 |
+
oya = round_info[0] % 4
|
| 144 |
+
patch = lambda arr: arr if len(arr) >= 13 else [0] * 13
|
| 145 |
+
events = [{
|
| 146 |
+
'type': 'start_kyoku', 'bakaze': 'ESWN'[round_info[0] // 4], 'kyoku': oya + 1,
|
| 147 |
+
'honba': round_info[1], 'kyotaku': round_info[2], 'oya': oya,
|
| 148 |
+
'dora_marker': cls.tile_name(dora_markers.pop(0)), 'scores': scores,
|
| 149 |
+
'tehais': [[cls.tile_name(i) for i in patch(data[k])] for k in [4, 7, 10, 13]]
|
| 150 |
+
}]
|
| 151 |
+
e_list = [cls.parse_events(i, data[5+i*3], data[6+i*3]) for i in range(4)]
|
| 152 |
+
events += cls.merge_events(oya, e_list, dora_markers)
|
| 153 |
+
last_type = events[-1]['type']
|
| 154 |
+
if last_type == 'tsumo' and result_info[0] == '和了': events.append({'type': 'hora', 'actor': events[-1]['actor'], 'target': events[-1]['actor']})
|
| 155 |
+
elif result_info[0] == '和了':
|
| 156 |
+
actor = next(i for i, x in enumerate(result_info[1]) if x > 0)
|
| 157 |
+
events.append({'type': 'hora', 'actor': actor, 'target': actor})
|
| 158 |
+
elif last_type == 'tsumo' or '九牌' in result_info[0]: events.append({'type': 'ryukyoku', 'actor': events[-1]['actor']})
|
| 159 |
+
return events
|
| 160 |
+
|
| 161 |
+
@classmethod
|
| 162 |
+
def parse_log(cls, log):
|
| 163 |
+
scores = log.get('sc', [])
|
| 164 |
+
weights = [1.0, 1.0, 1.0]
|
| 165 |
+
seat = log['name'].index('私') if '私' in log['name'] else -1
|
| 166 |
+
parsed_rounds = []
|
| 167 |
+
for i in log['log'][:]:
|
| 168 |
+
round_events = [{"type": "start_game", "id": seat, "weight": weights}] + cls.parse_single_round(i)
|
| 169 |
+
parsed_rounds.append(round_events)
|
| 170 |
+
return parsed_rounds
|
| 171 |
+
|
| 172 |
+
# ==========================================
|
| 173 |
+
# [特征拦截假引擎 (Teacher)]
|
| 174 |
+
# ==========================================
|
| 175 |
+
class DummyFeatureEngine:
|
| 176 |
+
def __init__(self):
|
| 177 |
+
self.engine_type = 'mortal'
|
| 178 |
+
self.name = 'DataMiner'
|
| 179 |
+
self.version = 4
|
| 180 |
+
self.is_oracle = False
|
| 181 |
+
self.enable_quick_eval = True
|
| 182 |
+
self.enable_rule_based_agari_guard = True
|
| 183 |
+
|
| 184 |
+
def react_batch(self, obs, masks, invisible_obs):
|
| 185 |
+
_thread_local.interception = (obs, masks, invisible_obs)
|
| 186 |
+
batch_size = len(obs)
|
| 187 |
+
actions, q_outs, pure_masks = [], [], []
|
| 188 |
+
|
| 189 |
+
for m in masks:
|
| 190 |
+
m_list = m.tolist() if hasattr(m, 'tolist') else list(m)
|
| 191 |
+
pure_masks.append(m_list)
|
| 192 |
+
try: valid_action = m_list.index(True)
|
| 193 |
+
except ValueError: valid_action = 0
|
| 194 |
+
actions.append(valid_action)
|
| 195 |
+
q_outs.append([0.0] * len(m_list))
|
| 196 |
+
return actions, q_outs, pure_masks, [True] * batch_size
|
| 197 |
+
|
| 198 |
+
# ==========================================
|
| 199 |
+
# [双重特征打包架构 (Distillation)]
|
| 200 |
+
# ==========================================
|
| 201 |
+
class FeatureEncoder:
|
| 202 |
+
def __init__(self, chunk_size=2048, pool_size=8):
|
| 203 |
+
self.chunk_size = chunk_size
|
| 204 |
+
self.pool_size = pool_size
|
| 205 |
+
self.inputs, self.outputs, self.weights = [], [], []
|
| 206 |
+
self.chunk_count = 0
|
| 207 |
+
self.hf_api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
|
| 208 |
+
|
| 209 |
+
self.local_pool_dir = "local_chunks_pool"
|
| 210 |
+
os.makedirs(self.local_pool_dir, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def action_to_mask(who, action):
|
| 214 |
+
if action is None: return NONE_CODE
|
| 215 |
+
if type(action) is str: action = json.loads(action)
|
| 216 |
+
if action.get('actor') != who or action.get('type') == 'tsumo': return NONE_CODE
|
| 217 |
+
if action['type'] == 'dahai': return MASK_3P.index(action['pai'])
|
| 218 |
+
if action['type'] in ('daiminkan', 'ankan', 'kakan'): return KAN_CODE
|
| 219 |
+
if action['type'] in MASK_3P: return MASK_3P.index(action['type'])
|
| 220 |
+
raise Exception(f"Unknown action map: {action}")
|
| 221 |
+
|
| 222 |
+
def save_and_check_upload(self):
|
| 223 |
+
if not self.inputs: return
|
| 224 |
+
|
| 225 |
+
filename = f"chunk_distill_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{self.chunk_count}.pkl"
|
| 226 |
+
filepath = os.path.join(self.local_pool_dir, filename)
|
| 227 |
+
|
| 228 |
+
with open(filepath, 'wb') as f:
|
| 229 |
+
pickle.dump({'inputs': self.inputs, 'outputs': self.outputs, 'weights': self.weights}, f)
|
| 230 |
+
|
| 231 |
+
print(f"📦 已生成蒸馏缓存: {filename} ({len(self.inputs)} records).")
|
| 232 |
+
|
| 233 |
+
self.chunk_count += 1
|
| 234 |
+
self.inputs.clear()
|
| 235 |
+
self.outputs.clear()
|
| 236 |
+
self.weights.clear()
|
| 237 |
+
|
| 238 |
+
current_files = os.listdir(self.local_pool_dir)
|
| 239 |
+
if len(current_files) >= self.pool_size:
|
| 240 |
+
self.upload_pool()
|
| 241 |
+
|
| 242 |
+
def upload_pool(self):
|
| 243 |
+
current_files = os.listdir(self.local_pool_dir)
|
| 244 |
+
if not current_files or not self.hf_api or not DATASET_REPO: return
|
| 245 |
+
|
| 246 |
+
import time
|
| 247 |
+
print(f"🚀 本地池满,正在批量上传 {len(current_files)} 个文件...")
|
| 248 |
+
|
| 249 |
+
for attempt in range(6):
|
| 250 |
+
try:
|
| 251 |
+
self.hf_api.upload_folder(
|
| 252 |
+
folder_path=self.local_pool_dir,
|
| 253 |
+
path_in_repo="distill_chunks",
|
| 254 |
+
repo_id=DATASET_REPO,
|
| 255 |
+
repo_type="dataset"
|
| 256 |
+
)
|
| 257 |
+
print(f"✅ 上传成功 (Attempt {attempt + 1}).")
|
| 258 |
+
worker_status["total_chunks_uploaded"] += len(current_files)
|
| 259 |
+
for f in current_files: os.remove(os.path.join(self.local_pool_dir, f))
|
| 260 |
+
break
|
| 261 |
+
except Exception as e:
|
| 262 |
+
wait_time = 5 * (2 ** attempt)
|
| 263 |
+
print(f"⚠️ Upload failed: {e}. Waiting {wait_time}s...")
|
| 264 |
+
time.sleep(wait_time)
|
| 265 |
+
|
| 266 |
+
def process_game(self, events):
|
| 267 |
+
who = -1
|
| 268 |
+
current_weight = 1.0
|
| 269 |
+
|
| 270 |
+
ps_student = None
|
| 271 |
+
bot_teacher = None
|
| 272 |
+
|
| 273 |
+
for i, event in enumerate(events):
|
| 274 |
+
if event.get('type') == 'start_game':
|
| 275 |
+
who = event['id']
|
| 276 |
+
weights_list = event.get('weight', [1.0, 1.0, 1.0])
|
| 277 |
+
current_weight = weights_list[who]
|
| 278 |
+
|
| 279 |
+
# 初始化双模型状态机
|
| 280 |
+
ps_student = sanma_state.PlayerState(who)
|
| 281 |
+
bot_teacher = RiichiBot(DummyFeatureEngine(), who)
|
| 282 |
+
|
| 283 |
+
if ps_student is None or bot_teacher is None:
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
if event.get('type') == 'end_game':
|
| 287 |
+
continue
|
| 288 |
+
|
| 289 |
+
next_event = None
|
| 290 |
+
for j in range(i + 1, len(events)):
|
| 291 |
+
if events[j].get('type') not in ('dora', 'reach_accepted'):
|
| 292 |
+
next_event = events[j]; break
|
| 293 |
+
|
| 294 |
+
event_str = json.dumps(event, separators=(",", ":"))
|
| 295 |
+
|
| 296 |
+
# --- 1. Teacher 更新与拦截 ---
|
| 297 |
+
_thread_local.interception = None
|
| 298 |
+
bot_teacher.react(event_str)
|
| 299 |
+
intercepted = getattr(_thread_local, 'interception', None)
|
| 300 |
+
|
| 301 |
+
# --- 2. Student 更新与特征生成 ---
|
| 302 |
+
cans = ps_student.update(event_str)
|
| 303 |
+
|
| 304 |
+
if intercepted is None or not cans.can_act:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
obs_t, masks_t, _ = intercepted
|
| 308 |
+
obs_s, mask_s = ps_student.encode_obs(4, False)
|
| 309 |
+
|
| 310 |
+
valid_actions_count = int(np.count_nonzero(masks_t[0]))
|
| 311 |
+
if valid_actions_count <= 1:
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
try:
|
| 315 |
+
output_code = self.action_to_mask(who, next_event)
|
| 316 |
+
|
| 317 |
+
# 存入字典,解耦新老数据格式
|
| 318 |
+
self.inputs.append({
|
| 319 |
+
"obs_student": obs_s,
|
| 320 |
+
"mask_student": mask_s,
|
| 321 |
+
"obs_teacher": obs_t[0], # 去除 batch 维度
|
| 322 |
+
"mask_teacher": masks_t[0]
|
| 323 |
+
})
|
| 324 |
+
self.outputs.append(output_code)
|
| 325 |
+
self.weights.append(current_weight)
|
| 326 |
+
|
| 327 |
+
worker_status["total_records_extracted"] += 1
|
| 328 |
+
except Exception: pass
|
| 329 |
+
|
| 330 |
+
if len(self.inputs) >= self.chunk_size:
|
| 331 |
+
self.save_and_check_upload()
|
| 332 |
+
|
| 333 |
+
# ==========================================
|
| 334 |
+
# [数据挖掘总管线]
|
| 335 |
+
# ==========================================
|
| 336 |
+
def worker_pipeline():
|
| 337 |
+
if not HF_TOKEN or not DATASET_REPO:
|
| 338 |
+
worker_status["status"] = "Error: HF_TOKEN or DATASET_REPO missing!"
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
worker_status["status"] = "Fetching target URL list..."
|
| 342 |
+
try:
|
| 343 |
+
url_file_path = hf_hub_download(repo_id=DATASET_REPO, filename=URL_LIST_FILE, repo_type="dataset", token=HF_TOKEN)
|
| 344 |
+
with open(url_file_path, 'r') as f: target_urls = [line.strip() for line in f if line.strip()]
|
| 345 |
+
except Exception as e:
|
| 346 |
+
worker_status["status"] = f"Failed to fetch {URL_LIST_FILE}: {e}"
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
headers = {"User-Agent": "Mozilla/5.0"}
|
| 350 |
+
encoder = FeatureEncoder(chunk_size=2048, pool_size=8)
|
| 351 |
+
worker_status["status"] = "Mining..."
|
| 352 |
+
|
| 353 |
+
for url in target_urls:
|
| 354 |
+
worker_status["current_target"] = url
|
| 355 |
+
log_match = re.search(r'log=([^&]+)', url)
|
| 356 |
+
tw_match = re.search(r'tw=(\d+)', url)
|
| 357 |
+
if not log_match: continue
|
| 358 |
+
|
| 359 |
+
tw = int(tw_match.group(1)) if tw_match else -1
|
| 360 |
+
log_id = log_match.group(1)
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
res = requests.get(f"https://tenhou.net/5/mjlog2json.cgi?{log_id}", headers=headers, timeout=30)
|
| 364 |
+
parsed_games = TenhouParser.parse_log(res.json())
|
| 365 |
+
|
| 366 |
+
for game in parsed_games:
|
| 367 |
+
for j in range(3):
|
| 368 |
+
if j == tw: continue
|
| 369 |
+
game[0]['id'] = j
|
| 370 |
+
encoder.process_game(game)
|
| 371 |
+
|
| 372 |
+
worker_status["urls_processed"] += 1
|
| 373 |
+
except Exception as e:
|
| 374 |
+
worker_status["errors"] += 1
|
| 375 |
+
|
| 376 |
+
encoder.save_and_check_upload()
|
| 377 |
+
encoder.upload_pool()
|
| 378 |
+
worker_status["status"] = "Finished! All URLs processed."
|
| 379 |
+
worker_status["current_target"] = "Idle"
|
| 380 |
+
|
| 381 |
+
app = FastAPI()
|
| 382 |
+
@app.get("/")
|
| 383 |
+
def read_status(): return worker_status
|
| 384 |
+
|
| 385 |
+
if __name__ == '__main__':
|
| 386 |
+
thread = threading.Thread(target=worker_pipeline, daemon=True)
|
| 387 |
+
thread.start()
|
| 388 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|