File size: 658 Bytes
16906c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from models import vae_models
from config import config
from PIL import Image 
from torchvision.transforms import Resize, ToPILImage, Compose

from utils import load_model, tensor_to_img, resize_img, export_to_onnx



def predict(model_ckpt="vae_alpha_1024_dim_128.ckpt"):
    model_type = config.model_type
    model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{model_ckpt}")
    model.eval()
    test_iter = iter(model.test_dataloader())
    d, _ = next(test_iter)
    _, _, out = model(d)
    out_img = tensor_to_img(out)
    return out_img

    

if __name__ == "__main__":
    predict()
    # export_to_onnx("./saved_models/vae.ckpt")