from pydantic import BaseModel from typing import Optional, Union import yaml class TrainConfig(BaseModel): max_epochs: int auto_lr_find: Union[bool, int] gpus: int class VAEConfig(BaseModel): model_type: str hidden_size: int latent_size: int alpha: int dataset: str batch_size: Optional[int] = 64 save_images: Optional[bool] = False lr: Optional[float] = None save_path: Optional[str] = None class ConvVAEConfig(VAEConfig): channels: int height: int width: int class LoggerConfig(BaseModel): name: str save_dir: str class Config(BaseModel): model_config: Union[VAEConfig, ConvVAEConfig] train_config: TrainConfig model_type: str log_config: LoggerConfig def load_config(path="config.yaml"): config = yaml.load(open(path), yaml.SafeLoader) model_type = config['model_params']['model_type'] if model_type == "vae": model_config = VAEConfig(**config["model_params"]) elif model_type == "conv-vae": model_config = ConvVAEConfig(**config["model_params"]) else: raise NotImplementedError(f"Model {model_type} is not implemented") train_config = TrainConfig(**config["training_params"]) log_config = LoggerConfig(**config["logger_params"]) config = Config(model_config=model_config, train_config=train_config, model_type=model_type, log_config=log_config) return config config = load_config()