from models import vae_models | |
from config import config | |
from PIL import Image | |
from torchvision.transforms import Resize, ToPILImage, Compose | |
from utils import load_model, tensor_to_img, resize_img, export_to_onnx | |
def predict(model_ckpt="vae_alpha_1024_dim_128.ckpt"): | |
model_type = config.model_type | |
model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{model_ckpt}") | |
model.eval() | |
test_iter = iter(model.test_dataloader()) | |
d, _ = next(test_iter) | |
_, _, out = model(d) | |
out_img = tensor_to_img(out) | |
return out_img | |
if __name__ == "__main__": | |
predict() | |
# export_to_onnx("./saved_models/vae.ckpt") |