File size: 1,798 Bytes
16906c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from pytorch_lightning import Trainer
from torchvision.utils import save_image
from models import vae_models
from config import config
from PIL import Image 
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from torch.nn.functional import interpolate
from torchvision.transforms import Resize, ToPILImage, Compose
from torchvision.utils import make_grid

def load_model(ckpt, model_type="vae"):
    model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{ckpt}")
    model.eval()
    return model

def parse_model_file_name(file_name):
    # Hard Coded Parsing based on the filenames that I use
    substrings = file_name.split(".")[0].split("_")
    name, alpha, dim = substrings[0], substrings[2], substrings[4]
    new_name = ""
    if name == "vae":
        new_name += "Vanilla VAE"
    new_name += f" | alpha={alpha}"
    new_name += f" | dim={dim}"
    return new_name

def tensor_to_img(tsr):
    if tsr.ndim == 4:
        tsr = tsr.squeeze(0)
    
    transform = Compose([
        ToPILImage()
    ])
    img = transform(tsr)
    return img


def resize_img(img, w, h):
    return img.resize((w, h))

def canvas_to_tensor(canvas):
    """
    Convert Image of RGBA to single channel B/W and convert from numpy array
    to a PyTorch Tensor of [1,1,28,28]
    """
    img = canvas.image_data
    img = img[:, :, :-1]  # Ignore alpha channel
    img = img.mean(axis=2)
    img = img/255
    img = img*2 - 1.
    img = torch.FloatTensor(img)
    tens = img.unsqueeze(0).unsqueeze(0)
    tens = interpolate(tens, (28, 28))
    return tens


def export_to_onnx(ckpt):
    model = load_model(ckpt)
    filepath = "model.onnx"
    test_iter = iter(model.test_dataloader())
    sample, _ = next(test_iter)
    model.to_onnx(filepath, sample, export_params=True)