File size: 1,313 Bytes
16906c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pytorch_lightning import Trainer
from models import vae_models
from config import config
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


def make_model(config):
    model_type = config.model_type
    model_config = config.model_config

    if model_type not in vae_models.keys():
        raise NotImplementedError("Model Architecture not implemented")
    else:
        return vae_models[model_type](**model_config.dict())


if __name__ == "__main__":
    model = make_model(config)
    train_config = config.train_config
    logger = TensorBoardLogger(**config.log_config.dict())
    trainer = Trainer(**train_config.dict(), logger=logger,
                      callbacks=LearningRateMonitor())
    if train_config.auto_lr_find:
        lr_finder = trainer.tuner.lr_find(model)
        new_lr = lr_finder.suggestion()
        print("Learning Rate Chosen:", new_lr)
        model.lr = new_lr
        trainer.fit(model)
    else:
        trainer.fit(model)
    if not os.path.isdir("./saved_models"):
        os.mkdir("./saved_models")
    trainer.save_checkpoint(
        f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt")