File size: 1,466 Bytes
16906c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from pydantic import BaseModel
from typing import Optional, Union
import yaml


class TrainConfig(BaseModel):
    max_epochs: int
    auto_lr_find: Union[bool, int]
    gpus: int


class VAEConfig(BaseModel):
    model_type: str
    hidden_size: int
    latent_size: int
    alpha: int
    dataset: str
    batch_size: Optional[int] = 64
    save_images: Optional[bool] = False
    lr: Optional[float] = None
    save_path: Optional[str] = None


class ConvVAEConfig(VAEConfig):
    channels: int
    height: int
    width: int


class LoggerConfig(BaseModel):
    name: str
    save_dir: str


class Config(BaseModel):
    model_config: Union[VAEConfig, ConvVAEConfig]
    train_config: TrainConfig
    model_type: str
    log_config: LoggerConfig


def load_config(path="config.yaml"):
    config = yaml.load(open(path), yaml.SafeLoader)
    model_type = config['model_params']['model_type']
    if model_type == "vae":
        model_config = VAEConfig(**config["model_params"])
    elif model_type == "conv-vae":
        model_config = ConvVAEConfig(**config["model_params"])
    else:
        raise NotImplementedError(f"Model {model_type} is not implemented")
    train_config = TrainConfig(**config["training_params"])
    log_config = LoggerConfig(**config["logger_params"])
    config = Config(model_config=model_config, train_config=train_config,
                    model_type=model_type, log_config=log_config)

    return config


config = load_config()