VAE / train.py
souranil3d's picture
First commit for VAE space
16906c1
raw
history blame
1.31 kB
from pytorch_lightning import Trainer
from models import vae_models
from config import config
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
def make_model(config):
model_type = config.model_type
model_config = config.model_config
if model_type not in vae_models.keys():
raise NotImplementedError("Model Architecture not implemented")
else:
return vae_models[model_type](**model_config.dict())
if __name__ == "__main__":
model = make_model(config)
train_config = config.train_config
logger = TensorBoardLogger(**config.log_config.dict())
trainer = Trainer(**train_config.dict(), logger=logger,
callbacks=LearningRateMonitor())
if train_config.auto_lr_find:
lr_finder = trainer.tuner.lr_find(model)
new_lr = lr_finder.suggestion()
print("Learning Rate Chosen:", new_lr)
model.lr = new_lr
trainer.fit(model)
else:
trainer.fit(model)
if not os.path.isdir("./saved_models"):
os.mkdir("./saved_models")
trainer.save_checkpoint(
f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt")