liumaolin
commited on
Commit
·
e054d0c
1
Parent(s):
e43edbb
feat(api): implement local training MVP with adapter pattern
Browse files- Add adapter base classes (TaskQueue, Progress, Storage, Database)
- Implement local adapters (AsyncTrainingManager, SQLite, LocalStorage)
- Add Pydantic schemas for tasks, experiments, files, and stages
- Implement service layer (TaskService, ExperimentService, FileService)
- Add API endpoints for Quick Mode (/tasks) and Advanced Mode (/experiments)
- Add adapter factory with dependency injection support
- Update architecture design document with implementation status
- api_server/app/adapters/base.py +454 -1
- api_server/app/adapters/local/__init__.py +16 -2
- api_server/app/adapters/local/database.py +683 -0
- api_server/app/adapters/local/progress.py +238 -0
- api_server/app/adapters/local/storage.py +342 -0
- api_server/app/adapters/local/task_queue.py +73 -2
- api_server/app/api/__init__.py +9 -0
- api_server/app/api/deps.py +96 -0
- api_server/app/api/v1/__init__.py +9 -0
- api_server/app/api/v1/endpoints/__init__.py +17 -0
- api_server/app/api/v1/endpoints/experiments.py +393 -0
- api_server/app/api/v1/endpoints/files.py +222 -0
- api_server/app/api/v1/endpoints/stages.py +247 -0
- api_server/app/api/v1/endpoints/tasks.py +228 -0
- api_server/app/api/v1/router.py +39 -0
- api_server/app/core/adapters.py +180 -0
- api_server/app/main.py +155 -0
- api_server/app/models/__init__.py +72 -1
- api_server/app/models/schemas/__init__.py +80 -0
- api_server/app/models/schemas/common.py +95 -0
- api_server/app/models/schemas/experiment.py +556 -0
- api_server/app/models/schemas/file.py +159 -0
- api_server/app/models/schemas/task.py +232 -0
- api_server/app/scripts/run_pipeline.py +16 -2
- api_server/app/services/__init__.py +20 -0
- api_server/app/services/experiment_service.py +513 -0
- api_server/app/services/file_service.py +277 -0
- api_server/app/services/task_service.py +322 -0
api_server/app/adapters/base.py
CHANGED
|
@@ -5,7 +5,10 @@
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from abc import ABC, abstractmethod
|
| 8 |
-
from typing import Dict, Optional, AsyncGenerator
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class TaskQueueAdapter(ABC):
|
|
@@ -138,3 +141,453 @@ class ProgressAdapter(ABC):
|
|
| 138 |
进度信息字典
|
| 139 |
"""
|
| 140 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, AsyncGenerator, Any
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from ..models.domain import Task
|
| 12 |
|
| 13 |
|
| 14 |
class TaskQueueAdapter(ABC):
|
|
|
|
| 141 |
进度信息字典
|
| 142 |
"""
|
| 143 |
pass
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class StorageAdapter(ABC):
|
| 147 |
+
"""
|
| 148 |
+
存储适配器抽象基类
|
| 149 |
+
|
| 150 |
+
定义文件存储的通用接口,支持本地文件系统和
|
| 151 |
+
对象存储(S3/MinIO)两种实现方式。
|
| 152 |
+
|
| 153 |
+
Example:
|
| 154 |
+
>>> adapter = LocalStorageAdapter(base_path="./data/files")
|
| 155 |
+
>>> file_id = await adapter.upload_file(data, "audio.wav", {"purpose": "training"})
|
| 156 |
+
>>> content = await adapter.download_file(file_id)
|
| 157 |
+
>>> await adapter.delete_file(file_id)
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
@abstractmethod
|
| 161 |
+
async def upload_file(
|
| 162 |
+
self,
|
| 163 |
+
file_data: bytes,
|
| 164 |
+
filename: str,
|
| 165 |
+
metadata: Dict[str, Any]
|
| 166 |
+
) -> str:
|
| 167 |
+
"""
|
| 168 |
+
上传文件
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
file_data: 文件二进制数据
|
| 172 |
+
filename: 原始文件名
|
| 173 |
+
metadata: 文件元数据,可包含:
|
| 174 |
+
- content_type: MIME类型
|
| 175 |
+
- purpose: 文件用途 (training, reference, output)
|
| 176 |
+
- 其他自定义字段
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
file_id: 文件唯一标识
|
| 180 |
+
|
| 181 |
+
Raises:
|
| 182 |
+
IOError: 存储失败时抛出
|
| 183 |
+
"""
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
@abstractmethod
|
| 187 |
+
async def download_file(self, file_id: str) -> bytes:
|
| 188 |
+
"""
|
| 189 |
+
下载文件
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
file_id: 文件唯一标识
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
文件二进制数据
|
| 196 |
+
|
| 197 |
+
Raises:
|
| 198 |
+
FileNotFoundError: 文件不存在时抛出
|
| 199 |
+
"""
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
@abstractmethod
|
| 203 |
+
async def delete_file(self, file_id: str) -> bool:
|
| 204 |
+
"""
|
| 205 |
+
删除文件
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
file_id: 文件唯一标识
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
是否成功删除
|
| 212 |
+
"""
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
@abstractmethod
|
| 216 |
+
async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
|
| 217 |
+
"""
|
| 218 |
+
获取文件元数据
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
file_id: 文件唯一标识
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
文件元数据字典,包含:
|
| 225 |
+
- id: 文件ID
|
| 226 |
+
- filename: 原始文件名
|
| 227 |
+
- content_type: MIME类型
|
| 228 |
+
- size_bytes: 文件大小
|
| 229 |
+
- purpose: 文件用途
|
| 230 |
+
- uploaded_at: 上传时间
|
| 231 |
+
- 音频文件额外包含: duration_seconds, sample_rate
|
| 232 |
+
|
| 233 |
+
文件不存在时返回 None
|
| 234 |
+
"""
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
@abstractmethod
|
| 238 |
+
async def list_files(
|
| 239 |
+
self,
|
| 240 |
+
purpose: Optional[str] = None,
|
| 241 |
+
limit: int = 50,
|
| 242 |
+
offset: int = 0
|
| 243 |
+
) -> List[Dict[str, Any]]:
|
| 244 |
+
"""
|
| 245 |
+
列出文件
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
purpose: 按用途筛选 (training, reference, output)
|
| 249 |
+
limit: 返回数量限制
|
| 250 |
+
offset: 偏移量
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
文件元数据列表
|
| 254 |
+
"""
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
@abstractmethod
|
| 258 |
+
async def file_exists(self, file_id: str) -> bool:
|
| 259 |
+
"""
|
| 260 |
+
检查文件是否存在
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
file_id: 文件唯一标识
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
文件是否存在
|
| 267 |
+
"""
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class DatabaseAdapter(ABC):
|
| 272 |
+
"""
|
| 273 |
+
数据库适配器抽象基类
|
| 274 |
+
|
| 275 |
+
定义数据持久化的通用接口,支持 SQLite 和
|
| 276 |
+
PostgreSQL 两种实现方式。
|
| 277 |
+
|
| 278 |
+
管理以下实体:
|
| 279 |
+
- Task: Quick Mode 一键训练任务
|
| 280 |
+
- Experiment: Advanced Mode 实验
|
| 281 |
+
- Stage: 实验中的各个阶段
|
| 282 |
+
- File: 上传的文件记录(可选,与StorageAdapter配合)
|
| 283 |
+
|
| 284 |
+
Example:
|
| 285 |
+
>>> adapter = SQLiteAdapter(db_path="./data/app.db")
|
| 286 |
+
>>> task = await adapter.create_task(task_data)
|
| 287 |
+
>>> task = await adapter.get_task(task_id)
|
| 288 |
+
>>> await adapter.update_task(task_id, {"status": "completed"})
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
# ============================================================
|
| 292 |
+
# Task CRUD (Quick Mode)
|
| 293 |
+
# ============================================================
|
| 294 |
+
|
| 295 |
+
@abstractmethod
|
| 296 |
+
async def create_task(self, task: "Task") -> "Task":
|
| 297 |
+
"""
|
| 298 |
+
创建任务
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
task: Task 领域模型实例
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
创建后的 Task 实例(包含生成的字段如 created_at)
|
| 305 |
+
"""
|
| 306 |
+
pass
|
| 307 |
+
|
| 308 |
+
@abstractmethod
|
| 309 |
+
async def get_task(self, task_id: str) -> Optional["Task"]:
|
| 310 |
+
"""
|
| 311 |
+
获取任务
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
task_id: 任务唯一标识
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Task 实例���不存在则返回 None
|
| 318 |
+
"""
|
| 319 |
+
pass
|
| 320 |
+
|
| 321 |
+
@abstractmethod
|
| 322 |
+
async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional["Task"]:
|
| 323 |
+
"""
|
| 324 |
+
更新任务
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
task_id: 任务唯一标识
|
| 328 |
+
updates: 要更新的字段字典
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
更新后的 Task 实例,不存在则返回 None
|
| 332 |
+
"""
|
| 333 |
+
pass
|
| 334 |
+
|
| 335 |
+
@abstractmethod
|
| 336 |
+
async def list_tasks(
|
| 337 |
+
self,
|
| 338 |
+
status: Optional[str] = None,
|
| 339 |
+
limit: int = 50,
|
| 340 |
+
offset: int = 0
|
| 341 |
+
) -> List["Task"]:
|
| 342 |
+
"""
|
| 343 |
+
查询任务列表
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
status: 按状态筛选
|
| 347 |
+
limit: 返回数量限制
|
| 348 |
+
offset: 偏移量
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Task 实例列表
|
| 352 |
+
"""
|
| 353 |
+
pass
|
| 354 |
+
|
| 355 |
+
@abstractmethod
|
| 356 |
+
async def delete_task(self, task_id: str) -> bool:
|
| 357 |
+
"""
|
| 358 |
+
删除任务
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
task_id: 任务唯一标识
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
是否成功删除
|
| 365 |
+
"""
|
| 366 |
+
pass
|
| 367 |
+
|
| 368 |
+
@abstractmethod
|
| 369 |
+
async def count_tasks(self, status: Optional[str] = None) -> int:
|
| 370 |
+
"""
|
| 371 |
+
统计任务数量
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
status: 按状态筛选
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
任务数量
|
| 378 |
+
"""
|
| 379 |
+
pass
|
| 380 |
+
|
| 381 |
+
@abstractmethod
|
| 382 |
+
async def get_task_by_exp_name(self, exp_name: str) -> Optional["Task"]:
|
| 383 |
+
"""
|
| 384 |
+
根据实验名称获取任务
|
| 385 |
+
|
| 386 |
+
用于检查 exp_name 是否已存在。
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
exp_name: 实验名称
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
Task 实例,不存在则返回 None
|
| 393 |
+
"""
|
| 394 |
+
pass
|
| 395 |
+
|
| 396 |
+
# ============================================================
|
| 397 |
+
# Experiment CRUD (Advanced Mode)
|
| 398 |
+
# ============================================================
|
| 399 |
+
|
| 400 |
+
@abstractmethod
|
| 401 |
+
async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]:
|
| 402 |
+
"""
|
| 403 |
+
创建实验
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
experiment: 实验数据字典
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
创建后的实验数据
|
| 410 |
+
"""
|
| 411 |
+
pass
|
| 412 |
+
|
| 413 |
+
@abstractmethod
|
| 414 |
+
async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]:
|
| 415 |
+
"""
|
| 416 |
+
获取实验
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
exp_id: 实验唯一标识
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
实验数据字典,不存在则返回 None
|
| 423 |
+
"""
|
| 424 |
+
pass
|
| 425 |
+
|
| 426 |
+
@abstractmethod
|
| 427 |
+
async def update_experiment(
|
| 428 |
+
self,
|
| 429 |
+
exp_id: str,
|
| 430 |
+
updates: Dict[str, Any]
|
| 431 |
+
) -> Optional[Dict[str, Any]]:
|
| 432 |
+
"""
|
| 433 |
+
更新实验
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
exp_id: 实验唯一标识
|
| 437 |
+
updates: 要更新的字段字典
|
| 438 |
+
|
| 439 |
+
Returns:
|
| 440 |
+
更新后的实验数据,不存在则返回 None
|
| 441 |
+
"""
|
| 442 |
+
pass
|
| 443 |
+
|
| 444 |
+
@abstractmethod
|
| 445 |
+
async def list_experiments(
|
| 446 |
+
self,
|
| 447 |
+
status: Optional[str] = None,
|
| 448 |
+
limit: int = 50,
|
| 449 |
+
offset: int = 0
|
| 450 |
+
) -> List[Dict[str, Any]]:
|
| 451 |
+
"""
|
| 452 |
+
查询实验列表
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
status: 按状态筛选
|
| 456 |
+
limit: 返回数量限制
|
| 457 |
+
offset: 偏移量
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
实验数据列表
|
| 461 |
+
"""
|
| 462 |
+
pass
|
| 463 |
+
|
| 464 |
+
@abstractmethod
|
| 465 |
+
async def delete_experiment(self, exp_id: str) -> bool:
|
| 466 |
+
"""
|
| 467 |
+
删除实验
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
exp_id: 实验唯一标识
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
是否成功删除
|
| 474 |
+
"""
|
| 475 |
+
pass
|
| 476 |
+
|
| 477 |
+
# ============================================================
|
| 478 |
+
# Stage 操作 (Advanced Mode)
|
| 479 |
+
# ============================================================
|
| 480 |
+
|
| 481 |
+
@abstractmethod
|
| 482 |
+
async def update_stage(
|
| 483 |
+
self,
|
| 484 |
+
exp_id: str,
|
| 485 |
+
stage_type: str,
|
| 486 |
+
updates: Dict[str, Any]
|
| 487 |
+
) -> Optional[Dict[str, Any]]:
|
| 488 |
+
"""
|
| 489 |
+
更新阶段状态
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
exp_id: 实验唯一标识
|
| 493 |
+
stage_type: 阶段类型
|
| 494 |
+
updates: 要更新的字段字典
|
| 495 |
+
|
| 496 |
+
Returns:
|
| 497 |
+
更新后的阶段数据,不存在则返回 None
|
| 498 |
+
"""
|
| 499 |
+
pass
|
| 500 |
+
|
| 501 |
+
@abstractmethod
|
| 502 |
+
async def get_stage(
|
| 503 |
+
self,
|
| 504 |
+
exp_id: str,
|
| 505 |
+
stage_type: str
|
| 506 |
+
) -> Optional[Dict[str, Any]]:
|
| 507 |
+
"""
|
| 508 |
+
获取阶段状态
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
exp_id: 实验唯一标识
|
| 512 |
+
stage_type: 阶段类型
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
阶段数据字典,不存在则返回 None
|
| 516 |
+
"""
|
| 517 |
+
pass
|
| 518 |
+
|
| 519 |
+
@abstractmethod
|
| 520 |
+
async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]:
|
| 521 |
+
"""
|
| 522 |
+
获取实验的所有阶段状态
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
exp_id: 实验唯一标识
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
阶段数据列表
|
| 529 |
+
"""
|
| 530 |
+
pass
|
| 531 |
+
|
| 532 |
+
# ============================================================
|
| 533 |
+
# File 记录 (可选,与 StorageAdapter 配合)
|
| 534 |
+
# ============================================================
|
| 535 |
+
|
| 536 |
+
@abstractmethod
|
| 537 |
+
async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 538 |
+
"""
|
| 539 |
+
创建文件记录
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
file_data: 文件元数据
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
创建后的文件记录
|
| 546 |
+
"""
|
| 547 |
+
pass
|
| 548 |
+
|
| 549 |
+
@abstractmethod
|
| 550 |
+
async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
|
| 551 |
+
"""
|
| 552 |
+
获取文件记录
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
file_id: 文件唯一标识
|
| 556 |
+
|
| 557 |
+
Returns:
|
| 558 |
+
文件记录,不存在则返回 None
|
| 559 |
+
"""
|
| 560 |
+
pass
|
| 561 |
+
|
| 562 |
+
@abstractmethod
|
| 563 |
+
async def delete_file_record(self, file_id: str) -> bool:
|
| 564 |
+
"""
|
| 565 |
+
删除文件记录
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
file_id: 文件唯一标识
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
是否成功删除
|
| 572 |
+
"""
|
| 573 |
+
pass
|
| 574 |
+
|
| 575 |
+
@abstractmethod
|
| 576 |
+
async def list_file_records(
|
| 577 |
+
self,
|
| 578 |
+
purpose: Optional[str] = None,
|
| 579 |
+
limit: int = 50,
|
| 580 |
+
offset: int = 0
|
| 581 |
+
) -> List[Dict[str, Any]]:
|
| 582 |
+
"""
|
| 583 |
+
查询文件记录列表
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
purpose: 按用途筛选
|
| 587 |
+
limit: 返回数量限制
|
| 588 |
+
offset: 偏移量
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
文件记录列表
|
| 592 |
+
"""
|
| 593 |
+
pass
|
api_server/app/adapters/local/__init__.py
CHANGED
|
@@ -1,9 +1,23 @@
|
|
| 1 |
"""
|
| 2 |
本地适配器模块
|
| 3 |
|
| 4 |
-
提供基于 SQLite 和 asyncio.subprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from .task_queue import AsyncTrainingManager
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
本地适配器模块
|
| 3 |
|
| 4 |
+
提供基于 SQLite 和 asyncio.subprocess 的本地实现。
|
| 5 |
+
|
| 6 |
+
适配器列表:
|
| 7 |
+
- AsyncTrainingManager: 任务队列适配器(基于 asyncio.subprocess)
|
| 8 |
+
- LocalStorageAdapter: 文件存储适配器(基于本地文件系统)
|
| 9 |
+
- SQLiteAdapter: 数据库适配器(基于 SQLite)
|
| 10 |
+
- LocalProgressAdapter: 进度管理适配器(基于内存队列)
|
| 11 |
"""
|
| 12 |
|
| 13 |
from .task_queue import AsyncTrainingManager
|
| 14 |
+
from .storage import LocalStorageAdapter
|
| 15 |
+
from .database import SQLiteAdapter
|
| 16 |
+
from .progress import LocalProgressAdapter
|
| 17 |
|
| 18 |
+
__all__ = [
|
| 19 |
+
"AsyncTrainingManager",
|
| 20 |
+
"LocalStorageAdapter",
|
| 21 |
+
"SQLiteAdapter",
|
| 22 |
+
"LocalProgressAdapter",
|
| 23 |
+
]
|
api_server/app/adapters/local/database.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLite 数据库适配器
|
| 3 |
+
|
| 4 |
+
基于 SQLite + aiosqlite 实现的数据库适配器,适用于 macOS 本地训练场景。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import sqlite3
|
| 9 |
+
import uuid
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, List, Optional
|
| 13 |
+
|
| 14 |
+
import aiosqlite
|
| 15 |
+
|
| 16 |
+
from ..base import DatabaseAdapter
|
| 17 |
+
from ...core.config import settings
|
| 18 |
+
from ...models.domain import Task, TaskStatus
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# 阶段类型列表
|
| 22 |
+
STAGE_TYPES = [
|
| 23 |
+
"audio_slice",
|
| 24 |
+
"asr",
|
| 25 |
+
"text_feature",
|
| 26 |
+
"hubert_feature",
|
| 27 |
+
"semantic_token",
|
| 28 |
+
"sovits_train",
|
| 29 |
+
"gpt_train",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SQLiteAdapter(DatabaseAdapter):
|
| 34 |
+
"""
|
| 35 |
+
SQLite 数据库适配器
|
| 36 |
+
|
| 37 |
+
特点:
|
| 38 |
+
1. 使用 aiosqlite 实现异步数据库操作
|
| 39 |
+
2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理
|
| 40 |
+
3. 自动初始化数据库表结构
|
| 41 |
+
|
| 42 |
+
表结构:
|
| 43 |
+
- tasks: Quick Mode 任务
|
| 44 |
+
- experiments: Advanced Mode 实验
|
| 45 |
+
- stages: 实验阶段状态
|
| 46 |
+
- files: 文件记录
|
| 47 |
+
|
| 48 |
+
Example:
|
| 49 |
+
>>> adapter = SQLiteAdapter()
|
| 50 |
+
>>> task = Task(id="task-123", exp_name="my_voice", config={})
|
| 51 |
+
>>> await adapter.create_task(task)
|
| 52 |
+
>>> task = await adapter.get_task("task-123")
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, db_path: Optional[str] = None):
|
| 56 |
+
"""
|
| 57 |
+
初始化 SQLite 适配器
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
db_path: 数据库文件路径,默认使用 settings.SQLITE_PATH
|
| 61 |
+
"""
|
| 62 |
+
if db_path:
|
| 63 |
+
self.db_path = db_path
|
| 64 |
+
else:
|
| 65 |
+
self.db_path = str(settings.SQLITE_PATH)
|
| 66 |
+
|
| 67 |
+
# 确保目录存在
|
| 68 |
+
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
# 同步初始化数据库
|
| 71 |
+
self._init_db_sync()
|
| 72 |
+
|
| 73 |
+
def _init_db_sync(self) -> None:
|
| 74 |
+
"""同步初始化数据库表结构"""
|
| 75 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 76 |
+
# Tasks 表 (Quick Mode)
|
| 77 |
+
conn.execute('''
|
| 78 |
+
CREATE TABLE IF NOT EXISTS tasks (
|
| 79 |
+
id TEXT PRIMARY KEY,
|
| 80 |
+
job_id TEXT,
|
| 81 |
+
exp_name TEXT NOT NULL,
|
| 82 |
+
status TEXT NOT NULL DEFAULT 'queued',
|
| 83 |
+
config TEXT,
|
| 84 |
+
current_stage TEXT,
|
| 85 |
+
progress REAL DEFAULT 0,
|
| 86 |
+
stage_progress REAL DEFAULT 0,
|
| 87 |
+
message TEXT,
|
| 88 |
+
error_message TEXT,
|
| 89 |
+
created_at TEXT NOT NULL,
|
| 90 |
+
started_at TEXT,
|
| 91 |
+
completed_at TEXT
|
| 92 |
+
)
|
| 93 |
+
''')
|
| 94 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)')
|
| 95 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)')
|
| 96 |
+
|
| 97 |
+
# Experiments 表 (Advanced Mode)
|
| 98 |
+
conn.execute('''
|
| 99 |
+
CREATE TABLE IF NOT EXISTS experiments (
|
| 100 |
+
id TEXT PRIMARY KEY,
|
| 101 |
+
exp_name TEXT NOT NULL,
|
| 102 |
+
version TEXT NOT NULL DEFAULT 'v2',
|
| 103 |
+
exp_root TEXT DEFAULT 'logs',
|
| 104 |
+
gpu_numbers TEXT DEFAULT '0',
|
| 105 |
+
is_half INTEGER DEFAULT 1,
|
| 106 |
+
audio_file_id TEXT,
|
| 107 |
+
status TEXT NOT NULL DEFAULT 'created',
|
| 108 |
+
created_at TEXT NOT NULL,
|
| 109 |
+
updated_at TEXT
|
| 110 |
+
)
|
| 111 |
+
''')
|
| 112 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status)')
|
| 113 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_created ON experiments(created_at)')
|
| 114 |
+
|
| 115 |
+
# Stages 表 (Advanced Mode 阶段状态)
|
| 116 |
+
conn.execute('''
|
| 117 |
+
CREATE TABLE IF NOT EXISTS stages (
|
| 118 |
+
id TEXT PRIMARY KEY,
|
| 119 |
+
experiment_id TEXT NOT NULL,
|
| 120 |
+
stage_type TEXT NOT NULL,
|
| 121 |
+
status TEXT DEFAULT 'pending',
|
| 122 |
+
progress REAL DEFAULT 0,
|
| 123 |
+
message TEXT,
|
| 124 |
+
job_id TEXT,
|
| 125 |
+
config TEXT,
|
| 126 |
+
outputs TEXT,
|
| 127 |
+
started_at TEXT,
|
| 128 |
+
completed_at TEXT,
|
| 129 |
+
error_message TEXT,
|
| 130 |
+
FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE,
|
| 131 |
+
UNIQUE (experiment_id, stage_type)
|
| 132 |
+
)
|
| 133 |
+
''')
|
| 134 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_experiment ON stages(experiment_id)')
|
| 135 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_status ON stages(status)')
|
| 136 |
+
|
| 137 |
+
# Files 表 (文件记录)
|
| 138 |
+
conn.execute('''
|
| 139 |
+
CREATE TABLE IF NOT EXISTS files (
|
| 140 |
+
id TEXT PRIMARY KEY,
|
| 141 |
+
filename TEXT NOT NULL,
|
| 142 |
+
content_type TEXT,
|
| 143 |
+
size_bytes INTEGER DEFAULT 0,
|
| 144 |
+
purpose TEXT DEFAULT 'training',
|
| 145 |
+
duration_seconds REAL,
|
| 146 |
+
sample_rate INTEGER,
|
| 147 |
+
storage_path TEXT,
|
| 148 |
+
uploaded_at TEXT NOT NULL
|
| 149 |
+
)
|
| 150 |
+
''')
|
| 151 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_files_purpose ON files(purpose)')
|
| 152 |
+
conn.execute('CREATE INDEX IF NOT EXISTS idx_files_uploaded ON files(uploaded_at)')
|
| 153 |
+
|
| 154 |
+
conn.commit()
|
| 155 |
+
|
| 156 |
+
# ============================================================
|
| 157 |
+
# Task CRUD (Quick Mode)
|
| 158 |
+
# ============================================================
|
| 159 |
+
|
| 160 |
+
async def create_task(self, task: Task) -> Task:
|
| 161 |
+
"""创建任务"""
|
| 162 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 163 |
+
await db.execute(
|
| 164 |
+
'''INSERT INTO tasks
|
| 165 |
+
(id, job_id, exp_name, status, config, current_stage,
|
| 166 |
+
progress, stage_progress, message, error_message,
|
| 167 |
+
created_at, started_at, completed_at)
|
| 168 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
|
| 169 |
+
(
|
| 170 |
+
task.id,
|
| 171 |
+
task.job_id,
|
| 172 |
+
task.exp_name,
|
| 173 |
+
task.status.value if isinstance(task.status, TaskStatus) else task.status,
|
| 174 |
+
json.dumps(task.config, ensure_ascii=False) if task.config else None,
|
| 175 |
+
task.current_stage,
|
| 176 |
+
task.progress,
|
| 177 |
+
task.stage_progress,
|
| 178 |
+
task.message,
|
| 179 |
+
task.error_message,
|
| 180 |
+
task.created_at.isoformat() if task.created_at else datetime.utcnow().isoformat(),
|
| 181 |
+
task.started_at.isoformat() if task.started_at else None,
|
| 182 |
+
task.completed_at.isoformat() if task.completed_at else None,
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
await db.commit()
|
| 186 |
+
|
| 187 |
+
return task
|
| 188 |
+
|
| 189 |
+
async def get_task(self, task_id: str) -> Optional[Task]:
|
| 190 |
+
"""获取任务"""
|
| 191 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 192 |
+
db.row_factory = aiosqlite.Row
|
| 193 |
+
async with db.execute(
|
| 194 |
+
"SELECT * FROM tasks WHERE id = ?", (task_id,)
|
| 195 |
+
) as cursor:
|
| 196 |
+
row = await cursor.fetchone()
|
| 197 |
+
if row:
|
| 198 |
+
return self._row_to_task(dict(row))
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional[Task]:
|
| 202 |
+
"""更新任务"""
|
| 203 |
+
if not updates:
|
| 204 |
+
return await self.get_task(task_id)
|
| 205 |
+
|
| 206 |
+
# 处理特殊字段
|
| 207 |
+
processed = {}
|
| 208 |
+
for key, value in updates.items():
|
| 209 |
+
if key == "status" and isinstance(value, TaskStatus):
|
| 210 |
+
processed[key] = value.value
|
| 211 |
+
elif key == "config" and isinstance(value, dict):
|
| 212 |
+
processed[key] = json.dumps(value, ensure_ascii=False)
|
| 213 |
+
elif key in ("created_at", "started_at", "completed_at") and isinstance(value, datetime):
|
| 214 |
+
processed[key] = value.isoformat()
|
| 215 |
+
else:
|
| 216 |
+
processed[key] = value
|
| 217 |
+
|
| 218 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 219 |
+
set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
|
| 220 |
+
values = list(processed.values()) + [task_id]
|
| 221 |
+
|
| 222 |
+
await db.execute(
|
| 223 |
+
f"UPDATE tasks SET {set_clause} WHERE id = ?",
|
| 224 |
+
values
|
| 225 |
+
)
|
| 226 |
+
await db.commit()
|
| 227 |
+
|
| 228 |
+
return await self.get_task(task_id)
|
| 229 |
+
|
| 230 |
+
async def list_tasks(
|
| 231 |
+
self,
|
| 232 |
+
status: Optional[str] = None,
|
| 233 |
+
limit: int = 50,
|
| 234 |
+
offset: int = 0
|
| 235 |
+
) -> List[Task]:
|
| 236 |
+
"""查询任务列表"""
|
| 237 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 238 |
+
db.row_factory = aiosqlite.Row
|
| 239 |
+
|
| 240 |
+
if status:
|
| 241 |
+
query = """
|
| 242 |
+
SELECT * FROM tasks
|
| 243 |
+
WHERE status = ?
|
| 244 |
+
ORDER BY created_at DESC
|
| 245 |
+
LIMIT ? OFFSET ?
|
| 246 |
+
"""
|
| 247 |
+
params = (status, limit, offset)
|
| 248 |
+
else:
|
| 249 |
+
query = """
|
| 250 |
+
SELECT * FROM tasks
|
| 251 |
+
ORDER BY created_at DESC
|
| 252 |
+
LIMIT ? OFFSET ?
|
| 253 |
+
"""
|
| 254 |
+
params = (limit, offset)
|
| 255 |
+
|
| 256 |
+
async with db.execute(query, params) as cursor:
|
| 257 |
+
rows = await cursor.fetchall()
|
| 258 |
+
return [self._row_to_task(dict(row)) for row in rows]
|
| 259 |
+
|
| 260 |
+
async def delete_task(self, task_id: str) -> bool:
|
| 261 |
+
"""删除任务"""
|
| 262 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 263 |
+
cursor = await db.execute(
|
| 264 |
+
"DELETE FROM tasks WHERE id = ?", (task_id,)
|
| 265 |
+
)
|
| 266 |
+
await db.commit()
|
| 267 |
+
return cursor.rowcount > 0
|
| 268 |
+
|
| 269 |
+
async def count_tasks(self, status: Optional[str] = None) -> int:
|
| 270 |
+
"""统计任务数量"""
|
| 271 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 272 |
+
if status:
|
| 273 |
+
async with db.execute(
|
| 274 |
+
"SELECT COUNT(*) FROM tasks WHERE status = ?", (status,)
|
| 275 |
+
) as cursor:
|
| 276 |
+
row = await cursor.fetchone()
|
| 277 |
+
else:
|
| 278 |
+
async with db.execute("SELECT COUNT(*) FROM tasks") as cursor:
|
| 279 |
+
row = await cursor.fetchone()
|
| 280 |
+
|
| 281 |
+
return row[0] if row else 0
|
| 282 |
+
|
| 283 |
+
async def get_task_by_exp_name(self, exp_name: str) -> Optional[Task]:
|
| 284 |
+
"""根据实验名称获取任务"""
|
| 285 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 286 |
+
db.row_factory = aiosqlite.Row
|
| 287 |
+
async with db.execute(
|
| 288 |
+
"SELECT * FROM tasks WHERE exp_name = ? LIMIT 1", (exp_name,)
|
| 289 |
+
) as cursor:
|
| 290 |
+
row = await cursor.fetchone()
|
| 291 |
+
if row:
|
| 292 |
+
return self._row_to_task(dict(row))
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
def _row_to_task(self, row: Dict[str, Any]) -> Task:
|
| 296 |
+
"""将数据库行转换为 Task 对象"""
|
| 297 |
+
# 解析 config JSON
|
| 298 |
+
config = row.get("config")
|
| 299 |
+
if config and isinstance(config, str):
|
| 300 |
+
try:
|
| 301 |
+
config = json.loads(config)
|
| 302 |
+
except json.JSONDecodeError:
|
| 303 |
+
config = {}
|
| 304 |
+
|
| 305 |
+
return Task.from_dict({
|
| 306 |
+
"id": row["id"],
|
| 307 |
+
"job_id": row.get("job_id"),
|
| 308 |
+
"exp_name": row["exp_name"],
|
| 309 |
+
"status": row.get("status", "queued"),
|
| 310 |
+
"config": config or {},
|
| 311 |
+
"current_stage": row.get("current_stage"),
|
| 312 |
+
"progress": row.get("progress", 0.0),
|
| 313 |
+
"stage_progress": row.get("stage_progress", 0.0),
|
| 314 |
+
"message": row.get("message"),
|
| 315 |
+
"error_message": row.get("error_message"),
|
| 316 |
+
"created_at": row.get("created_at"),
|
| 317 |
+
"started_at": row.get("started_at"),
|
| 318 |
+
"completed_at": row.get("completed_at"),
|
| 319 |
+
})
|
| 320 |
+
|
| 321 |
+
# ============================================================
|
| 322 |
+
# Experiment CRUD (Advanced Mode)
|
| 323 |
+
# ============================================================
|
| 324 |
+
|
| 325 |
+
async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]:
|
| 326 |
+
"""创建实验"""
|
| 327 |
+
exp_id = experiment.get("id") or f"exp-{uuid.uuid4().hex[:8]}"
|
| 328 |
+
now = datetime.utcnow().isoformat()
|
| 329 |
+
|
| 330 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 331 |
+
# 创建实验记录
|
| 332 |
+
await db.execute(
|
| 333 |
+
'''INSERT INTO experiments
|
| 334 |
+
(id, exp_name, version, exp_root, gpu_numbers, is_half,
|
| 335 |
+
audio_file_id, status, created_at, updated_at)
|
| 336 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
|
| 337 |
+
(
|
| 338 |
+
exp_id,
|
| 339 |
+
experiment["exp_name"],
|
| 340 |
+
experiment.get("version", "v2"),
|
| 341 |
+
experiment.get("exp_root", "logs"),
|
| 342 |
+
experiment.get("gpu_numbers", "0"),
|
| 343 |
+
1 if experiment.get("is_half", True) else 0,
|
| 344 |
+
experiment.get("audio_file_id"),
|
| 345 |
+
experiment.get("status", "created"),
|
| 346 |
+
now,
|
| 347 |
+
now,
|
| 348 |
+
)
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# 创建所有阶段的初始状态
|
| 352 |
+
for stage_type in STAGE_TYPES:
|
| 353 |
+
stage_id = f"{exp_id}-{stage_type}"
|
| 354 |
+
await db.execute(
|
| 355 |
+
'''INSERT INTO stages
|
| 356 |
+
(id, experiment_id, stage_type, status)
|
| 357 |
+
VALUES (?, ?, ?, 'pending')''',
|
| 358 |
+
(stage_id, exp_id, stage_type)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
await db.commit()
|
| 362 |
+
|
| 363 |
+
return await self.get_experiment(exp_id)
|
| 364 |
+
|
| 365 |
+
async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]:
|
| 366 |
+
"""获取实验"""
|
| 367 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 368 |
+
db.row_factory = aiosqlite.Row
|
| 369 |
+
|
| 370 |
+
# 获取实验基本信息
|
| 371 |
+
async with db.execute(
|
| 372 |
+
"SELECT * FROM experiments WHERE id = ?", (exp_id,)
|
| 373 |
+
) as cursor:
|
| 374 |
+
row = await cursor.fetchone()
|
| 375 |
+
if not row:
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
experiment = dict(row)
|
| 379 |
+
experiment["is_half"] = bool(experiment.get("is_half", 1))
|
| 380 |
+
|
| 381 |
+
# 获取所有阶段状态
|
| 382 |
+
stages = {}
|
| 383 |
+
async with db.execute(
|
| 384 |
+
"SELECT * FROM stages WHERE experiment_id = ?", (exp_id,)
|
| 385 |
+
) as cursor:
|
| 386 |
+
stage_rows = await cursor.fetchall()
|
| 387 |
+
for stage_row in stage_rows:
|
| 388 |
+
stage = dict(stage_row)
|
| 389 |
+
stage_type = stage["stage_type"]
|
| 390 |
+
|
| 391 |
+
# 解析 JSON 字段
|
| 392 |
+
for json_field in ("config", "outputs"):
|
| 393 |
+
if stage.get(json_field) and isinstance(stage[json_field], str):
|
| 394 |
+
try:
|
| 395 |
+
stage[json_field] = json.loads(stage[json_field])
|
| 396 |
+
except json.JSONDecodeError:
|
| 397 |
+
stage[json_field] = None
|
| 398 |
+
|
| 399 |
+
stages[stage_type] = stage
|
| 400 |
+
|
| 401 |
+
experiment["stages"] = stages
|
| 402 |
+
return experiment
|
| 403 |
+
|
| 404 |
+
async def update_experiment(
|
| 405 |
+
self,
|
| 406 |
+
exp_id: str,
|
| 407 |
+
updates: Dict[str, Any]
|
| 408 |
+
) -> Optional[Dict[str, Any]]:
|
| 409 |
+
"""更新实验"""
|
| 410 |
+
if not updates:
|
| 411 |
+
return await self.get_experiment(exp_id)
|
| 412 |
+
|
| 413 |
+
# 处理 is_half 布尔值
|
| 414 |
+
processed = {}
|
| 415 |
+
for key, value in updates.items():
|
| 416 |
+
if key == "is_half":
|
| 417 |
+
processed[key] = 1 if value else 0
|
| 418 |
+
elif key == "updated_at" and isinstance(value, datetime):
|
| 419 |
+
processed[key] = value.isoformat()
|
| 420 |
+
elif key != "stages": # stages 单独处理
|
| 421 |
+
processed[key] = value
|
| 422 |
+
|
| 423 |
+
# 添加更新时间
|
| 424 |
+
if "updated_at" not in processed:
|
| 425 |
+
processed["updated_at"] = datetime.utcnow().isoformat()
|
| 426 |
+
|
| 427 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 428 |
+
if processed:
|
| 429 |
+
set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
|
| 430 |
+
values = list(processed.values()) + [exp_id]
|
| 431 |
+
|
| 432 |
+
await db.execute(
|
| 433 |
+
f"UPDATE experiments SET {set_clause} WHERE id = ?",
|
| 434 |
+
values
|
| 435 |
+
)
|
| 436 |
+
await db.commit()
|
| 437 |
+
|
| 438 |
+
return await self.get_experiment(exp_id)
|
| 439 |
+
|
| 440 |
+
async def list_experiments(
|
| 441 |
+
self,
|
| 442 |
+
status: Optional[str] = None,
|
| 443 |
+
limit: int = 50,
|
| 444 |
+
offset: int = 0
|
| 445 |
+
) -> List[Dict[str, Any]]:
|
| 446 |
+
"""查询实验列表"""
|
| 447 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 448 |
+
db.row_factory = aiosqlite.Row
|
| 449 |
+
|
| 450 |
+
if status:
|
| 451 |
+
query = """
|
| 452 |
+
SELECT * FROM experiments
|
| 453 |
+
WHERE status = ?
|
| 454 |
+
ORDER BY created_at DESC
|
| 455 |
+
LIMIT ? OFFSET ?
|
| 456 |
+
"""
|
| 457 |
+
params = (status, limit, offset)
|
| 458 |
+
else:
|
| 459 |
+
query = """
|
| 460 |
+
SELECT * FROM experiments
|
| 461 |
+
ORDER BY created_at DESC
|
| 462 |
+
LIMIT ? OFFSET ?
|
| 463 |
+
"""
|
| 464 |
+
params = (limit, offset)
|
| 465 |
+
|
| 466 |
+
async with db.execute(query, params) as cursor:
|
| 467 |
+
rows = await cursor.fetchall()
|
| 468 |
+
|
| 469 |
+
results = []
|
| 470 |
+
for row in rows:
|
| 471 |
+
exp = dict(row)
|
| 472 |
+
exp["is_half"] = bool(exp.get("is_half", 1))
|
| 473 |
+
# 简化列表,不包含完整的 stages
|
| 474 |
+
results.append(exp)
|
| 475 |
+
|
| 476 |
+
return results
|
| 477 |
+
|
| 478 |
+
async def delete_experiment(self, exp_id: str) -> bool:
|
| 479 |
+
"""删除实验及其阶段"""
|
| 480 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 481 |
+
# 先删除阶段
|
| 482 |
+
await db.execute(
|
| 483 |
+
"DELETE FROM stages WHERE experiment_id = ?", (exp_id,)
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# 再删除实验
|
| 487 |
+
cursor = await db.execute(
|
| 488 |
+
"DELETE FROM experiments WHERE id = ?", (exp_id,)
|
| 489 |
+
)
|
| 490 |
+
await db.commit()
|
| 491 |
+
return cursor.rowcount > 0
|
| 492 |
+
|
| 493 |
+
# ============================================================
|
| 494 |
+
# Stage 操作 (Advanced Mode)
|
| 495 |
+
# ============================================================
|
| 496 |
+
|
| 497 |
+
async def update_stage(
|
| 498 |
+
self,
|
| 499 |
+
exp_id: str,
|
| 500 |
+
stage_type: str,
|
| 501 |
+
updates: Dict[str, Any]
|
| 502 |
+
) -> Optional[Dict[str, Any]]:
|
| 503 |
+
"""更新阶段状态"""
|
| 504 |
+
if not updates:
|
| 505 |
+
return await self.get_stage(exp_id, stage_type)
|
| 506 |
+
|
| 507 |
+
# 处理 JSON 字段
|
| 508 |
+
processed = {}
|
| 509 |
+
for key, value in updates.items():
|
| 510 |
+
if key in ("config", "outputs") and isinstance(value, dict):
|
| 511 |
+
processed[key] = json.dumps(value, ensure_ascii=False)
|
| 512 |
+
elif key in ("started_at", "completed_at") and isinstance(value, datetime):
|
| 513 |
+
processed[key] = value.isoformat()
|
| 514 |
+
else:
|
| 515 |
+
processed[key] = value
|
| 516 |
+
|
| 517 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 518 |
+
set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
|
| 519 |
+
values = list(processed.values()) + [exp_id, stage_type]
|
| 520 |
+
|
| 521 |
+
await db.execute(
|
| 522 |
+
f"UPDATE stages SET {set_clause} WHERE experiment_id = ? AND stage_type = ?",
|
| 523 |
+
values
|
| 524 |
+
)
|
| 525 |
+
await db.commit()
|
| 526 |
+
|
| 527 |
+
# 同时更新实验的 updated_at
|
| 528 |
+
await self.update_experiment(exp_id, {})
|
| 529 |
+
|
| 530 |
+
return await self.get_stage(exp_id, stage_type)
|
| 531 |
+
|
| 532 |
+
async def get_stage(
|
| 533 |
+
self,
|
| 534 |
+
exp_id: str,
|
| 535 |
+
stage_type: str
|
| 536 |
+
) -> Optional[Dict[str, Any]]:
|
| 537 |
+
"""获取阶段状态"""
|
| 538 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 539 |
+
db.row_factory = aiosqlite.Row
|
| 540 |
+
|
| 541 |
+
async with db.execute(
|
| 542 |
+
"SELECT * FROM stages WHERE experiment_id = ? AND stage_type = ?",
|
| 543 |
+
(exp_id, stage_type)
|
| 544 |
+
) as cursor:
|
| 545 |
+
row = await cursor.fetchone()
|
| 546 |
+
if not row:
|
| 547 |
+
return None
|
| 548 |
+
|
| 549 |
+
stage = dict(row)
|
| 550 |
+
|
| 551 |
+
# 解析 JSON 字段
|
| 552 |
+
for json_field in ("config", "outputs"):
|
| 553 |
+
if stage.get(json_field) and isinstance(stage[json_field], str):
|
| 554 |
+
try:
|
| 555 |
+
stage[json_field] = json.loads(stage[json_field])
|
| 556 |
+
except json.JSONDecodeError:
|
| 557 |
+
stage[json_field] = None
|
| 558 |
+
|
| 559 |
+
return stage
|
| 560 |
+
|
| 561 |
+
async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]:
|
| 562 |
+
"""获取实验的所有阶段状态"""
|
| 563 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 564 |
+
db.row_factory = aiosqlite.Row
|
| 565 |
+
|
| 566 |
+
async with db.execute(
|
| 567 |
+
"SELECT * FROM stages WHERE experiment_id = ? ORDER BY id",
|
| 568 |
+
(exp_id,)
|
| 569 |
+
) as cursor:
|
| 570 |
+
rows = await cursor.fetchall()
|
| 571 |
+
|
| 572 |
+
results = []
|
| 573 |
+
for row in rows:
|
| 574 |
+
stage = dict(row)
|
| 575 |
+
|
| 576 |
+
# 解析 JSON 字段
|
| 577 |
+
for json_field in ("config", "outputs"):
|
| 578 |
+
if stage.get(json_field) and isinstance(stage[json_field], str):
|
| 579 |
+
try:
|
| 580 |
+
stage[json_field] = json.loads(stage[json_field])
|
| 581 |
+
except json.JSONDecodeError:
|
| 582 |
+
stage[json_field] = None
|
| 583 |
+
|
| 584 |
+
results.append(stage)
|
| 585 |
+
|
| 586 |
+
return results
|
| 587 |
+
|
| 588 |
+
# ============================================================
|
| 589 |
+
# File 记录
|
| 590 |
+
# ============================================================
|
| 591 |
+
|
| 592 |
+
async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 593 |
+
"""创建文件记录"""
|
| 594 |
+
file_id = file_data.get("id") or str(uuid.uuid4())
|
| 595 |
+
now = datetime.utcnow().isoformat()
|
| 596 |
+
|
| 597 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 598 |
+
await db.execute(
|
| 599 |
+
'''INSERT INTO files
|
| 600 |
+
(id, filename, content_type, size_bytes, purpose,
|
| 601 |
+
duration_seconds, sample_rate, storage_path, uploaded_at)
|
| 602 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
|
| 603 |
+
(
|
| 604 |
+
file_id,
|
| 605 |
+
file_data["filename"],
|
| 606 |
+
file_data.get("content_type"),
|
| 607 |
+
file_data.get("size_bytes", 0),
|
| 608 |
+
file_data.get("purpose", "training"),
|
| 609 |
+
file_data.get("duration_seconds"),
|
| 610 |
+
file_data.get("sample_rate"),
|
| 611 |
+
file_data.get("storage_path"),
|
| 612 |
+
file_data.get("uploaded_at", now),
|
| 613 |
+
)
|
| 614 |
+
)
|
| 615 |
+
await db.commit()
|
| 616 |
+
|
| 617 |
+
return await self.get_file_record(file_id)
|
| 618 |
+
|
| 619 |
+
async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
|
| 620 |
+
"""获取文件记录"""
|
| 621 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 622 |
+
db.row_factory = aiosqlite.Row
|
| 623 |
+
|
| 624 |
+
async with db.execute(
|
| 625 |
+
"SELECT * FROM files WHERE id = ?", (file_id,)
|
| 626 |
+
) as cursor:
|
| 627 |
+
row = await cursor.fetchone()
|
| 628 |
+
if row:
|
| 629 |
+
return dict(row)
|
| 630 |
+
return None
|
| 631 |
+
|
| 632 |
+
async def delete_file_record(self, file_id: str) -> bool:
|
| 633 |
+
"""删除文件记录"""
|
| 634 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 635 |
+
cursor = await db.execute(
|
| 636 |
+
"DELETE FROM files WHERE id = ?", (file_id,)
|
| 637 |
+
)
|
| 638 |
+
await db.commit()
|
| 639 |
+
return cursor.rowcount > 0
|
| 640 |
+
|
| 641 |
+
async def list_file_records(
|
| 642 |
+
self,
|
| 643 |
+
purpose: Optional[str] = None,
|
| 644 |
+
limit: int = 50,
|
| 645 |
+
offset: int = 0
|
| 646 |
+
) -> List[Dict[str, Any]]:
|
| 647 |
+
"""查询文件记录列表"""
|
| 648 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 649 |
+
db.row_factory = aiosqlite.Row
|
| 650 |
+
|
| 651 |
+
if purpose:
|
| 652 |
+
query = """
|
| 653 |
+
SELECT * FROM files
|
| 654 |
+
WHERE purpose = ?
|
| 655 |
+
ORDER BY uploaded_at DESC
|
| 656 |
+
LIMIT ? OFFSET ?
|
| 657 |
+
"""
|
| 658 |
+
params = (purpose, limit, offset)
|
| 659 |
+
else:
|
| 660 |
+
query = """
|
| 661 |
+
SELECT * FROM files
|
| 662 |
+
ORDER BY uploaded_at DESC
|
| 663 |
+
LIMIT ? OFFSET ?
|
| 664 |
+
"""
|
| 665 |
+
params = (limit, offset)
|
| 666 |
+
|
| 667 |
+
async with db.execute(query, params) as cursor:
|
| 668 |
+
rows = await cursor.fetchall()
|
| 669 |
+
return [dict(row) for row in rows]
|
| 670 |
+
|
| 671 |
+
async def count_file_records(self, purpose: Optional[str] = None) -> int:
|
| 672 |
+
"""统计文件记录数量"""
|
| 673 |
+
async with aiosqlite.connect(self.db_path) as db:
|
| 674 |
+
if purpose:
|
| 675 |
+
async with db.execute(
|
| 676 |
+
"SELECT COUNT(*) FROM files WHERE purpose = ?", (purpose,)
|
| 677 |
+
) as cursor:
|
| 678 |
+
row = await cursor.fetchone()
|
| 679 |
+
else:
|
| 680 |
+
async with db.execute("SELECT COUNT(*) FROM files") as cursor:
|
| 681 |
+
row = await cursor.fetchone()
|
| 682 |
+
|
| 683 |
+
return row[0] if row else 0
|
api_server/app/adapters/local/progress.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
本地进度管理适配器
|
| 3 |
+
|
| 4 |
+
基于内存队列实现的进度管理适配器,适用于本地单实例场景。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Any, AsyncGenerator, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
from ..base import ProgressAdapter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LocalProgressAdapter(ProgressAdapter):
|
| 16 |
+
"""
|
| 17 |
+
本地内存进度管理适配器
|
| 18 |
+
|
| 19 |
+
特点:
|
| 20 |
+
1. 使用内存字典存储最新进度
|
| 21 |
+
2. 使用 asyncio.Queue 实现订阅者模式
|
| 22 |
+
3. 支持多订阅者同时订阅同一任务
|
| 23 |
+
4. 与 AsyncTrainingManager 的进度推送机制兼容
|
| 24 |
+
|
| 25 |
+
注意:
|
| 26 |
+
- 进程重启后进度数据会丢失
|
| 27 |
+
- 仅适用于单实例部署
|
| 28 |
+
- 服务器模式应使用 RedisProgressAdapter
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
>>> adapter = LocalProgressAdapter()
|
| 32 |
+
>>> await adapter.update_progress("task-123", {
|
| 33 |
+
... "stage": "sovits_train",
|
| 34 |
+
... "progress": 0.5,
|
| 35 |
+
... "message": "Epoch 8/16"
|
| 36 |
+
... })
|
| 37 |
+
>>>
|
| 38 |
+
>>> # 订阅进度
|
| 39 |
+
>>> async for progress in adapter.subscribe("task-123"):
|
| 40 |
+
... print(f"{progress['stage']}: {progress['progress']*100:.1f}%")
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
"""初始化本地进度适配器"""
|
| 45 |
+
# 存储每个任务的最新进度
|
| 46 |
+
self.progress_store: Dict[str, Dict[str, Any]] = {}
|
| 47 |
+
|
| 48 |
+
# 存储每个任务的订阅者队列列表
|
| 49 |
+
self.subscribers: Dict[str, List[asyncio.Queue]] = defaultdict(list)
|
| 50 |
+
|
| 51 |
+
# 锁,用于保护订阅者列表的并发访问
|
| 52 |
+
self._lock = asyncio.Lock()
|
| 53 |
+
|
| 54 |
+
async def update_progress(self, task_id: str, progress: Dict[str, Any]) -> None:
|
| 55 |
+
"""
|
| 56 |
+
更新进度
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
task_id: 任务ID
|
| 60 |
+
progress: 进度信息字典,可包含:
|
| 61 |
+
- type: 消息类型 ("progress", "log", "error", "heartbeat")
|
| 62 |
+
- stage: 当前阶段
|
| 63 |
+
- progress: 阶段进度 (0.0-1.0)
|
| 64 |
+
- overall_progress: 总体进度 (0.0-1.0)
|
| 65 |
+
- message: 进度消息
|
| 66 |
+
- status: 状态 ("running", "completed", "failed", "cancelled")
|
| 67 |
+
"""
|
| 68 |
+
# 添加时间戳
|
| 69 |
+
if "timestamp" not in progress:
|
| 70 |
+
progress["timestamp"] = datetime.utcnow().isoformat()
|
| 71 |
+
|
| 72 |
+
# 存储最新进度
|
| 73 |
+
self.progress_store[task_id] = progress
|
| 74 |
+
|
| 75 |
+
# 通知所有订阅者
|
| 76 |
+
async with self._lock:
|
| 77 |
+
if task_id in self.subscribers:
|
| 78 |
+
for queue in self.subscribers[task_id]:
|
| 79 |
+
try:
|
| 80 |
+
await queue.put(progress)
|
| 81 |
+
except asyncio.QueueFull:
|
| 82 |
+
# 队列满了,跳过(避免阻塞)
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
async def get_progress(self, task_id: str) -> Optional[Dict[str, Any]]:
|
| 86 |
+
"""
|
| 87 |
+
获取当前进度
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
task_id: 任务ID
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
最新进度信息,不存在则返回 None
|
| 94 |
+
"""
|
| 95 |
+
return self.progress_store.get(task_id)
|
| 96 |
+
|
| 97 |
+
async def subscribe(self, task_id: str) -> AsyncGenerator[Dict[str, Any], None]:
|
| 98 |
+
"""
|
| 99 |
+
订阅进度更新
|
| 100 |
+
|
| 101 |
+
创建一个异步生成器,持续接收指定任务的进度更新。
|
| 102 |
+
当任务进入终态(completed, failed, cancelled)时自动结束。
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
task_id: 任务ID
|
| 106 |
+
|
| 107 |
+
Yields:
|
| 108 |
+
进度信息字典
|
| 109 |
+
|
| 110 |
+
Example:
|
| 111 |
+
>>> async for progress in adapter.subscribe("task-123"):
|
| 112 |
+
... print(progress)
|
| 113 |
+
... if progress.get("status") == "completed":
|
| 114 |
+
... break
|
| 115 |
+
"""
|
| 116 |
+
# 创建订阅者队列
|
| 117 |
+
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
| 118 |
+
|
| 119 |
+
async with self._lock:
|
| 120 |
+
self.subscribers[task_id].append(queue)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# 首先发送当前进度(如果有)
|
| 124 |
+
current = self.progress_store.get(task_id)
|
| 125 |
+
if current:
|
| 126 |
+
yield current
|
| 127 |
+
# 如果已经是终态,直接返回
|
| 128 |
+
if current.get("status") in ("completed", "failed", "cancelled"):
|
| 129 |
+
return
|
| 130 |
+
|
| 131 |
+
# 持续接收更新
|
| 132 |
+
while True:
|
| 133 |
+
try:
|
| 134 |
+
# 30秒超时,发送心跳
|
| 135 |
+
progress = await asyncio.wait_for(queue.get(), timeout=30.0)
|
| 136 |
+
yield progress
|
| 137 |
+
|
| 138 |
+
# 检查是否为终态
|
| 139 |
+
if progress.get("status") in ("completed", "failed", "cancelled"):
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
except asyncio.TimeoutError:
|
| 143 |
+
# 发送心跳保持连接
|
| 144 |
+
yield {
|
| 145 |
+
"type": "heartbeat",
|
| 146 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
finally:
|
| 150 |
+
# 清理订阅者
|
| 151 |
+
async with self._lock:
|
| 152 |
+
if task_id in self.subscribers:
|
| 153 |
+
try:
|
| 154 |
+
self.subscribers[task_id].remove(queue)
|
| 155 |
+
except ValueError:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
# 如果没有订阅者了,清理列表
|
| 159 |
+
if not self.subscribers[task_id]:
|
| 160 |
+
del self.subscribers[task_id]
|
| 161 |
+
|
| 162 |
+
async def clear_progress(self, task_id: str) -> None:
|
| 163 |
+
"""
|
| 164 |
+
清除任务进度数据
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
task_id: 任务ID
|
| 168 |
+
"""
|
| 169 |
+
self.progress_store.pop(task_id, None)
|
| 170 |
+
|
| 171 |
+
async with self._lock:
|
| 172 |
+
self.subscribers.pop(task_id, None)
|
| 173 |
+
|
| 174 |
+
async def get_subscriber_count(self, task_id: str) -> int:
|
| 175 |
+
"""
|
| 176 |
+
获取任务的订阅者数量
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
task_id: 任务ID
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
订阅者数量
|
| 183 |
+
"""
|
| 184 |
+
async with self._lock:
|
| 185 |
+
return len(self.subscribers.get(task_id, []))
|
| 186 |
+
|
| 187 |
+
async def broadcast_to_all(self, message: Dict[str, Any]) -> int:
|
| 188 |
+
"""
|
| 189 |
+
向所有任务的订阅者广播消息
|
| 190 |
+
|
| 191 |
+
用于系统级通知,如服务器关闭警告等。
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
message: 消息内容
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
发送成功的订阅者数量
|
| 198 |
+
"""
|
| 199 |
+
if "timestamp" not in message:
|
| 200 |
+
message["timestamp"] = datetime.utcnow().isoformat()
|
| 201 |
+
|
| 202 |
+
count = 0
|
| 203 |
+
async with self._lock:
|
| 204 |
+
for task_id, queues in self.subscribers.items():
|
| 205 |
+
for queue in queues:
|
| 206 |
+
try:
|
| 207 |
+
await queue.put(message)
|
| 208 |
+
count += 1
|
| 209 |
+
except asyncio.QueueFull:
|
| 210 |
+
pass
|
| 211 |
+
|
| 212 |
+
return count
|
| 213 |
+
|
| 214 |
+
def get_active_tasks(self) -> List[str]:
|
| 215 |
+
"""
|
| 216 |
+
获取有活跃订阅者的任务列表
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
任务ID列表
|
| 220 |
+
"""
|
| 221 |
+
return list(self.subscribers.keys())
|
| 222 |
+
|
| 223 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 224 |
+
"""
|
| 225 |
+
获取适配器统计信息
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
统计信息字典
|
| 229 |
+
"""
|
| 230 |
+
total_subscribers = sum(
|
| 231 |
+
len(queues) for queues in self.subscribers.values()
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
"stored_progress_count": len(self.progress_store),
|
| 236 |
+
"active_tasks": len(self.subscribers),
|
| 237 |
+
"total_subscribers": total_subscribers,
|
| 238 |
+
}
|
api_server/app/adapters/local/storage.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
本地文件存储适配器
|
| 3 |
+
|
| 4 |
+
基于本地文件系统实现的存储适配器,适用于 macOS 本地训练场景。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import mimetypes
|
| 9 |
+
import uuid
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, List, Optional
|
| 13 |
+
|
| 14 |
+
import aiofiles
|
| 15 |
+
|
| 16 |
+
from ..base import StorageAdapter
|
| 17 |
+
from ...core.config import settings
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LocalStorageAdapter(StorageAdapter):
|
| 21 |
+
"""
|
| 22 |
+
本地文件系统存储适配器
|
| 23 |
+
|
| 24 |
+
特点:
|
| 25 |
+
1. 使用 aiofiles 进行异步文件读写
|
| 26 |
+
2. 元数据存储在 .meta.json 文件中
|
| 27 |
+
3. 支持音频文件信息提取(时长、采样率等)
|
| 28 |
+
|
| 29 |
+
目录结构:
|
| 30 |
+
```
|
| 31 |
+
base_path/
|
| 32 |
+
├── {file_id} # 实际文件
|
| 33 |
+
└── {file_id}.meta.json # 元数据文件
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Example:
|
| 37 |
+
>>> adapter = LocalStorageAdapter()
|
| 38 |
+
>>> file_id = await adapter.upload_file(
|
| 39 |
+
... file_data=b"...",
|
| 40 |
+
... filename="audio.wav",
|
| 41 |
+
... metadata={"purpose": "training"}
|
| 42 |
+
... )
|
| 43 |
+
>>> content = await adapter.download_file(file_id)
|
| 44 |
+
>>> metadata = await adapter.get_file_metadata(file_id)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, base_path: Optional[str] = None):
|
| 48 |
+
"""
|
| 49 |
+
初始化本地存储适配器
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
base_path: 文件存储根目录,默认使用 settings.DATA_DIR / "files"
|
| 53 |
+
"""
|
| 54 |
+
if base_path:
|
| 55 |
+
self.base_path = Path(base_path)
|
| 56 |
+
else:
|
| 57 |
+
self.base_path = settings.DATA_DIR / "files"
|
| 58 |
+
|
| 59 |
+
# 确保目录存在
|
| 60 |
+
self.base_path.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
def _get_file_path(self, file_id: str) -> Path:
|
| 63 |
+
"""获取文件存储路径"""
|
| 64 |
+
return self.base_path / file_id
|
| 65 |
+
|
| 66 |
+
def _get_meta_path(self, file_id: str) -> Path:
|
| 67 |
+
"""获取元数据文件路径"""
|
| 68 |
+
return self.base_path / f"{file_id}.meta.json"
|
| 69 |
+
|
| 70 |
+
async def upload_file(
|
| 71 |
+
self,
|
| 72 |
+
file_data: bytes,
|
| 73 |
+
filename: str,
|
| 74 |
+
metadata: Dict[str, Any]
|
| 75 |
+
) -> str:
|
| 76 |
+
"""
|
| 77 |
+
上传文件到本地文件系统
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
file_data: 文件二进制数据
|
| 81 |
+
filename: 原始文件名
|
| 82 |
+
metadata: 文件元数据
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
file_id: 生成的文件唯一标识
|
| 86 |
+
"""
|
| 87 |
+
# 生成文件ID
|
| 88 |
+
file_id = str(uuid.uuid4())
|
| 89 |
+
|
| 90 |
+
# 确定文件扩展名
|
| 91 |
+
suffix = Path(filename).suffix
|
| 92 |
+
if suffix:
|
| 93 |
+
file_id = f"{file_id}{suffix}"
|
| 94 |
+
|
| 95 |
+
file_path = self._get_file_path(file_id)
|
| 96 |
+
meta_path = self._get_meta_path(file_id)
|
| 97 |
+
|
| 98 |
+
# 写入文件
|
| 99 |
+
async with aiofiles.open(file_path, 'wb') as f:
|
| 100 |
+
await f.write(file_data)
|
| 101 |
+
|
| 102 |
+
# 猜测 MIME 类型
|
| 103 |
+
content_type = metadata.get("content_type")
|
| 104 |
+
if not content_type:
|
| 105 |
+
content_type, _ = mimetypes.guess_type(filename)
|
| 106 |
+
content_type = content_type or "application/octet-stream"
|
| 107 |
+
|
| 108 |
+
# 构建元数据
|
| 109 |
+
file_metadata = {
|
| 110 |
+
"id": file_id,
|
| 111 |
+
"filename": filename,
|
| 112 |
+
"content_type": content_type,
|
| 113 |
+
"size_bytes": len(file_data),
|
| 114 |
+
"purpose": metadata.get("purpose", "training"),
|
| 115 |
+
"uploaded_at": datetime.utcnow().isoformat(),
|
| 116 |
+
**{k: v for k, v in metadata.items() if k not in ("content_type", "purpose")}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# 尝试提取音频信息
|
| 120 |
+
if content_type and content_type.startswith("audio/"):
|
| 121 |
+
audio_info = await self._extract_audio_info(file_path)
|
| 122 |
+
if audio_info:
|
| 123 |
+
file_metadata.update(audio_info)
|
| 124 |
+
|
| 125 |
+
# 写入元数据
|
| 126 |
+
async with aiofiles.open(meta_path, 'w', encoding='utf-8') as f:
|
| 127 |
+
await f.write(json.dumps(file_metadata, ensure_ascii=False, indent=2))
|
| 128 |
+
|
| 129 |
+
return file_id
|
| 130 |
+
|
| 131 |
+
async def download_file(self, file_id: str) -> bytes:
|
| 132 |
+
"""
|
| 133 |
+
下载文件
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
file_id: 文件唯一标识
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
文件二进制数据
|
| 140 |
+
|
| 141 |
+
Raises:
|
| 142 |
+
FileNotFoundError: 文件不存在时抛出
|
| 143 |
+
"""
|
| 144 |
+
file_path = self._get_file_path(file_id)
|
| 145 |
+
|
| 146 |
+
if not file_path.exists():
|
| 147 |
+
raise FileNotFoundError(f"File not found: {file_id}")
|
| 148 |
+
|
| 149 |
+
async with aiofiles.open(file_path, 'rb') as f:
|
| 150 |
+
return await f.read()
|
| 151 |
+
|
| 152 |
+
async def delete_file(self, file_id: str) -> bool:
|
| 153 |
+
"""
|
| 154 |
+
删除文件及其元数据
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
file_id: 文件唯一标识
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
是否成功删除
|
| 161 |
+
"""
|
| 162 |
+
file_path = self._get_file_path(file_id)
|
| 163 |
+
meta_path = self._get_meta_path(file_id)
|
| 164 |
+
|
| 165 |
+
deleted = False
|
| 166 |
+
|
| 167 |
+
# 删除文件
|
| 168 |
+
if file_path.exists():
|
| 169 |
+
file_path.unlink()
|
| 170 |
+
deleted = True
|
| 171 |
+
|
| 172 |
+
# 删除元数据
|
| 173 |
+
if meta_path.exists():
|
| 174 |
+
meta_path.unlink()
|
| 175 |
+
deleted = True
|
| 176 |
+
|
| 177 |
+
return deleted
|
| 178 |
+
|
| 179 |
+
async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
|
| 180 |
+
"""
|
| 181 |
+
获取文件元数据
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
file_id: 文件唯一标识
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
文件元数据字典,不存在则返回 None
|
| 188 |
+
"""
|
| 189 |
+
meta_path = self._get_meta_path(file_id)
|
| 190 |
+
|
| 191 |
+
if not meta_path.exists():
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
async with aiofiles.open(meta_path, 'r', encoding='utf-8') as f:
|
| 196 |
+
content = await f.read()
|
| 197 |
+
return json.loads(content)
|
| 198 |
+
except (json.JSONDecodeError, IOError):
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
async def list_files(
|
| 202 |
+
self,
|
| 203 |
+
purpose: Optional[str] = None,
|
| 204 |
+
limit: int = 50,
|
| 205 |
+
offset: int = 0
|
| 206 |
+
) -> List[Dict[str, Any]]:
|
| 207 |
+
"""
|
| 208 |
+
列出文件
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
purpose: 按用途筛选
|
| 212 |
+
limit: 返回数量限制
|
| 213 |
+
offset: 偏移量
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
文件元数据列表
|
| 217 |
+
"""
|
| 218 |
+
results = []
|
| 219 |
+
|
| 220 |
+
# 遍历所有 .meta.json 文件
|
| 221 |
+
meta_files = sorted(
|
| 222 |
+
self.base_path.glob("*.meta.json"),
|
| 223 |
+
key=lambda p: p.stat().st_mtime,
|
| 224 |
+
reverse=True # 最新的在前
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
for meta_path in meta_files:
|
| 228 |
+
try:
|
| 229 |
+
async with aiofiles.open(meta_path, 'r', encoding='utf-8') as f:
|
| 230 |
+
content = await f.read()
|
| 231 |
+
metadata = json.loads(content)
|
| 232 |
+
|
| 233 |
+
# 按用途筛选
|
| 234 |
+
if purpose and metadata.get("purpose") != purpose:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
results.append(metadata)
|
| 238 |
+
except (json.JSONDecodeError, IOError):
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
# 应用分页
|
| 242 |
+
return results[offset:offset + limit]
|
| 243 |
+
|
| 244 |
+
async def file_exists(self, file_id: str) -> bool:
|
| 245 |
+
"""
|
| 246 |
+
检查文件是否存在
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
file_id: 文件唯一标识
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
文件是否存在
|
| 253 |
+
"""
|
| 254 |
+
file_path = self._get_file_path(file_id)
|
| 255 |
+
return file_path.exists()
|
| 256 |
+
|
| 257 |
+
async def count_files(self, purpose: Optional[str] = None) -> int:
|
| 258 |
+
"""
|
| 259 |
+
统计文件数量
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
purpose: 按用途筛选
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
文件数量
|
| 266 |
+
"""
|
| 267 |
+
if not purpose:
|
| 268 |
+
# 直接计数 meta 文件
|
| 269 |
+
return len(list(self.base_path.glob("*.meta.json")))
|
| 270 |
+
|
| 271 |
+
# 需要筛选时读取元数据
|
| 272 |
+
count = 0
|
| 273 |
+
for meta_path in self.base_path.glob("*.meta.json"):
|
| 274 |
+
try:
|
| 275 |
+
async with aiofiles.open(meta_path, 'r', encoding='utf-8') as f:
|
| 276 |
+
content = await f.read()
|
| 277 |
+
metadata = json.loads(content)
|
| 278 |
+
if metadata.get("purpose") == purpose:
|
| 279 |
+
count += 1
|
| 280 |
+
except (json.JSONDecodeError, IOError):
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
return count
|
| 284 |
+
|
| 285 |
+
async def _extract_audio_info(self, file_path: Path) -> Optional[Dict[str, Any]]:
|
| 286 |
+
"""
|
| 287 |
+
提取音频文件信息(时长、采样率等)
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
file_path: 音频文件路径
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
音频信息字典,提取失败返回 None
|
| 294 |
+
"""
|
| 295 |
+
try:
|
| 296 |
+
# 尝试使用 soundfile(如果可用)
|
| 297 |
+
import soundfile as sf
|
| 298 |
+
|
| 299 |
+
info = sf.info(str(file_path))
|
| 300 |
+
return {
|
| 301 |
+
"duration_seconds": info.duration,
|
| 302 |
+
"sample_rate": info.samplerate,
|
| 303 |
+
"channels": info.channels,
|
| 304 |
+
}
|
| 305 |
+
except ImportError:
|
| 306 |
+
# soundfile 不可用,尝试使用 wave 模块处理 WAV 文件
|
| 307 |
+
if file_path.suffix.lower() == '.wav':
|
| 308 |
+
try:
|
| 309 |
+
import wave
|
| 310 |
+
|
| 311 |
+
with wave.open(str(file_path), 'rb') as wf:
|
| 312 |
+
frames = wf.getnframes()
|
| 313 |
+
rate = wf.getframerate()
|
| 314 |
+
channels = wf.getnchannels()
|
| 315 |
+
duration = frames / float(rate) if rate > 0 else 0
|
| 316 |
+
|
| 317 |
+
return {
|
| 318 |
+
"duration_seconds": duration,
|
| 319 |
+
"sample_rate": rate,
|
| 320 |
+
"channels": channels,
|
| 321 |
+
}
|
| 322 |
+
except Exception:
|
| 323 |
+
pass
|
| 324 |
+
except Exception:
|
| 325 |
+
pass
|
| 326 |
+
|
| 327 |
+
return None
|
| 328 |
+
|
| 329 |
+
async def get_file_path(self, file_id: str) -> Optional[Path]:
|
| 330 |
+
"""
|
| 331 |
+
获取文件的本地路径(供其他模块直接访问文件使用)
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
file_id: 文件唯一标识
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
文件路径,不存在则返回 None
|
| 338 |
+
"""
|
| 339 |
+
file_path = self._get_file_path(file_id)
|
| 340 |
+
if file_path.exists():
|
| 341 |
+
return file_path
|
| 342 |
+
return None
|
api_server/app/adapters/local/task_queue.py
CHANGED
|
@@ -13,13 +13,16 @@ import sys
|
|
| 13 |
import uuid
|
| 14 |
from datetime import datetime
|
| 15 |
from pathlib import Path
|
| 16 |
-
from typing import Dict, Optional, AsyncGenerator, List
|
| 17 |
|
| 18 |
import aiosqlite
|
| 19 |
|
| 20 |
from ..base import TaskQueueAdapter
|
| 21 |
from ...core.config import settings, PROJECT_ROOT, get_pythonpath
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
# 进度消息标识符(与 run_pipeline.py 保持一致)
|
| 24 |
PROGRESS_PREFIX = "##PROGRESS##"
|
| 25 |
PROGRESS_SUFFIX = "##"
|
|
@@ -47,22 +50,32 @@ class AsyncTrainingManager(TaskQueueAdapter):
|
|
| 47 |
>>> await manager.cancel(job_id)
|
| 48 |
"""
|
| 49 |
|
| 50 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
初始化任务管理器
|
| 53 |
|
| 54 |
Args:
|
| 55 |
db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH
|
| 56 |
max_concurrent: 最大并发任务数(本地通常为1)
|
|
|
|
| 57 |
"""
|
| 58 |
self.db_path = db_path or str(settings.SQLITE_PATH)
|
| 59 |
self.max_concurrent = max_concurrent
|
|
|
|
| 60 |
|
| 61 |
# 运行时状态
|
| 62 |
self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process
|
| 63 |
self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue
|
| 64 |
self._running_count = 0
|
| 65 |
self._queue_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# 初始化数据库
|
| 68 |
self._init_db_sync()
|
|
@@ -123,6 +136,9 @@ class AsyncTrainingManager(TaskQueueAdapter):
|
|
| 123 |
)
|
| 124 |
await db.commit()
|
| 125 |
|
|
|
|
|
|
|
|
|
|
| 126 |
# 创建进度队列
|
| 127 |
self.progress_channels[task_id] = asyncio.Queue()
|
| 128 |
|
|
@@ -381,6 +397,8 @@ class AsyncTrainingManager(TaskQueueAdapter):
|
|
| 381 |
"""
|
| 382 |
更新任务状态
|
| 383 |
|
|
|
|
|
|
|
| 384 |
Args:
|
| 385 |
job_id: 作业ID
|
| 386 |
**kwargs: 要更新的字段
|
|
@@ -388,6 +406,8 @@ class AsyncTrainingManager(TaskQueueAdapter):
|
|
| 388 |
if not kwargs:
|
| 389 |
return
|
| 390 |
|
|
|
|
|
|
|
| 391 |
async with aiosqlite.connect(self.db_path) as db:
|
| 392 |
updates = []
|
| 393 |
values = []
|
|
@@ -403,6 +423,57 @@ class AsyncTrainingManager(TaskQueueAdapter):
|
|
| 403 |
values
|
| 404 |
)
|
| 405 |
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
async def get_status(self, job_id: str) -> Dict:
|
| 408 |
"""
|
|
|
|
| 13 |
import uuid
|
| 14 |
from datetime import datetime
|
| 15 |
from pathlib import Path
|
| 16 |
+
from typing import TYPE_CHECKING, Dict, Optional, AsyncGenerator, List
|
| 17 |
|
| 18 |
import aiosqlite
|
| 19 |
|
| 20 |
from ..base import TaskQueueAdapter
|
| 21 |
from ...core.config import settings, PROJECT_ROOT, get_pythonpath
|
| 22 |
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from ..base import DatabaseAdapter
|
| 25 |
+
|
| 26 |
# 进度消息标识符(与 run_pipeline.py 保持一致)
|
| 27 |
PROGRESS_PREFIX = "##PROGRESS##"
|
| 28 |
PROGRESS_SUFFIX = "##"
|
|
|
|
| 50 |
>>> await manager.cancel(job_id)
|
| 51 |
"""
|
| 52 |
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
db_path: str = None,
|
| 56 |
+
max_concurrent: int = 1,
|
| 57 |
+
database_adapter: "DatabaseAdapter" = None
|
| 58 |
+
):
|
| 59 |
"""
|
| 60 |
初始化任务管理器
|
| 61 |
|
| 62 |
Args:
|
| 63 |
db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH
|
| 64 |
max_concurrent: 最大并发任务数(本地通常为1)
|
| 65 |
+
database_adapter: 数据库适配器,用于同步更新 tasks 表
|
| 66 |
"""
|
| 67 |
self.db_path = db_path or str(settings.SQLITE_PATH)
|
| 68 |
self.max_concurrent = max_concurrent
|
| 69 |
+
self._database_adapter = database_adapter
|
| 70 |
|
| 71 |
# 运行时状态
|
| 72 |
self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process
|
| 73 |
self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue
|
| 74 |
self._running_count = 0
|
| 75 |
self._queue_lock = asyncio.Lock()
|
| 76 |
+
|
| 77 |
+
# task_id 到 job_id 的映射缓存
|
| 78 |
+
self._task_job_mapping: Dict[str, str] = {}
|
| 79 |
|
| 80 |
# 初始化数据库
|
| 81 |
self._init_db_sync()
|
|
|
|
| 136 |
)
|
| 137 |
await db.commit()
|
| 138 |
|
| 139 |
+
# 缓存 task_id -> job_id 映射
|
| 140 |
+
self._task_job_mapping[task_id] = job_id
|
| 141 |
+
|
| 142 |
# 创建进度队列
|
| 143 |
self.progress_channels[task_id] = asyncio.Queue()
|
| 144 |
|
|
|
|
| 397 |
"""
|
| 398 |
更新任务状态
|
| 399 |
|
| 400 |
+
同时更新 task_queue 表和 tasks 表(通过 DatabaseAdapter)。
|
| 401 |
+
|
| 402 |
Args:
|
| 403 |
job_id: 作业ID
|
| 404 |
**kwargs: 要更新的字段
|
|
|
|
| 406 |
if not kwargs:
|
| 407 |
return
|
| 408 |
|
| 409 |
+
# 1. 更新 task_queue 表
|
| 410 |
+
task_id = None
|
| 411 |
async with aiosqlite.connect(self.db_path) as db:
|
| 412 |
updates = []
|
| 413 |
values = []
|
|
|
|
| 423 |
values
|
| 424 |
)
|
| 425 |
await db.commit()
|
| 426 |
+
|
| 427 |
+
# 获取 task_id 用于同步更新 tasks 表
|
| 428 |
+
async with db.execute(
|
| 429 |
+
"SELECT task_id FROM task_queue WHERE job_id = ?", (job_id,)
|
| 430 |
+
) as cursor:
|
| 431 |
+
row = await cursor.fetchone()
|
| 432 |
+
if row:
|
| 433 |
+
task_id = row[0]
|
| 434 |
+
|
| 435 |
+
# 2. 同步更新 tasks 表(通过 DatabaseAdapter)
|
| 436 |
+
if self._database_adapter and task_id:
|
| 437 |
+
await self._sync_to_tasks_table(task_id, kwargs)
|
| 438 |
+
|
| 439 |
+
async def _sync_to_tasks_table(self, task_id: str, updates: Dict) -> None:
|
| 440 |
+
"""
|
| 441 |
+
同步状态更新到 tasks 表
|
| 442 |
+
|
| 443 |
+
字段映射:
|
| 444 |
+
- task_queue.progress -> tasks.stage_progress
|
| 445 |
+
- task_queue.overall_progress -> tasks.progress
|
| 446 |
+
- 其他字段直接映射
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
task_id: 任务ID
|
| 450 |
+
updates: 要更新的字段字典
|
| 451 |
+
"""
|
| 452 |
+
if not self._database_adapter:
|
| 453 |
+
return
|
| 454 |
+
|
| 455 |
+
# 字段映射
|
| 456 |
+
tasks_updates = {}
|
| 457 |
+
|
| 458 |
+
for key, value in updates.items():
|
| 459 |
+
if key == 'progress':
|
| 460 |
+
# task_queue.progress -> tasks.stage_progress
|
| 461 |
+
tasks_updates['stage_progress'] = value
|
| 462 |
+
elif key == 'overall_progress':
|
| 463 |
+
# task_queue.overall_progress -> tasks.progress
|
| 464 |
+
tasks_updates['progress'] = value
|
| 465 |
+
elif key in ('status', 'current_stage', 'message', 'error_message',
|
| 466 |
+
'started_at', 'completed_at'):
|
| 467 |
+
# 直接映射的字段
|
| 468 |
+
tasks_updates[key] = value
|
| 469 |
+
|
| 470 |
+
if tasks_updates:
|
| 471 |
+
try:
|
| 472 |
+
await self._database_adapter.update_task(task_id, tasks_updates)
|
| 473 |
+
except Exception as e:
|
| 474 |
+
# 记录错误但不中断主流程
|
| 475 |
+
import logging
|
| 476 |
+
logging.warning(f"Failed to sync task status to tasks table: {e}")
|
| 477 |
|
| 478 |
async def get_status(self, job_id: str) -> Dict:
|
| 479 |
"""
|
api_server/app/api/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API 模块
|
| 3 |
+
|
| 4 |
+
包含所有 API 路由和端点
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .v1.router import api_router
|
| 8 |
+
|
| 9 |
+
__all__ = ["api_router"]
|
api_server/app/api/deps.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
依赖注入模块
|
| 3 |
+
|
| 4 |
+
提供 FastAPI 依赖注入函数,用于获取服务和适配器实例
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import Generator
|
| 9 |
+
|
| 10 |
+
from ..services.task_service import TaskService
|
| 11 |
+
from ..services.experiment_service import ExperimentService
|
| 12 |
+
from ..services.file_service import FileService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ============================================================
|
| 16 |
+
# 服务依赖
|
| 17 |
+
# ============================================================
|
| 18 |
+
|
| 19 |
+
@lru_cache()
|
| 20 |
+
def get_task_service() -> TaskService:
|
| 21 |
+
"""
|
| 22 |
+
获取 TaskService 实例
|
| 23 |
+
|
| 24 |
+
使用 lru_cache 确保单例
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
TaskService 实例
|
| 28 |
+
|
| 29 |
+
Example:
|
| 30 |
+
>>> @router.post("/tasks")
|
| 31 |
+
... async def create_task(
|
| 32 |
+
... request: QuickModeRequest,
|
| 33 |
+
... service: TaskService = Depends(get_task_service)
|
| 34 |
+
... ):
|
| 35 |
+
... return await service.create_quick_task(request)
|
| 36 |
+
"""
|
| 37 |
+
return TaskService()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@lru_cache()
|
| 41 |
+
def get_experiment_service() -> ExperimentService:
|
| 42 |
+
"""
|
| 43 |
+
获取 ExperimentService 实例
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
ExperimentService 实例
|
| 47 |
+
"""
|
| 48 |
+
return ExperimentService()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache()
|
| 52 |
+
def get_file_service() -> FileService:
|
| 53 |
+
"""
|
| 54 |
+
获取 FileService 实例
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
FileService 实例
|
| 58 |
+
"""
|
| 59 |
+
return FileService()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ============================================================
|
| 63 |
+
# 通用依赖
|
| 64 |
+
# ============================================================
|
| 65 |
+
|
| 66 |
+
async def get_pagination_params(
|
| 67 |
+
limit: int = 50,
|
| 68 |
+
offset: int = 0
|
| 69 |
+
) -> dict:
|
| 70 |
+
"""
|
| 71 |
+
分页参数依赖
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
limit: 每页数量,默认 50,最大 100
|
| 75 |
+
offset: 偏移量,默认 0
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
分页参数字典
|
| 79 |
+
"""
|
| 80 |
+
# 限制最大值
|
| 81 |
+
if limit > 100:
|
| 82 |
+
limit = 100
|
| 83 |
+
if limit < 1:
|
| 84 |
+
limit = 1
|
| 85 |
+
if offset < 0:
|
| 86 |
+
offset = 0
|
| 87 |
+
|
| 88 |
+
return {"limit": limit, "offset": offset}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
__all__ = [
|
| 92 |
+
"get_task_service",
|
| 93 |
+
"get_experiment_service",
|
| 94 |
+
"get_file_service",
|
| 95 |
+
"get_pagination_params",
|
| 96 |
+
]
|
api_server/app/api/v1/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API v1 模块
|
| 3 |
+
|
| 4 |
+
包含 v1 版本的所有 API 端点
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .router import api_router
|
| 8 |
+
|
| 9 |
+
__all__ = ["api_router"]
|
api_server/app/api/v1/endpoints/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API v1 端点模块
|
| 3 |
+
|
| 4 |
+
包含所有 API 端点实现
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from . import tasks
|
| 8 |
+
from . import experiments
|
| 9 |
+
from . import files
|
| 10 |
+
from . import stages
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"tasks",
|
| 14 |
+
"experiments",
|
| 15 |
+
"files",
|
| 16 |
+
"stages",
|
| 17 |
+
]
|
api_server/app/api/v1/endpoints/experiments.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Mode 实验 API
|
| 3 |
+
|
| 4 |
+
专家用户分阶段训练 API 端点
|
| 5 |
+
|
| 6 |
+
API 列表:
|
| 7 |
+
- POST /experiments 创建实验
|
| 8 |
+
- GET /experiments 获取实验列表
|
| 9 |
+
- GET /experiments/{exp_id} 获取实验详情
|
| 10 |
+
- PATCH /experiments/{exp_id} 更新实验配置
|
| 11 |
+
- DELETE /experiments/{exp_id} 删除实验
|
| 12 |
+
- POST /experiments/{exp_id}/stages/{stage_type} 执行阶段
|
| 13 |
+
- GET /experiments/{exp_id}/stages 获取所有阶段状态
|
| 14 |
+
- GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情
|
| 15 |
+
- DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段
|
| 16 |
+
- GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
from typing import Any, Dict, Optional
|
| 21 |
+
|
| 22 |
+
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
| 23 |
+
from fastapi.responses import StreamingResponse
|
| 24 |
+
|
| 25 |
+
from ....models.schemas.experiment import (
|
| 26 |
+
ExperimentCreate,
|
| 27 |
+
ExperimentUpdate,
|
| 28 |
+
ExperimentResponse,
|
| 29 |
+
ExperimentListResponse,
|
| 30 |
+
StageStatus,
|
| 31 |
+
StageExecuteResponse,
|
| 32 |
+
StagesListResponse,
|
| 33 |
+
STAGE_DEPENDENCIES,
|
| 34 |
+
get_stage_params_class,
|
| 35 |
+
)
|
| 36 |
+
from ....models.schemas.common import SuccessResponse, ErrorResponse
|
| 37 |
+
from ....services.experiment_service import ExperimentService
|
| 38 |
+
from ...deps import get_experiment_service
|
| 39 |
+
|
| 40 |
+
router = APIRouter()
|
| 41 |
+
|
| 42 |
+
# 有效的阶段类型
|
| 43 |
+
VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys())
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@router.post(
|
| 47 |
+
"",
|
| 48 |
+
response_model=ExperimentResponse,
|
| 49 |
+
summary="创建实验",
|
| 50 |
+
description="""
|
| 51 |
+
创建实验(专家用户)。
|
| 52 |
+
|
| 53 |
+
创建实验但不立即执行,用户可以逐阶段控制训练流程。
|
| 54 |
+
实验创建后,所有阶段状态为 `pending`,需要手动触发执行。
|
| 55 |
+
|
| 56 |
+
**训练阶段**:
|
| 57 |
+
- `audio_slice`: 音频切片
|
| 58 |
+
- `asr`: 语音识别
|
| 59 |
+
- `text_feature`: 文本特征提取
|
| 60 |
+
- `hubert_feature`: HuBERT 特征提取
|
| 61 |
+
- `semantic_token`: 语义 Token 提取
|
| 62 |
+
- `sovits_train`: SoVITS 训练
|
| 63 |
+
- `gpt_train`: GPT 训练
|
| 64 |
+
""",
|
| 65 |
+
)
|
| 66 |
+
async def create_experiment(
|
| 67 |
+
request: ExperimentCreate,
|
| 68 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 69 |
+
) -> ExperimentResponse:
|
| 70 |
+
"""
|
| 71 |
+
创建实验
|
| 72 |
+
"""
|
| 73 |
+
return await service.create_experiment(request)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@router.get(
|
| 77 |
+
"",
|
| 78 |
+
response_model=ExperimentListResponse,
|
| 79 |
+
summary="获取实验列表",
|
| 80 |
+
description="获取所有实验列表,支持按状态筛选和分页。",
|
| 81 |
+
)
|
| 82 |
+
async def list_experiments(
|
| 83 |
+
status: Optional[str] = Query(None, description="按状态筛选"),
|
| 84 |
+
limit: int = Query(50, ge=1, le=100, description="每页数量"),
|
| 85 |
+
offset: int = Query(0, ge=0, description="偏移量"),
|
| 86 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 87 |
+
) -> ExperimentListResponse:
|
| 88 |
+
"""
|
| 89 |
+
获取实验列表
|
| 90 |
+
"""
|
| 91 |
+
return await service.list_experiments(status=status, limit=limit, offset=offset)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@router.get(
|
| 95 |
+
"/{exp_id}",
|
| 96 |
+
response_model=ExperimentResponse,
|
| 97 |
+
summary="获取实验详情",
|
| 98 |
+
description="获取指定实验的详细信息,包括所有阶段状态。",
|
| 99 |
+
responses={
|
| 100 |
+
404: {"model": ErrorResponse, "description": "实验不存在"},
|
| 101 |
+
},
|
| 102 |
+
)
|
| 103 |
+
async def get_experiment(
|
| 104 |
+
exp_id: str,
|
| 105 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 106 |
+
) -> ExperimentResponse:
|
| 107 |
+
"""
|
| 108 |
+
获取实验详情
|
| 109 |
+
"""
|
| 110 |
+
experiment = await service.get_experiment(exp_id)
|
| 111 |
+
if not experiment:
|
| 112 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 113 |
+
return experiment
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@router.patch(
|
| 117 |
+
"/{exp_id}",
|
| 118 |
+
response_model=ExperimentResponse,
|
| 119 |
+
summary="更新实验配置",
|
| 120 |
+
description="更新实验的基础配置(非阶段参数)。",
|
| 121 |
+
responses={
|
| 122 |
+
404: {"model": ErrorResponse, "description": "实验不存在"},
|
| 123 |
+
},
|
| 124 |
+
)
|
| 125 |
+
async def update_experiment(
|
| 126 |
+
exp_id: str,
|
| 127 |
+
request: ExperimentUpdate,
|
| 128 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 129 |
+
) -> ExperimentResponse:
|
| 130 |
+
"""
|
| 131 |
+
更新实验配置
|
| 132 |
+
"""
|
| 133 |
+
experiment = await service.update_experiment(exp_id, request)
|
| 134 |
+
if not experiment:
|
| 135 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 136 |
+
return experiment
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@router.delete(
|
| 140 |
+
"/{exp_id}",
|
| 141 |
+
response_model=SuccessResponse,
|
| 142 |
+
summary="删除实验",
|
| 143 |
+
description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。",
|
| 144 |
+
responses={
|
| 145 |
+
404: {"model": ErrorResponse, "description": "实验不存在"},
|
| 146 |
+
},
|
| 147 |
+
)
|
| 148 |
+
async def delete_experiment(
|
| 149 |
+
exp_id: str,
|
| 150 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 151 |
+
) -> SuccessResponse:
|
| 152 |
+
"""
|
| 153 |
+
删除实验
|
| 154 |
+
"""
|
| 155 |
+
success = await service.delete_experiment(exp_id)
|
| 156 |
+
if not success:
|
| 157 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 158 |
+
return SuccessResponse(message="实验已删除")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@router.post(
|
| 162 |
+
"/{exp_id}/stages/{stage_type}",
|
| 163 |
+
response_model=StageExecuteResponse,
|
| 164 |
+
summary="执行阶段",
|
| 165 |
+
description="""
|
| 166 |
+
执行指定阶段。
|
| 167 |
+
|
| 168 |
+
**阶段依赖关系**:
|
| 169 |
+
- `audio_slice`: 无依赖
|
| 170 |
+
- `asr`: 依赖 audio_slice
|
| 171 |
+
- `text_feature`: 依赖 asr
|
| 172 |
+
- `hubert_feature`: 依赖 audio_slice
|
| 173 |
+
- `semantic_token`: 依赖 hubert_feature
|
| 174 |
+
- `sovits_train`: 依赖 text_feature, semantic_token
|
| 175 |
+
- `gpt_train`: 依赖 text_feature, semantic_token
|
| 176 |
+
|
| 177 |
+
如果依赖阶段未完成,会返回 400 错误。
|
| 178 |
+
如果阶段已完成,会重新执行(返回 `rerun: true`)。
|
| 179 |
+
""",
|
| 180 |
+
responses={
|
| 181 |
+
400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"},
|
| 182 |
+
404: {"model": ErrorResponse, "description": "实验不存在"},
|
| 183 |
+
},
|
| 184 |
+
)
|
| 185 |
+
async def execute_stage(
|
| 186 |
+
exp_id: str,
|
| 187 |
+
stage_type: str,
|
| 188 |
+
params: Dict[str, Any] = Body(default={}),
|
| 189 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 190 |
+
) -> StageExecuteResponse:
|
| 191 |
+
"""
|
| 192 |
+
执行阶段
|
| 193 |
+
"""
|
| 194 |
+
# 验证阶段类型
|
| 195 |
+
if stage_type not in VALID_STAGE_TYPES:
|
| 196 |
+
raise HTTPException(
|
| 197 |
+
status_code=400,
|
| 198 |
+
detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# 检查实验是否存在
|
| 202 |
+
experiment = await service.get_experiment(exp_id)
|
| 203 |
+
if not experiment:
|
| 204 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 205 |
+
|
| 206 |
+
# 检查依赖
|
| 207 |
+
deps = await service.check_stage_dependencies(exp_id, stage_type)
|
| 208 |
+
if not deps["satisfied"]:
|
| 209 |
+
raise HTTPException(
|
| 210 |
+
status_code=400,
|
| 211 |
+
detail=f"依赖阶段未完成: {', '.join(deps['missing'])}"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# 验证并解析参数
|
| 215 |
+
try:
|
| 216 |
+
params_class = get_stage_params_class(stage_type)
|
| 217 |
+
validated_params = params_class(**params)
|
| 218 |
+
params = validated_params.model_dump(exclude_unset=True)
|
| 219 |
+
except ValueError as e:
|
| 220 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 221 |
+
|
| 222 |
+
# 执行阶段
|
| 223 |
+
result = await service.execute_stage(exp_id, stage_type, params)
|
| 224 |
+
if not result:
|
| 225 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 226 |
+
|
| 227 |
+
return result
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@router.get(
|
| 231 |
+
"/{exp_id}/stages",
|
| 232 |
+
response_model=StagesListResponse,
|
| 233 |
+
summary="获取所有阶段状态",
|
| 234 |
+
description="获取实验的所有阶段状态列表。",
|
| 235 |
+
responses={
|
| 236 |
+
404: {"model": ErrorResponse, "description": "实验不存在"},
|
| 237 |
+
},
|
| 238 |
+
)
|
| 239 |
+
async def get_all_stages(
|
| 240 |
+
exp_id: str,
|
| 241 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 242 |
+
) -> StagesListResponse:
|
| 243 |
+
"""
|
| 244 |
+
获取所有阶段状态
|
| 245 |
+
"""
|
| 246 |
+
result = await service.get_all_stages(exp_id)
|
| 247 |
+
if not result:
|
| 248 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 249 |
+
return result
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@router.get(
|
| 253 |
+
"/{exp_id}/stages/{stage_type}",
|
| 254 |
+
response_model=StageStatus,
|
| 255 |
+
summary="获取阶段详情",
|
| 256 |
+
description="获取指定阶段的详细状态和结果。",
|
| 257 |
+
responses={
|
| 258 |
+
400: {"model": ErrorResponse, "description": "阶段类型无效"},
|
| 259 |
+
404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
|
| 260 |
+
},
|
| 261 |
+
)
|
| 262 |
+
async def get_stage(
|
| 263 |
+
exp_id: str,
|
| 264 |
+
stage_type: str,
|
| 265 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 266 |
+
) -> StageStatus:
|
| 267 |
+
"""
|
| 268 |
+
获取阶段详情
|
| 269 |
+
"""
|
| 270 |
+
# 验证阶段类型
|
| 271 |
+
if stage_type not in VALID_STAGE_TYPES:
|
| 272 |
+
raise HTTPException(
|
| 273 |
+
status_code=400,
|
| 274 |
+
detail=f"无效的阶段类型: {stage_type}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
stage = await service.get_stage(exp_id, stage_type)
|
| 278 |
+
if not stage:
|
| 279 |
+
raise HTTPException(status_code=404, detail="实验或阶段不存在")
|
| 280 |
+
return stage
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@router.delete(
|
| 284 |
+
"/{exp_id}/stages/{stage_type}",
|
| 285 |
+
response_model=SuccessResponse,
|
| 286 |
+
summary="取消阶段",
|
| 287 |
+
description="取消正在执行的阶段。只有运行中的阶段可以取消。",
|
| 288 |
+
responses={
|
| 289 |
+
400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"},
|
| 290 |
+
404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
|
| 291 |
+
},
|
| 292 |
+
)
|
| 293 |
+
async def cancel_stage(
|
| 294 |
+
exp_id: str,
|
| 295 |
+
stage_type: str,
|
| 296 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 297 |
+
) -> SuccessResponse:
|
| 298 |
+
"""
|
| 299 |
+
取消阶段
|
| 300 |
+
"""
|
| 301 |
+
# 验证阶段类型
|
| 302 |
+
if stage_type not in VALID_STAGE_TYPES:
|
| 303 |
+
raise HTTPException(
|
| 304 |
+
status_code=400,
|
| 305 |
+
detail=f"无效的阶段类型: {stage_type}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
success = await service.cancel_stage(exp_id, stage_type)
|
| 309 |
+
if not success:
|
| 310 |
+
raise HTTPException(
|
| 311 |
+
status_code=400,
|
| 312 |
+
detail="阶段未运行或无法取消"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return SuccessResponse(message=f"阶段 {stage_type} 已取消")
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@router.get(
|
| 319 |
+
"/{exp_id}/stages/{stage_type}/progress",
|
| 320 |
+
summary="SSE 阶段进度订阅",
|
| 321 |
+
description="""
|
| 322 |
+
订阅阶段进度更新(Server-Sent Events)。
|
| 323 |
+
|
| 324 |
+
返回的事件流格式:
|
| 325 |
+
```
|
| 326 |
+
event: progress
|
| 327 |
+
data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034}
|
| 328 |
+
|
| 329 |
+
event: checkpoint
|
| 330 |
+
data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"}
|
| 331 |
+
|
| 332 |
+
event: completed
|
| 333 |
+
data: {"status": "completed", "final_loss": 0.023}
|
| 334 |
+
```
|
| 335 |
+
""",
|
| 336 |
+
responses={
|
| 337 |
+
400: {"model": ErrorResponse, "description": "阶段类型无效"},
|
| 338 |
+
404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
|
| 339 |
+
},
|
| 340 |
+
)
|
| 341 |
+
async def subscribe_stage_progress(
|
| 342 |
+
exp_id: str,
|
| 343 |
+
stage_type: str,
|
| 344 |
+
service: ExperimentService = Depends(get_experiment_service),
|
| 345 |
+
) -> StreamingResponse:
|
| 346 |
+
"""
|
| 347 |
+
SSE 阶段进度订阅
|
| 348 |
+
"""
|
| 349 |
+
# 验证阶段类型
|
| 350 |
+
if stage_type not in VALID_STAGE_TYPES:
|
| 351 |
+
raise HTTPException(
|
| 352 |
+
status_code=400,
|
| 353 |
+
detail=f"无效的阶段类型: {stage_type}"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# 检查实验是否存在
|
| 357 |
+
experiment = await service.get_experiment(exp_id)
|
| 358 |
+
if not experiment:
|
| 359 |
+
raise HTTPException(status_code=404, detail="实验不存在")
|
| 360 |
+
|
| 361 |
+
async def event_generator():
|
| 362 |
+
"""生成 SSE 事件流"""
|
| 363 |
+
async for progress in service.subscribe_stage_progress(exp_id, stage_type):
|
| 364 |
+
# 确定事件类型
|
| 365 |
+
event_type = progress.get("type", "progress")
|
| 366 |
+
status = progress.get("status")
|
| 367 |
+
|
| 368 |
+
if status == "completed":
|
| 369 |
+
event_type = "completed"
|
| 370 |
+
elif status == "failed":
|
| 371 |
+
event_type = "failed"
|
| 372 |
+
elif status == "cancelled":
|
| 373 |
+
event_type = "cancelled"
|
| 374 |
+
elif progress.get("model_path"):
|
| 375 |
+
event_type = "checkpoint"
|
| 376 |
+
|
| 377 |
+
# 构建 SSE 格式
|
| 378 |
+
data = json.dumps(progress, ensure_ascii=False)
|
| 379 |
+
yield f"event: {event_type}\ndata: {data}\n\n"
|
| 380 |
+
|
| 381 |
+
# 如果是终态,结束流
|
| 382 |
+
if status in ("completed", "failed", "cancelled"):
|
| 383 |
+
break
|
| 384 |
+
|
| 385 |
+
return StreamingResponse(
|
| 386 |
+
event_generator(),
|
| 387 |
+
media_type="text/event-stream",
|
| 388 |
+
headers={
|
| 389 |
+
"Cache-Control": "no-cache",
|
| 390 |
+
"Connection": "keep-alive",
|
| 391 |
+
"X-Accel-Buffering": "no",
|
| 392 |
+
},
|
| 393 |
+
)
|
api_server/app/api/v1/endpoints/files.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文件管理 API
|
| 3 |
+
|
| 4 |
+
文件上传、下载和管理 API 端点
|
| 5 |
+
|
| 6 |
+
API 列表:
|
| 7 |
+
- POST /files 上传文件
|
| 8 |
+
- GET /files 获取文件列表
|
| 9 |
+
- GET /files/{file_id} 下载文件(或获取元数据)
|
| 10 |
+
- DELETE /files/{file_id} 删除文件
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
| 16 |
+
from fastapi.responses import Response
|
| 17 |
+
|
| 18 |
+
from ....models.schemas.file import (
|
| 19 |
+
FileMetadata,
|
| 20 |
+
FileUploadResponse,
|
| 21 |
+
FileListResponse,
|
| 22 |
+
FileDeleteResponse,
|
| 23 |
+
)
|
| 24 |
+
from ....models.schemas.common import ErrorResponse
|
| 25 |
+
from ....services.file_service import FileService
|
| 26 |
+
from ...deps import get_file_service
|
| 27 |
+
|
| 28 |
+
router = APIRouter()
|
| 29 |
+
|
| 30 |
+
# 允许的音频 MIME 类型
|
| 31 |
+
ALLOWED_AUDIO_TYPES = {
|
| 32 |
+
"audio/wav",
|
| 33 |
+
"audio/wave",
|
| 34 |
+
"audio/x-wav",
|
| 35 |
+
"audio/mpeg",
|
| 36 |
+
"audio/mp3",
|
| 37 |
+
"audio/mp4",
|
| 38 |
+
"audio/aac",
|
| 39 |
+
"audio/ogg",
|
| 40 |
+
"audio/flac",
|
| 41 |
+
"audio/x-flac",
|
| 42 |
+
"audio/webm",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# 最大文件大小 (500MB)
|
| 46 |
+
MAX_FILE_SIZE = 500 * 1024 * 1024
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@router.post(
|
| 50 |
+
"",
|
| 51 |
+
response_model=FileUploadResponse,
|
| 52 |
+
summary="上传文件",
|
| 53 |
+
description="""
|
| 54 |
+
上传音频文件用于训练。
|
| 55 |
+
|
| 56 |
+
**支持的音频格式**:
|
| 57 |
+
- WAV
|
| 58 |
+
- MP3
|
| 59 |
+
- FLAC
|
| 60 |
+
- OGG
|
| 61 |
+
- AAC
|
| 62 |
+
- WebM
|
| 63 |
+
|
| 64 |
+
**文件大小限制**: 500MB
|
| 65 |
+
|
| 66 |
+
**用途类型**:
|
| 67 |
+
- `training`: 训练音频(默认)
|
| 68 |
+
- `reference`: 参考音频
|
| 69 |
+
- `output`: 输出文件
|
| 70 |
+
""",
|
| 71 |
+
responses={
|
| 72 |
+
200: {"model": FileUploadResponse, "description": "文件上传成功"},
|
| 73 |
+
400: {"model": ErrorResponse, "description": "文件格式或大小不合法"},
|
| 74 |
+
},
|
| 75 |
+
)
|
| 76 |
+
async def upload_file(
|
| 77 |
+
file: UploadFile = File(..., description="要上传的音频文件"),
|
| 78 |
+
purpose: str = Query(
|
| 79 |
+
"training",
|
| 80 |
+
description="文件用途: training, reference, output"
|
| 81 |
+
),
|
| 82 |
+
service: FileService = Depends(get_file_service),
|
| 83 |
+
) -> FileUploadResponse:
|
| 84 |
+
"""
|
| 85 |
+
上传文件
|
| 86 |
+
"""
|
| 87 |
+
# 验证用途
|
| 88 |
+
if purpose not in ("training", "reference", "output"):
|
| 89 |
+
raise HTTPException(
|
| 90 |
+
status_code=400,
|
| 91 |
+
detail="无效的用途类型,有效值: training, reference, output"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 验证文件类型(可选,允许不明确类型的文件)
|
| 95 |
+
content_type = file.content_type
|
| 96 |
+
if content_type and content_type not in ALLOWED_AUDIO_TYPES:
|
| 97 |
+
# 检查文件扩展名
|
| 98 |
+
filename = file.filename or ""
|
| 99 |
+
ext = filename.lower().split(".")[-1] if "." in filename else ""
|
| 100 |
+
allowed_exts = {"wav", "mp3", "flac", "ogg", "aac", "webm", "m4a"}
|
| 101 |
+
if ext not in allowed_exts:
|
| 102 |
+
raise HTTPException(
|
| 103 |
+
status_code=400,
|
| 104 |
+
detail=f"不支持的文件类型: {content_type}。支持的格式: WAV, MP3, FLAC, OGG, AAC, WebM"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# 读取文件内容
|
| 108 |
+
file_data = await file.read()
|
| 109 |
+
|
| 110 |
+
# 验证文件大小
|
| 111 |
+
if len(file_data) > MAX_FILE_SIZE:
|
| 112 |
+
raise HTTPException(
|
| 113 |
+
status_code=400,
|
| 114 |
+
detail=f"文件过大,最大允许 {MAX_FILE_SIZE // (1024*1024)}MB"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 验证文件不为空
|
| 118 |
+
if len(file_data) == 0:
|
| 119 |
+
raise HTTPException(
|
| 120 |
+
status_code=400,
|
| 121 |
+
detail="文件为空"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# 上传文件
|
| 125 |
+
return await service.upload_file(
|
| 126 |
+
file_data=file_data,
|
| 127 |
+
filename=file.filename or "audio",
|
| 128 |
+
content_type=content_type,
|
| 129 |
+
purpose=purpose,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@router.get(
|
| 134 |
+
"",
|
| 135 |
+
response_model=FileListResponse,
|
| 136 |
+
summary="获取文件列表",
|
| 137 |
+
description="获取已上传的文件列表,支持按用途筛选和分页。",
|
| 138 |
+
)
|
| 139 |
+
async def list_files(
|
| 140 |
+
purpose: Optional[str] = Query(
|
| 141 |
+
None,
|
| 142 |
+
description="按用途筛选: training, reference, output"
|
| 143 |
+
),
|
| 144 |
+
limit: int = Query(50, ge=1, le=100, description="每页数量"),
|
| 145 |
+
offset: int = Query(0, ge=0, description="偏移量"),
|
| 146 |
+
service: FileService = Depends(get_file_service),
|
| 147 |
+
) -> FileListResponse:
|
| 148 |
+
"""
|
| 149 |
+
获取文件列表
|
| 150 |
+
"""
|
| 151 |
+
return await service.list_files(purpose=purpose, limit=limit, offset=offset)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@router.get(
|
| 155 |
+
"/{file_id}",
|
| 156 |
+
summary="下载文件或获取元数据",
|
| 157 |
+
description="""
|
| 158 |
+
根据请求类型返回文件内容或元数据。
|
| 159 |
+
|
| 160 |
+
- 默认返回文件内容(用于下载)
|
| 161 |
+
- 添加 `?metadata=true` 参数只返回元数据
|
| 162 |
+
""",
|
| 163 |
+
responses={
|
| 164 |
+
200: {
|
| 165 |
+
"description": "文件内容(下载时)或元数据(metadata=true 时)",
|
| 166 |
+
},
|
| 167 |
+
404: {"model": ErrorResponse, "description": "文件不存在"},
|
| 168 |
+
},
|
| 169 |
+
)
|
| 170 |
+
async def get_file(
|
| 171 |
+
file_id: str,
|
| 172 |
+
metadata: bool = Query(False, description="只返回元数据"),
|
| 173 |
+
service: FileService = Depends(get_file_service),
|
| 174 |
+
):
|
| 175 |
+
"""
|
| 176 |
+
下载文件或获取元数据
|
| 177 |
+
"""
|
| 178 |
+
if metadata:
|
| 179 |
+
# 只返回元数据
|
| 180 |
+
file_metadata = await service.get_file(file_id)
|
| 181 |
+
if not file_metadata:
|
| 182 |
+
raise HTTPException(status_code=404, detail="文件不存在")
|
| 183 |
+
return file_metadata
|
| 184 |
+
else:
|
| 185 |
+
# 下载文件
|
| 186 |
+
result = await service.download_file(file_id)
|
| 187 |
+
if not result:
|
| 188 |
+
raise HTTPException(status_code=404, detail="文件不存在")
|
| 189 |
+
|
| 190 |
+
file_data, filename, content_type = result
|
| 191 |
+
|
| 192 |
+
return Response(
|
| 193 |
+
content=file_data,
|
| 194 |
+
media_type=content_type,
|
| 195 |
+
headers={
|
| 196 |
+
"Content-Disposition": f'attachment; filename="{filename}"',
|
| 197 |
+
"Content-Length": str(len(file_data)),
|
| 198 |
+
},
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@router.delete(
|
| 203 |
+
"/{file_id}",
|
| 204 |
+
response_model=FileDeleteResponse,
|
| 205 |
+
summary="删除文件",
|
| 206 |
+
description="删除指定的文件。",
|
| 207 |
+
responses={
|
| 208 |
+
200: {"model": FileDeleteResponse, "description": "删除结果"},
|
| 209 |
+
404: {"model": ErrorResponse, "description": "文件不存在"},
|
| 210 |
+
},
|
| 211 |
+
)
|
| 212 |
+
async def delete_file(
|
| 213 |
+
file_id: str,
|
| 214 |
+
service: FileService = Depends(get_file_service),
|
| 215 |
+
) -> FileDeleteResponse:
|
| 216 |
+
"""
|
| 217 |
+
删除文件
|
| 218 |
+
"""
|
| 219 |
+
result = await service.delete_file(file_id)
|
| 220 |
+
if not result.success:
|
| 221 |
+
raise HTTPException(status_code=404, detail="文件不存在或已删除")
|
| 222 |
+
return result
|
api_server/app/api/v1/endpoints/stages.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
阶段模板 API
|
| 3 |
+
|
| 4 |
+
阶段预设和参数模板 API 端点
|
| 5 |
+
|
| 6 |
+
API 列表:
|
| 7 |
+
- GET /stages/presets 获取阶段预设列表
|
| 8 |
+
- GET /stages/{stage_type}/schema 获取阶段参数模板
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, List
|
| 12 |
+
|
| 13 |
+
from fastapi import APIRouter, HTTPException
|
| 14 |
+
|
| 15 |
+
from ....models.schemas.experiment import (
|
| 16 |
+
STAGE_DEPENDENCIES,
|
| 17 |
+
STAGE_PARAMS_MAP,
|
| 18 |
+
AudioSliceParams,
|
| 19 |
+
ASRParams,
|
| 20 |
+
TextFeatureParams,
|
| 21 |
+
HubertFeatureParams,
|
| 22 |
+
SemanticTokenParams,
|
| 23 |
+
SoVITSTrainParams,
|
| 24 |
+
GPTTrainParams,
|
| 25 |
+
)
|
| 26 |
+
from ....models.schemas.common import ErrorResponse
|
| 27 |
+
|
| 28 |
+
router = APIRouter()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ============================================================
|
| 32 |
+
# 阶段预设定义
|
| 33 |
+
# ============================================================
|
| 34 |
+
|
| 35 |
+
STAGE_PRESETS = [
|
| 36 |
+
{
|
| 37 |
+
"id": "full_training",
|
| 38 |
+
"name": "完整训练流程",
|
| 39 |
+
"description": "包含所有阶段的标准训练,从音频切片到模型训练",
|
| 40 |
+
"stages": [
|
| 41 |
+
"audio_slice",
|
| 42 |
+
"asr",
|
| 43 |
+
"text_feature",
|
| 44 |
+
"hubert_feature",
|
| 45 |
+
"semantic_token",
|
| 46 |
+
"sovits_train",
|
| 47 |
+
"gpt_train",
|
| 48 |
+
],
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"id": "retrain_sovits",
|
| 52 |
+
"name": "重训 SoVITS",
|
| 53 |
+
"description": "跳过预处理,仅重新训练 SoVITS 模型",
|
| 54 |
+
"stages": ["sovits_train"],
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"id": "retrain_gpt",
|
| 58 |
+
"name": "重训 GPT",
|
| 59 |
+
"description": "跳过预处理,仅重新训练 GPT 模型",
|
| 60 |
+
"stages": ["gpt_train"],
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"id": "retrain_both",
|
| 64 |
+
"name": "重训两个模型",
|
| 65 |
+
"description": "跳过预处理,重新训练 SoVITS 和 GPT 模型",
|
| 66 |
+
"stages": ["sovits_train", "gpt_train"],
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"id": "feature_extraction",
|
| 70 |
+
"name": "特征提取",
|
| 71 |
+
"description": "仅执行音频切片和特征提取,不进行训练",
|
| 72 |
+
"stages": [
|
| 73 |
+
"audio_slice",
|
| 74 |
+
"asr",
|
| 75 |
+
"text_feature",
|
| 76 |
+
"hubert_feature",
|
| 77 |
+
"semantic_token",
|
| 78 |
+
],
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"id": "audio_preprocessing",
|
| 82 |
+
"name": "音频预处理",
|
| 83 |
+
"description": "仅执行音频切片和语音识别",
|
| 84 |
+
"stages": ["audio_slice", "asr"],
|
| 85 |
+
},
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ============================================================
|
| 90 |
+
# 阶段信息定义
|
| 91 |
+
# ============================================================
|
| 92 |
+
|
| 93 |
+
STAGE_INFO = {
|
| 94 |
+
"audio_slice": {
|
| 95 |
+
"name": "音频切片",
|
| 96 |
+
"description": "将长音频切分为短片段,便于后续处理",
|
| 97 |
+
"dependencies": [],
|
| 98 |
+
},
|
| 99 |
+
"asr": {
|
| 100 |
+
"name": "语音识别",
|
| 101 |
+
"description": "识别音频中的文本内容",
|
| 102 |
+
"dependencies": ["audio_slice"],
|
| 103 |
+
},
|
| 104 |
+
"text_feature": {
|
| 105 |
+
"name": "文本特征提取",
|
| 106 |
+
"description": "使用 BERT 模型提取文本特征",
|
| 107 |
+
"dependencies": ["asr"],
|
| 108 |
+
},
|
| 109 |
+
"hubert_feature": {
|
| 110 |
+
"name": "HuBERT 特征提取",
|
| 111 |
+
"description": "使用 HuBERT 模型提取音频特征",
|
| 112 |
+
"dependencies": ["audio_slice"],
|
| 113 |
+
},
|
| 114 |
+
"semantic_token": {
|
| 115 |
+
"name": "语义 Token 提取",
|
| 116 |
+
"description": "从 HuBERT 特征中提取语义 Token",
|
| 117 |
+
"dependencies": ["hubert_feature"],
|
| 118 |
+
},
|
| 119 |
+
"sovits_train": {
|
| 120 |
+
"name": "SoVITS 训练",
|
| 121 |
+
"description": "训练 SoVITS 声码器模型",
|
| 122 |
+
"dependencies": ["text_feature", "semantic_token"],
|
| 123 |
+
},
|
| 124 |
+
"gpt_train": {
|
| 125 |
+
"name": "GPT 训练",
|
| 126 |
+
"description": "训练 GPT 语言模型",
|
| 127 |
+
"dependencies": ["text_feature", "semantic_token"],
|
| 128 |
+
},
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_parameter_schema(params_class: type) -> Dict[str, Any]:
|
| 133 |
+
"""
|
| 134 |
+
从 Pydantic 模型生成参数 Schema
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
params_class: Pydantic 模型类
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
参数 Schema 字典
|
| 141 |
+
"""
|
| 142 |
+
schema = params_class.model_json_schema()
|
| 143 |
+
properties = schema.get("properties", {})
|
| 144 |
+
|
| 145 |
+
parameters = {}
|
| 146 |
+
for name, prop in properties.items():
|
| 147 |
+
param_info = {
|
| 148 |
+
"type": prop.get("type", "string"),
|
| 149 |
+
"description": prop.get("description", ""),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# 添加默认值
|
| 153 |
+
if "default" in prop:
|
| 154 |
+
param_info["default"] = prop["default"]
|
| 155 |
+
|
| 156 |
+
# 添加范围限制
|
| 157 |
+
if "minimum" in prop:
|
| 158 |
+
param_info["min"] = prop["minimum"]
|
| 159 |
+
if "maximum" in prop:
|
| 160 |
+
param_info["max"] = prop["maximum"]
|
| 161 |
+
|
| 162 |
+
# 处理枚举
|
| 163 |
+
if "enum" in prop:
|
| 164 |
+
param_info["enum"] = prop["enum"]
|
| 165 |
+
|
| 166 |
+
parameters[name] = param_info
|
| 167 |
+
|
| 168 |
+
return parameters
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@router.get(
|
| 172 |
+
"/presets",
|
| 173 |
+
summary="获取阶段预设列表",
|
| 174 |
+
description="""
|
| 175 |
+
获取预定义的训练流程预设。
|
| 176 |
+
|
| 177 |
+
每个预设包含一组阶段,用户可以选择预设快速配置训练流程。
|
| 178 |
+
""",
|
| 179 |
+
response_model=Dict[str, List[Dict[str, Any]]],
|
| 180 |
+
)
|
| 181 |
+
async def get_presets() -> Dict[str, List[Dict[str, Any]]]:
|
| 182 |
+
"""
|
| 183 |
+
获取阶段预设列表
|
| 184 |
+
"""
|
| 185 |
+
return {"presets": STAGE_PRESETS}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@router.get(
|
| 189 |
+
"/{stage_type}/schema",
|
| 190 |
+
summary="获取阶段参数模板",
|
| 191 |
+
description="""
|
| 192 |
+
获取指定阶段的参数模板,包含参数定义、默认值和取值范围。
|
| 193 |
+
|
| 194 |
+
前端可以使用此接口动态生成参数配置表单。
|
| 195 |
+
""",
|
| 196 |
+
responses={
|
| 197 |
+
200: {"description": "阶段参数模板"},
|
| 198 |
+
404: {"model": ErrorResponse, "description": "阶段类型无效"},
|
| 199 |
+
},
|
| 200 |
+
)
|
| 201 |
+
async def get_stage_schema(stage_type: str) -> Dict[str, Any]:
|
| 202 |
+
"""
|
| 203 |
+
获取阶段参数模板
|
| 204 |
+
"""
|
| 205 |
+
# 验证阶段类型
|
| 206 |
+
if stage_type not in STAGE_PARAMS_MAP:
|
| 207 |
+
raise HTTPException(
|
| 208 |
+
status_code=404,
|
| 209 |
+
detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(STAGE_PARAMS_MAP.keys())}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# 获取阶段信息
|
| 213 |
+
stage_info = STAGE_INFO.get(stage_type, {})
|
| 214 |
+
params_class = STAGE_PARAMS_MAP[stage_type]
|
| 215 |
+
|
| 216 |
+
# 生成参数 schema
|
| 217 |
+
parameters = get_parameter_schema(params_class)
|
| 218 |
+
|
| 219 |
+
return {
|
| 220 |
+
"type": stage_type,
|
| 221 |
+
"name": stage_info.get("name", stage_type),
|
| 222 |
+
"description": stage_info.get("description", ""),
|
| 223 |
+
"dependencies": STAGE_DEPENDENCIES.get(stage_type, []),
|
| 224 |
+
"parameters": parameters,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@router.get(
|
| 229 |
+
"",
|
| 230 |
+
summary="获取所有阶段信息",
|
| 231 |
+
description="获取所有训练阶段的信息和依赖关系。",
|
| 232 |
+
)
|
| 233 |
+
async def get_all_stages() -> Dict[str, Any]:
|
| 234 |
+
"""
|
| 235 |
+
获取所有阶段信息
|
| 236 |
+
"""
|
| 237 |
+
stages = []
|
| 238 |
+
for stage_type in STAGE_PARAMS_MAP.keys():
|
| 239 |
+
stage_info = STAGE_INFO.get(stage_type, {})
|
| 240 |
+
stages.append({
|
| 241 |
+
"type": stage_type,
|
| 242 |
+
"name": stage_info.get("name", stage_type),
|
| 243 |
+
"description": stage_info.get("description", ""),
|
| 244 |
+
"dependencies": STAGE_DEPENDENCIES.get(stage_type, []),
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
return {"stages": stages}
|
api_server/app/api/v1/endpoints/tasks.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick Mode 任务 API
|
| 3 |
+
|
| 4 |
+
小白用户一键训练 API 端点
|
| 5 |
+
|
| 6 |
+
API 列表:
|
| 7 |
+
- POST /tasks 创建一键训练任务
|
| 8 |
+
- GET /tasks 获取任务列表
|
| 9 |
+
- GET /tasks/{task_id} 获取任务详情
|
| 10 |
+
- DELETE /tasks/{task_id} 取消任务
|
| 11 |
+
- GET /tasks/{task_id}/progress SSE 进度订阅
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 18 |
+
from fastapi.responses import StreamingResponse
|
| 19 |
+
|
| 20 |
+
from ....models.schemas.task import (
|
| 21 |
+
QuickModeRequest,
|
| 22 |
+
TaskResponse,
|
| 23 |
+
TaskListResponse,
|
| 24 |
+
)
|
| 25 |
+
from ....models.schemas.common import SuccessResponse, ErrorResponse
|
| 26 |
+
from ....services.task_service import TaskService
|
| 27 |
+
from ...deps import get_task_service
|
| 28 |
+
|
| 29 |
+
router = APIRouter()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@router.post(
|
| 33 |
+
"",
|
| 34 |
+
response_model=TaskResponse,
|
| 35 |
+
summary="创建一键训练任务",
|
| 36 |
+
description="""
|
| 37 |
+
创建一键训练任务(小白用户)。
|
| 38 |
+
|
| 39 |
+
上传音频文件后,系统自动配置所有参数并执行完整训练流程:
|
| 40 |
+
`audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train`
|
| 41 |
+
|
| 42 |
+
**质量预设**:
|
| 43 |
+
- `fast`: SoVITS 4 epochs, GPT 8 epochs, 约10分钟
|
| 44 |
+
- `standard`: SoVITS 8 epochs, GPT 15 epochs, 约20分钟
|
| 45 |
+
- `high`: SoVITS 16 epochs, GPT 30 epochs, 约40分钟
|
| 46 |
+
""",
|
| 47 |
+
responses={
|
| 48 |
+
200: {"model": TaskResponse, "description": "任务创建成功"},
|
| 49 |
+
400: {"model": ErrorResponse, "description": "请求参数错误"},
|
| 50 |
+
404: {"model": ErrorResponse, "description": "音频文件不存在"},
|
| 51 |
+
409: {"model": ErrorResponse, "description": "实验名称已存在"},
|
| 52 |
+
},
|
| 53 |
+
)
|
| 54 |
+
async def create_task(
|
| 55 |
+
request: QuickModeRequest,
|
| 56 |
+
service: TaskService = Depends(get_task_service),
|
| 57 |
+
) -> TaskResponse:
|
| 58 |
+
"""
|
| 59 |
+
创建一键训练任务
|
| 60 |
+
"""
|
| 61 |
+
# 验证 exp_name 是否已存在
|
| 62 |
+
if await service.check_exp_name_exists(request.exp_name):
|
| 63 |
+
raise HTTPException(
|
| 64 |
+
status_code=409,
|
| 65 |
+
detail=f"实验名称 '{request.exp_name}' 已存在,请使用不同的名称"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 验证音频文件是否存在
|
| 69 |
+
file_exists, audio_path = await service.validate_audio_file(request.audio_file_id)
|
| 70 |
+
if not file_exists:
|
| 71 |
+
raise HTTPException(
|
| 72 |
+
status_code=404,
|
| 73 |
+
detail=f"音频文件不存在: {request.audio_file_id}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return await service.create_quick_task(request)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@router.get(
|
| 80 |
+
"",
|
| 81 |
+
response_model=TaskListResponse,
|
| 82 |
+
summary="获取任务列表",
|
| 83 |
+
description="获取所有训练任务列表,支持按状态筛选和分页。",
|
| 84 |
+
)
|
| 85 |
+
async def list_tasks(
|
| 86 |
+
status: Optional[str] = Query(
|
| 87 |
+
None,
|
| 88 |
+
description="按状态筛选: queued, running, completed, failed, cancelled, interrupted"
|
| 89 |
+
),
|
| 90 |
+
limit: int = Query(50, ge=1, le=100, description="每页数量"),
|
| 91 |
+
offset: int = Query(0, ge=0, description="偏移量"),
|
| 92 |
+
service: TaskService = Depends(get_task_service),
|
| 93 |
+
) -> TaskListResponse:
|
| 94 |
+
"""
|
| 95 |
+
获取任务列表
|
| 96 |
+
"""
|
| 97 |
+
return await service.list_tasks(status=status, limit=limit, offset=offset)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@router.get(
|
| 101 |
+
"/{task_id}",
|
| 102 |
+
response_model=TaskResponse,
|
| 103 |
+
summary="获取任务详情",
|
| 104 |
+
description="获取指定任务的详细状态信息。",
|
| 105 |
+
responses={
|
| 106 |
+
200: {"model": TaskResponse, "description": "任务详情"},
|
| 107 |
+
404: {"model": ErrorResponse, "description": "任务不存在"},
|
| 108 |
+
},
|
| 109 |
+
)
|
| 110 |
+
async def get_task(
|
| 111 |
+
task_id: str,
|
| 112 |
+
service: TaskService = Depends(get_task_service),
|
| 113 |
+
) -> TaskResponse:
|
| 114 |
+
"""
|
| 115 |
+
获取任务详情
|
| 116 |
+
"""
|
| 117 |
+
task = await service.get_task(task_id)
|
| 118 |
+
if not task:
|
| 119 |
+
raise HTTPException(status_code=404, detail="任务不存在")
|
| 120 |
+
return task
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@router.delete(
|
| 124 |
+
"/{task_id}",
|
| 125 |
+
response_model=SuccessResponse,
|
| 126 |
+
summary="取消任务",
|
| 127 |
+
description="取消排队中或运行中的任务。已完成、失败或已取消的任务无法取消。",
|
| 128 |
+
responses={
|
| 129 |
+
200: {"model": SuccessResponse, "description": "任务取消成功"},
|
| 130 |
+
400: {"model": ErrorResponse, "description": "任务无法取消"},
|
| 131 |
+
404: {"model": ErrorResponse, "description": "任务不存在"},
|
| 132 |
+
},
|
| 133 |
+
)
|
| 134 |
+
async def cancel_task(
|
| 135 |
+
task_id: str,
|
| 136 |
+
service: TaskService = Depends(get_task_service),
|
| 137 |
+
) -> SuccessResponse:
|
| 138 |
+
"""
|
| 139 |
+
取消任务
|
| 140 |
+
"""
|
| 141 |
+
# 先检查任务是否存在
|
| 142 |
+
task = await service.get_task(task_id)
|
| 143 |
+
if not task:
|
| 144 |
+
raise HTTPException(status_code=404, detail="任务不存在")
|
| 145 |
+
|
| 146 |
+
success = await service.cancel_task(task_id)
|
| 147 |
+
if not success:
|
| 148 |
+
raise HTTPException(status_code=400, detail="任务无法取消(可能已完成或已取消)")
|
| 149 |
+
|
| 150 |
+
return SuccessResponse(message="任务已取消")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@router.get(
|
| 154 |
+
"/{task_id}/progress",
|
| 155 |
+
summary="SSE 进度订阅",
|
| 156 |
+
description="""
|
| 157 |
+
订阅任务进度更新(Server-Sent Events)。
|
| 158 |
+
|
| 159 |
+
返回的事件流格式:
|
| 160 |
+
```
|
| 161 |
+
event: progress
|
| 162 |
+
data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"}
|
| 163 |
+
|
| 164 |
+
event: progress
|
| 165 |
+
data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"}
|
| 166 |
+
|
| 167 |
+
event: completed
|
| 168 |
+
data: {"status": "completed", "message": "训练完成"}
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
可能的事件类型:
|
| 172 |
+
- `progress`: 进度更新
|
| 173 |
+
- `log`: 日志消息
|
| 174 |
+
- `heartbeat`: 心跳(保持连接)
|
| 175 |
+
- `completed`: 任务完成
|
| 176 |
+
- `failed`: 任务失败
|
| 177 |
+
- `cancelled`: 任务取消
|
| 178 |
+
""",
|
| 179 |
+
responses={
|
| 180 |
+
200: {"description": "SSE 事件流"},
|
| 181 |
+
404: {"model": ErrorResponse, "description": "任务不存在"},
|
| 182 |
+
},
|
| 183 |
+
)
|
| 184 |
+
async def subscribe_progress(
|
| 185 |
+
task_id: str,
|
| 186 |
+
service: TaskService = Depends(get_task_service),
|
| 187 |
+
) -> StreamingResponse:
|
| 188 |
+
"""
|
| 189 |
+
SSE 进度订阅
|
| 190 |
+
"""
|
| 191 |
+
# 先检查任务是否存在
|
| 192 |
+
task = await service.get_task(task_id)
|
| 193 |
+
if not task:
|
| 194 |
+
raise HTTPException(status_code=404, detail="任务不存在")
|
| 195 |
+
|
| 196 |
+
async def event_generator():
|
| 197 |
+
"""生成 SSE 事件流"""
|
| 198 |
+
async for progress in service.subscribe_progress(task_id):
|
| 199 |
+
# 确定事件类型
|
| 200 |
+
event_type = progress.get("type", "progress")
|
| 201 |
+
status = progress.get("status")
|
| 202 |
+
|
| 203 |
+
if status == "completed":
|
| 204 |
+
event_type = "completed"
|
| 205 |
+
elif status == "failed":
|
| 206 |
+
event_type = "failed"
|
| 207 |
+
elif status == "cancelled":
|
| 208 |
+
event_type = "cancelled"
|
| 209 |
+
elif event_type == "heartbeat":
|
| 210 |
+
event_type = "heartbeat"
|
| 211 |
+
|
| 212 |
+
# 构建 SSE 格式
|
| 213 |
+
data = json.dumps(progress, ensure_ascii=False)
|
| 214 |
+
yield f"event: {event_type}\ndata: {data}\n\n"
|
| 215 |
+
|
| 216 |
+
# 如果是终态,结束流
|
| 217 |
+
if status in ("completed", "failed", "cancelled"):
|
| 218 |
+
break
|
| 219 |
+
|
| 220 |
+
return StreamingResponse(
|
| 221 |
+
event_generator(),
|
| 222 |
+
media_type="text/event-stream",
|
| 223 |
+
headers={
|
| 224 |
+
"Cache-Control": "no-cache",
|
| 225 |
+
"Connection": "keep-alive",
|
| 226 |
+
"X-Accel-Buffering": "no", # Nginx 禁用缓冲
|
| 227 |
+
},
|
| 228 |
+
)
|
api_server/app/api/v1/router.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API v1 路由注册
|
| 3 |
+
|
| 4 |
+
统一注册所有 v1 版本的 API 路由
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter
|
| 8 |
+
|
| 9 |
+
from .endpoints import tasks, experiments, files, stages
|
| 10 |
+
|
| 11 |
+
api_router = APIRouter()
|
| 12 |
+
|
| 13 |
+
# Quick Mode API - 一键训练任务
|
| 14 |
+
api_router.include_router(
|
| 15 |
+
tasks.router,
|
| 16 |
+
prefix="/tasks",
|
| 17 |
+
tags=["Quick Mode - 任务管理"],
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Advanced Mode API - 专家模式实验
|
| 21 |
+
api_router.include_router(
|
| 22 |
+
experiments.router,
|
| 23 |
+
prefix="/experiments",
|
| 24 |
+
tags=["Advanced Mode - 实验管理"],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# 文件管理 API
|
| 28 |
+
api_router.include_router(
|
| 29 |
+
files.router,
|
| 30 |
+
prefix="/files",
|
| 31 |
+
tags=["文件管理"],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# 阶段模板 API
|
| 35 |
+
api_router.include_router(
|
| 36 |
+
stages.router,
|
| 37 |
+
prefix="/stages",
|
| 38 |
+
tags=["阶段模板"],
|
| 39 |
+
)
|
api_server/app/core/adapters.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
适配器工厂模块
|
| 3 |
+
|
| 4 |
+
根据 DEPLOYMENT_MODE 配置自动选择本地或服务器适配器。
|
| 5 |
+
|
| 6 |
+
Example:
|
| 7 |
+
>>> from app.core.adapters import get_database_adapter, get_storage_adapter
|
| 8 |
+
>>> db = get_database_adapter()
|
| 9 |
+
>>> storage = get_storage_adapter()
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from functools import lru_cache
|
| 13 |
+
from typing import TYPE_CHECKING
|
| 14 |
+
|
| 15 |
+
from .config import settings
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from ..adapters.base import (
|
| 19 |
+
DatabaseAdapter,
|
| 20 |
+
ProgressAdapter,
|
| 21 |
+
StorageAdapter,
|
| 22 |
+
TaskQueueAdapter,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AdapterFactory:
|
| 27 |
+
"""
|
| 28 |
+
适配器工厂
|
| 29 |
+
|
| 30 |
+
根据 DEPLOYMENT_MODE 配置创建对应的适配器实例。
|
| 31 |
+
|
| 32 |
+
- local 模式: SQLite + 本地文件系统 + asyncio.subprocess
|
| 33 |
+
- server 模式: PostgreSQL + S3/MinIO + Celery (Phase 2)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def create_storage_adapter() -> "StorageAdapter":
|
| 38 |
+
"""
|
| 39 |
+
创建存储适配器
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
本地模式返回 LocalStorageAdapter
|
| 43 |
+
服务器模式返回 S3StorageAdapter (Phase 2)
|
| 44 |
+
"""
|
| 45 |
+
if settings.DEPLOYMENT_MODE == "local":
|
| 46 |
+
from ..adapters.local.storage import LocalStorageAdapter
|
| 47 |
+
return LocalStorageAdapter(base_path=str(settings.DATA_DIR / "files"))
|
| 48 |
+
else:
|
| 49 |
+
# Phase 2: 服务器模式
|
| 50 |
+
raise NotImplementedError("Server mode storage adapter not implemented yet")
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def create_database_adapter() -> "DatabaseAdapter":
|
| 54 |
+
"""
|
| 55 |
+
创建数据库适配器
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
本地模式返回 SQLiteAdapter
|
| 59 |
+
服务器模式返回 PostgreSQLAdapter (Phase 2)
|
| 60 |
+
"""
|
| 61 |
+
if settings.DEPLOYMENT_MODE == "local":
|
| 62 |
+
from ..adapters.local.database import SQLiteAdapter
|
| 63 |
+
return SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
|
| 64 |
+
else:
|
| 65 |
+
# Phase 2: 服务器模式
|
| 66 |
+
raise NotImplementedError("Server mode database adapter not implemented yet")
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def create_task_queue_adapter(database_adapter: "DatabaseAdapter" = None) -> "TaskQueueAdapter":
|
| 70 |
+
"""
|
| 71 |
+
创建任务队列适配器
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
database_adapter: 数据库适配器,用于同步任务状态到 tasks 表。
|
| 75 |
+
如果未提供,将自动创建一个实例。
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
本地模式返回 AsyncTrainingManager
|
| 79 |
+
服务器模式返回 CeleryTaskQueueAdapter (Phase 2)
|
| 80 |
+
"""
|
| 81 |
+
if settings.DEPLOYMENT_MODE == "local":
|
| 82 |
+
from ..adapters.local.task_queue import AsyncTrainingManager
|
| 83 |
+
from ..adapters.local.database import SQLiteAdapter
|
| 84 |
+
|
| 85 |
+
# 如果未提供 database_adapter,创建一个新实例用于状态同步
|
| 86 |
+
if database_adapter is None:
|
| 87 |
+
database_adapter = SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
|
| 88 |
+
|
| 89 |
+
return AsyncTrainingManager(
|
| 90 |
+
db_path=str(settings.SQLITE_PATH),
|
| 91 |
+
database_adapter=database_adapter
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
# Phase 2: 服务器模式
|
| 95 |
+
raise NotImplementedError("Server mode task queue adapter not implemented yet")
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def create_progress_adapter() -> "ProgressAdapter":
|
| 99 |
+
"""
|
| 100 |
+
创建进度管理适配器
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
本地模式返回 LocalProgressAdapter
|
| 104 |
+
服务器模式返回 RedisProgressAdapter (Phase 2)
|
| 105 |
+
"""
|
| 106 |
+
if settings.DEPLOYMENT_MODE == "local":
|
| 107 |
+
from ..adapters.local.progress import LocalProgressAdapter
|
| 108 |
+
return LocalProgressAdapter()
|
| 109 |
+
else:
|
| 110 |
+
# Phase 2: 服务器模式
|
| 111 |
+
raise NotImplementedError("Server mode progress adapter not implemented yet")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ============================================================
|
| 115 |
+
# 全局单例获取函数(使用 lru_cache 缓存实例)
|
| 116 |
+
# ============================================================
|
| 117 |
+
|
| 118 |
+
@lru_cache()
|
| 119 |
+
def get_storage_adapter() -> "StorageAdapter":
|
| 120 |
+
"""
|
| 121 |
+
获取存储适配器单例
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
StorageAdapter 实例
|
| 125 |
+
"""
|
| 126 |
+
return AdapterFactory.create_storage_adapter()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@lru_cache()
|
| 130 |
+
def get_database_adapter() -> "DatabaseAdapter":
|
| 131 |
+
"""
|
| 132 |
+
获取数据库适配器单例
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
DatabaseAdapter 实例
|
| 136 |
+
"""
|
| 137 |
+
return AdapterFactory.create_database_adapter()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@lru_cache()
|
| 141 |
+
def get_task_queue_adapter() -> "TaskQueueAdapter":
|
| 142 |
+
"""
|
| 143 |
+
获取任务队列适配器单例
|
| 144 |
+
|
| 145 |
+
使用共享的数据库适配器实例来确保状态同步一致性。
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
TaskQueueAdapter 实例
|
| 149 |
+
"""
|
| 150 |
+
# 使用共享的数据库适配器实例
|
| 151 |
+
db_adapter = get_database_adapter()
|
| 152 |
+
return AdapterFactory.create_task_queue_adapter(database_adapter=db_adapter)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@lru_cache()
|
| 156 |
+
def get_progress_adapter() -> "ProgressAdapter":
|
| 157 |
+
"""
|
| 158 |
+
获���进度管理适配器单例
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
ProgressAdapter 实例
|
| 162 |
+
"""
|
| 163 |
+
return AdapterFactory.create_progress_adapter()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ============================================================
|
| 167 |
+
# 便捷别名(向后兼容)
|
| 168 |
+
# ============================================================
|
| 169 |
+
|
| 170 |
+
# 延迟初始化的全局变量,在首次访问时创建
|
| 171 |
+
# 注意:这些是函数调用的结果,不是直接的实例引用
|
| 172 |
+
# 如果需要在模块级别使用,请调用对应的 get_*_adapter() 函数
|
| 173 |
+
|
| 174 |
+
__all__ = [
|
| 175 |
+
"AdapterFactory",
|
| 176 |
+
"get_storage_adapter",
|
| 177 |
+
"get_database_adapter",
|
| 178 |
+
"get_task_queue_adapter",
|
| 179 |
+
"get_progress_adapter",
|
| 180 |
+
]
|
api_server/app/main.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI 应用入口
|
| 3 |
+
|
| 4 |
+
GPT-SoVITS 音色训练 HTTP API 服务
|
| 5 |
+
|
| 6 |
+
启动方式:
|
| 7 |
+
uvicorn api_server.app.main:app --host 0.0.0.0 --port 8000 --reload
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
+
from typing import AsyncGenerator
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
|
| 16 |
+
from .api.v1.router import api_router
|
| 17 |
+
from .core.config import settings, ensure_data_dirs
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@asynccontextmanager
|
| 21 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 22 |
+
"""
|
| 23 |
+
应用生命周期管理
|
| 24 |
+
|
| 25 |
+
启动时:
|
| 26 |
+
- 确保数据目录存在
|
| 27 |
+
- 恢复中断的任务(可选)
|
| 28 |
+
|
| 29 |
+
关闭时:
|
| 30 |
+
- 清理资源
|
| 31 |
+
"""
|
| 32 |
+
# 启动时执行
|
| 33 |
+
print(f"Starting GPT-SoVITS Training API in {settings.DEPLOYMENT_MODE.upper()} mode")
|
| 34 |
+
print(f" Project Root: {settings.PROJECT_ROOT}")
|
| 35 |
+
print(f" Data Directory: {settings.DATA_DIR}")
|
| 36 |
+
print(f" SQLite Path: {settings.SQLITE_PATH}")
|
| 37 |
+
|
| 38 |
+
# 确保数据目录存在
|
| 39 |
+
ensure_data_dirs()
|
| 40 |
+
|
| 41 |
+
# 恢复中断的任务(可选)
|
| 42 |
+
if settings.DEPLOYMENT_MODE == "local":
|
| 43 |
+
try:
|
| 44 |
+
from .core.adapters import get_task_queue_adapter
|
| 45 |
+
queue = get_task_queue_adapter()
|
| 46 |
+
# 检查是否有 recover_pending_tasks 方法
|
| 47 |
+
if hasattr(queue, 'recover_pending_tasks'):
|
| 48 |
+
count = await queue.recover_pending_tasks()
|
| 49 |
+
if count > 0:
|
| 50 |
+
print(f" Recovered {count} pending tasks")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f" Warning: Failed to recover tasks: {e}")
|
| 53 |
+
|
| 54 |
+
print(" API Server ready!")
|
| 55 |
+
print(f" Docs: http://{settings.API_HOST}:{settings.API_PORT}/docs")
|
| 56 |
+
|
| 57 |
+
yield
|
| 58 |
+
|
| 59 |
+
# 关闭时执行
|
| 60 |
+
print("Shutting down GPT-SoVITS Training API...")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# 创建 FastAPI 应用
|
| 64 |
+
app = FastAPI(
|
| 65 |
+
title="GPT-SoVITS Training API",
|
| 66 |
+
description="""
|
| 67 |
+
GPT-SoVITS 音色训练 HTTP API 服务
|
| 68 |
+
|
| 69 |
+
## 功能概述
|
| 70 |
+
|
| 71 |
+
提供两种训练模式:
|
| 72 |
+
|
| 73 |
+
### Quick Mode(小白用户)
|
| 74 |
+
- 上传音频即可训练,系统自动配置所有参数
|
| 75 |
+
- 适合个人开发者、快速验证
|
| 76 |
+
|
| 77 |
+
### Advanced Mode(专家用户)
|
| 78 |
+
- 分阶段控制训练流程
|
| 79 |
+
- 精细调整每个阶段的参数
|
| 80 |
+
- 适合需要深度定制的用户
|
| 81 |
+
|
| 82 |
+
## API 分组
|
| 83 |
+
|
| 84 |
+
- **Quick Mode - 任务管理**: `/api/v1/tasks`
|
| 85 |
+
- **Advanced Mode - 实验管理**: `/api/v1/experiments`
|
| 86 |
+
- **文件管理**: `/api/v1/files`
|
| 87 |
+
- **阶段模板**: `/api/v1/stages`
|
| 88 |
+
""",
|
| 89 |
+
version="1.0.0",
|
| 90 |
+
lifespan=lifespan,
|
| 91 |
+
docs_url="/docs",
|
| 92 |
+
redoc_url="/redoc",
|
| 93 |
+
openapi_url="/openapi.json",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 配置 CORS
|
| 97 |
+
app.add_middleware(
|
| 98 |
+
CORSMiddleware,
|
| 99 |
+
allow_origins=["*"], # 生产环境应该限制来源
|
| 100 |
+
allow_credentials=True,
|
| 101 |
+
allow_methods=["*"],
|
| 102 |
+
allow_headers=["*"],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# 注册 API 路由
|
| 106 |
+
app.include_router(api_router, prefix=settings.API_V1_PREFIX)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ============================================================
|
| 110 |
+
# 根路由和健康检查
|
| 111 |
+
# ============================================================
|
| 112 |
+
|
| 113 |
+
@app.get("/", tags=["Root"])
|
| 114 |
+
async def root():
|
| 115 |
+
"""
|
| 116 |
+
根路由
|
| 117 |
+
|
| 118 |
+
返回 API 基本信息
|
| 119 |
+
"""
|
| 120 |
+
return {
|
| 121 |
+
"name": "GPT-SoVITS Training API",
|
| 122 |
+
"version": "1.0.0",
|
| 123 |
+
"mode": settings.DEPLOYMENT_MODE,
|
| 124 |
+
"docs": "/docs",
|
| 125 |
+
"health": "/health",
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@app.get("/health", tags=["Health"])
|
| 130 |
+
async def health_check():
|
| 131 |
+
"""
|
| 132 |
+
健康检查端点
|
| 133 |
+
|
| 134 |
+
用于容器编排和负载均衡器健康检查
|
| 135 |
+
"""
|
| 136 |
+
return {
|
| 137 |
+
"status": "healthy",
|
| 138 |
+
"mode": settings.DEPLOYMENT_MODE,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ============================================================
|
| 143 |
+
# 开发模式直接运行
|
| 144 |
+
# ============================================================
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
import uvicorn
|
| 148 |
+
|
| 149 |
+
uvicorn.run(
|
| 150 |
+
"api_server.app.main:app",
|
| 151 |
+
host=settings.API_HOST,
|
| 152 |
+
port=settings.API_PORT,
|
| 153 |
+
reload=True,
|
| 154 |
+
reload_dirs=[str(settings.API_SERVER_ROOT)],
|
| 155 |
+
)
|
api_server/app/models/__init__.py
CHANGED
|
@@ -6,4 +6,75 @@
|
|
| 6 |
|
| 7 |
from .domain import Task, TaskStatus, ProgressInfo
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from .domain import Task, TaskStatus, ProgressInfo
|
| 8 |
|
| 9 |
+
# Pydantic Schemas
|
| 10 |
+
from .schemas import (
|
| 11 |
+
# Common
|
| 12 |
+
SuccessResponse,
|
| 13 |
+
ErrorResponse,
|
| 14 |
+
PaginatedResponse,
|
| 15 |
+
# Task (Quick Mode)
|
| 16 |
+
QuickModeOptions,
|
| 17 |
+
QuickModeRequest,
|
| 18 |
+
TaskResponse,
|
| 19 |
+
TaskListResponse,
|
| 20 |
+
# Experiment (Advanced Mode)
|
| 21 |
+
StageType,
|
| 22 |
+
ExperimentCreate,
|
| 23 |
+
ExperimentUpdate,
|
| 24 |
+
StageStatus,
|
| 25 |
+
ExperimentResponse,
|
| 26 |
+
ExperimentListResponse,
|
| 27 |
+
StageExecuteRequest,
|
| 28 |
+
AudioSliceParams,
|
| 29 |
+
ASRParams,
|
| 30 |
+
TextFeatureParams,
|
| 31 |
+
HubertFeatureParams,
|
| 32 |
+
SemanticTokenParams,
|
| 33 |
+
SoVITSTrainParams,
|
| 34 |
+
GPTTrainParams,
|
| 35 |
+
StageExecuteResponse,
|
| 36 |
+
StagesListResponse,
|
| 37 |
+
# File
|
| 38 |
+
FileUploadResponse,
|
| 39 |
+
FileMetadata,
|
| 40 |
+
FileListResponse,
|
| 41 |
+
FileDeleteResponse,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
__all__ = [
|
| 45 |
+
# Domain models
|
| 46 |
+
"Task",
|
| 47 |
+
"TaskStatus",
|
| 48 |
+
"ProgressInfo",
|
| 49 |
+
# Common schemas
|
| 50 |
+
"SuccessResponse",
|
| 51 |
+
"ErrorResponse",
|
| 52 |
+
"PaginatedResponse",
|
| 53 |
+
# Task schemas (Quick Mode)
|
| 54 |
+
"QuickModeOptions",
|
| 55 |
+
"QuickModeRequest",
|
| 56 |
+
"TaskResponse",
|
| 57 |
+
"TaskListResponse",
|
| 58 |
+
# Experiment schemas (Advanced Mode)
|
| 59 |
+
"StageType",
|
| 60 |
+
"ExperimentCreate",
|
| 61 |
+
"ExperimentUpdate",
|
| 62 |
+
"StageStatus",
|
| 63 |
+
"ExperimentResponse",
|
| 64 |
+
"ExperimentListResponse",
|
| 65 |
+
"StageExecuteRequest",
|
| 66 |
+
"AudioSliceParams",
|
| 67 |
+
"ASRParams",
|
| 68 |
+
"TextFeatureParams",
|
| 69 |
+
"HubertFeatureParams",
|
| 70 |
+
"SemanticTokenParams",
|
| 71 |
+
"SoVITSTrainParams",
|
| 72 |
+
"GPTTrainParams",
|
| 73 |
+
"StageExecuteResponse",
|
| 74 |
+
"StagesListResponse",
|
| 75 |
+
# File schemas
|
| 76 |
+
"FileUploadResponse",
|
| 77 |
+
"FileMetadata",
|
| 78 |
+
"FileListResponse",
|
| 79 |
+
"FileDeleteResponse",
|
| 80 |
+
]
|
api_server/app/models/schemas/__init__.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic Schema 模块
|
| 3 |
+
|
| 4 |
+
包含 API 请求/响应的数据验证模型
|
| 5 |
+
|
| 6 |
+
- common: 通用响应模型
|
| 7 |
+
- task: Quick Mode 任务模型
|
| 8 |
+
- experiment: Advanced Mode 实验/阶段模型
|
| 9 |
+
- file: 文件管理模型
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .common import (
|
| 13 |
+
SuccessResponse,
|
| 14 |
+
ErrorResponse,
|
| 15 |
+
PaginatedResponse,
|
| 16 |
+
)
|
| 17 |
+
from .task import (
|
| 18 |
+
QuickModeOptions,
|
| 19 |
+
QuickModeRequest,
|
| 20 |
+
TaskResponse,
|
| 21 |
+
TaskListResponse,
|
| 22 |
+
)
|
| 23 |
+
from .experiment import (
|
| 24 |
+
StageType,
|
| 25 |
+
ExperimentCreate,
|
| 26 |
+
ExperimentUpdate,
|
| 27 |
+
StageStatus,
|
| 28 |
+
ExperimentResponse,
|
| 29 |
+
ExperimentListResponse,
|
| 30 |
+
StageExecuteRequest,
|
| 31 |
+
AudioSliceParams,
|
| 32 |
+
ASRParams,
|
| 33 |
+
TextFeatureParams,
|
| 34 |
+
HubertFeatureParams,
|
| 35 |
+
SemanticTokenParams,
|
| 36 |
+
SoVITSTrainParams,
|
| 37 |
+
GPTTrainParams,
|
| 38 |
+
StageExecuteResponse,
|
| 39 |
+
StagesListResponse,
|
| 40 |
+
)
|
| 41 |
+
from .file import (
|
| 42 |
+
FileUploadResponse,
|
| 43 |
+
FileMetadata,
|
| 44 |
+
FileListResponse,
|
| 45 |
+
FileDeleteResponse,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
__all__ = [
|
| 49 |
+
# Common
|
| 50 |
+
"SuccessResponse",
|
| 51 |
+
"ErrorResponse",
|
| 52 |
+
"PaginatedResponse",
|
| 53 |
+
# Task (Quick Mode)
|
| 54 |
+
"QuickModeOptions",
|
| 55 |
+
"QuickModeRequest",
|
| 56 |
+
"TaskResponse",
|
| 57 |
+
"TaskListResponse",
|
| 58 |
+
# Experiment (Advanced Mode)
|
| 59 |
+
"StageType",
|
| 60 |
+
"ExperimentCreate",
|
| 61 |
+
"ExperimentUpdate",
|
| 62 |
+
"StageStatus",
|
| 63 |
+
"ExperimentResponse",
|
| 64 |
+
"ExperimentListResponse",
|
| 65 |
+
"StageExecuteRequest",
|
| 66 |
+
"AudioSliceParams",
|
| 67 |
+
"ASRParams",
|
| 68 |
+
"TextFeatureParams",
|
| 69 |
+
"HubertFeatureParams",
|
| 70 |
+
"SemanticTokenParams",
|
| 71 |
+
"SoVITSTrainParams",
|
| 72 |
+
"GPTTrainParams",
|
| 73 |
+
"StageExecuteResponse",
|
| 74 |
+
"StagesListResponse",
|
| 75 |
+
# File
|
| 76 |
+
"FileUploadResponse",
|
| 77 |
+
"FileMetadata",
|
| 78 |
+
"FileListResponse",
|
| 79 |
+
"FileDeleteResponse",
|
| 80 |
+
]
|
api_server/app/models/schemas/common.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
通用响应模型
|
| 3 |
+
|
| 4 |
+
定义 API 通用的响应结构
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Generic, List, Optional, TypeVar
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
# 泛型类型变量,用于分页响应
|
| 11 |
+
T = TypeVar("T")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SuccessResponse(BaseModel):
|
| 15 |
+
"""
|
| 16 |
+
通用成功响应
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> response = SuccessResponse(message="操作成功")
|
| 20 |
+
>>> response.model_dump()
|
| 21 |
+
{'success': True, 'message': '操作成功'}
|
| 22 |
+
"""
|
| 23 |
+
success: bool = Field(default=True, description="是否成功")
|
| 24 |
+
message: str = Field(default="操作成功", description="响应消息")
|
| 25 |
+
|
| 26 |
+
model_config = {
|
| 27 |
+
"json_schema_extra": {
|
| 28 |
+
"examples": [
|
| 29 |
+
{"success": True, "message": "操作成功"}
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ErrorResponse(BaseModel):
|
| 36 |
+
"""
|
| 37 |
+
错误响应
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
>>> response = ErrorResponse(message="任务不存在", code="TASK_NOT_FOUND")
|
| 41 |
+
>>> response.model_dump()
|
| 42 |
+
{'success': False, 'message': '任务不存在', 'code': 'TASK_NOT_FOUND', 'details': None}
|
| 43 |
+
"""
|
| 44 |
+
success: bool = Field(default=False, description="是否成功")
|
| 45 |
+
message: str = Field(..., description="错误消息")
|
| 46 |
+
code: Optional[str] = Field(default=None, description="错误代码")
|
| 47 |
+
details: Optional[Any] = Field(default=None, description="错误详情")
|
| 48 |
+
|
| 49 |
+
model_config = {
|
| 50 |
+
"json_schema_extra": {
|
| 51 |
+
"examples": [
|
| 52 |
+
{
|
| 53 |
+
"success": False,
|
| 54 |
+
"message": "任务不存在",
|
| 55 |
+
"code": "TASK_NOT_FOUND",
|
| 56 |
+
"details": None
|
| 57 |
+
}
|
| 58 |
+
]
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PaginatedResponse(BaseModel, Generic[T]):
|
| 64 |
+
"""
|
| 65 |
+
分页响应基类
|
| 66 |
+
|
| 67 |
+
泛型参数 T 表示列表项的类型
|
| 68 |
+
|
| 69 |
+
Example:
|
| 70 |
+
>>> from typing import List
|
| 71 |
+
>>> class TaskListResponse(PaginatedResponse[TaskResponse]):
|
| 72 |
+
... pass
|
| 73 |
+
"""
|
| 74 |
+
items: List[T] = Field(default_factory=list, description="数据列表")
|
| 75 |
+
total: int = Field(default=0, ge=0, description="总数量")
|
| 76 |
+
limit: int = Field(default=50, ge=1, le=100, description="每页数量")
|
| 77 |
+
offset: int = Field(default=0, ge=0, description="偏移量")
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def has_more(self) -> bool:
|
| 81 |
+
"""是否有更多数据"""
|
| 82 |
+
return self.offset + len(self.items) < self.total
|
| 83 |
+
|
| 84 |
+
model_config = {
|
| 85 |
+
"json_schema_extra": {
|
| 86 |
+
"examples": [
|
| 87 |
+
{
|
| 88 |
+
"items": [],
|
| 89 |
+
"total": 0,
|
| 90 |
+
"limit": 50,
|
| 91 |
+
"offset": 0
|
| 92 |
+
}
|
| 93 |
+
]
|
| 94 |
+
}
|
| 95 |
+
}
|
api_server/app/models/schemas/experiment.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Mode 实验/阶段 Schema
|
| 3 |
+
|
| 4 |
+
专家用户分阶段训练模式的请求/响应模型
|
| 5 |
+
|
| 6 |
+
参考文档: development.md 4.6.2
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ============================================================
|
| 16 |
+
# 枚举类型
|
| 17 |
+
# ============================================================
|
| 18 |
+
|
| 19 |
+
class StageType(str, Enum):
|
| 20 |
+
"""
|
| 21 |
+
训练阶段类型枚举
|
| 22 |
+
|
| 23 |
+
定义了完整训练流程中的所有阶段
|
| 24 |
+
"""
|
| 25 |
+
AUDIO_SLICE = "audio_slice" # 音频切片
|
| 26 |
+
ASR = "asr" # 语音识别
|
| 27 |
+
TEXT_FEATURE = "text_feature" # 文本特征提取
|
| 28 |
+
HUBERT_FEATURE = "hubert_feature" # HuBERT 特征提取
|
| 29 |
+
SEMANTIC_TOKEN = "semantic_token" # 语义 Token 提取
|
| 30 |
+
SOVITS_TRAIN = "sovits_train" # SoVITS 训练
|
| 31 |
+
GPT_TRAIN = "gpt_train" # GPT 训练
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 阶段依赖关系
|
| 35 |
+
STAGE_DEPENDENCIES: Dict[str, List[str]] = {
|
| 36 |
+
"audio_slice": [],
|
| 37 |
+
"asr": ["audio_slice"],
|
| 38 |
+
"text_feature": ["asr"],
|
| 39 |
+
"hubert_feature": ["audio_slice"],
|
| 40 |
+
"semantic_token": ["hubert_feature"],
|
| 41 |
+
"sovits_train": ["text_feature", "semantic_token"],
|
| 42 |
+
"gpt_train": ["text_feature", "semantic_token"],
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ============================================================
|
| 47 |
+
# 实验管理
|
| 48 |
+
# ============================================================
|
| 49 |
+
|
| 50 |
+
class ExperimentCreate(BaseModel):
|
| 51 |
+
"""
|
| 52 |
+
创建实验请求
|
| 53 |
+
|
| 54 |
+
创建实验但不立即执行,用户可以逐阶段控制训练流程
|
| 55 |
+
|
| 56 |
+
Attributes:
|
| 57 |
+
exp_name: 实验名称
|
| 58 |
+
version: 模型版本
|
| 59 |
+
gpu_numbers: GPU 编号
|
| 60 |
+
is_half: 是否使用半精度
|
| 61 |
+
audio_file_id: 音频文件 ID
|
| 62 |
+
"""
|
| 63 |
+
exp_name: str = Field(
|
| 64 |
+
...,
|
| 65 |
+
min_length=1,
|
| 66 |
+
max_length=100,
|
| 67 |
+
description="实验名称"
|
| 68 |
+
)
|
| 69 |
+
version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(
|
| 70 |
+
default="v2",
|
| 71 |
+
description="模型版本"
|
| 72 |
+
)
|
| 73 |
+
gpu_numbers: str = Field(
|
| 74 |
+
default="0",
|
| 75 |
+
description="GPU 编号,多个 GPU 用逗号分隔,如 '0,1'"
|
| 76 |
+
)
|
| 77 |
+
is_half: bool = Field(
|
| 78 |
+
default=True,
|
| 79 |
+
description="是否使用半精度(FP16),可节省显存"
|
| 80 |
+
)
|
| 81 |
+
audio_file_id: str = Field(
|
| 82 |
+
...,
|
| 83 |
+
description="已上传音频文件的 ID"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
model_config = {
|
| 87 |
+
"json_schema_extra": {
|
| 88 |
+
"examples": [
|
| 89 |
+
{
|
| 90 |
+
"exp_name": "my_voice_custom",
|
| 91 |
+
"version": "v2",
|
| 92 |
+
"gpu_numbers": "0",
|
| 93 |
+
"is_half": True,
|
| 94 |
+
"audio_file_id": "550e8400-e29b-41d4-a716-446655440000"
|
| 95 |
+
}
|
| 96 |
+
]
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ExperimentUpdate(BaseModel):
|
| 102 |
+
"""
|
| 103 |
+
更新实验请求
|
| 104 |
+
|
| 105 |
+
用于更新实验的基础配置(非阶段参数)
|
| 106 |
+
"""
|
| 107 |
+
exp_name: Optional[str] = Field(
|
| 108 |
+
default=None,
|
| 109 |
+
min_length=1,
|
| 110 |
+
max_length=100,
|
| 111 |
+
description="实验名称"
|
| 112 |
+
)
|
| 113 |
+
gpu_numbers: Optional[str] = Field(
|
| 114 |
+
default=None,
|
| 115 |
+
description="GPU 编号"
|
| 116 |
+
)
|
| 117 |
+
is_half: Optional[bool] = Field(
|
| 118 |
+
default=None,
|
| 119 |
+
description="是否使用半精度"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class StageStatus(BaseModel):
|
| 124 |
+
"""
|
| 125 |
+
阶段状态
|
| 126 |
+
|
| 127 |
+
描述单个阶段的执行状态和结果
|
| 128 |
+
"""
|
| 129 |
+
stage_type: str = Field(..., description="阶段类型")
|
| 130 |
+
status: Literal["pending", "running", "completed", "failed", "cancelled"] = Field(
|
| 131 |
+
default="pending",
|
| 132 |
+
description="阶段状态"
|
| 133 |
+
)
|
| 134 |
+
progress: Optional[float] = Field(
|
| 135 |
+
default=None,
|
| 136 |
+
ge=0.0,
|
| 137 |
+
le=1.0,
|
| 138 |
+
description="阶段进度 (0.0-1.0)"
|
| 139 |
+
)
|
| 140 |
+
message: Optional[str] = Field(
|
| 141 |
+
default=None,
|
| 142 |
+
description="状态消息"
|
| 143 |
+
)
|
| 144 |
+
started_at: Optional[datetime] = Field(
|
| 145 |
+
default=None,
|
| 146 |
+
description="开始时间"
|
| 147 |
+
)
|
| 148 |
+
completed_at: Optional[datetime] = Field(
|
| 149 |
+
default=None,
|
| 150 |
+
description="完成时间"
|
| 151 |
+
)
|
| 152 |
+
config: Optional[Dict[str, Any]] = Field(
|
| 153 |
+
default=None,
|
| 154 |
+
description="阶段配置参数"
|
| 155 |
+
)
|
| 156 |
+
outputs: Optional[Dict[str, Any]] = Field(
|
| 157 |
+
default=None,
|
| 158 |
+
description="阶段输出结果"
|
| 159 |
+
)
|
| 160 |
+
error_message: Optional[str] = Field(
|
| 161 |
+
default=None,
|
| 162 |
+
description="错误消息(失败时)"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
model_config = {
|
| 166 |
+
"json_schema_extra": {
|
| 167 |
+
"examples": [
|
| 168 |
+
{
|
| 169 |
+
"stage_type": "sovits_train",
|
| 170 |
+
"status": "completed",
|
| 171 |
+
"progress": 1.0,
|
| 172 |
+
"message": "训练完成",
|
| 173 |
+
"started_at": "2024-01-01T10:30:00Z",
|
| 174 |
+
"completed_at": "2024-01-01T11:00:00Z",
|
| 175 |
+
"config": {"batch_size": 8, "total_epoch": 16},
|
| 176 |
+
"outputs": {
|
| 177 |
+
"model_path": "logs/my_voice/sovits_e16.pth",
|
| 178 |
+
"metrics": {"final_loss": 0.023}
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
]
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ExperimentResponse(BaseModel):
|
| 187 |
+
"""
|
| 188 |
+
实验响应
|
| 189 |
+
|
| 190 |
+
包含实验的完整信息和所有阶段状态
|
| 191 |
+
"""
|
| 192 |
+
id: str = Field(..., description="实验唯一标识")
|
| 193 |
+
exp_name: str = Field(..., description="实验名称")
|
| 194 |
+
version: str = Field(..., description="模型版本")
|
| 195 |
+
status: str = Field(..., description="实验状态")
|
| 196 |
+
gpu_numbers: str = Field(default="0", description="GPU 编号")
|
| 197 |
+
is_half: bool = Field(default=True, description="是否使用半精度")
|
| 198 |
+
audio_file_id: str = Field(..., description="音频文件 ID")
|
| 199 |
+
stages: Dict[str, StageStatus] = Field(
|
| 200 |
+
default_factory=dict,
|
| 201 |
+
description="各阶段状态"
|
| 202 |
+
)
|
| 203 |
+
created_at: datetime = Field(..., description="创建时间")
|
| 204 |
+
updated_at: Optional[datetime] = Field(default=None, description="更新时间")
|
| 205 |
+
|
| 206 |
+
model_config = {
|
| 207 |
+
"json_schema_extra": {
|
| 208 |
+
"examples": [
|
| 209 |
+
{
|
| 210 |
+
"id": "exp-abc123",
|
| 211 |
+
"exp_name": "my_voice_custom",
|
| 212 |
+
"version": "v2",
|
| 213 |
+
"status": "created",
|
| 214 |
+
"gpu_numbers": "0",
|
| 215 |
+
"is_half": True,
|
| 216 |
+
"audio_file_id": "550e8400-e29b-41d4-a716-446655440000",
|
| 217 |
+
"stages": {
|
| 218 |
+
"audio_slice": {"stage_type": "audio_slice", "status": "pending"},
|
| 219 |
+
"asr": {"stage_type": "asr", "status": "pending"},
|
| 220 |
+
"sovits_train": {"stage_type": "sovits_train", "status": "pending"}
|
| 221 |
+
},
|
| 222 |
+
"created_at": "2024-01-01T10:00:00Z"
|
| 223 |
+
}
|
| 224 |
+
]
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ExperimentListResponse(BaseModel):
|
| 230 |
+
"""
|
| 231 |
+
实验列表响应
|
| 232 |
+
"""
|
| 233 |
+
items: List[ExperimentResponse] = Field(
|
| 234 |
+
default_factory=list,
|
| 235 |
+
description="实验列表"
|
| 236 |
+
)
|
| 237 |
+
total: int = Field(default=0, ge=0, description="总数量")
|
| 238 |
+
limit: int = Field(default=50, ge=1, le=100, description="每页数量")
|
| 239 |
+
offset: int = Field(default=0, ge=0, description="偏移量")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ============================================================
|
| 243 |
+
# 阶段执行参数
|
| 244 |
+
# ============================================================
|
| 245 |
+
|
| 246 |
+
class StageExecuteRequest(BaseModel):
|
| 247 |
+
"""
|
| 248 |
+
阶段执行请求基类
|
| 249 |
+
|
| 250 |
+
允许传入任意额外参数
|
| 251 |
+
"""
|
| 252 |
+
model_config = {
|
| 253 |
+
"extra": "allow" # 允许额外字段(阶段特定参数)
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class AudioSliceParams(StageExecuteRequest):
|
| 258 |
+
"""
|
| 259 |
+
音频切片参数
|
| 260 |
+
|
| 261 |
+
将长音频切分为短片段
|
| 262 |
+
|
| 263 |
+
参考文档: development.md 4.5.2
|
| 264 |
+
"""
|
| 265 |
+
threshold: int = Field(
|
| 266 |
+
default=-34,
|
| 267 |
+
ge=-60,
|
| 268 |
+
le=0,
|
| 269 |
+
description="静音检测阈值 (dB)"
|
| 270 |
+
)
|
| 271 |
+
min_length: int = Field(
|
| 272 |
+
default=4000,
|
| 273 |
+
ge=1000,
|
| 274 |
+
le=10000,
|
| 275 |
+
description="最小切片长度 (ms)"
|
| 276 |
+
)
|
| 277 |
+
min_interval: int = Field(
|
| 278 |
+
default=300,
|
| 279 |
+
ge=100,
|
| 280 |
+
le=1000,
|
| 281 |
+
description="最小静音间隔 (ms)"
|
| 282 |
+
)
|
| 283 |
+
hop_size: int = Field(
|
| 284 |
+
default=10,
|
| 285 |
+
ge=5,
|
| 286 |
+
le=50,
|
| 287 |
+
description="检测步长 (ms)"
|
| 288 |
+
)
|
| 289 |
+
max_sil_kept: int = Field(
|
| 290 |
+
default=500,
|
| 291 |
+
ge=100,
|
| 292 |
+
le=2000,
|
| 293 |
+
description="切片保留的最大静音长度 (ms)"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
model_config = {
|
| 297 |
+
"json_schema_extra": {
|
| 298 |
+
"examples": [
|
| 299 |
+
{
|
| 300 |
+
"threshold": -34,
|
| 301 |
+
"min_length": 4000,
|
| 302 |
+
"min_interval": 300,
|
| 303 |
+
"hop_size": 10,
|
| 304 |
+
"max_sil_kept": 500
|
| 305 |
+
}
|
| 306 |
+
]
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class ASRParams(StageExecuteRequest):
|
| 312 |
+
"""
|
| 313 |
+
ASR 语音识别参数
|
| 314 |
+
"""
|
| 315 |
+
model: str = Field(
|
| 316 |
+
default="达摩 ASR (中文)",
|
| 317 |
+
description="ASR 模型名称"
|
| 318 |
+
)
|
| 319 |
+
language: str = Field(
|
| 320 |
+
default="zh",
|
| 321 |
+
description="识别语言"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
model_config = {
|
| 325 |
+
"json_schema_extra": {
|
| 326 |
+
"examples": [
|
| 327 |
+
{"model": "达摩 ASR (中文)", "language": "zh"}
|
| 328 |
+
]
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class TextFeatureParams(StageExecuteRequest):
|
| 334 |
+
"""
|
| 335 |
+
文本特征提取参数
|
| 336 |
+
"""
|
| 337 |
+
bert_pretrained_dir: Optional[str] = Field(
|
| 338 |
+
default=None,
|
| 339 |
+
description="BERT 预训练模型目录,为空使用默认"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
model_config = {
|
| 343 |
+
"json_schema_extra": {
|
| 344 |
+
"examples": [
|
| 345 |
+
{"bert_pretrained_dir": None}
|
| 346 |
+
]
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class HubertFeatureParams(StageExecuteRequest):
|
| 352 |
+
"""
|
| 353 |
+
HuBERT 特征提取参数
|
| 354 |
+
"""
|
| 355 |
+
ssl_pretrained_dir: Optional[str] = Field(
|
| 356 |
+
default=None,
|
| 357 |
+
description="SSL 预训练模型目录,为空使用默认"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
model_config = {
|
| 361 |
+
"json_schema_extra": {
|
| 362 |
+
"examples": [
|
| 363 |
+
{"ssl_pretrained_dir": None}
|
| 364 |
+
]
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class SemanticTokenParams(StageExecuteRequest):
|
| 370 |
+
"""
|
| 371 |
+
语义 Token 提取参数
|
| 372 |
+
"""
|
| 373 |
+
# 当前阶段无特殊参数,保留扩展性
|
| 374 |
+
pass
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class SoVITSTrainParams(StageExecuteRequest):
|
| 378 |
+
"""
|
| 379 |
+
SoVITS 训练参数
|
| 380 |
+
|
| 381 |
+
参考文档: development.md 4.5.2
|
| 382 |
+
"""
|
| 383 |
+
batch_size: int = Field(
|
| 384 |
+
default=4,
|
| 385 |
+
ge=1,
|
| 386 |
+
le=32,
|
| 387 |
+
description="批次大小,显存不足时减小"
|
| 388 |
+
)
|
| 389 |
+
total_epoch: int = Field(
|
| 390 |
+
default=8,
|
| 391 |
+
ge=1,
|
| 392 |
+
le=100,
|
| 393 |
+
description="训练总轮数"
|
| 394 |
+
)
|
| 395 |
+
save_every_epoch: int = Field(
|
| 396 |
+
default=4,
|
| 397 |
+
ge=1,
|
| 398 |
+
description="每 N 轮保存一次模型"
|
| 399 |
+
)
|
| 400 |
+
pretrained_s2G: Optional[str] = Field(
|
| 401 |
+
default=None,
|
| 402 |
+
description="预训练生成器模型路径,为空使用默认"
|
| 403 |
+
)
|
| 404 |
+
pretrained_s2D: Optional[str] = Field(
|
| 405 |
+
default=None,
|
| 406 |
+
description="预训练判别器模型路径,为空使用默认"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
model_config = {
|
| 410 |
+
"json_schema_extra": {
|
| 411 |
+
"examples": [
|
| 412 |
+
{
|
| 413 |
+
"batch_size": 8,
|
| 414 |
+
"total_epoch": 16,
|
| 415 |
+
"save_every_epoch": 4,
|
| 416 |
+
"pretrained_s2G": None,
|
| 417 |
+
"pretrained_s2D": None
|
| 418 |
+
}
|
| 419 |
+
]
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class GPTTrainParams(StageExecuteRequest):
|
| 425 |
+
"""
|
| 426 |
+
GPT 训练参数
|
| 427 |
+
"""
|
| 428 |
+
batch_size: int = Field(
|
| 429 |
+
default=4,
|
| 430 |
+
ge=1,
|
| 431 |
+
le=32,
|
| 432 |
+
description="批次大小"
|
| 433 |
+
)
|
| 434 |
+
total_epoch: int = Field(
|
| 435 |
+
default=15,
|
| 436 |
+
ge=1,
|
| 437 |
+
le=100,
|
| 438 |
+
description="训练总轮数"
|
| 439 |
+
)
|
| 440 |
+
save_every_epoch: int = Field(
|
| 441 |
+
default=5,
|
| 442 |
+
ge=1,
|
| 443 |
+
description="每 N 轮保存一次模型"
|
| 444 |
+
)
|
| 445 |
+
pretrained_s1: Optional[str] = Field(
|
| 446 |
+
default=None,
|
| 447 |
+
description="预训练模型路径,为空使用默认"
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
model_config = {
|
| 451 |
+
"json_schema_extra": {
|
| 452 |
+
"examples": [
|
| 453 |
+
{
|
| 454 |
+
"batch_size": 4,
|
| 455 |
+
"total_epoch": 15,
|
| 456 |
+
"save_every_epoch": 5,
|
| 457 |
+
"pretrained_s1": None
|
| 458 |
+
}
|
| 459 |
+
]
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class StageExecuteResponse(BaseModel):
|
| 465 |
+
"""
|
| 466 |
+
阶段执行响应
|
| 467 |
+
"""
|
| 468 |
+
exp_id: str = Field(..., description="实验 ID")
|
| 469 |
+
stage_type: str = Field(..., description="阶段类型")
|
| 470 |
+
status: Literal["running", "queued"] = Field(..., description="执行状态")
|
| 471 |
+
job_id: str = Field(..., description="作业 ID")
|
| 472 |
+
config: Dict[str, Any] = Field(
|
| 473 |
+
default_factory=dict,
|
| 474 |
+
description="阶段配置"
|
| 475 |
+
)
|
| 476 |
+
rerun: bool = Field(
|
| 477 |
+
default=False,
|
| 478 |
+
description="是否为重新执行"
|
| 479 |
+
)
|
| 480 |
+
previous_run: Optional[Dict[str, Any]] = Field(
|
| 481 |
+
default=None,
|
| 482 |
+
description="上次执行的信息(重新执行时)"
|
| 483 |
+
)
|
| 484 |
+
started_at: datetime = Field(..., description="开始时间")
|
| 485 |
+
|
| 486 |
+
model_config = {
|
| 487 |
+
"json_schema_extra": {
|
| 488 |
+
"examples": [
|
| 489 |
+
{
|
| 490 |
+
"exp_id": "exp-abc123",
|
| 491 |
+
"stage_type": "sovits_train",
|
| 492 |
+
"status": "running",
|
| 493 |
+
"job_id": "job-xyz789",
|
| 494 |
+
"config": {"batch_size": 8, "total_epoch": 16},
|
| 495 |
+
"rerun": False,
|
| 496 |
+
"started_at": "2024-01-01T10:30:00Z"
|
| 497 |
+
}
|
| 498 |
+
]
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class StagesListResponse(BaseModel):
|
| 504 |
+
"""
|
| 505 |
+
所有阶段状态响应
|
| 506 |
+
"""
|
| 507 |
+
exp_id: str = Field(..., description="实验 ID")
|
| 508 |
+
stages: List[StageStatus] = Field(
|
| 509 |
+
default_factory=list,
|
| 510 |
+
description="阶段状态列表"
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
model_config = {
|
| 514 |
+
"json_schema_extra": {
|
| 515 |
+
"examples": [
|
| 516 |
+
{
|
| 517 |
+
"exp_id": "exp-abc123",
|
| 518 |
+
"stages": [
|
| 519 |
+
{"stage_type": "audio_slice", "status": "completed"},
|
| 520 |
+
{"stage_type": "asr", "status": "completed"},
|
| 521 |
+
{"stage_type": "sovits_train", "status": "running", "progress": 0.45}
|
| 522 |
+
]
|
| 523 |
+
}
|
| 524 |
+
]
|
| 525 |
+
}
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# 阶段参数类型映射
|
| 530 |
+
STAGE_PARAMS_MAP: Dict[str, type] = {
|
| 531 |
+
"audio_slice": AudioSliceParams,
|
| 532 |
+
"asr": ASRParams,
|
| 533 |
+
"text_feature": TextFeatureParams,
|
| 534 |
+
"hubert_feature": HubertFeatureParams,
|
| 535 |
+
"semantic_token": SemanticTokenParams,
|
| 536 |
+
"sovits_train": SoVITSTrainParams,
|
| 537 |
+
"gpt_train": GPTTrainParams,
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def get_stage_params_class(stage_type: str) -> type:
|
| 542 |
+
"""
|
| 543 |
+
获取阶段对应的参数类
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
stage_type: 阶段类型
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
对应的参数 Pydantic 类
|
| 550 |
+
|
| 551 |
+
Raises:
|
| 552 |
+
ValueError: 无效的阶段类型
|
| 553 |
+
"""
|
| 554 |
+
if stage_type not in STAGE_PARAMS_MAP:
|
| 555 |
+
raise ValueError(f"Invalid stage type: {stage_type}")
|
| 556 |
+
return STAGE_PARAMS_MAP[stage_type]
|
api_server/app/models/schemas/file.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文件管理 Schema
|
| 3 |
+
|
| 4 |
+
文件上传/下载相关的请求/响应模型
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import List, Literal, Optional
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FileMetadata(BaseModel):
|
| 13 |
+
"""
|
| 14 |
+
文件元数据
|
| 15 |
+
|
| 16 |
+
描述已上传文件的详细信息
|
| 17 |
+
"""
|
| 18 |
+
id: str = Field(..., description="文件唯一标识")
|
| 19 |
+
filename: str = Field(..., description="原始文件名")
|
| 20 |
+
content_type: Optional[str] = Field(
|
| 21 |
+
default=None,
|
| 22 |
+
description="MIME 类型,如 'audio/wav', 'audio/mp3'"
|
| 23 |
+
)
|
| 24 |
+
size_bytes: int = Field(
|
| 25 |
+
default=0,
|
| 26 |
+
ge=0,
|
| 27 |
+
description="文件大小(字节)"
|
| 28 |
+
)
|
| 29 |
+
purpose: Optional[Literal["training", "reference", "output"]] = Field(
|
| 30 |
+
default="training",
|
| 31 |
+
description="文件用途:training(训练), reference(参考音频), output(输出模型)"
|
| 32 |
+
)
|
| 33 |
+
duration_seconds: Optional[float] = Field(
|
| 34 |
+
default=None,
|
| 35 |
+
ge=0,
|
| 36 |
+
description="音频时长(秒),仅音频文件有效"
|
| 37 |
+
)
|
| 38 |
+
sample_rate: Optional[int] = Field(
|
| 39 |
+
default=None,
|
| 40 |
+
ge=0,
|
| 41 |
+
description="采样率(Hz),仅音频文件有效"
|
| 42 |
+
)
|
| 43 |
+
uploaded_at: datetime = Field(..., description="上传时间")
|
| 44 |
+
|
| 45 |
+
model_config = {
|
| 46 |
+
"json_schema_extra": {
|
| 47 |
+
"examples": [
|
| 48 |
+
{
|
| 49 |
+
"id": "550e8400-e29b-41d4-a716-446655440000",
|
| 50 |
+
"filename": "my_voice.wav",
|
| 51 |
+
"content_type": "audio/wav",
|
| 52 |
+
"size_bytes": 15728640,
|
| 53 |
+
"purpose": "training",
|
| 54 |
+
"duration_seconds": 120.5,
|
| 55 |
+
"sample_rate": 44100,
|
| 56 |
+
"uploaded_at": "2024-01-01T10:00:00Z"
|
| 57 |
+
}
|
| 58 |
+
]
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class FileUploadResponse(BaseModel):
|
| 64 |
+
"""
|
| 65 |
+
文件上传响应
|
| 66 |
+
|
| 67 |
+
上传成功后返回文件信息
|
| 68 |
+
"""
|
| 69 |
+
success: bool = Field(default=True, description="是否成功")
|
| 70 |
+
message: str = Field(default="文件上传成功", description="响应消息")
|
| 71 |
+
file: FileMetadata = Field(..., description="文件元数据")
|
| 72 |
+
|
| 73 |
+
model_config = {
|
| 74 |
+
"json_schema_extra": {
|
| 75 |
+
"examples": [
|
| 76 |
+
{
|
| 77 |
+
"success": True,
|
| 78 |
+
"message": "文件上传成功",
|
| 79 |
+
"file": {
|
| 80 |
+
"id": "550e8400-e29b-41d4-a716-446655440000",
|
| 81 |
+
"filename": "my_voice.wav",
|
| 82 |
+
"content_type": "audio/wav",
|
| 83 |
+
"size_bytes": 15728640,
|
| 84 |
+
"purpose": "training",
|
| 85 |
+
"uploaded_at": "2024-01-01T10:00:00Z"
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
]
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class FileListResponse(BaseModel):
|
| 94 |
+
"""
|
| 95 |
+
文件列表响应
|
| 96 |
+
"""
|
| 97 |
+
items: List[FileMetadata] = Field(
|
| 98 |
+
default_factory=list,
|
| 99 |
+
description="文件列表"
|
| 100 |
+
)
|
| 101 |
+
total: int = Field(
|
| 102 |
+
default=0,
|
| 103 |
+
ge=0,
|
| 104 |
+
description="总数量"
|
| 105 |
+
)
|
| 106 |
+
limit: int = Field(
|
| 107 |
+
default=50,
|
| 108 |
+
ge=1,
|
| 109 |
+
le=100,
|
| 110 |
+
description="每页数量"
|
| 111 |
+
)
|
| 112 |
+
offset: int = Field(
|
| 113 |
+
default=0,
|
| 114 |
+
ge=0,
|
| 115 |
+
description="偏移量"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
model_config = {
|
| 119 |
+
"json_schema_extra": {
|
| 120 |
+
"examples": [
|
| 121 |
+
{
|
| 122 |
+
"items": [
|
| 123 |
+
{
|
| 124 |
+
"id": "file-123",
|
| 125 |
+
"filename": "voice_1.wav",
|
| 126 |
+
"content_type": "audio/wav",
|
| 127 |
+
"size_bytes": 5242880,
|
| 128 |
+
"purpose": "training",
|
| 129 |
+
"uploaded_at": "2024-01-01T10:00:00Z"
|
| 130 |
+
}
|
| 131 |
+
],
|
| 132 |
+
"total": 1,
|
| 133 |
+
"limit": 50,
|
| 134 |
+
"offset": 0
|
| 135 |
+
}
|
| 136 |
+
]
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class FileDeleteResponse(BaseModel):
|
| 142 |
+
"""
|
| 143 |
+
文件删除响应
|
| 144 |
+
"""
|
| 145 |
+
success: bool = Field(default=True, description="是否成功")
|
| 146 |
+
message: str = Field(default="文件删除成功", description="响应消息")
|
| 147 |
+
file_id: str = Field(..., description="已删除的文件 ID")
|
| 148 |
+
|
| 149 |
+
model_config = {
|
| 150 |
+
"json_schema_extra": {
|
| 151 |
+
"examples": [
|
| 152 |
+
{
|
| 153 |
+
"success": True,
|
| 154 |
+
"message": "文件删除成功",
|
| 155 |
+
"file_id": "550e8400-e29b-41d4-a716-446655440000"
|
| 156 |
+
}
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
}
|
api_server/app/models/schemas/task.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick Mode 任务 Schema
|
| 3 |
+
|
| 4 |
+
小白用户一键训练模式的请求/响应模型
|
| 5 |
+
|
| 6 |
+
参考文档: development.md 4.6.1 + 4.6.3
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import List, Literal, Optional
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QuickModeOptions(BaseModel):
|
| 15 |
+
"""
|
| 16 |
+
Quick Mode 训练选项
|
| 17 |
+
|
| 18 |
+
用于一键训练时的简化参数配置
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
version: 模型版本
|
| 22 |
+
language: 训练语言
|
| 23 |
+
quality: 训练质量预设
|
| 24 |
+
|
| 25 |
+
质量预设说明:
|
| 26 |
+
- fast: SoVITS 4 epochs, GPT 8 epochs, ~10分钟
|
| 27 |
+
- standard: SoVITS 8 epochs, GPT 15 epochs, ~20分钟
|
| 28 |
+
- high: SoVITS 16 epochs, GPT 30 epochs, ~40分钟
|
| 29 |
+
"""
|
| 30 |
+
version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(
|
| 31 |
+
default="v2",
|
| 32 |
+
description="模型版本"
|
| 33 |
+
)
|
| 34 |
+
language: str = Field(
|
| 35 |
+
default="zh",
|
| 36 |
+
description="训练语言,如 'zh', 'en', 'ja' 等"
|
| 37 |
+
)
|
| 38 |
+
quality: Literal["fast", "standard", "high"] = Field(
|
| 39 |
+
default="standard",
|
| 40 |
+
description="训练质量预设:fast(快速)、standard(标准)、high(高质量)"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
model_config = {
|
| 44 |
+
"json_schema_extra": {
|
| 45 |
+
"examples": [
|
| 46 |
+
{"version": "v2", "language": "zh", "quality": "standard"}
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class QuickModeRequest(BaseModel):
|
| 53 |
+
"""
|
| 54 |
+
小白用户一键训练请求
|
| 55 |
+
|
| 56 |
+
创建一键训练任务,系统自动配置所有参数并执行完整流程:
|
| 57 |
+
audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train
|
| 58 |
+
|
| 59 |
+
Attributes:
|
| 60 |
+
exp_name: 实验名称(用于标识训练任务)
|
| 61 |
+
audio_file_id: 已上传音频文件的 ID
|
| 62 |
+
options: 训练选项
|
| 63 |
+
"""
|
| 64 |
+
exp_name: str = Field(
|
| 65 |
+
...,
|
| 66 |
+
min_length=1,
|
| 67 |
+
max_length=100,
|
| 68 |
+
description="实验名称,用于标识训练任务和生成的模型"
|
| 69 |
+
)
|
| 70 |
+
audio_file_id: str = Field(
|
| 71 |
+
...,
|
| 72 |
+
description="已上传音频文件的 ID"
|
| 73 |
+
)
|
| 74 |
+
options: QuickModeOptions = Field(
|
| 75 |
+
default_factory=QuickModeOptions,
|
| 76 |
+
description="训练选项"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
model_config = {
|
| 80 |
+
"json_schema_extra": {
|
| 81 |
+
"examples": [
|
| 82 |
+
{
|
| 83 |
+
"exp_name": "my_voice",
|
| 84 |
+
"audio_file_id": "550e8400-e29b-41d4-a716-446655440000",
|
| 85 |
+
"options": {
|
| 86 |
+
"version": "v2",
|
| 87 |
+
"language": "zh",
|
| 88 |
+
"quality": "standard"
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
]
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TaskResponse(BaseModel):
|
| 97 |
+
"""
|
| 98 |
+
任务响应(Quick Mode)
|
| 99 |
+
|
| 100 |
+
返回任务的完整状态信息,包括进度、当前阶段等
|
| 101 |
+
|
| 102 |
+
Attributes:
|
| 103 |
+
id: 任务唯一标识
|
| 104 |
+
exp_name: 实验名称
|
| 105 |
+
status: 任务状态
|
| 106 |
+
current_stage: 当前执行的阶段
|
| 107 |
+
progress: 当前阶段进度 (0.0-1.0)
|
| 108 |
+
overall_progress: 总体进度 (0.0-1.0)
|
| 109 |
+
message: 最新状态消息
|
| 110 |
+
error_message: 错误消息(失败时)
|
| 111 |
+
created_at: 任务创建时间
|
| 112 |
+
started_at: 任务开始执行时间
|
| 113 |
+
completed_at: 任务完成时间
|
| 114 |
+
"""
|
| 115 |
+
id: str = Field(..., description="任务唯一标识")
|
| 116 |
+
exp_name: str = Field(..., description="实验名称")
|
| 117 |
+
status: Literal["queued", "running", "completed", "failed", "cancelled", "interrupted"] = Field(
|
| 118 |
+
...,
|
| 119 |
+
description="任务状态"
|
| 120 |
+
)
|
| 121 |
+
current_stage: Optional[str] = Field(
|
| 122 |
+
default=None,
|
| 123 |
+
description="当前执行的阶段,如 'audio_slice', 'sovits_train' 等"
|
| 124 |
+
)
|
| 125 |
+
progress: float = Field(
|
| 126 |
+
default=0.0,
|
| 127 |
+
ge=0.0,
|
| 128 |
+
le=1.0,
|
| 129 |
+
description="当前阶段进度 (0.0-1.0)"
|
| 130 |
+
)
|
| 131 |
+
overall_progress: float = Field(
|
| 132 |
+
default=0.0,
|
| 133 |
+
ge=0.0,
|
| 134 |
+
le=1.0,
|
| 135 |
+
description="总体进度 (0.0-1.0)"
|
| 136 |
+
)
|
| 137 |
+
message: Optional[str] = Field(
|
| 138 |
+
default=None,
|
| 139 |
+
description="最新状态消息"
|
| 140 |
+
)
|
| 141 |
+
error_message: Optional[str] = Field(
|
| 142 |
+
default=None,
|
| 143 |
+
description="错误消息(失败时)"
|
| 144 |
+
)
|
| 145 |
+
created_at: Optional[datetime] = Field(
|
| 146 |
+
default=None,
|
| 147 |
+
description="任务创建时间"
|
| 148 |
+
)
|
| 149 |
+
started_at: Optional[datetime] = Field(
|
| 150 |
+
default=None,
|
| 151 |
+
description="任务开始执行时间"
|
| 152 |
+
)
|
| 153 |
+
completed_at: Optional[datetime] = Field(
|
| 154 |
+
default=None,
|
| 155 |
+
description="任务完成时间"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
model_config = {
|
| 159 |
+
"from_attributes": True,
|
| 160 |
+
"json_schema_extra": {
|
| 161 |
+
"examples": [
|
| 162 |
+
{
|
| 163 |
+
"id": "task-550e8400-e29b-41d4-a716-446655440000",
|
| 164 |
+
"exp_name": "my_voice",
|
| 165 |
+
"status": "running",
|
| 166 |
+
"current_stage": "sovits_train",
|
| 167 |
+
"progress": 0.45,
|
| 168 |
+
"overall_progress": 0.72,
|
| 169 |
+
"message": "SoVITS 训练中 Epoch 8/16",
|
| 170 |
+
"error_message": None,
|
| 171 |
+
"created_at": "2024-01-01T10:00:00Z",
|
| 172 |
+
"started_at": "2024-01-01T10:00:05Z",
|
| 173 |
+
"completed_at": None
|
| 174 |
+
}
|
| 175 |
+
]
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class TaskListResponse(BaseModel):
|
| 181 |
+
"""
|
| 182 |
+
任务列表响应
|
| 183 |
+
|
| 184 |
+
Attributes:
|
| 185 |
+
items: 任务列表
|
| 186 |
+
total: 总数量
|
| 187 |
+
limit: 每页数量
|
| 188 |
+
offset: 偏移量
|
| 189 |
+
"""
|
| 190 |
+
items: List[TaskResponse] = Field(
|
| 191 |
+
default_factory=list,
|
| 192 |
+
description="任务列表"
|
| 193 |
+
)
|
| 194 |
+
total: int = Field(
|
| 195 |
+
default=0,
|
| 196 |
+
ge=0,
|
| 197 |
+
description="总数量"
|
| 198 |
+
)
|
| 199 |
+
limit: int = Field(
|
| 200 |
+
default=50,
|
| 201 |
+
ge=1,
|
| 202 |
+
le=100,
|
| 203 |
+
description="每页数量"
|
| 204 |
+
)
|
| 205 |
+
offset: int = Field(
|
| 206 |
+
default=0,
|
| 207 |
+
ge=0,
|
| 208 |
+
description="偏移量"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
model_config = {
|
| 212 |
+
"json_schema_extra": {
|
| 213 |
+
"examples": [
|
| 214 |
+
{
|
| 215 |
+
"items": [
|
| 216 |
+
{
|
| 217 |
+
"id": "task-123",
|
| 218 |
+
"exp_name": "voice_1",
|
| 219 |
+
"status": "completed",
|
| 220 |
+
"current_stage": None,
|
| 221 |
+
"progress": 1.0,
|
| 222 |
+
"overall_progress": 1.0,
|
| 223 |
+
"message": "训练完成"
|
| 224 |
+
}
|
| 225 |
+
],
|
| 226 |
+
"total": 1,
|
| 227 |
+
"limit": 50,
|
| 228 |
+
"offset": 0
|
| 229 |
+
}
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
}
|
api_server/app/scripts/run_pipeline.py
CHANGED
|
@@ -238,9 +238,23 @@ def build_pipeline(config: Dict[str, Any]):
|
|
| 238 |
}
|
| 239 |
|
| 240 |
# 按顺序添加阶段
|
|
|
|
|
|
|
|
|
|
| 241 |
stages = config.get("stages", [])
|
| 242 |
-
for
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
if stage_type in stage_builders:
|
| 245 |
stage = stage_builders[stage_type](stage_config)
|
| 246 |
pipeline.add_stage(stage)
|
|
|
|
| 238 |
}
|
| 239 |
|
| 240 |
# 按顺序添加阶段
|
| 241 |
+
# stages 可以是:
|
| 242 |
+
# 1. 字符串列表: ["audio_slice", "asr", ...]
|
| 243 |
+
# 2. 字典列表: [{"type": "audio_slice", "threshold": -30}, ...]
|
| 244 |
stages = config.get("stages", [])
|
| 245 |
+
for stage_item in stages:
|
| 246 |
+
# 判断是字符串还是字典
|
| 247 |
+
if isinstance(stage_item, str):
|
| 248 |
+
stage_type = stage_item
|
| 249 |
+
stage_config = config # 使用全局配置作为阶段配置
|
| 250 |
+
elif isinstance(stage_item, dict):
|
| 251 |
+
stage_type = stage_item.get("type")
|
| 252 |
+
# 合并全局配置和阶段特定配置
|
| 253 |
+
stage_config = {**config, **stage_item}
|
| 254 |
+
else:
|
| 255 |
+
emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}")
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
if stage_type in stage_builders:
|
| 259 |
stage = stage_builders[stage_type](stage_config)
|
| 260 |
pipeline.add_stage(stage)
|
api_server/app/services/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
服务层模块
|
| 3 |
+
|
| 4 |
+
业务逻辑层,封装适配器调用,提供高级业务操作。
|
| 5 |
+
|
| 6 |
+
服务列表:
|
| 7 |
+
- TaskService: Quick Mode 任务服务
|
| 8 |
+
- ExperimentService: Advanced Mode 实验服务
|
| 9 |
+
- FileService: 文件管理服务
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .task_service import TaskService
|
| 13 |
+
from .experiment_service import ExperimentService
|
| 14 |
+
from .file_service import FileService
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"TaskService",
|
| 18 |
+
"ExperimentService",
|
| 19 |
+
"FileService",
|
| 20 |
+
]
|
api_server/app/services/experiment_service.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Mode 实验服务
|
| 3 |
+
|
| 4 |
+
处理专家模式分阶段训练的业务逻辑
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import uuid
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import AsyncGenerator, Dict, List, Optional, Any
|
| 10 |
+
|
| 11 |
+
from ..core.adapters import (
|
| 12 |
+
get_database_adapter,
|
| 13 |
+
get_task_queue_adapter,
|
| 14 |
+
get_progress_adapter,
|
| 15 |
+
)
|
| 16 |
+
from ..models.schemas.experiment import (
|
| 17 |
+
ExperimentCreate,
|
| 18 |
+
ExperimentUpdate,
|
| 19 |
+
ExperimentResponse,
|
| 20 |
+
ExperimentListResponse,
|
| 21 |
+
StageStatus,
|
| 22 |
+
StageExecuteResponse,
|
| 23 |
+
StagesListResponse,
|
| 24 |
+
STAGE_DEPENDENCIES,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# 阶段类型列表(按执行顺序)
|
| 29 |
+
STAGE_TYPES = [
|
| 30 |
+
"audio_slice",
|
| 31 |
+
"asr",
|
| 32 |
+
"text_feature",
|
| 33 |
+
"hubert_feature",
|
| 34 |
+
"semantic_token",
|
| 35 |
+
"sovits_train",
|
| 36 |
+
"gpt_train",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ExperimentService:
|
| 41 |
+
"""
|
| 42 |
+
Advanced Mode 实验服务
|
| 43 |
+
|
| 44 |
+
提供专家模式的分阶段训练管理:
|
| 45 |
+
- 创建实验
|
| 46 |
+
- 查询实验/阶段状态
|
| 47 |
+
- 执行/取消单个阶段
|
| 48 |
+
- 检查阶段依赖
|
| 49 |
+
|
| 50 |
+
Example:
|
| 51 |
+
>>> service = ExperimentService()
|
| 52 |
+
>>> exp = await service.create_experiment(request)
|
| 53 |
+
>>> await service.execute_stage(exp.id, "audio_slice", {})
|
| 54 |
+
>>> stages = await service.get_all_stages(exp.id)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
"""初始化服务"""
|
| 59 |
+
self._db = None
|
| 60 |
+
self._queue = None
|
| 61 |
+
self._progress = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def db(self):
|
| 65 |
+
"""延迟获取数据库适配器"""
|
| 66 |
+
if self._db is None:
|
| 67 |
+
self._db = get_database_adapter()
|
| 68 |
+
return self._db
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def queue(self):
|
| 72 |
+
"""延迟获取任务队列适配器"""
|
| 73 |
+
if self._queue is None:
|
| 74 |
+
self._queue = get_task_queue_adapter()
|
| 75 |
+
return self._queue
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def progress_adapter(self):
|
| 79 |
+
"""延迟获取进度适配器"""
|
| 80 |
+
if self._progress is None:
|
| 81 |
+
self._progress = get_progress_adapter()
|
| 82 |
+
return self._progress
|
| 83 |
+
|
| 84 |
+
async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse:
|
| 85 |
+
"""
|
| 86 |
+
创建实验
|
| 87 |
+
|
| 88 |
+
创建实验但不立即执行,用户可以逐阶段控制训练流程。
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
request: 创建实验请求
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
ExperimentResponse
|
| 95 |
+
"""
|
| 96 |
+
exp_id = f"exp-{uuid.uuid4().hex[:8]}"
|
| 97 |
+
|
| 98 |
+
experiment_data = {
|
| 99 |
+
"id": exp_id,
|
| 100 |
+
"exp_name": request.exp_name,
|
| 101 |
+
"version": request.version,
|
| 102 |
+
"gpu_numbers": request.gpu_numbers,
|
| 103 |
+
"is_half": request.is_half,
|
| 104 |
+
"audio_file_id": request.audio_file_id,
|
| 105 |
+
"status": "created",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# 创建实验(会自动创建所有阶段)
|
| 109 |
+
experiment = await self.db.create_experiment(experiment_data)
|
| 110 |
+
|
| 111 |
+
return self._experiment_to_response(experiment)
|
| 112 |
+
|
| 113 |
+
async def get_experiment(self, exp_id: str) -> Optional[ExperimentResponse]:
|
| 114 |
+
"""
|
| 115 |
+
获取实验详情
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
exp_id: 实验ID
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
ExperimentResponse 或 None
|
| 122 |
+
"""
|
| 123 |
+
experiment = await self.db.get_experiment(exp_id)
|
| 124 |
+
if not experiment:
|
| 125 |
+
return None
|
| 126 |
+
return self._experiment_to_response(experiment)
|
| 127 |
+
|
| 128 |
+
async def list_experiments(
|
| 129 |
+
self,
|
| 130 |
+
status: Optional[str] = None,
|
| 131 |
+
limit: int = 50,
|
| 132 |
+
offset: int = 0
|
| 133 |
+
) -> ExperimentListResponse:
|
| 134 |
+
"""
|
| 135 |
+
获取实验列表
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
status: 按状态筛选
|
| 139 |
+
limit: 每页数量
|
| 140 |
+
offset: 偏移量
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
ExperimentListResponse
|
| 144 |
+
"""
|
| 145 |
+
experiments = await self.db.list_experiments(
|
| 146 |
+
status=status, limit=limit, offset=offset
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# 获取每个实验的完整信息(包含 stages)
|
| 150 |
+
full_experiments = []
|
| 151 |
+
for exp in experiments:
|
| 152 |
+
full_exp = await self.db.get_experiment(exp["id"])
|
| 153 |
+
if full_exp:
|
| 154 |
+
full_experiments.append(full_exp)
|
| 155 |
+
|
| 156 |
+
return ExperimentListResponse(
|
| 157 |
+
items=[self._experiment_to_response(e) for e in full_experiments],
|
| 158 |
+
total=len(experiments), # TODO: 添加 count 方法
|
| 159 |
+
limit=limit,
|
| 160 |
+
offset=offset,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
async def update_experiment(
|
| 164 |
+
self,
|
| 165 |
+
exp_id: str,
|
| 166 |
+
request: ExperimentUpdate
|
| 167 |
+
) -> Optional[ExperimentResponse]:
|
| 168 |
+
"""
|
| 169 |
+
更新实验基础配置
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
exp_id: 实验ID
|
| 173 |
+
request: 更新请求
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
ExperimentResponse 或 None
|
| 177 |
+
"""
|
| 178 |
+
updates = {}
|
| 179 |
+
if request.exp_name is not None:
|
| 180 |
+
updates["exp_name"] = request.exp_name
|
| 181 |
+
if request.gpu_numbers is not None:
|
| 182 |
+
updates["gpu_numbers"] = request.gpu_numbers
|
| 183 |
+
if request.is_half is not None:
|
| 184 |
+
updates["is_half"] = request.is_half
|
| 185 |
+
|
| 186 |
+
if not updates:
|
| 187 |
+
return await self.get_experiment(exp_id)
|
| 188 |
+
|
| 189 |
+
experiment = await self.db.update_experiment(exp_id, updates)
|
| 190 |
+
if not experiment:
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
return self._experiment_to_response(experiment)
|
| 194 |
+
|
| 195 |
+
async def delete_experiment(self, exp_id: str) -> bool:
|
| 196 |
+
"""
|
| 197 |
+
删除实验
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
exp_id: 实验ID
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
是否成功删除
|
| 204 |
+
"""
|
| 205 |
+
# 先取消所有运行中的阶段
|
| 206 |
+
stages = await self.db.get_all_stages(exp_id)
|
| 207 |
+
for stage in stages:
|
| 208 |
+
if stage.get("status") == "running" and stage.get("job_id"):
|
| 209 |
+
await self.queue.cancel(stage["job_id"])
|
| 210 |
+
|
| 211 |
+
return await self.db.delete_experiment(exp_id)
|
| 212 |
+
|
| 213 |
+
async def check_stage_dependencies(
|
| 214 |
+
self,
|
| 215 |
+
exp_id: str,
|
| 216 |
+
stage_type: str
|
| 217 |
+
) -> Dict[str, Any]:
|
| 218 |
+
"""
|
| 219 |
+
检查阶段依赖是否满足
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
exp_id: 实验ID
|
| 223 |
+
stage_type: 阶段类型
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
{"satisfied": bool, "missing": List[str]}
|
| 227 |
+
"""
|
| 228 |
+
experiment = await self.db.get_experiment(exp_id)
|
| 229 |
+
if not experiment:
|
| 230 |
+
return {"satisfied": False, "missing": [], "error": "实验不存在"}
|
| 231 |
+
|
| 232 |
+
dependencies = STAGE_DEPENDENCIES.get(stage_type, [])
|
| 233 |
+
stages = experiment.get("stages", {})
|
| 234 |
+
|
| 235 |
+
missing = []
|
| 236 |
+
for dep in dependencies:
|
| 237 |
+
dep_stage = stages.get(dep, {})
|
| 238 |
+
if dep_stage.get("status") != "completed":
|
| 239 |
+
missing.append(dep)
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"satisfied": len(missing) == 0,
|
| 243 |
+
"missing": missing,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
async def execute_stage(
|
| 247 |
+
self,
|
| 248 |
+
exp_id: str,
|
| 249 |
+
stage_type: str,
|
| 250 |
+
params: Dict[str, Any]
|
| 251 |
+
) -> Optional[StageExecuteResponse]:
|
| 252 |
+
"""
|
| 253 |
+
执行指定阶段
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
exp_id: 实验ID
|
| 257 |
+
stage_type: 阶段类型
|
| 258 |
+
params: 阶段参数
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
StageExecuteResponse 或 None
|
| 262 |
+
"""
|
| 263 |
+
# 获取实验
|
| 264 |
+
experiment = await self.db.get_experiment(exp_id)
|
| 265 |
+
if not experiment:
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
stages = experiment.get("stages", {})
|
| 269 |
+
current_stage = stages.get(stage_type, {})
|
| 270 |
+
|
| 271 |
+
# 检查是否是重新执行
|
| 272 |
+
is_rerun = current_stage.get("status") == "completed"
|
| 273 |
+
previous_run = None
|
| 274 |
+
if is_rerun:
|
| 275 |
+
previous_run = {
|
| 276 |
+
"completed_at": current_stage.get("completed_at"),
|
| 277 |
+
"outputs": current_stage.get("outputs"),
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
# 构建阶段配置
|
| 281 |
+
stage_config = {
|
| 282 |
+
"exp_id": exp_id,
|
| 283 |
+
"exp_name": experiment["exp_name"],
|
| 284 |
+
"version": experiment.get("version", "v2"),
|
| 285 |
+
"gpu_numbers": experiment.get("gpu_numbers", "0"),
|
| 286 |
+
"is_half": experiment.get("is_half", True),
|
| 287 |
+
"audio_file_id": experiment.get("audio_file_id"),
|
| 288 |
+
"stage_type": stage_type,
|
| 289 |
+
"params": params,
|
| 290 |
+
# 只执行单个阶段
|
| 291 |
+
"stages": [stage_type],
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
# 生成任务ID(用于进度追踪)
|
| 295 |
+
task_id = f"{exp_id}-{stage_type}-{uuid.uuid4().hex[:4]}"
|
| 296 |
+
|
| 297 |
+
# 入队执行
|
| 298 |
+
job_id = await self.queue.enqueue(task_id, stage_config)
|
| 299 |
+
|
| 300 |
+
# 更新阶段状态
|
| 301 |
+
now = datetime.utcnow()
|
| 302 |
+
await self.db.update_stage(exp_id, stage_type, {
|
| 303 |
+
"status": "running",
|
| 304 |
+
"config": params,
|
| 305 |
+
"job_id": job_id,
|
| 306 |
+
"started_at": now,
|
| 307 |
+
"completed_at": None,
|
| 308 |
+
"error_message": None,
|
| 309 |
+
"outputs": None,
|
| 310 |
+
"progress": 0.0,
|
| 311 |
+
})
|
| 312 |
+
|
| 313 |
+
# 更新实验状态为运行中
|
| 314 |
+
await self.db.update_experiment(exp_id, {"status": "running"})
|
| 315 |
+
|
| 316 |
+
return StageExecuteResponse(
|
| 317 |
+
exp_id=exp_id,
|
| 318 |
+
stage_type=stage_type,
|
| 319 |
+
status="running",
|
| 320 |
+
job_id=job_id,
|
| 321 |
+
config=params,
|
| 322 |
+
rerun=is_rerun,
|
| 323 |
+
previous_run=previous_run,
|
| 324 |
+
started_at=now,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
async def get_stage(
|
| 328 |
+
self,
|
| 329 |
+
exp_id: str,
|
| 330 |
+
stage_type: str
|
| 331 |
+
) -> Optional[StageStatus]:
|
| 332 |
+
"""
|
| 333 |
+
获取阶段状态
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
exp_id: 实验ID
|
| 337 |
+
stage_type: 阶段类型
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
StageStatus 或 None
|
| 341 |
+
"""
|
| 342 |
+
stage = await self.db.get_stage(exp_id, stage_type)
|
| 343 |
+
if not stage:
|
| 344 |
+
return None
|
| 345 |
+
return self._stage_to_status(stage)
|
| 346 |
+
|
| 347 |
+
async def get_all_stages(self, exp_id: str) -> Optional[StagesListResponse]:
|
| 348 |
+
"""
|
| 349 |
+
获取所有阶段状态
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
exp_id: 实验ID
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
StagesListResponse 或 None
|
| 356 |
+
"""
|
| 357 |
+
stages = await self.db.get_all_stages(exp_id)
|
| 358 |
+
if not stages:
|
| 359 |
+
# 检查实验是否存在
|
| 360 |
+
experiment = await self.db.get_experiment(exp_id)
|
| 361 |
+
if not experiment:
|
| 362 |
+
return None
|
| 363 |
+
stages = []
|
| 364 |
+
|
| 365 |
+
return StagesListResponse(
|
| 366 |
+
exp_id=exp_id,
|
| 367 |
+
stages=[self._stage_to_status(s) for s in stages],
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
async def cancel_stage(self, exp_id: str, stage_type: str) -> bool:
|
| 371 |
+
"""
|
| 372 |
+
取消正在执行的阶段
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
exp_id: 实验ID
|
| 376 |
+
stage_type: 阶段类型
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
是否成功取消
|
| 380 |
+
"""
|
| 381 |
+
stage = await self.db.get_stage(exp_id, stage_type)
|
| 382 |
+
if not stage:
|
| 383 |
+
return False
|
| 384 |
+
|
| 385 |
+
# 只有运行中的阶段可以取消
|
| 386 |
+
if stage.get("status") != "running":
|
| 387 |
+
return False
|
| 388 |
+
|
| 389 |
+
# 取消任务
|
| 390 |
+
job_id = stage.get("job_id")
|
| 391 |
+
if job_id:
|
| 392 |
+
await self.queue.cancel(job_id)
|
| 393 |
+
|
| 394 |
+
# 更新状态
|
| 395 |
+
await self.db.update_stage(exp_id, stage_type, {
|
| 396 |
+
"status": "cancelled",
|
| 397 |
+
"completed_at": datetime.utcnow(),
|
| 398 |
+
"message": "阶段已取消",
|
| 399 |
+
})
|
| 400 |
+
|
| 401 |
+
return True
|
| 402 |
+
|
| 403 |
+
async def subscribe_stage_progress(
|
| 404 |
+
self,
|
| 405 |
+
exp_id: str,
|
| 406 |
+
stage_type: str
|
| 407 |
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
| 408 |
+
"""
|
| 409 |
+
订阅阶段进度(SSE 流)
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
exp_id: 实验ID
|
| 413 |
+
stage_type: 阶段类型
|
| 414 |
+
|
| 415 |
+
Yields:
|
| 416 |
+
进度信息字典
|
| 417 |
+
"""
|
| 418 |
+
# 获取阶段信息
|
| 419 |
+
stage = await self.db.get_stage(exp_id, stage_type)
|
| 420 |
+
if not stage:
|
| 421 |
+
yield {"type": "error", "message": "阶段不存在"}
|
| 422 |
+
return
|
| 423 |
+
|
| 424 |
+
# 如果阶段已结束,直接返回最终状态
|
| 425 |
+
if stage.get("status") in ("completed", "failed", "cancelled"):
|
| 426 |
+
yield {
|
| 427 |
+
"type": "final",
|
| 428 |
+
"status": stage.get("status"),
|
| 429 |
+
"message": stage.get("message") or stage.get("error_message"),
|
| 430 |
+
"progress": stage.get("progress", 0.0),
|
| 431 |
+
"outputs": stage.get("outputs"),
|
| 432 |
+
}
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
# 如果阶段未开始
|
| 436 |
+
if stage.get("status") == "pending":
|
| 437 |
+
yield {"type": "info", "message": "阶段尚未开始"}
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
# 使用任务ID订阅进度
|
| 441 |
+
# 任务ID格式: {exp_id}-{stage_type}-{random}
|
| 442 |
+
# 由于我们不知道确切的任务ID,使用 job_id
|
| 443 |
+
job_id = stage.get("job_id")
|
| 444 |
+
if not job_id:
|
| 445 |
+
yield {"type": "error", "message": "无法获取任务ID"}
|
| 446 |
+
return
|
| 447 |
+
|
| 448 |
+
# 订阅进度
|
| 449 |
+
# 注意:这里需要根据实际的进度适配器实现来调整
|
| 450 |
+
# 当前使用 task_id 格式为 "{exp_id}-{stage_type}"
|
| 451 |
+
task_id = f"{exp_id}-{stage_type}"
|
| 452 |
+
|
| 453 |
+
async for progress in self.progress_adapter.subscribe(task_id):
|
| 454 |
+
yield progress
|
| 455 |
+
|
| 456 |
+
# 检查是否为终态
|
| 457 |
+
if progress.get("status") in ("completed", "failed", "cancelled"):
|
| 458 |
+
break
|
| 459 |
+
|
| 460 |
+
def _experiment_to_response(self, experiment: Dict[str, Any]) -> ExperimentResponse:
|
| 461 |
+
"""将实验数据转换为响应模型"""
|
| 462 |
+
stages_data = experiment.get("stages", {})
|
| 463 |
+
stages = {}
|
| 464 |
+
|
| 465 |
+
for stage_type, stage_info in stages_data.items():
|
| 466 |
+
stages[stage_type] = self._stage_to_status(stage_info)
|
| 467 |
+
|
| 468 |
+
# 解析日期时间
|
| 469 |
+
created_at = experiment.get("created_at")
|
| 470 |
+
if isinstance(created_at, str):
|
| 471 |
+
created_at = datetime.fromisoformat(created_at)
|
| 472 |
+
elif created_at is None:
|
| 473 |
+
created_at = datetime.utcnow()
|
| 474 |
+
|
| 475 |
+
updated_at = experiment.get("updated_at")
|
| 476 |
+
if isinstance(updated_at, str):
|
| 477 |
+
updated_at = datetime.fromisoformat(updated_at)
|
| 478 |
+
|
| 479 |
+
return ExperimentResponse(
|
| 480 |
+
id=experiment["id"],
|
| 481 |
+
exp_name=experiment["exp_name"],
|
| 482 |
+
version=experiment.get("version", "v2"),
|
| 483 |
+
status=experiment.get("status", "created"),
|
| 484 |
+
gpu_numbers=experiment.get("gpu_numbers", "0"),
|
| 485 |
+
is_half=experiment.get("is_half", True),
|
| 486 |
+
audio_file_id=experiment.get("audio_file_id", ""),
|
| 487 |
+
stages=stages,
|
| 488 |
+
created_at=created_at,
|
| 489 |
+
updated_at=updated_at,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
def _stage_to_status(self, stage: Dict[str, Any]) -> StageStatus:
|
| 493 |
+
"""将阶段数据转换为状态模型"""
|
| 494 |
+
# 解析日期时间
|
| 495 |
+
started_at = stage.get("started_at")
|
| 496 |
+
if isinstance(started_at, str):
|
| 497 |
+
started_at = datetime.fromisoformat(started_at)
|
| 498 |
+
|
| 499 |
+
completed_at = stage.get("completed_at")
|
| 500 |
+
if isinstance(completed_at, str):
|
| 501 |
+
completed_at = datetime.fromisoformat(completed_at)
|
| 502 |
+
|
| 503 |
+
return StageStatus(
|
| 504 |
+
stage_type=stage.get("stage_type", ""),
|
| 505 |
+
status=stage.get("status", "pending"),
|
| 506 |
+
progress=stage.get("progress"),
|
| 507 |
+
message=stage.get("message"),
|
| 508 |
+
started_at=started_at,
|
| 509 |
+
completed_at=completed_at,
|
| 510 |
+
config=stage.get("config"),
|
| 511 |
+
outputs=stage.get("outputs"),
|
| 512 |
+
error_message=stage.get("error_message"),
|
| 513 |
+
)
|
api_server/app/services/file_service.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文件管理服务
|
| 3 |
+
|
| 4 |
+
处理文件上传、下载和管理的业务逻辑
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from ..core.adapters import get_database_adapter, get_storage_adapter
|
| 11 |
+
from ..models.schemas.file import (
|
| 12 |
+
FileMetadata,
|
| 13 |
+
FileUploadResponse,
|
| 14 |
+
FileListResponse,
|
| 15 |
+
FileDeleteResponse,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FileService:
|
| 20 |
+
"""
|
| 21 |
+
文件管理服务
|
| 22 |
+
|
| 23 |
+
提供文件的完整生命周期管理:
|
| 24 |
+
- 上传文件
|
| 25 |
+
- 下载文件
|
| 26 |
+
- 获取元数据
|
| 27 |
+
- 列出文件
|
| 28 |
+
- 删除文件
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
>>> service = FileService()
|
| 32 |
+
>>> result = await service.upload_file(data, "audio.wav", "audio/wav", "training")
|
| 33 |
+
>>> content = await service.download_file(result.file.id)
|
| 34 |
+
>>> await service.delete_file(result.file.id)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
"""初始化服务"""
|
| 39 |
+
self._db = None
|
| 40 |
+
self._storage = None
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def db(self):
|
| 44 |
+
"""延迟获取数据库适配器"""
|
| 45 |
+
if self._db is None:
|
| 46 |
+
self._db = get_database_adapter()
|
| 47 |
+
return self._db
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def storage(self):
|
| 51 |
+
"""延迟获取存储适配器"""
|
| 52 |
+
if self._storage is None:
|
| 53 |
+
self._storage = get_storage_adapter()
|
| 54 |
+
return self._storage
|
| 55 |
+
|
| 56 |
+
async def upload_file(
|
| 57 |
+
self,
|
| 58 |
+
file_data: bytes,
|
| 59 |
+
filename: str,
|
| 60 |
+
content_type: Optional[str] = None,
|
| 61 |
+
purpose: str = "training"
|
| 62 |
+
) -> FileUploadResponse:
|
| 63 |
+
"""
|
| 64 |
+
上传文件
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
file_data: 文件二进制数据
|
| 68 |
+
filename: 原始文件名
|
| 69 |
+
content_type: MIME 类型
|
| 70 |
+
purpose: 文件用途 (training, reference, output)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
FileUploadResponse
|
| 74 |
+
"""
|
| 75 |
+
# 构建元数据
|
| 76 |
+
metadata = {
|
| 77 |
+
"content_type": content_type,
|
| 78 |
+
"purpose": purpose,
|
| 79 |
+
"size_bytes": len(file_data),
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# 上传到存储
|
| 83 |
+
file_id = await self.storage.upload_file(file_data, filename, metadata)
|
| 84 |
+
|
| 85 |
+
# 获取完整元数据(包含音频信息)
|
| 86 |
+
full_metadata = await self.storage.get_file_metadata(file_id)
|
| 87 |
+
|
| 88 |
+
# 保存到数据库
|
| 89 |
+
file_record = {
|
| 90 |
+
"id": file_id,
|
| 91 |
+
"filename": filename,
|
| 92 |
+
"content_type": content_type,
|
| 93 |
+
"size_bytes": len(file_data),
|
| 94 |
+
"purpose": purpose,
|
| 95 |
+
"duration_seconds": full_metadata.get("duration_seconds") if full_metadata else None,
|
| 96 |
+
"sample_rate": full_metadata.get("sample_rate") if full_metadata else None,
|
| 97 |
+
"uploaded_at": datetime.utcnow().isoformat(),
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
await self.db.create_file_record(file_record)
|
| 101 |
+
|
| 102 |
+
# 构建响应
|
| 103 |
+
file_metadata = FileMetadata(
|
| 104 |
+
id=file_id,
|
| 105 |
+
filename=filename,
|
| 106 |
+
content_type=content_type,
|
| 107 |
+
size_bytes=len(file_data),
|
| 108 |
+
purpose=purpose,
|
| 109 |
+
duration_seconds=file_record.get("duration_seconds"),
|
| 110 |
+
sample_rate=file_record.get("sample_rate"),
|
| 111 |
+
uploaded_at=datetime.utcnow(),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return FileUploadResponse(
|
| 115 |
+
success=True,
|
| 116 |
+
message="文件上传成功",
|
| 117 |
+
file=file_metadata,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
async def download_file(self, file_id: str) -> Optional[Tuple[bytes, str, str]]:
|
| 121 |
+
"""
|
| 122 |
+
下载文件
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
file_id: 文件ID
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
(文件数据, 文件名, 内容类型) 或 None
|
| 129 |
+
"""
|
| 130 |
+
# 检查文件是否存在
|
| 131 |
+
if not await self.storage.file_exists(file_id):
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
# 获取元数据
|
| 135 |
+
metadata = await self.storage.get_file_metadata(file_id)
|
| 136 |
+
if not metadata:
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
# 下载文件
|
| 140 |
+
file_data = await self.storage.download_file(file_id)
|
| 141 |
+
|
| 142 |
+
return (
|
| 143 |
+
file_data,
|
| 144 |
+
metadata.get("filename", "file"),
|
| 145 |
+
metadata.get("content_type", "application/octet-stream"),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
async def get_file(self, file_id: str) -> Optional[FileMetadata]:
|
| 149 |
+
"""
|
| 150 |
+
获取文件元数据
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
file_id: 文件ID
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
FileMetadata 或 None
|
| 157 |
+
"""
|
| 158 |
+
# 从数据库获取
|
| 159 |
+
record = await self.db.get_file_record(file_id)
|
| 160 |
+
if record:
|
| 161 |
+
return self._record_to_metadata(record)
|
| 162 |
+
|
| 163 |
+
# 尝试从存储获取
|
| 164 |
+
metadata = await self.storage.get_file_metadata(file_id)
|
| 165 |
+
if metadata:
|
| 166 |
+
return self._storage_metadata_to_file_metadata(metadata)
|
| 167 |
+
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
async def list_files(
|
| 171 |
+
self,
|
| 172 |
+
purpose: Optional[str] = None,
|
| 173 |
+
limit: int = 50,
|
| 174 |
+
offset: int = 0
|
| 175 |
+
) -> FileListResponse:
|
| 176 |
+
"""
|
| 177 |
+
获取文件列表
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
purpose: 按用途筛选
|
| 181 |
+
limit: 每页数量
|
| 182 |
+
offset: 偏移量
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
FileListResponse
|
| 186 |
+
"""
|
| 187 |
+
# 从数据库获取
|
| 188 |
+
records = await self.db.list_file_records(
|
| 189 |
+
purpose=purpose, limit=limit, offset=offset
|
| 190 |
+
)
|
| 191 |
+
total = await self.db.count_file_records(purpose=purpose)
|
| 192 |
+
|
| 193 |
+
return FileListResponse(
|
| 194 |
+
items=[self._record_to_metadata(r) for r in records],
|
| 195 |
+
total=total,
|
| 196 |
+
limit=limit,
|
| 197 |
+
offset=offset,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
async def delete_file(self, file_id: str) -> FileDeleteResponse:
|
| 201 |
+
"""
|
| 202 |
+
删除文件
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
file_id: 文件ID
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
FileDeleteResponse
|
| 209 |
+
"""
|
| 210 |
+
# 从存储删除
|
| 211 |
+
storage_deleted = await self.storage.delete_file(file_id)
|
| 212 |
+
|
| 213 |
+
# 从数据库删除
|
| 214 |
+
db_deleted = await self.db.delete_file_record(file_id)
|
| 215 |
+
|
| 216 |
+
if storage_deleted or db_deleted:
|
| 217 |
+
return FileDeleteResponse(
|
| 218 |
+
success=True,
|
| 219 |
+
message="文件删除成功",
|
| 220 |
+
file_id=file_id,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
return FileDeleteResponse(
|
| 224 |
+
success=False,
|
| 225 |
+
message="文件不存在或已删除",
|
| 226 |
+
file_id=file_id,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
async def file_exists(self, file_id: str) -> bool:
|
| 230 |
+
"""
|
| 231 |
+
检查文件是否存在
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
file_id: 文件ID
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
是否存在
|
| 238 |
+
"""
|
| 239 |
+
return await self.storage.file_exists(file_id)
|
| 240 |
+
|
| 241 |
+
def _record_to_metadata(self, record: dict) -> FileMetadata:
|
| 242 |
+
"""将数据库记录转换为 FileMetadata"""
|
| 243 |
+
uploaded_at = record.get("uploaded_at")
|
| 244 |
+
if isinstance(uploaded_at, str):
|
| 245 |
+
uploaded_at = datetime.fromisoformat(uploaded_at)
|
| 246 |
+
elif uploaded_at is None:
|
| 247 |
+
uploaded_at = datetime.utcnow()
|
| 248 |
+
|
| 249 |
+
return FileMetadata(
|
| 250 |
+
id=record["id"],
|
| 251 |
+
filename=record["filename"],
|
| 252 |
+
content_type=record.get("content_type"),
|
| 253 |
+
size_bytes=record.get("size_bytes", 0),
|
| 254 |
+
purpose=record.get("purpose", "training"),
|
| 255 |
+
duration_seconds=record.get("duration_seconds"),
|
| 256 |
+
sample_rate=record.get("sample_rate"),
|
| 257 |
+
uploaded_at=uploaded_at,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def _storage_metadata_to_file_metadata(self, metadata: dict) -> FileMetadata:
|
| 261 |
+
"""将存储元数据转换为 FileMetadata"""
|
| 262 |
+
uploaded_at = metadata.get("uploaded_at")
|
| 263 |
+
if isinstance(uploaded_at, str):
|
| 264 |
+
uploaded_at = datetime.fromisoformat(uploaded_at)
|
| 265 |
+
elif uploaded_at is None:
|
| 266 |
+
uploaded_at = datetime.utcnow()
|
| 267 |
+
|
| 268 |
+
return FileMetadata(
|
| 269 |
+
id=metadata.get("id", ""),
|
| 270 |
+
filename=metadata.get("filename", ""),
|
| 271 |
+
content_type=metadata.get("content_type"),
|
| 272 |
+
size_bytes=metadata.get("size_bytes", 0),
|
| 273 |
+
purpose=metadata.get("purpose", "training"),
|
| 274 |
+
duration_seconds=metadata.get("duration_seconds"),
|
| 275 |
+
sample_rate=metadata.get("sample_rate"),
|
| 276 |
+
uploaded_at=uploaded_at,
|
| 277 |
+
)
|
api_server/app/services/task_service.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick Mode 任务服务
|
| 3 |
+
|
| 4 |
+
处理一键训练任务的业务逻辑
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import uuid
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import AsyncGenerator, Dict, List, Optional, Any
|
| 10 |
+
|
| 11 |
+
from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter
|
| 12 |
+
from ..core.config import settings
|
| 13 |
+
from ..models.domain import Task, TaskStatus
|
| 14 |
+
from ..models.schemas.task import (
|
| 15 |
+
QuickModeRequest,
|
| 16 |
+
TaskResponse,
|
| 17 |
+
TaskListResponse,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# 质量预设配置
|
| 22 |
+
QUALITY_PRESETS = {
|
| 23 |
+
"fast": {
|
| 24 |
+
"sovits_epochs": 4,
|
| 25 |
+
"gpt_epochs": 8,
|
| 26 |
+
"description": "快速训练,约10分钟",
|
| 27 |
+
},
|
| 28 |
+
"standard": {
|
| 29 |
+
"sovits_epochs": 8,
|
| 30 |
+
"gpt_epochs": 15,
|
| 31 |
+
"description": "标准训练,约20分钟",
|
| 32 |
+
},
|
| 33 |
+
"high": {
|
| 34 |
+
"sovits_epochs": 16,
|
| 35 |
+
"gpt_epochs": 30,
|
| 36 |
+
"description": "高质量训练,约40分钟",
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TaskService:
|
| 42 |
+
"""
|
| 43 |
+
Quick Mode 任务服务
|
| 44 |
+
|
| 45 |
+
提供一键训练任务的完整生命周期管理:
|
| 46 |
+
- 创建任务
|
| 47 |
+
- 查询任务状态
|
| 48 |
+
- 取消任务
|
| 49 |
+
- 订阅进度更新
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
>>> service = TaskService()
|
| 53 |
+
>>> task = await service.create_quick_task(request)
|
| 54 |
+
>>> status = await service.get_task(task.id)
|
| 55 |
+
>>> await service.cancel_task(task.id)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self):
|
| 59 |
+
"""初始化服务"""
|
| 60 |
+
self._db = None
|
| 61 |
+
self._queue = None
|
| 62 |
+
self._storage = None
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def db(self):
|
| 66 |
+
"""延迟获取数据库适配器"""
|
| 67 |
+
if self._db is None:
|
| 68 |
+
self._db = get_database_adapter()
|
| 69 |
+
return self._db
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def queue(self):
|
| 73 |
+
"""延迟获取任务队列适配器"""
|
| 74 |
+
if self._queue is None:
|
| 75 |
+
self._queue = get_task_queue_adapter()
|
| 76 |
+
return self._queue
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def storage(self):
|
| 80 |
+
"""延迟获取存储适配器"""
|
| 81 |
+
if self._storage is None:
|
| 82 |
+
self._storage = get_storage_adapter()
|
| 83 |
+
return self._storage
|
| 84 |
+
|
| 85 |
+
async def check_exp_name_exists(self, exp_name: str) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
检查实验名称是否已存在
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
exp_name: 实验名称
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
如果存在返回 True,否则返回 False
|
| 94 |
+
"""
|
| 95 |
+
existing_task = await self.db.get_task_by_exp_name(exp_name)
|
| 96 |
+
return existing_task is not None
|
| 97 |
+
|
| 98 |
+
async def validate_audio_file(self, audio_file_id: str) -> tuple[bool, str]:
|
| 99 |
+
"""
|
| 100 |
+
验证音频文件是否存在
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
audio_file_id: 音频文件 ID 或路径
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
(是否存在, 实际文件路径)
|
| 107 |
+
"""
|
| 108 |
+
import os
|
| 109 |
+
|
| 110 |
+
# 尝试获取文件元数据
|
| 111 |
+
file_metadata = await self.storage.get_file_metadata(audio_file_id)
|
| 112 |
+
|
| 113 |
+
if file_metadata:
|
| 114 |
+
# 文件存储在 storage.base_path / file_id
|
| 115 |
+
audio_file_path = str(self.storage.base_path / audio_file_id)
|
| 116 |
+
exists = os.path.exists(audio_file_path)
|
| 117 |
+
return exists, audio_file_path
|
| 118 |
+
else:
|
| 119 |
+
# 如果找不到元数据,将 audio_file_id 当作路径
|
| 120 |
+
exists = os.path.exists(audio_file_id)
|
| 121 |
+
return exists, audio_file_id
|
| 122 |
+
|
| 123 |
+
async def create_quick_task(self, request: QuickModeRequest) -> TaskResponse:
|
| 124 |
+
"""
|
| 125 |
+
创建一键训练任务
|
| 126 |
+
|
| 127 |
+
根据请求参数和质量预设,自动配置训练参数并创建任务。
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
request: 快速模式请求
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
TaskResponse: 任务响应
|
| 134 |
+
"""
|
| 135 |
+
# 生成任务ID
|
| 136 |
+
task_id = f"task-{uuid.uuid4().hex[:12]}"
|
| 137 |
+
|
| 138 |
+
# 获取质量预设
|
| 139 |
+
quality = request.options.quality
|
| 140 |
+
preset = QUALITY_PRESETS.get(quality, QUALITY_PRESETS["standard"])
|
| 141 |
+
|
| 142 |
+
# 验证并解析音频文件路径
|
| 143 |
+
audio_file_id = request.audio_file_id
|
| 144 |
+
_, audio_file_path = await self.validate_audio_file(audio_file_id)
|
| 145 |
+
|
| 146 |
+
# 构建任务配置
|
| 147 |
+
config = {
|
| 148 |
+
"exp_name": request.exp_name,
|
| 149 |
+
"audio_file_id": audio_file_id,
|
| 150 |
+
"input_path": audio_file_path, # 音频文件的实际路径
|
| 151 |
+
"version": request.options.version,
|
| 152 |
+
"language": request.options.language,
|
| 153 |
+
"quality": quality,
|
| 154 |
+
# 训练参数
|
| 155 |
+
"total_epoch": preset["sovits_epochs"], # SoVITS epoch
|
| 156 |
+
"sovits_epochs": preset["sovits_epochs"],
|
| 157 |
+
"gpt_epochs": preset["gpt_epochs"],
|
| 158 |
+
# 预训练模型路径
|
| 159 |
+
"bert_pretrained_dir": str(settings.BERT_PRETRAINED_DIR),
|
| 160 |
+
"ssl_pretrained_dir": str(settings.SSL_PRETRAINED_DIR),
|
| 161 |
+
"pretrained_s2G": str(settings.PRETRAINED_S2G),
|
| 162 |
+
"pretrained_s2D": str(settings.PRETRAINED_S2D),
|
| 163 |
+
"pretrained_s1": str(settings.PRETRAINED_S1),
|
| 164 |
+
# 执行完整流程
|
| 165 |
+
"stages": [
|
| 166 |
+
"audio_slice",
|
| 167 |
+
"asr",
|
| 168 |
+
"text_feature",
|
| 169 |
+
"hubert_feature",
|
| 170 |
+
"semantic_token",
|
| 171 |
+
"sovits_train",
|
| 172 |
+
"gpt_train",
|
| 173 |
+
],
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
# 创建 Task 领域模型
|
| 177 |
+
task = Task(
|
| 178 |
+
id=task_id,
|
| 179 |
+
exp_name=request.exp_name,
|
| 180 |
+
config=config,
|
| 181 |
+
status=TaskStatus.QUEUED,
|
| 182 |
+
created_at=datetime.utcnow(),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# 保存到数据库
|
| 186 |
+
await self.db.create_task(task)
|
| 187 |
+
|
| 188 |
+
# 入队执行
|
| 189 |
+
job_id = await self.queue.enqueue(task_id, config)
|
| 190 |
+
|
| 191 |
+
# 更新 job_id
|
| 192 |
+
await self.db.update_task(task_id, {"job_id": job_id})
|
| 193 |
+
task.job_id = job_id
|
| 194 |
+
|
| 195 |
+
return self._task_to_response(task)
|
| 196 |
+
|
| 197 |
+
async def get_task(self, task_id: str) -> Optional[TaskResponse]:
|
| 198 |
+
"""
|
| 199 |
+
获取任务详情
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
task_id: 任务ID
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
TaskResponse 或 None(不存在时)
|
| 206 |
+
"""
|
| 207 |
+
task = await self.db.get_task(task_id)
|
| 208 |
+
if not task:
|
| 209 |
+
return None
|
| 210 |
+
return self._task_to_response(task)
|
| 211 |
+
|
| 212 |
+
async def list_tasks(
|
| 213 |
+
self,
|
| 214 |
+
status: Optional[str] = None,
|
| 215 |
+
limit: int = 50,
|
| 216 |
+
offset: int = 0
|
| 217 |
+
) -> TaskListResponse:
|
| 218 |
+
"""
|
| 219 |
+
获取任务列表
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
status: 按状态筛选
|
| 223 |
+
limit: 每页数量
|
| 224 |
+
offset: 偏移量
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
TaskListResponse
|
| 228 |
+
"""
|
| 229 |
+
tasks = await self.db.list_tasks(status=status, limit=limit, offset=offset)
|
| 230 |
+
total = await self.db.count_tasks(status=status)
|
| 231 |
+
|
| 232 |
+
return TaskListResponse(
|
| 233 |
+
items=[self._task_to_response(t) for t in tasks],
|
| 234 |
+
total=total,
|
| 235 |
+
limit=limit,
|
| 236 |
+
offset=offset,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
async def cancel_task(self, task_id: str) -> bool:
|
| 240 |
+
"""
|
| 241 |
+
取消任务
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
task_id: 任务ID
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
是否成功取消
|
| 248 |
+
"""
|
| 249 |
+
# 获取任务
|
| 250 |
+
task = await self.db.get_task(task_id)
|
| 251 |
+
if not task:
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
# 只有排队中或运行中的任务可以取消
|
| 255 |
+
if task.status not in (TaskStatus.QUEUED, TaskStatus.RUNNING):
|
| 256 |
+
return False
|
| 257 |
+
|
| 258 |
+
# 如果有 job_id,尝试取消队列任务
|
| 259 |
+
if task.job_id:
|
| 260 |
+
await self.queue.cancel(task.job_id)
|
| 261 |
+
|
| 262 |
+
# 更新状态
|
| 263 |
+
await self.db.update_task(task_id, {
|
| 264 |
+
"status": TaskStatus.CANCELLED,
|
| 265 |
+
"completed_at": datetime.utcnow(),
|
| 266 |
+
"message": "任务已取消",
|
| 267 |
+
})
|
| 268 |
+
|
| 269 |
+
return True
|
| 270 |
+
|
| 271 |
+
async def subscribe_progress(
|
| 272 |
+
self,
|
| 273 |
+
task_id: str
|
| 274 |
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
| 275 |
+
"""
|
| 276 |
+
订阅任务进度(SSE 流)
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
task_id: 任务ID
|
| 280 |
+
|
| 281 |
+
Yields:
|
| 282 |
+
进度信息字典
|
| 283 |
+
"""
|
| 284 |
+
# 检查任务是否存在
|
| 285 |
+
task = await self.db.get_task(task_id)
|
| 286 |
+
if not task:
|
| 287 |
+
yield {"type": "error", "message": "任务不存在"}
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
# 如果任务已结束,直接返回最终状态
|
| 291 |
+
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
| 292 |
+
yield {
|
| 293 |
+
"type": "final",
|
| 294 |
+
"status": task.status.value,
|
| 295 |
+
"message": task.message or task.error_message,
|
| 296 |
+
"progress": task.progress,
|
| 297 |
+
}
|
| 298 |
+
return
|
| 299 |
+
|
| 300 |
+
# 订阅进度更新
|
| 301 |
+
async for progress in self.queue.subscribe_progress(task_id):
|
| 302 |
+
yield progress
|
| 303 |
+
|
| 304 |
+
# 检查是否为终态
|
| 305 |
+
if progress.get("status") in ("completed", "failed", "cancelled"):
|
| 306 |
+
break
|
| 307 |
+
|
| 308 |
+
def _task_to_response(self, task: Task) -> TaskResponse:
|
| 309 |
+
"""将 Task 领域模型转换为 TaskResponse"""
|
| 310 |
+
return TaskResponse(
|
| 311 |
+
id=task.id,
|
| 312 |
+
exp_name=task.exp_name,
|
| 313 |
+
status=task.status.value if isinstance(task.status, TaskStatus) else task.status,
|
| 314 |
+
current_stage=task.current_stage,
|
| 315 |
+
progress=task.stage_progress,
|
| 316 |
+
overall_progress=task.progress,
|
| 317 |
+
message=task.message,
|
| 318 |
+
error_message=task.error_message,
|
| 319 |
+
created_at=task.created_at,
|
| 320 |
+
started_at=task.started_at,
|
| 321 |
+
completed_at=task.completed_at,
|
| 322 |
+
)
|