| """ |
| Database models and initialization for Universal Model Trainer |
| """ |
|
|
| from sqlalchemy import ( |
| Column, Integer, String, Text, Float, Boolean, DateTime, |
| ForeignKey, JSON, Enum, create_engine |
| ) |
| from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine |
| from sqlalchemy.orm import sessionmaker, relationship, declarative_base |
| from datetime import datetime |
| import enum |
| import os |
|
|
| from app.config import settings |
|
|
| |
| DATABASE_PATH = settings.DATABASE_URL.replace("sqlite:///./", "").replace("sqlite://", "") |
| os.makedirs(os.path.dirname(DATABASE_PATH) if os.path.dirname(DATABASE_PATH) else ".", exist_ok=True) |
|
|
| |
| ASYNC_DB_URL = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://") |
| engine = create_async_engine(ASYNC_DB_URL, echo=settings.DEBUG) |
|
|
| AsyncSessionLocal = sessionmaker( |
| engine, class_=AsyncSession, expire_on_commit=False |
| ) |
|
|
| Base = declarative_base() |
|
|
|
|
| class JobStatus(str, enum.Enum): |
| """Training job status enum.""" |
| PENDING = "pending" |
| QUEUED = "queued" |
| RUNNING = "running" |
| COMPLETED = "completed" |
| FAILED = "failed" |
| CANCELLED = "cancelled" |
| PAUSED = "paused" |
|
|
|
|
| class TaskType(str, enum.Enum): |
| """Supported task types.""" |
| CAUSAL_LM = "causal-lm" |
| SEQ2SEQ = "seq2seq" |
| TOKEN_CLASSIFICATION = "token-classification" |
| SEQUENCE_CLASSIFICATION = "sequence-classification" |
| QUESTION_ANSWERING = "question-answering" |
| SUMMARIZATION = "summarization" |
| TRANSLATION = "translation" |
| TEXT_CLASSIFICATION = "text-classification" |
| MASKED_LM = "masked-lm" |
| VISION_CLASSIFICATION = "vision-classification" |
| VISION_SEGMENTATION = "vision-segmentation" |
| AUDIO_CLASSIFICATION = "audio-classification" |
| AUDIO_TRANSCRIPTION = "audio-transcription" |
|
|
|
|
| class TrainingJob(Base): |
| """Model for training jobs.""" |
| __tablename__ = "training_jobs" |
| |
| id = Column(Integer, primary_key=True, index=True) |
| job_id = Column(String(36), unique=True, index=True, nullable=False) |
| name = Column(String(255), nullable=False) |
| description = Column(Text, nullable=True) |
| |
| |
| task_type = Column(String(50), nullable=False) |
| base_model = Column(String(255), nullable=False) |
| output_model_name = Column(String(255), nullable=True) |
| |
| |
| dataset_source = Column(String(50), default="huggingface") |
| dataset_name = Column(String(255), nullable=True) |
| dataset_config = Column(String(100), nullable=True) |
| dataset_split = Column(String(50), default="train") |
| custom_dataset_path = Column(String(512), nullable=True) |
| |
| |
| training_args = Column(JSON, default=dict) |
| peft_config = Column(JSON, nullable=True) |
| deepspeed_config = Column(JSON, nullable=True) |
| |
| |
| status = Column(String(20), default=JobStatus.PENDING.value) |
| progress = Column(Float, default=0.0) |
| current_epoch = Column(Integer, default=0) |
| total_epochs = Column(Integer, default=3) |
| current_step = Column(Integer, default=0) |
| total_steps = Column(Integer, default=0) |
| |
| |
| train_loss = Column(Float, nullable=True) |
| eval_loss = Column(Float, nullable=True) |
| learning_rate = Column(Float, nullable=True) |
| metrics = Column(JSON, default=dict) |
| |
| |
| output_path = Column(String(512), nullable=True) |
| hub_model_id = Column(String(255), nullable=True) |
| model_card = Column(Text, nullable=True) |
| |
| |
| error_message = Column(Text, nullable=True) |
| traceback = Column(Text, nullable=True) |
| retry_count = Column(Integer, default=0) |
| max_retries = Column(Integer, default=3) |
| |
| |
| created_at = Column(DateTime, default=datetime.utcnow) |
| started_at = Column(DateTime, nullable=True) |
| completed_at = Column(DateTime, nullable=True) |
| updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) |
| |
| |
| created_by = Column(String(100), nullable=True) |
| tags = Column(JSON, default=list) |
| |
| |
| checkpoints = relationship("Checkpoint", back_populates="job", cascade="all, delete-orphan") |
| logs = relationship("TrainingLog", back_populates="job", cascade="all, delete-orphan") |
| |
| def to_dict(self): |
| return { |
| "id": self.id, |
| "job_id": self.job_id, |
| "name": self.name, |
| "description": self.description, |
| "task_type": self.task_type, |
| "base_model": self.base_model, |
| "output_model_name": self.output_model_name, |
| "dataset_name": self.dataset_name, |
| "status": self.status, |
| "progress": self.progress, |
| "current_epoch": self.current_epoch, |
| "total_epochs": self.total_epochs, |
| "current_step": self.current_step, |
| "total_steps": self.total_steps, |
| "train_loss": self.train_loss, |
| "eval_loss": self.eval_loss, |
| "metrics": self.metrics, |
| "output_path": self.output_path, |
| "hub_model_id": self.hub_model_id, |
| "error_message": self.error_message, |
| "created_at": self.created_at.isoformat() if self.created_at else None, |
| "started_at": self.started_at.isoformat() if self.started_at else None, |
| "completed_at": self.completed_at.isoformat() if self.completed_at else None, |
| "tags": self.tags |
| } |
|
|
|
|
| class Checkpoint(Base): |
| """Model for training checkpoints.""" |
| __tablename__ = "checkpoints" |
| |
| id = Column(Integer, primary_key=True, index=True) |
| job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False) |
| checkpoint_name = Column(String(255), nullable=False) |
| checkpoint_path = Column(String(512), nullable=False) |
| step = Column(Integer, nullable=False) |
| epoch = Column(Float, nullable=False) |
| loss = Column(Float, nullable=True) |
| metrics = Column(JSON, default=dict) |
| is_best = Column(Boolean, default=False) |
| created_at = Column(DateTime, default=datetime.utcnow) |
| size_mb = Column(Float, nullable=True) |
| |
| |
| job = relationship("TrainingJob", back_populates="checkpoints") |
|
|
|
|
| class TrainingLog(Base): |
| """Model for training logs.""" |
| __tablename__ = "training_logs" |
| |
| id = Column(Integer, primary_key=True, index=True) |
| job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False) |
| level = Column(String(10), default="INFO") |
| message = Column(Text, nullable=False) |
| step = Column(Integer, nullable=True) |
| epoch = Column(Float, nullable=True) |
| loss = Column(Float, nullable=True) |
| learning_rate = Column(Float, nullable=True) |
| metrics = Column(JSON, nullable=True) |
| created_at = Column(DateTime, default=datetime.utcnow) |
| |
| |
| job = relationship("TrainingJob", back_populates="logs") |
|
|
|
|
| class ModelRegistry(Base): |
| """Registry of trained and available models.""" |
| __tablename__ = "model_registry" |
| |
| id = Column(Integer, primary_key=True, index=True) |
| name = Column(String(255), unique=True, nullable=False) |
| model_id = Column(String(255), nullable=False) |
| task_type = Column(String(50), nullable=False) |
| description = Column(Text, nullable=True) |
| tags = Column(JSON, default=list) |
| parameters = Column(String(20), nullable=True) |
| is_local = Column(Boolean, default=False) |
| local_path = Column(String(512), nullable=True) |
| hub_url = Column(String(512), nullable=True) |
| is_trained = Column(Boolean, default=False) |
| training_job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=True) |
| created_at = Column(DateTime, default=datetime.utcnow) |
| last_used = Column(DateTime, nullable=True) |
|
|
|
|
| class DatasetCache(Base): |
| """Cache for downloaded datasets.""" |
| __tablename__ = "dataset_cache" |
| |
| id = Column(Integer, primary_key=True, index=True) |
| name = Column(String(255), unique=True, nullable=False) |
| config = Column(String(100), nullable=True) |
| split = Column(String(50), nullable=True) |
| local_path = Column(String(512), nullable=False) |
| size_mb = Column(Float, nullable=True) |
| num_samples = Column(Integer, nullable=True) |
| features = Column(JSON, nullable=True) |
| created_at = Column(DateTime, default=datetime.utcnow) |
| last_accessed = Column(DateTime, default=datetime.utcnow) |
|
|
|
|
| async def init_db(): |
| """Initialize database tables.""" |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
|
|
|
|
| async def get_db(): |
| """Get database session.""" |
| async with AsyncSessionLocal() as session: |
| try: |
| yield session |
| finally: |
| await session.close() |