mbellan's picture
Initial deployment
c3efd49
raw
history blame contribute delete
504 Bytes
"""Model configuration classes."""
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
"""Configuration for voice model."""
name: str
device: str = "cuda"
checkpoint: Optional[str] = None
cache_dir: Optional[str] = None
def __post_init__(self):
"""Validate configuration."""
if self.device not in ["cuda", "cpu", "mps"]:
raise ValueError(f"Invalid device: {self.device}. Must be 'cuda', 'cpu', or 'mps'")