File size: 740 Bytes
16906c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from pytorch_lightning import Trainer
from models import vae_models
from config import config
from pytorch_lightning.loggers import TensorBoardLogger
import os
def make_model(config):
model_type = config.model_type
model_config = config.model_config
if model_type not in vae_models.keys():
raise NotImplementedError("Model Architecture not implemented")
else:
return vae_models[model_type](**model_config.dict())
if __name__ == "__main__":
model_type = config.model_type
model = vae_models[model_type].load_from_checkpoint("./saved_models/vae_alpha_1024_dim_128.ckpt")
logger = TensorBoardLogger(**config.log_config.dict())
trainer = Trainer(gpus=1, logger=logger)
trainer.test(model) |