VAE / test.py
souranil3d's picture
First commit for VAE space
16906c1
from pytorch_lightning import Trainer
from models import vae_models
from config import config
from pytorch_lightning.loggers import TensorBoardLogger
import os
def make_model(config):
model_type = config.model_type
model_config = config.model_config
if model_type not in vae_models.keys():
raise NotImplementedError("Model Architecture not implemented")
else:
return vae_models[model_type](**model_config.dict())
if __name__ == "__main__":
model_type = config.model_type
model = vae_models[model_type].load_from_checkpoint("./saved_models/vae_alpha_1024_dim_128.ckpt")
logger = TensorBoardLogger(**config.log_config.dict())
trainer = Trainer(gpus=1, logger=logger)
trainer.test(model)