liumaolin commited on
Commit
e054d0c
·
1 Parent(s): e43edbb

feat(api): implement local training MVP with adapter pattern

Browse files

- Add adapter base classes (TaskQueue, Progress, Storage, Database)
- Implement local adapters (AsyncTrainingManager, SQLite, LocalStorage)
- Add Pydantic schemas for tasks, experiments, files, and stages
- Implement service layer (TaskService, ExperimentService, FileService)
- Add API endpoints for Quick Mode (/tasks) and Advanced Mode (/experiments)
- Add adapter factory with dependency injection support
- Update architecture design document with implementation status

api_server/app/adapters/base.py CHANGED
@@ -5,7 +5,10 @@
5
  """
6
 
7
  from abc import ABC, abstractmethod
8
- from typing import Dict, Optional, AsyncGenerator
 
 
 
9
 
10
 
11
  class TaskQueueAdapter(ABC):
@@ -138,3 +141,453 @@ class ProgressAdapter(ABC):
138
  进度信息字典
139
  """
140
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  from abc import ABC, abstractmethod
8
+ from typing import TYPE_CHECKING, Dict, List, Optional, AsyncGenerator, Any
9
+
10
+ if TYPE_CHECKING:
11
+ from ..models.domain import Task
12
 
13
 
14
  class TaskQueueAdapter(ABC):
 
141
  进度信息字典
142
  """
143
  pass
144
+
145
+
146
+ class StorageAdapter(ABC):
147
+ """
148
+ 存储适配器抽象基类
149
+
150
+ 定义文件存储的通用接口,支持本地文件系统和
151
+ 对象存储(S3/MinIO)两种实现方式。
152
+
153
+ Example:
154
+ >>> adapter = LocalStorageAdapter(base_path="./data/files")
155
+ >>> file_id = await adapter.upload_file(data, "audio.wav", {"purpose": "training"})
156
+ >>> content = await adapter.download_file(file_id)
157
+ >>> await adapter.delete_file(file_id)
158
+ """
159
+
160
+ @abstractmethod
161
+ async def upload_file(
162
+ self,
163
+ file_data: bytes,
164
+ filename: str,
165
+ metadata: Dict[str, Any]
166
+ ) -> str:
167
+ """
168
+ 上传文件
169
+
170
+ Args:
171
+ file_data: 文件二进制数据
172
+ filename: 原始文件名
173
+ metadata: 文件元数据,可包含:
174
+ - content_type: MIME类型
175
+ - purpose: 文件用途 (training, reference, output)
176
+ - 其他自定义字段
177
+
178
+ Returns:
179
+ file_id: 文件唯一标识
180
+
181
+ Raises:
182
+ IOError: 存储失败时抛出
183
+ """
184
+ pass
185
+
186
+ @abstractmethod
187
+ async def download_file(self, file_id: str) -> bytes:
188
+ """
189
+ 下载文件
190
+
191
+ Args:
192
+ file_id: 文件唯一标识
193
+
194
+ Returns:
195
+ 文件二进制数据
196
+
197
+ Raises:
198
+ FileNotFoundError: 文件不存在时抛出
199
+ """
200
+ pass
201
+
202
+ @abstractmethod
203
+ async def delete_file(self, file_id: str) -> bool:
204
+ """
205
+ 删除文件
206
+
207
+ Args:
208
+ file_id: 文件唯一标识
209
+
210
+ Returns:
211
+ 是否成功删除
212
+ """
213
+ pass
214
+
215
+ @abstractmethod
216
+ async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
217
+ """
218
+ 获取文件元数据
219
+
220
+ Args:
221
+ file_id: 文件唯一标识
222
+
223
+ Returns:
224
+ 文件元数据字典,包含:
225
+ - id: 文件ID
226
+ - filename: 原始文件名
227
+ - content_type: MIME类型
228
+ - size_bytes: 文件大小
229
+ - purpose: 文件用途
230
+ - uploaded_at: 上传时间
231
+ - 音频文件额外包含: duration_seconds, sample_rate
232
+
233
+ 文件不存在时返回 None
234
+ """
235
+ pass
236
+
237
+ @abstractmethod
238
+ async def list_files(
239
+ self,
240
+ purpose: Optional[str] = None,
241
+ limit: int = 50,
242
+ offset: int = 0
243
+ ) -> List[Dict[str, Any]]:
244
+ """
245
+ 列出文件
246
+
247
+ Args:
248
+ purpose: 按用途筛选 (training, reference, output)
249
+ limit: 返回数量限制
250
+ offset: 偏移量
251
+
252
+ Returns:
253
+ 文件元数据列表
254
+ """
255
+ pass
256
+
257
+ @abstractmethod
258
+ async def file_exists(self, file_id: str) -> bool:
259
+ """
260
+ 检查文件是否存在
261
+
262
+ Args:
263
+ file_id: 文件唯一标识
264
+
265
+ Returns:
266
+ 文件是否存在
267
+ """
268
+ pass
269
+
270
+
271
+ class DatabaseAdapter(ABC):
272
+ """
273
+ 数据库适配器抽象基类
274
+
275
+ 定义数据持久化的通用接口,支持 SQLite 和
276
+ PostgreSQL 两种实现方式。
277
+
278
+ 管理以下实体:
279
+ - Task: Quick Mode 一键训练任务
280
+ - Experiment: Advanced Mode 实验
281
+ - Stage: 实验中的各个阶段
282
+ - File: 上传的文件记录(可选,与StorageAdapter配合)
283
+
284
+ Example:
285
+ >>> adapter = SQLiteAdapter(db_path="./data/app.db")
286
+ >>> task = await adapter.create_task(task_data)
287
+ >>> task = await adapter.get_task(task_id)
288
+ >>> await adapter.update_task(task_id, {"status": "completed"})
289
+ """
290
+
291
+ # ============================================================
292
+ # Task CRUD (Quick Mode)
293
+ # ============================================================
294
+
295
+ @abstractmethod
296
+ async def create_task(self, task: "Task") -> "Task":
297
+ """
298
+ 创建任务
299
+
300
+ Args:
301
+ task: Task 领域模型实例
302
+
303
+ Returns:
304
+ 创建后的 Task 实例(包含生成的字段如 created_at)
305
+ """
306
+ pass
307
+
308
+ @abstractmethod
309
+ async def get_task(self, task_id: str) -> Optional["Task"]:
310
+ """
311
+ 获取任务
312
+
313
+ Args:
314
+ task_id: 任务唯一标识
315
+
316
+ Returns:
317
+ Task 实例���不存在则返回 None
318
+ """
319
+ pass
320
+
321
+ @abstractmethod
322
+ async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional["Task"]:
323
+ """
324
+ 更新任务
325
+
326
+ Args:
327
+ task_id: 任务唯一标识
328
+ updates: 要更新的字段字典
329
+
330
+ Returns:
331
+ 更新后的 Task 实例,不存在则返回 None
332
+ """
333
+ pass
334
+
335
+ @abstractmethod
336
+ async def list_tasks(
337
+ self,
338
+ status: Optional[str] = None,
339
+ limit: int = 50,
340
+ offset: int = 0
341
+ ) -> List["Task"]:
342
+ """
343
+ 查询任务列表
344
+
345
+ Args:
346
+ status: 按状态筛选
347
+ limit: 返回数量限制
348
+ offset: 偏移量
349
+
350
+ Returns:
351
+ Task 实例列表
352
+ """
353
+ pass
354
+
355
+ @abstractmethod
356
+ async def delete_task(self, task_id: str) -> bool:
357
+ """
358
+ 删除任务
359
+
360
+ Args:
361
+ task_id: 任务唯一标识
362
+
363
+ Returns:
364
+ 是否成功删除
365
+ """
366
+ pass
367
+
368
+ @abstractmethod
369
+ async def count_tasks(self, status: Optional[str] = None) -> int:
370
+ """
371
+ 统计任务数量
372
+
373
+ Args:
374
+ status: 按状态筛选
375
+
376
+ Returns:
377
+ 任务数量
378
+ """
379
+ pass
380
+
381
+ @abstractmethod
382
+ async def get_task_by_exp_name(self, exp_name: str) -> Optional["Task"]:
383
+ """
384
+ 根据实验名称获取任务
385
+
386
+ 用于检查 exp_name 是否已存在。
387
+
388
+ Args:
389
+ exp_name: 实验名称
390
+
391
+ Returns:
392
+ Task 实例,不存在则返回 None
393
+ """
394
+ pass
395
+
396
+ # ============================================================
397
+ # Experiment CRUD (Advanced Mode)
398
+ # ============================================================
399
+
400
+ @abstractmethod
401
+ async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]:
402
+ """
403
+ 创建实验
404
+
405
+ Args:
406
+ experiment: 实验数据字典
407
+
408
+ Returns:
409
+ 创建后的实验数据
410
+ """
411
+ pass
412
+
413
+ @abstractmethod
414
+ async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]:
415
+ """
416
+ 获取实验
417
+
418
+ Args:
419
+ exp_id: 实验唯一标识
420
+
421
+ Returns:
422
+ 实验数据字典,不存在则返回 None
423
+ """
424
+ pass
425
+
426
+ @abstractmethod
427
+ async def update_experiment(
428
+ self,
429
+ exp_id: str,
430
+ updates: Dict[str, Any]
431
+ ) -> Optional[Dict[str, Any]]:
432
+ """
433
+ 更新实验
434
+
435
+ Args:
436
+ exp_id: 实验唯一标识
437
+ updates: 要更新的字段字典
438
+
439
+ Returns:
440
+ 更新后的实验数据,不存在则返回 None
441
+ """
442
+ pass
443
+
444
+ @abstractmethod
445
+ async def list_experiments(
446
+ self,
447
+ status: Optional[str] = None,
448
+ limit: int = 50,
449
+ offset: int = 0
450
+ ) -> List[Dict[str, Any]]:
451
+ """
452
+ 查询实验列表
453
+
454
+ Args:
455
+ status: 按状态筛选
456
+ limit: 返回数量限制
457
+ offset: 偏移量
458
+
459
+ Returns:
460
+ 实验数据列表
461
+ """
462
+ pass
463
+
464
+ @abstractmethod
465
+ async def delete_experiment(self, exp_id: str) -> bool:
466
+ """
467
+ 删除实验
468
+
469
+ Args:
470
+ exp_id: 实验唯一标识
471
+
472
+ Returns:
473
+ 是否成功删除
474
+ """
475
+ pass
476
+
477
+ # ============================================================
478
+ # Stage 操作 (Advanced Mode)
479
+ # ============================================================
480
+
481
+ @abstractmethod
482
+ async def update_stage(
483
+ self,
484
+ exp_id: str,
485
+ stage_type: str,
486
+ updates: Dict[str, Any]
487
+ ) -> Optional[Dict[str, Any]]:
488
+ """
489
+ 更新阶段状态
490
+
491
+ Args:
492
+ exp_id: 实验唯一标识
493
+ stage_type: 阶段类型
494
+ updates: 要更新的字段字典
495
+
496
+ Returns:
497
+ 更新后的阶段数据,不存在则返回 None
498
+ """
499
+ pass
500
+
501
+ @abstractmethod
502
+ async def get_stage(
503
+ self,
504
+ exp_id: str,
505
+ stage_type: str
506
+ ) -> Optional[Dict[str, Any]]:
507
+ """
508
+ 获取阶段状态
509
+
510
+ Args:
511
+ exp_id: 实验唯一标识
512
+ stage_type: 阶段类型
513
+
514
+ Returns:
515
+ 阶段数据字典,不存在则返回 None
516
+ """
517
+ pass
518
+
519
+ @abstractmethod
520
+ async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]:
521
+ """
522
+ 获取实验的所有阶段状态
523
+
524
+ Args:
525
+ exp_id: 实验唯一标识
526
+
527
+ Returns:
528
+ 阶段数据列表
529
+ """
530
+ pass
531
+
532
+ # ============================================================
533
+ # File 记录 (可选,与 StorageAdapter 配合)
534
+ # ============================================================
535
+
536
+ @abstractmethod
537
+ async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]:
538
+ """
539
+ 创建文件记录
540
+
541
+ Args:
542
+ file_data: 文件元数据
543
+
544
+ Returns:
545
+ 创建后的文件记录
546
+ """
547
+ pass
548
+
549
+ @abstractmethod
550
+ async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
551
+ """
552
+ 获取文件记录
553
+
554
+ Args:
555
+ file_id: 文件唯一标识
556
+
557
+ Returns:
558
+ 文件记录,不存在则返回 None
559
+ """
560
+ pass
561
+
562
+ @abstractmethod
563
+ async def delete_file_record(self, file_id: str) -> bool:
564
+ """
565
+ 删除文件记录
566
+
567
+ Args:
568
+ file_id: 文件唯一标识
569
+
570
+ Returns:
571
+ 是否成功删除
572
+ """
573
+ pass
574
+
575
+ @abstractmethod
576
+ async def list_file_records(
577
+ self,
578
+ purpose: Optional[str] = None,
579
+ limit: int = 50,
580
+ offset: int = 0
581
+ ) -> List[Dict[str, Any]]:
582
+ """
583
+ 查询文件记录列表
584
+
585
+ Args:
586
+ purpose: 按用途筛选
587
+ limit: 返回数量限制
588
+ offset: 偏移量
589
+
590
+ Returns:
591
+ 文件记录列表
592
+ """
593
+ pass
api_server/app/adapters/local/__init__.py CHANGED
@@ -1,9 +1,23 @@
1
  """
2
  本地适配器模块
3
 
4
- 提供基于 SQLite 和 asyncio.subprocess 的本地实现
 
 
 
 
 
 
5
  """
6
 
7
  from .task_queue import AsyncTrainingManager
 
 
 
8
 
9
- __all__ = ["AsyncTrainingManager"]
 
 
 
 
 
 
1
  """
2
  本地适配器模块
3
 
4
+ 提供基于 SQLite 和 asyncio.subprocess 的本地实现。
5
+
6
+ 适配器列表:
7
+ - AsyncTrainingManager: 任务队列适配器(基于 asyncio.subprocess)
8
+ - LocalStorageAdapter: 文件存储适配器(基于本地文件系统)
9
+ - SQLiteAdapter: 数据库适配器(基于 SQLite)
10
+ - LocalProgressAdapter: 进度管理适配器(基于内存队列)
11
  """
12
 
13
  from .task_queue import AsyncTrainingManager
14
+ from .storage import LocalStorageAdapter
15
+ from .database import SQLiteAdapter
16
+ from .progress import LocalProgressAdapter
17
 
18
+ __all__ = [
19
+ "AsyncTrainingManager",
20
+ "LocalStorageAdapter",
21
+ "SQLiteAdapter",
22
+ "LocalProgressAdapter",
23
+ ]
api_server/app/adapters/local/database.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite 数据库适配器
3
+
4
+ 基于 SQLite + aiosqlite 实现的数据库适配器,适用于 macOS 本地训练场景。
5
+ """
6
+
7
+ import json
8
+ import sqlite3
9
+ import uuid
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ import aiosqlite
15
+
16
+ from ..base import DatabaseAdapter
17
+ from ...core.config import settings
18
+ from ...models.domain import Task, TaskStatus
19
+
20
+
21
+ # 阶段类型列表
22
+ STAGE_TYPES = [
23
+ "audio_slice",
24
+ "asr",
25
+ "text_feature",
26
+ "hubert_feature",
27
+ "semantic_token",
28
+ "sovits_train",
29
+ "gpt_train",
30
+ ]
31
+
32
+
33
+ class SQLiteAdapter(DatabaseAdapter):
34
+ """
35
+ SQLite 数据库适配器
36
+
37
+ 特点:
38
+ 1. 使用 aiosqlite 实现异步数据库操作
39
+ 2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理
40
+ 3. 自动初始化数据库表结构
41
+
42
+ 表结构:
43
+ - tasks: Quick Mode 任务
44
+ - experiments: Advanced Mode 实验
45
+ - stages: 实验阶段状态
46
+ - files: 文件记录
47
+
48
+ Example:
49
+ >>> adapter = SQLiteAdapter()
50
+ >>> task = Task(id="task-123", exp_name="my_voice", config={})
51
+ >>> await adapter.create_task(task)
52
+ >>> task = await adapter.get_task("task-123")
53
+ """
54
+
55
+ def __init__(self, db_path: Optional[str] = None):
56
+ """
57
+ 初始化 SQLite 适配器
58
+
59
+ Args:
60
+ db_path: 数据库文件路径,默认使用 settings.SQLITE_PATH
61
+ """
62
+ if db_path:
63
+ self.db_path = db_path
64
+ else:
65
+ self.db_path = str(settings.SQLITE_PATH)
66
+
67
+ # 确保目录存在
68
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
69
+
70
+ # 同步初始化数据库
71
+ self._init_db_sync()
72
+
73
+ def _init_db_sync(self) -> None:
74
+ """同步初始化数据库表结构"""
75
+ with sqlite3.connect(self.db_path) as conn:
76
+ # Tasks 表 (Quick Mode)
77
+ conn.execute('''
78
+ CREATE TABLE IF NOT EXISTS tasks (
79
+ id TEXT PRIMARY KEY,
80
+ job_id TEXT,
81
+ exp_name TEXT NOT NULL,
82
+ status TEXT NOT NULL DEFAULT 'queued',
83
+ config TEXT,
84
+ current_stage TEXT,
85
+ progress REAL DEFAULT 0,
86
+ stage_progress REAL DEFAULT 0,
87
+ message TEXT,
88
+ error_message TEXT,
89
+ created_at TEXT NOT NULL,
90
+ started_at TEXT,
91
+ completed_at TEXT
92
+ )
93
+ ''')
94
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)')
95
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_tasks_created ON tasks(created_at)')
96
+
97
+ # Experiments 表 (Advanced Mode)
98
+ conn.execute('''
99
+ CREATE TABLE IF NOT EXISTS experiments (
100
+ id TEXT PRIMARY KEY,
101
+ exp_name TEXT NOT NULL,
102
+ version TEXT NOT NULL DEFAULT 'v2',
103
+ exp_root TEXT DEFAULT 'logs',
104
+ gpu_numbers TEXT DEFAULT '0',
105
+ is_half INTEGER DEFAULT 1,
106
+ audio_file_id TEXT,
107
+ status TEXT NOT NULL DEFAULT 'created',
108
+ created_at TEXT NOT NULL,
109
+ updated_at TEXT
110
+ )
111
+ ''')
112
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_status ON experiments(status)')
113
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_experiments_created ON experiments(created_at)')
114
+
115
+ # Stages 表 (Advanced Mode 阶段状态)
116
+ conn.execute('''
117
+ CREATE TABLE IF NOT EXISTS stages (
118
+ id TEXT PRIMARY KEY,
119
+ experiment_id TEXT NOT NULL,
120
+ stage_type TEXT NOT NULL,
121
+ status TEXT DEFAULT 'pending',
122
+ progress REAL DEFAULT 0,
123
+ message TEXT,
124
+ job_id TEXT,
125
+ config TEXT,
126
+ outputs TEXT,
127
+ started_at TEXT,
128
+ completed_at TEXT,
129
+ error_message TEXT,
130
+ FOREIGN KEY (experiment_id) REFERENCES experiments(id) ON DELETE CASCADE,
131
+ UNIQUE (experiment_id, stage_type)
132
+ )
133
+ ''')
134
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_experiment ON stages(experiment_id)')
135
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_stages_status ON stages(status)')
136
+
137
+ # Files 表 (文件记录)
138
+ conn.execute('''
139
+ CREATE TABLE IF NOT EXISTS files (
140
+ id TEXT PRIMARY KEY,
141
+ filename TEXT NOT NULL,
142
+ content_type TEXT,
143
+ size_bytes INTEGER DEFAULT 0,
144
+ purpose TEXT DEFAULT 'training',
145
+ duration_seconds REAL,
146
+ sample_rate INTEGER,
147
+ storage_path TEXT,
148
+ uploaded_at TEXT NOT NULL
149
+ )
150
+ ''')
151
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_files_purpose ON files(purpose)')
152
+ conn.execute('CREATE INDEX IF NOT EXISTS idx_files_uploaded ON files(uploaded_at)')
153
+
154
+ conn.commit()
155
+
156
+ # ============================================================
157
+ # Task CRUD (Quick Mode)
158
+ # ============================================================
159
+
160
+ async def create_task(self, task: Task) -> Task:
161
+ """创建任务"""
162
+ async with aiosqlite.connect(self.db_path) as db:
163
+ await db.execute(
164
+ '''INSERT INTO tasks
165
+ (id, job_id, exp_name, status, config, current_stage,
166
+ progress, stage_progress, message, error_message,
167
+ created_at, started_at, completed_at)
168
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
169
+ (
170
+ task.id,
171
+ task.job_id,
172
+ task.exp_name,
173
+ task.status.value if isinstance(task.status, TaskStatus) else task.status,
174
+ json.dumps(task.config, ensure_ascii=False) if task.config else None,
175
+ task.current_stage,
176
+ task.progress,
177
+ task.stage_progress,
178
+ task.message,
179
+ task.error_message,
180
+ task.created_at.isoformat() if task.created_at else datetime.utcnow().isoformat(),
181
+ task.started_at.isoformat() if task.started_at else None,
182
+ task.completed_at.isoformat() if task.completed_at else None,
183
+ )
184
+ )
185
+ await db.commit()
186
+
187
+ return task
188
+
189
+ async def get_task(self, task_id: str) -> Optional[Task]:
190
+ """获取任务"""
191
+ async with aiosqlite.connect(self.db_path) as db:
192
+ db.row_factory = aiosqlite.Row
193
+ async with db.execute(
194
+ "SELECT * FROM tasks WHERE id = ?", (task_id,)
195
+ ) as cursor:
196
+ row = await cursor.fetchone()
197
+ if row:
198
+ return self._row_to_task(dict(row))
199
+ return None
200
+
201
+ async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional[Task]:
202
+ """更新任务"""
203
+ if not updates:
204
+ return await self.get_task(task_id)
205
+
206
+ # 处理特殊字段
207
+ processed = {}
208
+ for key, value in updates.items():
209
+ if key == "status" and isinstance(value, TaskStatus):
210
+ processed[key] = value.value
211
+ elif key == "config" and isinstance(value, dict):
212
+ processed[key] = json.dumps(value, ensure_ascii=False)
213
+ elif key in ("created_at", "started_at", "completed_at") and isinstance(value, datetime):
214
+ processed[key] = value.isoformat()
215
+ else:
216
+ processed[key] = value
217
+
218
+ async with aiosqlite.connect(self.db_path) as db:
219
+ set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
220
+ values = list(processed.values()) + [task_id]
221
+
222
+ await db.execute(
223
+ f"UPDATE tasks SET {set_clause} WHERE id = ?",
224
+ values
225
+ )
226
+ await db.commit()
227
+
228
+ return await self.get_task(task_id)
229
+
230
+ async def list_tasks(
231
+ self,
232
+ status: Optional[str] = None,
233
+ limit: int = 50,
234
+ offset: int = 0
235
+ ) -> List[Task]:
236
+ """查询任务列表"""
237
+ async with aiosqlite.connect(self.db_path) as db:
238
+ db.row_factory = aiosqlite.Row
239
+
240
+ if status:
241
+ query = """
242
+ SELECT * FROM tasks
243
+ WHERE status = ?
244
+ ORDER BY created_at DESC
245
+ LIMIT ? OFFSET ?
246
+ """
247
+ params = (status, limit, offset)
248
+ else:
249
+ query = """
250
+ SELECT * FROM tasks
251
+ ORDER BY created_at DESC
252
+ LIMIT ? OFFSET ?
253
+ """
254
+ params = (limit, offset)
255
+
256
+ async with db.execute(query, params) as cursor:
257
+ rows = await cursor.fetchall()
258
+ return [self._row_to_task(dict(row)) for row in rows]
259
+
260
+ async def delete_task(self, task_id: str) -> bool:
261
+ """删除任务"""
262
+ async with aiosqlite.connect(self.db_path) as db:
263
+ cursor = await db.execute(
264
+ "DELETE FROM tasks WHERE id = ?", (task_id,)
265
+ )
266
+ await db.commit()
267
+ return cursor.rowcount > 0
268
+
269
+ async def count_tasks(self, status: Optional[str] = None) -> int:
270
+ """统计任务数量"""
271
+ async with aiosqlite.connect(self.db_path) as db:
272
+ if status:
273
+ async with db.execute(
274
+ "SELECT COUNT(*) FROM tasks WHERE status = ?", (status,)
275
+ ) as cursor:
276
+ row = await cursor.fetchone()
277
+ else:
278
+ async with db.execute("SELECT COUNT(*) FROM tasks") as cursor:
279
+ row = await cursor.fetchone()
280
+
281
+ return row[0] if row else 0
282
+
283
+ async def get_task_by_exp_name(self, exp_name: str) -> Optional[Task]:
284
+ """根据实验名称获取任务"""
285
+ async with aiosqlite.connect(self.db_path) as db:
286
+ db.row_factory = aiosqlite.Row
287
+ async with db.execute(
288
+ "SELECT * FROM tasks WHERE exp_name = ? LIMIT 1", (exp_name,)
289
+ ) as cursor:
290
+ row = await cursor.fetchone()
291
+ if row:
292
+ return self._row_to_task(dict(row))
293
+ return None
294
+
295
+ def _row_to_task(self, row: Dict[str, Any]) -> Task:
296
+ """将数据库行转换为 Task 对象"""
297
+ # 解析 config JSON
298
+ config = row.get("config")
299
+ if config and isinstance(config, str):
300
+ try:
301
+ config = json.loads(config)
302
+ except json.JSONDecodeError:
303
+ config = {}
304
+
305
+ return Task.from_dict({
306
+ "id": row["id"],
307
+ "job_id": row.get("job_id"),
308
+ "exp_name": row["exp_name"],
309
+ "status": row.get("status", "queued"),
310
+ "config": config or {},
311
+ "current_stage": row.get("current_stage"),
312
+ "progress": row.get("progress", 0.0),
313
+ "stage_progress": row.get("stage_progress", 0.0),
314
+ "message": row.get("message"),
315
+ "error_message": row.get("error_message"),
316
+ "created_at": row.get("created_at"),
317
+ "started_at": row.get("started_at"),
318
+ "completed_at": row.get("completed_at"),
319
+ })
320
+
321
+ # ============================================================
322
+ # Experiment CRUD (Advanced Mode)
323
+ # ============================================================
324
+
325
+ async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]:
326
+ """创建实验"""
327
+ exp_id = experiment.get("id") or f"exp-{uuid.uuid4().hex[:8]}"
328
+ now = datetime.utcnow().isoformat()
329
+
330
+ async with aiosqlite.connect(self.db_path) as db:
331
+ # 创建实验记录
332
+ await db.execute(
333
+ '''INSERT INTO experiments
334
+ (id, exp_name, version, exp_root, gpu_numbers, is_half,
335
+ audio_file_id, status, created_at, updated_at)
336
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)''',
337
+ (
338
+ exp_id,
339
+ experiment["exp_name"],
340
+ experiment.get("version", "v2"),
341
+ experiment.get("exp_root", "logs"),
342
+ experiment.get("gpu_numbers", "0"),
343
+ 1 if experiment.get("is_half", True) else 0,
344
+ experiment.get("audio_file_id"),
345
+ experiment.get("status", "created"),
346
+ now,
347
+ now,
348
+ )
349
+ )
350
+
351
+ # 创建所有阶段的初始状态
352
+ for stage_type in STAGE_TYPES:
353
+ stage_id = f"{exp_id}-{stage_type}"
354
+ await db.execute(
355
+ '''INSERT INTO stages
356
+ (id, experiment_id, stage_type, status)
357
+ VALUES (?, ?, ?, 'pending')''',
358
+ (stage_id, exp_id, stage_type)
359
+ )
360
+
361
+ await db.commit()
362
+
363
+ return await self.get_experiment(exp_id)
364
+
365
+ async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]:
366
+ """获取实验"""
367
+ async with aiosqlite.connect(self.db_path) as db:
368
+ db.row_factory = aiosqlite.Row
369
+
370
+ # 获取实验基本信息
371
+ async with db.execute(
372
+ "SELECT * FROM experiments WHERE id = ?", (exp_id,)
373
+ ) as cursor:
374
+ row = await cursor.fetchone()
375
+ if not row:
376
+ return None
377
+
378
+ experiment = dict(row)
379
+ experiment["is_half"] = bool(experiment.get("is_half", 1))
380
+
381
+ # 获取所有阶段状态
382
+ stages = {}
383
+ async with db.execute(
384
+ "SELECT * FROM stages WHERE experiment_id = ?", (exp_id,)
385
+ ) as cursor:
386
+ stage_rows = await cursor.fetchall()
387
+ for stage_row in stage_rows:
388
+ stage = dict(stage_row)
389
+ stage_type = stage["stage_type"]
390
+
391
+ # 解析 JSON 字段
392
+ for json_field in ("config", "outputs"):
393
+ if stage.get(json_field) and isinstance(stage[json_field], str):
394
+ try:
395
+ stage[json_field] = json.loads(stage[json_field])
396
+ except json.JSONDecodeError:
397
+ stage[json_field] = None
398
+
399
+ stages[stage_type] = stage
400
+
401
+ experiment["stages"] = stages
402
+ return experiment
403
+
404
+ async def update_experiment(
405
+ self,
406
+ exp_id: str,
407
+ updates: Dict[str, Any]
408
+ ) -> Optional[Dict[str, Any]]:
409
+ """更新实验"""
410
+ if not updates:
411
+ return await self.get_experiment(exp_id)
412
+
413
+ # 处理 is_half 布尔值
414
+ processed = {}
415
+ for key, value in updates.items():
416
+ if key == "is_half":
417
+ processed[key] = 1 if value else 0
418
+ elif key == "updated_at" and isinstance(value, datetime):
419
+ processed[key] = value.isoformat()
420
+ elif key != "stages": # stages 单独处理
421
+ processed[key] = value
422
+
423
+ # 添加更新时间
424
+ if "updated_at" not in processed:
425
+ processed["updated_at"] = datetime.utcnow().isoformat()
426
+
427
+ async with aiosqlite.connect(self.db_path) as db:
428
+ if processed:
429
+ set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
430
+ values = list(processed.values()) + [exp_id]
431
+
432
+ await db.execute(
433
+ f"UPDATE experiments SET {set_clause} WHERE id = ?",
434
+ values
435
+ )
436
+ await db.commit()
437
+
438
+ return await self.get_experiment(exp_id)
439
+
440
+ async def list_experiments(
441
+ self,
442
+ status: Optional[str] = None,
443
+ limit: int = 50,
444
+ offset: int = 0
445
+ ) -> List[Dict[str, Any]]:
446
+ """查询实验列表"""
447
+ async with aiosqlite.connect(self.db_path) as db:
448
+ db.row_factory = aiosqlite.Row
449
+
450
+ if status:
451
+ query = """
452
+ SELECT * FROM experiments
453
+ WHERE status = ?
454
+ ORDER BY created_at DESC
455
+ LIMIT ? OFFSET ?
456
+ """
457
+ params = (status, limit, offset)
458
+ else:
459
+ query = """
460
+ SELECT * FROM experiments
461
+ ORDER BY created_at DESC
462
+ LIMIT ? OFFSET ?
463
+ """
464
+ params = (limit, offset)
465
+
466
+ async with db.execute(query, params) as cursor:
467
+ rows = await cursor.fetchall()
468
+
469
+ results = []
470
+ for row in rows:
471
+ exp = dict(row)
472
+ exp["is_half"] = bool(exp.get("is_half", 1))
473
+ # 简化列表,不包含完整的 stages
474
+ results.append(exp)
475
+
476
+ return results
477
+
478
+ async def delete_experiment(self, exp_id: str) -> bool:
479
+ """删除实验及其阶段"""
480
+ async with aiosqlite.connect(self.db_path) as db:
481
+ # 先删除阶段
482
+ await db.execute(
483
+ "DELETE FROM stages WHERE experiment_id = ?", (exp_id,)
484
+ )
485
+
486
+ # 再删除实验
487
+ cursor = await db.execute(
488
+ "DELETE FROM experiments WHERE id = ?", (exp_id,)
489
+ )
490
+ await db.commit()
491
+ return cursor.rowcount > 0
492
+
493
+ # ============================================================
494
+ # Stage 操作 (Advanced Mode)
495
+ # ============================================================
496
+
497
+ async def update_stage(
498
+ self,
499
+ exp_id: str,
500
+ stage_type: str,
501
+ updates: Dict[str, Any]
502
+ ) -> Optional[Dict[str, Any]]:
503
+ """更新阶段状态"""
504
+ if not updates:
505
+ return await self.get_stage(exp_id, stage_type)
506
+
507
+ # 处理 JSON 字段
508
+ processed = {}
509
+ for key, value in updates.items():
510
+ if key in ("config", "outputs") and isinstance(value, dict):
511
+ processed[key] = json.dumps(value, ensure_ascii=False)
512
+ elif key in ("started_at", "completed_at") and isinstance(value, datetime):
513
+ processed[key] = value.isoformat()
514
+ else:
515
+ processed[key] = value
516
+
517
+ async with aiosqlite.connect(self.db_path) as db:
518
+ set_clause = ", ".join(f"{k} = ?" for k in processed.keys())
519
+ values = list(processed.values()) + [exp_id, stage_type]
520
+
521
+ await db.execute(
522
+ f"UPDATE stages SET {set_clause} WHERE experiment_id = ? AND stage_type = ?",
523
+ values
524
+ )
525
+ await db.commit()
526
+
527
+ # 同时更新实验的 updated_at
528
+ await self.update_experiment(exp_id, {})
529
+
530
+ return await self.get_stage(exp_id, stage_type)
531
+
532
+ async def get_stage(
533
+ self,
534
+ exp_id: str,
535
+ stage_type: str
536
+ ) -> Optional[Dict[str, Any]]:
537
+ """获取阶段状态"""
538
+ async with aiosqlite.connect(self.db_path) as db:
539
+ db.row_factory = aiosqlite.Row
540
+
541
+ async with db.execute(
542
+ "SELECT * FROM stages WHERE experiment_id = ? AND stage_type = ?",
543
+ (exp_id, stage_type)
544
+ ) as cursor:
545
+ row = await cursor.fetchone()
546
+ if not row:
547
+ return None
548
+
549
+ stage = dict(row)
550
+
551
+ # 解析 JSON 字段
552
+ for json_field in ("config", "outputs"):
553
+ if stage.get(json_field) and isinstance(stage[json_field], str):
554
+ try:
555
+ stage[json_field] = json.loads(stage[json_field])
556
+ except json.JSONDecodeError:
557
+ stage[json_field] = None
558
+
559
+ return stage
560
+
561
+ async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]:
562
+ """获取实验的所有阶段状态"""
563
+ async with aiosqlite.connect(self.db_path) as db:
564
+ db.row_factory = aiosqlite.Row
565
+
566
+ async with db.execute(
567
+ "SELECT * FROM stages WHERE experiment_id = ? ORDER BY id",
568
+ (exp_id,)
569
+ ) as cursor:
570
+ rows = await cursor.fetchall()
571
+
572
+ results = []
573
+ for row in rows:
574
+ stage = dict(row)
575
+
576
+ # 解析 JSON 字段
577
+ for json_field in ("config", "outputs"):
578
+ if stage.get(json_field) and isinstance(stage[json_field], str):
579
+ try:
580
+ stage[json_field] = json.loads(stage[json_field])
581
+ except json.JSONDecodeError:
582
+ stage[json_field] = None
583
+
584
+ results.append(stage)
585
+
586
+ return results
587
+
588
+ # ============================================================
589
+ # File 记录
590
+ # ============================================================
591
+
592
+ async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]:
593
+ """创建文件记录"""
594
+ file_id = file_data.get("id") or str(uuid.uuid4())
595
+ now = datetime.utcnow().isoformat()
596
+
597
+ async with aiosqlite.connect(self.db_path) as db:
598
+ await db.execute(
599
+ '''INSERT INTO files
600
+ (id, filename, content_type, size_bytes, purpose,
601
+ duration_seconds, sample_rate, storage_path, uploaded_at)
602
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
603
+ (
604
+ file_id,
605
+ file_data["filename"],
606
+ file_data.get("content_type"),
607
+ file_data.get("size_bytes", 0),
608
+ file_data.get("purpose", "training"),
609
+ file_data.get("duration_seconds"),
610
+ file_data.get("sample_rate"),
611
+ file_data.get("storage_path"),
612
+ file_data.get("uploaded_at", now),
613
+ )
614
+ )
615
+ await db.commit()
616
+
617
+ return await self.get_file_record(file_id)
618
+
619
+ async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
620
+ """获取文件记录"""
621
+ async with aiosqlite.connect(self.db_path) as db:
622
+ db.row_factory = aiosqlite.Row
623
+
624
+ async with db.execute(
625
+ "SELECT * FROM files WHERE id = ?", (file_id,)
626
+ ) as cursor:
627
+ row = await cursor.fetchone()
628
+ if row:
629
+ return dict(row)
630
+ return None
631
+
632
+ async def delete_file_record(self, file_id: str) -> bool:
633
+ """删除文件记录"""
634
+ async with aiosqlite.connect(self.db_path) as db:
635
+ cursor = await db.execute(
636
+ "DELETE FROM files WHERE id = ?", (file_id,)
637
+ )
638
+ await db.commit()
639
+ return cursor.rowcount > 0
640
+
641
+ async def list_file_records(
642
+ self,
643
+ purpose: Optional[str] = None,
644
+ limit: int = 50,
645
+ offset: int = 0
646
+ ) -> List[Dict[str, Any]]:
647
+ """查询文件记录列表"""
648
+ async with aiosqlite.connect(self.db_path) as db:
649
+ db.row_factory = aiosqlite.Row
650
+
651
+ if purpose:
652
+ query = """
653
+ SELECT * FROM files
654
+ WHERE purpose = ?
655
+ ORDER BY uploaded_at DESC
656
+ LIMIT ? OFFSET ?
657
+ """
658
+ params = (purpose, limit, offset)
659
+ else:
660
+ query = """
661
+ SELECT * FROM files
662
+ ORDER BY uploaded_at DESC
663
+ LIMIT ? OFFSET ?
664
+ """
665
+ params = (limit, offset)
666
+
667
+ async with db.execute(query, params) as cursor:
668
+ rows = await cursor.fetchall()
669
+ return [dict(row) for row in rows]
670
+
671
+ async def count_file_records(self, purpose: Optional[str] = None) -> int:
672
+ """统计文件记录数量"""
673
+ async with aiosqlite.connect(self.db_path) as db:
674
+ if purpose:
675
+ async with db.execute(
676
+ "SELECT COUNT(*) FROM files WHERE purpose = ?", (purpose,)
677
+ ) as cursor:
678
+ row = await cursor.fetchone()
679
+ else:
680
+ async with db.execute("SELECT COUNT(*) FROM files") as cursor:
681
+ row = await cursor.fetchone()
682
+
683
+ return row[0] if row else 0
api_server/app/adapters/local/progress.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 本地进度管理适配器
3
+
4
+ 基于内存队列实现的进度管理适配器,适用于本地单实例场景。
5
+ """
6
+
7
+ import asyncio
8
+ from collections import defaultdict
9
+ from datetime import datetime
10
+ from typing import Any, AsyncGenerator, Dict, List, Optional
11
+
12
+ from ..base import ProgressAdapter
13
+
14
+
15
+ class LocalProgressAdapter(ProgressAdapter):
16
+ """
17
+ 本地内存进度管理适配器
18
+
19
+ 特点:
20
+ 1. 使用内存字典存储最新进度
21
+ 2. 使用 asyncio.Queue 实现订阅者模式
22
+ 3. 支持多订阅者同时订阅同一任务
23
+ 4. 与 AsyncTrainingManager 的进度推送机制兼容
24
+
25
+ 注意:
26
+ - 进程重启后进度数据会丢失
27
+ - 仅适用于单实例部署
28
+ - 服务器模式应使用 RedisProgressAdapter
29
+
30
+ Example:
31
+ >>> adapter = LocalProgressAdapter()
32
+ >>> await adapter.update_progress("task-123", {
33
+ ... "stage": "sovits_train",
34
+ ... "progress": 0.5,
35
+ ... "message": "Epoch 8/16"
36
+ ... })
37
+ >>>
38
+ >>> # 订阅进度
39
+ >>> async for progress in adapter.subscribe("task-123"):
40
+ ... print(f"{progress['stage']}: {progress['progress']*100:.1f}%")
41
+ """
42
+
43
+ def __init__(self):
44
+ """初始化本地进度适配器"""
45
+ # 存储每个任务的最新进度
46
+ self.progress_store: Dict[str, Dict[str, Any]] = {}
47
+
48
+ # 存储每个任务的订阅者队列列表
49
+ self.subscribers: Dict[str, List[asyncio.Queue]] = defaultdict(list)
50
+
51
+ # 锁,用于保护订阅者列表的并发访问
52
+ self._lock = asyncio.Lock()
53
+
54
+ async def update_progress(self, task_id: str, progress: Dict[str, Any]) -> None:
55
+ """
56
+ 更新进度
57
+
58
+ Args:
59
+ task_id: 任务ID
60
+ progress: 进度信息字典,可包含:
61
+ - type: 消息类型 ("progress", "log", "error", "heartbeat")
62
+ - stage: 当前阶段
63
+ - progress: 阶段进度 (0.0-1.0)
64
+ - overall_progress: 总体进度 (0.0-1.0)
65
+ - message: 进度消息
66
+ - status: 状态 ("running", "completed", "failed", "cancelled")
67
+ """
68
+ # 添加时间戳
69
+ if "timestamp" not in progress:
70
+ progress["timestamp"] = datetime.utcnow().isoformat()
71
+
72
+ # 存储最新进度
73
+ self.progress_store[task_id] = progress
74
+
75
+ # 通知所有订阅者
76
+ async with self._lock:
77
+ if task_id in self.subscribers:
78
+ for queue in self.subscribers[task_id]:
79
+ try:
80
+ await queue.put(progress)
81
+ except asyncio.QueueFull:
82
+ # 队列满了,跳过(避免阻塞)
83
+ pass
84
+
85
+ async def get_progress(self, task_id: str) -> Optional[Dict[str, Any]]:
86
+ """
87
+ 获取当前进度
88
+
89
+ Args:
90
+ task_id: 任务ID
91
+
92
+ Returns:
93
+ 最新进度信息,不存在则返回 None
94
+ """
95
+ return self.progress_store.get(task_id)
96
+
97
+ async def subscribe(self, task_id: str) -> AsyncGenerator[Dict[str, Any], None]:
98
+ """
99
+ 订阅进度更新
100
+
101
+ 创建一个异步生成器,持续接收指定任务的进度更新。
102
+ 当任务进入终态(completed, failed, cancelled)时自动结束。
103
+
104
+ Args:
105
+ task_id: 任务ID
106
+
107
+ Yields:
108
+ 进度信息字典
109
+
110
+ Example:
111
+ >>> async for progress in adapter.subscribe("task-123"):
112
+ ... print(progress)
113
+ ... if progress.get("status") == "completed":
114
+ ... break
115
+ """
116
+ # 创建订阅者队列
117
+ queue: asyncio.Queue = asyncio.Queue(maxsize=100)
118
+
119
+ async with self._lock:
120
+ self.subscribers[task_id].append(queue)
121
+
122
+ try:
123
+ # 首先发送当前进度(如果有)
124
+ current = self.progress_store.get(task_id)
125
+ if current:
126
+ yield current
127
+ # 如果已经是终态,直接返回
128
+ if current.get("status") in ("completed", "failed", "cancelled"):
129
+ return
130
+
131
+ # 持续接收更新
132
+ while True:
133
+ try:
134
+ # 30秒超时,发送心跳
135
+ progress = await asyncio.wait_for(queue.get(), timeout=30.0)
136
+ yield progress
137
+
138
+ # 检查是否为终态
139
+ if progress.get("status") in ("completed", "failed", "cancelled"):
140
+ break
141
+
142
+ except asyncio.TimeoutError:
143
+ # 发送心跳保持连接
144
+ yield {
145
+ "type": "heartbeat",
146
+ "timestamp": datetime.utcnow().isoformat(),
147
+ }
148
+
149
+ finally:
150
+ # 清理订阅者
151
+ async with self._lock:
152
+ if task_id in self.subscribers:
153
+ try:
154
+ self.subscribers[task_id].remove(queue)
155
+ except ValueError:
156
+ pass
157
+
158
+ # 如果没有订阅者了,清理列表
159
+ if not self.subscribers[task_id]:
160
+ del self.subscribers[task_id]
161
+
162
+ async def clear_progress(self, task_id: str) -> None:
163
+ """
164
+ 清除任务进度数据
165
+
166
+ Args:
167
+ task_id: 任务ID
168
+ """
169
+ self.progress_store.pop(task_id, None)
170
+
171
+ async with self._lock:
172
+ self.subscribers.pop(task_id, None)
173
+
174
+ async def get_subscriber_count(self, task_id: str) -> int:
175
+ """
176
+ 获取任务的订阅者数量
177
+
178
+ Args:
179
+ task_id: 任务ID
180
+
181
+ Returns:
182
+ 订阅者数量
183
+ """
184
+ async with self._lock:
185
+ return len(self.subscribers.get(task_id, []))
186
+
187
+ async def broadcast_to_all(self, message: Dict[str, Any]) -> int:
188
+ """
189
+ 向所有任务的订阅者广播消息
190
+
191
+ 用于系统级通知,如服务器关闭警告等。
192
+
193
+ Args:
194
+ message: 消息内容
195
+
196
+ Returns:
197
+ 发送成功的订阅者数量
198
+ """
199
+ if "timestamp" not in message:
200
+ message["timestamp"] = datetime.utcnow().isoformat()
201
+
202
+ count = 0
203
+ async with self._lock:
204
+ for task_id, queues in self.subscribers.items():
205
+ for queue in queues:
206
+ try:
207
+ await queue.put(message)
208
+ count += 1
209
+ except asyncio.QueueFull:
210
+ pass
211
+
212
+ return count
213
+
214
+ def get_active_tasks(self) -> List[str]:
215
+ """
216
+ 获取有活跃订阅者的任务列表
217
+
218
+ Returns:
219
+ 任务ID列表
220
+ """
221
+ return list(self.subscribers.keys())
222
+
223
+ def get_stats(self) -> Dict[str, Any]:
224
+ """
225
+ 获取适配器统计信息
226
+
227
+ Returns:
228
+ 统计信息字典
229
+ """
230
+ total_subscribers = sum(
231
+ len(queues) for queues in self.subscribers.values()
232
+ )
233
+
234
+ return {
235
+ "stored_progress_count": len(self.progress_store),
236
+ "active_tasks": len(self.subscribers),
237
+ "total_subscribers": total_subscribers,
238
+ }
api_server/app/adapters/local/storage.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 本地文件存储适配器
3
+
4
+ 基于本地文件系统实现的存储适配器,适用于 macOS 本地训练场景。
5
+ """
6
+
7
+ import json
8
+ import mimetypes
9
+ import uuid
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ import aiofiles
15
+
16
+ from ..base import StorageAdapter
17
+ from ...core.config import settings
18
+
19
+
20
+ class LocalStorageAdapter(StorageAdapter):
21
+ """
22
+ 本地文件系统存储适配器
23
+
24
+ 特点:
25
+ 1. 使用 aiofiles 进行异步文件读写
26
+ 2. 元数据存储在 .meta.json 文件中
27
+ 3. 支持音频文件信息提取(时长、采样率等)
28
+
29
+ 目录结构:
30
+ ```
31
+ base_path/
32
+ ├── {file_id} # 实际文件
33
+ └── {file_id}.meta.json # 元数据文件
34
+ ```
35
+
36
+ Example:
37
+ >>> adapter = LocalStorageAdapter()
38
+ >>> file_id = await adapter.upload_file(
39
+ ... file_data=b"...",
40
+ ... filename="audio.wav",
41
+ ... metadata={"purpose": "training"}
42
+ ... )
43
+ >>> content = await adapter.download_file(file_id)
44
+ >>> metadata = await adapter.get_file_metadata(file_id)
45
+ """
46
+
47
+ def __init__(self, base_path: Optional[str] = None):
48
+ """
49
+ 初始化本地存储适配器
50
+
51
+ Args:
52
+ base_path: 文件存储根目录,默认使用 settings.DATA_DIR / "files"
53
+ """
54
+ if base_path:
55
+ self.base_path = Path(base_path)
56
+ else:
57
+ self.base_path = settings.DATA_DIR / "files"
58
+
59
+ # 确保目录存在
60
+ self.base_path.mkdir(parents=True, exist_ok=True)
61
+
62
+ def _get_file_path(self, file_id: str) -> Path:
63
+ """获取文件存储路径"""
64
+ return self.base_path / file_id
65
+
66
+ def _get_meta_path(self, file_id: str) -> Path:
67
+ """获取元数据文件路径"""
68
+ return self.base_path / f"{file_id}.meta.json"
69
+
70
+ async def upload_file(
71
+ self,
72
+ file_data: bytes,
73
+ filename: str,
74
+ metadata: Dict[str, Any]
75
+ ) -> str:
76
+ """
77
+ 上传文件到本地文件系统
78
+
79
+ Args:
80
+ file_data: 文件二进制数据
81
+ filename: 原始文件名
82
+ metadata: 文件元数据
83
+
84
+ Returns:
85
+ file_id: 生成的文件唯一标识
86
+ """
87
+ # 生成文件ID
88
+ file_id = str(uuid.uuid4())
89
+
90
+ # 确定文件扩展名
91
+ suffix = Path(filename).suffix
92
+ if suffix:
93
+ file_id = f"{file_id}{suffix}"
94
+
95
+ file_path = self._get_file_path(file_id)
96
+ meta_path = self._get_meta_path(file_id)
97
+
98
+ # 写入文件
99
+ async with aiofiles.open(file_path, 'wb') as f:
100
+ await f.write(file_data)
101
+
102
+ # 猜测 MIME 类型
103
+ content_type = metadata.get("content_type")
104
+ if not content_type:
105
+ content_type, _ = mimetypes.guess_type(filename)
106
+ content_type = content_type or "application/octet-stream"
107
+
108
+ # 构建元数据
109
+ file_metadata = {
110
+ "id": file_id,
111
+ "filename": filename,
112
+ "content_type": content_type,
113
+ "size_bytes": len(file_data),
114
+ "purpose": metadata.get("purpose", "training"),
115
+ "uploaded_at": datetime.utcnow().isoformat(),
116
+ **{k: v for k, v in metadata.items() if k not in ("content_type", "purpose")}
117
+ }
118
+
119
+ # 尝试提取音频信息
120
+ if content_type and content_type.startswith("audio/"):
121
+ audio_info = await self._extract_audio_info(file_path)
122
+ if audio_info:
123
+ file_metadata.update(audio_info)
124
+
125
+ # 写入元数据
126
+ async with aiofiles.open(meta_path, 'w', encoding='utf-8') as f:
127
+ await f.write(json.dumps(file_metadata, ensure_ascii=False, indent=2))
128
+
129
+ return file_id
130
+
131
+ async def download_file(self, file_id: str) -> bytes:
132
+ """
133
+ 下载文件
134
+
135
+ Args:
136
+ file_id: 文件唯一标识
137
+
138
+ Returns:
139
+ 文件二进制数据
140
+
141
+ Raises:
142
+ FileNotFoundError: 文件不存在时抛出
143
+ """
144
+ file_path = self._get_file_path(file_id)
145
+
146
+ if not file_path.exists():
147
+ raise FileNotFoundError(f"File not found: {file_id}")
148
+
149
+ async with aiofiles.open(file_path, 'rb') as f:
150
+ return await f.read()
151
+
152
+ async def delete_file(self, file_id: str) -> bool:
153
+ """
154
+ 删除文件及其元数据
155
+
156
+ Args:
157
+ file_id: 文件唯一标识
158
+
159
+ Returns:
160
+ 是否成功删除
161
+ """
162
+ file_path = self._get_file_path(file_id)
163
+ meta_path = self._get_meta_path(file_id)
164
+
165
+ deleted = False
166
+
167
+ # 删除文件
168
+ if file_path.exists():
169
+ file_path.unlink()
170
+ deleted = True
171
+
172
+ # 删除元数据
173
+ if meta_path.exists():
174
+ meta_path.unlink()
175
+ deleted = True
176
+
177
+ return deleted
178
+
179
+ async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
180
+ """
181
+ 获取文件元数据
182
+
183
+ Args:
184
+ file_id: 文件唯一标识
185
+
186
+ Returns:
187
+ 文件元数据字典,不存在则返回 None
188
+ """
189
+ meta_path = self._get_meta_path(file_id)
190
+
191
+ if not meta_path.exists():
192
+ return None
193
+
194
+ try:
195
+ async with aiofiles.open(meta_path, 'r', encoding='utf-8') as f:
196
+ content = await f.read()
197
+ return json.loads(content)
198
+ except (json.JSONDecodeError, IOError):
199
+ return None
200
+
201
+ async def list_files(
202
+ self,
203
+ purpose: Optional[str] = None,
204
+ limit: int = 50,
205
+ offset: int = 0
206
+ ) -> List[Dict[str, Any]]:
207
+ """
208
+ 列出文件
209
+
210
+ Args:
211
+ purpose: 按用途筛选
212
+ limit: 返回数量限制
213
+ offset: 偏移量
214
+
215
+ Returns:
216
+ 文件元数据列表
217
+ """
218
+ results = []
219
+
220
+ # 遍历所有 .meta.json 文件
221
+ meta_files = sorted(
222
+ self.base_path.glob("*.meta.json"),
223
+ key=lambda p: p.stat().st_mtime,
224
+ reverse=True # 最新的在前
225
+ )
226
+
227
+ for meta_path in meta_files:
228
+ try:
229
+ async with aiofiles.open(meta_path, 'r', encoding='utf-8') as f:
230
+ content = await f.read()
231
+ metadata = json.loads(content)
232
+
233
+ # 按用途筛选
234
+ if purpose and metadata.get("purpose") != purpose:
235
+ continue
236
+
237
+ results.append(metadata)
238
+ except (json.JSONDecodeError, IOError):
239
+ continue
240
+
241
+ # 应用分页
242
+ return results[offset:offset + limit]
243
+
244
+ async def file_exists(self, file_id: str) -> bool:
245
+ """
246
+ 检查文件是否存在
247
+
248
+ Args:
249
+ file_id: 文件唯一标识
250
+
251
+ Returns:
252
+ 文件是否存在
253
+ """
254
+ file_path = self._get_file_path(file_id)
255
+ return file_path.exists()
256
+
257
+ async def count_files(self, purpose: Optional[str] = None) -> int:
258
+ """
259
+ 统计文件数量
260
+
261
+ Args:
262
+ purpose: 按用途筛选
263
+
264
+ Returns:
265
+ 文件数量
266
+ """
267
+ if not purpose:
268
+ # 直接计数 meta 文件
269
+ return len(list(self.base_path.glob("*.meta.json")))
270
+
271
+ # 需要筛选时读取元数据
272
+ count = 0
273
+ for meta_path in self.base_path.glob("*.meta.json"):
274
+ try:
275
+ async with aiofiles.open(meta_path, 'r', encoding='utf-8') as f:
276
+ content = await f.read()
277
+ metadata = json.loads(content)
278
+ if metadata.get("purpose") == purpose:
279
+ count += 1
280
+ except (json.JSONDecodeError, IOError):
281
+ continue
282
+
283
+ return count
284
+
285
+ async def _extract_audio_info(self, file_path: Path) -> Optional[Dict[str, Any]]:
286
+ """
287
+ 提取音频文件信息(时长、采样率等)
288
+
289
+ Args:
290
+ file_path: 音频文件路径
291
+
292
+ Returns:
293
+ 音频信息字典,提取失败返回 None
294
+ """
295
+ try:
296
+ # 尝试使用 soundfile(如果可用)
297
+ import soundfile as sf
298
+
299
+ info = sf.info(str(file_path))
300
+ return {
301
+ "duration_seconds": info.duration,
302
+ "sample_rate": info.samplerate,
303
+ "channels": info.channels,
304
+ }
305
+ except ImportError:
306
+ # soundfile 不可用,尝试使用 wave 模块处理 WAV 文件
307
+ if file_path.suffix.lower() == '.wav':
308
+ try:
309
+ import wave
310
+
311
+ with wave.open(str(file_path), 'rb') as wf:
312
+ frames = wf.getnframes()
313
+ rate = wf.getframerate()
314
+ channels = wf.getnchannels()
315
+ duration = frames / float(rate) if rate > 0 else 0
316
+
317
+ return {
318
+ "duration_seconds": duration,
319
+ "sample_rate": rate,
320
+ "channels": channels,
321
+ }
322
+ except Exception:
323
+ pass
324
+ except Exception:
325
+ pass
326
+
327
+ return None
328
+
329
+ async def get_file_path(self, file_id: str) -> Optional[Path]:
330
+ """
331
+ 获取文件的本地路径(供其他模块直接访问文件使用)
332
+
333
+ Args:
334
+ file_id: 文件唯一标识
335
+
336
+ Returns:
337
+ 文件路径,不存在则返回 None
338
+ """
339
+ file_path = self._get_file_path(file_id)
340
+ if file_path.exists():
341
+ return file_path
342
+ return None
api_server/app/adapters/local/task_queue.py CHANGED
@@ -13,13 +13,16 @@ import sys
13
  import uuid
14
  from datetime import datetime
15
  from pathlib import Path
16
- from typing import Dict, Optional, AsyncGenerator, List
17
 
18
  import aiosqlite
19
 
20
  from ..base import TaskQueueAdapter
21
  from ...core.config import settings, PROJECT_ROOT, get_pythonpath
22
 
 
 
 
23
  # 进度消息标识符(与 run_pipeline.py 保持一致)
24
  PROGRESS_PREFIX = "##PROGRESS##"
25
  PROGRESS_SUFFIX = "##"
@@ -47,22 +50,32 @@ class AsyncTrainingManager(TaskQueueAdapter):
47
  >>> await manager.cancel(job_id)
48
  """
49
 
50
- def __init__(self, db_path: str = None, max_concurrent: int = 1):
 
 
 
 
 
51
  """
52
  初始化任务管理器
53
 
54
  Args:
55
  db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH
56
  max_concurrent: 最大并发任务数(本地通常为1)
 
57
  """
58
  self.db_path = db_path or str(settings.SQLITE_PATH)
59
  self.max_concurrent = max_concurrent
 
60
 
61
  # 运行时状态
62
  self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process
63
  self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue
64
  self._running_count = 0
65
  self._queue_lock = asyncio.Lock()
 
 
 
66
 
67
  # 初始化数据库
68
  self._init_db_sync()
@@ -123,6 +136,9 @@ class AsyncTrainingManager(TaskQueueAdapter):
123
  )
124
  await db.commit()
125
 
 
 
 
126
  # 创建进度队列
127
  self.progress_channels[task_id] = asyncio.Queue()
128
 
@@ -381,6 +397,8 @@ class AsyncTrainingManager(TaskQueueAdapter):
381
  """
382
  更新任务状态
383
 
 
 
384
  Args:
385
  job_id: 作业ID
386
  **kwargs: 要更新的字段
@@ -388,6 +406,8 @@ class AsyncTrainingManager(TaskQueueAdapter):
388
  if not kwargs:
389
  return
390
 
 
 
391
  async with aiosqlite.connect(self.db_path) as db:
392
  updates = []
393
  values = []
@@ -403,6 +423,57 @@ class AsyncTrainingManager(TaskQueueAdapter):
403
  values
404
  )
405
  await db.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  async def get_status(self, job_id: str) -> Dict:
408
  """
 
13
  import uuid
14
  from datetime import datetime
15
  from pathlib import Path
16
+ from typing import TYPE_CHECKING, Dict, Optional, AsyncGenerator, List
17
 
18
  import aiosqlite
19
 
20
  from ..base import TaskQueueAdapter
21
  from ...core.config import settings, PROJECT_ROOT, get_pythonpath
22
 
23
+ if TYPE_CHECKING:
24
+ from ..base import DatabaseAdapter
25
+
26
  # 进度消息标识符(与 run_pipeline.py 保持一致)
27
  PROGRESS_PREFIX = "##PROGRESS##"
28
  PROGRESS_SUFFIX = "##"
 
50
  >>> await manager.cancel(job_id)
51
  """
52
 
53
+ def __init__(
54
+ self,
55
+ db_path: str = None,
56
+ max_concurrent: int = 1,
57
+ database_adapter: "DatabaseAdapter" = None
58
+ ):
59
  """
60
  初始化任务管理器
61
 
62
  Args:
63
  db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH
64
  max_concurrent: 最大并发任务数(本地通常为1)
65
+ database_adapter: 数据库适配器,用于同步更新 tasks 表
66
  """
67
  self.db_path = db_path or str(settings.SQLITE_PATH)
68
  self.max_concurrent = max_concurrent
69
+ self._database_adapter = database_adapter
70
 
71
  # 运行时状态
72
  self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process
73
  self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue
74
  self._running_count = 0
75
  self._queue_lock = asyncio.Lock()
76
+
77
+ # task_id 到 job_id 的映射缓存
78
+ self._task_job_mapping: Dict[str, str] = {}
79
 
80
  # 初始化数据库
81
  self._init_db_sync()
 
136
  )
137
  await db.commit()
138
 
139
+ # 缓存 task_id -> job_id 映射
140
+ self._task_job_mapping[task_id] = job_id
141
+
142
  # 创建进度队列
143
  self.progress_channels[task_id] = asyncio.Queue()
144
 
 
397
  """
398
  更新任务状态
399
 
400
+ 同时更新 task_queue 表和 tasks 表(通过 DatabaseAdapter)。
401
+
402
  Args:
403
  job_id: 作业ID
404
  **kwargs: 要更新的字段
 
406
  if not kwargs:
407
  return
408
 
409
+ # 1. 更新 task_queue 表
410
+ task_id = None
411
  async with aiosqlite.connect(self.db_path) as db:
412
  updates = []
413
  values = []
 
423
  values
424
  )
425
  await db.commit()
426
+
427
+ # 获取 task_id 用于同步更新 tasks 表
428
+ async with db.execute(
429
+ "SELECT task_id FROM task_queue WHERE job_id = ?", (job_id,)
430
+ ) as cursor:
431
+ row = await cursor.fetchone()
432
+ if row:
433
+ task_id = row[0]
434
+
435
+ # 2. 同步更新 tasks 表(通过 DatabaseAdapter)
436
+ if self._database_adapter and task_id:
437
+ await self._sync_to_tasks_table(task_id, kwargs)
438
+
439
+ async def _sync_to_tasks_table(self, task_id: str, updates: Dict) -> None:
440
+ """
441
+ 同步状态更新到 tasks 表
442
+
443
+ 字段映射:
444
+ - task_queue.progress -> tasks.stage_progress
445
+ - task_queue.overall_progress -> tasks.progress
446
+ - 其他字段直接映射
447
+
448
+ Args:
449
+ task_id: 任务ID
450
+ updates: 要更新的字段字典
451
+ """
452
+ if not self._database_adapter:
453
+ return
454
+
455
+ # 字段映射
456
+ tasks_updates = {}
457
+
458
+ for key, value in updates.items():
459
+ if key == 'progress':
460
+ # task_queue.progress -> tasks.stage_progress
461
+ tasks_updates['stage_progress'] = value
462
+ elif key == 'overall_progress':
463
+ # task_queue.overall_progress -> tasks.progress
464
+ tasks_updates['progress'] = value
465
+ elif key in ('status', 'current_stage', 'message', 'error_message',
466
+ 'started_at', 'completed_at'):
467
+ # 直接映射的字段
468
+ tasks_updates[key] = value
469
+
470
+ if tasks_updates:
471
+ try:
472
+ await self._database_adapter.update_task(task_id, tasks_updates)
473
+ except Exception as e:
474
+ # 记录错误但不中断主流程
475
+ import logging
476
+ logging.warning(f"Failed to sync task status to tasks table: {e}")
477
 
478
  async def get_status(self, job_id: str) -> Dict:
479
  """
api_server/app/api/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API 模块
3
+
4
+ 包含所有 API 路由和端点
5
+ """
6
+
7
+ from .v1.router import api_router
8
+
9
+ __all__ = ["api_router"]
api_server/app/api/deps.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 依赖注入模块
3
+
4
+ 提供 FastAPI 依赖注入函数,用于获取服务和适配器实例
5
+ """
6
+
7
+ from functools import lru_cache
8
+ from typing import Generator
9
+
10
+ from ..services.task_service import TaskService
11
+ from ..services.experiment_service import ExperimentService
12
+ from ..services.file_service import FileService
13
+
14
+
15
+ # ============================================================
16
+ # 服务依赖
17
+ # ============================================================
18
+
19
+ @lru_cache()
20
+ def get_task_service() -> TaskService:
21
+ """
22
+ 获取 TaskService 实例
23
+
24
+ 使用 lru_cache 确保单例
25
+
26
+ Returns:
27
+ TaskService 实例
28
+
29
+ Example:
30
+ >>> @router.post("/tasks")
31
+ ... async def create_task(
32
+ ... request: QuickModeRequest,
33
+ ... service: TaskService = Depends(get_task_service)
34
+ ... ):
35
+ ... return await service.create_quick_task(request)
36
+ """
37
+ return TaskService()
38
+
39
+
40
+ @lru_cache()
41
+ def get_experiment_service() -> ExperimentService:
42
+ """
43
+ 获取 ExperimentService 实例
44
+
45
+ Returns:
46
+ ExperimentService 实例
47
+ """
48
+ return ExperimentService()
49
+
50
+
51
+ @lru_cache()
52
+ def get_file_service() -> FileService:
53
+ """
54
+ 获取 FileService 实例
55
+
56
+ Returns:
57
+ FileService 实例
58
+ """
59
+ return FileService()
60
+
61
+
62
+ # ============================================================
63
+ # 通用依赖
64
+ # ============================================================
65
+
66
+ async def get_pagination_params(
67
+ limit: int = 50,
68
+ offset: int = 0
69
+ ) -> dict:
70
+ """
71
+ 分页参数依赖
72
+
73
+ Args:
74
+ limit: 每页数量,默认 50,最大 100
75
+ offset: 偏移量,默认 0
76
+
77
+ Returns:
78
+ 分页参数字典
79
+ """
80
+ # 限制最大值
81
+ if limit > 100:
82
+ limit = 100
83
+ if limit < 1:
84
+ limit = 1
85
+ if offset < 0:
86
+ offset = 0
87
+
88
+ return {"limit": limit, "offset": offset}
89
+
90
+
91
+ __all__ = [
92
+ "get_task_service",
93
+ "get_experiment_service",
94
+ "get_file_service",
95
+ "get_pagination_params",
96
+ ]
api_server/app/api/v1/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API v1 模块
3
+
4
+ 包含 v1 版本的所有 API 端点
5
+ """
6
+
7
+ from .router import api_router
8
+
9
+ __all__ = ["api_router"]
api_server/app/api/v1/endpoints/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API v1 端点模块
3
+
4
+ 包含所有 API 端点实现
5
+ """
6
+
7
+ from . import tasks
8
+ from . import experiments
9
+ from . import files
10
+ from . import stages
11
+
12
+ __all__ = [
13
+ "tasks",
14
+ "experiments",
15
+ "files",
16
+ "stages",
17
+ ]
api_server/app/api/v1/endpoints/experiments.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Mode 实验 API
3
+
4
+ 专家用户分阶段训练 API 端点
5
+
6
+ API 列表:
7
+ - POST /experiments 创建实验
8
+ - GET /experiments 获取实验列表
9
+ - GET /experiments/{exp_id} 获取实验详情
10
+ - PATCH /experiments/{exp_id} 更新实验配置
11
+ - DELETE /experiments/{exp_id} 删除实验
12
+ - POST /experiments/{exp_id}/stages/{stage_type} 执行阶段
13
+ - GET /experiments/{exp_id}/stages 获取所有阶段状态
14
+ - GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情
15
+ - DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段
16
+ - GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度
17
+ """
18
+
19
+ import json
20
+ from typing import Any, Dict, Optional
21
+
22
+ from fastapi import APIRouter, Body, Depends, HTTPException, Query
23
+ from fastapi.responses import StreamingResponse
24
+
25
+ from ....models.schemas.experiment import (
26
+ ExperimentCreate,
27
+ ExperimentUpdate,
28
+ ExperimentResponse,
29
+ ExperimentListResponse,
30
+ StageStatus,
31
+ StageExecuteResponse,
32
+ StagesListResponse,
33
+ STAGE_DEPENDENCIES,
34
+ get_stage_params_class,
35
+ )
36
+ from ....models.schemas.common import SuccessResponse, ErrorResponse
37
+ from ....services.experiment_service import ExperimentService
38
+ from ...deps import get_experiment_service
39
+
40
+ router = APIRouter()
41
+
42
+ # 有效的阶段类型
43
+ VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys())
44
+
45
+
46
+ @router.post(
47
+ "",
48
+ response_model=ExperimentResponse,
49
+ summary="创建实验",
50
+ description="""
51
+ 创建实验(专家用户)。
52
+
53
+ 创建实验但不立即执行,用户可以逐阶段控制训练流程。
54
+ 实验创建后,所有阶段状态为 `pending`,需要手动触发执行。
55
+
56
+ **训练阶段**:
57
+ - `audio_slice`: 音频切片
58
+ - `asr`: 语音识别
59
+ - `text_feature`: 文本特征提取
60
+ - `hubert_feature`: HuBERT 特征提取
61
+ - `semantic_token`: 语义 Token 提取
62
+ - `sovits_train`: SoVITS 训练
63
+ - `gpt_train`: GPT 训练
64
+ """,
65
+ )
66
+ async def create_experiment(
67
+ request: ExperimentCreate,
68
+ service: ExperimentService = Depends(get_experiment_service),
69
+ ) -> ExperimentResponse:
70
+ """
71
+ 创建实验
72
+ """
73
+ return await service.create_experiment(request)
74
+
75
+
76
+ @router.get(
77
+ "",
78
+ response_model=ExperimentListResponse,
79
+ summary="获取实验列表",
80
+ description="获取所有实验列表,支持按状态筛选和分页。",
81
+ )
82
+ async def list_experiments(
83
+ status: Optional[str] = Query(None, description="按状态筛选"),
84
+ limit: int = Query(50, ge=1, le=100, description="每页数量"),
85
+ offset: int = Query(0, ge=0, description="偏移量"),
86
+ service: ExperimentService = Depends(get_experiment_service),
87
+ ) -> ExperimentListResponse:
88
+ """
89
+ 获取实验列表
90
+ """
91
+ return await service.list_experiments(status=status, limit=limit, offset=offset)
92
+
93
+
94
+ @router.get(
95
+ "/{exp_id}",
96
+ response_model=ExperimentResponse,
97
+ summary="获取实验详情",
98
+ description="获取指定实验的详细信息,包括所有阶段状态。",
99
+ responses={
100
+ 404: {"model": ErrorResponse, "description": "实验不存在"},
101
+ },
102
+ )
103
+ async def get_experiment(
104
+ exp_id: str,
105
+ service: ExperimentService = Depends(get_experiment_service),
106
+ ) -> ExperimentResponse:
107
+ """
108
+ 获取实验详情
109
+ """
110
+ experiment = await service.get_experiment(exp_id)
111
+ if not experiment:
112
+ raise HTTPException(status_code=404, detail="实验不存在")
113
+ return experiment
114
+
115
+
116
+ @router.patch(
117
+ "/{exp_id}",
118
+ response_model=ExperimentResponse,
119
+ summary="更新实验配置",
120
+ description="更新实验的基础配置(非阶段参数)。",
121
+ responses={
122
+ 404: {"model": ErrorResponse, "description": "实验不存在"},
123
+ },
124
+ )
125
+ async def update_experiment(
126
+ exp_id: str,
127
+ request: ExperimentUpdate,
128
+ service: ExperimentService = Depends(get_experiment_service),
129
+ ) -> ExperimentResponse:
130
+ """
131
+ 更新实验配置
132
+ """
133
+ experiment = await service.update_experiment(exp_id, request)
134
+ if not experiment:
135
+ raise HTTPException(status_code=404, detail="实验不存在")
136
+ return experiment
137
+
138
+
139
+ @router.delete(
140
+ "/{exp_id}",
141
+ response_model=SuccessResponse,
142
+ summary="删除实验",
143
+ description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。",
144
+ responses={
145
+ 404: {"model": ErrorResponse, "description": "实验不存在"},
146
+ },
147
+ )
148
+ async def delete_experiment(
149
+ exp_id: str,
150
+ service: ExperimentService = Depends(get_experiment_service),
151
+ ) -> SuccessResponse:
152
+ """
153
+ 删除实验
154
+ """
155
+ success = await service.delete_experiment(exp_id)
156
+ if not success:
157
+ raise HTTPException(status_code=404, detail="实验不存在")
158
+ return SuccessResponse(message="实验已删除")
159
+
160
+
161
+ @router.post(
162
+ "/{exp_id}/stages/{stage_type}",
163
+ response_model=StageExecuteResponse,
164
+ summary="执行阶段",
165
+ description="""
166
+ 执行指定阶段。
167
+
168
+ **阶段依赖关系**:
169
+ - `audio_slice`: 无依赖
170
+ - `asr`: 依赖 audio_slice
171
+ - `text_feature`: 依赖 asr
172
+ - `hubert_feature`: 依赖 audio_slice
173
+ - `semantic_token`: 依赖 hubert_feature
174
+ - `sovits_train`: 依赖 text_feature, semantic_token
175
+ - `gpt_train`: 依赖 text_feature, semantic_token
176
+
177
+ 如果依赖阶段未完成,会返回 400 错误。
178
+ 如果阶段已完成,会重新执行(返回 `rerun: true`)。
179
+ """,
180
+ responses={
181
+ 400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"},
182
+ 404: {"model": ErrorResponse, "description": "实验不存在"},
183
+ },
184
+ )
185
+ async def execute_stage(
186
+ exp_id: str,
187
+ stage_type: str,
188
+ params: Dict[str, Any] = Body(default={}),
189
+ service: ExperimentService = Depends(get_experiment_service),
190
+ ) -> StageExecuteResponse:
191
+ """
192
+ 执行阶段
193
+ """
194
+ # 验证阶段类型
195
+ if stage_type not in VALID_STAGE_TYPES:
196
+ raise HTTPException(
197
+ status_code=400,
198
+ detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}"
199
+ )
200
+
201
+ # 检查实验是否存在
202
+ experiment = await service.get_experiment(exp_id)
203
+ if not experiment:
204
+ raise HTTPException(status_code=404, detail="实验不存在")
205
+
206
+ # 检查依赖
207
+ deps = await service.check_stage_dependencies(exp_id, stage_type)
208
+ if not deps["satisfied"]:
209
+ raise HTTPException(
210
+ status_code=400,
211
+ detail=f"依赖阶段未完成: {', '.join(deps['missing'])}"
212
+ )
213
+
214
+ # 验证并解析参数
215
+ try:
216
+ params_class = get_stage_params_class(stage_type)
217
+ validated_params = params_class(**params)
218
+ params = validated_params.model_dump(exclude_unset=True)
219
+ except ValueError as e:
220
+ raise HTTPException(status_code=400, detail=str(e))
221
+
222
+ # 执行阶段
223
+ result = await service.execute_stage(exp_id, stage_type, params)
224
+ if not result:
225
+ raise HTTPException(status_code=404, detail="实验不存在")
226
+
227
+ return result
228
+
229
+
230
+ @router.get(
231
+ "/{exp_id}/stages",
232
+ response_model=StagesListResponse,
233
+ summary="获取所有阶段状态",
234
+ description="获取实验的所有阶段状态列表。",
235
+ responses={
236
+ 404: {"model": ErrorResponse, "description": "实验不存在"},
237
+ },
238
+ )
239
+ async def get_all_stages(
240
+ exp_id: str,
241
+ service: ExperimentService = Depends(get_experiment_service),
242
+ ) -> StagesListResponse:
243
+ """
244
+ 获取所有阶段状态
245
+ """
246
+ result = await service.get_all_stages(exp_id)
247
+ if not result:
248
+ raise HTTPException(status_code=404, detail="实验不存在")
249
+ return result
250
+
251
+
252
+ @router.get(
253
+ "/{exp_id}/stages/{stage_type}",
254
+ response_model=StageStatus,
255
+ summary="获取阶段详情",
256
+ description="获取指定阶段的详细状态和结果。",
257
+ responses={
258
+ 400: {"model": ErrorResponse, "description": "阶段类型无效"},
259
+ 404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
260
+ },
261
+ )
262
+ async def get_stage(
263
+ exp_id: str,
264
+ stage_type: str,
265
+ service: ExperimentService = Depends(get_experiment_service),
266
+ ) -> StageStatus:
267
+ """
268
+ 获取阶段详情
269
+ """
270
+ # 验证阶段类型
271
+ if stage_type not in VALID_STAGE_TYPES:
272
+ raise HTTPException(
273
+ status_code=400,
274
+ detail=f"无效的阶段类型: {stage_type}"
275
+ )
276
+
277
+ stage = await service.get_stage(exp_id, stage_type)
278
+ if not stage:
279
+ raise HTTPException(status_code=404, detail="实验或阶段不存在")
280
+ return stage
281
+
282
+
283
+ @router.delete(
284
+ "/{exp_id}/stages/{stage_type}",
285
+ response_model=SuccessResponse,
286
+ summary="取消阶段",
287
+ description="取消正在执行的阶段。只有运行中的阶段可以取消。",
288
+ responses={
289
+ 400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"},
290
+ 404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
291
+ },
292
+ )
293
+ async def cancel_stage(
294
+ exp_id: str,
295
+ stage_type: str,
296
+ service: ExperimentService = Depends(get_experiment_service),
297
+ ) -> SuccessResponse:
298
+ """
299
+ 取消阶段
300
+ """
301
+ # 验证阶段类型
302
+ if stage_type not in VALID_STAGE_TYPES:
303
+ raise HTTPException(
304
+ status_code=400,
305
+ detail=f"无效的阶段类型: {stage_type}"
306
+ )
307
+
308
+ success = await service.cancel_stage(exp_id, stage_type)
309
+ if not success:
310
+ raise HTTPException(
311
+ status_code=400,
312
+ detail="阶段未运行或无法取消"
313
+ )
314
+
315
+ return SuccessResponse(message=f"阶段 {stage_type} 已取消")
316
+
317
+
318
+ @router.get(
319
+ "/{exp_id}/stages/{stage_type}/progress",
320
+ summary="SSE 阶段进度订阅",
321
+ description="""
322
+ 订阅阶段进度更新(Server-Sent Events)。
323
+
324
+ 返回的事件流格式:
325
+ ```
326
+ event: progress
327
+ data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034}
328
+
329
+ event: checkpoint
330
+ data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"}
331
+
332
+ event: completed
333
+ data: {"status": "completed", "final_loss": 0.023}
334
+ ```
335
+ """,
336
+ responses={
337
+ 400: {"model": ErrorResponse, "description": "阶段类型无效"},
338
+ 404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
339
+ },
340
+ )
341
+ async def subscribe_stage_progress(
342
+ exp_id: str,
343
+ stage_type: str,
344
+ service: ExperimentService = Depends(get_experiment_service),
345
+ ) -> StreamingResponse:
346
+ """
347
+ SSE 阶段进度订阅
348
+ """
349
+ # 验证阶段类型
350
+ if stage_type not in VALID_STAGE_TYPES:
351
+ raise HTTPException(
352
+ status_code=400,
353
+ detail=f"无效的阶段类型: {stage_type}"
354
+ )
355
+
356
+ # 检查实验是否存在
357
+ experiment = await service.get_experiment(exp_id)
358
+ if not experiment:
359
+ raise HTTPException(status_code=404, detail="实验不存在")
360
+
361
+ async def event_generator():
362
+ """生成 SSE 事件流"""
363
+ async for progress in service.subscribe_stage_progress(exp_id, stage_type):
364
+ # 确定事件类型
365
+ event_type = progress.get("type", "progress")
366
+ status = progress.get("status")
367
+
368
+ if status == "completed":
369
+ event_type = "completed"
370
+ elif status == "failed":
371
+ event_type = "failed"
372
+ elif status == "cancelled":
373
+ event_type = "cancelled"
374
+ elif progress.get("model_path"):
375
+ event_type = "checkpoint"
376
+
377
+ # 构建 SSE 格式
378
+ data = json.dumps(progress, ensure_ascii=False)
379
+ yield f"event: {event_type}\ndata: {data}\n\n"
380
+
381
+ # 如果是终态,结束流
382
+ if status in ("completed", "failed", "cancelled"):
383
+ break
384
+
385
+ return StreamingResponse(
386
+ event_generator(),
387
+ media_type="text/event-stream",
388
+ headers={
389
+ "Cache-Control": "no-cache",
390
+ "Connection": "keep-alive",
391
+ "X-Accel-Buffering": "no",
392
+ },
393
+ )
api_server/app/api/v1/endpoints/files.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件管理 API
3
+
4
+ 文件上传、下载和管理 API 端点
5
+
6
+ API 列表:
7
+ - POST /files 上传文件
8
+ - GET /files 获取文件列表
9
+ - GET /files/{file_id} 下载文件(或获取元数据)
10
+ - DELETE /files/{file_id} 删除文件
11
+ """
12
+
13
+ from typing import Optional
14
+
15
+ from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
16
+ from fastapi.responses import Response
17
+
18
+ from ....models.schemas.file import (
19
+ FileMetadata,
20
+ FileUploadResponse,
21
+ FileListResponse,
22
+ FileDeleteResponse,
23
+ )
24
+ from ....models.schemas.common import ErrorResponse
25
+ from ....services.file_service import FileService
26
+ from ...deps import get_file_service
27
+
28
+ router = APIRouter()
29
+
30
+ # 允许的音频 MIME 类型
31
+ ALLOWED_AUDIO_TYPES = {
32
+ "audio/wav",
33
+ "audio/wave",
34
+ "audio/x-wav",
35
+ "audio/mpeg",
36
+ "audio/mp3",
37
+ "audio/mp4",
38
+ "audio/aac",
39
+ "audio/ogg",
40
+ "audio/flac",
41
+ "audio/x-flac",
42
+ "audio/webm",
43
+ }
44
+
45
+ # 最大文件大小 (500MB)
46
+ MAX_FILE_SIZE = 500 * 1024 * 1024
47
+
48
+
49
+ @router.post(
50
+ "",
51
+ response_model=FileUploadResponse,
52
+ summary="上传文件",
53
+ description="""
54
+ 上传音频文件用于训练。
55
+
56
+ **支持的音频格式**:
57
+ - WAV
58
+ - MP3
59
+ - FLAC
60
+ - OGG
61
+ - AAC
62
+ - WebM
63
+
64
+ **文件大小限制**: 500MB
65
+
66
+ **用途类型**:
67
+ - `training`: 训练音频(默认)
68
+ - `reference`: 参考音频
69
+ - `output`: 输出文件
70
+ """,
71
+ responses={
72
+ 200: {"model": FileUploadResponse, "description": "文件上传成功"},
73
+ 400: {"model": ErrorResponse, "description": "文件格式或大小不合法"},
74
+ },
75
+ )
76
+ async def upload_file(
77
+ file: UploadFile = File(..., description="要上传的音频文件"),
78
+ purpose: str = Query(
79
+ "training",
80
+ description="文件用途: training, reference, output"
81
+ ),
82
+ service: FileService = Depends(get_file_service),
83
+ ) -> FileUploadResponse:
84
+ """
85
+ 上传文件
86
+ """
87
+ # 验证用途
88
+ if purpose not in ("training", "reference", "output"):
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail="无效的用途类型,有效值: training, reference, output"
92
+ )
93
+
94
+ # 验证文件类型(可选,允许不明确类型的文件)
95
+ content_type = file.content_type
96
+ if content_type and content_type not in ALLOWED_AUDIO_TYPES:
97
+ # 检查文件扩展名
98
+ filename = file.filename or ""
99
+ ext = filename.lower().split(".")[-1] if "." in filename else ""
100
+ allowed_exts = {"wav", "mp3", "flac", "ogg", "aac", "webm", "m4a"}
101
+ if ext not in allowed_exts:
102
+ raise HTTPException(
103
+ status_code=400,
104
+ detail=f"不支持的文件类型: {content_type}。支持的格式: WAV, MP3, FLAC, OGG, AAC, WebM"
105
+ )
106
+
107
+ # 读取文件内容
108
+ file_data = await file.read()
109
+
110
+ # 验证文件大小
111
+ if len(file_data) > MAX_FILE_SIZE:
112
+ raise HTTPException(
113
+ status_code=400,
114
+ detail=f"文件过大,最大允许 {MAX_FILE_SIZE // (1024*1024)}MB"
115
+ )
116
+
117
+ # 验证文件不为空
118
+ if len(file_data) == 0:
119
+ raise HTTPException(
120
+ status_code=400,
121
+ detail="文件为空"
122
+ )
123
+
124
+ # 上传文件
125
+ return await service.upload_file(
126
+ file_data=file_data,
127
+ filename=file.filename or "audio",
128
+ content_type=content_type,
129
+ purpose=purpose,
130
+ )
131
+
132
+
133
+ @router.get(
134
+ "",
135
+ response_model=FileListResponse,
136
+ summary="获取文件列表",
137
+ description="获取已上传的文件列表,支持按用途筛选和分页。",
138
+ )
139
+ async def list_files(
140
+ purpose: Optional[str] = Query(
141
+ None,
142
+ description="按用途筛选: training, reference, output"
143
+ ),
144
+ limit: int = Query(50, ge=1, le=100, description="每页数量"),
145
+ offset: int = Query(0, ge=0, description="偏移量"),
146
+ service: FileService = Depends(get_file_service),
147
+ ) -> FileListResponse:
148
+ """
149
+ 获取文件列表
150
+ """
151
+ return await service.list_files(purpose=purpose, limit=limit, offset=offset)
152
+
153
+
154
+ @router.get(
155
+ "/{file_id}",
156
+ summary="下载文件或获取元数据",
157
+ description="""
158
+ 根据请求类型返回文件内容或元数据。
159
+
160
+ - 默认返回文件内容(用于下载)
161
+ - 添加 `?metadata=true` 参数只返回元数据
162
+ """,
163
+ responses={
164
+ 200: {
165
+ "description": "文件内容(下载时)或元数据(metadata=true 时)",
166
+ },
167
+ 404: {"model": ErrorResponse, "description": "文件不存在"},
168
+ },
169
+ )
170
+ async def get_file(
171
+ file_id: str,
172
+ metadata: bool = Query(False, description="只返回元数据"),
173
+ service: FileService = Depends(get_file_service),
174
+ ):
175
+ """
176
+ 下载文件或获取元数据
177
+ """
178
+ if metadata:
179
+ # 只返回元数据
180
+ file_metadata = await service.get_file(file_id)
181
+ if not file_metadata:
182
+ raise HTTPException(status_code=404, detail="文件不存在")
183
+ return file_metadata
184
+ else:
185
+ # 下载文件
186
+ result = await service.download_file(file_id)
187
+ if not result:
188
+ raise HTTPException(status_code=404, detail="文件不存在")
189
+
190
+ file_data, filename, content_type = result
191
+
192
+ return Response(
193
+ content=file_data,
194
+ media_type=content_type,
195
+ headers={
196
+ "Content-Disposition": f'attachment; filename="{filename}"',
197
+ "Content-Length": str(len(file_data)),
198
+ },
199
+ )
200
+
201
+
202
+ @router.delete(
203
+ "/{file_id}",
204
+ response_model=FileDeleteResponse,
205
+ summary="删除文件",
206
+ description="删除指定的文件。",
207
+ responses={
208
+ 200: {"model": FileDeleteResponse, "description": "删除结果"},
209
+ 404: {"model": ErrorResponse, "description": "文件不存在"},
210
+ },
211
+ )
212
+ async def delete_file(
213
+ file_id: str,
214
+ service: FileService = Depends(get_file_service),
215
+ ) -> FileDeleteResponse:
216
+ """
217
+ 删除文件
218
+ """
219
+ result = await service.delete_file(file_id)
220
+ if not result.success:
221
+ raise HTTPException(status_code=404, detail="文件不存在或已删除")
222
+ return result
api_server/app/api/v1/endpoints/stages.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 阶段模板 API
3
+
4
+ 阶段预设和参数模板 API 端点
5
+
6
+ API 列表:
7
+ - GET /stages/presets 获取阶段预设列表
8
+ - GET /stages/{stage_type}/schema 获取阶段参数模板
9
+ """
10
+
11
+ from typing import Any, Dict, List
12
+
13
+ from fastapi import APIRouter, HTTPException
14
+
15
+ from ....models.schemas.experiment import (
16
+ STAGE_DEPENDENCIES,
17
+ STAGE_PARAMS_MAP,
18
+ AudioSliceParams,
19
+ ASRParams,
20
+ TextFeatureParams,
21
+ HubertFeatureParams,
22
+ SemanticTokenParams,
23
+ SoVITSTrainParams,
24
+ GPTTrainParams,
25
+ )
26
+ from ....models.schemas.common import ErrorResponse
27
+
28
+ router = APIRouter()
29
+
30
+
31
+ # ============================================================
32
+ # 阶段预设定义
33
+ # ============================================================
34
+
35
+ STAGE_PRESETS = [
36
+ {
37
+ "id": "full_training",
38
+ "name": "完整训练流程",
39
+ "description": "包含所有阶段的标准训练,从音频切片到模型训练",
40
+ "stages": [
41
+ "audio_slice",
42
+ "asr",
43
+ "text_feature",
44
+ "hubert_feature",
45
+ "semantic_token",
46
+ "sovits_train",
47
+ "gpt_train",
48
+ ],
49
+ },
50
+ {
51
+ "id": "retrain_sovits",
52
+ "name": "重训 SoVITS",
53
+ "description": "跳过预处理,仅重新训练 SoVITS 模型",
54
+ "stages": ["sovits_train"],
55
+ },
56
+ {
57
+ "id": "retrain_gpt",
58
+ "name": "重训 GPT",
59
+ "description": "跳过预处理,仅重新训练 GPT 模型",
60
+ "stages": ["gpt_train"],
61
+ },
62
+ {
63
+ "id": "retrain_both",
64
+ "name": "重训两个模型",
65
+ "description": "跳过预处理,重新训练 SoVITS 和 GPT 模型",
66
+ "stages": ["sovits_train", "gpt_train"],
67
+ },
68
+ {
69
+ "id": "feature_extraction",
70
+ "name": "特征提取",
71
+ "description": "仅执行音频切片和特征提取,不进行训练",
72
+ "stages": [
73
+ "audio_slice",
74
+ "asr",
75
+ "text_feature",
76
+ "hubert_feature",
77
+ "semantic_token",
78
+ ],
79
+ },
80
+ {
81
+ "id": "audio_preprocessing",
82
+ "name": "音频预处理",
83
+ "description": "仅执行音频切片和语音识别",
84
+ "stages": ["audio_slice", "asr"],
85
+ },
86
+ ]
87
+
88
+
89
+ # ============================================================
90
+ # 阶段信息定义
91
+ # ============================================================
92
+
93
+ STAGE_INFO = {
94
+ "audio_slice": {
95
+ "name": "音频切片",
96
+ "description": "将长音频切分为短片段,便于后续处理",
97
+ "dependencies": [],
98
+ },
99
+ "asr": {
100
+ "name": "语音识别",
101
+ "description": "识别音频中的文本内容",
102
+ "dependencies": ["audio_slice"],
103
+ },
104
+ "text_feature": {
105
+ "name": "文本特征提取",
106
+ "description": "使用 BERT 模型提取文本特征",
107
+ "dependencies": ["asr"],
108
+ },
109
+ "hubert_feature": {
110
+ "name": "HuBERT 特征提取",
111
+ "description": "使用 HuBERT 模型提取音频特征",
112
+ "dependencies": ["audio_slice"],
113
+ },
114
+ "semantic_token": {
115
+ "name": "语义 Token 提取",
116
+ "description": "从 HuBERT 特征中提取语义 Token",
117
+ "dependencies": ["hubert_feature"],
118
+ },
119
+ "sovits_train": {
120
+ "name": "SoVITS 训练",
121
+ "description": "训练 SoVITS 声码器模型",
122
+ "dependencies": ["text_feature", "semantic_token"],
123
+ },
124
+ "gpt_train": {
125
+ "name": "GPT 训练",
126
+ "description": "训练 GPT 语言模型",
127
+ "dependencies": ["text_feature", "semantic_token"],
128
+ },
129
+ }
130
+
131
+
132
+ def get_parameter_schema(params_class: type) -> Dict[str, Any]:
133
+ """
134
+ 从 Pydantic 模型生成参数 Schema
135
+
136
+ Args:
137
+ params_class: Pydantic 模型类
138
+
139
+ Returns:
140
+ 参数 Schema 字典
141
+ """
142
+ schema = params_class.model_json_schema()
143
+ properties = schema.get("properties", {})
144
+
145
+ parameters = {}
146
+ for name, prop in properties.items():
147
+ param_info = {
148
+ "type": prop.get("type", "string"),
149
+ "description": prop.get("description", ""),
150
+ }
151
+
152
+ # 添加默认值
153
+ if "default" in prop:
154
+ param_info["default"] = prop["default"]
155
+
156
+ # 添加范围限制
157
+ if "minimum" in prop:
158
+ param_info["min"] = prop["minimum"]
159
+ if "maximum" in prop:
160
+ param_info["max"] = prop["maximum"]
161
+
162
+ # 处理枚举
163
+ if "enum" in prop:
164
+ param_info["enum"] = prop["enum"]
165
+
166
+ parameters[name] = param_info
167
+
168
+ return parameters
169
+
170
+
171
+ @router.get(
172
+ "/presets",
173
+ summary="获取阶段预设列表",
174
+ description="""
175
+ 获取预定义的训练流程预设。
176
+
177
+ 每个预设包含一组阶段,用户可以选择预设快速配置训练流程。
178
+ """,
179
+ response_model=Dict[str, List[Dict[str, Any]]],
180
+ )
181
+ async def get_presets() -> Dict[str, List[Dict[str, Any]]]:
182
+ """
183
+ 获取阶段预设列表
184
+ """
185
+ return {"presets": STAGE_PRESETS}
186
+
187
+
188
+ @router.get(
189
+ "/{stage_type}/schema",
190
+ summary="获取阶段参数模板",
191
+ description="""
192
+ 获取指定阶段的参数模板,包含参数定义、默认值和取值范围。
193
+
194
+ 前端可以使用此接口动态生成参数配置表单。
195
+ """,
196
+ responses={
197
+ 200: {"description": "阶段参数模板"},
198
+ 404: {"model": ErrorResponse, "description": "阶段类型无效"},
199
+ },
200
+ )
201
+ async def get_stage_schema(stage_type: str) -> Dict[str, Any]:
202
+ """
203
+ 获取阶段参数模板
204
+ """
205
+ # 验证阶段类型
206
+ if stage_type not in STAGE_PARAMS_MAP:
207
+ raise HTTPException(
208
+ status_code=404,
209
+ detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(STAGE_PARAMS_MAP.keys())}"
210
+ )
211
+
212
+ # 获取阶段信息
213
+ stage_info = STAGE_INFO.get(stage_type, {})
214
+ params_class = STAGE_PARAMS_MAP[stage_type]
215
+
216
+ # 生成参数 schema
217
+ parameters = get_parameter_schema(params_class)
218
+
219
+ return {
220
+ "type": stage_type,
221
+ "name": stage_info.get("name", stage_type),
222
+ "description": stage_info.get("description", ""),
223
+ "dependencies": STAGE_DEPENDENCIES.get(stage_type, []),
224
+ "parameters": parameters,
225
+ }
226
+
227
+
228
+ @router.get(
229
+ "",
230
+ summary="获取所有阶段信息",
231
+ description="获取所有训练阶段的信息和依赖关系。",
232
+ )
233
+ async def get_all_stages() -> Dict[str, Any]:
234
+ """
235
+ 获取所有阶段信息
236
+ """
237
+ stages = []
238
+ for stage_type in STAGE_PARAMS_MAP.keys():
239
+ stage_info = STAGE_INFO.get(stage_type, {})
240
+ stages.append({
241
+ "type": stage_type,
242
+ "name": stage_info.get("name", stage_type),
243
+ "description": stage_info.get("description", ""),
244
+ "dependencies": STAGE_DEPENDENCIES.get(stage_type, []),
245
+ })
246
+
247
+ return {"stages": stages}
api_server/app/api/v1/endpoints/tasks.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick Mode 任务 API
3
+
4
+ 小白用户一键训练 API 端点
5
+
6
+ API 列表:
7
+ - POST /tasks 创建一键训练任务
8
+ - GET /tasks 获取任务列表
9
+ - GET /tasks/{task_id} 获取任务详情
10
+ - DELETE /tasks/{task_id} 取消任务
11
+ - GET /tasks/{task_id}/progress SSE 进度订阅
12
+ """
13
+
14
+ import json
15
+ from typing import Optional
16
+
17
+ from fastapi import APIRouter, Depends, HTTPException, Query
18
+ from fastapi.responses import StreamingResponse
19
+
20
+ from ....models.schemas.task import (
21
+ QuickModeRequest,
22
+ TaskResponse,
23
+ TaskListResponse,
24
+ )
25
+ from ....models.schemas.common import SuccessResponse, ErrorResponse
26
+ from ....services.task_service import TaskService
27
+ from ...deps import get_task_service
28
+
29
+ router = APIRouter()
30
+
31
+
32
+ @router.post(
33
+ "",
34
+ response_model=TaskResponse,
35
+ summary="创建一键训练任务",
36
+ description="""
37
+ 创建一键训练任务(小白用户)。
38
+
39
+ 上传音频文件后,系统自动配置所有参数并执行完整训练流程:
40
+ `audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train`
41
+
42
+ **质量预设**:
43
+ - `fast`: SoVITS 4 epochs, GPT 8 epochs, 约10分钟
44
+ - `standard`: SoVITS 8 epochs, GPT 15 epochs, 约20分钟
45
+ - `high`: SoVITS 16 epochs, GPT 30 epochs, 约40分钟
46
+ """,
47
+ responses={
48
+ 200: {"model": TaskResponse, "description": "任务创建成功"},
49
+ 400: {"model": ErrorResponse, "description": "请求参数错误"},
50
+ 404: {"model": ErrorResponse, "description": "音频文件不存在"},
51
+ 409: {"model": ErrorResponse, "description": "实验名称已存在"},
52
+ },
53
+ )
54
+ async def create_task(
55
+ request: QuickModeRequest,
56
+ service: TaskService = Depends(get_task_service),
57
+ ) -> TaskResponse:
58
+ """
59
+ 创建一键训练任务
60
+ """
61
+ # 验证 exp_name 是否已存在
62
+ if await service.check_exp_name_exists(request.exp_name):
63
+ raise HTTPException(
64
+ status_code=409,
65
+ detail=f"实验名称 '{request.exp_name}' 已存在,请使用不同的名称"
66
+ )
67
+
68
+ # 验证音频文件是否存在
69
+ file_exists, audio_path = await service.validate_audio_file(request.audio_file_id)
70
+ if not file_exists:
71
+ raise HTTPException(
72
+ status_code=404,
73
+ detail=f"音频文件不存在: {request.audio_file_id}"
74
+ )
75
+
76
+ return await service.create_quick_task(request)
77
+
78
+
79
+ @router.get(
80
+ "",
81
+ response_model=TaskListResponse,
82
+ summary="获取任务列表",
83
+ description="获取所有训练任务列表,支持按状态筛选和分页。",
84
+ )
85
+ async def list_tasks(
86
+ status: Optional[str] = Query(
87
+ None,
88
+ description="按状态筛选: queued, running, completed, failed, cancelled, interrupted"
89
+ ),
90
+ limit: int = Query(50, ge=1, le=100, description="每页数量"),
91
+ offset: int = Query(0, ge=0, description="偏移量"),
92
+ service: TaskService = Depends(get_task_service),
93
+ ) -> TaskListResponse:
94
+ """
95
+ 获取任务列表
96
+ """
97
+ return await service.list_tasks(status=status, limit=limit, offset=offset)
98
+
99
+
100
+ @router.get(
101
+ "/{task_id}",
102
+ response_model=TaskResponse,
103
+ summary="获取任务详情",
104
+ description="获取指定任务的详细状态信息。",
105
+ responses={
106
+ 200: {"model": TaskResponse, "description": "任务详情"},
107
+ 404: {"model": ErrorResponse, "description": "任务不存在"},
108
+ },
109
+ )
110
+ async def get_task(
111
+ task_id: str,
112
+ service: TaskService = Depends(get_task_service),
113
+ ) -> TaskResponse:
114
+ """
115
+ 获取任务详情
116
+ """
117
+ task = await service.get_task(task_id)
118
+ if not task:
119
+ raise HTTPException(status_code=404, detail="任务不存在")
120
+ return task
121
+
122
+
123
+ @router.delete(
124
+ "/{task_id}",
125
+ response_model=SuccessResponse,
126
+ summary="取消任务",
127
+ description="取消排队中或运行中的任务。已完成、失败或已取消的任务无法取消。",
128
+ responses={
129
+ 200: {"model": SuccessResponse, "description": "任务取消成功"},
130
+ 400: {"model": ErrorResponse, "description": "任务无法取消"},
131
+ 404: {"model": ErrorResponse, "description": "任务不存在"},
132
+ },
133
+ )
134
+ async def cancel_task(
135
+ task_id: str,
136
+ service: TaskService = Depends(get_task_service),
137
+ ) -> SuccessResponse:
138
+ """
139
+ 取消任务
140
+ """
141
+ # 先检查任务是否存在
142
+ task = await service.get_task(task_id)
143
+ if not task:
144
+ raise HTTPException(status_code=404, detail="任务不存在")
145
+
146
+ success = await service.cancel_task(task_id)
147
+ if not success:
148
+ raise HTTPException(status_code=400, detail="任务无法取消(可能已完成或已取消)")
149
+
150
+ return SuccessResponse(message="任务已取消")
151
+
152
+
153
+ @router.get(
154
+ "/{task_id}/progress",
155
+ summary="SSE 进度订阅",
156
+ description="""
157
+ 订阅任务进度更新(Server-Sent Events)。
158
+
159
+ 返回的事件流格式:
160
+ ```
161
+ event: progress
162
+ data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"}
163
+
164
+ event: progress
165
+ data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"}
166
+
167
+ event: completed
168
+ data: {"status": "completed", "message": "训练完成"}
169
+ ```
170
+
171
+ 可能的事件类型:
172
+ - `progress`: 进度更新
173
+ - `log`: 日志消息
174
+ - `heartbeat`: 心跳(保持连接)
175
+ - `completed`: 任务完成
176
+ - `failed`: 任务失败
177
+ - `cancelled`: 任务取消
178
+ """,
179
+ responses={
180
+ 200: {"description": "SSE 事件流"},
181
+ 404: {"model": ErrorResponse, "description": "任务不存在"},
182
+ },
183
+ )
184
+ async def subscribe_progress(
185
+ task_id: str,
186
+ service: TaskService = Depends(get_task_service),
187
+ ) -> StreamingResponse:
188
+ """
189
+ SSE 进度订阅
190
+ """
191
+ # 先检查任务是否存在
192
+ task = await service.get_task(task_id)
193
+ if not task:
194
+ raise HTTPException(status_code=404, detail="任务不存在")
195
+
196
+ async def event_generator():
197
+ """生成 SSE 事件流"""
198
+ async for progress in service.subscribe_progress(task_id):
199
+ # 确定事件类型
200
+ event_type = progress.get("type", "progress")
201
+ status = progress.get("status")
202
+
203
+ if status == "completed":
204
+ event_type = "completed"
205
+ elif status == "failed":
206
+ event_type = "failed"
207
+ elif status == "cancelled":
208
+ event_type = "cancelled"
209
+ elif event_type == "heartbeat":
210
+ event_type = "heartbeat"
211
+
212
+ # 构建 SSE 格式
213
+ data = json.dumps(progress, ensure_ascii=False)
214
+ yield f"event: {event_type}\ndata: {data}\n\n"
215
+
216
+ # 如果是终态,结束流
217
+ if status in ("completed", "failed", "cancelled"):
218
+ break
219
+
220
+ return StreamingResponse(
221
+ event_generator(),
222
+ media_type="text/event-stream",
223
+ headers={
224
+ "Cache-Control": "no-cache",
225
+ "Connection": "keep-alive",
226
+ "X-Accel-Buffering": "no", # Nginx 禁用缓冲
227
+ },
228
+ )
api_server/app/api/v1/router.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API v1 路由注册
3
+
4
+ 统一注册所有 v1 版本的 API 路由
5
+ """
6
+
7
+ from fastapi import APIRouter
8
+
9
+ from .endpoints import tasks, experiments, files, stages
10
+
11
+ api_router = APIRouter()
12
+
13
+ # Quick Mode API - 一键训练任务
14
+ api_router.include_router(
15
+ tasks.router,
16
+ prefix="/tasks",
17
+ tags=["Quick Mode - 任务管理"],
18
+ )
19
+
20
+ # Advanced Mode API - 专家模式实验
21
+ api_router.include_router(
22
+ experiments.router,
23
+ prefix="/experiments",
24
+ tags=["Advanced Mode - 实验管理"],
25
+ )
26
+
27
+ # 文件管理 API
28
+ api_router.include_router(
29
+ files.router,
30
+ prefix="/files",
31
+ tags=["文件管理"],
32
+ )
33
+
34
+ # 阶段模板 API
35
+ api_router.include_router(
36
+ stages.router,
37
+ prefix="/stages",
38
+ tags=["阶段模板"],
39
+ )
api_server/app/core/adapters.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 适配器工厂模块
3
+
4
+ 根据 DEPLOYMENT_MODE 配置自动选择本地或服务器适配器。
5
+
6
+ Example:
7
+ >>> from app.core.adapters import get_database_adapter, get_storage_adapter
8
+ >>> db = get_database_adapter()
9
+ >>> storage = get_storage_adapter()
10
+ """
11
+
12
+ from functools import lru_cache
13
+ from typing import TYPE_CHECKING
14
+
15
+ from .config import settings
16
+
17
+ if TYPE_CHECKING:
18
+ from ..adapters.base import (
19
+ DatabaseAdapter,
20
+ ProgressAdapter,
21
+ StorageAdapter,
22
+ TaskQueueAdapter,
23
+ )
24
+
25
+
26
+ class AdapterFactory:
27
+ """
28
+ 适配器工厂
29
+
30
+ 根据 DEPLOYMENT_MODE 配置创建对应的适配器实例。
31
+
32
+ - local 模式: SQLite + 本地文件系统 + asyncio.subprocess
33
+ - server 模式: PostgreSQL + S3/MinIO + Celery (Phase 2)
34
+ """
35
+
36
+ @staticmethod
37
+ def create_storage_adapter() -> "StorageAdapter":
38
+ """
39
+ 创建存储适配器
40
+
41
+ Returns:
42
+ 本地模式返回 LocalStorageAdapter
43
+ 服务器模式返回 S3StorageAdapter (Phase 2)
44
+ """
45
+ if settings.DEPLOYMENT_MODE == "local":
46
+ from ..adapters.local.storage import LocalStorageAdapter
47
+ return LocalStorageAdapter(base_path=str(settings.DATA_DIR / "files"))
48
+ else:
49
+ # Phase 2: 服务器模式
50
+ raise NotImplementedError("Server mode storage adapter not implemented yet")
51
+
52
+ @staticmethod
53
+ def create_database_adapter() -> "DatabaseAdapter":
54
+ """
55
+ 创建数据库适配器
56
+
57
+ Returns:
58
+ 本地模式返回 SQLiteAdapter
59
+ 服务器模式返回 PostgreSQLAdapter (Phase 2)
60
+ """
61
+ if settings.DEPLOYMENT_MODE == "local":
62
+ from ..adapters.local.database import SQLiteAdapter
63
+ return SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
64
+ else:
65
+ # Phase 2: 服务器模式
66
+ raise NotImplementedError("Server mode database adapter not implemented yet")
67
+
68
+ @staticmethod
69
+ def create_task_queue_adapter(database_adapter: "DatabaseAdapter" = None) -> "TaskQueueAdapter":
70
+ """
71
+ 创建任务队列适配器
72
+
73
+ Args:
74
+ database_adapter: 数据库适配器,用于同步任务状态到 tasks 表。
75
+ 如果未提供,将自动创建一个实例。
76
+
77
+ Returns:
78
+ 本地模式返回 AsyncTrainingManager
79
+ 服务器模式返回 CeleryTaskQueueAdapter (Phase 2)
80
+ """
81
+ if settings.DEPLOYMENT_MODE == "local":
82
+ from ..adapters.local.task_queue import AsyncTrainingManager
83
+ from ..adapters.local.database import SQLiteAdapter
84
+
85
+ # 如果未提供 database_adapter,创建一个新实例用于状态同步
86
+ if database_adapter is None:
87
+ database_adapter = SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
88
+
89
+ return AsyncTrainingManager(
90
+ db_path=str(settings.SQLITE_PATH),
91
+ database_adapter=database_adapter
92
+ )
93
+ else:
94
+ # Phase 2: 服务器模式
95
+ raise NotImplementedError("Server mode task queue adapter not implemented yet")
96
+
97
+ @staticmethod
98
+ def create_progress_adapter() -> "ProgressAdapter":
99
+ """
100
+ 创建进度管理适配器
101
+
102
+ Returns:
103
+ 本地模式返回 LocalProgressAdapter
104
+ 服务器模式返回 RedisProgressAdapter (Phase 2)
105
+ """
106
+ if settings.DEPLOYMENT_MODE == "local":
107
+ from ..adapters.local.progress import LocalProgressAdapter
108
+ return LocalProgressAdapter()
109
+ else:
110
+ # Phase 2: 服务器模式
111
+ raise NotImplementedError("Server mode progress adapter not implemented yet")
112
+
113
+
114
+ # ============================================================
115
+ # 全局单例获取函数(使用 lru_cache 缓存实例)
116
+ # ============================================================
117
+
118
+ @lru_cache()
119
+ def get_storage_adapter() -> "StorageAdapter":
120
+ """
121
+ 获取存储适配器单例
122
+
123
+ Returns:
124
+ StorageAdapter 实例
125
+ """
126
+ return AdapterFactory.create_storage_adapter()
127
+
128
+
129
+ @lru_cache()
130
+ def get_database_adapter() -> "DatabaseAdapter":
131
+ """
132
+ 获取数据库适配器单例
133
+
134
+ Returns:
135
+ DatabaseAdapter 实例
136
+ """
137
+ return AdapterFactory.create_database_adapter()
138
+
139
+
140
+ @lru_cache()
141
+ def get_task_queue_adapter() -> "TaskQueueAdapter":
142
+ """
143
+ 获取任务队列适配器单例
144
+
145
+ 使用共享的数据库适配器实例来确保状态同步一致性。
146
+
147
+ Returns:
148
+ TaskQueueAdapter 实例
149
+ """
150
+ # 使用共享的数据库适配器实例
151
+ db_adapter = get_database_adapter()
152
+ return AdapterFactory.create_task_queue_adapter(database_adapter=db_adapter)
153
+
154
+
155
+ @lru_cache()
156
+ def get_progress_adapter() -> "ProgressAdapter":
157
+ """
158
+ 获���进度管理适配器单例
159
+
160
+ Returns:
161
+ ProgressAdapter 实例
162
+ """
163
+ return AdapterFactory.create_progress_adapter()
164
+
165
+
166
+ # ============================================================
167
+ # 便捷别名(向后兼容)
168
+ # ============================================================
169
+
170
+ # 延迟初始化的全局变量,在首次访问时创建
171
+ # 注意:这些是函数调用的结果,不是直接的实例引用
172
+ # 如果需要在模块级别使用,请调用对应的 get_*_adapter() 函数
173
+
174
+ __all__ = [
175
+ "AdapterFactory",
176
+ "get_storage_adapter",
177
+ "get_database_adapter",
178
+ "get_task_queue_adapter",
179
+ "get_progress_adapter",
180
+ ]
api_server/app/main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI 应用入口
3
+
4
+ GPT-SoVITS 音色训练 HTTP API 服务
5
+
6
+ 启动方式:
7
+ uvicorn api_server.app.main:app --host 0.0.0.0 --port 8000 --reload
8
+ """
9
+
10
+ from contextlib import asynccontextmanager
11
+ from typing import AsyncGenerator
12
+
13
+ from fastapi import FastAPI
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+
16
+ from .api.v1.router import api_router
17
+ from .core.config import settings, ensure_data_dirs
18
+
19
+
20
+ @asynccontextmanager
21
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
22
+ """
23
+ 应用生命周期管理
24
+
25
+ 启动时:
26
+ - 确保数据目录存在
27
+ - 恢复中断的任务(可选)
28
+
29
+ 关闭时:
30
+ - 清理资源
31
+ """
32
+ # 启动时执行
33
+ print(f"Starting GPT-SoVITS Training API in {settings.DEPLOYMENT_MODE.upper()} mode")
34
+ print(f" Project Root: {settings.PROJECT_ROOT}")
35
+ print(f" Data Directory: {settings.DATA_DIR}")
36
+ print(f" SQLite Path: {settings.SQLITE_PATH}")
37
+
38
+ # 确保数据目录存在
39
+ ensure_data_dirs()
40
+
41
+ # 恢复中断的任务(可选)
42
+ if settings.DEPLOYMENT_MODE == "local":
43
+ try:
44
+ from .core.adapters import get_task_queue_adapter
45
+ queue = get_task_queue_adapter()
46
+ # 检查是否有 recover_pending_tasks 方法
47
+ if hasattr(queue, 'recover_pending_tasks'):
48
+ count = await queue.recover_pending_tasks()
49
+ if count > 0:
50
+ print(f" Recovered {count} pending tasks")
51
+ except Exception as e:
52
+ print(f" Warning: Failed to recover tasks: {e}")
53
+
54
+ print(" API Server ready!")
55
+ print(f" Docs: http://{settings.API_HOST}:{settings.API_PORT}/docs")
56
+
57
+ yield
58
+
59
+ # 关闭时执行
60
+ print("Shutting down GPT-SoVITS Training API...")
61
+
62
+
63
+ # 创建 FastAPI 应用
64
+ app = FastAPI(
65
+ title="GPT-SoVITS Training API",
66
+ description="""
67
+ GPT-SoVITS 音色训练 HTTP API 服务
68
+
69
+ ## 功能概述
70
+
71
+ 提供两种训练模式:
72
+
73
+ ### Quick Mode(小白用户)
74
+ - 上传音频即可训练,系统自动配置所有参数
75
+ - 适合个人开发者、快速验证
76
+
77
+ ### Advanced Mode(专家用户)
78
+ - 分阶段控制训练流程
79
+ - 精细调整每个阶段的参数
80
+ - 适合需要深度定制的用户
81
+
82
+ ## API 分组
83
+
84
+ - **Quick Mode - 任务管理**: `/api/v1/tasks`
85
+ - **Advanced Mode - 实验管理**: `/api/v1/experiments`
86
+ - **文件管理**: `/api/v1/files`
87
+ - **阶段模板**: `/api/v1/stages`
88
+ """,
89
+ version="1.0.0",
90
+ lifespan=lifespan,
91
+ docs_url="/docs",
92
+ redoc_url="/redoc",
93
+ openapi_url="/openapi.json",
94
+ )
95
+
96
+ # 配置 CORS
97
+ app.add_middleware(
98
+ CORSMiddleware,
99
+ allow_origins=["*"], # 生产环境应该限制来源
100
+ allow_credentials=True,
101
+ allow_methods=["*"],
102
+ allow_headers=["*"],
103
+ )
104
+
105
+ # 注册 API 路由
106
+ app.include_router(api_router, prefix=settings.API_V1_PREFIX)
107
+
108
+
109
+ # ============================================================
110
+ # 根路由和健康检查
111
+ # ============================================================
112
+
113
+ @app.get("/", tags=["Root"])
114
+ async def root():
115
+ """
116
+ 根路由
117
+
118
+ 返回 API 基本信息
119
+ """
120
+ return {
121
+ "name": "GPT-SoVITS Training API",
122
+ "version": "1.0.0",
123
+ "mode": settings.DEPLOYMENT_MODE,
124
+ "docs": "/docs",
125
+ "health": "/health",
126
+ }
127
+
128
+
129
+ @app.get("/health", tags=["Health"])
130
+ async def health_check():
131
+ """
132
+ 健康检查端点
133
+
134
+ 用于容器编排和负载均衡器健康检查
135
+ """
136
+ return {
137
+ "status": "healthy",
138
+ "mode": settings.DEPLOYMENT_MODE,
139
+ }
140
+
141
+
142
+ # ============================================================
143
+ # 开发模式直接运行
144
+ # ============================================================
145
+
146
+ if __name__ == "__main__":
147
+ import uvicorn
148
+
149
+ uvicorn.run(
150
+ "api_server.app.main:app",
151
+ host=settings.API_HOST,
152
+ port=settings.API_PORT,
153
+ reload=True,
154
+ reload_dirs=[str(settings.API_SERVER_ROOT)],
155
+ )
api_server/app/models/__init__.py CHANGED
@@ -6,4 +6,75 @@
6
 
7
  from .domain import Task, TaskStatus, ProgressInfo
8
 
9
- __all__ = ["Task", "TaskStatus", "ProgressInfo"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from .domain import Task, TaskStatus, ProgressInfo
8
 
9
+ # Pydantic Schemas
10
+ from .schemas import (
11
+ # Common
12
+ SuccessResponse,
13
+ ErrorResponse,
14
+ PaginatedResponse,
15
+ # Task (Quick Mode)
16
+ QuickModeOptions,
17
+ QuickModeRequest,
18
+ TaskResponse,
19
+ TaskListResponse,
20
+ # Experiment (Advanced Mode)
21
+ StageType,
22
+ ExperimentCreate,
23
+ ExperimentUpdate,
24
+ StageStatus,
25
+ ExperimentResponse,
26
+ ExperimentListResponse,
27
+ StageExecuteRequest,
28
+ AudioSliceParams,
29
+ ASRParams,
30
+ TextFeatureParams,
31
+ HubertFeatureParams,
32
+ SemanticTokenParams,
33
+ SoVITSTrainParams,
34
+ GPTTrainParams,
35
+ StageExecuteResponse,
36
+ StagesListResponse,
37
+ # File
38
+ FileUploadResponse,
39
+ FileMetadata,
40
+ FileListResponse,
41
+ FileDeleteResponse,
42
+ )
43
+
44
+ __all__ = [
45
+ # Domain models
46
+ "Task",
47
+ "TaskStatus",
48
+ "ProgressInfo",
49
+ # Common schemas
50
+ "SuccessResponse",
51
+ "ErrorResponse",
52
+ "PaginatedResponse",
53
+ # Task schemas (Quick Mode)
54
+ "QuickModeOptions",
55
+ "QuickModeRequest",
56
+ "TaskResponse",
57
+ "TaskListResponse",
58
+ # Experiment schemas (Advanced Mode)
59
+ "StageType",
60
+ "ExperimentCreate",
61
+ "ExperimentUpdate",
62
+ "StageStatus",
63
+ "ExperimentResponse",
64
+ "ExperimentListResponse",
65
+ "StageExecuteRequest",
66
+ "AudioSliceParams",
67
+ "ASRParams",
68
+ "TextFeatureParams",
69
+ "HubertFeatureParams",
70
+ "SemanticTokenParams",
71
+ "SoVITSTrainParams",
72
+ "GPTTrainParams",
73
+ "StageExecuteResponse",
74
+ "StagesListResponse",
75
+ # File schemas
76
+ "FileUploadResponse",
77
+ "FileMetadata",
78
+ "FileListResponse",
79
+ "FileDeleteResponse",
80
+ ]
api_server/app/models/schemas/__init__.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic Schema 模块
3
+
4
+ 包含 API 请求/响应的数据验证模型
5
+
6
+ - common: 通用响应模型
7
+ - task: Quick Mode 任务模型
8
+ - experiment: Advanced Mode 实验/阶段模型
9
+ - file: 文件管理模型
10
+ """
11
+
12
+ from .common import (
13
+ SuccessResponse,
14
+ ErrorResponse,
15
+ PaginatedResponse,
16
+ )
17
+ from .task import (
18
+ QuickModeOptions,
19
+ QuickModeRequest,
20
+ TaskResponse,
21
+ TaskListResponse,
22
+ )
23
+ from .experiment import (
24
+ StageType,
25
+ ExperimentCreate,
26
+ ExperimentUpdate,
27
+ StageStatus,
28
+ ExperimentResponse,
29
+ ExperimentListResponse,
30
+ StageExecuteRequest,
31
+ AudioSliceParams,
32
+ ASRParams,
33
+ TextFeatureParams,
34
+ HubertFeatureParams,
35
+ SemanticTokenParams,
36
+ SoVITSTrainParams,
37
+ GPTTrainParams,
38
+ StageExecuteResponse,
39
+ StagesListResponse,
40
+ )
41
+ from .file import (
42
+ FileUploadResponse,
43
+ FileMetadata,
44
+ FileListResponse,
45
+ FileDeleteResponse,
46
+ )
47
+
48
+ __all__ = [
49
+ # Common
50
+ "SuccessResponse",
51
+ "ErrorResponse",
52
+ "PaginatedResponse",
53
+ # Task (Quick Mode)
54
+ "QuickModeOptions",
55
+ "QuickModeRequest",
56
+ "TaskResponse",
57
+ "TaskListResponse",
58
+ # Experiment (Advanced Mode)
59
+ "StageType",
60
+ "ExperimentCreate",
61
+ "ExperimentUpdate",
62
+ "StageStatus",
63
+ "ExperimentResponse",
64
+ "ExperimentListResponse",
65
+ "StageExecuteRequest",
66
+ "AudioSliceParams",
67
+ "ASRParams",
68
+ "TextFeatureParams",
69
+ "HubertFeatureParams",
70
+ "SemanticTokenParams",
71
+ "SoVITSTrainParams",
72
+ "GPTTrainParams",
73
+ "StageExecuteResponse",
74
+ "StagesListResponse",
75
+ # File
76
+ "FileUploadResponse",
77
+ "FileMetadata",
78
+ "FileListResponse",
79
+ "FileDeleteResponse",
80
+ ]
api_server/app/models/schemas/common.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 通用响应模型
3
+
4
+ 定义 API 通用的响应结构
5
+ """
6
+
7
+ from typing import Any, Generic, List, Optional, TypeVar
8
+ from pydantic import BaseModel, Field
9
+
10
+ # 泛型类型变量,用于分页响应
11
+ T = TypeVar("T")
12
+
13
+
14
+ class SuccessResponse(BaseModel):
15
+ """
16
+ 通用成功响应
17
+
18
+ Example:
19
+ >>> response = SuccessResponse(message="操作成功")
20
+ >>> response.model_dump()
21
+ {'success': True, 'message': '操作成功'}
22
+ """
23
+ success: bool = Field(default=True, description="是否成功")
24
+ message: str = Field(default="操作成功", description="响应消息")
25
+
26
+ model_config = {
27
+ "json_schema_extra": {
28
+ "examples": [
29
+ {"success": True, "message": "操作成功"}
30
+ ]
31
+ }
32
+ }
33
+
34
+
35
+ class ErrorResponse(BaseModel):
36
+ """
37
+ 错误响应
38
+
39
+ Example:
40
+ >>> response = ErrorResponse(message="任务不存在", code="TASK_NOT_FOUND")
41
+ >>> response.model_dump()
42
+ {'success': False, 'message': '任务不存在', 'code': 'TASK_NOT_FOUND', 'details': None}
43
+ """
44
+ success: bool = Field(default=False, description="是否成功")
45
+ message: str = Field(..., description="错误消息")
46
+ code: Optional[str] = Field(default=None, description="错误代码")
47
+ details: Optional[Any] = Field(default=None, description="错误详情")
48
+
49
+ model_config = {
50
+ "json_schema_extra": {
51
+ "examples": [
52
+ {
53
+ "success": False,
54
+ "message": "任务不存在",
55
+ "code": "TASK_NOT_FOUND",
56
+ "details": None
57
+ }
58
+ ]
59
+ }
60
+ }
61
+
62
+
63
+ class PaginatedResponse(BaseModel, Generic[T]):
64
+ """
65
+ 分页响应基类
66
+
67
+ 泛型参数 T 表示列表项的类型
68
+
69
+ Example:
70
+ >>> from typing import List
71
+ >>> class TaskListResponse(PaginatedResponse[TaskResponse]):
72
+ ... pass
73
+ """
74
+ items: List[T] = Field(default_factory=list, description="数据列表")
75
+ total: int = Field(default=0, ge=0, description="总数量")
76
+ limit: int = Field(default=50, ge=1, le=100, description="每页数量")
77
+ offset: int = Field(default=0, ge=0, description="偏移量")
78
+
79
+ @property
80
+ def has_more(self) -> bool:
81
+ """是否有更多数据"""
82
+ return self.offset + len(self.items) < self.total
83
+
84
+ model_config = {
85
+ "json_schema_extra": {
86
+ "examples": [
87
+ {
88
+ "items": [],
89
+ "total": 0,
90
+ "limit": 50,
91
+ "offset": 0
92
+ }
93
+ ]
94
+ }
95
+ }
api_server/app/models/schemas/experiment.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Mode 实验/阶段 Schema
3
+
4
+ 专家用户分阶段训练模式的请求/响应模型
5
+
6
+ 参考文档: development.md 4.6.2
7
+ """
8
+
9
+ from datetime import datetime
10
+ from enum import Enum
11
+ from typing import Any, Dict, List, Literal, Optional
12
+ from pydantic import BaseModel, Field
13
+
14
+
15
+ # ============================================================
16
+ # 枚举类型
17
+ # ============================================================
18
+
19
+ class StageType(str, Enum):
20
+ """
21
+ 训练阶段类型枚举
22
+
23
+ 定义了完整训练流程中的所有阶段
24
+ """
25
+ AUDIO_SLICE = "audio_slice" # 音频切片
26
+ ASR = "asr" # 语音识别
27
+ TEXT_FEATURE = "text_feature" # 文本特征提取
28
+ HUBERT_FEATURE = "hubert_feature" # HuBERT 特征提取
29
+ SEMANTIC_TOKEN = "semantic_token" # 语义 Token 提取
30
+ SOVITS_TRAIN = "sovits_train" # SoVITS 训练
31
+ GPT_TRAIN = "gpt_train" # GPT 训练
32
+
33
+
34
+ # 阶段依赖关系
35
+ STAGE_DEPENDENCIES: Dict[str, List[str]] = {
36
+ "audio_slice": [],
37
+ "asr": ["audio_slice"],
38
+ "text_feature": ["asr"],
39
+ "hubert_feature": ["audio_slice"],
40
+ "semantic_token": ["hubert_feature"],
41
+ "sovits_train": ["text_feature", "semantic_token"],
42
+ "gpt_train": ["text_feature", "semantic_token"],
43
+ }
44
+
45
+
46
+ # ============================================================
47
+ # 实验管理
48
+ # ============================================================
49
+
50
+ class ExperimentCreate(BaseModel):
51
+ """
52
+ 创建实验请求
53
+
54
+ 创建实验但不立即执行,用户可以逐阶段控制训练流程
55
+
56
+ Attributes:
57
+ exp_name: 实验名称
58
+ version: 模型版本
59
+ gpu_numbers: GPU 编号
60
+ is_half: 是否使用半精度
61
+ audio_file_id: 音频文件 ID
62
+ """
63
+ exp_name: str = Field(
64
+ ...,
65
+ min_length=1,
66
+ max_length=100,
67
+ description="实验名称"
68
+ )
69
+ version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(
70
+ default="v2",
71
+ description="模型版本"
72
+ )
73
+ gpu_numbers: str = Field(
74
+ default="0",
75
+ description="GPU 编号,多个 GPU 用逗号分隔,如 '0,1'"
76
+ )
77
+ is_half: bool = Field(
78
+ default=True,
79
+ description="是否使用半精度(FP16),可节省显存"
80
+ )
81
+ audio_file_id: str = Field(
82
+ ...,
83
+ description="已上传音频文件的 ID"
84
+ )
85
+
86
+ model_config = {
87
+ "json_schema_extra": {
88
+ "examples": [
89
+ {
90
+ "exp_name": "my_voice_custom",
91
+ "version": "v2",
92
+ "gpu_numbers": "0",
93
+ "is_half": True,
94
+ "audio_file_id": "550e8400-e29b-41d4-a716-446655440000"
95
+ }
96
+ ]
97
+ }
98
+ }
99
+
100
+
101
+ class ExperimentUpdate(BaseModel):
102
+ """
103
+ 更新实验请求
104
+
105
+ 用于更新实验的基础配置(非阶段参数)
106
+ """
107
+ exp_name: Optional[str] = Field(
108
+ default=None,
109
+ min_length=1,
110
+ max_length=100,
111
+ description="实验名称"
112
+ )
113
+ gpu_numbers: Optional[str] = Field(
114
+ default=None,
115
+ description="GPU 编号"
116
+ )
117
+ is_half: Optional[bool] = Field(
118
+ default=None,
119
+ description="是否使用半精度"
120
+ )
121
+
122
+
123
+ class StageStatus(BaseModel):
124
+ """
125
+ 阶段状态
126
+
127
+ 描述单个阶段的执行状态和结果
128
+ """
129
+ stage_type: str = Field(..., description="阶段类型")
130
+ status: Literal["pending", "running", "completed", "failed", "cancelled"] = Field(
131
+ default="pending",
132
+ description="阶段状态"
133
+ )
134
+ progress: Optional[float] = Field(
135
+ default=None,
136
+ ge=0.0,
137
+ le=1.0,
138
+ description="阶段进度 (0.0-1.0)"
139
+ )
140
+ message: Optional[str] = Field(
141
+ default=None,
142
+ description="状态消息"
143
+ )
144
+ started_at: Optional[datetime] = Field(
145
+ default=None,
146
+ description="开始时间"
147
+ )
148
+ completed_at: Optional[datetime] = Field(
149
+ default=None,
150
+ description="完成时间"
151
+ )
152
+ config: Optional[Dict[str, Any]] = Field(
153
+ default=None,
154
+ description="阶段配置参数"
155
+ )
156
+ outputs: Optional[Dict[str, Any]] = Field(
157
+ default=None,
158
+ description="阶段输出结果"
159
+ )
160
+ error_message: Optional[str] = Field(
161
+ default=None,
162
+ description="错误消息(失败时)"
163
+ )
164
+
165
+ model_config = {
166
+ "json_schema_extra": {
167
+ "examples": [
168
+ {
169
+ "stage_type": "sovits_train",
170
+ "status": "completed",
171
+ "progress": 1.0,
172
+ "message": "训练完成",
173
+ "started_at": "2024-01-01T10:30:00Z",
174
+ "completed_at": "2024-01-01T11:00:00Z",
175
+ "config": {"batch_size": 8, "total_epoch": 16},
176
+ "outputs": {
177
+ "model_path": "logs/my_voice/sovits_e16.pth",
178
+ "metrics": {"final_loss": 0.023}
179
+ }
180
+ }
181
+ ]
182
+ }
183
+ }
184
+
185
+
186
+ class ExperimentResponse(BaseModel):
187
+ """
188
+ 实验响应
189
+
190
+ 包含实验的完整信息和所有阶段状态
191
+ """
192
+ id: str = Field(..., description="实验唯一标识")
193
+ exp_name: str = Field(..., description="实验名称")
194
+ version: str = Field(..., description="模型版本")
195
+ status: str = Field(..., description="实验状态")
196
+ gpu_numbers: str = Field(default="0", description="GPU 编号")
197
+ is_half: bool = Field(default=True, description="是否使用半精度")
198
+ audio_file_id: str = Field(..., description="音频文件 ID")
199
+ stages: Dict[str, StageStatus] = Field(
200
+ default_factory=dict,
201
+ description="各阶段状态"
202
+ )
203
+ created_at: datetime = Field(..., description="创建时间")
204
+ updated_at: Optional[datetime] = Field(default=None, description="更新时间")
205
+
206
+ model_config = {
207
+ "json_schema_extra": {
208
+ "examples": [
209
+ {
210
+ "id": "exp-abc123",
211
+ "exp_name": "my_voice_custom",
212
+ "version": "v2",
213
+ "status": "created",
214
+ "gpu_numbers": "0",
215
+ "is_half": True,
216
+ "audio_file_id": "550e8400-e29b-41d4-a716-446655440000",
217
+ "stages": {
218
+ "audio_slice": {"stage_type": "audio_slice", "status": "pending"},
219
+ "asr": {"stage_type": "asr", "status": "pending"},
220
+ "sovits_train": {"stage_type": "sovits_train", "status": "pending"}
221
+ },
222
+ "created_at": "2024-01-01T10:00:00Z"
223
+ }
224
+ ]
225
+ }
226
+ }
227
+
228
+
229
+ class ExperimentListResponse(BaseModel):
230
+ """
231
+ 实验列表响应
232
+ """
233
+ items: List[ExperimentResponse] = Field(
234
+ default_factory=list,
235
+ description="实验列表"
236
+ )
237
+ total: int = Field(default=0, ge=0, description="总数量")
238
+ limit: int = Field(default=50, ge=1, le=100, description="每页数量")
239
+ offset: int = Field(default=0, ge=0, description="偏移量")
240
+
241
+
242
+ # ============================================================
243
+ # 阶段执行参数
244
+ # ============================================================
245
+
246
+ class StageExecuteRequest(BaseModel):
247
+ """
248
+ 阶段执行请求基类
249
+
250
+ 允许传入任意额外参数
251
+ """
252
+ model_config = {
253
+ "extra": "allow" # 允许额外字段(阶段特定参数)
254
+ }
255
+
256
+
257
+ class AudioSliceParams(StageExecuteRequest):
258
+ """
259
+ 音频切片参数
260
+
261
+ 将长音频切分为短片段
262
+
263
+ 参考文档: development.md 4.5.2
264
+ """
265
+ threshold: int = Field(
266
+ default=-34,
267
+ ge=-60,
268
+ le=0,
269
+ description="静音检测阈值 (dB)"
270
+ )
271
+ min_length: int = Field(
272
+ default=4000,
273
+ ge=1000,
274
+ le=10000,
275
+ description="最小切片长度 (ms)"
276
+ )
277
+ min_interval: int = Field(
278
+ default=300,
279
+ ge=100,
280
+ le=1000,
281
+ description="最小静音间隔 (ms)"
282
+ )
283
+ hop_size: int = Field(
284
+ default=10,
285
+ ge=5,
286
+ le=50,
287
+ description="检测步长 (ms)"
288
+ )
289
+ max_sil_kept: int = Field(
290
+ default=500,
291
+ ge=100,
292
+ le=2000,
293
+ description="切片保留的最大静音长度 (ms)"
294
+ )
295
+
296
+ model_config = {
297
+ "json_schema_extra": {
298
+ "examples": [
299
+ {
300
+ "threshold": -34,
301
+ "min_length": 4000,
302
+ "min_interval": 300,
303
+ "hop_size": 10,
304
+ "max_sil_kept": 500
305
+ }
306
+ ]
307
+ }
308
+ }
309
+
310
+
311
+ class ASRParams(StageExecuteRequest):
312
+ """
313
+ ASR 语音识别参数
314
+ """
315
+ model: str = Field(
316
+ default="达摩 ASR (中文)",
317
+ description="ASR 模型名称"
318
+ )
319
+ language: str = Field(
320
+ default="zh",
321
+ description="识别语言"
322
+ )
323
+
324
+ model_config = {
325
+ "json_schema_extra": {
326
+ "examples": [
327
+ {"model": "达摩 ASR (中文)", "language": "zh"}
328
+ ]
329
+ }
330
+ }
331
+
332
+
333
+ class TextFeatureParams(StageExecuteRequest):
334
+ """
335
+ 文本特征提取参数
336
+ """
337
+ bert_pretrained_dir: Optional[str] = Field(
338
+ default=None,
339
+ description="BERT 预训练模型目录,为空使用默认"
340
+ )
341
+
342
+ model_config = {
343
+ "json_schema_extra": {
344
+ "examples": [
345
+ {"bert_pretrained_dir": None}
346
+ ]
347
+ }
348
+ }
349
+
350
+
351
+ class HubertFeatureParams(StageExecuteRequest):
352
+ """
353
+ HuBERT 特征提取参数
354
+ """
355
+ ssl_pretrained_dir: Optional[str] = Field(
356
+ default=None,
357
+ description="SSL 预训练模型目录,为空使用默认"
358
+ )
359
+
360
+ model_config = {
361
+ "json_schema_extra": {
362
+ "examples": [
363
+ {"ssl_pretrained_dir": None}
364
+ ]
365
+ }
366
+ }
367
+
368
+
369
+ class SemanticTokenParams(StageExecuteRequest):
370
+ """
371
+ 语义 Token 提取参数
372
+ """
373
+ # 当前阶段无特殊参数,保留扩展性
374
+ pass
375
+
376
+
377
+ class SoVITSTrainParams(StageExecuteRequest):
378
+ """
379
+ SoVITS 训练参数
380
+
381
+ 参考文档: development.md 4.5.2
382
+ """
383
+ batch_size: int = Field(
384
+ default=4,
385
+ ge=1,
386
+ le=32,
387
+ description="批次大小,显存不足时减小"
388
+ )
389
+ total_epoch: int = Field(
390
+ default=8,
391
+ ge=1,
392
+ le=100,
393
+ description="训练总轮数"
394
+ )
395
+ save_every_epoch: int = Field(
396
+ default=4,
397
+ ge=1,
398
+ description="每 N 轮保存一次模型"
399
+ )
400
+ pretrained_s2G: Optional[str] = Field(
401
+ default=None,
402
+ description="预训练生成器模型路径,为空使用默认"
403
+ )
404
+ pretrained_s2D: Optional[str] = Field(
405
+ default=None,
406
+ description="预训练判别器模型路径,为空使用默认"
407
+ )
408
+
409
+ model_config = {
410
+ "json_schema_extra": {
411
+ "examples": [
412
+ {
413
+ "batch_size": 8,
414
+ "total_epoch": 16,
415
+ "save_every_epoch": 4,
416
+ "pretrained_s2G": None,
417
+ "pretrained_s2D": None
418
+ }
419
+ ]
420
+ }
421
+ }
422
+
423
+
424
+ class GPTTrainParams(StageExecuteRequest):
425
+ """
426
+ GPT 训练参数
427
+ """
428
+ batch_size: int = Field(
429
+ default=4,
430
+ ge=1,
431
+ le=32,
432
+ description="批次大小"
433
+ )
434
+ total_epoch: int = Field(
435
+ default=15,
436
+ ge=1,
437
+ le=100,
438
+ description="训练总轮数"
439
+ )
440
+ save_every_epoch: int = Field(
441
+ default=5,
442
+ ge=1,
443
+ description="每 N 轮保存一次模型"
444
+ )
445
+ pretrained_s1: Optional[str] = Field(
446
+ default=None,
447
+ description="预训练模型路径,为空使用默认"
448
+ )
449
+
450
+ model_config = {
451
+ "json_schema_extra": {
452
+ "examples": [
453
+ {
454
+ "batch_size": 4,
455
+ "total_epoch": 15,
456
+ "save_every_epoch": 5,
457
+ "pretrained_s1": None
458
+ }
459
+ ]
460
+ }
461
+ }
462
+
463
+
464
+ class StageExecuteResponse(BaseModel):
465
+ """
466
+ 阶段执行响应
467
+ """
468
+ exp_id: str = Field(..., description="实验 ID")
469
+ stage_type: str = Field(..., description="阶段类型")
470
+ status: Literal["running", "queued"] = Field(..., description="执行状态")
471
+ job_id: str = Field(..., description="作业 ID")
472
+ config: Dict[str, Any] = Field(
473
+ default_factory=dict,
474
+ description="阶段配置"
475
+ )
476
+ rerun: bool = Field(
477
+ default=False,
478
+ description="是否为重新执行"
479
+ )
480
+ previous_run: Optional[Dict[str, Any]] = Field(
481
+ default=None,
482
+ description="上次执行的信息(重新执行时)"
483
+ )
484
+ started_at: datetime = Field(..., description="开始时间")
485
+
486
+ model_config = {
487
+ "json_schema_extra": {
488
+ "examples": [
489
+ {
490
+ "exp_id": "exp-abc123",
491
+ "stage_type": "sovits_train",
492
+ "status": "running",
493
+ "job_id": "job-xyz789",
494
+ "config": {"batch_size": 8, "total_epoch": 16},
495
+ "rerun": False,
496
+ "started_at": "2024-01-01T10:30:00Z"
497
+ }
498
+ ]
499
+ }
500
+ }
501
+
502
+
503
+ class StagesListResponse(BaseModel):
504
+ """
505
+ 所有阶段状态响应
506
+ """
507
+ exp_id: str = Field(..., description="实验 ID")
508
+ stages: List[StageStatus] = Field(
509
+ default_factory=list,
510
+ description="阶段状态列表"
511
+ )
512
+
513
+ model_config = {
514
+ "json_schema_extra": {
515
+ "examples": [
516
+ {
517
+ "exp_id": "exp-abc123",
518
+ "stages": [
519
+ {"stage_type": "audio_slice", "status": "completed"},
520
+ {"stage_type": "asr", "status": "completed"},
521
+ {"stage_type": "sovits_train", "status": "running", "progress": 0.45}
522
+ ]
523
+ }
524
+ ]
525
+ }
526
+ }
527
+
528
+
529
+ # 阶段参数类型映射
530
+ STAGE_PARAMS_MAP: Dict[str, type] = {
531
+ "audio_slice": AudioSliceParams,
532
+ "asr": ASRParams,
533
+ "text_feature": TextFeatureParams,
534
+ "hubert_feature": HubertFeatureParams,
535
+ "semantic_token": SemanticTokenParams,
536
+ "sovits_train": SoVITSTrainParams,
537
+ "gpt_train": GPTTrainParams,
538
+ }
539
+
540
+
541
+ def get_stage_params_class(stage_type: str) -> type:
542
+ """
543
+ 获取阶段对应的参数类
544
+
545
+ Args:
546
+ stage_type: 阶段类型
547
+
548
+ Returns:
549
+ 对应的参数 Pydantic 类
550
+
551
+ Raises:
552
+ ValueError: 无效的阶段类型
553
+ """
554
+ if stage_type not in STAGE_PARAMS_MAP:
555
+ raise ValueError(f"Invalid stage type: {stage_type}")
556
+ return STAGE_PARAMS_MAP[stage_type]
api_server/app/models/schemas/file.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件管理 Schema
3
+
4
+ 文件上传/下载相关的请求/响应模型
5
+ """
6
+
7
+ from datetime import datetime
8
+ from typing import List, Literal, Optional
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class FileMetadata(BaseModel):
13
+ """
14
+ 文件元数据
15
+
16
+ 描述已上传文件的详细信息
17
+ """
18
+ id: str = Field(..., description="文件唯一标识")
19
+ filename: str = Field(..., description="原始文件名")
20
+ content_type: Optional[str] = Field(
21
+ default=None,
22
+ description="MIME 类型,如 'audio/wav', 'audio/mp3'"
23
+ )
24
+ size_bytes: int = Field(
25
+ default=0,
26
+ ge=0,
27
+ description="文件大小(字节)"
28
+ )
29
+ purpose: Optional[Literal["training", "reference", "output"]] = Field(
30
+ default="training",
31
+ description="文件用途:training(训练), reference(参考音频), output(输出模型)"
32
+ )
33
+ duration_seconds: Optional[float] = Field(
34
+ default=None,
35
+ ge=0,
36
+ description="音频时长(秒),仅音频文件有效"
37
+ )
38
+ sample_rate: Optional[int] = Field(
39
+ default=None,
40
+ ge=0,
41
+ description="采样率(Hz),仅音频文件有效"
42
+ )
43
+ uploaded_at: datetime = Field(..., description="上传时间")
44
+
45
+ model_config = {
46
+ "json_schema_extra": {
47
+ "examples": [
48
+ {
49
+ "id": "550e8400-e29b-41d4-a716-446655440000",
50
+ "filename": "my_voice.wav",
51
+ "content_type": "audio/wav",
52
+ "size_bytes": 15728640,
53
+ "purpose": "training",
54
+ "duration_seconds": 120.5,
55
+ "sample_rate": 44100,
56
+ "uploaded_at": "2024-01-01T10:00:00Z"
57
+ }
58
+ ]
59
+ }
60
+ }
61
+
62
+
63
+ class FileUploadResponse(BaseModel):
64
+ """
65
+ 文件上传响应
66
+
67
+ 上传成功后返回文件信息
68
+ """
69
+ success: bool = Field(default=True, description="是否成功")
70
+ message: str = Field(default="文件上传成功", description="响应消息")
71
+ file: FileMetadata = Field(..., description="文件元数据")
72
+
73
+ model_config = {
74
+ "json_schema_extra": {
75
+ "examples": [
76
+ {
77
+ "success": True,
78
+ "message": "文件上传成功",
79
+ "file": {
80
+ "id": "550e8400-e29b-41d4-a716-446655440000",
81
+ "filename": "my_voice.wav",
82
+ "content_type": "audio/wav",
83
+ "size_bytes": 15728640,
84
+ "purpose": "training",
85
+ "uploaded_at": "2024-01-01T10:00:00Z"
86
+ }
87
+ }
88
+ ]
89
+ }
90
+ }
91
+
92
+
93
+ class FileListResponse(BaseModel):
94
+ """
95
+ 文件列表响应
96
+ """
97
+ items: List[FileMetadata] = Field(
98
+ default_factory=list,
99
+ description="文件列表"
100
+ )
101
+ total: int = Field(
102
+ default=0,
103
+ ge=0,
104
+ description="总数量"
105
+ )
106
+ limit: int = Field(
107
+ default=50,
108
+ ge=1,
109
+ le=100,
110
+ description="每页数量"
111
+ )
112
+ offset: int = Field(
113
+ default=0,
114
+ ge=0,
115
+ description="偏移量"
116
+ )
117
+
118
+ model_config = {
119
+ "json_schema_extra": {
120
+ "examples": [
121
+ {
122
+ "items": [
123
+ {
124
+ "id": "file-123",
125
+ "filename": "voice_1.wav",
126
+ "content_type": "audio/wav",
127
+ "size_bytes": 5242880,
128
+ "purpose": "training",
129
+ "uploaded_at": "2024-01-01T10:00:00Z"
130
+ }
131
+ ],
132
+ "total": 1,
133
+ "limit": 50,
134
+ "offset": 0
135
+ }
136
+ ]
137
+ }
138
+ }
139
+
140
+
141
+ class FileDeleteResponse(BaseModel):
142
+ """
143
+ 文件删除响应
144
+ """
145
+ success: bool = Field(default=True, description="是否成功")
146
+ message: str = Field(default="文件删除成功", description="响应消息")
147
+ file_id: str = Field(..., description="已删除的文件 ID")
148
+
149
+ model_config = {
150
+ "json_schema_extra": {
151
+ "examples": [
152
+ {
153
+ "success": True,
154
+ "message": "文件删除成功",
155
+ "file_id": "550e8400-e29b-41d4-a716-446655440000"
156
+ }
157
+ ]
158
+ }
159
+ }
api_server/app/models/schemas/task.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick Mode 任务 Schema
3
+
4
+ 小白用户一键训练模式的请求/响应模型
5
+
6
+ 参考文档: development.md 4.6.1 + 4.6.3
7
+ """
8
+
9
+ from datetime import datetime
10
+ from typing import List, Literal, Optional
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ class QuickModeOptions(BaseModel):
15
+ """
16
+ Quick Mode 训练选项
17
+
18
+ 用于一键训练时的简化参数配置
19
+
20
+ Attributes:
21
+ version: 模型版本
22
+ language: 训练语言
23
+ quality: 训练质量预设
24
+
25
+ 质量预设说明:
26
+ - fast: SoVITS 4 epochs, GPT 8 epochs, ~10分钟
27
+ - standard: SoVITS 8 epochs, GPT 15 epochs, ~20分钟
28
+ - high: SoVITS 16 epochs, GPT 30 epochs, ~40分钟
29
+ """
30
+ version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(
31
+ default="v2",
32
+ description="模型版本"
33
+ )
34
+ language: str = Field(
35
+ default="zh",
36
+ description="训练语言,如 'zh', 'en', 'ja' 等"
37
+ )
38
+ quality: Literal["fast", "standard", "high"] = Field(
39
+ default="standard",
40
+ description="训练质量预设:fast(快速)、standard(标准)、high(高质量)"
41
+ )
42
+
43
+ model_config = {
44
+ "json_schema_extra": {
45
+ "examples": [
46
+ {"version": "v2", "language": "zh", "quality": "standard"}
47
+ ]
48
+ }
49
+ }
50
+
51
+
52
+ class QuickModeRequest(BaseModel):
53
+ """
54
+ 小白用户一键训练请求
55
+
56
+ 创建一键训练任务,系统自动配置所有参数并执行完整流程:
57
+ audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train
58
+
59
+ Attributes:
60
+ exp_name: 实验名称(用于标识训练任务)
61
+ audio_file_id: 已上传音频文件的 ID
62
+ options: 训练选项
63
+ """
64
+ exp_name: str = Field(
65
+ ...,
66
+ min_length=1,
67
+ max_length=100,
68
+ description="实验名称,用于标识训练任务和生成的模型"
69
+ )
70
+ audio_file_id: str = Field(
71
+ ...,
72
+ description="已上传音频文件的 ID"
73
+ )
74
+ options: QuickModeOptions = Field(
75
+ default_factory=QuickModeOptions,
76
+ description="训练选项"
77
+ )
78
+
79
+ model_config = {
80
+ "json_schema_extra": {
81
+ "examples": [
82
+ {
83
+ "exp_name": "my_voice",
84
+ "audio_file_id": "550e8400-e29b-41d4-a716-446655440000",
85
+ "options": {
86
+ "version": "v2",
87
+ "language": "zh",
88
+ "quality": "standard"
89
+ }
90
+ }
91
+ ]
92
+ }
93
+ }
94
+
95
+
96
+ class TaskResponse(BaseModel):
97
+ """
98
+ 任务响应(Quick Mode)
99
+
100
+ 返回任务的完整状态信息,包括进度、当前阶段等
101
+
102
+ Attributes:
103
+ id: 任务唯一标识
104
+ exp_name: 实验名称
105
+ status: 任务状态
106
+ current_stage: 当前执行的阶段
107
+ progress: 当前阶段进度 (0.0-1.0)
108
+ overall_progress: 总体进度 (0.0-1.0)
109
+ message: 最新状态消息
110
+ error_message: 错误消息(失败时)
111
+ created_at: 任务创建时间
112
+ started_at: 任务开始执行时间
113
+ completed_at: 任务完成时间
114
+ """
115
+ id: str = Field(..., description="任务唯一标识")
116
+ exp_name: str = Field(..., description="实验名称")
117
+ status: Literal["queued", "running", "completed", "failed", "cancelled", "interrupted"] = Field(
118
+ ...,
119
+ description="任务状态"
120
+ )
121
+ current_stage: Optional[str] = Field(
122
+ default=None,
123
+ description="当前执行的阶段,如 'audio_slice', 'sovits_train' 等"
124
+ )
125
+ progress: float = Field(
126
+ default=0.0,
127
+ ge=0.0,
128
+ le=1.0,
129
+ description="当前阶段进度 (0.0-1.0)"
130
+ )
131
+ overall_progress: float = Field(
132
+ default=0.0,
133
+ ge=0.0,
134
+ le=1.0,
135
+ description="总体进度 (0.0-1.0)"
136
+ )
137
+ message: Optional[str] = Field(
138
+ default=None,
139
+ description="最新状态消息"
140
+ )
141
+ error_message: Optional[str] = Field(
142
+ default=None,
143
+ description="错误消息(失败时)"
144
+ )
145
+ created_at: Optional[datetime] = Field(
146
+ default=None,
147
+ description="任务创建时间"
148
+ )
149
+ started_at: Optional[datetime] = Field(
150
+ default=None,
151
+ description="任务开始执行时间"
152
+ )
153
+ completed_at: Optional[datetime] = Field(
154
+ default=None,
155
+ description="任务完成时间"
156
+ )
157
+
158
+ model_config = {
159
+ "from_attributes": True,
160
+ "json_schema_extra": {
161
+ "examples": [
162
+ {
163
+ "id": "task-550e8400-e29b-41d4-a716-446655440000",
164
+ "exp_name": "my_voice",
165
+ "status": "running",
166
+ "current_stage": "sovits_train",
167
+ "progress": 0.45,
168
+ "overall_progress": 0.72,
169
+ "message": "SoVITS 训练中 Epoch 8/16",
170
+ "error_message": None,
171
+ "created_at": "2024-01-01T10:00:00Z",
172
+ "started_at": "2024-01-01T10:00:05Z",
173
+ "completed_at": None
174
+ }
175
+ ]
176
+ }
177
+ }
178
+
179
+
180
+ class TaskListResponse(BaseModel):
181
+ """
182
+ 任务列表响应
183
+
184
+ Attributes:
185
+ items: 任务列表
186
+ total: 总数量
187
+ limit: 每页数量
188
+ offset: 偏移量
189
+ """
190
+ items: List[TaskResponse] = Field(
191
+ default_factory=list,
192
+ description="任务列表"
193
+ )
194
+ total: int = Field(
195
+ default=0,
196
+ ge=0,
197
+ description="总数量"
198
+ )
199
+ limit: int = Field(
200
+ default=50,
201
+ ge=1,
202
+ le=100,
203
+ description="每页数量"
204
+ )
205
+ offset: int = Field(
206
+ default=0,
207
+ ge=0,
208
+ description="偏移量"
209
+ )
210
+
211
+ model_config = {
212
+ "json_schema_extra": {
213
+ "examples": [
214
+ {
215
+ "items": [
216
+ {
217
+ "id": "task-123",
218
+ "exp_name": "voice_1",
219
+ "status": "completed",
220
+ "current_stage": None,
221
+ "progress": 1.0,
222
+ "overall_progress": 1.0,
223
+ "message": "训练完成"
224
+ }
225
+ ],
226
+ "total": 1,
227
+ "limit": 50,
228
+ "offset": 0
229
+ }
230
+ ]
231
+ }
232
+ }
api_server/app/scripts/run_pipeline.py CHANGED
@@ -238,9 +238,23 @@ def build_pipeline(config: Dict[str, Any]):
238
  }
239
 
240
  # 按顺序添加阶段
 
 
 
241
  stages = config.get("stages", [])
242
- for stage_config in stages:
243
- stage_type = stage_config.get("type")
 
 
 
 
 
 
 
 
 
 
 
244
  if stage_type in stage_builders:
245
  stage = stage_builders[stage_type](stage_config)
246
  pipeline.add_stage(stage)
 
238
  }
239
 
240
  # 按顺序添加阶段
241
+ # stages 可以是:
242
+ # 1. 字符串列表: ["audio_slice", "asr", ...]
243
+ # 2. 字典列表: [{"type": "audio_slice", "threshold": -30}, ...]
244
  stages = config.get("stages", [])
245
+ for stage_item in stages:
246
+ # 判断是字符串还是字典
247
+ if isinstance(stage_item, str):
248
+ stage_type = stage_item
249
+ stage_config = config # 使用全局配置作为阶段配置
250
+ elif isinstance(stage_item, dict):
251
+ stage_type = stage_item.get("type")
252
+ # 合并全局配置和阶段特定配置
253
+ stage_config = {**config, **stage_item}
254
+ else:
255
+ emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}")
256
+ continue
257
+
258
  if stage_type in stage_builders:
259
  stage = stage_builders[stage_type](stage_config)
260
  pipeline.add_stage(stage)
api_server/app/services/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 服务层模块
3
+
4
+ 业务逻辑层,封装适配器调用,提供高级业务操作。
5
+
6
+ 服务列表:
7
+ - TaskService: Quick Mode 任务服务
8
+ - ExperimentService: Advanced Mode 实验服务
9
+ - FileService: 文件管理服务
10
+ """
11
+
12
+ from .task_service import TaskService
13
+ from .experiment_service import ExperimentService
14
+ from .file_service import FileService
15
+
16
+ __all__ = [
17
+ "TaskService",
18
+ "ExperimentService",
19
+ "FileService",
20
+ ]
api_server/app/services/experiment_service.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Mode 实验服务
3
+
4
+ 处理专家模式分阶段训练的业务逻辑
5
+ """
6
+
7
+ import uuid
8
+ from datetime import datetime
9
+ from typing import AsyncGenerator, Dict, List, Optional, Any
10
+
11
+ from ..core.adapters import (
12
+ get_database_adapter,
13
+ get_task_queue_adapter,
14
+ get_progress_adapter,
15
+ )
16
+ from ..models.schemas.experiment import (
17
+ ExperimentCreate,
18
+ ExperimentUpdate,
19
+ ExperimentResponse,
20
+ ExperimentListResponse,
21
+ StageStatus,
22
+ StageExecuteResponse,
23
+ StagesListResponse,
24
+ STAGE_DEPENDENCIES,
25
+ )
26
+
27
+
28
+ # 阶段类型列表(按执行顺序)
29
+ STAGE_TYPES = [
30
+ "audio_slice",
31
+ "asr",
32
+ "text_feature",
33
+ "hubert_feature",
34
+ "semantic_token",
35
+ "sovits_train",
36
+ "gpt_train",
37
+ ]
38
+
39
+
40
+ class ExperimentService:
41
+ """
42
+ Advanced Mode 实验服务
43
+
44
+ 提供专家模式的分阶段训练管理:
45
+ - 创建实验
46
+ - 查询实验/阶段状态
47
+ - 执行/取消单个阶段
48
+ - 检查阶段依赖
49
+
50
+ Example:
51
+ >>> service = ExperimentService()
52
+ >>> exp = await service.create_experiment(request)
53
+ >>> await service.execute_stage(exp.id, "audio_slice", {})
54
+ >>> stages = await service.get_all_stages(exp.id)
55
+ """
56
+
57
+ def __init__(self):
58
+ """初始化服务"""
59
+ self._db = None
60
+ self._queue = None
61
+ self._progress = None
62
+
63
+ @property
64
+ def db(self):
65
+ """延迟获取数据库适配器"""
66
+ if self._db is None:
67
+ self._db = get_database_adapter()
68
+ return self._db
69
+
70
+ @property
71
+ def queue(self):
72
+ """延迟获取任务队列适配器"""
73
+ if self._queue is None:
74
+ self._queue = get_task_queue_adapter()
75
+ return self._queue
76
+
77
+ @property
78
+ def progress_adapter(self):
79
+ """延迟获取进度适配器"""
80
+ if self._progress is None:
81
+ self._progress = get_progress_adapter()
82
+ return self._progress
83
+
84
+ async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse:
85
+ """
86
+ 创建实验
87
+
88
+ 创建实验但不立即执行,用户可以逐阶段控制训练流程。
89
+
90
+ Args:
91
+ request: 创建实验请求
92
+
93
+ Returns:
94
+ ExperimentResponse
95
+ """
96
+ exp_id = f"exp-{uuid.uuid4().hex[:8]}"
97
+
98
+ experiment_data = {
99
+ "id": exp_id,
100
+ "exp_name": request.exp_name,
101
+ "version": request.version,
102
+ "gpu_numbers": request.gpu_numbers,
103
+ "is_half": request.is_half,
104
+ "audio_file_id": request.audio_file_id,
105
+ "status": "created",
106
+ }
107
+
108
+ # 创建实验(会自动创建所有阶段)
109
+ experiment = await self.db.create_experiment(experiment_data)
110
+
111
+ return self._experiment_to_response(experiment)
112
+
113
+ async def get_experiment(self, exp_id: str) -> Optional[ExperimentResponse]:
114
+ """
115
+ 获取实验详情
116
+
117
+ Args:
118
+ exp_id: 实验ID
119
+
120
+ Returns:
121
+ ExperimentResponse 或 None
122
+ """
123
+ experiment = await self.db.get_experiment(exp_id)
124
+ if not experiment:
125
+ return None
126
+ return self._experiment_to_response(experiment)
127
+
128
+ async def list_experiments(
129
+ self,
130
+ status: Optional[str] = None,
131
+ limit: int = 50,
132
+ offset: int = 0
133
+ ) -> ExperimentListResponse:
134
+ """
135
+ 获取实验列表
136
+
137
+ Args:
138
+ status: 按状态筛选
139
+ limit: 每页数量
140
+ offset: 偏移量
141
+
142
+ Returns:
143
+ ExperimentListResponse
144
+ """
145
+ experiments = await self.db.list_experiments(
146
+ status=status, limit=limit, offset=offset
147
+ )
148
+
149
+ # 获取每个实验的完整信息(包含 stages)
150
+ full_experiments = []
151
+ for exp in experiments:
152
+ full_exp = await self.db.get_experiment(exp["id"])
153
+ if full_exp:
154
+ full_experiments.append(full_exp)
155
+
156
+ return ExperimentListResponse(
157
+ items=[self._experiment_to_response(e) for e in full_experiments],
158
+ total=len(experiments), # TODO: 添加 count 方法
159
+ limit=limit,
160
+ offset=offset,
161
+ )
162
+
163
+ async def update_experiment(
164
+ self,
165
+ exp_id: str,
166
+ request: ExperimentUpdate
167
+ ) -> Optional[ExperimentResponse]:
168
+ """
169
+ 更新实验基础配置
170
+
171
+ Args:
172
+ exp_id: 实验ID
173
+ request: 更新请求
174
+
175
+ Returns:
176
+ ExperimentResponse 或 None
177
+ """
178
+ updates = {}
179
+ if request.exp_name is not None:
180
+ updates["exp_name"] = request.exp_name
181
+ if request.gpu_numbers is not None:
182
+ updates["gpu_numbers"] = request.gpu_numbers
183
+ if request.is_half is not None:
184
+ updates["is_half"] = request.is_half
185
+
186
+ if not updates:
187
+ return await self.get_experiment(exp_id)
188
+
189
+ experiment = await self.db.update_experiment(exp_id, updates)
190
+ if not experiment:
191
+ return None
192
+
193
+ return self._experiment_to_response(experiment)
194
+
195
+ async def delete_experiment(self, exp_id: str) -> bool:
196
+ """
197
+ 删除实验
198
+
199
+ Args:
200
+ exp_id: 实验ID
201
+
202
+ Returns:
203
+ 是否成功删除
204
+ """
205
+ # 先取消所有运行中的阶段
206
+ stages = await self.db.get_all_stages(exp_id)
207
+ for stage in stages:
208
+ if stage.get("status") == "running" and stage.get("job_id"):
209
+ await self.queue.cancel(stage["job_id"])
210
+
211
+ return await self.db.delete_experiment(exp_id)
212
+
213
+ async def check_stage_dependencies(
214
+ self,
215
+ exp_id: str,
216
+ stage_type: str
217
+ ) -> Dict[str, Any]:
218
+ """
219
+ 检查阶段依赖是否满足
220
+
221
+ Args:
222
+ exp_id: 实验ID
223
+ stage_type: 阶段类型
224
+
225
+ Returns:
226
+ {"satisfied": bool, "missing": List[str]}
227
+ """
228
+ experiment = await self.db.get_experiment(exp_id)
229
+ if not experiment:
230
+ return {"satisfied": False, "missing": [], "error": "实验不存在"}
231
+
232
+ dependencies = STAGE_DEPENDENCIES.get(stage_type, [])
233
+ stages = experiment.get("stages", {})
234
+
235
+ missing = []
236
+ for dep in dependencies:
237
+ dep_stage = stages.get(dep, {})
238
+ if dep_stage.get("status") != "completed":
239
+ missing.append(dep)
240
+
241
+ return {
242
+ "satisfied": len(missing) == 0,
243
+ "missing": missing,
244
+ }
245
+
246
+ async def execute_stage(
247
+ self,
248
+ exp_id: str,
249
+ stage_type: str,
250
+ params: Dict[str, Any]
251
+ ) -> Optional[StageExecuteResponse]:
252
+ """
253
+ 执行指定阶段
254
+
255
+ Args:
256
+ exp_id: 实验ID
257
+ stage_type: 阶段类型
258
+ params: 阶段参数
259
+
260
+ Returns:
261
+ StageExecuteResponse 或 None
262
+ """
263
+ # 获取实验
264
+ experiment = await self.db.get_experiment(exp_id)
265
+ if not experiment:
266
+ return None
267
+
268
+ stages = experiment.get("stages", {})
269
+ current_stage = stages.get(stage_type, {})
270
+
271
+ # 检查是否是重新执行
272
+ is_rerun = current_stage.get("status") == "completed"
273
+ previous_run = None
274
+ if is_rerun:
275
+ previous_run = {
276
+ "completed_at": current_stage.get("completed_at"),
277
+ "outputs": current_stage.get("outputs"),
278
+ }
279
+
280
+ # 构建阶段配置
281
+ stage_config = {
282
+ "exp_id": exp_id,
283
+ "exp_name": experiment["exp_name"],
284
+ "version": experiment.get("version", "v2"),
285
+ "gpu_numbers": experiment.get("gpu_numbers", "0"),
286
+ "is_half": experiment.get("is_half", True),
287
+ "audio_file_id": experiment.get("audio_file_id"),
288
+ "stage_type": stage_type,
289
+ "params": params,
290
+ # 只执行单个阶段
291
+ "stages": [stage_type],
292
+ }
293
+
294
+ # 生成任务ID(用于进度追踪)
295
+ task_id = f"{exp_id}-{stage_type}-{uuid.uuid4().hex[:4]}"
296
+
297
+ # 入队执行
298
+ job_id = await self.queue.enqueue(task_id, stage_config)
299
+
300
+ # 更新阶段状态
301
+ now = datetime.utcnow()
302
+ await self.db.update_stage(exp_id, stage_type, {
303
+ "status": "running",
304
+ "config": params,
305
+ "job_id": job_id,
306
+ "started_at": now,
307
+ "completed_at": None,
308
+ "error_message": None,
309
+ "outputs": None,
310
+ "progress": 0.0,
311
+ })
312
+
313
+ # 更新实验状态为运行中
314
+ await self.db.update_experiment(exp_id, {"status": "running"})
315
+
316
+ return StageExecuteResponse(
317
+ exp_id=exp_id,
318
+ stage_type=stage_type,
319
+ status="running",
320
+ job_id=job_id,
321
+ config=params,
322
+ rerun=is_rerun,
323
+ previous_run=previous_run,
324
+ started_at=now,
325
+ )
326
+
327
+ async def get_stage(
328
+ self,
329
+ exp_id: str,
330
+ stage_type: str
331
+ ) -> Optional[StageStatus]:
332
+ """
333
+ 获取阶段状态
334
+
335
+ Args:
336
+ exp_id: 实验ID
337
+ stage_type: 阶段类型
338
+
339
+ Returns:
340
+ StageStatus 或 None
341
+ """
342
+ stage = await self.db.get_stage(exp_id, stage_type)
343
+ if not stage:
344
+ return None
345
+ return self._stage_to_status(stage)
346
+
347
+ async def get_all_stages(self, exp_id: str) -> Optional[StagesListResponse]:
348
+ """
349
+ 获取所有阶段状态
350
+
351
+ Args:
352
+ exp_id: 实验ID
353
+
354
+ Returns:
355
+ StagesListResponse 或 None
356
+ """
357
+ stages = await self.db.get_all_stages(exp_id)
358
+ if not stages:
359
+ # 检查实验是否存在
360
+ experiment = await self.db.get_experiment(exp_id)
361
+ if not experiment:
362
+ return None
363
+ stages = []
364
+
365
+ return StagesListResponse(
366
+ exp_id=exp_id,
367
+ stages=[self._stage_to_status(s) for s in stages],
368
+ )
369
+
370
+ async def cancel_stage(self, exp_id: str, stage_type: str) -> bool:
371
+ """
372
+ 取消正在执行的阶段
373
+
374
+ Args:
375
+ exp_id: 实验ID
376
+ stage_type: 阶段类型
377
+
378
+ Returns:
379
+ 是否成功取消
380
+ """
381
+ stage = await self.db.get_stage(exp_id, stage_type)
382
+ if not stage:
383
+ return False
384
+
385
+ # 只有运行中的阶段可以取消
386
+ if stage.get("status") != "running":
387
+ return False
388
+
389
+ # 取消任务
390
+ job_id = stage.get("job_id")
391
+ if job_id:
392
+ await self.queue.cancel(job_id)
393
+
394
+ # 更新状态
395
+ await self.db.update_stage(exp_id, stage_type, {
396
+ "status": "cancelled",
397
+ "completed_at": datetime.utcnow(),
398
+ "message": "阶段已取消",
399
+ })
400
+
401
+ return True
402
+
403
+ async def subscribe_stage_progress(
404
+ self,
405
+ exp_id: str,
406
+ stage_type: str
407
+ ) -> AsyncGenerator[Dict[str, Any], None]:
408
+ """
409
+ 订阅阶段进度(SSE 流)
410
+
411
+ Args:
412
+ exp_id: 实验ID
413
+ stage_type: 阶段类型
414
+
415
+ Yields:
416
+ 进度信息字典
417
+ """
418
+ # 获取阶段信息
419
+ stage = await self.db.get_stage(exp_id, stage_type)
420
+ if not stage:
421
+ yield {"type": "error", "message": "阶段不存在"}
422
+ return
423
+
424
+ # 如果阶段已结束,直接返回最终状态
425
+ if stage.get("status") in ("completed", "failed", "cancelled"):
426
+ yield {
427
+ "type": "final",
428
+ "status": stage.get("status"),
429
+ "message": stage.get("message") or stage.get("error_message"),
430
+ "progress": stage.get("progress", 0.0),
431
+ "outputs": stage.get("outputs"),
432
+ }
433
+ return
434
+
435
+ # 如果阶段未开始
436
+ if stage.get("status") == "pending":
437
+ yield {"type": "info", "message": "阶段尚未开始"}
438
+ return
439
+
440
+ # 使用任务ID订阅进度
441
+ # 任务ID格式: {exp_id}-{stage_type}-{random}
442
+ # 由于我们不知道确切的任务ID,使用 job_id
443
+ job_id = stage.get("job_id")
444
+ if not job_id:
445
+ yield {"type": "error", "message": "无法获取任务ID"}
446
+ return
447
+
448
+ # 订阅进度
449
+ # 注意:这里需要根据实际的进度适配器实现来调整
450
+ # 当前使用 task_id 格式为 "{exp_id}-{stage_type}"
451
+ task_id = f"{exp_id}-{stage_type}"
452
+
453
+ async for progress in self.progress_adapter.subscribe(task_id):
454
+ yield progress
455
+
456
+ # 检查是否为终态
457
+ if progress.get("status") in ("completed", "failed", "cancelled"):
458
+ break
459
+
460
+ def _experiment_to_response(self, experiment: Dict[str, Any]) -> ExperimentResponse:
461
+ """将实验数据转换为响应模型"""
462
+ stages_data = experiment.get("stages", {})
463
+ stages = {}
464
+
465
+ for stage_type, stage_info in stages_data.items():
466
+ stages[stage_type] = self._stage_to_status(stage_info)
467
+
468
+ # 解析日期时间
469
+ created_at = experiment.get("created_at")
470
+ if isinstance(created_at, str):
471
+ created_at = datetime.fromisoformat(created_at)
472
+ elif created_at is None:
473
+ created_at = datetime.utcnow()
474
+
475
+ updated_at = experiment.get("updated_at")
476
+ if isinstance(updated_at, str):
477
+ updated_at = datetime.fromisoformat(updated_at)
478
+
479
+ return ExperimentResponse(
480
+ id=experiment["id"],
481
+ exp_name=experiment["exp_name"],
482
+ version=experiment.get("version", "v2"),
483
+ status=experiment.get("status", "created"),
484
+ gpu_numbers=experiment.get("gpu_numbers", "0"),
485
+ is_half=experiment.get("is_half", True),
486
+ audio_file_id=experiment.get("audio_file_id", ""),
487
+ stages=stages,
488
+ created_at=created_at,
489
+ updated_at=updated_at,
490
+ )
491
+
492
+ def _stage_to_status(self, stage: Dict[str, Any]) -> StageStatus:
493
+ """将阶段数据转换为状态模型"""
494
+ # 解析日期时间
495
+ started_at = stage.get("started_at")
496
+ if isinstance(started_at, str):
497
+ started_at = datetime.fromisoformat(started_at)
498
+
499
+ completed_at = stage.get("completed_at")
500
+ if isinstance(completed_at, str):
501
+ completed_at = datetime.fromisoformat(completed_at)
502
+
503
+ return StageStatus(
504
+ stage_type=stage.get("stage_type", ""),
505
+ status=stage.get("status", "pending"),
506
+ progress=stage.get("progress"),
507
+ message=stage.get("message"),
508
+ started_at=started_at,
509
+ completed_at=completed_at,
510
+ config=stage.get("config"),
511
+ outputs=stage.get("outputs"),
512
+ error_message=stage.get("error_message"),
513
+ )
api_server/app/services/file_service.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文件管理服务
3
+
4
+ 处理文件上传、下载和管理的业务逻辑
5
+ """
6
+
7
+ from datetime import datetime
8
+ from typing import List, Optional, Tuple
9
+
10
+ from ..core.adapters import get_database_adapter, get_storage_adapter
11
+ from ..models.schemas.file import (
12
+ FileMetadata,
13
+ FileUploadResponse,
14
+ FileListResponse,
15
+ FileDeleteResponse,
16
+ )
17
+
18
+
19
+ class FileService:
20
+ """
21
+ 文件管理服务
22
+
23
+ 提供文件的完整生命周期管理:
24
+ - 上传文件
25
+ - 下载文件
26
+ - 获取元数据
27
+ - 列出文件
28
+ - 删除文件
29
+
30
+ Example:
31
+ >>> service = FileService()
32
+ >>> result = await service.upload_file(data, "audio.wav", "audio/wav", "training")
33
+ >>> content = await service.download_file(result.file.id)
34
+ >>> await service.delete_file(result.file.id)
35
+ """
36
+
37
+ def __init__(self):
38
+ """初始化服务"""
39
+ self._db = None
40
+ self._storage = None
41
+
42
+ @property
43
+ def db(self):
44
+ """延迟获取数据库适配器"""
45
+ if self._db is None:
46
+ self._db = get_database_adapter()
47
+ return self._db
48
+
49
+ @property
50
+ def storage(self):
51
+ """延迟获取存储适配器"""
52
+ if self._storage is None:
53
+ self._storage = get_storage_adapter()
54
+ return self._storage
55
+
56
+ async def upload_file(
57
+ self,
58
+ file_data: bytes,
59
+ filename: str,
60
+ content_type: Optional[str] = None,
61
+ purpose: str = "training"
62
+ ) -> FileUploadResponse:
63
+ """
64
+ 上传文件
65
+
66
+ Args:
67
+ file_data: 文件二进制数据
68
+ filename: 原始文件名
69
+ content_type: MIME 类型
70
+ purpose: 文件用途 (training, reference, output)
71
+
72
+ Returns:
73
+ FileUploadResponse
74
+ """
75
+ # 构建元数据
76
+ metadata = {
77
+ "content_type": content_type,
78
+ "purpose": purpose,
79
+ "size_bytes": len(file_data),
80
+ }
81
+
82
+ # 上传到存储
83
+ file_id = await self.storage.upload_file(file_data, filename, metadata)
84
+
85
+ # 获取完整元数据(包含音频信息)
86
+ full_metadata = await self.storage.get_file_metadata(file_id)
87
+
88
+ # 保存到数据库
89
+ file_record = {
90
+ "id": file_id,
91
+ "filename": filename,
92
+ "content_type": content_type,
93
+ "size_bytes": len(file_data),
94
+ "purpose": purpose,
95
+ "duration_seconds": full_metadata.get("duration_seconds") if full_metadata else None,
96
+ "sample_rate": full_metadata.get("sample_rate") if full_metadata else None,
97
+ "uploaded_at": datetime.utcnow().isoformat(),
98
+ }
99
+
100
+ await self.db.create_file_record(file_record)
101
+
102
+ # 构建响应
103
+ file_metadata = FileMetadata(
104
+ id=file_id,
105
+ filename=filename,
106
+ content_type=content_type,
107
+ size_bytes=len(file_data),
108
+ purpose=purpose,
109
+ duration_seconds=file_record.get("duration_seconds"),
110
+ sample_rate=file_record.get("sample_rate"),
111
+ uploaded_at=datetime.utcnow(),
112
+ )
113
+
114
+ return FileUploadResponse(
115
+ success=True,
116
+ message="文件上传成功",
117
+ file=file_metadata,
118
+ )
119
+
120
+ async def download_file(self, file_id: str) -> Optional[Tuple[bytes, str, str]]:
121
+ """
122
+ 下载文件
123
+
124
+ Args:
125
+ file_id: 文件ID
126
+
127
+ Returns:
128
+ (文件数据, 文件名, 内容类型) 或 None
129
+ """
130
+ # 检查文件是否存在
131
+ if not await self.storage.file_exists(file_id):
132
+ return None
133
+
134
+ # 获取元数据
135
+ metadata = await self.storage.get_file_metadata(file_id)
136
+ if not metadata:
137
+ return None
138
+
139
+ # 下载文件
140
+ file_data = await self.storage.download_file(file_id)
141
+
142
+ return (
143
+ file_data,
144
+ metadata.get("filename", "file"),
145
+ metadata.get("content_type", "application/octet-stream"),
146
+ )
147
+
148
+ async def get_file(self, file_id: str) -> Optional[FileMetadata]:
149
+ """
150
+ 获取文件元数据
151
+
152
+ Args:
153
+ file_id: 文件ID
154
+
155
+ Returns:
156
+ FileMetadata 或 None
157
+ """
158
+ # 从数据库获取
159
+ record = await self.db.get_file_record(file_id)
160
+ if record:
161
+ return self._record_to_metadata(record)
162
+
163
+ # 尝试从存储获取
164
+ metadata = await self.storage.get_file_metadata(file_id)
165
+ if metadata:
166
+ return self._storage_metadata_to_file_metadata(metadata)
167
+
168
+ return None
169
+
170
+ async def list_files(
171
+ self,
172
+ purpose: Optional[str] = None,
173
+ limit: int = 50,
174
+ offset: int = 0
175
+ ) -> FileListResponse:
176
+ """
177
+ 获取文件列表
178
+
179
+ Args:
180
+ purpose: 按用途筛选
181
+ limit: 每页数量
182
+ offset: 偏移量
183
+
184
+ Returns:
185
+ FileListResponse
186
+ """
187
+ # 从数据库获取
188
+ records = await self.db.list_file_records(
189
+ purpose=purpose, limit=limit, offset=offset
190
+ )
191
+ total = await self.db.count_file_records(purpose=purpose)
192
+
193
+ return FileListResponse(
194
+ items=[self._record_to_metadata(r) for r in records],
195
+ total=total,
196
+ limit=limit,
197
+ offset=offset,
198
+ )
199
+
200
+ async def delete_file(self, file_id: str) -> FileDeleteResponse:
201
+ """
202
+ 删除文件
203
+
204
+ Args:
205
+ file_id: 文件ID
206
+
207
+ Returns:
208
+ FileDeleteResponse
209
+ """
210
+ # 从存储删除
211
+ storage_deleted = await self.storage.delete_file(file_id)
212
+
213
+ # 从数据库删除
214
+ db_deleted = await self.db.delete_file_record(file_id)
215
+
216
+ if storage_deleted or db_deleted:
217
+ return FileDeleteResponse(
218
+ success=True,
219
+ message="文件删除成功",
220
+ file_id=file_id,
221
+ )
222
+ else:
223
+ return FileDeleteResponse(
224
+ success=False,
225
+ message="文件不存在或已删除",
226
+ file_id=file_id,
227
+ )
228
+
229
+ async def file_exists(self, file_id: str) -> bool:
230
+ """
231
+ 检查文件是否存在
232
+
233
+ Args:
234
+ file_id: 文件ID
235
+
236
+ Returns:
237
+ 是否存在
238
+ """
239
+ return await self.storage.file_exists(file_id)
240
+
241
+ def _record_to_metadata(self, record: dict) -> FileMetadata:
242
+ """将数据库记录转换为 FileMetadata"""
243
+ uploaded_at = record.get("uploaded_at")
244
+ if isinstance(uploaded_at, str):
245
+ uploaded_at = datetime.fromisoformat(uploaded_at)
246
+ elif uploaded_at is None:
247
+ uploaded_at = datetime.utcnow()
248
+
249
+ return FileMetadata(
250
+ id=record["id"],
251
+ filename=record["filename"],
252
+ content_type=record.get("content_type"),
253
+ size_bytes=record.get("size_bytes", 0),
254
+ purpose=record.get("purpose", "training"),
255
+ duration_seconds=record.get("duration_seconds"),
256
+ sample_rate=record.get("sample_rate"),
257
+ uploaded_at=uploaded_at,
258
+ )
259
+
260
+ def _storage_metadata_to_file_metadata(self, metadata: dict) -> FileMetadata:
261
+ """将存储元数据转换为 FileMetadata"""
262
+ uploaded_at = metadata.get("uploaded_at")
263
+ if isinstance(uploaded_at, str):
264
+ uploaded_at = datetime.fromisoformat(uploaded_at)
265
+ elif uploaded_at is None:
266
+ uploaded_at = datetime.utcnow()
267
+
268
+ return FileMetadata(
269
+ id=metadata.get("id", ""),
270
+ filename=metadata.get("filename", ""),
271
+ content_type=metadata.get("content_type"),
272
+ size_bytes=metadata.get("size_bytes", 0),
273
+ purpose=metadata.get("purpose", "training"),
274
+ duration_seconds=metadata.get("duration_seconds"),
275
+ sample_rate=metadata.get("sample_rate"),
276
+ uploaded_at=uploaded_at,
277
+ )
api_server/app/services/task_service.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick Mode 任务服务
3
+
4
+ 处理一键训练任务的业务逻辑
5
+ """
6
+
7
+ import uuid
8
+ from datetime import datetime
9
+ from typing import AsyncGenerator, Dict, List, Optional, Any
10
+
11
+ from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter
12
+ from ..core.config import settings
13
+ from ..models.domain import Task, TaskStatus
14
+ from ..models.schemas.task import (
15
+ QuickModeRequest,
16
+ TaskResponse,
17
+ TaskListResponse,
18
+ )
19
+
20
+
21
+ # 质量预设配置
22
+ QUALITY_PRESETS = {
23
+ "fast": {
24
+ "sovits_epochs": 4,
25
+ "gpt_epochs": 8,
26
+ "description": "快速训练,约10分钟",
27
+ },
28
+ "standard": {
29
+ "sovits_epochs": 8,
30
+ "gpt_epochs": 15,
31
+ "description": "标准训练,约20分钟",
32
+ },
33
+ "high": {
34
+ "sovits_epochs": 16,
35
+ "gpt_epochs": 30,
36
+ "description": "高质量训练,约40分钟",
37
+ },
38
+ }
39
+
40
+
41
+ class TaskService:
42
+ """
43
+ Quick Mode 任务服务
44
+
45
+ 提供一键训练任务的完整生命周期管理:
46
+ - 创建任务
47
+ - 查询任务状态
48
+ - 取消任务
49
+ - 订阅进度更新
50
+
51
+ Example:
52
+ >>> service = TaskService()
53
+ >>> task = await service.create_quick_task(request)
54
+ >>> status = await service.get_task(task.id)
55
+ >>> await service.cancel_task(task.id)
56
+ """
57
+
58
+ def __init__(self):
59
+ """初始化服务"""
60
+ self._db = None
61
+ self._queue = None
62
+ self._storage = None
63
+
64
+ @property
65
+ def db(self):
66
+ """延迟获取数据库适配器"""
67
+ if self._db is None:
68
+ self._db = get_database_adapter()
69
+ return self._db
70
+
71
+ @property
72
+ def queue(self):
73
+ """延迟获取任务队列适配器"""
74
+ if self._queue is None:
75
+ self._queue = get_task_queue_adapter()
76
+ return self._queue
77
+
78
+ @property
79
+ def storage(self):
80
+ """延迟获取存储适配器"""
81
+ if self._storage is None:
82
+ self._storage = get_storage_adapter()
83
+ return self._storage
84
+
85
+ async def check_exp_name_exists(self, exp_name: str) -> bool:
86
+ """
87
+ 检查实验名称是否已存在
88
+
89
+ Args:
90
+ exp_name: 实验名称
91
+
92
+ Returns:
93
+ 如果存在返回 True,否则返回 False
94
+ """
95
+ existing_task = await self.db.get_task_by_exp_name(exp_name)
96
+ return existing_task is not None
97
+
98
+ async def validate_audio_file(self, audio_file_id: str) -> tuple[bool, str]:
99
+ """
100
+ 验证音频文件是否存在
101
+
102
+ Args:
103
+ audio_file_id: 音频文件 ID 或路径
104
+
105
+ Returns:
106
+ (是否存在, 实际文件路径)
107
+ """
108
+ import os
109
+
110
+ # 尝试获取文件元数据
111
+ file_metadata = await self.storage.get_file_metadata(audio_file_id)
112
+
113
+ if file_metadata:
114
+ # 文件存储在 storage.base_path / file_id
115
+ audio_file_path = str(self.storage.base_path / audio_file_id)
116
+ exists = os.path.exists(audio_file_path)
117
+ return exists, audio_file_path
118
+ else:
119
+ # 如果找不到元数据,将 audio_file_id 当作路径
120
+ exists = os.path.exists(audio_file_id)
121
+ return exists, audio_file_id
122
+
123
+ async def create_quick_task(self, request: QuickModeRequest) -> TaskResponse:
124
+ """
125
+ 创建一键训练任务
126
+
127
+ 根据请求参数和质量预设,自动配置训练参数并创建任务。
128
+
129
+ Args:
130
+ request: 快速模式请求
131
+
132
+ Returns:
133
+ TaskResponse: 任务响应
134
+ """
135
+ # 生成任务ID
136
+ task_id = f"task-{uuid.uuid4().hex[:12]}"
137
+
138
+ # 获取质量预设
139
+ quality = request.options.quality
140
+ preset = QUALITY_PRESETS.get(quality, QUALITY_PRESETS["standard"])
141
+
142
+ # 验证并解析音频文件路径
143
+ audio_file_id = request.audio_file_id
144
+ _, audio_file_path = await self.validate_audio_file(audio_file_id)
145
+
146
+ # 构建任务配置
147
+ config = {
148
+ "exp_name": request.exp_name,
149
+ "audio_file_id": audio_file_id,
150
+ "input_path": audio_file_path, # 音频文件的实际路径
151
+ "version": request.options.version,
152
+ "language": request.options.language,
153
+ "quality": quality,
154
+ # 训练参数
155
+ "total_epoch": preset["sovits_epochs"], # SoVITS epoch
156
+ "sovits_epochs": preset["sovits_epochs"],
157
+ "gpt_epochs": preset["gpt_epochs"],
158
+ # 预训练模型路径
159
+ "bert_pretrained_dir": str(settings.BERT_PRETRAINED_DIR),
160
+ "ssl_pretrained_dir": str(settings.SSL_PRETRAINED_DIR),
161
+ "pretrained_s2G": str(settings.PRETRAINED_S2G),
162
+ "pretrained_s2D": str(settings.PRETRAINED_S2D),
163
+ "pretrained_s1": str(settings.PRETRAINED_S1),
164
+ # 执行完整流程
165
+ "stages": [
166
+ "audio_slice",
167
+ "asr",
168
+ "text_feature",
169
+ "hubert_feature",
170
+ "semantic_token",
171
+ "sovits_train",
172
+ "gpt_train",
173
+ ],
174
+ }
175
+
176
+ # 创建 Task 领域模型
177
+ task = Task(
178
+ id=task_id,
179
+ exp_name=request.exp_name,
180
+ config=config,
181
+ status=TaskStatus.QUEUED,
182
+ created_at=datetime.utcnow(),
183
+ )
184
+
185
+ # 保存到数据库
186
+ await self.db.create_task(task)
187
+
188
+ # 入队执行
189
+ job_id = await self.queue.enqueue(task_id, config)
190
+
191
+ # 更新 job_id
192
+ await self.db.update_task(task_id, {"job_id": job_id})
193
+ task.job_id = job_id
194
+
195
+ return self._task_to_response(task)
196
+
197
+ async def get_task(self, task_id: str) -> Optional[TaskResponse]:
198
+ """
199
+ 获取任务详情
200
+
201
+ Args:
202
+ task_id: 任务ID
203
+
204
+ Returns:
205
+ TaskResponse 或 None(不存在时)
206
+ """
207
+ task = await self.db.get_task(task_id)
208
+ if not task:
209
+ return None
210
+ return self._task_to_response(task)
211
+
212
+ async def list_tasks(
213
+ self,
214
+ status: Optional[str] = None,
215
+ limit: int = 50,
216
+ offset: int = 0
217
+ ) -> TaskListResponse:
218
+ """
219
+ 获取任务列表
220
+
221
+ Args:
222
+ status: 按状态筛选
223
+ limit: 每页数量
224
+ offset: 偏移量
225
+
226
+ Returns:
227
+ TaskListResponse
228
+ """
229
+ tasks = await self.db.list_tasks(status=status, limit=limit, offset=offset)
230
+ total = await self.db.count_tasks(status=status)
231
+
232
+ return TaskListResponse(
233
+ items=[self._task_to_response(t) for t in tasks],
234
+ total=total,
235
+ limit=limit,
236
+ offset=offset,
237
+ )
238
+
239
+ async def cancel_task(self, task_id: str) -> bool:
240
+ """
241
+ 取消任务
242
+
243
+ Args:
244
+ task_id: 任务ID
245
+
246
+ Returns:
247
+ 是否成功取消
248
+ """
249
+ # 获取任务
250
+ task = await self.db.get_task(task_id)
251
+ if not task:
252
+ return False
253
+
254
+ # 只有排队中或运行中的任务可以取消
255
+ if task.status not in (TaskStatus.QUEUED, TaskStatus.RUNNING):
256
+ return False
257
+
258
+ # 如果有 job_id,尝试取消队列任务
259
+ if task.job_id:
260
+ await self.queue.cancel(task.job_id)
261
+
262
+ # 更新状态
263
+ await self.db.update_task(task_id, {
264
+ "status": TaskStatus.CANCELLED,
265
+ "completed_at": datetime.utcnow(),
266
+ "message": "任务已取消",
267
+ })
268
+
269
+ return True
270
+
271
+ async def subscribe_progress(
272
+ self,
273
+ task_id: str
274
+ ) -> AsyncGenerator[Dict[str, Any], None]:
275
+ """
276
+ 订阅任务进度(SSE 流)
277
+
278
+ Args:
279
+ task_id: 任务ID
280
+
281
+ Yields:
282
+ 进度信息字典
283
+ """
284
+ # 检查任务是否存在
285
+ task = await self.db.get_task(task_id)
286
+ if not task:
287
+ yield {"type": "error", "message": "任务不存在"}
288
+ return
289
+
290
+ # 如果任务已结束,直接返回最终状态
291
+ if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
292
+ yield {
293
+ "type": "final",
294
+ "status": task.status.value,
295
+ "message": task.message or task.error_message,
296
+ "progress": task.progress,
297
+ }
298
+ return
299
+
300
+ # 订阅进度更新
301
+ async for progress in self.queue.subscribe_progress(task_id):
302
+ yield progress
303
+
304
+ # 检查是否为终态
305
+ if progress.get("status") in ("completed", "failed", "cancelled"):
306
+ break
307
+
308
+ def _task_to_response(self, task: Task) -> TaskResponse:
309
+ """将 Task 领域模型转换为 TaskResponse"""
310
+ return TaskResponse(
311
+ id=task.id,
312
+ exp_name=task.exp_name,
313
+ status=task.status.value if isinstance(task.status, TaskStatus) else task.status,
314
+ current_stage=task.current_stage,
315
+ progress=task.stage_progress,
316
+ overall_progress=task.progress,
317
+ message=task.message,
318
+ error_message=task.error_message,
319
+ created_at=task.created_at,
320
+ started_at=task.started_at,
321
+ completed_at=task.completed_at,
322
+ )