VAE / config /config.py
souranil3d's picture
First commit for VAE space
16906c1
from pydantic import BaseModel
from typing import Optional, Union
import yaml
class TrainConfig(BaseModel):
max_epochs: int
auto_lr_find: Union[bool, int]
gpus: int
class VAEConfig(BaseModel):
model_type: str
hidden_size: int
latent_size: int
alpha: int
dataset: str
batch_size: Optional[int] = 64
save_images: Optional[bool] = False
lr: Optional[float] = None
save_path: Optional[str] = None
class ConvVAEConfig(VAEConfig):
channels: int
height: int
width: int
class LoggerConfig(BaseModel):
name: str
save_dir: str
class Config(BaseModel):
model_config: Union[VAEConfig, ConvVAEConfig]
train_config: TrainConfig
model_type: str
log_config: LoggerConfig
def load_config(path="config.yaml"):
config = yaml.load(open(path), yaml.SafeLoader)
model_type = config['model_params']['model_type']
if model_type == "vae":
model_config = VAEConfig(**config["model_params"])
elif model_type == "conv-vae":
model_config = ConvVAEConfig(**config["model_params"])
else:
raise NotImplementedError(f"Model {model_type} is not implemented")
train_config = TrainConfig(**config["training_params"])
log_config = LoggerConfig(**config["logger_params"])
config = Config(model_config=model_config, train_config=train_config,
model_type=model_type, log_config=log_config)
return config
config = load_config()