from pytorch_lightning import Trainer | |
from models import vae_models | |
from config import config | |
from pytorch_lightning.loggers import TensorBoardLogger | |
import os | |
def make_model(config): | |
model_type = config.model_type | |
model_config = config.model_config | |
if model_type not in vae_models.keys(): | |
raise NotImplementedError("Model Architecture not implemented") | |
else: | |
return vae_models[model_type](**model_config.dict()) | |
if __name__ == "__main__": | |
model_type = config.model_type | |
model = vae_models[model_type].load_from_checkpoint("./saved_models/vae_alpha_1024_dim_128.ckpt") | |
logger = TensorBoardLogger(**config.log_config.dict()) | |
trainer = Trainer(gpus=1, logger=logger) | |
trainer.test(model) |