liumaolin commited on
Commit
e43edbb
·
1 Parent(s): f458b69

feat(api): implement AsyncTrainingManager MVP with SQLite persistence

Browse files

- Add api_server module with adapter pattern architecture
- Implement AsyncTrainingManager using asyncio.subprocess + SQLite
- Add TaskQueueAdapter abstract base class for future server mode
- Create domain models: Task, TaskStatus, ProgressInfo
- Add run_pipeline.py wrapper script for subprocess execution
- Create config module for centralized environment variables
- Add aiosqlite dependency to pyproject.toml
- Include test config files for pipeline validation

The AsyncTrainingManager provides:
- Async task queue with SQLite persistence
- Real-time progress tracking via stdout JSON parsing
- Task cancellation and status querying
- Progress subscription for SSE streaming
- Application restart recovery support

api_server/app/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ GPT-SoVITS 训练 API Server
3
+ """
4
+
5
+ __version__ = "0.1.0"
api_server/app/adapters/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 适配器模块
3
+
4
+ 提供不同环境下的存储、任务队列等适配器实现
5
+ """
6
+
7
+ from .base import TaskQueueAdapter
8
+
9
+ __all__ = ["TaskQueueAdapter"]
api_server/app/adapters/base.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 适配器抽象基类模块
3
+
4
+ 定义任务队列、存储、数据库等适配器的抽象接口
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Dict, Optional, AsyncGenerator
9
+
10
+
11
+ class TaskQueueAdapter(ABC):
12
+ """
13
+ 任务队列适配器抽象基类
14
+
15
+ 定义任务队列的通用接口,支持本地(asyncio.subprocess)和
16
+ 服务器(Celery)两种实现方式。
17
+
18
+ Example:
19
+ >>> adapter = AsyncTrainingManager(db_path="./data/tasks.db")
20
+ >>> job_id = await adapter.enqueue("task-123", {"exp_name": "test"})
21
+ >>> status = await adapter.get_status(job_id)
22
+ >>> async for progress in adapter.subscribe_progress("task-123"):
23
+ ... print(progress)
24
+ """
25
+
26
+ @abstractmethod
27
+ async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str:
28
+ """
29
+ 将任务加入队列
30
+
31
+ Args:
32
+ task_id: 任务唯一标识
33
+ config: 任务配置字典,包含训练所需的所有参数
34
+ priority: 任务优先级 ("low", "normal", "high")
35
+
36
+ Returns:
37
+ job_id: 队列中的作业ID
38
+
39
+ Raises:
40
+ ValueError: 配置无效时抛出
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ async def get_status(self, job_id: str) -> Dict:
46
+ """
47
+ 获取任务状态
48
+
49
+ Args:
50
+ job_id: 作业ID
51
+
52
+ Returns:
53
+ 状态字典,包含:
54
+ - status: 任务状态 (queued, running, completed, failed, cancelled)
55
+ - progress: 进度 (0.0-1.0)
56
+ - current_stage: 当前阶段名称
57
+ - message: 状态消息
58
+ - error_message: 错误信息(如果失败)
59
+ """
60
+ pass
61
+
62
+ @abstractmethod
63
+ async def cancel(self, job_id: str) -> bool:
64
+ """
65
+ 取消任务
66
+
67
+ Args:
68
+ job_id: 作业ID
69
+
70
+ Returns:
71
+ 是否成功取消
72
+ """
73
+ pass
74
+
75
+ @abstractmethod
76
+ async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]:
77
+ """
78
+ 订阅任务进度(用于 SSE 流)
79
+
80
+ Args:
81
+ task_id: 任务ID
82
+
83
+ Yields:
84
+ 进度信息字典,包含:
85
+ - type: 消息类型 ("progress", "log", "heartbeat")
86
+ - stage: 当前阶段
87
+ - progress: 进度值
88
+ - message: 进度消息
89
+ - status: 状态 (running, completed, failed, cancelled)
90
+
91
+ Note:
92
+ 当 status 为终态时,生成器会自动结束
93
+ """
94
+ pass
95
+
96
+
97
+ class ProgressAdapter(ABC):
98
+ """
99
+ 进度管理适配器抽象基类
100
+
101
+ 用于更新和订阅任务进度,支持本地(内存队列)和
102
+ 服务器(Redis Pub/Sub)两种实现。
103
+ """
104
+
105
+ @abstractmethod
106
+ async def update_progress(self, task_id: str, progress: Dict) -> None:
107
+ """
108
+ 更新进度
109
+
110
+ Args:
111
+ task_id: 任务ID
112
+ progress: 进度信息字典
113
+ """
114
+ pass
115
+
116
+ @abstractmethod
117
+ async def get_progress(self, task_id: str) -> Optional[Dict]:
118
+ """
119
+ 获取当前进度
120
+
121
+ Args:
122
+ task_id: 任务ID
123
+
124
+ Returns:
125
+ 最新进度信息,不存在则返回 None
126
+ """
127
+ pass
128
+
129
+ @abstractmethod
130
+ async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]:
131
+ """
132
+ 订阅进度更新
133
+
134
+ Args:
135
+ task_id: 任务ID
136
+
137
+ Yields:
138
+ 进度信息字典
139
+ """
140
+ pass
api_server/app/adapters/local/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 本地适配器模块
3
+
4
+ 提供基于 SQLite 和 asyncio.subprocess 的本地实现
5
+ """
6
+
7
+ from .task_queue import AsyncTrainingManager
8
+
9
+ __all__ = ["AsyncTrainingManager"]
api_server/app/adapters/local/task_queue.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 本地异步任务管理器
3
+
4
+ 基于 asyncio.subprocess + SQLite 的本地任务队列实现。
5
+ 适用于 macOS 本地训练和 Electron 集成场景。
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import os
11
+ import sqlite3
12
+ import sys
13
+ import uuid
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+ from typing import Dict, Optional, AsyncGenerator, List
17
+
18
+ import aiosqlite
19
+
20
+ from ..base import TaskQueueAdapter
21
+ from ...core.config import settings, PROJECT_ROOT, get_pythonpath
22
+
23
+ # 进度消息标识符(与 run_pipeline.py 保持一致)
24
+ PROGRESS_PREFIX = "##PROGRESS##"
25
+ PROGRESS_SUFFIX = "##"
26
+
27
+
28
+ class AsyncTrainingManager(TaskQueueAdapter):
29
+ """
30
+ 基于 asyncio.subprocess 的异步任务管理器
31
+
32
+ 特点:
33
+ 1. 使用 asyncio.create_subprocess_exec() 异步启动训练子进程
34
+ 2. 完全非阻塞,与 FastAPI 异步模型完美契合
35
+ 3. SQLite 持久化任务状态,支持应用重启后恢复
36
+ 4. 实时解析子进程输出获取进度
37
+
38
+ Example:
39
+ >>> manager = AsyncTrainingManager(db_path="./data/tasks.db")
40
+ >>> job_id = await manager.enqueue("task-123", {"exp_name": "test", ...})
41
+ >>>
42
+ >>> # 订阅进度
43
+ >>> async for progress in manager.subscribe_progress("task-123"):
44
+ ... print(f"{progress['stage']}: {progress['progress']*100:.1f}%")
45
+ >>>
46
+ >>> # 取消任务
47
+ >>> await manager.cancel(job_id)
48
+ """
49
+
50
+ def __init__(self, db_path: str = None, max_concurrent: int = 1):
51
+ """
52
+ 初始化任务管理器
53
+
54
+ Args:
55
+ db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH
56
+ max_concurrent: 最大并发任务数(本地通常为1)
57
+ """
58
+ self.db_path = db_path or str(settings.SQLITE_PATH)
59
+ self.max_concurrent = max_concurrent
60
+
61
+ # 运行时状态
62
+ self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process
63
+ self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue
64
+ self._running_count = 0
65
+ self._queue_lock = asyncio.Lock()
66
+
67
+ # 初始化数据库
68
+ self._init_db_sync()
69
+
70
+ def _init_db_sync(self) -> None:
71
+ """同步初始化数据库(启动时调用)"""
72
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
73
+
74
+ with sqlite3.connect(self.db_path) as conn:
75
+ conn.execute('''
76
+ CREATE TABLE IF NOT EXISTS task_queue (
77
+ job_id TEXT PRIMARY KEY,
78
+ task_id TEXT NOT NULL UNIQUE,
79
+ exp_name TEXT NOT NULL,
80
+ config TEXT NOT NULL,
81
+ status TEXT DEFAULT 'queued',
82
+ current_stage TEXT,
83
+ progress REAL DEFAULT 0,
84
+ overall_progress REAL DEFAULT 0,
85
+ message TEXT,
86
+ error_message TEXT,
87
+ created_at TEXT NOT NULL,
88
+ started_at TEXT,
89
+ completed_at TEXT
90
+ )
91
+ ''')
92
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_status ON task_queue(status)')
93
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_task_id ON task_queue(task_id)')
94
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_created ON task_queue(created_at)')
95
+ conn.commit()
96
+
97
+ async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str:
98
+ """
99
+ 将任务加入队列并异步启动
100
+
101
+ Args:
102
+ task_id: 任务唯一标识
103
+ config: 任务配置,需包含:
104
+ - exp_name: 实验名称
105
+ - version: 模型版本
106
+ - stages: 阶段配置列表
107
+ priority: 优先级(当前实现忽略此参数)
108
+
109
+ Returns:
110
+ job_id: 作业ID
111
+ """
112
+ job_id = str(uuid.uuid4())
113
+ exp_name = config.get("exp_name", "unknown")
114
+
115
+ # 持久化到 SQLite
116
+ async with aiosqlite.connect(self.db_path) as db:
117
+ await db.execute(
118
+ '''INSERT INTO task_queue
119
+ (job_id, task_id, exp_name, config, status, created_at)
120
+ VALUES (?, ?, ?, ?, 'queued', ?)''',
121
+ (job_id, task_id, exp_name, json.dumps(config, ensure_ascii=False),
122
+ datetime.utcnow().isoformat())
123
+ )
124
+ await db.commit()
125
+
126
+ # 创建进度队列
127
+ self.progress_channels[task_id] = asyncio.Queue()
128
+
129
+ # 异步启动训练任务
130
+ asyncio.create_task(self._run_training_async(job_id, task_id, config))
131
+
132
+ return job_id
133
+
134
+ async def _run_training_async(self, job_id: str, task_id: str, config: Dict) -> None:
135
+ """
136
+ 异步执行训练 Pipeline
137
+
138
+ Args:
139
+ job_id: 作业ID
140
+ task_id: 任务ID
141
+ config: 任务配置
142
+ """
143
+ config_path = None
144
+
145
+ try:
146
+ # 更新状态为 running
147
+ await self._update_status(
148
+ job_id,
149
+ status='running',
150
+ started_at=datetime.utcnow().isoformat()
151
+ )
152
+ await self._send_progress(task_id, {
153
+ "type": "progress",
154
+ "status": "running",
155
+ "message": "训练任务启动中...",
156
+ "progress": 0.0,
157
+ "overall_progress": 0.0,
158
+ })
159
+
160
+ # 写入临时配置文件
161
+ config_path = await self._write_config_file(task_id, config)
162
+
163
+ # 获取 run_pipeline.py 脚本路径
164
+ script_path = self._get_pipeline_script_path()
165
+
166
+ # 构建环境变量
167
+ env = os.environ.copy()
168
+ env['PYTHONPATH'] = get_pythonpath()
169
+
170
+ # 创建子进程
171
+ process = await asyncio.create_subprocess_exec(
172
+ sys.executable, script_path,
173
+ '--config', config_path,
174
+ '--task-id', task_id,
175
+ stdout=asyncio.subprocess.PIPE,
176
+ stderr=asyncio.subprocess.PIPE,
177
+ env=env,
178
+ cwd=str(PROJECT_ROOT),
179
+ )
180
+
181
+ self.running_processes[task_id] = process
182
+ self._running_count += 1
183
+
184
+ # 异步监控子进程输出
185
+ await self._monitor_process_output(task_id, job_id, process)
186
+
187
+ # 等待进程完成
188
+ returncode = await process.wait()
189
+
190
+ if returncode == 0:
191
+ await self._update_status(
192
+ job_id,
193
+ status='completed',
194
+ progress=1.0,
195
+ overall_progress=1.0,
196
+ message='训练完成',
197
+ completed_at=datetime.utcnow().isoformat()
198
+ )
199
+ await self._send_progress(task_id, {
200
+ "type": "progress",
201
+ "status": "completed",
202
+ "message": "训练完成",
203
+ "progress": 1.0,
204
+ "overall_progress": 1.0,
205
+ })
206
+ else:
207
+ # 尝试读取剩余的 stderr
208
+ stderr_data = await process.stderr.read()
209
+ error_msg = stderr_data.decode() if stderr_data else f"进程退出码: {returncode}"
210
+
211
+ await self._update_status(
212
+ job_id,
213
+ status='failed',
214
+ error_message=error_msg,
215
+ completed_at=datetime.utcnow().isoformat()
216
+ )
217
+ await self._send_progress(task_id, {
218
+ "type": "progress",
219
+ "status": "failed",
220
+ "message": f"训练失败: {error_msg[:200]}",
221
+ "error": error_msg,
222
+ })
223
+
224
+ except asyncio.CancelledError:
225
+ await self._update_status(
226
+ job_id,
227
+ status='cancelled',
228
+ message='任务已取消',
229
+ completed_at=datetime.utcnow().isoformat()
230
+ )
231
+ await self._send_progress(task_id, {
232
+ "type": "progress",
233
+ "status": "cancelled",
234
+ "message": "任务已取消",
235
+ })
236
+
237
+ except Exception as e:
238
+ error_msg = str(e)
239
+ await self._update_status(
240
+ job_id,
241
+ status='failed',
242
+ error_message=error_msg,
243
+ completed_at=datetime.utcnow().isoformat()
244
+ )
245
+ await self._send_progress(task_id, {
246
+ "type": "progress",
247
+ "status": "failed",
248
+ "message": f"任务执行出错: {error_msg}",
249
+ "error": error_msg,
250
+ })
251
+
252
+ finally:
253
+ # 清理
254
+ self.running_processes.pop(task_id, None)
255
+ self._running_count = max(0, self._running_count - 1)
256
+
257
+ # 清理临时配置文件
258
+ if config_path:
259
+ await self._cleanup_config_file(config_path)
260
+
261
+ async def _monitor_process_output(
262
+ self,
263
+ task_id: str,
264
+ job_id: str,
265
+ process: asyncio.subprocess.Process
266
+ ) -> None:
267
+ """
268
+ 异步监控子进程输出并解析进度
269
+
270
+ Args:
271
+ task_id: 任务ID
272
+ job_id: 作业ID
273
+ process: 子进程对象
274
+ """
275
+ async def read_stdout():
276
+ """读取 stdout 并解析进度"""
277
+ while True:
278
+ line = await process.stdout.readline()
279
+ if not line:
280
+ break
281
+
282
+ text = line.decode('utf-8', errors='replace').strip()
283
+ if not text:
284
+ continue
285
+
286
+ # 检测进度标记
287
+ if text.startswith(PROGRESS_PREFIX) and text.endswith(PROGRESS_SUFFIX):
288
+ json_str = text[len(PROGRESS_PREFIX):-len(PROGRESS_SUFFIX)]
289
+ try:
290
+ progress_info = json.loads(json_str)
291
+ await self._handle_progress(task_id, job_id, progress_info)
292
+ except json.JSONDecodeError as e:
293
+ # 解析失败,作为普通日志处理
294
+ await self._send_progress(task_id, {
295
+ "type": "log",
296
+ "level": "warning",
297
+ "message": f"进度解析失败: {e}",
298
+ })
299
+ else:
300
+ # 普通输出,作为日志处理
301
+ await self._send_progress(task_id, {
302
+ "type": "log",
303
+ "level": "info",
304
+ "message": text,
305
+ })
306
+
307
+ async def read_stderr():
308
+ """读取 stderr 作为错误日志"""
309
+ while True:
310
+ line = await process.stderr.readline()
311
+ if not line:
312
+ break
313
+
314
+ text = line.decode('utf-8', errors='replace').strip()
315
+ if text:
316
+ await self._send_progress(task_id, {
317
+ "type": "log",
318
+ "level": "error",
319
+ "message": text,
320
+ })
321
+
322
+ # 并发读取 stdout 和 stderr
323
+ await asyncio.gather(
324
+ read_stdout(),
325
+ read_stderr(),
326
+ return_exceptions=True
327
+ )
328
+
329
+ async def _handle_progress(
330
+ self,
331
+ task_id: str,
332
+ job_id: str,
333
+ progress_info: Dict
334
+ ) -> None:
335
+ """
336
+ 处理进度信息
337
+
338
+ Args:
339
+ task_id: 任务ID
340
+ job_id: 作业ID
341
+ progress_info: 进度信息字典
342
+ """
343
+ # 发送到订阅者
344
+ await self._send_progress(task_id, progress_info)
345
+
346
+ # 更新数据库中的进度
347
+ updates = {}
348
+
349
+ if 'stage' in progress_info:
350
+ updates['current_stage'] = progress_info['stage']
351
+ if 'progress' in progress_info:
352
+ updates['progress'] = progress_info['progress']
353
+ if 'overall_progress' in progress_info:
354
+ updates['overall_progress'] = progress_info['overall_progress']
355
+ if 'message' in progress_info:
356
+ updates['message'] = progress_info['message']
357
+ if 'status' in progress_info:
358
+ updates['status'] = progress_info['status']
359
+ if 'error' in progress_info:
360
+ updates['error_message'] = progress_info['error']
361
+
362
+ if updates:
363
+ await self._update_status(job_id, **updates)
364
+
365
+ async def _send_progress(self, task_id: str, progress_info: Dict) -> None:
366
+ """
367
+ 发送进度到订阅队列
368
+
369
+ Args:
370
+ task_id: 任务ID
371
+ progress_info: 进度信息
372
+ """
373
+ if task_id in self.progress_channels:
374
+ # 添加时间戳
375
+ if 'timestamp' not in progress_info:
376
+ progress_info['timestamp'] = datetime.utcnow().isoformat()
377
+
378
+ await self.progress_channels[task_id].put(progress_info)
379
+
380
+ async def _update_status(self, job_id: str, **kwargs) -> None:
381
+ """
382
+ 更新任务状态
383
+
384
+ Args:
385
+ job_id: 作业ID
386
+ **kwargs: 要更新的字段
387
+ """
388
+ if not kwargs:
389
+ return
390
+
391
+ async with aiosqlite.connect(self.db_path) as db:
392
+ updates = []
393
+ values = []
394
+
395
+ for key, value in kwargs.items():
396
+ updates.append(f"{key} = ?")
397
+ values.append(value)
398
+
399
+ values.append(job_id)
400
+
401
+ await db.execute(
402
+ f"UPDATE task_queue SET {', '.join(updates)} WHERE job_id = ?",
403
+ values
404
+ )
405
+ await db.commit()
406
+
407
+ async def get_status(self, job_id: str) -> Dict:
408
+ """
409
+ 获取任务状态
410
+
411
+ Args:
412
+ job_id: 作业ID
413
+
414
+ Returns:
415
+ 状态字典
416
+ """
417
+ async with aiosqlite.connect(self.db_path) as db:
418
+ db.row_factory = aiosqlite.Row
419
+ async with db.execute(
420
+ "SELECT * FROM task_queue WHERE job_id = ?", (job_id,)
421
+ ) as cursor:
422
+ row = await cursor.fetchone()
423
+ if row:
424
+ return dict(row)
425
+
426
+ return {"status": "not_found", "message": "任务不存在"}
427
+
428
+ async def get_status_by_task_id(self, task_id: str) -> Dict:
429
+ """
430
+ 通过 task_id 获取任务状态
431
+
432
+ Args:
433
+ task_id: 任务ID
434
+
435
+ Returns:
436
+ 状态字典
437
+ """
438
+ async with aiosqlite.connect(self.db_path) as db:
439
+ db.row_factory = aiosqlite.Row
440
+ async with db.execute(
441
+ "SELECT * FROM task_queue WHERE task_id = ?", (task_id,)
442
+ ) as cursor:
443
+ row = await cursor.fetchone()
444
+ if row:
445
+ return dict(row)
446
+
447
+ return {"status": "not_found", "message": "任务不存在"}
448
+
449
+ async def cancel(self, job_id: str) -> bool:
450
+ """
451
+ 取消任务
452
+
453
+ Args:
454
+ job_id: 作业ID
455
+
456
+ Returns:
457
+ 是否成功取消
458
+ """
459
+ # 查找 task_id
460
+ async with aiosqlite.connect(self.db_path) as db:
461
+ async with db.execute(
462
+ "SELECT task_id, status FROM task_queue WHERE job_id = ?", (job_id,)
463
+ ) as cursor:
464
+ row = await cursor.fetchone()
465
+ if not row:
466
+ return False
467
+ task_id, status = row
468
+
469
+ # 如果任务已经完成,无法取消
470
+ if status in ('completed', 'failed', 'cancelled'):
471
+ return False
472
+
473
+ # 终止进程
474
+ if task_id in self.running_processes:
475
+ process = self.running_processes[task_id]
476
+
477
+ # 先尝试优雅终止
478
+ process.terminate()
479
+
480
+ try:
481
+ # 等待进程终止
482
+ await asyncio.wait_for(process.wait(), timeout=5.0)
483
+ except asyncio.TimeoutError:
484
+ # 超时则强制终止
485
+ process.kill()
486
+ await process.wait()
487
+
488
+ return True
489
+
490
+ # 如果进程不在运行(可能还在队列中),直接更新状态
491
+ await self._update_status(
492
+ job_id,
493
+ status='cancelled',
494
+ message='任务已取消',
495
+ completed_at=datetime.utcnow().isoformat()
496
+ )
497
+
498
+ # 通知订阅者
499
+ if task_id in self.progress_channels:
500
+ await self._send_progress(task_id, {
501
+ "type": "progress",
502
+ "status": "cancelled",
503
+ "message": "任务已取消",
504
+ })
505
+
506
+ return True
507
+
508
+ async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]:
509
+ """
510
+ 订阅任务进度(用于 SSE 流)
511
+
512
+ Args:
513
+ task_id: 任务ID
514
+
515
+ Yields:
516
+ 进度信息字典
517
+ """
518
+ # 确保队列存在
519
+ if task_id not in self.progress_channels:
520
+ self.progress_channels[task_id] = asyncio.Queue()
521
+
522
+ queue = self.progress_channels[task_id]
523
+
524
+ # 首先发送当前状态
525
+ status = await self.get_status_by_task_id(task_id)
526
+ if status.get("status") != "not_found":
527
+ yield {
528
+ "type": "progress",
529
+ "status": status.get("status"),
530
+ "stage": status.get("current_stage"),
531
+ "progress": status.get("progress", 0.0),
532
+ "overall_progress": status.get("overall_progress", 0.0),
533
+ "message": status.get("message"),
534
+ "timestamp": datetime.utcnow().isoformat(),
535
+ }
536
+
537
+ # 持续接收进度更新
538
+ while True:
539
+ try:
540
+ # 30秒超时,发送心跳
541
+ progress = await asyncio.wait_for(queue.get(), timeout=30.0)
542
+ yield progress
543
+
544
+ # 检查是否为终态
545
+ if progress.get('status') in ('completed', 'failed', 'cancelled'):
546
+ break
547
+
548
+ except asyncio.TimeoutError:
549
+ # 发送心跳保持连接
550
+ yield {
551
+ "type": "heartbeat",
552
+ "timestamp": datetime.utcnow().isoformat(),
553
+ }
554
+
555
+ async def list_tasks(
556
+ self,
557
+ status: Optional[str] = None,
558
+ limit: int = 50,
559
+ offset: int = 0
560
+ ) -> List[Dict]:
561
+ """
562
+ 列出任务
563
+
564
+ Args:
565
+ status: 按状态筛选
566
+ limit: 返回数量限制
567
+ offset: 偏移量
568
+
569
+ Returns:
570
+ 任务列表
571
+ """
572
+ async with aiosqlite.connect(self.db_path) as db:
573
+ db.row_factory = aiosqlite.Row
574
+
575
+ if status:
576
+ query = """
577
+ SELECT * FROM task_queue
578
+ WHERE status = ?
579
+ ORDER BY created_at DESC
580
+ LIMIT ? OFFSET ?
581
+ """
582
+ params = (status, limit, offset)
583
+ else:
584
+ query = """
585
+ SELECT * FROM task_queue
586
+ ORDER BY created_at DESC
587
+ LIMIT ? OFFSET ?
588
+ """
589
+ params = (limit, offset)
590
+
591
+ async with db.execute(query, params) as cursor:
592
+ rows = await cursor.fetchall()
593
+ return [dict(row) for row in rows]
594
+
595
+ async def recover_pending_tasks(self) -> int:
596
+ """
597
+ 应用重启后恢复未完成的任务
598
+
599
+ 将 running 状态的任务标记为 interrupted,
600
+ 可选择重新启动 queued 状态的任务。
601
+
602
+ Returns:
603
+ 恢复的任务数量
604
+ """
605
+ async with aiosqlite.connect(self.db_path) as db:
606
+ # 将 running 状态的任务标记为 interrupted
607
+ await db.execute(
608
+ """UPDATE task_queue
609
+ SET status = 'interrupted',
610
+ message = '应用重启导致任务中断'
611
+ WHERE status = 'running'"""
612
+ )
613
+ await db.commit()
614
+
615
+ # 获取 queued 状态的任务
616
+ db.row_factory = aiosqlite.Row
617
+ async with db.execute(
618
+ "SELECT * FROM task_queue WHERE status = 'queued' ORDER BY created_at"
619
+ ) as cursor:
620
+ queued_tasks = await cursor.fetchall()
621
+
622
+ # 重新启动 queued 状态的任务
623
+ recovered = 0
624
+ for task in queued_tasks:
625
+ task_id = task['task_id']
626
+ job_id = task['job_id']
627
+ config = json.loads(task['config'])
628
+
629
+ self.progress_channels[task_id] = asyncio.Queue()
630
+ asyncio.create_task(self._run_training_async(job_id, task_id, config))
631
+ recovered += 1
632
+
633
+ return recovered
634
+
635
+ async def cleanup_old_tasks(self, days: int = 7) -> int:
636
+ """
637
+ 清理旧任务记录
638
+
639
+ Args:
640
+ days: 保留天数
641
+
642
+ Returns:
643
+ 删除的任务数量
644
+ """
645
+ from datetime import timedelta
646
+
647
+ cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat()
648
+
649
+ async with aiosqlite.connect(self.db_path) as db:
650
+ cursor = await db.execute(
651
+ """DELETE FROM task_queue
652
+ WHERE status IN ('completed', 'failed', 'cancelled')
653
+ AND completed_at < ?""",
654
+ (cutoff,)
655
+ )
656
+ deleted = cursor.rowcount
657
+ await db.commit()
658
+
659
+ return deleted
660
+
661
+ def _get_pipeline_script_path(self) -> str:
662
+ """获取 run_pipeline.py 脚本路径"""
663
+ return str(settings.PIPELINE_SCRIPT_PATH)
664
+
665
+ async def _write_config_file(self, task_id: str, config: Dict) -> str:
666
+ """
667
+ 写入临时配置文件
668
+
669
+ Args:
670
+ task_id: 任务ID
671
+ config: 配置字典
672
+
673
+ Returns:
674
+ 配置文件路径
675
+ """
676
+ config_path = settings.CONFIGS_DIR / f"{task_id}.json"
677
+
678
+ with open(config_path, 'w', encoding='utf-8') as f:
679
+ json.dump(config, f, ensure_ascii=False, indent=2)
680
+
681
+ return str(config_path)
682
+
683
+ async def _cleanup_config_file(self, config_path: str) -> None:
684
+ """
685
+ 清理临时配置文件
686
+
687
+ Args:
688
+ config_path: 配置文件路径
689
+ """
690
+ try:
691
+ path = Path(config_path)
692
+ if path.exists():
693
+ path.unlink()
694
+ except Exception:
695
+ pass # 忽略清理错误
api_server/app/core/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 核心模块
3
+
4
+ 包含配置、枚举等核心组件
5
+ """
6
+
7
+ from .config import settings, PROJECT_ROOT, API_SERVER_ROOT
8
+
9
+ __all__ = ["settings", "PROJECT_ROOT", "API_SERVER_ROOT"]
api_server/app/core/config.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 环境变量和配置模块
3
+
4
+ 统一管理项目路径、环境配置等
5
+ """
6
+
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Literal
10
+
11
+ # ============================================================
12
+ # 路径常量
13
+ # ============================================================
14
+
15
+ USER_HOME_ROOT = Path.home()
16
+
17
+ # api_server/app/core/config.py -> api_server/app/core -> api_server/app -> api_server -> 项目根目录
18
+ API_SERVER_ROOT = Path(__file__).parent.parent.parent.resolve()
19
+ PROJECT_ROOT = API_SERVER_ROOT.parent.resolve()
20
+
21
+ # GPT_SoVITS 模块路径
22
+ GPT_SOVITS_ROOT = PROJECT_ROOT / "GPT_SoVITS"
23
+
24
+ # 默认数据目录
25
+ DEFAULT_DATA_DIR = USER_HOME_ROOT / '.moyoyo-tts' / "data"
26
+
27
+ # 预训练模型目录
28
+ PRETRAINED_MODELS_DIR = GPT_SOVITS_ROOT / "pretrained_models"
29
+
30
+ # 日志目录
31
+ LOGS_DIR = PROJECT_ROOT / "logs"
32
+
33
+
34
+ # ============================================================
35
+ # 配置类
36
+ # ============================================================
37
+
38
+ class Settings:
39
+ """
40
+ API Server 配置
41
+
42
+ 支持从环境变量读取配置,提供合理的默认值
43
+
44
+ Example:
45
+ >>> from api_server.app.core.config import settings
46
+ >>> print(settings.PROJECT_ROOT)
47
+ >>> print(settings.DEPLOYMENT_MODE)
48
+ """
49
+
50
+ # 部署模式
51
+ DEPLOYMENT_MODE: Literal["local", "server"] = os.getenv("DEPLOYMENT_MODE", "local")
52
+
53
+ # API 配置
54
+ API_V1_PREFIX: str = os.getenv("API_V1_PREFIX", "/api/v1")
55
+ API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
56
+ API_PORT: int = int(os.getenv("API_PORT", "8000"))
57
+
58
+ # 路径配置(可通过环境变量覆盖)
59
+ PROJECT_ROOT: Path = Path(os.getenv("PROJECT_ROOT", str(PROJECT_ROOT)))
60
+ API_SERVER_ROOT: Path = Path(os.getenv("API_SERVER_ROOT", str(API_SERVER_ROOT)))
61
+ DATA_DIR: Path = Path(os.getenv("DATA_DIR", str(DEFAULT_DATA_DIR)))
62
+
63
+ # SQLite 数据库路径
64
+ SQLITE_PATH: Path = Path(os.getenv("SQLITE_PATH", str(DEFAULT_DATA_DIR / "tasks.db")))
65
+
66
+ # 任务配置
67
+ LOCAL_MAX_WORKERS: int = int(os.getenv("LOCAL_MAX_WORKERS", "1"))
68
+
69
+ # 预训练模型路径
70
+ BERT_PRETRAINED_DIR: str = os.getenv(
71
+ "BERT_PRETRAINED_DIR",
72
+ str(PRETRAINED_MODELS_DIR / "chinese-roberta-wwm-ext-large")
73
+ )
74
+ SSL_PRETRAINED_DIR: str = os.getenv(
75
+ "SSL_PRETRAINED_DIR",
76
+ str(PRETRAINED_MODELS_DIR / "chinese-hubert-base")
77
+ )
78
+ PRETRAINED_S2G: str = os.getenv(
79
+ "PRETRAINED_S2G",
80
+ str(PRETRAINED_MODELS_DIR / "gsv-v2final-pretrained" / "s2G2333k.pth")
81
+ )
82
+ PRETRAINED_S2D: str = os.getenv(
83
+ "PRETRAINED_S2D",
84
+ str(PRETRAINED_MODELS_DIR / "gsv-v2final-pretrained" / "s2D2333k.pth")
85
+ )
86
+ PRETRAINED_S1: str = os.getenv(
87
+ "PRETRAINED_S1",
88
+ str(PRETRAINED_MODELS_DIR / "gsv-v2final-pretrained" / "s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
89
+ )
90
+
91
+ # Pipeline 脚本路径
92
+ @property
93
+ def PIPELINE_SCRIPT_PATH(self) -> Path:
94
+ """Pipeline 执行脚本路径"""
95
+ return self.API_SERVER_ROOT / "app" / "scripts" / "run_pipeline.py"
96
+
97
+ # 临时配置文件目录
98
+ @property
99
+ def CONFIGS_DIR(self) -> Path:
100
+ """临时配置文件目录"""
101
+ path = self.DATA_DIR / "configs"
102
+ path.mkdir(parents=True, exist_ok=True)
103
+ return path
104
+
105
+ def __repr__(self) -> str:
106
+ return (
107
+ f"Settings(\n"
108
+ f" DEPLOYMENT_MODE={self.DEPLOYMENT_MODE!r},\n"
109
+ f" PROJECT_ROOT={self.PROJECT_ROOT},\n"
110
+ f" API_SERVER_ROOT={self.API_SERVER_ROOT},\n"
111
+ f" DATA_DIR={self.DATA_DIR},\n"
112
+ f" SQLITE_PATH={self.SQLITE_PATH},\n"
113
+ f")"
114
+ )
115
+
116
+
117
+ # 全局配置实例
118
+ settings = Settings()
119
+
120
+
121
+ def get_pythonpath() -> str:
122
+ """
123
+ 获取 PYTHONPATH 环境变量值
124
+
125
+ 用于子进程启动时设置正确的模块搜索路径
126
+
127
+ Returns:
128
+ PYTHONPATH 字符串
129
+ """
130
+ paths = [
131
+ str(PROJECT_ROOT),
132
+ str(GPT_SOVITS_ROOT),
133
+ ]
134
+ return os.pathsep.join(paths)
135
+
136
+
137
+ def ensure_data_dirs() -> None:
138
+ """
139
+ 确保必要的数据目录存在
140
+ """
141
+ settings.DATA_DIR.mkdir(parents=True, exist_ok=True)
142
+ settings.CONFIGS_DIR.mkdir(parents=True, exist_ok=True)
api_server/app/models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 模型模块
3
+
4
+ 包含领域模型和 Pydantic Schema
5
+ """
6
+
7
+ from .domain import Task, TaskStatus, ProgressInfo
8
+
9
+ __all__ = ["Task", "TaskStatus", "ProgressInfo"]
api_server/app/models/domain.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 领域模型模块
3
+
4
+ 定义训练任务相关的核心数据结构
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from typing import Dict, Optional, Any
11
+
12
+
13
+ class TaskStatus(Enum):
14
+ """任务状态枚举"""
15
+ QUEUED = "queued" # 已入队,等待执行
16
+ RUNNING = "running" # 执行中
17
+ COMPLETED = "completed" # 已完成
18
+ FAILED = "failed" # 失败
19
+ CANCELLED = "cancelled" # 已取消
20
+ INTERRUPTED = "interrupted" # 被中断(应用重启时运行中的任务)
21
+
22
+
23
+ @dataclass
24
+ class Task:
25
+ """
26
+ 训练任务领域模型
27
+
28
+ Attributes:
29
+ id: 任务唯一标识
30
+ job_id: 队列作业ID(由任务队列生成)
31
+ exp_name: 实验名称
32
+ status: 任务状态
33
+ config: 任务配置(包含所有训练参数)
34
+ current_stage: 当前执行阶段
35
+ progress: 总体进度 (0.0-1.0)
36
+ stage_progress: 当前阶段进度 (0.0-1.0)
37
+ message: 最新状态消息
38
+ error_message: 错误信息(失败时)
39
+ created_at: 创建时间
40
+ started_at: 开始执行时间
41
+ completed_at: 完成时间
42
+
43
+ Example:
44
+ >>> task = Task(
45
+ ... id="task-123",
46
+ ... exp_name="my_voice",
47
+ ... config={"version": "v2", "batch_size": 4}
48
+ ... )
49
+ >>> task.status
50
+ <TaskStatus.QUEUED: 'queued'>
51
+ """
52
+ id: str
53
+ exp_name: str
54
+ config: Dict[str, Any]
55
+ job_id: Optional[str] = None
56
+ status: TaskStatus = TaskStatus.QUEUED
57
+ current_stage: Optional[str] = None
58
+ progress: float = 0.0
59
+ stage_progress: float = 0.0
60
+ message: Optional[str] = None
61
+ error_message: Optional[str] = None
62
+ created_at: datetime = field(default_factory=datetime.utcnow)
63
+ started_at: Optional[datetime] = None
64
+ completed_at: Optional[datetime] = None
65
+
66
+ def to_dict(self) -> Dict[str, Any]:
67
+ """转换为字典"""
68
+ return {
69
+ "id": self.id,
70
+ "job_id": self.job_id,
71
+ "exp_name": self.exp_name,
72
+ "status": self.status.value,
73
+ "config": self.config,
74
+ "current_stage": self.current_stage,
75
+ "progress": self.progress,
76
+ "stage_progress": self.stage_progress,
77
+ "message": self.message,
78
+ "error_message": self.error_message,
79
+ "created_at": self.created_at.isoformat() if self.created_at else None,
80
+ "started_at": self.started_at.isoformat() if self.started_at else None,
81
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
82
+ }
83
+
84
+ @classmethod
85
+ def from_dict(cls, data: Dict[str, Any]) -> "Task":
86
+ """从字典创建实例"""
87
+ # 处理状态枚举
88
+ status = data.get("status", "queued")
89
+ if isinstance(status, str):
90
+ status = TaskStatus(status)
91
+
92
+ # 处理日期时间
93
+ def parse_datetime(value):
94
+ if value is None:
95
+ return None
96
+ if isinstance(value, datetime):
97
+ return value
98
+ return datetime.fromisoformat(value)
99
+
100
+ return cls(
101
+ id=data["id"],
102
+ job_id=data.get("job_id"),
103
+ exp_name=data["exp_name"],
104
+ status=status,
105
+ config=data.get("config", {}),
106
+ current_stage=data.get("current_stage"),
107
+ progress=data.get("progress", 0.0),
108
+ stage_progress=data.get("stage_progress", 0.0),
109
+ message=data.get("message"),
110
+ error_message=data.get("error_message"),
111
+ created_at=parse_datetime(data.get("created_at")),
112
+ started_at=parse_datetime(data.get("started_at")),
113
+ completed_at=parse_datetime(data.get("completed_at")),
114
+ )
115
+
116
+
117
+ @dataclass
118
+ class ProgressInfo:
119
+ """
120
+ 进度信息数据结构
121
+
122
+ 用于在子进程和主进程之间传递进度更新
123
+
124
+ Attributes:
125
+ type: 消息类型 ("progress", "log", "error", "heartbeat")
126
+ stage: 当前阶段名称
127
+ stage_index: 当前阶段索引
128
+ total_stages: 总阶段数
129
+ progress: 阶段内进度 (0.0-1.0)
130
+ overall_progress: 总体进度 (0.0-1.0)
131
+ message: 进度消息
132
+ status: 状态
133
+ data: 附加数据
134
+ """
135
+ type: str = "progress"
136
+ stage: Optional[str] = None
137
+ stage_index: Optional[int] = None
138
+ total_stages: Optional[int] = None
139
+ progress: float = 0.0
140
+ overall_progress: float = 0.0
141
+ message: Optional[str] = None
142
+ status: Optional[str] = None
143
+ data: Dict[str, Any] = field(default_factory=dict)
144
+
145
+ def to_dict(self) -> Dict[str, Any]:
146
+ """转换为字典"""
147
+ return {
148
+ "type": self.type,
149
+ "stage": self.stage,
150
+ "stage_index": self.stage_index,
151
+ "total_stages": self.total_stages,
152
+ "progress": self.progress,
153
+ "overall_progress": self.overall_progress,
154
+ "message": self.message,
155
+ "status": self.status,
156
+ "data": self.data,
157
+ }
158
+
159
+ @classmethod
160
+ def from_dict(cls, data: Dict[str, Any]) -> "ProgressInfo":
161
+ """从字典创建实例"""
162
+ return cls(
163
+ type=data.get("type", "progress"),
164
+ stage=data.get("stage"),
165
+ stage_index=data.get("stage_index"),
166
+ total_stages=data.get("total_stages"),
167
+ progress=data.get("progress", 0.0),
168
+ overall_progress=data.get("overall_progress", 0.0),
169
+ message=data.get("message"),
170
+ status=data.get("status"),
171
+ data=data.get("data", {}),
172
+ )
api_server/app/scripts/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ 脚本模块
3
+
4
+ 包含用于子进程执行的独立脚本
5
+ """
api_server/app/scripts/run_pipeline.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pipeline 包装脚本
4
+
5
+ 此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。
6
+ 主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。
7
+
8
+ 进度消息格式:
9
+ ##PROGRESS##{"type": "progress", "stage": "...", ...}##
10
+
11
+ Usage:
12
+ python run_pipeline.py --config /path/to/config.json --task-id task-123
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import sys
18
+ import os
19
+ import traceback
20
+ from datetime import datetime
21
+ from typing import Dict, Any
22
+
23
+ # 确保可以导入项目模块(在导入其他模块之前)
24
+ from pathlib import Path
25
+ _SCRIPT_DIR = Path(__file__).parent.resolve()
26
+ _API_SERVER_ROOT = _SCRIPT_DIR.parent.parent
27
+ _PROJECT_ROOT = _API_SERVER_ROOT.parent
28
+ sys.path.insert(0, str(_PROJECT_ROOT))
29
+
30
+ # 导入配置模块
31
+ from api_server.app.core.config import settings, PROJECT_ROOT, get_pythonpath
32
+
33
+
34
+ # 进度消息前缀和后缀,用于主进程解析
35
+ PROGRESS_PREFIX = "##PROGRESS##"
36
+ PROGRESS_SUFFIX = "##"
37
+
38
+
39
+ def emit_progress(progress_info: Dict[str, Any]) -> None:
40
+ """
41
+ 输出进度消息到 stdout
42
+
43
+ Args:
44
+ progress_info: 进度信息字典
45
+ """
46
+ # 确保有时间戳
47
+ if "timestamp" not in progress_info:
48
+ progress_info["timestamp"] = datetime.utcnow().isoformat()
49
+
50
+ json_str = json.dumps(progress_info, ensure_ascii=False)
51
+ print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True)
52
+
53
+
54
+ def emit_log(level: str, message: str, **extra) -> None:
55
+ """
56
+ 输出日志消息
57
+
58
+ Args:
59
+ level: 日志级别 (info, warning, error)
60
+ message: 日志消息
61
+ **extra: 额外数据
62
+ """
63
+ emit_progress({
64
+ "type": "log",
65
+ "level": level,
66
+ "message": message,
67
+ **extra
68
+ })
69
+
70
+
71
+ def load_config(config_path: str) -> Dict[str, Any]:
72
+ """
73
+ 加载配置文件
74
+
75
+ Args:
76
+ config_path: 配置文件路径
77
+
78
+ Returns:
79
+ 配置字典
80
+ """
81
+ with open(config_path, 'r', encoding='utf-8') as f:
82
+ return json.load(f)
83
+
84
+
85
+ def build_pipeline(config: Dict[str, Any]):
86
+ """
87
+ 根据配置构建 TrainingPipeline
88
+
89
+ Args:
90
+ config: 配置字典,包含:
91
+ - exp_name: 实验名称
92
+ - version: 模型版本
93
+ - stages: 要执行的阶段列表
94
+ - 各阶段的具体配置
95
+
96
+ Returns:
97
+ TrainingPipeline 实例
98
+ """
99
+ from training_pipeline import (
100
+ TrainingPipeline,
101
+ ModelVersion,
102
+ # 配置类
103
+ AudioSliceConfig,
104
+ ASRConfig,
105
+ DenoiseConfig,
106
+ FeatureExtractionConfig,
107
+ SoVITSTrainConfig,
108
+ GPTTrainConfig,
109
+ InferenceConfig,
110
+ # 阶段类
111
+ AudioSliceStage,
112
+ ASRStage,
113
+ DenoiseStage,
114
+ TextFeatureStage,
115
+ HuBERTFeatureStage,
116
+ SemanticTokenStage,
117
+ SoVITSTrainStage,
118
+ GPTTrainStage,
119
+ InferenceStage,
120
+ )
121
+
122
+ pipeline = TrainingPipeline()
123
+
124
+ exp_name = config["exp_name"]
125
+ version_str = config.get("version", "v2")
126
+ version = ModelVersion(version_str) if isinstance(version_str, str) else version_str
127
+
128
+ # 通用配置参数
129
+ base_params = {
130
+ "exp_name": exp_name,
131
+ "exp_root": config.get("exp_root", "logs"),
132
+ "gpu_numbers": config.get("gpu_numbers", "0"),
133
+ "is_half": config.get("is_half", True),
134
+ }
135
+
136
+ # 阶段配置映射
137
+ stage_builders = {
138
+ "audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig(
139
+ **base_params,
140
+ input_path=cfg.get("input_path", ""),
141
+ threshold=cfg.get("threshold", -34),
142
+ min_length=cfg.get("min_length", 4000),
143
+ min_interval=cfg.get("min_interval", 300),
144
+ hop_size=cfg.get("hop_size", 10),
145
+ max_sil_kept=cfg.get("max_sil_kept", 500),
146
+ max_amp=cfg.get("max_amp", 0.9),
147
+ alpha=cfg.get("alpha", 0.25),
148
+ n_parts=cfg.get("n_parts", 4),
149
+ )),
150
+
151
+ "asr": lambda cfg: ASRStage(ASRConfig(
152
+ **base_params,
153
+ model=cfg.get("model", "达摩 ASR (中文)"),
154
+ model_size=cfg.get("model_size", "large"),
155
+ language=cfg.get("language", "zh"),
156
+ precision=cfg.get("precision", "float32"),
157
+ )),
158
+
159
+ "denoise": lambda cfg: DenoiseStage(DenoiseConfig(
160
+ **base_params,
161
+ input_dir=cfg.get("input_dir", ""),
162
+ output_dir=cfg.get("output_dir", "output/denoise_opt"),
163
+ )),
164
+
165
+ "text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig(
166
+ **base_params,
167
+ version=version,
168
+ bert_pretrained_dir=cfg.get("bert_pretrained_dir",
169
+ "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
170
+ ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
171
+ "GPT_SoVITS/pretrained_models/chinese-hubert-base"),
172
+ pretrained_s2G=cfg.get("pretrained_s2G",
173
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
174
+ )),
175
+
176
+ "hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig(
177
+ **base_params,
178
+ version=version,
179
+ bert_pretrained_dir=cfg.get("bert_pretrained_dir",
180
+ "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
181
+ ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
182
+ "GPT_SoVITS/pretrained_models/chinese-hubert-base"),
183
+ pretrained_s2G=cfg.get("pretrained_s2G",
184
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
185
+ )),
186
+
187
+ "semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig(
188
+ **base_params,
189
+ version=version,
190
+ bert_pretrained_dir=cfg.get("bert_pretrained_dir",
191
+ "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
192
+ ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
193
+ "GPT_SoVITS/pretrained_models/chinese-hubert-base"),
194
+ pretrained_s2G=cfg.get("pretrained_s2G",
195
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
196
+ )),
197
+
198
+ "sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig(
199
+ **base_params,
200
+ version=version,
201
+ batch_size=cfg.get("batch_size", 4),
202
+ total_epoch=cfg.get("total_epoch", 8),
203
+ text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4),
204
+ save_every_epoch=cfg.get("save_every_epoch", 4),
205
+ if_save_latest=cfg.get("if_save_latest", True),
206
+ if_save_every_weights=cfg.get("if_save_every_weights", True),
207
+ pretrained_s2G=cfg.get("pretrained_s2G",
208
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
209
+ pretrained_s2D=cfg.get("pretrained_s2D",
210
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"),
211
+ if_grad_ckpt=cfg.get("if_grad_ckpt", False),
212
+ lora_rank=cfg.get("lora_rank", 32),
213
+ )),
214
+
215
+ "gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig(
216
+ **base_params,
217
+ version=version,
218
+ batch_size=cfg.get("batch_size", 4),
219
+ total_epoch=cfg.get("total_epoch", 15),
220
+ save_every_epoch=cfg.get("save_every_epoch", 5),
221
+ if_save_latest=cfg.get("if_save_latest", True),
222
+ if_save_every_weights=cfg.get("if_save_every_weights", True),
223
+ if_dpo=cfg.get("if_dpo", False),
224
+ pretrained_s1=cfg.get("pretrained_s1",
225
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"),
226
+ )),
227
+
228
+ "inference": lambda cfg: InferenceStage(InferenceConfig(
229
+ **base_params,
230
+ version=version,
231
+ gpt_path=cfg.get("gpt_path", ""),
232
+ sovits_path=cfg.get("sovits_path", ""),
233
+ ref_text=cfg.get("ref_text", ""),
234
+ ref_audio_path=cfg.get("ref_audio_path", ""),
235
+ target_text=cfg.get("target_text", ""),
236
+ text_split_method=cfg.get("text_split_method", "cut1"),
237
+ )),
238
+ }
239
+
240
+ # 按顺序添加阶段
241
+ stages = config.get("stages", [])
242
+ for stage_config in stages:
243
+ stage_type = stage_config.get("type")
244
+ if stage_type in stage_builders:
245
+ stage = stage_builders[stage_type](stage_config)
246
+ pipeline.add_stage(stage)
247
+ emit_log("info", f"已添加阶段: {stage.name}")
248
+ else:
249
+ emit_log("warning", f"未知阶段类型: {stage_type}")
250
+
251
+ return pipeline
252
+
253
+
254
+ def run_pipeline(config: Dict[str, Any], task_id: str) -> bool:
255
+ """
256
+ 执行 Pipeline
257
+
258
+ Args:
259
+ config: 配置字典
260
+ task_id: 任务ID
261
+
262
+ Returns:
263
+ 是否成功完成
264
+ """
265
+ emit_progress({
266
+ "type": "progress",
267
+ "status": "running",
268
+ "message": "正在初始化训练流水线...",
269
+ "task_id": task_id,
270
+ "progress": 0.0,
271
+ "overall_progress": 0.0,
272
+ })
273
+
274
+ try:
275
+ pipeline = build_pipeline(config)
276
+
277
+ stages = pipeline.get_stages()
278
+ if not stages:
279
+ emit_progress({
280
+ "type": "progress",
281
+ "status": "failed",
282
+ "message": "没有配置任何训练阶段",
283
+ "task_id": task_id,
284
+ })
285
+ return False
286
+
287
+ emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段")
288
+
289
+ # 执行 Pipeline
290
+ for progress in pipeline.run():
291
+ # 转换进度格式
292
+ emit_progress({
293
+ "type": "progress",
294
+ "status": "running",
295
+ "stage": progress.get("stage"),
296
+ "stage_index": progress.get("stage_index"),
297
+ "total_stages": progress.get("total_stages"),
298
+ "progress": progress.get("progress", 0.0),
299
+ "overall_progress": progress.get("overall_progress", 0.0),
300
+ "message": progress.get("message"),
301
+ "task_id": task_id,
302
+ "data": progress.get("data", {}),
303
+ })
304
+
305
+ # 检查是否失败
306
+ if progress.get("status") == "failed":
307
+ emit_progress({
308
+ "type": "progress",
309
+ "status": "failed",
310
+ "stage": progress.get("stage"),
311
+ "message": progress.get("message", "阶段执行失败"),
312
+ "task_id": task_id,
313
+ })
314
+ return False
315
+
316
+ # 完成
317
+ emit_progress({
318
+ "type": "progress",
319
+ "status": "completed",
320
+ "message": "训练流水线执行完成",
321
+ "task_id": task_id,
322
+ "progress": 1.0,
323
+ "overall_progress": 1.0,
324
+ })
325
+ return True
326
+
327
+ except Exception as e:
328
+ error_msg = str(e)
329
+ error_trace = traceback.format_exc()
330
+ emit_progress({
331
+ "type": "progress",
332
+ "status": "failed",
333
+ "message": f"执行出错: {error_msg}",
334
+ "error": error_msg,
335
+ "traceback": error_trace,
336
+ "task_id": task_id,
337
+ })
338
+ return False
339
+
340
+
341
+ def main():
342
+ """主函数"""
343
+ parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线")
344
+ parser.add_argument("--config", required=True, help="配置文件路径 (JSON)")
345
+ parser.add_argument("--task-id", required=True, help="任务ID")
346
+
347
+ args = parser.parse_args()
348
+
349
+ emit_log("info", f"启动训练任务: {args.task_id}")
350
+ emit_log("info", f"配置文件: {args.config}")
351
+
352
+ try:
353
+ config = load_config(args.config)
354
+ except Exception as e:
355
+ emit_progress({
356
+ "type": "progress",
357
+ "status": "failed",
358
+ "message": f"加载配置文件失败: {e}",
359
+ "task_id": args.task_id,
360
+ })
361
+ sys.exit(1)
362
+
363
+ success = run_pipeline(config, args.task_id)
364
+ sys.exit(0 if success else 1)
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main()