diff --git "a/development.md" "b/development.md" new file mode 100644--- /dev/null +++ "b/development.md" @@ -0,0 +1,3339 @@ +# GPT-SoVITS 音色训练 HTTP API 服务架构设计 + +> **文档说明**: 本文档是 API 服务的完整架构设计文档,包含设计规范和实现参考代码。 + +## 实现进度总览 + +| 模块 | 状态 | 说明 | +|------|------|------| +| **架构设计** | ✅ 完成 | 双模式 API 设计(Quick Mode + Advanced Mode) | +| **Pydantic Schema** | ✅ 已实现 | `app/models/schemas/` - task.py, experiment.py, file.py, common.py | +| **数据库 Schema** | ✅ 设计完成 | SQLite/PostgreSQL 表结构 | +| **适配器基类** | ✅ 已实现 | `TaskQueueAdapter`, `ProgressAdapter`, `StorageAdapter`, `DatabaseAdapter` | +| **AsyncTrainingManager** | ✅ 已实现 | 本地任务队列完整实现 | +| **配置管理** | ✅ 已实现 | `app/core/config.py` | +| **领域模型** | ✅ 已实现 | `Task`, `TaskStatus`, `ProgressInfo` | +| **Pipeline 脚本** | ✅ 已实现 | `app/scripts/run_pipeline.py` | +| **存储适配器** | ✅ 已实现 | `app/adapters/local/storage.py` - LocalStorageAdapter | +| **数据库适配器** | ✅ 已实现 | `app/adapters/local/database.py` - SQLiteAdapter | +| **进度适配器** | ✅ 已实现 | `app/adapters/local/progress.py` - LocalProgressAdapter | +| **适配器工厂** | ✅ 已实现 | `app/core/adapters.py` - AdapterFactory | +| **API 端点** | ✅ 已实现 | `app/api/v1/endpoints/` - tasks, experiments, files, stages | +| **服务层** | ✅ 已实现 | `app/services/` - TaskService, ExperimentService, FileService | +| **FastAPI 入口** | ✅ 已实现 | `app/main.py` - 应用入口和生命周期管理 | + +--- + +## 一、架构总览 + +### 1.1 两种部署场景对比 + +| 维度 | macOS本地训练 | Linux服务器端训练 | +|------|--------------|------------------| +| **用户场景** | 个人开发者、小规模训练 | 生产环境、多用户、大规模训练 | +| **并发需求** | 单用户、串行任务 | 多用户、并发任务 | +| **资源管理** | 简单(单机GPU) | 复杂(多GPU、分布式) | +| **持久化需求** | 轻量级(SQLite/文件) | 重量级(PostgreSQL/分布式存储) | +| **任务队列** | 简单队列(内存/SQLite) | 分布式队列(Celery+Redis) | +| **API复杂度** | 简化版 | 完整版 | + +### 1.1.1 macOS本地训练的运行模式 + +macOS本地训练可以有三种运行方式,需要根据最终交付形态选择合适的任务管理方案: + +| 运行模式 | 描述 | 启动方式 | 任务管理推荐 | +|----------|------|----------|-------------| +| **开发模式** | 直接运行Python脚本 | `python main.py` / `uvicorn` | asyncio.subprocess ⭐ | +| **PyInstaller打包** | 打包为独立可执行文件 | `./app` 单个可执行文件 | asyncio.subprocess ⭐ | +| **Electron集成** | 作为Electron子进程运行 | Electron spawn Python进程 | asyncio.subprocess ⭐ | + +#### ⚠️ PyInstaller + Electron 场景的特殊考量 + +当需要将训练工程通过PyInstaller打包并集成到Electron应用时,**Huey不是合适的选择**,原因如下: + +1. **多进程架构冲突**:Huey需要独立的`huey_consumer`进程 +2. **进程生命周期复杂**:Electron需要管理多个Python子进程 +3. **打包复杂度增加**:PyInstaller需要正确打包所有依赖 + +**推荐方案**:使用 **`asyncio.subprocess`** 方案(见第7.1节),训练任务本身已经是子进程,无需额外的任务队列。 + +### 1.2 架构统一设计原则 + +**核心理念**: 使用适配器模式,统一API层和业务逻辑层,底层存储和任务执行通过适配器切换 + +``` +┌─────────────────────────────────────────────────────┐ +│ Unified API Layer (FastAPI) │ +│ /api/v1/tasks, /api/v1/experiments, /files, etc. │ +└────────────────────┬────────────────────────────────┘ + │ +┌────────────────────▼────────────────────────────────┐ +│ Service Layer (Unified) │ +│ TaskService, ExperimentService, FileService, etc. │ +└────────┬───────────────────────────────┬────────────┘ + │ │ + │ Adapter Pattern │ + │ │ + ┌────▼─────┐ ┌─────▼──────┐ + │ Local │ │ Server │ + │ Adapter │ │ Adapter │ + └────┬─────┘ └─────┬──────┘ + │ │ + ┌────▼─────────────┐ ┌────────▼────────────┐ + │ Local Backend │ │ Server Backend │ + │ - SQLite │ │ - PostgreSQL │ + │ - asyncio.subproc│ │ - Celery+Redis │ + │ - Local FS │ │ - S3/MinIO │ + └──────────────────┘ └─────────────────────┘ +``` + + +--- + +## 二、技术栈对比 + +### 2.1 macOS本地训练方案 + +```yaml +Web框架: FastAPI +数据库: SQLite (aiosqlite) +任务管理: asyncio.subprocess (推荐) - 训练脚本本身是子进程 +文件存储: 本地文件系统 +进度推送: SSE (Server-Sent Events) +缓存: 内存 (lru_cache / cachetools) +日志: Loguru +配置: YAML / .env文件 +``` + + +**优点**: +- 无需额外服务(Redis、PostgreSQL) +- 部署简单,一键启动 +- 适合个人使用 + +**缺点**: +- 不支持水平扩展 +- 单点故障 +- 任务并发能力有限 + +### 2.2 Linux服务器端训练方案 + +```yaml +Web框架: FastAPI +数据库: PostgreSQL + Alembic (数据迁移) +任务队列: Celery + Redis +文件存储: MinIO / S3 +进度推送: SSE + Redis Pub/Sub +缓存: Redis +日志: Loguru + ELK Stack (可选) +监控: Prometheus + Grafana +配置: 环境变量 + Consul/etcd (可选) +``` + + +**优点**: +- 高并发、高可用 +- 水平扩展 +- 完整的监控告警 + +**缺点**: +- 部署复杂 +- 需要额外服务依赖 + +--- + +## 三、统一架构设计 + +### 3.1 项目结构 + +> **图例**: ✅ 已实现 | [待实现] 设计完成待开发 | [Phase 2] 服务器模式后续实现 + +``` +api_server/ +├── app/ +│ ├── __init__.py # ✅ 已实现 +│ │ +│ ├── api/ # ✅ API 路由层 +│ │ ├── __init__.py # ✅ 已实现 +│ │ ├── deps.py # ✅ 已实现 - 依赖注入 +│ │ └── v1/ +│ │ ├── __init__.py # ✅ 已实现 +│ │ ├── endpoints/ +│ │ │ ├── __init__.py # ✅ 已实现 +│ │ │ ├── tasks.py # ✅ 已实现 - Quick Mode 任务管理 +│ │ │ ├── experiments.py # ✅ 已实现 - Advanced Mode 实验管理 +│ │ │ ├── stages.py # ✅ 已实现 - 阶段参数模板 +│ │ │ ├── files.py # ✅ 已实现 - 文件管理 +│ │ │ ├── models.py # [待实现] 模型管理 +│ │ │ └── inference.py # [待实现] 推理接口 +│ │ └── router.py # ✅ 已实现 - 路由注册 +│ │ +│ ├── core/ +│ │ ├── __init__.py # ✅ 已实现 +│ │ ├── config.py # ✅ 已实现 - Settings, 路径常量, get_pythonpath() +│ │ ├── adapters.py # ✅ 已实现 - 适配器工厂 +│ │ └── enums.py # [待实现] 枚举定义 +│ │ +│ ├── services/ # ✅ 业务逻辑层 +│ │ ├── __init__.py # ✅ 已实现 +│ │ ├── task_service.py # ✅ 已实现 - Quick Mode 任务服务 +│ │ ├── experiment_service.py # ✅ 已实现 - Advanced Mode 实验服务 +│ │ ├── file_service.py # ✅ 已实现 - 文件管理服务 +│ │ ├── pipeline_service.py # [待实现] +│ │ └── progress_service.py # [待实现] +│ │ +│ ├── adapters/ # 适配器层 +│ │ ├── __init__.py # ✅ 已实现 +│ │ ├── base.py # ✅ 已实现 - TaskQueueAdapter, ProgressAdapter, StorageAdapter, DatabaseAdapter +│ │ ├── local/ +│ │ │ ├── __init__.py # ✅ 已实现 +│ │ │ ├── task_queue.py # ✅ 已实现 - AsyncTrainingManager (完整) +│ │ │ ├── storage.py # ✅ 已实现 - LocalStorageAdapter +│ │ │ ├── database.py # ✅ 已实现 - SQLiteAdapter +│ │ │ └── progress.py # ✅ 已实现 - LocalProgressAdapter +│ │ └── server/ # [Phase 2] +│ │ ├── storage.py # S3/MinIO 适配器 +│ │ ├── task_queue.py # Celery 适配器 +│ │ └── database.py # PostgreSQL 适配器 +│ │ +│ ├── models/ +│ │ ├── __init__.py # ✅ 已实现 +│ │ ├── domain.py # ✅ 已实现 - Task, TaskStatus, ProgressInfo +│ │ └── schemas/ # ✅ 已实现 - Pydantic 模型 +│ │ ├── __init__.py # ✅ 已实现 - Schema 模块导出 +│ │ ├── common.py # ✅ 已实现 - 通用响应模型 +│ │ ├── task.py # ✅ 已实现 - Quick Mode 任务模型 +│ │ ├── experiment.py # ✅ 已实现 - Advanced Mode 实验/阶段模型 +│ │ ├── file.py # ✅ 已实现 - 文件上传/下载模型 +│ │ └── inference.py # [待实现] 推理相关模型 +│ │ +│ ├── scripts/ +│ │ ├── __init__.py # ✅ 已实现 +│ │ └── run_pipeline.py # ✅ 已实现 - Pipeline 子进程执行器 +│ │ +│ ├── workers/ # [待实现] 任务执行器 +│ │ ├── local_worker.py # 本地执行器 +│ │ └── celery_worker.py # [Phase 2] Celery 执行器 +│ │ +│ └── main.py # ✅ 已实现 - FastAPI 入口 +│ +├── data/ # 数据目录 +│ ├── configs/ # 任务配置文件 +│ ├── tasks.db # SQLite 数据库 +│ └── test_config.json # 测试配置 +│ +├── config/ # [待实现] +│ ├── local.yaml # 本地配置 +│ └── server.yaml # 服务器配置 +│ +├── requirements/ # [待实现] +│ ├── base.txt # 共同依赖 +│ ├── local.txt # 本地额外依赖 +│ └── server.txt # 服务器额外依赖 +│ +├── docker-compose.local.yml # [待实现] 本地开发 +├── docker-compose.server.yml # [Phase 2] 服务器部署 +└── README.md # [待实现] +``` + + +### 3.2 核心适配器设计 + +#### 3.2.1 抽象基类 ✅ 已完成 + +> **实现状态**: 所有适配器抽象基类已在 `app/adapters/base.py` 中实现: +> - `TaskQueueAdapter` - 任务队列接口 +> - `ProgressAdapter` - 进度管理接口 +> - `StorageAdapter` - 文件存储接口 +> - `DatabaseAdapter` - 数据库操作接口 + +```python +# app/adapters/base.py - ✅ 已实现部分 + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, AsyncGenerator + + +class TaskQueueAdapter(ABC): + """ + 任务队列适配器抽象基类 ✅ 已实现 + + 定义任务队列的通用接口,支持本地(asyncio.subprocess)和 + 服务器(Celery)两种实现方式。 + """ + + @abstractmethod + async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: + """将任务加入队列,返回job_id""" + pass + + @abstractmethod + async def get_status(self, job_id: str) -> Dict: + """获取任务状态""" + pass + + @abstractmethod + async def cancel(self, job_id: str) -> bool: + """取消任务""" + pass + + @abstractmethod + async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: + """订阅任务进度(SSE流)""" + pass + + +class ProgressAdapter(ABC): + """ + 进度管理适配器抽象基类 ✅ 已实现 + + 用于更新和订阅任务进度,支持本地(内存队列)和 + 服务器(Redis Pub/Sub)两种实现。 + """ + + @abstractmethod + async def update_progress(self, task_id: str, progress: Dict) -> None: + """更新进度""" + pass + + @abstractmethod + async def get_progress(self, task_id: str) -> Optional[Dict]: + """获取当前进度""" + pass + + @abstractmethod + async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: + """订阅进度更新""" + pass +``` + +```python +# app/adapters/base.py - 待实现部分 + +class StorageAdapter(ABC): + """存储适配器抽象基类 [待实现]""" + + @abstractmethod + async def upload_file(self, file_data: bytes, filename: str, metadata: Dict) -> str: + """上传文件,返回文件ID""" + pass + + @abstractmethod + async def download_file(self, file_id: str) -> bytes: + """下载文件""" + pass + + @abstractmethod + async def delete_file(self, file_id: str) -> bool: + """删除文件""" + pass + + @abstractmethod + async def get_file_metadata(self, file_id: str) -> Dict: + """获取文件元数据""" + pass + + +class DatabaseAdapter(ABC): + """数据库适配器抽象基类 [待实现]""" + + @abstractmethod + async def create_task(self, task: Task) -> Task: + """创建任务""" + pass + + @abstractmethod + async def get_task(self, task_id: str) -> Optional[Task]: + """获取任务""" + pass + + @abstractmethod + async def update_task(self, task_id: str, updates: Dict) -> Task: + """更新任务""" + pass + + @abstractmethod + async def list_tasks(self, filters: Dict, limit: int, offset: int) -> List[Task]: + """查询任务列表""" + pass + + @abstractmethod + async def delete_task(self, task_id: str) -> bool: + """删除任务""" + pass +``` + + +#### 3.2.2 本地适配器实现 + +##### AsyncTrainingManager ✅ 已完整实现 + +> **实现文件**: `app/adapters/local/task_queue.py` +> +> 这是本地模式的核心组件,已完整实现以下功能: +> - 任务入队与异步执行 +> - 子进程管理 (`asyncio.create_subprocess_exec`) +> - 进度解析与 SSE 流推送 +> - 任务状态持久化(SQLite) +> - 任务取消与恢复 + +```python +# app/adapters/local/task_queue.py - ✅ 已完整实现 + +class AsyncTrainingManager(TaskQueueAdapter): + """ + 基于 asyncio.subprocess 的异步任务管理器 + + 特点: + 1. 使用 asyncio.create_subprocess_exec() 异步启动训练子进程 + 2. 完全非阻塞,与 FastAPI 异步模型完美契合 + 3. SQLite 持久化任务状态,支持应用重启后恢复 + 4. 实时解析子进程输出获取进度 + """ + + def __init__(self, db_path: str = None, max_concurrent: int = 1): + self.db_path = db_path or str(settings.SQLITE_PATH) + self.max_concurrent = max_concurrent + self.running_processes: Dict[str, asyncio.subprocess.Process] = {} + self.progress_channels: Dict[str, asyncio.Queue] = {} + self._init_db_sync() + + async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: + """将任务加入队列并异步启动""" + # ... 完整实现见源文件 + + async def get_status(self, job_id: str) -> Dict: + """获取任务状态""" + # ... 完整实现见源文件 + + async def get_status_by_task_id(self, task_id: str) -> Dict: + """通过 task_id 获取任务状态""" + # ... 完整实现见源文件 + + async def cancel(self, job_id: str) -> bool: + """取消任务(优雅终止 + 强制终止)""" + # ... 完整实现见源文件 + + async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: + """订阅任务进度(用于 SSE 流)""" + # ... 完整实现见源文件 + + async def list_tasks(self, status: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: + """列出任务""" + # ... 完整实现见源文件 + + async def recover_pending_tasks(self) -> int: + """应用重启后恢复未完成的任务""" + # ... 完整实现见源文件 + + async def cleanup_old_tasks(self, days: int = 7) -> int: + """清理旧任务记录""" + # ... 完整实现见源文件 +``` + +##### LocalStorageAdapter ✅ 已实现 + +> **实现文件**: `app/adapters/local/storage.py` +> +> 基于本地文件系统的存储适配器,使用 aiofiles 实现异步 I/O。 +> 支持文件上传/下载、元数据管理、音频信息提取等功能。 + +```python +# app/adapters/local/storage.py - ✅ 已完整实现 + +class LocalStorageAdapter(StorageAdapter): + """ + 本地文件系统存储适配器 + + 特点: + 1. 使用 aiofiles 进行异步文件读写 + 2. 元数据存储在 .meta.json 文件中 + 3. 支持音频文件信息提取(时长、采样率等) + """ + + async def upload_file(self, file_data: bytes, filename: str, metadata: Dict) -> str: + """上传文件,返回 file_id""" + # ... 完整实现见源文件 + + async def download_file(self, file_id: str) -> bytes: + """下载文件""" + # ... 完整实现见源文件 + + async def delete_file(self, file_id: str) -> bool: + """删除文件及其元数据""" + # ... 完整实现见源文件 + + async def get_file_metadata(self, file_id: str) -> Optional[Dict]: + """获取文件元数据""" + # ... 完整实现见源文件 + + async def list_files(self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: + """列出文件""" + # ... 完整实现见源文件 +``` + +##### SQLiteAdapter ✅ 已实现 + +> **实现文件**: `app/adapters/local/database.py` +> +> 基于 SQLite + aiosqlite 的数据库适配器,支持 Task 和 Experiment 的完整 CRUD 操作。 + +```python +# app/adapters/local/database.py - ✅ 已完整实现 + +class SQLiteAdapter(DatabaseAdapter): + """ + SQLite 数据库适配器 + + 特点: + 1. 使用 aiosqlite 实现异步数据库操作 + 2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理 + 3. 自动初始化数据库表结构 + """ + + # Task CRUD + async def create_task(self, task: Task) -> Task: ... + async def get_task(self, task_id: str) -> Optional[Task]: ... + async def update_task(self, task_id: str, updates: Dict) -> Optional[Task]: ... + async def list_tasks(self, status: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Task]: ... + async def delete_task(self, task_id: str) -> bool: ... + async def count_tasks(self, status: Optional[str] = None) -> int: ... + + # Experiment CRUD + async def create_experiment(self, experiment: Dict) -> Dict: ... + async def get_experiment(self, exp_id: str) -> Optional[Dict]: ... + async def update_experiment(self, exp_id: str, updates: Dict) -> Optional[Dict]: ... + async def list_experiments(self, status: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: ... + async def delete_experiment(self, exp_id: str) -> bool: ... + + # Stage 操作 + async def update_stage(self, exp_id: str, stage_type: str, updates: Dict) -> Optional[Dict]: ... + async def get_stage(self, exp_id: str, stage_type: str) -> Optional[Dict]: ... + async def get_all_stages(self, exp_id: str) -> List[Dict]: ... + + # File 记录 + async def create_file_record(self, file_data: Dict) -> Dict: ... + async def get_file_record(self, file_id: str) -> Optional[Dict]: ... + async def delete_file_record(self, file_id: str) -> bool: ... + async def list_file_records(self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: ... +``` + +##### LocalProgressAdapter ✅ 已实现 + +> **实现文件**: `app/adapters/local/progress.py` +> +> 基于内存队列的进度管理适配器,支持多订阅者模式。 + +```python +# app/adapters/local/progress.py - ✅ 已完整实现 + +class LocalProgressAdapter(ProgressAdapter): + """ + 本地内存进度管理适配器 + + 特点: + 1. 使用内存字典存储最新进度 + 2. 使用 asyncio.Queue 实现订阅者模式 + 3. 支持多订阅者同时订阅同一任务 + 4. 与 AsyncTrainingManager 的进度推送机制兼容 + """ + + async def update_progress(self, task_id: str, progress: Dict) -> None: + """更新进度并通知所有订阅者""" + # ... 完整实现见源文件 + + async def get_progress(self, task_id: str) -> Optional[Dict]: + """获取当前进度""" + # ... 完整实现见源文件 + + async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: + """订阅进度更新(支持心跳、自动清理)""" + # ... 完整实现见源文件 +``` + + +#### 3.2.3 服务器适配器实现 + +```python +# app/adapters/server/storage.py + +from minio import Minio +from app.adapters.base import StorageAdapter + +class S3StorageAdapter(StorageAdapter): + """MinIO/S3对象存储适配器""" + + def __init__(self, endpoint: str, access_key: str, secret_key: str, bucket: str): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=False + ) + self.bucket = bucket + + # 确保bucket存在 + if not self.client.bucket_exists(bucket): + self.client.make_bucket(bucket) + + async def upload_file(self, file_data: bytes, filename: str, metadata: Dict) -> str: + file_id = str(uuid.uuid4()) + + # 上传文件 + self.client.put_object( + self.bucket, + file_id, + io.BytesIO(file_data), + len(file_data), + metadata=metadata + ) + + return file_id + + # ... 其他方法实现 +``` + + +```python +# app/adapters/server/database.py + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from app.adapters.base import DatabaseAdapter + +class PostgreSQLAdapter(DatabaseAdapter): + """PostgreSQL数据库适配器""" + + def __init__(self, database_url: str): + self.engine = create_async_engine(database_url) + # 使用SQLAlchemy ORM + + async def create_task(self, task: Task) -> Task: + async with AsyncSession(self.engine) as session: + db_task = TaskModel(**task.dict()) + session.add(db_task) + await session.commit() + await session.refresh(db_task) + return Task.from_orm(db_task) + + # ... 其他方法实现 +``` + + +```python +# app/adapters/server/task_queue.py + +from celery import Celery +from app.adapters.base import TaskQueueAdapter + +class CeleryTaskQueueAdapter(TaskQueueAdapter): + """Celery分布式任务队列""" + + def __init__(self, broker_url: str, backend_url: str): + self.celery_app = Celery( + 'gpt_sovits_training', + broker=broker_url, + backend=backend_url + ) + + async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: + from app.workers.celery_worker import execute_training_pipeline + + result = execute_training_pipeline.apply_async( + args=[task_id, config], + queue=f'queue_{priority}', + priority=self._get_priority_value(priority) + ) + + return result.id + + async def get_status(self, job_id: str) -> Dict: + result = self.celery_app.AsyncResult(job_id) + return { + "status": result.state, + "info": result.info + } + + # ... 其他方法实现 +``` + + +```python +# app/adapters/server/progress.py + +import redis.asyncio as redis +from app.adapters.base import ProgressAdapter + +class RedisProgressAdapter(ProgressAdapter): + """Redis进度管理""" + + def __init__(self, redis_url: str): + self.redis = redis.from_url(redis_url) + + async def update_progress(self, task_id: str, progress: Dict): + # 保存到Redis Hash + await self.redis.hset( + f"task:progress:{task_id}", + mapping={ + "data": json.dumps(progress), + "updated_at": time.time() + } + ) + + # 发布到Redis Pub/Sub + await self.redis.publish( + f"task:progress:{task_id}", + json.dumps(progress) + ) + + async def get_progress(self, task_id: str) -> Optional[Dict]: + data = await self.redis.hget(f"task:progress:{task_id}", "data") + if data: + return json.loads(data) + return None + + async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: + pubsub = self.redis.pubsub() + await pubsub.subscribe(f"task:progress:{task_id}") + + try: + async for message in pubsub.listen(): + if message['type'] == 'message': + progress = json.loads(message['data']) + yield progress + + if progress.get('status') in ['completed', 'failed', 'cancelled']: + break + finally: + await pubsub.unsubscribe(f"task:progress:{task_id}") +``` + + +### 3.3 适配器工厂 + +```python +# app/core/adapters.py + +from app.core.config import settings +from app.adapters.base import StorageAdapter, DatabaseAdapter, TaskQueueAdapter, ProgressAdapter + +class AdapterFactory: + """适配器工厂,根据配置创建对应的适配器""" + + @staticmethod + def create_storage_adapter() -> StorageAdapter: + if settings.DEPLOYMENT_MODE == "local": + from app.adapters.local.storage import LocalStorageAdapter + return LocalStorageAdapter(base_path=settings.LOCAL_STORAGE_PATH) + else: + from app.adapters.server.storage import S3StorageAdapter + return S3StorageAdapter( + endpoint=settings.S3_ENDPOINT, + access_key=settings.S3_ACCESS_KEY, + secret_key=settings.S3_SECRET_KEY, + bucket=settings.S3_BUCKET + ) + + @staticmethod + def create_database_adapter() -> DatabaseAdapter: + if settings.DEPLOYMENT_MODE == "local": + from app.adapters.local.database import SQLiteAdapter + return SQLiteAdapter(db_path=settings.SQLITE_PATH) + else: + from app.adapters.server.database import PostgreSQLAdapter + return PostgreSQLAdapter(database_url=settings.DATABASE_URL) + + @staticmethod + def create_task_queue_adapter() -> TaskQueueAdapter: + if settings.DEPLOYMENT_MODE == "local": + from app.adapters.local.task_queue import AsyncTrainingManager + return AsyncTrainingManager(db_path=settings.SQLITE_PATH) + else: + from app.adapters.server.task_queue import CeleryTaskQueueAdapter + return CeleryTaskQueueAdapter( + broker_url=settings.CELERY_BROKER_URL, + backend_url=settings.CELERY_RESULT_BACKEND + ) + + @staticmethod + def create_progress_adapter() -> ProgressAdapter: + if settings.DEPLOYMENT_MODE == "local": + from app.adapters.local.progress import LocalProgressAdapter + return LocalProgressAdapter() + else: + from app.adapters.server.progress import RedisProgressAdapter + return RedisProgressAdapter(redis_url=settings.REDIS_URL) + + +# 全局单例 +storage_adapter = AdapterFactory.create_storage_adapter() +database_adapter = AdapterFactory.create_database_adapter() +task_queue_adapter = AdapterFactory.create_task_queue_adapter() +progress_adapter = AdapterFactory.create_progress_adapter() +``` + + +### 3.4 统一配置管理 + +```python +# app/core/config.py + +from pydantic_settings import BaseSettings +from typing import Literal + +class Settings(BaseSettings): + # 部署模式 + DEPLOYMENT_MODE: Literal["local", "server"] = "local" + + # 通用配置 + API_V1_PREFIX: str = "/api/v1" + PROJECT_NAME: str = "GPT-SoVITS Training API" + + # 本地模式配置 + LOCAL_STORAGE_PATH: str = "./data/files" + SQLITE_PATH: str = "./data/app.db" + LOCAL_MAX_WORKERS: int = 1 # 本地同时运行的训练任务数 + + # 服务器模式配置 + DATABASE_URL: str = "postgresql+asyncpg://user:pass@localhost/gpt_sovits" + REDIS_URL: str = "redis://localhost:6379/0" + CELERY_BROKER_URL: str = "redis://localhost:6379/1" + CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2" + + S3_ENDPOINT: str = "localhost:9000" + S3_ACCESS_KEY: str = "minioadmin" + S3_SECRET_KEY: str = "minioadmin" + S3_BUCKET: str = "gpt-sovits" + + class Config: + env_file = ".env" + case_sensitive = True + +settings = Settings() +``` + + +--- + +## 四、统一API接口(无差异) + +无论是本地还是服务器模式,API接口完全一致。 + +### 4.1 API 设计目标 + +针对不同用户群体,提供两套独立的 API 体系: + +| 用户类型 | 需求 | API 模式 | 核心概念 | API 前缀 | +|----------|------|----------|----------|----------| +| **小白用户** | 上传音频即可训练,无需了解细节 | Quick Mode | Task(任务) | `/api/v1/tasks` | +| **专家用户** | 精细控制每个阶段参数,分阶段执行 | Advanced Mode | Experiment(实验)+ Stage(阶段) | `/api/v1/experiments` | + +### 4.2 完整 API 端点列表 + +#### Quick Mode API(小白用户) + +| 方法 | 路径 | 描述 | +|------|------|------| +| `POST` | `/api/v1/tasks` | 创建一键训练任务 | +| `GET` | `/api/v1/tasks` | 获取任务列表 | +| `GET` | `/api/v1/tasks/{task_id}` | 获取任务详情 | +| `DELETE` | `/api/v1/tasks/{task_id}` | 取消任务 | +| `GET` | `/api/v1/tasks/{task_id}/progress` | SSE 进度订阅 | + +#### Advanced Mode API(专家用户) + +| 方法 | 路径 | 描述 | +|------|------|------| +| `POST` | `/api/v1/experiments` | 创建实验(不立即执行) | +| `GET` | `/api/v1/experiments` | 获取实验列表 | +| `GET` | `/api/v1/experiments/{exp_id}` | 获取实验详情 | +| `DELETE` | `/api/v1/experiments/{exp_id}` | 删除实验 | +| `PATCH` | `/api/v1/experiments/{exp_id}` | 更新实验基础配置 | +| `POST` | `/api/v1/experiments/{exp_id}/stages/{stage_type}` | 执行指定阶段 | +| `GET` | `/api/v1/experiments/{exp_id}/stages` | 获取所有阶段状态 | +| `GET` | `/api/v1/experiments/{exp_id}/stages/{stage_type}` | 获取指定阶段状态/结果 | +| `GET` | `/api/v1/experiments/{exp_id}/stages/{stage_type}/progress` | SSE 阶段进度订阅 | +| `DELETE` | `/api/v1/experiments/{exp_id}/stages/{stage_type}` | 取消正在执行的阶段 | + +#### 通用 API + +| 方法 | 路径 | 描述 | +|------|------|------| +| `POST` | `/api/v1/files` | 上传文件 | +| `GET` | `/api/v1/files` | 获取文件列表 | +| `GET` | `/api/v1/files/{file_id}` | 下载文件 | +| `DELETE` | `/api/v1/files/{file_id}` | 删除文件 | +| `GET` | `/api/v1/stages/presets` | 获取阶段预设列表 | +| `GET` | `/api/v1/stages/{stage_type}/schema` | 获取阶段参数模板 | + +--- + +## 4.3 Quick Mode API 详解(小白用户) + +### 4.3.1 创建一键训练任务 + +``` +POST /api/v1/tasks +``` + +只需上传音频文件,系统自动配置所有训练参数并执行完整流程: + +```json +{ + "exp_name": "my_voice", + "audio_file_id": "550e8400-e29b-41d4-a716-446655440000", + "options": { + "version": "v2", + "language": "zh", + "quality": "standard" + } +} +``` + +**参数说明**: + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `exp_name` | string | 是 | 实验名称 | +| `audio_file_id` | string | 是 | 已上传音频文件的 ID | +| `options.version` | string | 否 | 模型版本,默认 `"v2"` | +| `options.language` | string | 否 | 语言,默认 `"zh"` | +| `options.quality` | string | 否 | 训练质量:`"fast"` / `"standard"` / `"high"` | + +**质量预设**: + +| quality | SoVITS epochs | GPT epochs | 训练时长 | +|---------|---------------|------------|----------| +| `fast` | 4 | 8 | ~10分钟 | +| `standard` | 8 | 15 | ~20分钟 | +| `high` | 16 | 30 | ~40分钟 | + +**系统自动执行流程**: + +``` +audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train +``` + +**响应示例**: + +```json +{ + "id": "task-550e8400-e29b-41d4-a716-446655440000", + "exp_name": "my_voice", + "status": "queued", + "current_stage": null, + "progress": 0.0, + "overall_progress": 0.0, + "created_at": "2024-01-01T10:00:00Z" +} +``` + +### 4.3.2 获取任务状态 + +``` +GET /api/v1/tasks/{task_id} +``` + +**响应示例**: + +```json +{ + "id": "task-550e8400-e29b-41d4-a716-446655440000", + "exp_name": "my_voice", + "status": "running", + "current_stage": "sovits_train", + "progress": 0.45, + "overall_progress": 0.72, + "message": "SoVITS 训练中 Epoch 8/16", + "created_at": "2024-01-01T10:00:00Z", + "started_at": "2024-01-01T10:00:05Z" +} +``` + +### 4.3.3 SSE 进度订阅 + +``` +GET /api/v1/tasks/{task_id}/progress +``` + +返回 SSE 流,实时推送进度更新: + +``` +event: progress +data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"} + +event: progress +data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"} + +event: completed +data: {"status": "completed", "message": "��练完成"} +``` + +--- + +## 4.4 Advanced Mode API 详解(专家用户) + +Advanced Mode 引入**实验(Experiment)**概念,允许前端分阶段调用不同 API 触发训练。 + +### 4.4.1 专家模式交互流程 + +```mermaid +sequenceDiagram + participant Frontend + participant API + participant Pipeline + + Frontend->>API: POST /experiments (创建实验) + API-->>Frontend: {exp_id: "abc123"} + + Frontend->>API: POST /experiments/abc123/stages/audio_slice + API->>Pipeline: 启动音频切片 + Frontend->>API: GET .../audio_slice/progress (SSE) + Pipeline-->>Frontend: 进度更新... + Pipeline-->>Frontend: {status: "completed"} + + Note over Frontend: 用户查看切片结果,调整参数 + + Frontend->>API: POST /experiments/abc123/stages/asr + API->>Pipeline: 启动 ASR + Pipeline-->>Frontend: 进度更新... + + Note over Frontend: 继续后续阶段... +``` + +### 4.4.2 创建实验 + +``` +POST /api/v1/experiments +``` + +创建实验但不立即执行,用户可以逐阶段控制: + +```json +{ + "exp_name": "my_voice_custom", + "version": "v2", + "gpu_numbers": "0", + "is_half": true, + "audio_file_id": "550e8400-e29b-41d4-a716-446655440000" +} +``` + +**参数说明**: + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `exp_name` | string | 是 | 实验名称 | +| `version` | string | 否 | 模型版本,默认 `"v2"` | +| `gpu_numbers` | string | 否 | GPU 编号,默认 `"0"` | +| `is_half` | bool | 否 | 是否使用半精度,默认 `true` | +| `audio_file_id` | string | 是 | 已上传音频文件的 ID | + +**响应示例**: + +```json +{ + "id": "exp-abc123", + "exp_name": "my_voice_custom", + "version": "v2", + "status": "created", + "stages": { + "audio_slice": { "status": "pending" }, + "asr": { "status": "pending" }, + "text_feature": { "status": "pending" }, + "hubert_feature": { "status": "pending" }, + "semantic_token": { "status": "pending" }, + "sovits_train": { "status": "pending" }, + "gpt_train": { "status": "pending" } + }, + "created_at": "2024-01-01T10:00:00Z" +} +``` + +### 4.4.3 执行阶段 + +``` +POST /api/v1/experiments/{exp_id}/stages/{stage_type} +``` + +触发指定阶段执行,可传入阶段特定参数覆盖默认值: + +**可用的阶段类型(stage_type)**: + +| stage_type | 描述 | 依赖阶段 | +|------------|------|----------| +| `audio_slice` | 音频切片 | 无 | +| `asr` | 语音识别 | audio_slice | +| `text_feature` | 文本特征提取 | asr | +| `hubert_feature` | HuBERT 特征提取 | audio_slice | +| `semantic_token` | 语义 token 提取 | hubert_feature | +| `sovits_train` | SoVITS 训练 | text_feature, semantic_token | +| `gpt_train` | GPT 训练 | text_feature, semantic_token | + +**请求示例(执行音频切片)**: + +``` +POST /api/v1/experiments/exp-abc123/stages/audio_slice +``` + +```json +{ + "threshold": -34, + "min_length": 4000, + "min_interval": 300, + "hop_size": 10, + "max_sil_kept": 500 +} +``` + +**请求示例(执行 SoVITS 训练)**: + +``` +POST /api/v1/experiments/exp-abc123/stages/sovits_train +``` + +```json +{ + "batch_size": 8, + "total_epoch": 16, + "save_every_epoch": 4, + "pretrained_s2G": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + "pretrained_s2D": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth" +} +``` + +**响应示例**: + +```json +{ + "exp_id": "exp-abc123", + "stage_type": "sovits_train", + "status": "running", + "job_id": "job-xyz789", + "config": { + "batch_size": 8, + "total_epoch": 16, + "save_every_epoch": 4 + }, + "started_at": "2024-01-01T10:30:00Z" +} +``` + +### 4.4.4 获取阶段状态 + +``` +GET /api/v1/experiments/{exp_id}/stages/{stage_type} +``` + +**响应示例(已完成)**: + +```json +{ + "stage_type": "sovits_train", + "status": "completed", + "started_at": "2024-01-01T10:30:00Z", + "completed_at": "2024-01-01T11:00:00Z", + "config": { + "batch_size": 8, + "total_epoch": 16, + "save_every_epoch": 4 + }, + "outputs": { + "model_path": "logs/my_voice_custom/sovits_e16.pth", + "metrics": { + "final_loss": 0.023, + "best_epoch": 14 + } + } +} +``` + +**响应示例(运行中)**: + +```json +{ + "stage_type": "sovits_train", + "status": "running", + "started_at": "2024-01-01T10:30:00Z", + "progress": 0.45, + "message": "Epoch 8/16, Loss: 0.034" +} +``` + +### 4.4.5 获取所有阶段状态 + +``` +GET /api/v1/experiments/{exp_id}/stages +``` + +**响应示例**: + +```json +{ + "exp_id": "exp-abc123", + "stages": [ + { + "stage_type": "audio_slice", + "status": "completed", + "completed_at": "2024-01-01T10:05:00Z" + }, + { + "stage_type": "asr", + "status": "completed", + "completed_at": "2024-01-01T10:10:00Z" + }, + { + "stage_type": "text_feature", + "status": "completed", + "completed_at": "2024-01-01T10:12:00Z" + }, + { + "stage_type": "hubert_feature", + "status": "completed", + "completed_at": "2024-01-01T10:20:00Z" + }, + { + "stage_type": "semantic_token", + "status": "completed", + "completed_at": "2024-01-01T10:25:00Z" + }, + { + "stage_type": "sovits_train", + "status": "running", + "started_at": "2024-01-01T10:30:00Z", + "progress": 0.45 + }, + { + "stage_type": "gpt_train", + "status": "pending" + } + ] +} +``` + +### 4.4.6 SSE 阶段进度订阅 + +``` +GET /api/v1/experiments/{exp_id}/stages/{stage_type}/progress +``` + +返回 SSE 流,实时推送阶段进度: + +``` +event: progress +data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034} + +event: progress +data: {"epoch": 9, "total_epochs": 16, "progress": 0.56, "loss": 0.031} + +event: checkpoint +data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"} + +event: completed +data: {"status": "completed", "final_loss": 0.023} +``` + +### 4.4.7 取消阶段执行 + +``` +DELETE /api/v1/experiments/{exp_id}/stages/{stage_type} +``` + +取消正在执行的阶段: + +**响应示例**: + +```json +{ + "success": true, + "message": "阶段 sovits_train 已取消", + "stage_type": "sovits_train", + "status": "cancelled" +} +``` + +### 4.4.8 重新执行阶段 + +专家用户可以对任意已完成的阶段重新执行(使用新参数): + +``` +POST /api/v1/experiments/{exp_id}/stages/sovits_train +``` + +如果阶段已完成,再次调用会重新执行。响应中会包含 `rerun: true` 标记: + +```json +{ + "exp_id": "exp-abc123", + "stage_type": "sovits_train", + "status": "running", + "rerun": true, + "previous_run": { + "completed_at": "2024-01-01T11:00:00Z", + "outputs": { "model_path": "logs/my_voice/sovits_e16.pth" } + } +} +``` + +--- + +## 4.5 阶段参数模板 API + +### 4.5.1 获取阶段预设列表 + +``` +GET /api/v1/stages/presets +``` + +**响应示例**: + +```json +{ + "presets": [ + { + "id": "full_training", + "name": "完整训练流程", + "description": "包含所有阶段的标准训练", + "stages": ["audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token", "sovits_train", "gpt_train"] + }, + { + "id": "retrain_sovits", + "name": "重训 SoVITS", + "description": "跳过预处理,仅重新训练 SoVITS", + "stages": ["sovits_train"] + }, + { + "id": "feature_extraction", + "name": "特征提取", + "description": "仅执行音频切片和特征提取", + "stages": ["audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token"] + } + ] +} +``` + +### 4.5.2 获取阶段参数模板 + +``` +GET /api/v1/stages/{stage_type}/schema +``` + +**响应示例**(`/api/v1/stages/audio_slice/schema`): + +```json +{ + "type": "audio_slice", + "name": "音频切片", + "description": "将长音频切分为短片段", + "parameters": { + "threshold": { + "type": "integer", + "default": -34, + "min": -60, + "max": 0, + "description": "静音检测阈值 (dB)" + }, + "min_length": { + "type": "integer", + "default": 4000, + "min": 1000, + "max": 10000, + "description": "最小切片长度 (ms)" + }, + "min_interval": { + "type": "integer", + "default": 300, + "min": 100, + "max": 1000, + "description": "最小静音间隔 (ms)" + }, + "hop_size": { + "type": "integer", + "default": 10, + "min": 5, + "max": 50, + "description": "检测步长 (ms)" + }, + "max_sil_kept": { + "type": "integer", + "default": 500, + "min": 100, + "max": 2000, + "description": "切片保留的最大静音长度 (ms)" + } + } +} +``` + +**响应示例**(`/api/v1/stages/sovits_train/schema`): + +```json +{ + "type": "sovits_train", + "name": "SoVITS 训练", + "description": "训练 SoVITS 声码器模型", + "parameters": { + "batch_size": { + "type": "integer", + "default": 4, + "min": 1, + "max": 32, + "description": "批次大小,显存不足时减小" + }, + "total_epoch": { + "type": "integer", + "default": 8, + "min": 1, + "max": 100, + "description": "训练总轮数" + }, + "save_every_epoch": { + "type": "integer", + "default": 4, + "min": 1, + "description": "每 N 轮保存一次模型" + }, + "pretrained_s2G": { + "type": "string", + "default": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + "description": "预训练生成器模型路径" + }, + "pretrained_s2D": { + "type": "string", + "default": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth", + "description": "预训练判别器模型路径" + } + } +} +``` + +--- + +## 4.6 Pydantic Schema 设计 + +### 4.6.1 Quick Mode Schema + +```python +from typing import Literal, Optional +from pydantic import BaseModel, Field + +class QuickModeOptions(BaseModel): + version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = "v2" + language: str = "zh" + quality: Literal["fast", "standard", "high"] = "standard" + +class QuickModeRequest(BaseModel): + """小白用户一键训练请求""" + exp_name: str = Field(..., min_length=1, max_length=100) + audio_file_id: str + options: QuickModeOptions = QuickModeOptions() +``` + +### 4.6.2 Advanced Mode Schema + +```python +from typing import Literal, Optional, Dict, Any, List +from pydantic import BaseModel, Field +from datetime import datetime + +# ============================================================ +# 实验管理 +# ============================================================ + +class ExperimentCreate(BaseModel): + """创建实验请求""" + exp_name: str = Field(..., min_length=1, max_length=100, description="实验名称") + version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(default="v2", description="模型版本") + gpu_numbers: str = Field(default="0", description="GPU 编号") + is_half: bool = Field(default=True, description="是否使用半精度") + audio_file_id: str = Field(..., description="音频文件 ID") + +class ExperimentUpdate(BaseModel): + """更新实验请求""" + exp_name: Optional[str] = Field(None, min_length=1, max_length=100) + gpu_numbers: Optional[str] = None + is_half: Optional[bool] = None + +class StageStatus(BaseModel): + """阶段状态""" + stage_type: str + status: Literal["pending", "running", "completed", "failed", "cancelled"] + progress: Optional[float] = None + message: Optional[str] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + config: Optional[Dict[str, Any]] = None + outputs: Optional[Dict[str, Any]] = None + +class ExperimentResponse(BaseModel): + """实验响应""" + id: str + exp_name: str + version: str + status: str + gpu_numbers: str + is_half: bool + audio_file_id: str + stages: Dict[str, StageStatus] + created_at: datetime + updated_at: Optional[datetime] = None + +# ============================================================ +# 阶段执行 +# ============================================================ + +class StageExecuteRequest(BaseModel): + """阶段执行请求基类""" + class Config: + extra = "allow" # 允许额外字段(阶段特定参数) + +class AudioSliceParams(StageExecuteRequest): + """音频切片参数""" + threshold: int = Field(default=-34, ge=-60, le=0, description="静音检测阈值 (dB)") + min_length: int = Field(default=4000, ge=1000, le=10000, description="最小切片长度 (ms)") + min_interval: int = Field(default=300, ge=100, le=1000, description="最小静音间隔 (ms)") + hop_size: int = Field(default=10, ge=5, le=50, description="检测步长 (ms)") + max_sil_kept: int = Field(default=500, ge=100, le=2000, description="保留最大静音长度 (ms)") + +class ASRParams(StageExecuteRequest): + """ASR 参数""" + model: str = Field(default="达摩 ASR (中文)", description="ASR 模型") + language: str = Field(default="zh", description="语言") + +class SoVITSTrainParams(StageExecuteRequest): + """SoVITS 训练参数""" + batch_size: int = Field(default=4, ge=1, le=32, description="批次大小") + total_epoch: int = Field(default=8, ge=1, le=100, description="训练总轮数") + save_every_epoch: int = Field(default=4, ge=1, description="保存间隔") + pretrained_s2G: Optional[str] = Field(None, description="预训练生成器路径") + pretrained_s2D: Optional[str] = Field(None, description="预训练判别器路径") + +class GPTTrainParams(StageExecuteRequest): + """GPT 训练参数""" + batch_size: int = Field(default=4, ge=1, le=32, description="批次大小") + total_epoch: int = Field(default=15, ge=1, le=100, description="训练总轮数") + save_every_epoch: int = Field(default=5, ge=1, description="保存间隔") + pretrained_s1: Optional[str] = Field(None, description="预训练模型路径") + +class StageExecuteResponse(BaseModel): + """阶段执行响应""" + exp_id: str + stage_type: str + status: Literal["running", "queued"] + job_id: str + config: Dict[str, Any] + rerun: bool = False + previous_run: Optional[Dict[str, Any]] = None + started_at: datetime +``` + +### 4.6.3 Task Schema(Quick Mode 响应) + +```python +class TaskResponse(BaseModel): + """任务响应(Quick Mode)""" + id: str = Field(..., description="任务唯一标识") + exp_name: str = Field(..., description="实验名称") + status: Literal["queued", "running", "completed", "failed", "cancelled"] + current_stage: Optional[str] = None + progress: float = Field(default=0.0, ge=0.0, le=1.0, description="当前阶段进度") + overall_progress: float = Field(default=0.0, ge=0.0, le=1.0, description="总体进度") + message: Optional[str] = None + error_message: Optional[str] = None + created_at: Optional[datetime] = None + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + class Config: + from_attributes = True +``` + +--- + +## 4.7 API 实现示例 + +### 4.7.1 Quick Mode API 实现 + +```python +# app/api/v1/endpoints/tasks.py + +from fastapi import APIRouter, HTTPException, Depends +from app.services.task_service import TaskService +from app.models.schemas.task import QuickModeRequest, TaskResponse + +router = APIRouter() + +@router.post("/tasks", response_model=TaskResponse) +async def create_task( + request: QuickModeRequest, + task_service: TaskService = Depends(get_task_service) +): + """ + 创建一键训练任务(小白用户) + + 上传音频文件后,系统自动配置参数并执行完整训练流程。 + """ + return await task_service.create_quick_task(request) + +@router.get("/tasks/{task_id}", response_model=TaskResponse) +async def get_task( + task_id: str, + task_service: TaskService = Depends(get_task_service) +): + """获取任务详情""" + task = await task_service.get_task(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + return task + +@router.delete("/tasks/{task_id}") +async def cancel_task( + task_id: str, + task_service: TaskService = Depends(get_task_service) +): + """取消任务""" + success = await task_service.cancel_task(task_id) + if not success: + raise HTTPException(status_code=404, detail="Task not found or cannot be cancelled") + return {"success": True, "message": "任务已取消"} +``` + +### 4.7.2 Advanced Mode API 实现 + +```python +# app/api/v1/endpoints/experiments.py + +from fastapi import APIRouter, HTTPException, Depends, Body +from typing import Dict, Any +from app.services.experiment_service import ExperimentService +from app.models.schemas.experiment import ( + ExperimentCreate, + ExperimentResponse, + StageExecuteResponse, + StageStatus, +) + +router = APIRouter() + +@router.post("/experiments", response_model=ExperimentResponse) +async def create_experiment( + request: ExperimentCreate, + experiment_service: ExperimentService = Depends(get_experiment_service) +): + """ + 创建实验(专家用户) + + 创建实验但不立即执行,用户可以逐阶段控制训练流程。 + """ + return await experiment_service.create_experiment(request) + +@router.get("/experiments/{exp_id}", response_model=ExperimentResponse) +async def get_experiment( + exp_id: str, + experiment_service: ExperimentService = Depends(get_experiment_service) +): + """获取实验详情""" + experiment = await experiment_service.get_experiment(exp_id) + if not experiment: + raise HTTPException(status_code=404, detail="Experiment not found") + return experiment + +@router.post("/experiments/{exp_id}/stages/{stage_type}", response_model=StageExecuteResponse) +async def execute_stage( + exp_id: str, + stage_type: str, + params: Dict[str, Any] = Body(default={}), + experiment_service: ExperimentService = Depends(get_experiment_service) +): + """ + 执行指定阶段 + + 可传入阶段特定参数覆盖默认值。如果阶段已完成,会重新执行。 + """ + # 验证阶段类型 + valid_stages = ["audio_slice", "asr", "text_feature", "hubert_feature", + "semantic_token", "sovits_train", "gpt_train"] + if stage_type not in valid_stages: + raise HTTPException(status_code=400, detail=f"Invalid stage type: {stage_type}") + + # 检查依赖阶段是否完成 + dependencies = await experiment_service.check_stage_dependencies(exp_id, stage_type) + if not dependencies["satisfied"]: + raise HTTPException( + status_code=400, + detail=f"依赖阶段未完成: {', '.join(dependencies['missing'])}" + ) + + return await experiment_service.execute_stage(exp_id, stage_type, params) + +@router.get("/experiments/{exp_id}/stages", response_model=Dict[str, StageStatus]) +async def get_all_stages( + exp_id: str, + experiment_service: ExperimentService = Depends(get_experiment_service) +): + """获取所有阶段状态""" + return await experiment_service.get_all_stages(exp_id) + +@router.get("/experiments/{exp_id}/stages/{stage_type}", response_model=StageStatus) +async def get_stage( + exp_id: str, + stage_type: str, + experiment_service: ExperimentService = Depends(get_experiment_service) +): + """获取指定阶段状态和结果""" + stage = await experiment_service.get_stage(exp_id, stage_type) + if not stage: + raise HTTPException(status_code=404, detail="Stage not found") + return stage + +@router.delete("/experiments/{exp_id}/stages/{stage_type}") +async def cancel_stage( + exp_id: str, + stage_type: str, + experiment_service: ExperimentService = Depends(get_experiment_service) +): + """取消正在执行的阶段""" + success = await experiment_service.cancel_stage(exp_id, stage_type) + if not success: + raise HTTPException(status_code=400, detail="Stage not running or cannot be cancelled") + return {"success": True, "message": f"阶段 {stage_type} 已取消"} +``` + +### 4.7.3 服务层实现 + +```python +# app/services/experiment_service.py + +from typing import Dict, Any, Optional +from datetime import datetime +import uuid + +from app.core.adapters import database_adapter, task_queue_adapter +from app.models.schemas.experiment import ExperimentCreate, ExperimentResponse + +# 阶段依赖关系 +STAGE_DEPENDENCIES = { + "audio_slice": [], + "asr": ["audio_slice"], + "text_feature": ["asr"], + "hubert_feature": ["audio_slice"], + "semantic_token": ["hubert_feature"], + "sovits_train": ["text_feature", "semantic_token"], + "gpt_train": ["text_feature", "semantic_token"], +} + +class ExperimentService: + """实验服务(Advanced Mode)""" + + def __init__(self): + self.db = database_adapter + self.queue = task_queue_adapter + + async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse: + """创建实验""" + exp_id = f"exp-{uuid.uuid4().hex[:8]}" + + # 初始化所有阶段为 pending 状态 + stages = { + stage: {"status": "pending", "config": None, "outputs": None} + for stage in STAGE_DEPENDENCIES.keys() + } + + experiment = { + "id": exp_id, + "exp_name": request.exp_name, + "version": request.version, + "gpu_numbers": request.gpu_numbers, + "is_half": request.is_half, + "audio_file_id": request.audio_file_id, + "status": "created", + "stages": stages, + "created_at": datetime.utcnow(), + } + + await self.db.create_experiment(experiment) + return ExperimentResponse(**experiment) + + async def check_stage_dependencies(self, exp_id: str, stage_type: str) -> Dict: + """检查阶段依赖是否满足""" + experiment = await self.db.get_experiment(exp_id) + dependencies = STAGE_DEPENDENCIES.get(stage_type, []) + + missing = [] + for dep in dependencies: + if experiment["stages"][dep]["status"] != "completed": + missing.append(dep) + + return { + "satisfied": len(missing) == 0, + "missing": missing + } + + async def execute_stage( + self, + exp_id: str, + stage_type: str, + params: Dict[str, Any] + ) -> StageExecuteResponse: + """执行阶段""" + experiment = await self.db.get_experiment(exp_id) + + # 检查是否是重新执行 + current_stage = experiment["stages"][stage_type] + is_rerun = current_stage["status"] == "completed" + previous_run = current_stage if is_rerun else None + + # 构建阶段配置 + stage_config = { + "exp_id": exp_id, + "exp_name": experiment["exp_name"], + "version": experiment["version"], + "gpu_numbers": experiment["gpu_numbers"], + "is_half": experiment["is_half"], + "stage_type": stage_type, + "params": params, + } + + # 加入执行队列 + job_id = await self.queue.enqueue_stage( + exp_id=exp_id, + stage_type=stage_type, + config=stage_config + ) + + # 更新阶段状态 + await self.db.update_stage(exp_id, stage_type, { + "status": "running", + "config": params, + "started_at": datetime.utcnow(), + "job_id": job_id, + }) + + return StageExecuteResponse( + exp_id=exp_id, + stage_type=stage_type, + status="running", + job_id=job_id, + config=params, + rerun=is_rerun, + previous_run=previous_run, + started_at=datetime.utcnow(), + ) +``` + + +--- + +## 五、部署配置 + +### 5.1 本地模式 (macOS) + +**配置文件: config/local.yaml** +```yaml +deployment_mode: local +local_storage_path: ./data/files +sqlite_path: ./data/app.db +local_max_workers: 1 # macOS单GPU,串行执行 +``` + + +**启动命令**: +```shell script +# 安装依赖 +pip install -r requirements/base.txt -r requirements/local.txt + +# 启动API服务 +uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload + +# 无需额外服务! +``` + + +**docker-compose.local.yml**: +```yaml +version: '3.8' + +services: + api: + build: . + ports: + - "8000:8000" + volumes: + - ./data:/app/data + - ./logs:/app/logs + environment: + - DEPLOYMENT_MODE=local +``` + + +### 5.2 服务器模式 (Linux) + +**配置文件: config/server.yaml** +```yaml +deployment_mode: server +database_url: postgresql+asyncpg://user:pass@postgres/gpt_sovits +redis_url: redis://redis:6379/0 +celery_broker_url: redis://redis:6379/1 +s3_endpoint: minio:9000 +``` + + +**启动命令**: +```shell script +# 使用docker-compose启动所有服务 +docker-compose -f docker-compose.server.yml up -d +``` + + +**docker-compose.server.yml**: +```yaml +version: '3.8' + +services: + api: + build: . + ports: + - "8000:8000" + depends_on: + - postgres + - redis + - minio + environment: + - DEPLOYMENT_MODE=server + - DATABASE_URL=postgresql+asyncpg://user:pass@postgres/gpt_sovits + - REDIS_URL=redis://redis:6379/0 + + celery-worker: + build: . + command: celery -A app.workers.celery_worker worker --loglevel=info --concurrency=2 + depends_on: + - redis + - postgres + environment: + - DEPLOYMENT_MODE=server + deploy: + replicas: 2 # 多个Worker + + postgres: + image: postgres:15 + volumes: + - postgres_data:/var/lib/postgresql/data + environment: + POSTGRES_PASSWORD: password + + redis: + image: redis:7-alpine + + minio: + image: minio/minio + command: server /data --console-address ":9001" + ports: + - "9000:9000" + - "9001:9001" + volumes: + - minio_data:/data + +volumes: + postgres_data: + minio_data: +``` + + +--- + +## 六、数据库方案对比 + +### 6.1 本地模式 - SQLite + +**Schema**: +```sql +-- tasks表(Quick Mode 一键训练任务) +CREATE TABLE tasks ( + id TEXT PRIMARY KEY, + exp_name TEXT NOT NULL, + version TEXT NOT NULL, + status TEXT NOT NULL, + current_stage TEXT, + overall_progress REAL, + config TEXT, -- JSON + created_at TEXT, + started_at TEXT, + completed_at TEXT, + error_message TEXT +); + +-- experiments表(Advanced Mode 实验) +CREATE TABLE experiments ( + id TEXT PRIMARY KEY, + exp_name TEXT NOT NULL, + version TEXT NOT NULL, + exp_root TEXT DEFAULT 'logs', + gpu_numbers TEXT DEFAULT '0', + is_half INTEGER DEFAULT 1, + audio_file_id TEXT NOT NULL, + status TEXT NOT NULL, + created_at TEXT, + updated_at TEXT, + FOREIGN KEY (audio_file_id) REFERENCES files(id) +); + +-- stages表(Advanced Mode 阶段状态) +CREATE TABLE stages ( + id TEXT PRIMARY KEY, + experiment_id TEXT NOT NULL, + stage_type TEXT NOT NULL, + status TEXT DEFAULT 'pending', + progress REAL DEFAULT 0, + message TEXT, + job_id TEXT, + config TEXT, -- JSON + outputs TEXT, -- JSON + started_at TEXT, + completed_at TEXT, + error_message TEXT, + FOREIGN KEY (experiment_id) REFERENCES experiments(id) +); + +-- files表 +CREATE TABLE files ( + id TEXT PRIMARY KEY, + filename TEXT NOT NULL, + storage_path TEXT NOT NULL, + purpose TEXT, + size_bytes INTEGER, + uploaded_at TEXT +); + +-- models表 +CREATE TABLE models ( + id TEXT PRIMARY KEY, + task_id TEXT, + experiment_id TEXT, + exp_name TEXT NOT NULL, + model_type TEXT NOT NULL, + storage_path TEXT NOT NULL, + epoch INTEGER, + created_at TEXT, + FOREIGN KEY (task_id) REFERENCES tasks(id), + FOREIGN KEY (experiment_id) REFERENCES experiments(id) +); + +-- 索引 +CREATE INDEX idx_tasks_status ON tasks(status); +CREATE INDEX idx_experiments_status ON experiments(status); +CREATE INDEX idx_stages_experiment ON stages(experiment_id); +CREATE INDEX idx_stages_status ON stages(status); +``` + + +**迁移管理**: 使用简单的版本号文件 + SQL脚本 + +### 6.2 服务器模式 - PostgreSQL + +**使用SQLAlchemy + Alembic**: + +```python +# app/models/db/models.py + +from sqlalchemy import Column, String, Float, JSON, DateTime, Boolean, ForeignKey +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class TaskModel(Base): + """Quick Mode 任务模型""" + __tablename__ = "tasks" + + id = Column(String, primary_key=True) + exp_name = Column(String, nullable=False, index=True) + version = Column(String, nullable=False) + status = Column(String, nullable=False, index=True) + current_stage = Column(String) + overall_progress = Column(Float) + config = Column(JSON) + created_at = Column(DateTime, index=True) + started_at = Column(DateTime) + completed_at = Column(DateTime) + error_message = Column(String) + + +class ExperimentModel(Base): + """Advanced Mode 实验模型""" + __tablename__ = "experiments" + + id = Column(String, primary_key=True) + exp_name = Column(String, nullable=False, index=True) + version = Column(String, nullable=False) + exp_root = Column(String, default="logs") + gpu_numbers = Column(String, default="0") + is_half = Column(Boolean, default=True) + audio_file_id = Column(String, ForeignKey("files.id"), nullable=False) + status = Column(String, nullable=False, index=True) + created_at = Column(DateTime, index=True) + updated_at = Column(DateTime) + + # 关联 + stages = relationship("StageModel", back_populates="experiment") + + +class StageModel(Base): + """Advanced Mode 阶段模型""" + __tablename__ = "stages" + + id = Column(String, primary_key=True) + experiment_id = Column(String, ForeignKey("experiments.id"), nullable=False) + stage_type = Column(String, nullable=False) + status = Column(String, default="pending", index=True) + progress = Column(Float, default=0) + message = Column(String) + job_id = Column(String) + config = Column(JSON) + outputs = Column(JSON) + started_at = Column(DateTime) + completed_at = Column(DateTime) + error_message = Column(String) + + # 关联 + experiment = relationship("ExperimentModel", back_populates="stages") +``` + + +**迁移**: `alembic upgrade head` + +--- + +## 七、任务队列方案对比 + +### 7.0 关键发现:训练Pipeline的执行模型 + +> [!IMPORTANT] +> **训练任务实际上是通过子进程执行的!** +> +> 分析 `training_pipeline/stages/training.py` 发现,每个训练阶段都通过 `subprocess.Popen` 调用独立的Python脚本: +> ```python +> cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train.py --config "{tmp_config_path}"' +> self._process = self._run_command(cmd, wait=True) +> ``` + +**这意味着**: +1. GPU密集型训练计算发生在**独立的子进程**中,不受Python GIL限制 +2. FastAPI主进程仅需要"管理"这些子进程:启动、监控、停止 +3. ThreadPoolExecutor在这里只是一个"监工",等待阻塞的subprocess调用完成 +4. 更优雅的方案是使用 `asyncio.subprocess`,完全非阻塞 + +**进程模型图**: +``` +┌─────────────────────────────────────────────────────────────────┐ +│ FastAPI 主进程 │ +│ ┌──────────────────┐ ┌──────────────────────────────────┐ │ +│ │ AsyncIO Event │ │ AsyncTrainingManager │ │ +│ │ Loop │◄───│ - 管理子进程生命周期 │ │ +│ │ │ │ - 异步读取stdout/stderr │ │ +│ │ │ │ - 推送进度到SSE │ │ +│ └──────────────────┘ └───────────────┬──────────────────┘ │ +└─────────────────────────────────────────────┼───────────────────┘ + │ asyncio.create_subprocess_exec() + ┌─────────────────────────┼─────────────────────────┐ + ▼ ▼ ▼ + ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ + │ s2_train.py │ │ s1_train.py │ │ inference.py │ + │ (GPU训练) │ │ (GPU训练) │ │ (推理) │ + └──────────────┘ └──────────────┘ └──────────────┘ +``` + +### 7.0.1 进度追踪能力分析 + +分析 `GPT_SoVITS/s2_train.py` 发现,训练脚本的输出格式如下: + +| 输出类型 | 输出位置 | 示例 | 可追踪性 | +|---------|---------|------|---------| +| **Epoch进度** | logger → stdout | `"====> Epoch: 5"` | ✅ 可解析 | +| **训练百分比** | logger → stdout | `"Train Epoch: 1 [50.0%]"` | ✅ 可解析 | +| **Loss信息** | logger → stdout | `[0.23, 0.45, ...]` | ✅ 可解析 | +| **Batch进度条** | tqdm → stderr | `45%|████▌ | 45/100` | ⚠️ 格式不规则 | +| **模型保存** | logger → stdout | `"saving ckpt xxx_e5:..."` | ✅ 可解析 | + +**当前问题**: +1. ❌ 输出不是JSON格式,需要正则表达式解析 +2. ❌ tqdm进度条格式复杂,难以精确解析 +3. ❌ 没有统一的进度通信协议 + +**解决方案**:修改训练脚本,添加JSON格式的进度输出 + +```python +# 在训练脚本中添加进度报告函数 +import json +import sys + +def report_progress(stage: str, epoch: int, total_epochs: int, + batch: int = None, total_batches: int = None, + loss: dict = None, message: str = None): + """输出JSON格式的进度信息到stdout,供管理器解析""" + progress_info = { + "type": "progress", + "stage": stage, + "epoch": epoch, + "total_epochs": total_epochs, + "progress": epoch / total_epochs * 100, + } + if batch is not None: + progress_info["batch"] = batch + progress_info["total_batches"] = total_batches + progress_info["progress"] = (epoch - 1 + batch / total_batches) / total_epochs * 100 + if loss: + progress_info["loss"] = loss + if message: + progress_info["message"] = message + + # 使用特殊前缀标识,便于解析 + print(f"##PROGRESS##{json.dumps(progress_info)}##", flush=True) + +# 在训练循环中调用 +for epoch in range(epoch_str, hps.train.epochs + 1): + report_progress("SoVITS训练", epoch, hps.train.epochs, message=f"开始Epoch {epoch}") + for batch_idx, data in enumerate(train_loader): + # ... 训练代码 ... + if batch_idx % 10 == 0: # 每10个batch报告一次 + report_progress("SoVITS训练", epoch, hps.train.epochs, + batch_idx, len(train_loader), + loss={"g_total": loss_gen_all.item()}) +``` + +**管理器端解析**: + +```python +async def _monitor_process_output(self, task_id: str, process): + """解析子进程输出获取进度""" + async for line in process.stdout: + text = line.decode().strip() + + # 检测JSON进度标记 + if text.startswith("##PROGRESS##") and text.endswith("##"): + json_str = text[12:-2] # 提取JSON部分 + progress_info = json.loads(json_str) + await self._send_progress(task_id, progress_info) + + # 兼容旧格式:正则解析 + elif "Train Epoch:" in text: + match = re.search(r"Train Epoch: (\d+) \[(\d+\.?\d*)%\]", text) + if match: + epoch, percent = match.groups() + await self._send_progress(task_id, { + "stage": "SoVITS训练", + "epoch": int(epoch), + "progress": float(percent), + "message": text + }) +``` + +--- + +### 7.0.2 任务控制能力分析 + +| 操作 | 实现方式 | macOS支持 | 备注 | +|------|---------|-----------|------| +| **终止(Kill)** | `process.terminate()` | ✅ 完全支持 | 立即终止,可能丢失当前epoch | +| **强制终止** | `process.kill()` | ✅ 完全支持 | 发送SIGKILL,强制停止 | +| **暂停(Pause)** | `os.kill(pid, signal.SIGSTOP)` | ⚠️ 支持但有风险 | GPU/CUDA状态可能异常 | +| **恢复(Resume)** | `os.kill(pid, signal.SIGCONT)` | ⚠️ 需配合SIGSTOP | 同上 | +| **优雅停止** | 需要训练脚本配合 | ❌ 当前不支持 | 需要修改训练脚本 | + +**优雅停止方案**: + +需要修改训练脚本以支持信号处理: + +```python +# 在训练脚本开头添加 +import signal +import json + +should_stop = False +should_pause = False + +def handle_stop_signal(signum, frame): + """收到SIGUSR1时,完成当前epoch后停止""" + global should_stop + should_stop = True + print(json.dumps({"type": "signal", "message": "收到停止信号,将在当前epoch结束后停止"})) + +def handle_pause_signal(signum, frame): + """收到SIGUSR2时,暂停训练""" + global should_pause + should_pause = not should_pause + status = "暂停" if should_pause else "继续" + print(json.dumps({"type": "signal", "message": f"训练已{status}"})) + +signal.signal(signal.SIGUSR1, handle_stop_signal) +signal.signal(signal.SIGUSR2, handle_pause_signal) + +# 在训练循环中检查 +for epoch in range(epoch_str, hps.train.epochs + 1): + # 检查暂停 + while should_pause: + time.sleep(1) + + # 检查停止 + if should_stop: + print(json.dumps({"type": "progress", "status": "stopped", + "message": f"训练在Epoch {epoch}结束后停止"})) + # 保存checkpoint + save_checkpoint(...) + break + + # ... 正常训练 ... +``` + +**管理器端控制**: + +```python +class AsyncTrainingManager: + async def pause(self, task_id: str) -> bool: + """暂停任务""" + if task_id in self.running_processes: + process = self.running_processes[task_id] + os.kill(process.pid, signal.SIGUSR2) + return True + return False + + async def graceful_stop(self, task_id: str) -> bool: + """优雅停止(完成当前epoch后停止)""" + if task_id in self.running_processes: + process = self.running_processes[task_id] + os.kill(process.pid, signal.SIGUSR1) + return True + return False + + async def force_stop(self, task_id: str) -> bool: + """强制停止""" + if task_id in self.running_processes: + process = self.running_processes[task_id] + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + process.kill() + return True + return False +``` + +> [!WARNING] +> **暂停训练的风险**: +> - macOS上使用SIGSTOP/SIGCONT暂停进程可能导致GPU资源锁定 +> - 长时间暂停后恢复,CUDA上下文可能失效 +> - 推荐使用:保存checkpoint后终止,需要时��checkpoint恢复 + +--- + +### 7.1 本地模式 - 任务管理方案 ✅ 已实现 + +> [!TIP] +> 选择任务管理方案时,需要考虑: +> - **执行模型**:训练已经是子进程,任务管理器只需监控 +> - **交付形态**:PyInstaller打包需要单主进程 +> - **简洁性**:asyncio.subprocess 比 ThreadPool 更简洁 + +#### Option 1: asyncio.subprocess ⭐⭐ 推荐(所有场景)✅ 已选用并实现 + +> **实现文件**: `app/adapters/local/task_queue.py` + +**核心设计思想**: +- 利用 `asyncio.create_subprocess_exec()` 异步启动训练子进程 +- 完全非阻塞,与 FastAPI 的异步模型完美契合 +- 无需 ThreadPool,架构更简洁 +- 异步读取子进程输出,实时解析进度 + +```python +# 优点: +- 纯asyncio,与FastAPI完美集成 +- 无需ThreadPool,无线程管理开销 +- 异步监控多个子进程 +- 更简洁的代码结构 +- 完全兼容PyInstaller打包 + +# 缺点: +- 需要修改Pipeline执行方式(从同步改为异步) +- 进度解析需要从stdout/stderr提取 +``` + +**完整实现**: + +```python +# app/adapters/local/async_task_manager.py + +import asyncio +import json +import os +import sys +import uuid +from datetime import datetime +from typing import Dict, Optional, AsyncGenerator, List +from pathlib import Path +import aiosqlite + +from app.adapters.base import TaskQueueAdapter + + +class AsyncTrainingManager(TaskQueueAdapter): + """ + 基于asyncio.subprocess的异步任务管理器。 + + 特点: + 1. 使用asyncio.create_subprocess_exec()异步启动训练子进程 + 2. 完全非阻塞,与FastAPI异步模型完美契合 + 3. SQLite持久化任务状态,支持应用重启后恢复 + 4. 实时解析子进程输出获取进度 + """ + + def __init__(self, db_path: str = "./data/tasks.db"): + self.db_path = db_path + + # 运行时状态 + self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process + self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue + + # 初始化数据库 + self._init_db_sync() + + def _init_db_sync(self): + """同步初始化数据库(启动时调用)""" + import sqlite3 + Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(''' + CREATE TABLE IF NOT EXISTS task_queue ( + job_id TEXT PRIMARY KEY, + task_id TEXT NOT NULL, + config TEXT NOT NULL, + status TEXT DEFAULT 'queued', + current_stage TEXT, + progress REAL DEFAULT 0, + created_at TEXT, + started_at TEXT, + completed_at TEXT, + error_message TEXT + ) + ''') + conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_status ON task_queue(status)') + conn.commit() + + async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: + """将任务加入队列并异步启动""" + job_id = str(uuid.uuid4()) + + # 持久化到SQLite + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + '''INSERT INTO task_queue (job_id, task_id, config, status, created_at) + VALUES (?, ?, ?, 'queued', ?)''', + (job_id, task_id, json.dumps(config), datetime.utcnow().isoformat()) + ) + await db.commit() + + # 创建进度队列 + self.progress_channels[task_id] = asyncio.Queue() + + # 异步启动训练任务 + asyncio.create_task(self._run_training_async(job_id, task_id, config)) + + return job_id + + async def _run_training_async(self, job_id: str, task_id: str, config: Dict): + """异步执行训练Pipeline""" + try: + await self._update_status(job_id, 'running', started_at=datetime.utcnow().isoformat()) + await self._send_progress(task_id, {"status": "running", "message": "训练启动中..."}) + + # 构建训练脚本命令 + # 这里调用一个包装脚本,它会执行完整的Pipeline并输出JSON格式的进度 + script_path = self._get_pipeline_script_path() + config_path = await self._write_config_file(task_id, config) + + # 创建子进程 + process = await asyncio.create_subprocess_exec( + sys.executable, script_path, + '--config', config_path, + '--task-id', task_id, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env={**os.environ, 'PYTHONPATH': '.:GPT_SoVITS'} + ) + + self.running_processes[task_id] = process + + # 异步读取stdout并解析进度 + await self._monitor_process_output(task_id, process) + + # 等待进程完成 + returncode = await process.wait() + + if returncode == 0: + await self._update_status(job_id, 'completed', completed_at=datetime.utcnow().isoformat()) + await self._send_progress(task_id, {"status": "completed", "progress": 100, "message": "训练完成"}) + else: + stderr = await process.stderr.read() + error_msg = stderr.decode() if stderr else f"Process exited with code {returncode}" + await self._update_status(job_id, 'failed', error_message=error_msg) + await self._send_progress(task_id, {"status": "failed", "error": error_msg}) + + except asyncio.CancelledError: + await self._update_status(job_id, 'cancelled') + await self._send_progress(task_id, {"status": "cancelled", "message": "任务已取消"}) + except Exception as e: + await self._update_status(job_id, 'failed', error_message=str(e)) + await self._send_progress(task_id, {"status": "failed", "error": str(e)}) + finally: + self.running_processes.pop(task_id, None) + # 清理临时配置文件 + await self._cleanup_config_file(task_id) + + async def _monitor_process_output(self, task_id: str, process: asyncio.subprocess.Process): + """异步监控子进程输出并解析进度""" + async def read_stream(stream, is_stderr=False): + while True: + line = await stream.readline() + if not line: + break + + text = line.decode().strip() + if not text: + continue + + # 尝试解析JSON格式的进度信息 + if text.startswith('{') and text.endswith('}'): + try: + progress_info = json.loads(text) + await self._send_progress(task_id, progress_info) + + # 同时更新数据库中的进度 + if 'progress' in progress_info or 'stage' in progress_info: + await self._update_progress_in_db(task_id, progress_info) + except json.JSONDecodeError: + pass + elif is_stderr: + # stderr输出作为日志 + await self._send_progress(task_id, {"type": "log", "level": "error", "message": text}) + + # 并发读取stdout和stderr + await asyncio.gather( + read_stream(process.stdout, is_stderr=False), + read_stream(process.stderr, is_stderr=True) + ) + + async def _send_progress(self, task_id: str, progress_info: Dict): + """发送进度到订阅队列""" + if task_id in self.progress_channels: + await self.progress_channels[task_id].put(progress_info) + + async def _update_status(self, job_id: str, status: str, **kwargs): + """更新任务状态""" + async with aiosqlite.connect(self.db_path) as db: + updates = ["status = ?"] + values = [status] + + for key, value in kwargs.items(): + updates.append(f"{key} = ?") + values.append(value) + + values.append(job_id) + await db.execute( + f"UPDATE task_queue SET {', '.join(updates)} WHERE job_id = ?", + values + ) + await db.commit() + + async def _update_progress_in_db(self, task_id: str, progress_info: Dict): + """更新数据库中的进度""" + async with aiosqlite.connect(self.db_path) as db: + updates = [] + values = [] + + if 'progress' in progress_info: + updates.append("progress = ?") + values.append(progress_info['progress']) + if 'stage' in progress_info: + updates.append("current_stage = ?") + values.append(progress_info['stage']) + + if updates: + values.append(task_id) + await db.execute( + f"UPDATE task_queue SET {', '.join(updates)} WHERE task_id = ?", + values + ) + await db.commit() + + async def get_status(self, job_id: str) -> Dict: + """获取任务状态""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM task_queue WHERE job_id = ?", (job_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return dict(row) + return {"status": "not_found"} + + async def cancel(self, job_id: str) -> bool: + """取消任务""" + # 查找task_id + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + "SELECT task_id FROM task_queue WHERE job_id = ?", (job_id,) + ) as cursor: + row = await cursor.fetchone() + if not row: + return False + task_id = row[0] + + # 终止进程 + if task_id in self.running_processes: + process = self.running_processes[task_id] + process.terminate() + + # 等待进程终止 + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + process.kill() + + await self._update_status(job_id, 'cancelled') + return True + + return False + + async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: + """订阅任务进度(SSE流)""" + if task_id not in self.progress_channels: + self.progress_channels[task_id] = asyncio.Queue() + + queue = self.progress_channels[task_id] + + while True: + try: + progress = await asyncio.wait_for(queue.get(), timeout=30.0) + yield progress + + if progress.get('status') in ['completed', 'failed', 'cancelled']: + break + except asyncio.TimeoutError: + # 发送心跳保持连接 + yield {"type": "heartbeat", "timestamp": datetime.utcnow().isoformat()} + + async def recover_pending_tasks(self) -> int: + """ + 应用重启后恢复未完成的任务。 + + 注意:由于子进程在应用重启后已经终止,这里只能: + 1. 将running状态的任务标记为interrupted + 2. 可选择重新启动queued状态的任务 + """ + async with aiosqlite.connect(self.db_path) as db: + # 将running状态的任务标记为interrupted(需要用户决定是否重试) + await db.execute( + "UPDATE task_queue SET status = 'interrupted' WHERE status = 'running'" + ) + await db.commit() + + # 重新启动queued状态的任务 + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT * FROM task_queue WHERE status = 'queued' ORDER BY created_at" + ) as cursor: + queued_tasks = await cursor.fetchall() + + for task in queued_tasks: + task_id = task['task_id'] + config = json.loads(task['config']) + job_id = task['job_id'] + + self.progress_channels[task_id] = asyncio.Queue() + asyncio.create_task(self._run_training_async(job_id, task_id, config)) + + return len(queued_tasks) + + def _get_pipeline_script_path(self) -> str: + """获取Pipeline执行脚本路径""" + # 这个脚本会封装TrainingPipeline,并输出JSON格式的进度 + return os.path.join(os.path.dirname(__file__), '..', '..', 'scripts', 'run_pipeline.py') + + async def _write_config_file(self, task_id: str, config: Dict) -> str: + """写入临时配置文件""" + config_dir = Path(self.db_path).parent / 'configs' + config_dir.mkdir(exist_ok=True) + config_path = config_dir / f"{task_id}.json" + + async with aiosqlite.connect(self.db_path): # 确保目录可写 + pass + + with open(config_path, 'w') as f: + json.dump(config, f) + + return str(config_path) + + async def _cleanup_config_file(self, task_id: str): + """清理临时配置文件""" + config_path = Path(self.db_path).parent / 'configs' / f"{task_id}.json" + if config_path.exists(): + config_path.unlink() +``` + + +#### Option 2: ThreadPoolExecutor + SQLite持久化(备选方案) + +如果不想修改现有的Pipeline执行方式,可以继续使用ThreadPool包装同步调用: + +```python +# 优点: +- 无需修改现有Pipeline代码 +- 标准库,依赖极少 +- 实现简单 + +# 缺点: +- ThreadPool线程仅用于等待阻塞的subprocess +- 资源利用不够优雅 +- 不是真正的异步 +``` + +> [!NOTE] +> 此方案使用 `concurrent.futures.ThreadPoolExecutor` 将同步的 subprocess 调用包装为异步操作。 +> 虽然功能可行,但与 asyncio.subprocess 相比增加了不必要的线程开销。 + +```python +# 简易实现逻辑 +from concurrent.futures import ThreadPoolExecutor + +class ThreadPoolAdapter(TaskQueueAdapter): + def __init__(self): + self.executor = ThreadPoolExecutor(max_workers=1) + + async def enqueue(self, task_id, config, priority="normal"): + job_id = str(uuid.uuid4()) + # 在线程中执行同步的 run_pipeline + self.executor.submit(self._run_sync, task_id, config) + return job_id + + def _run_sync(self, task_id, config): + # 同步执行 Pipeline + pipeline = TrainingPipeline(config) + pipeline.run() +``` + + +#### Option 3: Huey(仅适合开发模式,不推荐用于PyInstaller打包) + +> [!WARNING] +> Huey需要独立的consumer进程,**不适合**PyInstaller打包和Electron集成场景。 +> 仅在纯Python开发模式下使用。 + +```python +# 安装 +pip install huey + +# 配置 +from huey import SqliteHuey + +huey = SqliteHuey('gpt_sovits', filename='./data/tasks.db') + +@huey.task() +def execute_training_pipeline(task_id, config): + # 执行训练 + pass + +# 优点: +- 轻量级(~1000行代码) +- 支持SQLite后端(持久化) +- 支持任务重试、定时任务 +- 支持优先级队列 +- 无需额外服务 + +# 缺点: +- 需要独立的huey_consumer进程 +- 不兼容PyInstaller单文件打包 +- 功能不如Celery丰富 +- 社区较小 +``` + + +--- + +### 7.2 服务器模式 - Celery [Phase 2] + +> **注意**: 此部分为 Phase 2 服务器模式的设计,当前阶段优先实现本地模式。 + +```python +# app/workers/celery_worker.py + +from celery import Celery +from app.core.config import settings + +celery_app = Celery( + 'gpt_sovits', + broker=settings.CELERY_BROKER_URL, + backend=settings.CELERY_RESULT_BACKEND +) + +celery_app.conf.update( + task_serializer='json', + accept_content=['json'], + result_serializer='json', + timezone='UTC', + task_routes={ + 'app.workers.celery_worker.execute_training_pipeline': {'queue': 'training'}, + 'app.workers.celery_worker.execute_inference': {'queue': 'inference'} + } +) + +@celery_app.task(bind=True, max_retries=3) +def execute_training_pipeline(self, task_id: str, config: dict): + """执行训练Pipeline(与Huey版本类似)""" + # 实现逻辑同上 + pass +``` + + +--- + +## 八、完整对比表 + +| 维度 | 本地开发模式 (macOS) | PyInstaller/Electron模式 | 服务器模式 (Linux) | +|------|---------------------|--------------------------|-------------------| +| **数据库** | SQLite (单文件) | SQLite (单文件) | PostgreSQL (集群) | +| **任务管理** | asyncio.subprocess ⭐ | asyncio.subprocess ⭐ | Celery + Redis | +| **执行模型** | 子进程(s2_train.py等) | 子进程(s2_train.py等) | 分布式Worker | +| **文件存储** | 本地文件系统 | 本地文件系统 | MinIO/S3 | +| **进度管理** | stdout解析 + asyncio.Queue | stdout解析 + asyncio.Queue | Redis Pub/Sub | +| **并发能力** | 1-2个任务 | 1个任务(串行) | 无限(水平扩展) | +| **依赖服务** | 0 (全in-one) | 0 (全in-one) | 3+ (PostgreSQL, Redis, MinIO) | +| **启动命令** | `uvicorn app.main:app` | Electron启动Python子进程 | `docker-compose up` | +| **适用场景** | 开发调试 | 桌面应用分发 | 生产环境、多用户 | +| **部署复杂度** | ⭐ | ⭐⭐ | ⭐⭐⭐⭐ | +| **打包支持** | 不需要 | PyInstaller单文件 | Docker镜像 | +| **维护成本** | 低 | 低 | 中等 | + +--- + +## 九、推荐实现路径 + +### Phase 1: 本地模式MVP + +#### 1.1 架构设计与 Schema 定义 ✅ 已完成 + +| 任务 | 状态 | 说明 | +|------|------|------| +| API 架构设计 | ✅ 完成 | 双模式设计(Quick Mode + Advanced Mode) | +| Pydantic Schema 设计 | ✅ 完成 | development.md 中完整定义 | +| 数据库 Schema 设计 | ✅ 完成 | tasks, experiments, stages 表结构 | +| 阶段参数 Schema 设计 | ✅ 完成 | AudioSliceParams, SoVITSTrainParams 等 | + +#### 1.2 核心基础设施 ✅ 已完成 + +| 任务 | 状态 | 实现文件 | +|------|------|----------| +| 适配器抽象基类 | ✅ 完成 | `app/adapters/base.py` - TaskQueueAdapter, ProgressAdapter | +| AsyncTrainingManager | ✅ 完成 | `app/adapters/local/task_queue.py` - 完整实现 | +| 配置管理模块 | ✅ 完成 | `app/core/config.py` - Settings, 路径常量 | +| 领域模型 | ✅ 完成 | `app/models/domain.py` - Task, TaskStatus, ProgressInfo | +| Pipeline 包装脚本 | ✅ 完成 | `app/scripts/run_pipeline.py` - 子进程执行器 | + +**AsyncTrainingManager 已实现功能:** +- ✅ 任务入队与异步执行 (`enqueue`) +- ✅ 子进程管理 (`asyncio.create_subprocess_exec`) +- ✅ 进度解析与推送 (`_monitor_process_output`) +- ✅ 任务状态查询 (`get_status`, `get_status_by_task_id`) +- ✅ 任务取消 (`cancel`) +- ✅ 进度订阅 SSE 流 (`subscribe_progress`) +- ✅ 任务列表查询 (`list_tasks`) +- ✅ 任务恢复机制 (`recover_pending_tasks`) +- ✅ 旧任务清理 (`cleanup_old_tasks`) + +#### 1.3 Pydantic Schema 文件 ✅ 已完成 + +| 任务 | 状态 | 说明 | +|------|------|------| +| `app/models/schemas/common.py` | ✅ 完成 | SuccessResponse, ErrorResponse, PaginatedResponse | +| `app/models/schemas/task.py` | ✅ 完成 | QuickModeOptions, QuickModeRequest, TaskResponse, TaskListResponse | +| `app/models/schemas/experiment.py` | ✅ ��成 | ExperimentCreate, StageStatus, 各阶段参数类等 | +| `app/models/schemas/file.py` | ✅ 完成 | FileMetadata, FileUploadResponse, FileListResponse | + +#### 1.4 存储与数据库适配器 ✅ 已完成 + +| 任务 | 状态 | 说明 | +|------|------|------| +| StorageAdapter 抽象类 | ✅ 完成 | `app/adapters/base.py` - 文件存储接口 | +| DatabaseAdapter 抽象类 | ✅ 完成 | `app/adapters/base.py` - 数据库操作接口 | +| LocalStorageAdapter | ✅ 完成 | `app/adapters/local/storage.py` - 本地文件系统存储 | +| SQLiteAdapter | ✅ 完成 | `app/adapters/local/database.py` - SQLite 数据库适配器 | +| LocalProgressAdapter | ✅ 完成 | `app/adapters/local/progress.py` - 内存进度管理 | + +**LocalStorageAdapter 已实现功能:** +- ✅ 文件上传/下载 (`upload_file`, `download_file`) +- ✅ 文件删除 (`delete_file`) +- ✅ 元数据管理 (`.meta.json` 文件) +- ✅ 文件列表查询 (`list_files`) +- ✅ 音频信息提取(时长、采样率) + +**SQLiteAdapter 已实现功能:** +- ✅ Task CRUD (Quick Mode) +- ✅ Experiment CRUD (Advanced Mode) +- ✅ Stage 状态管理 +- ✅ File 记录管理 +- ✅ 自动表结构初始化 + +**LocalProgressAdapter 已实现功能:** +- ✅ 进度更新与存储 (`update_progress`) +- ✅ 订阅者模式 (`subscribe`) +- ✅ 多订阅者支持 +- ✅ 心跳机制 + +#### 1.5 API 端点 ✅ 已完成 + +| 任务 | 状态 | 说明 | +|------|------|------| +| Quick Mode API (`/tasks`) | ✅ 已实现 | `app/api/v1/endpoints/tasks.py` | +| Advanced Mode API (`/experiments`) | ✅ 已实现 | `app/api/v1/endpoints/experiments.py` | +| 文件管理 API (`/files`) | ✅ 已实现 | `app/api/v1/endpoints/files.py` | +| 阶段模板 API (`/stages`) | ✅ 已实现 | `app/api/v1/endpoints/stages.py` | +| 路由注册 | ✅ 已实现 | `app/api/v1/router.py` | +| FastAPI 入口 | ✅ 已实现 | `app/main.py` | +| 适配器工厂 | ✅ 已实现 | `app/core/adapters.py` | +| 依赖注入 | ✅ 已实现 | `app/api/deps.py` | + +**API 端点已实现功能:** +- ✅ Quick Mode: 创建任务、任务列表、任务详情、取消任务、SSE 进度订阅 +- ✅ Advanced Mode: 创建实验、实验列表、实验详情、更新/删除实验、执行阶段、阶段状态、取消阶段、SSE 阶段进度 +- ✅ 文件管理: 上传文件、文件列表、下载文件、删除文件 +- ✅ 阶段模板: 预设列表、阶段参数模板 + +#### 1.6 服务层 ✅ 已完成 + +| 任务 | 状态 | 说明 | +|------|------|------| +| TaskService | ✅ 已实现 | `app/services/task_service.py` | +| ExperimentService | ✅ 已实现 | `app/services/experiment_service.py` | +| FileService | ✅ 已实现 | `app/services/file_service.py` | + +**服务层已实现功能:** +- ✅ TaskService: 创建一键训练任务、质量预设配置、任务状态管理、进度订阅 +- ✅ ExperimentService: 实验 CRUD、阶段依赖检查、阶段执行/取消、进度订阅 +- ✅ FileService: 文件上传/下载、元数据管理、音频信息提取 + +#### 1.7 测试与验证 + +| 任务 | 状态 | 说明 | +|------|------|------| +| Quick Mode 端到端测试 | 🔲 待开始 | 上传音频 → 训练完成 | +| Advanced Mode 分阶段测试 | 🔲 待开始 | 逐阶段执行 + 重新执行 | +| 任务取消/恢复测试 | 🔲 待开始 | 验证任务生命周期管理 | + +--- + +### Phase 2: Electron 集成准备 + +| 任务 | 状态 | 说明 | +|------|------|------| +| 任务持久化和恢复机制 | 🔲 待开始 | 应用重启后恢复任务状态 | +| PyInstaller 打包配置 | 🔲 待开始 | .spec 文件配置 | +| Electron 进程管理模块 | 🔲 待开始 | spawn/kill Python 进程 | +| IPC 通信层 | 🔲 待开始 | HTTP API 或 WebSocket | +| macOS 签名和公证 | 🔲 待开始 | 可选,用于分发 | + +--- + +### Phase 3: 服务器模式 + +| 任务 | 状态 | 说明 | +|------|------|------| +| PostgreSQL 适配器 | 🔲 待开始 | SQLAlchemy + Alembic | +| Celery 任务队列适配器 | 🔲 待开始 | 分布式任务执行 | +| S3/MinIO 存储适配器 | 🔲 待开始 | 对象存储 | +| Redis 进度管理适配器 | 🔲 待开始 | Pub/Sub 进度推送 | +| 认证授权 | 🔲 待开始 | JWT / API Key | +| 监控告警 | 🔲 待开始 | Prometheus + Grafana | +| Docker 部署配置 | 🔲 待开始 | docker-compose.yml | + +--- + +### Phase 4: 增强功能 + +| 任务 | 状态 | 说明 | +|------|------|------| +| 模型版本管理 | 🔲 待开始 | 多版本模型存储和切换 | +| 批量推理 | 🔲 待开始 | 批量 TTS 生成 | +| 定时任务 | 🔲 待开始 | 计划训练任务 | +| Webhook 通知 | 🔲 待开始 | 训练完成回调 | +| 训练数据集管理 | 🔲 待开始 | 数据集版本控制 | + + +--- + +## 十、关键代码示例 + +### 10.1 启动文件(自动识别模式) + +```python +# app/main.py + +from fastapi import FastAPI +from app.core.config import settings +from app.api.v1.router import api_router + +app = FastAPI(title=settings.PROJECT_NAME) + +@app.on_event("startup") +async def startup_event(): + print(f"Starting in {settings.DEPLOYMENT_MODE.upper()} mode") + + if settings.DEPLOYMENT_MODE == "local": + print("Using SQLite + Huey + Local FileSystem") + # 启动Huey consumer(如果在同一进程) + # 或者提示用户启动: huey_consumer app.workers.local_worker.huey + else: + print("Using PostgreSQL + Celery + MinIO") + # 初始化数据库连接池 + # 预热Redis连接 + +app.include_router(api_router, prefix=settings.API_V1_PREFIX) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) +``` + + +### 10.2 环境变量配置 + +**.env.local**: +``` +DEPLOYMENT_MODE=local +LOCAL_STORAGE_PATH=./data/files +SQLITE_PATH=./data/app.db +LOCAL_MAX_WORKERS=1 +``` + + +**.env.server**: +``` +DEPLOYMENT_MODE=server +DATABASE_URL=postgresql+asyncpg://user:pass@localhost/gpt_sovits +REDIS_URL=redis://localhost:6379/0 +CELERY_BROKER_URL=redis://localhost:6379/1 +S3_ENDPOINT=localhost:9000 +S3_ACCESS_KEY=minioadmin +S3_SECRET_KEY=minioadmin +``` + + +--- + +## 十一、Electron集成指南 + +### 11.1 架构概览 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Electron Main Process │ +│ ┌─────────────────┐ ┌──────────────────────────────┐ │ +│ │ Process Manager │────▶│ Python (PyInstaller Bundle) │ │ +│ └─────────────────┘ │ ┌──────────────────────────┐│ │ +│ │ │ │ FastAPI HTTP Server ││ │ +│ │ │ │ + ThreadPool Queue ││ │ +│ │ │ │ + SQLite Database ││ │ +│ │ │ └──────────────────────────┘│ │ +│ │ └──────────────────────────────┘ │ +│ │ │ │ +│ ┌───────▼─────────────────────────────▼─────────────────┐ │ +│ │ Renderer Process (Vue/React) │ │ +│ │ HTTP API / SSE Progress Subscription │ │ +│ └───────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 11.2 Python进程管理(Electron侧) + +```javascript +// electron/python-manager.js + +const { spawn } = require('child_process'); +const path = require('path'); +const http = require('http'); + +class PythonProcessManager { + constructor() { + this.pythonProcess = null; + this.apiPort = 8765; + this.isReady = false; + } + + /** + * 启动Python后端进程 + */ + start() { + return new Promise((resolve, reject) => { + const pythonPath = this.getPythonPath(); + + this.pythonProcess = spawn(pythonPath, [], { + env: { + ...process.env, + DEPLOYMENT_MODE: 'local', + API_PORT: this.apiPort.toString(), + // 使用Electron的userData目录存储数据 + DATA_PATH: path.join(app.getPath('userData'), 'training-data') + }, + stdio: ['pipe', 'pipe', 'pipe'] + }); + + this.pythonProcess.stdout.on('data', (data) => { + console.log(`[Python] ${data}`); + // 检测服务启动完成 + if (data.toString().includes('Uvicorn running on')) { + this.isReady = true; + resolve(); + } + }); + + this.pythonProcess.stderr.on('data', (data) => { + console.error(`[Python Error] ${data}`); + }); + + this.pythonProcess.on('close', (code) => { + console.log(`Python process exited with code ${code}`); + this.isReady = false; + }); + + // 超时处理 + setTimeout(() => { + if (!this.isReady) { + reject(new Error('Python server startup timeout')); + } + }, 30000); + }); + } + + /** + * 获取打包后的Python可执行文件路径 + */ + getPythonPath() { + if (process.env.NODE_ENV === 'development') { + return 'python'; // 开发模式使用系统Python + } + + // 生产模式使用PyInstaller打包的可执行文件 + const resourcesPath = process.resourcesPath; + if (process.platform === 'darwin') { + return path.join(resourcesPath, 'python', 'gpt-sovits-api'); + } else if (process.platform === 'win32') { + return path.join(resourcesPath, 'python', 'gpt-sovits-api.exe'); + } + return path.join(resourcesPath, 'python', 'gpt-sovits-api'); + } + + /** + * 等待API服务就绪 + */ + async waitForReady(maxRetries = 30) { + for (let i = 0; i < maxRetries; i++) { + try { + await this.healthCheck(); + return true; + } catch { + await new Promise(r => setTimeout(r, 1000)); + } + } + return false; + } + + /** + * 健康检查 + */ + healthCheck() { + return new Promise((resolve, reject) => { + http.get(`http://localhost:${this.apiPort}/health`, (res) => { + if (res.statusCode === 200) resolve(); + else reject(); + }).on('error', reject); + }); + } + + /** + * 停止Python进程 + */ + stop() { + if (this.pythonProcess) { + this.pythonProcess.kill('SIGTERM'); + this.pythonProcess = null; + this.isReady = false; + } + } + + /** + * 获取API基础URL + */ + getApiBaseUrl() { + return `http://localhost:${this.apiPort}`; + } +} + +module.exports = PythonProcessManager; +``` + + +### 11.3 PyInstaller打包配置 + +```python +# gpt-sovits-api.spec + +# -*- mode: python ; coding: utf-8 -*- + +block_cipher = None + +a = Analysis( + ['app/main.py'], + pathex=[], + binaries=[], + datas=[ + # 包含预训练模型 + ('pretrained_models', 'pretrained_models'), + # 包含配置文件 + ('config', 'config'), + ], + hiddenimports=[ + 'uvicorn.logging', + 'uvicorn.loops', + 'uvicorn.loops.auto', + 'uvicorn.protocols', + 'uvicorn.protocols.http', + 'uvicorn.protocols.http.auto', + 'uvicorn.protocols.websockets', + 'uvicorn.protocols.websockets.auto', + 'uvicorn.lifespan', + 'uvicorn.lifespan.on', + 'aiosqlite', + 'torch', + 'torchaudio', + # 添加所有需要的隐式导入 + ], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[ + 'tkinter', + 'matplotlib', + 'IPython', + 'jupyter', + ], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) + +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.zipfiles, + a.datas, + [], + name='gpt-sovits-api', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=True, # 设为False隐藏控制台 + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) +``` + + +### 11.4 适配器工厂更新(支持Electron模式) + +```python +# app/core/adapters.py + +from app.core.config import settings +import os + +class AdapterFactory: + @staticmethod + def create_task_queue_adapter(): + # PyInstaller/Electron模式下强制使用ThreadPool + if settings.DEPLOYMENT_MODE == "local": + from app.adapters.local.task_queue import LocalTaskQueueAdapter + + # 根据环境确定数据路径 + data_path = os.environ.get('DATA_PATH', './data') + db_path = os.path.join(data_path, 'tasks.db') + + return LocalTaskQueueAdapter( + max_workers=settings.LOCAL_MAX_WORKERS, + db_path=db_path + ) + else: + from app.adapters.server.task_queue import CeleryTaskQueueAdapter + return CeleryTaskQueueAdapter( + broker_url=settings.CELERY_BROKER_URL, + backend_url=settings.CELERY_RESULT_BACKEND + ) +``` + + +### 11.5 打包和分发检查清单 + +```markdown +## macOS打包检查清单 + +- [ ] 签名Python可执行文件(如需分发到App Store外) +- [ ] 处理Gatekeeper问题(首次运行需要右键打开) +- [ ] 测试在干净的系统上启动 +- [ ] 验证模型文件正确打包 +- [ ] 测试任务恢复机制 +- [ ] 验证进度SSE流正常工作 +- [ ] 测试Electron退出时Python进程正确清理 + +## 目录结构 + +YourApp.app/ +├── Contents/ +│ ├── MacOS/ +│ │ └── YourApp # Electron主程序 +│ ├── Resources/ +│ │ ├── python/ +│ │ │ └── gpt-sovits-api # PyInstaller打包的Python +│ │ ├── pretrained_models/ # 预训练模型 +│ │ └── ... +│ └── Info.plist +``` + +--- + +## 总结 + +此架构设计核心思想: + +1. **统一接口**: API层和业务逻辑层完全统一 +2. **适配器模式**: 底层存储/队列/缓存通过适配器切换 +3. **配置驱动**: 通过环境变量控制部署模式 +4. **渐进式**: 先实现本地版本(快速验证),再扩展到服务器版本 +5. **零依赖���地部署**: 本地模式无需Docker、Redis、PostgreSQL +6. **子进程执行模型**: 训练任务通过subprocess执行,主进程仅管理 +7. **asyncio.subprocess推荐**: 完全非阻塞,与FastAPI完美契合 + +**推荐起步**: +- **所有本地场景**: 使用 `asyncio.subprocess` + SQLite 方案(`AsyncTrainingManager`) +- **Electron桌面应用**: 同上,完全兼容PyInstaller打包 +- **服务器生产环境**: 使用Celery + Redis实现分布式任务队列 + +> [!TIP] +> 关键洞察:既然训练Pipeline已经通过subprocess调用独立的Python脚本, +> 那么使用 `asyncio.create_subprocess_exec()` 是最自然的选择, +> 无需引入ThreadPool的额外复杂性。 \ No newline at end of file