from pytorch_lightning import Trainer from models import vae_models from config import config from pytorch_lightning.loggers import TensorBoardLogger import os def make_model(config): model_type = config.model_type model_config = config.model_config if model_type not in vae_models.keys(): raise NotImplementedError("Model Architecture not implemented") else: return vae_models[model_type](**model_config.dict()) if __name__ == "__main__": model_type = config.model_type model = vae_models[model_type].load_from_checkpoint("./saved_models/vae_alpha_1024_dim_128.ckpt") logger = TensorBoardLogger(**config.log_config.dict()) trainer = Trainer(gpus=1, logger=logger) trainer.test(model)