souranil3d
commited on
Commit
•
16906c1
1
Parent(s):
5f3dfcf
First commit for VAE space
Browse files- .gitignore +11 -0
- Dockerfile +9 -0
- README.md +20 -12
- app.py +55 -0
- config.yaml +20 -0
- config/__init__.py +2 -0
- config/config.py +59 -0
- inference.py +24 -0
- models/__init__.py +11 -0
- models/conv_vae.py +239 -0
- models/vae.py +213 -0
- requirements.txt +8 -0
- test.py +22 -0
- train.py +37 -0
- utils.py +63 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
medium/
|
2 |
+
data/
|
3 |
+
.vscode/
|
4 |
+
__pycache__/
|
5 |
+
.ipynb_checkpoints/
|
6 |
+
lightning_logs/
|
7 |
+
log_images/
|
8 |
+
logs/
|
9 |
+
*.env
|
10 |
+
.idea/
|
11 |
+
saved_models/
|
Dockerfile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8-slim-buster
|
2 |
+
WORKDIR /app
|
3 |
+
EXPOSE $PORT
|
4 |
+
|
5 |
+
COPY requirements.txt /
|
6 |
+
RUN pip3 install -r /requirements.txt
|
7 |
+
COPY . /app
|
8 |
+
|
9 |
+
CMD streamlit run app.py --server.port $PORT
|
README.md
CHANGED
@@ -1,12 +1,20 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### VAE with Pytorch-Lightning
|
2 |
+
|
3 |
+
This is inspired from vae-playground. This is an example where we test out vae and conv_vae models with multiple datasets
|
4 |
+
like MNIST, celeb-a and MNIST-Fashion datasets.
|
5 |
+
|
6 |
+
This also comes with an example streamlit app & deployed at huggingface.
|
7 |
+
|
8 |
+
|
9 |
+
## Model Training
|
10 |
+
|
11 |
+
You can train the VAE models by using `train.py` and editing the `config.yaml` file. \
|
12 |
+
Hyperparameters to change are:
|
13 |
+
- model_type [vae|conv_vae]
|
14 |
+
- alpha
|
15 |
+
- hidden_dim
|
16 |
+
- dataset [celeba|mnist|fashion-mnist]
|
17 |
+
|
18 |
+
There are other configurations that can be changed if required like height, width, channels etc. It also contains the pytorch-lightning configs as well.
|
19 |
+
|
20 |
+
|
app.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_drawable_canvas import st_canvas
|
3 |
+
import os
|
4 |
+
import utils
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
st.set_page_config("VAE MNIST Pytorch Lightning")
|
9 |
+
st.title("VAE Playground")
|
10 |
+
# title_img = Image.open("images/title_img.jpg")
|
11 |
+
|
12 |
+
# st.image(title_img)
|
13 |
+
st.markdown(
|
14 |
+
"This is a simple streamlit app to showcase how a simple VAEs."
|
15 |
+
)
|
16 |
+
|
17 |
+
def load_model_files():
|
18 |
+
files = os.listdir("./saved_models/")
|
19 |
+
# Docker creates some whiteout files which mig
|
20 |
+
files = [i for i in files if ".ckpt" in i]
|
21 |
+
clean_names = [utils.parse_model_file_name(name) for name in files]
|
22 |
+
return {k: v for k, v in zip(clean_names, files)}
|
23 |
+
|
24 |
+
|
25 |
+
file_name_map = load_model_files()
|
26 |
+
files = list(file_name_map.keys())
|
27 |
+
|
28 |
+
st.header("🖼️ Image Reconstruction", "recon")
|
29 |
+
|
30 |
+
with st.form("reconstruction"):
|
31 |
+
model_name = st.selectbox("Choose Model:", files,
|
32 |
+
key="recon_model_select")
|
33 |
+
recon_model_name = file_name_map[model_name]
|
34 |
+
recon_canvas = st_canvas(
|
35 |
+
# Fixed fill color with some opacity
|
36 |
+
fill_color="rgba(255, 165, 0, 0.3)",
|
37 |
+
stroke_width=8,
|
38 |
+
stroke_color="#FFFFFF",
|
39 |
+
background_color="#000000",
|
40 |
+
update_streamlit=True,
|
41 |
+
height=150,
|
42 |
+
width=150,
|
43 |
+
drawing_mode="freedraw",
|
44 |
+
key="recon_canvas",
|
45 |
+
)
|
46 |
+
submit = st.form_submit_button("Perform Reconstruction")
|
47 |
+
if submit:
|
48 |
+
recon_model = utils.load_model(recon_model_name)
|
49 |
+
inp_tens = utils.canvas_to_tensor(recon_canvas)
|
50 |
+
_, _, out = recon_model(inp_tens)
|
51 |
+
out = (out+1)/2
|
52 |
+
out_img = utils.resize_img(utils.tensor_to_img(out), 150, 150)
|
53 |
+
if submit:
|
54 |
+
st.image(out_img)
|
55 |
+
|
config.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
training_params:
|
2 |
+
max_epochs: 30
|
3 |
+
auto_lr_find: false
|
4 |
+
gpus: 1
|
5 |
+
model_params:
|
6 |
+
model_type: conv-vae # vae or conv-vae
|
7 |
+
lr: 0.005
|
8 |
+
batch_size: 1
|
9 |
+
hidden_size: 4096
|
10 |
+
latent_size: 128
|
11 |
+
alpha: 1024
|
12 |
+
dataset: "fashion-mnist"
|
13 |
+
save_images: true
|
14 |
+
save_path: "log_images/"
|
15 |
+
channels: 1
|
16 |
+
height: 64
|
17 |
+
width: 64
|
18 |
+
logger_params:
|
19 |
+
name: "conv-vae"
|
20 |
+
save_dir: "logs/"
|
config/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .config import config
|
2 |
+
__all__ = ["config"]
|
config/config.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import Optional, Union
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
|
6 |
+
class TrainConfig(BaseModel):
|
7 |
+
max_epochs: int
|
8 |
+
auto_lr_find: Union[bool, int]
|
9 |
+
gpus: int
|
10 |
+
|
11 |
+
|
12 |
+
class VAEConfig(BaseModel):
|
13 |
+
model_type: str
|
14 |
+
hidden_size: int
|
15 |
+
latent_size: int
|
16 |
+
alpha: int
|
17 |
+
dataset: str
|
18 |
+
batch_size: Optional[int] = 64
|
19 |
+
save_images: Optional[bool] = False
|
20 |
+
lr: Optional[float] = None
|
21 |
+
save_path: Optional[str] = None
|
22 |
+
|
23 |
+
|
24 |
+
class ConvVAEConfig(VAEConfig):
|
25 |
+
channels: int
|
26 |
+
height: int
|
27 |
+
width: int
|
28 |
+
|
29 |
+
|
30 |
+
class LoggerConfig(BaseModel):
|
31 |
+
name: str
|
32 |
+
save_dir: str
|
33 |
+
|
34 |
+
|
35 |
+
class Config(BaseModel):
|
36 |
+
model_config: Union[VAEConfig, ConvVAEConfig]
|
37 |
+
train_config: TrainConfig
|
38 |
+
model_type: str
|
39 |
+
log_config: LoggerConfig
|
40 |
+
|
41 |
+
|
42 |
+
def load_config(path="config.yaml"):
|
43 |
+
config = yaml.load(open(path), yaml.SafeLoader)
|
44 |
+
model_type = config['model_params']['model_type']
|
45 |
+
if model_type == "vae":
|
46 |
+
model_config = VAEConfig(**config["model_params"])
|
47 |
+
elif model_type == "conv-vae":
|
48 |
+
model_config = ConvVAEConfig(**config["model_params"])
|
49 |
+
else:
|
50 |
+
raise NotImplementedError(f"Model {model_type} is not implemented")
|
51 |
+
train_config = TrainConfig(**config["training_params"])
|
52 |
+
log_config = LoggerConfig(**config["logger_params"])
|
53 |
+
config = Config(model_config=model_config, train_config=train_config,
|
54 |
+
model_type=model_type, log_config=log_config)
|
55 |
+
|
56 |
+
return config
|
57 |
+
|
58 |
+
|
59 |
+
config = load_config()
|
inference.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import vae_models
|
2 |
+
from config import config
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import Resize, ToPILImage, Compose
|
5 |
+
|
6 |
+
from utils import load_model, tensor_to_img, resize_img, export_to_onnx
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def predict(model_ckpt="vae_alpha_1024_dim_128.ckpt"):
|
11 |
+
model_type = config.model_type
|
12 |
+
model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{model_ckpt}")
|
13 |
+
model.eval()
|
14 |
+
test_iter = iter(model.test_dataloader())
|
15 |
+
d, _ = next(test_iter)
|
16 |
+
_, _, out = model(d)
|
17 |
+
out_img = tensor_to_img(out)
|
18 |
+
return out_img
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
predict()
|
24 |
+
# export_to_onnx("./saved_models/vae.ckpt")
|
models/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vae import VAE, Flatten, Stack # noqa: F401
|
2 |
+
from .conv_vae import Conv_VAE # noqa: F401
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'VAE', 'Flatten', 'Stack'
|
6 |
+
'Conv_VAE',
|
7 |
+
]
|
8 |
+
vae_models = {
|
9 |
+
"conv-vae": Conv_VAE,
|
10 |
+
"vae": VAE
|
11 |
+
}
|
models/conv_vae.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vae import VAE, Flatten, Stack
|
2 |
+
import torch.nn as nn
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from typing import Optional
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torchvision.datasets import MNIST, FashionMNIST, CelebA
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from torch.optim import Adam
|
14 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
15 |
+
|
16 |
+
class PrintShape(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super(PrintShape, self).__init__()
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
# Do your print / debug stuff here
|
22 |
+
# print(f"Shape: {x.shape}")
|
23 |
+
return x
|
24 |
+
|
25 |
+
class UnFlatten(nn.Module):
|
26 |
+
def forward(self, input, size=4096):
|
27 |
+
# print("Unflatteing")
|
28 |
+
return input.view(input.size(0), size, 1, 1)
|
29 |
+
|
30 |
+
|
31 |
+
class Flatten(nn.Module):
|
32 |
+
def forward(self, input):
|
33 |
+
# print("Flattening")
|
34 |
+
return input.view(input.size(0), -1)
|
35 |
+
|
36 |
+
class Conv_VAE(pl.LightningModule):
|
37 |
+
def __init__(self, channels: int, height: int, width: int, lr: int,
|
38 |
+
latent_size: int, hidden_size: int, alpha: int, batch_size: int,
|
39 |
+
dataset: Optional[str] = None,
|
40 |
+
save_images: Optional[bool] = None,
|
41 |
+
save_path: Optional[str] = None, **kwargs):
|
42 |
+
super().__init__()
|
43 |
+
self.latent_size = latent_size
|
44 |
+
self.hidden_size = hidden_size
|
45 |
+
if save_images:
|
46 |
+
self.save_path = f'{save_path}/{kwargs["model_type"]}_images/'
|
47 |
+
self.save_hyperparameters()
|
48 |
+
self.save_images = save_images
|
49 |
+
self.lr = lr
|
50 |
+
self.batch_size = batch_size
|
51 |
+
self.alpha = alpha
|
52 |
+
self.dataset = dataset
|
53 |
+
assert not height % 4 and not width % 4, "Choose height and width to "\
|
54 |
+
"be divisible by 4"
|
55 |
+
self.channels = channels
|
56 |
+
self.height = height
|
57 |
+
self.width = width
|
58 |
+
self.latent_size = latent_size
|
59 |
+
self.save_hyperparameters()
|
60 |
+
|
61 |
+
self.data_transform = transforms.Compose([
|
62 |
+
transforms.Resize(64),
|
63 |
+
transforms.CenterCrop((64, 64)),
|
64 |
+
transforms.ToTensor()
|
65 |
+
])
|
66 |
+
|
67 |
+
|
68 |
+
self.encoder = nn.Sequential(
|
69 |
+
PrintShape(),
|
70 |
+
nn.Conv2d(self.channels, 32, kernel_size=3, stride=2, padding=1),
|
71 |
+
nn.BatchNorm2d(32),
|
72 |
+
nn.LeakyReLU(),
|
73 |
+
PrintShape(),
|
74 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
75 |
+
nn.BatchNorm2d(64),
|
76 |
+
nn.LeakyReLU(),
|
77 |
+
PrintShape(),
|
78 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
79 |
+
nn.BatchNorm2d(128),
|
80 |
+
nn.LeakyReLU(),
|
81 |
+
PrintShape(),
|
82 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
83 |
+
nn.BatchNorm2d(256),
|
84 |
+
nn.LeakyReLU(),
|
85 |
+
PrintShape(),
|
86 |
+
Flatten(),
|
87 |
+
PrintShape(),
|
88 |
+
)
|
89 |
+
|
90 |
+
self.fc1 = nn.Linear(self.hidden_size, self.latent_size)
|
91 |
+
self.fc2 = nn.Linear(self.latent_size, self.hidden_size)
|
92 |
+
|
93 |
+
self.decoder = nn.Sequential(
|
94 |
+
PrintShape(),
|
95 |
+
# nn.Linear(self.hidden_size, self.hidden_size),
|
96 |
+
# PrintShape(),
|
97 |
+
# nn.BatchNorm1d(self.hidden_size),
|
98 |
+
UnFlatten(),
|
99 |
+
PrintShape(),
|
100 |
+
nn.ConvTranspose2d(self.hidden_size, 256, kernel_size=6, stride=2, padding=1),
|
101 |
+
PrintShape(),
|
102 |
+
nn.LeakyReLU(),
|
103 |
+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
|
104 |
+
nn.BatchNorm2d(128),
|
105 |
+
PrintShape(),
|
106 |
+
nn.LeakyReLU(),
|
107 |
+
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
|
108 |
+
nn.BatchNorm2d(64),
|
109 |
+
PrintShape(),
|
110 |
+
nn.LeakyReLU(),
|
111 |
+
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
|
112 |
+
nn.BatchNorm2d(32),
|
113 |
+
PrintShape(),
|
114 |
+
nn.LeakyReLU(),
|
115 |
+
nn.ConvTranspose2d(32, self.channels, kernel_size=4, stride=2, padding=1),
|
116 |
+
nn.BatchNorm2d(self.channels),
|
117 |
+
PrintShape(),
|
118 |
+
nn.Sigmoid(),
|
119 |
+
)
|
120 |
+
|
121 |
+
def encode(self, x):
|
122 |
+
hidden = self.encoder(x)
|
123 |
+
mu, log_var = self.fc1(hidden), self.fc1(hidden)
|
124 |
+
# print("Encoded")
|
125 |
+
return mu, log_var
|
126 |
+
|
127 |
+
def decode(self, z):
|
128 |
+
# print("Decoding")
|
129 |
+
# f = nn.Linear(self.latent_size, self.hidden_size)
|
130 |
+
z = self.fc2(z)
|
131 |
+
# print(f"L: {z.shape}")
|
132 |
+
x = self.decoder(z)
|
133 |
+
return x
|
134 |
+
|
135 |
+
def reparametrize(self, mu, log_var):
|
136 |
+
# Reparametrization Trick to allow gradients to backpropagate from the
|
137 |
+
# stochastic part of the model
|
138 |
+
sigma = torch.exp(0.5*log_var)
|
139 |
+
z = torch.randn_like(sigma)
|
140 |
+
return mu + sigma*z
|
141 |
+
|
142 |
+
def training_step(self, batch, batch_idx):
|
143 |
+
x, _ = batch
|
144 |
+
mu, log_var, x_out = self.forward(x)
|
145 |
+
kl_loss = (-0.5*(1+log_var - mu**2 -
|
146 |
+
torch.exp(log_var)).sum(dim=1)).mean(dim=0)
|
147 |
+
recon_loss_criterion = nn.MSELoss()
|
148 |
+
recon_loss = recon_loss_criterion(x, x_out)
|
149 |
+
# print(kl_loss.item(),recon_loss.item())
|
150 |
+
loss = recon_loss*self.alpha + kl_loss
|
151 |
+
|
152 |
+
self.log('train_loss', loss, on_step=False,
|
153 |
+
on_epoch=True, prog_bar=True)
|
154 |
+
return loss
|
155 |
+
|
156 |
+
def validation_step(self, batch, batch_idx):
|
157 |
+
x, _ = batch
|
158 |
+
mu, log_var, x_out = self.forward(x)
|
159 |
+
|
160 |
+
kl_loss = (-0.5*(1+log_var - mu**2 -
|
161 |
+
torch.exp(log_var)).sum(dim=1)).mean(dim=0)
|
162 |
+
recon_loss_criterion = nn.MSELoss()
|
163 |
+
recon_loss = recon_loss_criterion(x, x_out)
|
164 |
+
# print(kl_loss.item(),recon_loss.item())
|
165 |
+
loss = recon_loss*self.alpha + kl_loss
|
166 |
+
self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True)
|
167 |
+
self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True)
|
168 |
+
self.log('val_loss', loss, on_step=False, on_epoch=True)
|
169 |
+
# print(x.mean(),x_out.mean())
|
170 |
+
return x_out, loss
|
171 |
+
|
172 |
+
def validation_epoch_end(self, outputs):
|
173 |
+
if not self.save_images:
|
174 |
+
return
|
175 |
+
if not os.path.exists(self.save_path):
|
176 |
+
os.makedirs(self.save_path)
|
177 |
+
choice = random.choice(outputs)
|
178 |
+
output_sample = choice[0]
|
179 |
+
output_sample = output_sample.reshape(-1, 1, self.width, self.height)
|
180 |
+
# output_sample = self.scale_image(output_sample)
|
181 |
+
save_image(
|
182 |
+
output_sample,
|
183 |
+
f"{self.save_path}/epoch_{self.current_epoch+1}.png",
|
184 |
+
# value_range=(-1, 1)
|
185 |
+
)
|
186 |
+
|
187 |
+
def configure_optimizers(self):
|
188 |
+
optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate))
|
189 |
+
lr_scheduler = ReduceLROnPlateau(optimizer,)
|
190 |
+
return {
|
191 |
+
"optimizer": optimizer, "lr_scheduler": lr_scheduler,
|
192 |
+
"monitor": "val_loss"
|
193 |
+
}
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
mu, log_var = self.encode(x)
|
197 |
+
hidden = self.reparametrize(mu, log_var)
|
198 |
+
output = self.decode(hidden)
|
199 |
+
return mu, log_var, output
|
200 |
+
|
201 |
+
# Functions for dataloading
|
202 |
+
def train_dataloader(self):
|
203 |
+
if self.dataset == "mnist":
|
204 |
+
train_set = MNIST('data/', download=True,
|
205 |
+
train=True, transform=self.data_transform)
|
206 |
+
elif self.dataset == "fashion-mnist":
|
207 |
+
train_set = FashionMNIST(
|
208 |
+
'data/', download=True, train=True,
|
209 |
+
transform=self.data_transform)
|
210 |
+
elif self.dataset == "celeba":
|
211 |
+
train_set = CelebA('data/', download=False, split="train", transform=self.data_transform)
|
212 |
+
return DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
|
213 |
+
|
214 |
+
def val_dataloader(self):
|
215 |
+
if self.dataset == "mnist":
|
216 |
+
val_set = MNIST('data/', download=True, train=False,
|
217 |
+
transform=self.data_transform)
|
218 |
+
elif self.dataset == "fashion-mnist":
|
219 |
+
val_set = FashionMNIST(
|
220 |
+
'data/', download=True, train=False,
|
221 |
+
transform=self.data_transform)
|
222 |
+
elif self.dataset == "celeba":
|
223 |
+
val_set = CelebA('data/', download=False, split="valid", transform=self.data_transform)
|
224 |
+
return DataLoader(val_set, batch_size=self.batch_size)
|
225 |
+
|
226 |
+
def test_dataloader(self):
|
227 |
+
if self.dataset == "mnist":
|
228 |
+
val_set = MNIST('data/', download=True, train=False,
|
229 |
+
transform=self.data_transform)
|
230 |
+
elif self.dataset == "fashion-mnist":
|
231 |
+
val_set = FashionMNIST(
|
232 |
+
'data/', download=True, train=False,
|
233 |
+
transform=self.data_transform)
|
234 |
+
elif self.dataset == "celeba":
|
235 |
+
val_set = CelebA('data/', download=False, split="test", transform=self.data_transform)
|
236 |
+
return DataLoader(val_set, batch_size=self.batch_size)
|
237 |
+
|
238 |
+
|
239 |
+
|
models/vae.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import random
|
5 |
+
from torchvision.datasets import MNIST, FashionMNIST, CelebA
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
from torch.optim import Adam
|
10 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
11 |
+
|
12 |
+
import os
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
|
16 |
+
class Flatten(nn.Module):
|
17 |
+
def forward(self, x):
|
18 |
+
return x.view(x.size(0), -1)
|
19 |
+
|
20 |
+
|
21 |
+
class Stack(nn.Module):
|
22 |
+
def __init__(self, channels, height, width):
|
23 |
+
super(Stack, self).__init__()
|
24 |
+
self.channels = channels
|
25 |
+
self.height = height
|
26 |
+
self.width = width
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return x.view(x.size(0), self.channels, self.height, self.width)
|
30 |
+
|
31 |
+
|
32 |
+
class VAE(pl.LightningModule):
|
33 |
+
def __init__(self, latent_size: int, hidden_size: int, alpha: int, lr: float,
|
34 |
+
batch_size: int,
|
35 |
+
dataset: Optional[str] = None,
|
36 |
+
save_images: Optional[bool] = None,
|
37 |
+
save_path: Optional[str] = None, **kwargs):
|
38 |
+
"""Init function for the VAE
|
39 |
+
|
40 |
+
Args:
|
41 |
+
|
42 |
+
latent_size (int): Latent Hidden Size
|
43 |
+
alpha (int): Hyperparameter to control the importance of
|
44 |
+
reconstruction loss vs KL-Divergence Loss
|
45 |
+
lr (float): Learning Rate, will not be used if auto_lr_find is used.
|
46 |
+
dataset (Optional[str]): Dataset to used
|
47 |
+
save_images (Optional[bool]): Boolean to decide whether to save images
|
48 |
+
save_path (Optional[str]): Path to save images
|
49 |
+
"""
|
50 |
+
|
51 |
+
super().__init__()
|
52 |
+
self.latent_size = latent_size
|
53 |
+
self.hidden_size = hidden_size
|
54 |
+
if save_images:
|
55 |
+
self.save_path = f'{save_path}/{kwargs["model_type"]}_images/'
|
56 |
+
self.save_hyperparameters()
|
57 |
+
self.save_images = save_images
|
58 |
+
self.lr = lr
|
59 |
+
self.batch_size = batch_size
|
60 |
+
self.encoder = nn.Sequential(
|
61 |
+
Flatten(),
|
62 |
+
nn.Linear(784, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1),
|
63 |
+
nn.Linear(392, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1),
|
64 |
+
nn.Linear(196, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1),
|
65 |
+
nn.Linear(128, latent_size)
|
66 |
+
)
|
67 |
+
self.hidden2mu = nn.Linear(latent_size, latent_size)
|
68 |
+
self.hidden2log_var = nn.Linear(latent_size, latent_size)
|
69 |
+
self.alpha = alpha
|
70 |
+
self.decoder = nn.Sequential(
|
71 |
+
nn.Linear(latent_size, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1),
|
72 |
+
nn.Linear(128, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1),
|
73 |
+
nn.Linear(196, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1),
|
74 |
+
nn.Linear(392, 784),
|
75 |
+
Stack(1, 28, 28),
|
76 |
+
nn.Tanh()
|
77 |
+
)
|
78 |
+
self.height = kwargs.get("height")
|
79 |
+
self.width = kwargs.get("width")
|
80 |
+
self.data_transform = transforms.Compose([
|
81 |
+
transforms.ToTensor(),
|
82 |
+
transforms.Lambda(lambda x:2*x-1.)])
|
83 |
+
self.dataset = dataset
|
84 |
+
|
85 |
+
def encode(self, x):
|
86 |
+
hidden = self.encoder(x)
|
87 |
+
mu = self.hidden2mu(hidden)
|
88 |
+
log_var = self.hidden2log_var(hidden)
|
89 |
+
return mu, log_var
|
90 |
+
|
91 |
+
def decode(self, x):
|
92 |
+
x = self.decoder(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
def reparametrize(self, mu, log_var):
|
96 |
+
# Reparametrization Trick to allow gradients to backpropagate from the
|
97 |
+
# stochastic part of the model
|
98 |
+
sigma = torch.exp(0.5*log_var)
|
99 |
+
z = torch.randn_like(sigma)
|
100 |
+
return mu + sigma*z
|
101 |
+
|
102 |
+
def training_step(self, batch, batch_idx):
|
103 |
+
x, _ = batch
|
104 |
+
mu, log_var, x_out = self.forward(x)
|
105 |
+
kl_loss = (-0.5*(1+log_var - mu**2 -
|
106 |
+
torch.exp(log_var)).sum(dim=1)).mean(dim=0)
|
107 |
+
recon_loss_criterion = nn.MSELoss()
|
108 |
+
recon_loss = recon_loss_criterion(x, x_out)
|
109 |
+
# print(kl_loss.item(),recon_loss.item())
|
110 |
+
loss = recon_loss*self.alpha + kl_loss
|
111 |
+
|
112 |
+
self.log('train_loss', loss, on_step=False,
|
113 |
+
on_epoch=True, prog_bar=True)
|
114 |
+
return loss
|
115 |
+
|
116 |
+
def validation_step(self, batch, batch_idx):
|
117 |
+
x, _ = batch
|
118 |
+
mu, log_var, x_out = self.forward(x)
|
119 |
+
|
120 |
+
kl_loss = (-0.5*(1+log_var - mu**2 -
|
121 |
+
torch.exp(log_var)).sum(dim=1)).mean(dim=0)
|
122 |
+
recon_loss_criterion = nn.MSELoss()
|
123 |
+
recon_loss = recon_loss_criterion(x, x_out)
|
124 |
+
# print(kl_loss.item(),recon_loss.item())
|
125 |
+
loss = recon_loss*self.alpha + kl_loss
|
126 |
+
self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True)
|
127 |
+
self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True)
|
128 |
+
self.log('val_loss', loss, on_step=False, on_epoch=True)
|
129 |
+
# print(x.mean(),x_out.mean())
|
130 |
+
return x_out, loss
|
131 |
+
|
132 |
+
def validation_epoch_end(self, outputs):
|
133 |
+
if not self.save_images:
|
134 |
+
return
|
135 |
+
if not os.path.exists(self.save_path):
|
136 |
+
os.makedirs(self.save_path)
|
137 |
+
choice = random.choice(outputs)
|
138 |
+
output_sample = choice[0]
|
139 |
+
output_sample = output_sample.reshape(-1, 1, self.width, self.height)
|
140 |
+
# output_sample = self.scale_image(output_sample)
|
141 |
+
save_image(
|
142 |
+
output_sample,
|
143 |
+
f"{self.save_path}/epoch_{self.current_epoch+1}.png",
|
144 |
+
# value_range=(-1, 1)
|
145 |
+
)
|
146 |
+
|
147 |
+
def configure_optimizers(self):
|
148 |
+
optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate))
|
149 |
+
lr_scheduler = ReduceLROnPlateau(optimizer,)
|
150 |
+
return {
|
151 |
+
"optimizer": optimizer, "lr_scheduler": lr_scheduler,
|
152 |
+
"monitor": "val_loss"
|
153 |
+
}
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
mu, log_var = self.encode(x)
|
157 |
+
hidden = self.reparametrize(mu, log_var)
|
158 |
+
output = self.decode(hidden)
|
159 |
+
return mu, log_var, output
|
160 |
+
|
161 |
+
# Functions for dataloading
|
162 |
+
def train_dataloader(self):
|
163 |
+
if self.dataset == "mnist":
|
164 |
+
train_set = MNIST('data/', download=True,
|
165 |
+
train=True, transform=self.data_transform)
|
166 |
+
elif self.dataset == "fashion-mnist":
|
167 |
+
train_set = FashionMNIST(
|
168 |
+
'data/', download=True, train=True,
|
169 |
+
transform=self.data_transform)
|
170 |
+
elif self.dataset == "celeba":
|
171 |
+
train_set = CelebA('data/', download=False, split="train", transform=self.data_transform)
|
172 |
+
return DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
|
173 |
+
|
174 |
+
def val_dataloader(self):
|
175 |
+
if self.dataset == "mnist":
|
176 |
+
val_set = MNIST('data/', download=True, train=False,
|
177 |
+
transform=self.data_transform)
|
178 |
+
elif self.dataset == "fashion-mnist":
|
179 |
+
val_set = FashionMNIST(
|
180 |
+
'data/', download=True, train=False,
|
181 |
+
transform=self.data_transform)
|
182 |
+
elif self.dataset == "celeba":
|
183 |
+
val_set = CelebA('data/', download=False, split="valid", transform=self.data_transform)
|
184 |
+
return DataLoader(val_set, batch_size=self.batch_size)
|
185 |
+
|
186 |
+
def scale_image(self, img):
|
187 |
+
out = (img + 1) / 2
|
188 |
+
return out
|
189 |
+
|
190 |
+
def interpolate(self, x1, x2):
|
191 |
+
|
192 |
+
assert x1.shape == x2.shape, "Inputs must be of the same shape"
|
193 |
+
if x1.dim() == 3:
|
194 |
+
x1 = x1.unsqueeze(0)
|
195 |
+
if x2.dim() == 3:
|
196 |
+
x2 = x2.unsqueeze(0)
|
197 |
+
if self.training:
|
198 |
+
raise Exception(
|
199 |
+
"This function should not be called when model is still "
|
200 |
+
"in training mode. Use model.eval() before calling the "
|
201 |
+
"function")
|
202 |
+
mu1, lv1 = self.encode(x1)
|
203 |
+
mu2, lv2 = self.encode(x2)
|
204 |
+
z1 = self.reparametrize(mu1, lv1)
|
205 |
+
z2 = self.reparametrize(mu2, lv2)
|
206 |
+
weights = torch.arange(0.1, 0.9, 0.1)
|
207 |
+
intermediate = [self.decode(z1)]
|
208 |
+
for wt in weights:
|
209 |
+
inter = (1.-wt)*z1 + wt*z2
|
210 |
+
intermediate.append(self.decode(inter))
|
211 |
+
intermediate.append(self.decode(z2))
|
212 |
+
out = torch.stack(intermediate, dim=0).squeeze(1)
|
213 |
+
return out, (mu1, lv1), (mu2, lv2)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.0
|
2 |
+
python-box==5.3.0
|
3 |
+
tensorboardX==2.1
|
4 |
+
pydantic
|
5 |
+
streamlit==0.82.0
|
6 |
+
streamlit-drawable-canvas==0.8.0
|
7 |
+
pytorch_lightning>=1.1.1
|
8 |
+
torchvision==0.9.0
|
test.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning import Trainer
|
2 |
+
from models import vae_models
|
3 |
+
from config import config
|
4 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
5 |
+
import os
|
6 |
+
|
7 |
+
def make_model(config):
|
8 |
+
model_type = config.model_type
|
9 |
+
model_config = config.model_config
|
10 |
+
|
11 |
+
if model_type not in vae_models.keys():
|
12 |
+
raise NotImplementedError("Model Architecture not implemented")
|
13 |
+
else:
|
14 |
+
return vae_models[model_type](**model_config.dict())
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
model_type = config.model_type
|
19 |
+
model = vae_models[model_type].load_from_checkpoint("./saved_models/vae_alpha_1024_dim_128.ckpt")
|
20 |
+
logger = TensorBoardLogger(**config.log_config.dict())
|
21 |
+
trainer = Trainer(gpus=1, logger=logger)
|
22 |
+
trainer.test(model)
|
train.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning import Trainer
|
2 |
+
from models import vae_models
|
3 |
+
from config import config
|
4 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
5 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
6 |
+
import os
|
7 |
+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
8 |
+
|
9 |
+
|
10 |
+
def make_model(config):
|
11 |
+
model_type = config.model_type
|
12 |
+
model_config = config.model_config
|
13 |
+
|
14 |
+
if model_type not in vae_models.keys():
|
15 |
+
raise NotImplementedError("Model Architecture not implemented")
|
16 |
+
else:
|
17 |
+
return vae_models[model_type](**model_config.dict())
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
model = make_model(config)
|
22 |
+
train_config = config.train_config
|
23 |
+
logger = TensorBoardLogger(**config.log_config.dict())
|
24 |
+
trainer = Trainer(**train_config.dict(), logger=logger,
|
25 |
+
callbacks=LearningRateMonitor())
|
26 |
+
if train_config.auto_lr_find:
|
27 |
+
lr_finder = trainer.tuner.lr_find(model)
|
28 |
+
new_lr = lr_finder.suggestion()
|
29 |
+
print("Learning Rate Chosen:", new_lr)
|
30 |
+
model.lr = new_lr
|
31 |
+
trainer.fit(model)
|
32 |
+
else:
|
33 |
+
trainer.fit(model)
|
34 |
+
if not os.path.isdir("./saved_models"):
|
35 |
+
os.mkdir("./saved_models")
|
36 |
+
trainer.save_checkpoint(
|
37 |
+
f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt")
|
utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning import Trainer
|
2 |
+
from torchvision.utils import save_image
|
3 |
+
from models import vae_models
|
4 |
+
from config import config
|
5 |
+
from PIL import Image
|
6 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
7 |
+
import torch
|
8 |
+
from torch.nn.functional import interpolate
|
9 |
+
from torchvision.transforms import Resize, ToPILImage, Compose
|
10 |
+
from torchvision.utils import make_grid
|
11 |
+
|
12 |
+
def load_model(ckpt, model_type="vae"):
|
13 |
+
model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{ckpt}")
|
14 |
+
model.eval()
|
15 |
+
return model
|
16 |
+
|
17 |
+
def parse_model_file_name(file_name):
|
18 |
+
# Hard Coded Parsing based on the filenames that I use
|
19 |
+
substrings = file_name.split(".")[0].split("_")
|
20 |
+
name, alpha, dim = substrings[0], substrings[2], substrings[4]
|
21 |
+
new_name = ""
|
22 |
+
if name == "vae":
|
23 |
+
new_name += "Vanilla VAE"
|
24 |
+
new_name += f" | alpha={alpha}"
|
25 |
+
new_name += f" | dim={dim}"
|
26 |
+
return new_name
|
27 |
+
|
28 |
+
def tensor_to_img(tsr):
|
29 |
+
if tsr.ndim == 4:
|
30 |
+
tsr = tsr.squeeze(0)
|
31 |
+
|
32 |
+
transform = Compose([
|
33 |
+
ToPILImage()
|
34 |
+
])
|
35 |
+
img = transform(tsr)
|
36 |
+
return img
|
37 |
+
|
38 |
+
|
39 |
+
def resize_img(img, w, h):
|
40 |
+
return img.resize((w, h))
|
41 |
+
|
42 |
+
def canvas_to_tensor(canvas):
|
43 |
+
"""
|
44 |
+
Convert Image of RGBA to single channel B/W and convert from numpy array
|
45 |
+
to a PyTorch Tensor of [1,1,28,28]
|
46 |
+
"""
|
47 |
+
img = canvas.image_data
|
48 |
+
img = img[:, :, :-1] # Ignore alpha channel
|
49 |
+
img = img.mean(axis=2)
|
50 |
+
img = img/255
|
51 |
+
img = img*2 - 1.
|
52 |
+
img = torch.FloatTensor(img)
|
53 |
+
tens = img.unsqueeze(0).unsqueeze(0)
|
54 |
+
tens = interpolate(tens, (28, 28))
|
55 |
+
return tens
|
56 |
+
|
57 |
+
|
58 |
+
def export_to_onnx(ckpt):
|
59 |
+
model = load_model(ckpt)
|
60 |
+
filepath = "model.onnx"
|
61 |
+
test_iter = iter(model.test_dataloader())
|
62 |
+
sample, _ = next(test_iter)
|
63 |
+
model.to_onnx(filepath, sample, export_params=True)
|