liumaolin
commited on
Commit
·
e43edbb
1
Parent(s):
f458b69
feat(api): implement AsyncTrainingManager MVP with SQLite persistence
Browse files- Add api_server module with adapter pattern architecture
- Implement AsyncTrainingManager using asyncio.subprocess + SQLite
- Add TaskQueueAdapter abstract base class for future server mode
- Create domain models: Task, TaskStatus, ProgressInfo
- Add run_pipeline.py wrapper script for subprocess execution
- Create config module for centralized environment variables
- Add aiosqlite dependency to pyproject.toml
- Include test config files for pipeline validation
The AsyncTrainingManager provides:
- Async task queue with SQLite persistence
- Real-time progress tracking via stdout JSON parsing
- Task cancellation and status querying
- Progress subscription for SSE streaming
- Application restart recovery support
- api_server/app/__init__.py +5 -0
- api_server/app/adapters/__init__.py +9 -0
- api_server/app/adapters/base.py +140 -0
- api_server/app/adapters/local/__init__.py +9 -0
- api_server/app/adapters/local/task_queue.py +695 -0
- api_server/app/core/__init__.py +9 -0
- api_server/app/core/config.py +142 -0
- api_server/app/models/__init__.py +9 -0
- api_server/app/models/domain.py +172 -0
- api_server/app/scripts/__init__.py +5 -0
- api_server/app/scripts/run_pipeline.py +368 -0
api_server/app/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-SoVITS 训练 API Server
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
__version__ = "0.1.0"
|
api_server/app/adapters/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
适配器模块
|
| 3 |
+
|
| 4 |
+
提供不同环境下的存储、任务队列等适配器实现
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .base import TaskQueueAdapter
|
| 8 |
+
|
| 9 |
+
__all__ = ["TaskQueueAdapter"]
|
api_server/app/adapters/base.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
适配器抽象基类模块
|
| 3 |
+
|
| 4 |
+
定义任务队列、存储、数据库等适配器的抽象接口
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Dict, Optional, AsyncGenerator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TaskQueueAdapter(ABC):
|
| 12 |
+
"""
|
| 13 |
+
任务队列适配器抽象基类
|
| 14 |
+
|
| 15 |
+
定义任务队列的通用接口,支持本地(asyncio.subprocess)和
|
| 16 |
+
服务器(Celery)两种实现方式。
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> adapter = AsyncTrainingManager(db_path="./data/tasks.db")
|
| 20 |
+
>>> job_id = await adapter.enqueue("task-123", {"exp_name": "test"})
|
| 21 |
+
>>> status = await adapter.get_status(job_id)
|
| 22 |
+
>>> async for progress in adapter.subscribe_progress("task-123"):
|
| 23 |
+
... print(progress)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str:
|
| 28 |
+
"""
|
| 29 |
+
将任务加入队列
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
task_id: 任务唯一标识
|
| 33 |
+
config: 任务配置字典,包含训练所需的所有参数
|
| 34 |
+
priority: 任务优先级 ("low", "normal", "high")
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
job_id: 队列中的作业ID
|
| 38 |
+
|
| 39 |
+
Raises:
|
| 40 |
+
ValueError: 配置无效时抛出
|
| 41 |
+
"""
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
async def get_status(self, job_id: str) -> Dict:
|
| 46 |
+
"""
|
| 47 |
+
获取任务状态
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
job_id: 作业ID
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
状态字典,包含:
|
| 54 |
+
- status: 任务状态 (queued, running, completed, failed, cancelled)
|
| 55 |
+
- progress: 进度 (0.0-1.0)
|
| 56 |
+
- current_stage: 当前阶段名称
|
| 57 |
+
- message: 状态消息
|
| 58 |
+
- error_message: 错误信息(如果失败)
|
| 59 |
+
"""
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
@abstractmethod
|
| 63 |
+
async def cancel(self, job_id: str) -> bool:
|
| 64 |
+
"""
|
| 65 |
+
取消任务
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
job_id: 作业ID
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
是否成功取消
|
| 72 |
+
"""
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]:
|
| 77 |
+
"""
|
| 78 |
+
订阅任务进度(用于 SSE 流)
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
task_id: 任务ID
|
| 82 |
+
|
| 83 |
+
Yields:
|
| 84 |
+
进度信息字典,包含:
|
| 85 |
+
- type: 消息类型 ("progress", "log", "heartbeat")
|
| 86 |
+
- stage: 当前阶段
|
| 87 |
+
- progress: 进度值
|
| 88 |
+
- message: 进度消息
|
| 89 |
+
- status: 状态 (running, completed, failed, cancelled)
|
| 90 |
+
|
| 91 |
+
Note:
|
| 92 |
+
当 status 为终态时,生成器会自动结束
|
| 93 |
+
"""
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ProgressAdapter(ABC):
|
| 98 |
+
"""
|
| 99 |
+
进度管理适配器抽象基类
|
| 100 |
+
|
| 101 |
+
用于更新和订阅任务进度,支持本地(内存队列)和
|
| 102 |
+
服务器(Redis Pub/Sub)两种实现。
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
@abstractmethod
|
| 106 |
+
async def update_progress(self, task_id: str, progress: Dict) -> None:
|
| 107 |
+
"""
|
| 108 |
+
更新进度
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
task_id: 任务ID
|
| 112 |
+
progress: 进度信息字典
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
@abstractmethod
|
| 117 |
+
async def get_progress(self, task_id: str) -> Optional[Dict]:
|
| 118 |
+
"""
|
| 119 |
+
获取当前进度
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
task_id: 任务ID
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
最新进度信息,不存在则返回 None
|
| 126 |
+
"""
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
@abstractmethod
|
| 130 |
+
async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]:
|
| 131 |
+
"""
|
| 132 |
+
订阅进度更新
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
task_id: 任务ID
|
| 136 |
+
|
| 137 |
+
Yields:
|
| 138 |
+
进度信息字典
|
| 139 |
+
"""
|
| 140 |
+
pass
|
api_server/app/adapters/local/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
本地适配器模块
|
| 3 |
+
|
| 4 |
+
提供基于 SQLite 和 asyncio.subprocess 的本地实现
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .task_queue import AsyncTrainingManager
|
| 8 |
+
|
| 9 |
+
__all__ = ["AsyncTrainingManager"]
|
api_server/app/adapters/local/task_queue.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
本地异步任务管理器
|
| 3 |
+
|
| 4 |
+
基于 asyncio.subprocess + SQLite 的本地任务队列实现。
|
| 5 |
+
适用于 macOS 本地训练和 Electron 集成场景。
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import sqlite3
|
| 12 |
+
import sys
|
| 13 |
+
import uuid
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, Optional, AsyncGenerator, List
|
| 17 |
+
|
| 18 |
+
import aiosqlite
|
| 19 |
+
|
| 20 |
+
from ..base import TaskQueueAdapter
|
| 21 |
+
from ...core.config import settings, PROJECT_ROOT, get_pythonpath
|
| 22 |
+
|
| 23 |
+
# 进度消息标识符(与 run_pipeline.py 保持一致)
|
| 24 |
+
PROGRESS_PREFIX = "##PROGRESS##"
|
| 25 |
+
PROGRESS_SUFFIX = "##"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AsyncTrainingManager(TaskQueueAdapter):
|
| 29 |
+
"""
|
| 30 |
+
基于 asyncio.subprocess 的异步任务管理器
|
| 31 |
+
|
| 32 |
+
特点:
|
| 33 |
+
1. 使用 asyncio.create_subprocess_exec() 异步启动训练子进程
|
| 34 |
+
2. 完全非阻塞,与 FastAPI 异步模型完美契合
|
| 35 |
+
3. SQLite 持久化任务状态,支持应用重启后恢复
|
| 36 |
+
4. 实时解析子进程输出获取进度
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
>>> manager = AsyncTrainingManager(db_path="./data/tasks.db")
|
| 40 |
+
>>> job_id = await manager.enqueue("task-123", {"exp_name": "test", ...})
|
| 41 |
+
>>>
|
| 42 |
+
>>> # 订阅进度
|
| 43 |
+
>>> async for progress in manager.subscribe_progress("task-123"):
|
| 44 |
+
... print(f"{progress['stage']}: {progress['progress']*100:.1f}%")
|
| 45 |
+
>>>
|
| 46 |
+
>>> # 取消任务
|
| 47 |
+
>>> await manager.cancel(job_id)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, db_path: str = None, max_concurrent: int = 1):
|
| 51 |
+
"""
|
| 52 |
+
初始化任务管理器
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH
|
| 56 |
+
max_concurrent: 最大并发任务数(本地通常为1)
|
| 57 |
+
"""
|
| 58 |
+
self.db_path = db_path or str(settings.SQLITE_PATH)
|
| 59 |
+
self.max_concurrent = max_concurrent
|
| 60 |
+
|
| 61 |
+
# 运行时状态
|
| 62 |
+
self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process
|
| 63 |
+
self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue
|
| 64 |
+
self._running_count = 0
|
| 65 |
+
self._queue_lock = asyncio.Lock()
|
| 66 |
+
|
| 67 |
+
# 初始化数据库
|
| 68 |
+
self._init_db_sync()
|
| 69 |
+
|
| 70 |
+
def _init_db_sync(self) -> None:
|
| 71 |
+
"""同步初始化数据库(启动时调用)"""
|
| 72 |
+
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 75 |
+
conn.execute('''
|
| 76 |
+
CREATE TABLE IF NOT EXISTS task_queue (
|
| 77 |
+
job_id TEXT PRIMARY KEY,
|
| 78 |
+
task_id TEXT NOT NULL UNIQUE,
|
| 79 |
+
exp_name TEXT NOT NULL,
|
| 80 |
+
config TEXT NOT NULL,
|
| 81 |
+
status TEXT DEFAULT 'queued',
|
| 82 |
+
current_stage TEXT,
|
| 83 |
+
progress REAL DEFAULT 0,
|
| 84 |
+
overall_progress REAL DEFAULT 0,
|
| 85 |
+
message TEXT,
|
| 86 |
+
error_message TEXT,
|
| 87 |
+
created_at TEXT NOT NULL,
|
| 88 |
+
started_at TEXT,
|
| 89 |
+
completed_at TEXT
|
| 90 |
+
)
|
| 91 |
+
''')
|
| 92 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_status ON task_queue(status)')
|
| 93 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_task_id ON task_queue(task_id)')
|
| 94 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_created ON task_queue(created_at)')
|
| 95 |
+
conn.commit()
|
| 96 |
+
|
| 97 |
+
async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str:
|
| 98 |
+
"""
|
| 99 |
+
将任务加入队列并异步启动
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
task_id: 任务唯一标识
|
| 103 |
+
config: 任务配置,需包含:
|
| 104 |
+
- exp_name: 实验名称
|
| 105 |
+
- version: 模型版本
|
| 106 |
+
- stages: 阶段配置列表
|
| 107 |
+
priority: 优先级(当前实现忽略此参数)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
job_id: 作业ID
|
| 111 |
+
"""
|
| 112 |
+
job_id = str(uuid.uuid4())
|
| 113 |
+
exp_name = config.get("exp_name", "unknown")
|
| 114 |
+
|
| 115 |
+
# 持久化到 SQLite
|
| 116 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 117 |
+
await db.execute(
|
| 118 |
+
'''INSERT INTO task_queue
|
| 119 |
+
(job_id, task_id, exp_name, config, status, created_at)
|
| 120 |
+
VALUES (?, ?, ?, ?, 'queued', ?)''',
|
| 121 |
+
(job_id, task_id, exp_name, json.dumps(config, ensure_ascii=False),
|
| 122 |
+
datetime.utcnow().isoformat())
|
| 123 |
+
)
|
| 124 |
+
await db.commit()
|
| 125 |
+
|
| 126 |
+
# 创建进度队列
|
| 127 |
+
self.progress_channels[task_id] = asyncio.Queue()
|
| 128 |
+
|
| 129 |
+
# 异步启动训练任务
|
| 130 |
+
asyncio.create_task(self._run_training_async(job_id, task_id, config))
|
| 131 |
+
|
| 132 |
+
return job_id
|
| 133 |
+
|
| 134 |
+
async def _run_training_async(self, job_id: str, task_id: str, config: Dict) -> None:
|
| 135 |
+
"""
|
| 136 |
+
异步执行训练 Pipeline
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
job_id: 作业ID
|
| 140 |
+
task_id: 任务ID
|
| 141 |
+
config: 任务配置
|
| 142 |
+
"""
|
| 143 |
+
config_path = None
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
# 更新状态为 running
|
| 147 |
+
await self._update_status(
|
| 148 |
+
job_id,
|
| 149 |
+
status='running',
|
| 150 |
+
started_at=datetime.utcnow().isoformat()
|
| 151 |
+
)
|
| 152 |
+
await self._send_progress(task_id, {
|
| 153 |
+
"type": "progress",
|
| 154 |
+
"status": "running",
|
| 155 |
+
"message": "训练任务启动中...",
|
| 156 |
+
"progress": 0.0,
|
| 157 |
+
"overall_progress": 0.0,
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
# 写入临时配置文件
|
| 161 |
+
config_path = await self._write_config_file(task_id, config)
|
| 162 |
+
|
| 163 |
+
# 获取 run_pipeline.py 脚本路径
|
| 164 |
+
script_path = self._get_pipeline_script_path()
|
| 165 |
+
|
| 166 |
+
# 构建环境变量
|
| 167 |
+
env = os.environ.copy()
|
| 168 |
+
env['PYTHONPATH'] = get_pythonpath()
|
| 169 |
+
|
| 170 |
+
# 创建子进程
|
| 171 |
+
process = await asyncio.create_subprocess_exec(
|
| 172 |
+
sys.executable, script_path,
|
| 173 |
+
'--config', config_path,
|
| 174 |
+
'--task-id', task_id,
|
| 175 |
+
stdout=asyncio.subprocess.PIPE,
|
| 176 |
+
stderr=asyncio.subprocess.PIPE,
|
| 177 |
+
env=env,
|
| 178 |
+
cwd=str(PROJECT_ROOT),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.running_processes[task_id] = process
|
| 182 |
+
self._running_count += 1
|
| 183 |
+
|
| 184 |
+
# 异步监控子进程输出
|
| 185 |
+
await self._monitor_process_output(task_id, job_id, process)
|
| 186 |
+
|
| 187 |
+
# 等待进程完成
|
| 188 |
+
returncode = await process.wait()
|
| 189 |
+
|
| 190 |
+
if returncode == 0:
|
| 191 |
+
await self._update_status(
|
| 192 |
+
job_id,
|
| 193 |
+
status='completed',
|
| 194 |
+
progress=1.0,
|
| 195 |
+
overall_progress=1.0,
|
| 196 |
+
message='训练完成',
|
| 197 |
+
completed_at=datetime.utcnow().isoformat()
|
| 198 |
+
)
|
| 199 |
+
await self._send_progress(task_id, {
|
| 200 |
+
"type": "progress",
|
| 201 |
+
"status": "completed",
|
| 202 |
+
"message": "训练完成",
|
| 203 |
+
"progress": 1.0,
|
| 204 |
+
"overall_progress": 1.0,
|
| 205 |
+
})
|
| 206 |
+
else:
|
| 207 |
+
# 尝试读取剩余的 stderr
|
| 208 |
+
stderr_data = await process.stderr.read()
|
| 209 |
+
error_msg = stderr_data.decode() if stderr_data else f"进程退出码: {returncode}"
|
| 210 |
+
|
| 211 |
+
await self._update_status(
|
| 212 |
+
job_id,
|
| 213 |
+
status='failed',
|
| 214 |
+
error_message=error_msg,
|
| 215 |
+
completed_at=datetime.utcnow().isoformat()
|
| 216 |
+
)
|
| 217 |
+
await self._send_progress(task_id, {
|
| 218 |
+
"type": "progress",
|
| 219 |
+
"status": "failed",
|
| 220 |
+
"message": f"训练失败: {error_msg[:200]}",
|
| 221 |
+
"error": error_msg,
|
| 222 |
+
})
|
| 223 |
+
|
| 224 |
+
except asyncio.CancelledError:
|
| 225 |
+
await self._update_status(
|
| 226 |
+
job_id,
|
| 227 |
+
status='cancelled',
|
| 228 |
+
message='任务已取消',
|
| 229 |
+
completed_at=datetime.utcnow().isoformat()
|
| 230 |
+
)
|
| 231 |
+
await self._send_progress(task_id, {
|
| 232 |
+
"type": "progress",
|
| 233 |
+
"status": "cancelled",
|
| 234 |
+
"message": "任务已取消",
|
| 235 |
+
})
|
| 236 |
+
|
| 237 |
+
except Exception as e:
|
| 238 |
+
error_msg = str(e)
|
| 239 |
+
await self._update_status(
|
| 240 |
+
job_id,
|
| 241 |
+
status='failed',
|
| 242 |
+
error_message=error_msg,
|
| 243 |
+
completed_at=datetime.utcnow().isoformat()
|
| 244 |
+
)
|
| 245 |
+
await self._send_progress(task_id, {
|
| 246 |
+
"type": "progress",
|
| 247 |
+
"status": "failed",
|
| 248 |
+
"message": f"任务执行出错: {error_msg}",
|
| 249 |
+
"error": error_msg,
|
| 250 |
+
})
|
| 251 |
+
|
| 252 |
+
finally:
|
| 253 |
+
# 清理
|
| 254 |
+
self.running_processes.pop(task_id, None)
|
| 255 |
+
self._running_count = max(0, self._running_count - 1)
|
| 256 |
+
|
| 257 |
+
# 清理临时配置文件
|
| 258 |
+
if config_path:
|
| 259 |
+
await self._cleanup_config_file(config_path)
|
| 260 |
+
|
| 261 |
+
async def _monitor_process_output(
|
| 262 |
+
self,
|
| 263 |
+
task_id: str,
|
| 264 |
+
job_id: str,
|
| 265 |
+
process: asyncio.subprocess.Process
|
| 266 |
+
) -> None:
|
| 267 |
+
"""
|
| 268 |
+
异步监控子进程输出并解析进度
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
task_id: 任务ID
|
| 272 |
+
job_id: 作业ID
|
| 273 |
+
process: 子进程对象
|
| 274 |
+
"""
|
| 275 |
+
async def read_stdout():
|
| 276 |
+
"""读取 stdout 并解析进度"""
|
| 277 |
+
while True:
|
| 278 |
+
line = await process.stdout.readline()
|
| 279 |
+
if not line:
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
text = line.decode('utf-8', errors='replace').strip()
|
| 283 |
+
if not text:
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
# 检测进度标记
|
| 287 |
+
if text.startswith(PROGRESS_PREFIX) and text.endswith(PROGRESS_SUFFIX):
|
| 288 |
+
json_str = text[len(PROGRESS_PREFIX):-len(PROGRESS_SUFFIX)]
|
| 289 |
+
try:
|
| 290 |
+
progress_info = json.loads(json_str)
|
| 291 |
+
await self._handle_progress(task_id, job_id, progress_info)
|
| 292 |
+
except json.JSONDecodeError as e:
|
| 293 |
+
# 解析失败,作为普通日志处理
|
| 294 |
+
await self._send_progress(task_id, {
|
| 295 |
+
"type": "log",
|
| 296 |
+
"level": "warning",
|
| 297 |
+
"message": f"进度解析失败: {e}",
|
| 298 |
+
})
|
| 299 |
+
else:
|
| 300 |
+
# 普通输出,作为日志处理
|
| 301 |
+
await self._send_progress(task_id, {
|
| 302 |
+
"type": "log",
|
| 303 |
+
"level": "info",
|
| 304 |
+
"message": text,
|
| 305 |
+
})
|
| 306 |
+
|
| 307 |
+
async def read_stderr():
|
| 308 |
+
"""读取 stderr 作为错误日志"""
|
| 309 |
+
while True:
|
| 310 |
+
line = await process.stderr.readline()
|
| 311 |
+
if not line:
|
| 312 |
+
break
|
| 313 |
+
|
| 314 |
+
text = line.decode('utf-8', errors='replace').strip()
|
| 315 |
+
if text:
|
| 316 |
+
await self._send_progress(task_id, {
|
| 317 |
+
"type": "log",
|
| 318 |
+
"level": "error",
|
| 319 |
+
"message": text,
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
# 并发读取 stdout 和 stderr
|
| 323 |
+
await asyncio.gather(
|
| 324 |
+
read_stdout(),
|
| 325 |
+
read_stderr(),
|
| 326 |
+
return_exceptions=True
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
async def _handle_progress(
|
| 330 |
+
self,
|
| 331 |
+
task_id: str,
|
| 332 |
+
job_id: str,
|
| 333 |
+
progress_info: Dict
|
| 334 |
+
) -> None:
|
| 335 |
+
"""
|
| 336 |
+
处理进度信息
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
task_id: 任务ID
|
| 340 |
+
job_id: 作业ID
|
| 341 |
+
progress_info: 进度信息字典
|
| 342 |
+
"""
|
| 343 |
+
# 发送到订阅者
|
| 344 |
+
await self._send_progress(task_id, progress_info)
|
| 345 |
+
|
| 346 |
+
# 更新数据库中的进度
|
| 347 |
+
updates = {}
|
| 348 |
+
|
| 349 |
+
if 'stage' in progress_info:
|
| 350 |
+
updates['current_stage'] = progress_info['stage']
|
| 351 |
+
if 'progress' in progress_info:
|
| 352 |
+
updates['progress'] = progress_info['progress']
|
| 353 |
+
if 'overall_progress' in progress_info:
|
| 354 |
+
updates['overall_progress'] = progress_info['overall_progress']
|
| 355 |
+
if 'message' in progress_info:
|
| 356 |
+
updates['message'] = progress_info['message']
|
| 357 |
+
if 'status' in progress_info:
|
| 358 |
+
updates['status'] = progress_info['status']
|
| 359 |
+
if 'error' in progress_info:
|
| 360 |
+
updates['error_message'] = progress_info['error']
|
| 361 |
+
|
| 362 |
+
if updates:
|
| 363 |
+
await self._update_status(job_id, **updates)
|
| 364 |
+
|
| 365 |
+
async def _send_progress(self, task_id: str, progress_info: Dict) -> None:
|
| 366 |
+
"""
|
| 367 |
+
发送进度到订阅队列
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
task_id: 任务ID
|
| 371 |
+
progress_info: 进度信息
|
| 372 |
+
"""
|
| 373 |
+
if task_id in self.progress_channels:
|
| 374 |
+
# 添加时间戳
|
| 375 |
+
if 'timestamp' not in progress_info:
|
| 376 |
+
progress_info['timestamp'] = datetime.utcnow().isoformat()
|
| 377 |
+
|
| 378 |
+
await self.progress_channels[task_id].put(progress_info)
|
| 379 |
+
|
| 380 |
+
async def _update_status(self, job_id: str, **kwargs) -> None:
|
| 381 |
+
"""
|
| 382 |
+
更新任务状态
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
job_id: 作业ID
|
| 386 |
+
**kwargs: 要更新的字段
|
| 387 |
+
"""
|
| 388 |
+
if not kwargs:
|
| 389 |
+
return
|
| 390 |
+
|
| 391 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 392 |
+
updates = []
|
| 393 |
+
values = []
|
| 394 |
+
|
| 395 |
+
for key, value in kwargs.items():
|
| 396 |
+
updates.append(f"{key} = ?")
|
| 397 |
+
values.append(value)
|
| 398 |
+
|
| 399 |
+
values.append(job_id)
|
| 400 |
+
|
| 401 |
+
await db.execute(
|
| 402 |
+
f"UPDATE task_queue SET {', '.join(updates)} WHERE job_id = ?",
|
| 403 |
+
values
|
| 404 |
+
)
|
| 405 |
+
await db.commit()
|
| 406 |
+
|
| 407 |
+
async def get_status(self, job_id: str) -> Dict:
|
| 408 |
+
"""
|
| 409 |
+
获取任务状态
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
job_id: 作业ID
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
状态字典
|
| 416 |
+
"""
|
| 417 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 418 |
+
db.row_factory = aiosqlite.Row
|
| 419 |
+
async with db.execute(
|
| 420 |
+
"SELECT * FROM task_queue WHERE job_id = ?", (job_id,)
|
| 421 |
+
) as cursor:
|
| 422 |
+
row = await cursor.fetchone()
|
| 423 |
+
if row:
|
| 424 |
+
return dict(row)
|
| 425 |
+
|
| 426 |
+
return {"status": "not_found", "message": "任务不存在"}
|
| 427 |
+
|
| 428 |
+
async def get_status_by_task_id(self, task_id: str) -> Dict:
|
| 429 |
+
"""
|
| 430 |
+
通过 task_id 获取任务状态
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
task_id: 任务ID
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
状态字典
|
| 437 |
+
"""
|
| 438 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 439 |
+
db.row_factory = aiosqlite.Row
|
| 440 |
+
async with db.execute(
|
| 441 |
+
"SELECT * FROM task_queue WHERE task_id = ?", (task_id,)
|
| 442 |
+
) as cursor:
|
| 443 |
+
row = await cursor.fetchone()
|
| 444 |
+
if row:
|
| 445 |
+
return dict(row)
|
| 446 |
+
|
| 447 |
+
return {"status": "not_found", "message": "任务不存在"}
|
| 448 |
+
|
| 449 |
+
async def cancel(self, job_id: str) -> bool:
|
| 450 |
+
"""
|
| 451 |
+
取消任务
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
job_id: 作业ID
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
是否成功取消
|
| 458 |
+
"""
|
| 459 |
+
# 查找 task_id
|
| 460 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 461 |
+
async with db.execute(
|
| 462 |
+
"SELECT task_id, status FROM task_queue WHERE job_id = ?", (job_id,)
|
| 463 |
+
) as cursor:
|
| 464 |
+
row = await cursor.fetchone()
|
| 465 |
+
if not row:
|
| 466 |
+
return False
|
| 467 |
+
task_id, status = row
|
| 468 |
+
|
| 469 |
+
# 如果任务已经完成,无法取消
|
| 470 |
+
if status in ('completed', 'failed', 'cancelled'):
|
| 471 |
+
return False
|
| 472 |
+
|
| 473 |
+
# 终止进程
|
| 474 |
+
if task_id in self.running_processes:
|
| 475 |
+
process = self.running_processes[task_id]
|
| 476 |
+
|
| 477 |
+
# 先尝试优雅终止
|
| 478 |
+
process.terminate()
|
| 479 |
+
|
| 480 |
+
try:
|
| 481 |
+
# 等待进程终止
|
| 482 |
+
await asyncio.wait_for(process.wait(), timeout=5.0)
|
| 483 |
+
except asyncio.TimeoutError:
|
| 484 |
+
# 超时则强制终止
|
| 485 |
+
process.kill()
|
| 486 |
+
await process.wait()
|
| 487 |
+
|
| 488 |
+
return True
|
| 489 |
+
|
| 490 |
+
# 如果进程不在运行(可能还在队列中),直接更新状态
|
| 491 |
+
await self._update_status(
|
| 492 |
+
job_id,
|
| 493 |
+
status='cancelled',
|
| 494 |
+
message='任务已取消',
|
| 495 |
+
completed_at=datetime.utcnow().isoformat()
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# 通知订阅者
|
| 499 |
+
if task_id in self.progress_channels:
|
| 500 |
+
await self._send_progress(task_id, {
|
| 501 |
+
"type": "progress",
|
| 502 |
+
"status": "cancelled",
|
| 503 |
+
"message": "任务已取消",
|
| 504 |
+
})
|
| 505 |
+
|
| 506 |
+
return True
|
| 507 |
+
|
| 508 |
+
async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]:
|
| 509 |
+
"""
|
| 510 |
+
订阅任务进度(用于 SSE 流)
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
task_id: 任务ID
|
| 514 |
+
|
| 515 |
+
Yields:
|
| 516 |
+
进度信息字典
|
| 517 |
+
"""
|
| 518 |
+
# 确保队列存在
|
| 519 |
+
if task_id not in self.progress_channels:
|
| 520 |
+
self.progress_channels[task_id] = asyncio.Queue()
|
| 521 |
+
|
| 522 |
+
queue = self.progress_channels[task_id]
|
| 523 |
+
|
| 524 |
+
# 首先发送当前状态
|
| 525 |
+
status = await self.get_status_by_task_id(task_id)
|
| 526 |
+
if status.get("status") != "not_found":
|
| 527 |
+
yield {
|
| 528 |
+
"type": "progress",
|
| 529 |
+
"status": status.get("status"),
|
| 530 |
+
"stage": status.get("current_stage"),
|
| 531 |
+
"progress": status.get("progress", 0.0),
|
| 532 |
+
"overall_progress": status.get("overall_progress", 0.0),
|
| 533 |
+
"message": status.get("message"),
|
| 534 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
# 持续接收进度更新
|
| 538 |
+
while True:
|
| 539 |
+
try:
|
| 540 |
+
# 30秒超时,发送心跳
|
| 541 |
+
progress = await asyncio.wait_for(queue.get(), timeout=30.0)
|
| 542 |
+
yield progress
|
| 543 |
+
|
| 544 |
+
# 检查是否为终态
|
| 545 |
+
if progress.get('status') in ('completed', 'failed', 'cancelled'):
|
| 546 |
+
break
|
| 547 |
+
|
| 548 |
+
except asyncio.TimeoutError:
|
| 549 |
+
# 发送心跳保持连接
|
| 550 |
+
yield {
|
| 551 |
+
"type": "heartbeat",
|
| 552 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
async def list_tasks(
|
| 556 |
+
self,
|
| 557 |
+
status: Optional[str] = None,
|
| 558 |
+
limit: int = 50,
|
| 559 |
+
offset: int = 0
|
| 560 |
+
) -> List[Dict]:
|
| 561 |
+
"""
|
| 562 |
+
列出任务
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
status: 按状态筛选
|
| 566 |
+
limit: 返回数量限制
|
| 567 |
+
offset: 偏移量
|
| 568 |
+
|
| 569 |
+
Returns:
|
| 570 |
+
任务列表
|
| 571 |
+
"""
|
| 572 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 573 |
+
db.row_factory = aiosqlite.Row
|
| 574 |
+
|
| 575 |
+
if status:
|
| 576 |
+
query = """
|
| 577 |
+
SELECT * FROM task_queue
|
| 578 |
+
WHERE status = ?
|
| 579 |
+
ORDER BY created_at DESC
|
| 580 |
+
LIMIT ? OFFSET ?
|
| 581 |
+
"""
|
| 582 |
+
params = (status, limit, offset)
|
| 583 |
+
else:
|
| 584 |
+
query = """
|
| 585 |
+
SELECT * FROM task_queue
|
| 586 |
+
ORDER BY created_at DESC
|
| 587 |
+
LIMIT ? OFFSET ?
|
| 588 |
+
"""
|
| 589 |
+
params = (limit, offset)
|
| 590 |
+
|
| 591 |
+
async with db.execute(query, params) as cursor:
|
| 592 |
+
rows = await cursor.fetchall()
|
| 593 |
+
return [dict(row) for row in rows]
|
| 594 |
+
|
| 595 |
+
async def recover_pending_tasks(self) -> int:
|
| 596 |
+
"""
|
| 597 |
+
应用重启后恢复未完成的任务
|
| 598 |
+
|
| 599 |
+
将 running 状态的任务标记为 interrupted,
|
| 600 |
+
可选择重新启动 queued 状态的任务。
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
恢复的任务数量
|
| 604 |
+
"""
|
| 605 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 606 |
+
# 将 running 状态的任务标记为 interrupted
|
| 607 |
+
await db.execute(
|
| 608 |
+
"""UPDATE task_queue
|
| 609 |
+
SET status = 'interrupted',
|
| 610 |
+
message = '应用重启导致任务中断'
|
| 611 |
+
WHERE status = 'running'"""
|
| 612 |
+
)
|
| 613 |
+
await db.commit()
|
| 614 |
+
|
| 615 |
+
# 获取 queued 状态的任务
|
| 616 |
+
db.row_factory = aiosqlite.Row
|
| 617 |
+
async with db.execute(
|
| 618 |
+
"SELECT * FROM task_queue WHERE status = 'queued' ORDER BY created_at"
|
| 619 |
+
) as cursor:
|
| 620 |
+
queued_tasks = await cursor.fetchall()
|
| 621 |
+
|
| 622 |
+
# 重新启动 queued 状态的任务
|
| 623 |
+
recovered = 0
|
| 624 |
+
for task in queued_tasks:
|
| 625 |
+
task_id = task['task_id']
|
| 626 |
+
job_id = task['job_id']
|
| 627 |
+
config = json.loads(task['config'])
|
| 628 |
+
|
| 629 |
+
self.progress_channels[task_id] = asyncio.Queue()
|
| 630 |
+
asyncio.create_task(self._run_training_async(job_id, task_id, config))
|
| 631 |
+
recovered += 1
|
| 632 |
+
|
| 633 |
+
return recovered
|
| 634 |
+
|
| 635 |
+
async def cleanup_old_tasks(self, days: int = 7) -> int:
|
| 636 |
+
"""
|
| 637 |
+
清理旧任务记录
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
days: 保留天数
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
删除的任务数量
|
| 644 |
+
"""
|
| 645 |
+
from datetime import timedelta
|
| 646 |
+
|
| 647 |
+
cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat()
|
| 648 |
+
|
| 649 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 650 |
+
cursor = await db.execute(
|
| 651 |
+
"""DELETE FROM task_queue
|
| 652 |
+
WHERE status IN ('completed', 'failed', 'cancelled')
|
| 653 |
+
AND completed_at < ?""",
|
| 654 |
+
(cutoff,)
|
| 655 |
+
)
|
| 656 |
+
deleted = cursor.rowcount
|
| 657 |
+
await db.commit()
|
| 658 |
+
|
| 659 |
+
return deleted
|
| 660 |
+
|
| 661 |
+
def _get_pipeline_script_path(self) -> str:
|
| 662 |
+
"""获取 run_pipeline.py 脚本路径"""
|
| 663 |
+
return str(settings.PIPELINE_SCRIPT_PATH)
|
| 664 |
+
|
| 665 |
+
async def _write_config_file(self, task_id: str, config: Dict) -> str:
|
| 666 |
+
"""
|
| 667 |
+
写入临时配置文件
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
task_id: 任务ID
|
| 671 |
+
config: 配置字典
|
| 672 |
+
|
| 673 |
+
Returns:
|
| 674 |
+
配置文件路径
|
| 675 |
+
"""
|
| 676 |
+
config_path = settings.CONFIGS_DIR / f"{task_id}.json"
|
| 677 |
+
|
| 678 |
+
with open(config_path, 'w', encoding='utf-8') as f:
|
| 679 |
+
json.dump(config, f, ensure_ascii=False, indent=2)
|
| 680 |
+
|
| 681 |
+
return str(config_path)
|
| 682 |
+
|
| 683 |
+
async def _cleanup_config_file(self, config_path: str) -> None:
|
| 684 |
+
"""
|
| 685 |
+
清理临时配置文件
|
| 686 |
+
|
| 687 |
+
Args:
|
| 688 |
+
config_path: 配置文件路径
|
| 689 |
+
"""
|
| 690 |
+
try:
|
| 691 |
+
path = Path(config_path)
|
| 692 |
+
if path.exists():
|
| 693 |
+
path.unlink()
|
| 694 |
+
except Exception:
|
| 695 |
+
pass # 忽略清理错误
|
api_server/app/core/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
核心模块
|
| 3 |
+
|
| 4 |
+
包含配置、枚举等核心组件
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .config import settings, PROJECT_ROOT, API_SERVER_ROOT
|
| 8 |
+
|
| 9 |
+
__all__ = ["settings", "PROJECT_ROOT", "API_SERVER_ROOT"]
|
api_server/app/core/config.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
环境变量和配置模块
|
| 3 |
+
|
| 4 |
+
统一管理项目路径、环境配置等
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Literal
|
| 10 |
+
|
| 11 |
+
# ============================================================
|
| 12 |
+
# 路径常量
|
| 13 |
+
# ============================================================
|
| 14 |
+
|
| 15 |
+
USER_HOME_ROOT = Path.home()
|
| 16 |
+
|
| 17 |
+
# api_server/app/core/config.py -> api_server/app/core -> api_server/app -> api_server -> 项目根目录
|
| 18 |
+
API_SERVER_ROOT = Path(__file__).parent.parent.parent.resolve()
|
| 19 |
+
PROJECT_ROOT = API_SERVER_ROOT.parent.resolve()
|
| 20 |
+
|
| 21 |
+
# GPT_SoVITS 模块路径
|
| 22 |
+
GPT_SOVITS_ROOT = PROJECT_ROOT / "GPT_SoVITS"
|
| 23 |
+
|
| 24 |
+
# 默认数据目录
|
| 25 |
+
DEFAULT_DATA_DIR = USER_HOME_ROOT / '.moyoyo-tts' / "data"
|
| 26 |
+
|
| 27 |
+
# 预训练模型目录
|
| 28 |
+
PRETRAINED_MODELS_DIR = GPT_SOVITS_ROOT / "pretrained_models"
|
| 29 |
+
|
| 30 |
+
# 日志目录
|
| 31 |
+
LOGS_DIR = PROJECT_ROOT / "logs"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ============================================================
|
| 35 |
+
# 配置类
|
| 36 |
+
# ============================================================
|
| 37 |
+
|
| 38 |
+
class Settings:
|
| 39 |
+
"""
|
| 40 |
+
API Server 配置
|
| 41 |
+
|
| 42 |
+
支持从环境变量读取配置,提供合理的默认值
|
| 43 |
+
|
| 44 |
+
Example:
|
| 45 |
+
>>> from api_server.app.core.config import settings
|
| 46 |
+
>>> print(settings.PROJECT_ROOT)
|
| 47 |
+
>>> print(settings.DEPLOYMENT_MODE)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# 部署模式
|
| 51 |
+
DEPLOYMENT_MODE: Literal["local", "server"] = os.getenv("DEPLOYMENT_MODE", "local")
|
| 52 |
+
|
| 53 |
+
# API 配置
|
| 54 |
+
API_V1_PREFIX: str = os.getenv("API_V1_PREFIX", "/api/v1")
|
| 55 |
+
API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
|
| 56 |
+
API_PORT: int = int(os.getenv("API_PORT", "8000"))
|
| 57 |
+
|
| 58 |
+
# 路径配置(可通过环境变量覆盖)
|
| 59 |
+
PROJECT_ROOT: Path = Path(os.getenv("PROJECT_ROOT", str(PROJECT_ROOT)))
|
| 60 |
+
API_SERVER_ROOT: Path = Path(os.getenv("API_SERVER_ROOT", str(API_SERVER_ROOT)))
|
| 61 |
+
DATA_DIR: Path = Path(os.getenv("DATA_DIR", str(DEFAULT_DATA_DIR)))
|
| 62 |
+
|
| 63 |
+
# SQLite 数据库路径
|
| 64 |
+
SQLITE_PATH: Path = Path(os.getenv("SQLITE_PATH", str(DEFAULT_DATA_DIR / "tasks.db")))
|
| 65 |
+
|
| 66 |
+
# 任务配置
|
| 67 |
+
LOCAL_MAX_WORKERS: int = int(os.getenv("LOCAL_MAX_WORKERS", "1"))
|
| 68 |
+
|
| 69 |
+
# 预训练模型路径
|
| 70 |
+
BERT_PRETRAINED_DIR: str = os.getenv(
|
| 71 |
+
"BERT_PRETRAINED_DIR",
|
| 72 |
+
str(PRETRAINED_MODELS_DIR / "chinese-roberta-wwm-ext-large")
|
| 73 |
+
)
|
| 74 |
+
SSL_PRETRAINED_DIR: str = os.getenv(
|
| 75 |
+
"SSL_PRETRAINED_DIR",
|
| 76 |
+
str(PRETRAINED_MODELS_DIR / "chinese-hubert-base")
|
| 77 |
+
)
|
| 78 |
+
PRETRAINED_S2G: str = os.getenv(
|
| 79 |
+
"PRETRAINED_S2G",
|
| 80 |
+
str(PRETRAINED_MODELS_DIR / "gsv-v2final-pretrained" / "s2G2333k.pth")
|
| 81 |
+
)
|
| 82 |
+
PRETRAINED_S2D: str = os.getenv(
|
| 83 |
+
"PRETRAINED_S2D",
|
| 84 |
+
str(PRETRAINED_MODELS_DIR / "gsv-v2final-pretrained" / "s2D2333k.pth")
|
| 85 |
+
)
|
| 86 |
+
PRETRAINED_S1: str = os.getenv(
|
| 87 |
+
"PRETRAINED_S1",
|
| 88 |
+
str(PRETRAINED_MODELS_DIR / "gsv-v2final-pretrained" / "s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Pipeline 脚本路径
|
| 92 |
+
@property
|
| 93 |
+
def PIPELINE_SCRIPT_PATH(self) -> Path:
|
| 94 |
+
"""Pipeline 执行脚本路径"""
|
| 95 |
+
return self.API_SERVER_ROOT / "app" / "scripts" / "run_pipeline.py"
|
| 96 |
+
|
| 97 |
+
# 临时配置文件目录
|
| 98 |
+
@property
|
| 99 |
+
def CONFIGS_DIR(self) -> Path:
|
| 100 |
+
"""临时配置文件目录"""
|
| 101 |
+
path = self.DATA_DIR / "configs"
|
| 102 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
return path
|
| 104 |
+
|
| 105 |
+
def __repr__(self) -> str:
|
| 106 |
+
return (
|
| 107 |
+
f"Settings(\n"
|
| 108 |
+
f" DEPLOYMENT_MODE={self.DEPLOYMENT_MODE!r},\n"
|
| 109 |
+
f" PROJECT_ROOT={self.PROJECT_ROOT},\n"
|
| 110 |
+
f" API_SERVER_ROOT={self.API_SERVER_ROOT},\n"
|
| 111 |
+
f" DATA_DIR={self.DATA_DIR},\n"
|
| 112 |
+
f" SQLITE_PATH={self.SQLITE_PATH},\n"
|
| 113 |
+
f")"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# 全局配置实例
|
| 118 |
+
settings = Settings()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_pythonpath() -> str:
|
| 122 |
+
"""
|
| 123 |
+
获取 PYTHONPATH 环境变量值
|
| 124 |
+
|
| 125 |
+
用于子进程启动时设置正确的模块搜索路径
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
PYTHONPATH 字符串
|
| 129 |
+
"""
|
| 130 |
+
paths = [
|
| 131 |
+
str(PROJECT_ROOT),
|
| 132 |
+
str(GPT_SOVITS_ROOT),
|
| 133 |
+
]
|
| 134 |
+
return os.pathsep.join(paths)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def ensure_data_dirs() -> None:
|
| 138 |
+
"""
|
| 139 |
+
确保必要的数据目录存在
|
| 140 |
+
"""
|
| 141 |
+
settings.DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 142 |
+
settings.CONFIGS_DIR.mkdir(parents=True, exist_ok=True)
|
api_server/app/models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
模型模块
|
| 3 |
+
|
| 4 |
+
包含领域模型和 Pydantic Schema
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .domain import Task, TaskStatus, ProgressInfo
|
| 8 |
+
|
| 9 |
+
__all__ = ["Task", "TaskStatus", "ProgressInfo"]
|
api_server/app/models/domain.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
领域模型模块
|
| 3 |
+
|
| 4 |
+
定义训练任务相关的核心数据结构
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Dict, Optional, Any
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TaskStatus(Enum):
|
| 14 |
+
"""任务状态枚举"""
|
| 15 |
+
QUEUED = "queued" # 已入队,等待执行
|
| 16 |
+
RUNNING = "running" # 执行中
|
| 17 |
+
COMPLETED = "completed" # 已完成
|
| 18 |
+
FAILED = "failed" # 失败
|
| 19 |
+
CANCELLED = "cancelled" # 已取消
|
| 20 |
+
INTERRUPTED = "interrupted" # 被中断(应用重启时运行中的任务)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class Task:
|
| 25 |
+
"""
|
| 26 |
+
训练任务领域模型
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
id: 任务唯一标识
|
| 30 |
+
job_id: 队列作业ID(由任务队列生成)
|
| 31 |
+
exp_name: 实验名称
|
| 32 |
+
status: 任务状态
|
| 33 |
+
config: 任务配置(包含所有训练参数)
|
| 34 |
+
current_stage: 当前执行阶段
|
| 35 |
+
progress: 总体进度 (0.0-1.0)
|
| 36 |
+
stage_progress: 当前阶段进度 (0.0-1.0)
|
| 37 |
+
message: 最新状态消息
|
| 38 |
+
error_message: 错误信息(失败时)
|
| 39 |
+
created_at: 创建时间
|
| 40 |
+
started_at: 开始执行时间
|
| 41 |
+
completed_at: 完成时间
|
| 42 |
+
|
| 43 |
+
Example:
|
| 44 |
+
>>> task = Task(
|
| 45 |
+
... id="task-123",
|
| 46 |
+
... exp_name="my_voice",
|
| 47 |
+
... config={"version": "v2", "batch_size": 4}
|
| 48 |
+
... )
|
| 49 |
+
>>> task.status
|
| 50 |
+
<TaskStatus.QUEUED: 'queued'>
|
| 51 |
+
"""
|
| 52 |
+
id: str
|
| 53 |
+
exp_name: str
|
| 54 |
+
config: Dict[str, Any]
|
| 55 |
+
job_id: Optional[str] = None
|
| 56 |
+
status: TaskStatus = TaskStatus.QUEUED
|
| 57 |
+
current_stage: Optional[str] = None
|
| 58 |
+
progress: float = 0.0
|
| 59 |
+
stage_progress: float = 0.0
|
| 60 |
+
message: Optional[str] = None
|
| 61 |
+
error_message: Optional[str] = None
|
| 62 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
| 63 |
+
started_at: Optional[datetime] = None
|
| 64 |
+
completed_at: Optional[datetime] = None
|
| 65 |
+
|
| 66 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 67 |
+
"""转换为字典"""
|
| 68 |
+
return {
|
| 69 |
+
"id": self.id,
|
| 70 |
+
"job_id": self.job_id,
|
| 71 |
+
"exp_name": self.exp_name,
|
| 72 |
+
"status": self.status.value,
|
| 73 |
+
"config": self.config,
|
| 74 |
+
"current_stage": self.current_stage,
|
| 75 |
+
"progress": self.progress,
|
| 76 |
+
"stage_progress": self.stage_progress,
|
| 77 |
+
"message": self.message,
|
| 78 |
+
"error_message": self.error_message,
|
| 79 |
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
| 80 |
+
"started_at": self.started_at.isoformat() if self.started_at else None,
|
| 81 |
+
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_dict(cls, data: Dict[str, Any]) -> "Task":
|
| 86 |
+
"""从字典创建实例"""
|
| 87 |
+
# 处理状态枚举
|
| 88 |
+
status = data.get("status", "queued")
|
| 89 |
+
if isinstance(status, str):
|
| 90 |
+
status = TaskStatus(status)
|
| 91 |
+
|
| 92 |
+
# 处理日期时间
|
| 93 |
+
def parse_datetime(value):
|
| 94 |
+
if value is None:
|
| 95 |
+
return None
|
| 96 |
+
if isinstance(value, datetime):
|
| 97 |
+
return value
|
| 98 |
+
return datetime.fromisoformat(value)
|
| 99 |
+
|
| 100 |
+
return cls(
|
| 101 |
+
id=data["id"],
|
| 102 |
+
job_id=data.get("job_id"),
|
| 103 |
+
exp_name=data["exp_name"],
|
| 104 |
+
status=status,
|
| 105 |
+
config=data.get("config", {}),
|
| 106 |
+
current_stage=data.get("current_stage"),
|
| 107 |
+
progress=data.get("progress", 0.0),
|
| 108 |
+
stage_progress=data.get("stage_progress", 0.0),
|
| 109 |
+
message=data.get("message"),
|
| 110 |
+
error_message=data.get("error_message"),
|
| 111 |
+
created_at=parse_datetime(data.get("created_at")),
|
| 112 |
+
started_at=parse_datetime(data.get("started_at")),
|
| 113 |
+
completed_at=parse_datetime(data.get("completed_at")),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class ProgressInfo:
|
| 119 |
+
"""
|
| 120 |
+
进度信息数据结构
|
| 121 |
+
|
| 122 |
+
用于在子进程和主进程之间传递进度更新
|
| 123 |
+
|
| 124 |
+
Attributes:
|
| 125 |
+
type: 消息类型 ("progress", "log", "error", "heartbeat")
|
| 126 |
+
stage: 当前阶段名称
|
| 127 |
+
stage_index: 当前阶段索引
|
| 128 |
+
total_stages: 总阶段数
|
| 129 |
+
progress: 阶段内进度 (0.0-1.0)
|
| 130 |
+
overall_progress: 总体进度 (0.0-1.0)
|
| 131 |
+
message: 进度消息
|
| 132 |
+
status: 状态
|
| 133 |
+
data: 附加数据
|
| 134 |
+
"""
|
| 135 |
+
type: str = "progress"
|
| 136 |
+
stage: Optional[str] = None
|
| 137 |
+
stage_index: Optional[int] = None
|
| 138 |
+
total_stages: Optional[int] = None
|
| 139 |
+
progress: float = 0.0
|
| 140 |
+
overall_progress: float = 0.0
|
| 141 |
+
message: Optional[str] = None
|
| 142 |
+
status: Optional[str] = None
|
| 143 |
+
data: Dict[str, Any] = field(default_factory=dict)
|
| 144 |
+
|
| 145 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 146 |
+
"""转换为字典"""
|
| 147 |
+
return {
|
| 148 |
+
"type": self.type,
|
| 149 |
+
"stage": self.stage,
|
| 150 |
+
"stage_index": self.stage_index,
|
| 151 |
+
"total_stages": self.total_stages,
|
| 152 |
+
"progress": self.progress,
|
| 153 |
+
"overall_progress": self.overall_progress,
|
| 154 |
+
"message": self.message,
|
| 155 |
+
"status": self.status,
|
| 156 |
+
"data": self.data,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def from_dict(cls, data: Dict[str, Any]) -> "ProgressInfo":
|
| 161 |
+
"""从字典创建实例"""
|
| 162 |
+
return cls(
|
| 163 |
+
type=data.get("type", "progress"),
|
| 164 |
+
stage=data.get("stage"),
|
| 165 |
+
stage_index=data.get("stage_index"),
|
| 166 |
+
total_stages=data.get("total_stages"),
|
| 167 |
+
progress=data.get("progress", 0.0),
|
| 168 |
+
overall_progress=data.get("overall_progress", 0.0),
|
| 169 |
+
message=data.get("message"),
|
| 170 |
+
status=data.get("status"),
|
| 171 |
+
data=data.get("data", {}),
|
| 172 |
+
)
|
api_server/app/scripts/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
脚本模块
|
| 3 |
+
|
| 4 |
+
包含用于子进程执行的独立脚本
|
| 5 |
+
"""
|
api_server/app/scripts/run_pipeline.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pipeline 包装脚本
|
| 4 |
+
|
| 5 |
+
此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。
|
| 6 |
+
主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。
|
| 7 |
+
|
| 8 |
+
进度消息格式:
|
| 9 |
+
##PROGRESS##{"type": "progress", "stage": "...", ...}##
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python run_pipeline.py --config /path/to/config.json --task-id task-123
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import sys
|
| 18 |
+
import os
|
| 19 |
+
import traceback
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
from typing import Dict, Any
|
| 22 |
+
|
| 23 |
+
# 确保可以导入项目模块(在导入其他模块之前)
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
_SCRIPT_DIR = Path(__file__).parent.resolve()
|
| 26 |
+
_API_SERVER_ROOT = _SCRIPT_DIR.parent.parent
|
| 27 |
+
_PROJECT_ROOT = _API_SERVER_ROOT.parent
|
| 28 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 29 |
+
|
| 30 |
+
# 导入配置模块
|
| 31 |
+
from api_server.app.core.config import settings, PROJECT_ROOT, get_pythonpath
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 进度消息前缀和后缀,用于主进程解析
|
| 35 |
+
PROGRESS_PREFIX = "##PROGRESS##"
|
| 36 |
+
PROGRESS_SUFFIX = "##"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def emit_progress(progress_info: Dict[str, Any]) -> None:
|
| 40 |
+
"""
|
| 41 |
+
输出进度消息到 stdout
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
progress_info: 进度信息字典
|
| 45 |
+
"""
|
| 46 |
+
# 确保有时间戳
|
| 47 |
+
if "timestamp" not in progress_info:
|
| 48 |
+
progress_info["timestamp"] = datetime.utcnow().isoformat()
|
| 49 |
+
|
| 50 |
+
json_str = json.dumps(progress_info, ensure_ascii=False)
|
| 51 |
+
print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def emit_log(level: str, message: str, **extra) -> None:
|
| 55 |
+
"""
|
| 56 |
+
输出日志消息
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
level: 日志级别 (info, warning, error)
|
| 60 |
+
message: 日志消息
|
| 61 |
+
**extra: 额外数据
|
| 62 |
+
"""
|
| 63 |
+
emit_progress({
|
| 64 |
+
"type": "log",
|
| 65 |
+
"level": level,
|
| 66 |
+
"message": message,
|
| 67 |
+
**extra
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_config(config_path: str) -> Dict[str, Any]:
|
| 72 |
+
"""
|
| 73 |
+
加载配置文件
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
config_path: 配置文件路径
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
配置字典
|
| 80 |
+
"""
|
| 81 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 82 |
+
return json.load(f)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_pipeline(config: Dict[str, Any]):
|
| 86 |
+
"""
|
| 87 |
+
根据配置构建 TrainingPipeline
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
config: 配置字典,包含:
|
| 91 |
+
- exp_name: 实验名称
|
| 92 |
+
- version: 模型版本
|
| 93 |
+
- stages: 要执行的阶段列表
|
| 94 |
+
- 各阶段的具体配置
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
TrainingPipeline 实例
|
| 98 |
+
"""
|
| 99 |
+
from training_pipeline import (
|
| 100 |
+
TrainingPipeline,
|
| 101 |
+
ModelVersion,
|
| 102 |
+
# 配置类
|
| 103 |
+
AudioSliceConfig,
|
| 104 |
+
ASRConfig,
|
| 105 |
+
DenoiseConfig,
|
| 106 |
+
FeatureExtractionConfig,
|
| 107 |
+
SoVITSTrainConfig,
|
| 108 |
+
GPTTrainConfig,
|
| 109 |
+
InferenceConfig,
|
| 110 |
+
# 阶段类
|
| 111 |
+
AudioSliceStage,
|
| 112 |
+
ASRStage,
|
| 113 |
+
DenoiseStage,
|
| 114 |
+
TextFeatureStage,
|
| 115 |
+
HuBERTFeatureStage,
|
| 116 |
+
SemanticTokenStage,
|
| 117 |
+
SoVITSTrainStage,
|
| 118 |
+
GPTTrainStage,
|
| 119 |
+
InferenceStage,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
pipeline = TrainingPipeline()
|
| 123 |
+
|
| 124 |
+
exp_name = config["exp_name"]
|
| 125 |
+
version_str = config.get("version", "v2")
|
| 126 |
+
version = ModelVersion(version_str) if isinstance(version_str, str) else version_str
|
| 127 |
+
|
| 128 |
+
# 通用配置参数
|
| 129 |
+
base_params = {
|
| 130 |
+
"exp_name": exp_name,
|
| 131 |
+
"exp_root": config.get("exp_root", "logs"),
|
| 132 |
+
"gpu_numbers": config.get("gpu_numbers", "0"),
|
| 133 |
+
"is_half": config.get("is_half", True),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# 阶段配置映射
|
| 137 |
+
stage_builders = {
|
| 138 |
+
"audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig(
|
| 139 |
+
**base_params,
|
| 140 |
+
input_path=cfg.get("input_path", ""),
|
| 141 |
+
threshold=cfg.get("threshold", -34),
|
| 142 |
+
min_length=cfg.get("min_length", 4000),
|
| 143 |
+
min_interval=cfg.get("min_interval", 300),
|
| 144 |
+
hop_size=cfg.get("hop_size", 10),
|
| 145 |
+
max_sil_kept=cfg.get("max_sil_kept", 500),
|
| 146 |
+
max_amp=cfg.get("max_amp", 0.9),
|
| 147 |
+
alpha=cfg.get("alpha", 0.25),
|
| 148 |
+
n_parts=cfg.get("n_parts", 4),
|
| 149 |
+
)),
|
| 150 |
+
|
| 151 |
+
"asr": lambda cfg: ASRStage(ASRConfig(
|
| 152 |
+
**base_params,
|
| 153 |
+
model=cfg.get("model", "达摩 ASR (中文)"),
|
| 154 |
+
model_size=cfg.get("model_size", "large"),
|
| 155 |
+
language=cfg.get("language", "zh"),
|
| 156 |
+
precision=cfg.get("precision", "float32"),
|
| 157 |
+
)),
|
| 158 |
+
|
| 159 |
+
"denoise": lambda cfg: DenoiseStage(DenoiseConfig(
|
| 160 |
+
**base_params,
|
| 161 |
+
input_dir=cfg.get("input_dir", ""),
|
| 162 |
+
output_dir=cfg.get("output_dir", "output/denoise_opt"),
|
| 163 |
+
)),
|
| 164 |
+
|
| 165 |
+
"text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig(
|
| 166 |
+
**base_params,
|
| 167 |
+
version=version,
|
| 168 |
+
bert_pretrained_dir=cfg.get("bert_pretrained_dir",
|
| 169 |
+
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
|
| 170 |
+
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
|
| 171 |
+
"GPT_SoVITS/pretrained_models/chinese-hubert-base"),
|
| 172 |
+
pretrained_s2G=cfg.get("pretrained_s2G",
|
| 173 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
|
| 174 |
+
)),
|
| 175 |
+
|
| 176 |
+
"hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig(
|
| 177 |
+
**base_params,
|
| 178 |
+
version=version,
|
| 179 |
+
bert_pretrained_dir=cfg.get("bert_pretrained_dir",
|
| 180 |
+
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
|
| 181 |
+
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
|
| 182 |
+
"GPT_SoVITS/pretrained_models/chinese-hubert-base"),
|
| 183 |
+
pretrained_s2G=cfg.get("pretrained_s2G",
|
| 184 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
|
| 185 |
+
)),
|
| 186 |
+
|
| 187 |
+
"semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig(
|
| 188 |
+
**base_params,
|
| 189 |
+
version=version,
|
| 190 |
+
bert_pretrained_dir=cfg.get("bert_pretrained_dir",
|
| 191 |
+
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
|
| 192 |
+
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
|
| 193 |
+
"GPT_SoVITS/pretrained_models/chinese-hubert-base"),
|
| 194 |
+
pretrained_s2G=cfg.get("pretrained_s2G",
|
| 195 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
|
| 196 |
+
)),
|
| 197 |
+
|
| 198 |
+
"sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig(
|
| 199 |
+
**base_params,
|
| 200 |
+
version=version,
|
| 201 |
+
batch_size=cfg.get("batch_size", 4),
|
| 202 |
+
total_epoch=cfg.get("total_epoch", 8),
|
| 203 |
+
text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4),
|
| 204 |
+
save_every_epoch=cfg.get("save_every_epoch", 4),
|
| 205 |
+
if_save_latest=cfg.get("if_save_latest", True),
|
| 206 |
+
if_save_every_weights=cfg.get("if_save_every_weights", True),
|
| 207 |
+
pretrained_s2G=cfg.get("pretrained_s2G",
|
| 208 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
|
| 209 |
+
pretrained_s2D=cfg.get("pretrained_s2D",
|
| 210 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"),
|
| 211 |
+
if_grad_ckpt=cfg.get("if_grad_ckpt", False),
|
| 212 |
+
lora_rank=cfg.get("lora_rank", 32),
|
| 213 |
+
)),
|
| 214 |
+
|
| 215 |
+
"gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig(
|
| 216 |
+
**base_params,
|
| 217 |
+
version=version,
|
| 218 |
+
batch_size=cfg.get("batch_size", 4),
|
| 219 |
+
total_epoch=cfg.get("total_epoch", 15),
|
| 220 |
+
save_every_epoch=cfg.get("save_every_epoch", 5),
|
| 221 |
+
if_save_latest=cfg.get("if_save_latest", True),
|
| 222 |
+
if_save_every_weights=cfg.get("if_save_every_weights", True),
|
| 223 |
+
if_dpo=cfg.get("if_dpo", False),
|
| 224 |
+
pretrained_s1=cfg.get("pretrained_s1",
|
| 225 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"),
|
| 226 |
+
)),
|
| 227 |
+
|
| 228 |
+
"inference": lambda cfg: InferenceStage(InferenceConfig(
|
| 229 |
+
**base_params,
|
| 230 |
+
version=version,
|
| 231 |
+
gpt_path=cfg.get("gpt_path", ""),
|
| 232 |
+
sovits_path=cfg.get("sovits_path", ""),
|
| 233 |
+
ref_text=cfg.get("ref_text", ""),
|
| 234 |
+
ref_audio_path=cfg.get("ref_audio_path", ""),
|
| 235 |
+
target_text=cfg.get("target_text", ""),
|
| 236 |
+
text_split_method=cfg.get("text_split_method", "cut1"),
|
| 237 |
+
)),
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# 按顺序添加阶段
|
| 241 |
+
stages = config.get("stages", [])
|
| 242 |
+
for stage_config in stages:
|
| 243 |
+
stage_type = stage_config.get("type")
|
| 244 |
+
if stage_type in stage_builders:
|
| 245 |
+
stage = stage_builders[stage_type](stage_config)
|
| 246 |
+
pipeline.add_stage(stage)
|
| 247 |
+
emit_log("info", f"已添加阶段: {stage.name}")
|
| 248 |
+
else:
|
| 249 |
+
emit_log("warning", f"未知阶段类型: {stage_type}")
|
| 250 |
+
|
| 251 |
+
return pipeline
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def run_pipeline(config: Dict[str, Any], task_id: str) -> bool:
|
| 255 |
+
"""
|
| 256 |
+
执行 Pipeline
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
config: 配置字典
|
| 260 |
+
task_id: 任务ID
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
是否成功完成
|
| 264 |
+
"""
|
| 265 |
+
emit_progress({
|
| 266 |
+
"type": "progress",
|
| 267 |
+
"status": "running",
|
| 268 |
+
"message": "正在初始化训练流水线...",
|
| 269 |
+
"task_id": task_id,
|
| 270 |
+
"progress": 0.0,
|
| 271 |
+
"overall_progress": 0.0,
|
| 272 |
+
})
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
pipeline = build_pipeline(config)
|
| 276 |
+
|
| 277 |
+
stages = pipeline.get_stages()
|
| 278 |
+
if not stages:
|
| 279 |
+
emit_progress({
|
| 280 |
+
"type": "progress",
|
| 281 |
+
"status": "failed",
|
| 282 |
+
"message": "没有配置任何训练阶段",
|
| 283 |
+
"task_id": task_id,
|
| 284 |
+
})
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段")
|
| 288 |
+
|
| 289 |
+
# 执行 Pipeline
|
| 290 |
+
for progress in pipeline.run():
|
| 291 |
+
# 转换进度格式
|
| 292 |
+
emit_progress({
|
| 293 |
+
"type": "progress",
|
| 294 |
+
"status": "running",
|
| 295 |
+
"stage": progress.get("stage"),
|
| 296 |
+
"stage_index": progress.get("stage_index"),
|
| 297 |
+
"total_stages": progress.get("total_stages"),
|
| 298 |
+
"progress": progress.get("progress", 0.0),
|
| 299 |
+
"overall_progress": progress.get("overall_progress", 0.0),
|
| 300 |
+
"message": progress.get("message"),
|
| 301 |
+
"task_id": task_id,
|
| 302 |
+
"data": progress.get("data", {}),
|
| 303 |
+
})
|
| 304 |
+
|
| 305 |
+
# 检查是否失败
|
| 306 |
+
if progress.get("status") == "failed":
|
| 307 |
+
emit_progress({
|
| 308 |
+
"type": "progress",
|
| 309 |
+
"status": "failed",
|
| 310 |
+
"stage": progress.get("stage"),
|
| 311 |
+
"message": progress.get("message", "阶段执行失败"),
|
| 312 |
+
"task_id": task_id,
|
| 313 |
+
})
|
| 314 |
+
return False
|
| 315 |
+
|
| 316 |
+
# 完成
|
| 317 |
+
emit_progress({
|
| 318 |
+
"type": "progress",
|
| 319 |
+
"status": "completed",
|
| 320 |
+
"message": "训练流水线执行完成",
|
| 321 |
+
"task_id": task_id,
|
| 322 |
+
"progress": 1.0,
|
| 323 |
+
"overall_progress": 1.0,
|
| 324 |
+
})
|
| 325 |
+
return True
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
error_msg = str(e)
|
| 329 |
+
error_trace = traceback.format_exc()
|
| 330 |
+
emit_progress({
|
| 331 |
+
"type": "progress",
|
| 332 |
+
"status": "failed",
|
| 333 |
+
"message": f"执行出错: {error_msg}",
|
| 334 |
+
"error": error_msg,
|
| 335 |
+
"traceback": error_trace,
|
| 336 |
+
"task_id": task_id,
|
| 337 |
+
})
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def main():
|
| 342 |
+
"""主函数"""
|
| 343 |
+
parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线")
|
| 344 |
+
parser.add_argument("--config", required=True, help="配置文件路径 (JSON)")
|
| 345 |
+
parser.add_argument("--task-id", required=True, help="任务ID")
|
| 346 |
+
|
| 347 |
+
args = parser.parse_args()
|
| 348 |
+
|
| 349 |
+
emit_log("info", f"启动训练任务: {args.task_id}")
|
| 350 |
+
emit_log("info", f"配置文件: {args.config}")
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
config = load_config(args.config)
|
| 354 |
+
except Exception as e:
|
| 355 |
+
emit_progress({
|
| 356 |
+
"type": "progress",
|
| 357 |
+
"status": "failed",
|
| 358 |
+
"message": f"加载配置文件失败: {e}",
|
| 359 |
+
"task_id": args.task_id,
|
| 360 |
+
})
|
| 361 |
+
sys.exit(1)
|
| 362 |
+
|
| 363 |
+
success = run_pipeline(config, args.task_id)
|
| 364 |
+
sys.exit(0 if success else 1)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
main()
|