File size: 1,466 Bytes
16906c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from pydantic import BaseModel
from typing import Optional, Union
import yaml
class TrainConfig(BaseModel):
max_epochs: int
auto_lr_find: Union[bool, int]
gpus: int
class VAEConfig(BaseModel):
model_type: str
hidden_size: int
latent_size: int
alpha: int
dataset: str
batch_size: Optional[int] = 64
save_images: Optional[bool] = False
lr: Optional[float] = None
save_path: Optional[str] = None
class ConvVAEConfig(VAEConfig):
channels: int
height: int
width: int
class LoggerConfig(BaseModel):
name: str
save_dir: str
class Config(BaseModel):
model_config: Union[VAEConfig, ConvVAEConfig]
train_config: TrainConfig
model_type: str
log_config: LoggerConfig
def load_config(path="config.yaml"):
config = yaml.load(open(path), yaml.SafeLoader)
model_type = config['model_params']['model_type']
if model_type == "vae":
model_config = VAEConfig(**config["model_params"])
elif model_type == "conv-vae":
model_config = ConvVAEConfig(**config["model_params"])
else:
raise NotImplementedError(f"Model {model_type} is not implemented")
train_config = TrainConfig(**config["training_params"])
log_config = LoggerConfig(**config["logger_params"])
config = Config(model_config=model_config, train_config=train_config,
model_type=model_type, log_config=log_config)
return config
config = load_config()
|