VAE / inference.py
souranil3d's picture
First commit for VAE space
16906c1
raw
history blame contribute delete
658 Bytes
from models import vae_models
from config import config
from PIL import Image
from torchvision.transforms import Resize, ToPILImage, Compose
from utils import load_model, tensor_to_img, resize_img, export_to_onnx
def predict(model_ckpt="vae_alpha_1024_dim_128.ckpt"):
model_type = config.model_type
model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{model_ckpt}")
model.eval()
test_iter = iter(model.test_dataloader())
d, _ = next(test_iter)
_, _, out = model(d)
out_img = tensor_to_img(out)
return out_img
if __name__ == "__main__":
predict()
# export_to_onnx("./saved_models/vae.ckpt")