liumaolin
feat(api): implement AsyncTrainingManager MVP with SQLite persistence
e43edbb
"""
领域模型模块
定义训练任务相关的核心数据结构
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, Optional, Any
class TaskStatus(Enum):
"""任务状态枚举"""
QUEUED = "queued" # 已入队,等待执行
RUNNING = "running" # 执行中
COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败
CANCELLED = "cancelled" # 已取消
INTERRUPTED = "interrupted" # 被中断(应用重启时运行中的任务)
@dataclass
class Task:
"""
训练任务领域模型
Attributes:
id: 任务唯一标识
job_id: 队列作业ID(由任务队列生成)
exp_name: 实验名称
status: 任务状态
config: 任务配置(包含所有训练参数)
current_stage: 当前执行阶段
progress: 总体进度 (0.0-1.0)
stage_progress: 当前阶段进度 (0.0-1.0)
message: 最新状态消息
error_message: 错误信息(失败时)
created_at: 创建时间
started_at: 开始执行时间
completed_at: 完成时间
Example:
>>> task = Task(
... id="task-123",
... exp_name="my_voice",
... config={"version": "v2", "batch_size": 4}
... )
>>> task.status
<TaskStatus.QUEUED: 'queued'>
"""
id: str
exp_name: str
config: Dict[str, Any]
job_id: Optional[str] = None
status: TaskStatus = TaskStatus.QUEUED
current_stage: Optional[str] = None
progress: float = 0.0
stage_progress: float = 0.0
message: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime = field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"job_id": self.job_id,
"exp_name": self.exp_name,
"status": self.status.value,
"config": self.config,
"current_stage": self.current_stage,
"progress": self.progress,
"stage_progress": self.stage_progress,
"message": self.message,
"error_message": self.error_message,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Task":
"""从字典创建实例"""
# 处理状态枚举
status = data.get("status", "queued")
if isinstance(status, str):
status = TaskStatus(status)
# 处理日期时间
def parse_datetime(value):
if value is None:
return None
if isinstance(value, datetime):
return value
return datetime.fromisoformat(value)
return cls(
id=data["id"],
job_id=data.get("job_id"),
exp_name=data["exp_name"],
status=status,
config=data.get("config", {}),
current_stage=data.get("current_stage"),
progress=data.get("progress", 0.0),
stage_progress=data.get("stage_progress", 0.0),
message=data.get("message"),
error_message=data.get("error_message"),
created_at=parse_datetime(data.get("created_at")),
started_at=parse_datetime(data.get("started_at")),
completed_at=parse_datetime(data.get("completed_at")),
)
@dataclass
class ProgressInfo:
"""
进度信息数据结构
用于在子进程和主进程之间传递进度更新
Attributes:
type: 消息类型 ("progress", "log", "error", "heartbeat")
stage: 当前阶段名称
stage_index: 当前阶段索引
total_stages: 总阶段数
progress: 阶段内进度 (0.0-1.0)
overall_progress: 总体进度 (0.0-1.0)
message: 进度消息
status: 状态
data: 附加数据
"""
type: str = "progress"
stage: Optional[str] = None
stage_index: Optional[int] = None
total_stages: Optional[int] = None
progress: float = 0.0
overall_progress: float = 0.0
message: Optional[str] = None
status: Optional[str] = None
data: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"type": self.type,
"stage": self.stage,
"stage_index": self.stage_index,
"total_stages": self.total_stages,
"progress": self.progress,
"overall_progress": self.overall_progress,
"message": self.message,
"status": self.status,
"data": self.data,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ProgressInfo":
"""从字典创建实例"""
return cls(
type=data.get("type", "progress"),
stage=data.get("stage"),
stage_index=data.get("stage_index"),
total_stages=data.get("total_stages"),
progress=data.get("progress", 0.0),
overall_progress=data.get("overall_progress", 0.0),
message=data.get("message"),
status=data.get("status"),
data=data.get("data", {}),
)