souranil3d commited on
Commit
16906c1
1 Parent(s): 5f3dfcf

First commit for VAE space

Browse files
Files changed (15) hide show
  1. .gitignore +11 -0
  2. Dockerfile +9 -0
  3. README.md +20 -12
  4. app.py +55 -0
  5. config.yaml +20 -0
  6. config/__init__.py +2 -0
  7. config/config.py +59 -0
  8. inference.py +24 -0
  9. models/__init__.py +11 -0
  10. models/conv_vae.py +239 -0
  11. models/vae.py +213 -0
  12. requirements.txt +8 -0
  13. test.py +22 -0
  14. train.py +37 -0
  15. utils.py +63 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ medium/
2
+ data/
3
+ .vscode/
4
+ __pycache__/
5
+ .ipynb_checkpoints/
6
+ lightning_logs/
7
+ log_images/
8
+ logs/
9
+ *.env
10
+ .idea/
11
+ saved_models/
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim-buster
2
+ WORKDIR /app
3
+ EXPOSE $PORT
4
+
5
+ COPY requirements.txt /
6
+ RUN pip3 install -r /requirements.txt
7
+ COPY . /app
8
+
9
+ CMD streamlit run app.py --server.port $PORT
README.md CHANGED
@@ -1,12 +1,20 @@
1
- ---
2
- title: VAE
3
- emoji: 📊
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: streamlit
7
- app_file: app.py
8
- pinned: false
9
- license: apache-2.0
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
1
+ ### VAE with Pytorch-Lightning
2
+
3
+ This is inspired from vae-playground. This is an example where we test out vae and conv_vae models with multiple datasets
4
+ like MNIST, celeb-a and MNIST-Fashion datasets.
5
+
6
+ This also comes with an example streamlit app & deployed at huggingface.
7
+
8
+
9
+ ## Model Training
10
+
11
+ You can train the VAE models by using `train.py` and editing the `config.yaml` file. \
12
+ Hyperparameters to change are:
13
+ - model_type [vae|conv_vae]
14
+ - alpha
15
+ - hidden_dim
16
+ - dataset [celeba|mnist|fashion-mnist]
17
+
18
+ There are other configurations that can be changed if required like height, width, channels etc. It also contains the pytorch-lightning configs as well.
19
+
20
+
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_drawable_canvas import st_canvas
3
+ import os
4
+ import utils
5
+ from PIL import Image
6
+
7
+
8
+ st.set_page_config("VAE MNIST Pytorch Lightning")
9
+ st.title("VAE Playground")
10
+ # title_img = Image.open("images/title_img.jpg")
11
+
12
+ # st.image(title_img)
13
+ st.markdown(
14
+ "This is a simple streamlit app to showcase how a simple VAEs."
15
+ )
16
+
17
+ def load_model_files():
18
+ files = os.listdir("./saved_models/")
19
+ # Docker creates some whiteout files which mig
20
+ files = [i for i in files if ".ckpt" in i]
21
+ clean_names = [utils.parse_model_file_name(name) for name in files]
22
+ return {k: v for k, v in zip(clean_names, files)}
23
+
24
+
25
+ file_name_map = load_model_files()
26
+ files = list(file_name_map.keys())
27
+
28
+ st.header("🖼️ Image Reconstruction", "recon")
29
+
30
+ with st.form("reconstruction"):
31
+ model_name = st.selectbox("Choose Model:", files,
32
+ key="recon_model_select")
33
+ recon_model_name = file_name_map[model_name]
34
+ recon_canvas = st_canvas(
35
+ # Fixed fill color with some opacity
36
+ fill_color="rgba(255, 165, 0, 0.3)",
37
+ stroke_width=8,
38
+ stroke_color="#FFFFFF",
39
+ background_color="#000000",
40
+ update_streamlit=True,
41
+ height=150,
42
+ width=150,
43
+ drawing_mode="freedraw",
44
+ key="recon_canvas",
45
+ )
46
+ submit = st.form_submit_button("Perform Reconstruction")
47
+ if submit:
48
+ recon_model = utils.load_model(recon_model_name)
49
+ inp_tens = utils.canvas_to_tensor(recon_canvas)
50
+ _, _, out = recon_model(inp_tens)
51
+ out = (out+1)/2
52
+ out_img = utils.resize_img(utils.tensor_to_img(out), 150, 150)
53
+ if submit:
54
+ st.image(out_img)
55
+
config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_params:
2
+ max_epochs: 30
3
+ auto_lr_find: false
4
+ gpus: 1
5
+ model_params:
6
+ model_type: conv-vae # vae or conv-vae
7
+ lr: 0.005
8
+ batch_size: 1
9
+ hidden_size: 4096
10
+ latent_size: 128
11
+ alpha: 1024
12
+ dataset: "fashion-mnist"
13
+ save_images: true
14
+ save_path: "log_images/"
15
+ channels: 1
16
+ height: 64
17
+ width: 64
18
+ logger_params:
19
+ name: "conv-vae"
20
+ save_dir: "logs/"
config/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .config import config
2
+ __all__ = ["config"]
config/config.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, Union
3
+ import yaml
4
+
5
+
6
+ class TrainConfig(BaseModel):
7
+ max_epochs: int
8
+ auto_lr_find: Union[bool, int]
9
+ gpus: int
10
+
11
+
12
+ class VAEConfig(BaseModel):
13
+ model_type: str
14
+ hidden_size: int
15
+ latent_size: int
16
+ alpha: int
17
+ dataset: str
18
+ batch_size: Optional[int] = 64
19
+ save_images: Optional[bool] = False
20
+ lr: Optional[float] = None
21
+ save_path: Optional[str] = None
22
+
23
+
24
+ class ConvVAEConfig(VAEConfig):
25
+ channels: int
26
+ height: int
27
+ width: int
28
+
29
+
30
+ class LoggerConfig(BaseModel):
31
+ name: str
32
+ save_dir: str
33
+
34
+
35
+ class Config(BaseModel):
36
+ model_config: Union[VAEConfig, ConvVAEConfig]
37
+ train_config: TrainConfig
38
+ model_type: str
39
+ log_config: LoggerConfig
40
+
41
+
42
+ def load_config(path="config.yaml"):
43
+ config = yaml.load(open(path), yaml.SafeLoader)
44
+ model_type = config['model_params']['model_type']
45
+ if model_type == "vae":
46
+ model_config = VAEConfig(**config["model_params"])
47
+ elif model_type == "conv-vae":
48
+ model_config = ConvVAEConfig(**config["model_params"])
49
+ else:
50
+ raise NotImplementedError(f"Model {model_type} is not implemented")
51
+ train_config = TrainConfig(**config["training_params"])
52
+ log_config = LoggerConfig(**config["logger_params"])
53
+ config = Config(model_config=model_config, train_config=train_config,
54
+ model_type=model_type, log_config=log_config)
55
+
56
+ return config
57
+
58
+
59
+ config = load_config()
inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import vae_models
2
+ from config import config
3
+ from PIL import Image
4
+ from torchvision.transforms import Resize, ToPILImage, Compose
5
+
6
+ from utils import load_model, tensor_to_img, resize_img, export_to_onnx
7
+
8
+
9
+
10
+ def predict(model_ckpt="vae_alpha_1024_dim_128.ckpt"):
11
+ model_type = config.model_type
12
+ model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{model_ckpt}")
13
+ model.eval()
14
+ test_iter = iter(model.test_dataloader())
15
+ d, _ = next(test_iter)
16
+ _, _, out = model(d)
17
+ out_img = tensor_to_img(out)
18
+ return out_img
19
+
20
+
21
+
22
+ if __name__ == "__main__":
23
+ predict()
24
+ # export_to_onnx("./saved_models/vae.ckpt")
models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vae import VAE, Flatten, Stack # noqa: F401
2
+ from .conv_vae import Conv_VAE # noqa: F401
3
+
4
+ __all__ = [
5
+ 'VAE', 'Flatten', 'Stack'
6
+ 'Conv_VAE',
7
+ ]
8
+ vae_models = {
9
+ "conv-vae": Conv_VAE,
10
+ "vae": VAE
11
+ }
models/conv_vae.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vae import VAE, Flatten, Stack
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import os
6
+ import random
7
+ from typing import Optional
8
+ import torchvision.transforms as transforms
9
+ from torchvision.datasets import MNIST, FashionMNIST, CelebA
10
+ import torchvision.transforms as transforms
11
+ from torch.utils.data import DataLoader
12
+ from torchvision.utils import save_image
13
+ from torch.optim import Adam
14
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
15
+
16
+ class PrintShape(nn.Module):
17
+ def __init__(self):
18
+ super(PrintShape, self).__init__()
19
+
20
+ def forward(self, x):
21
+ # Do your print / debug stuff here
22
+ # print(f"Shape: {x.shape}")
23
+ return x
24
+
25
+ class UnFlatten(nn.Module):
26
+ def forward(self, input, size=4096):
27
+ # print("Unflatteing")
28
+ return input.view(input.size(0), size, 1, 1)
29
+
30
+
31
+ class Flatten(nn.Module):
32
+ def forward(self, input):
33
+ # print("Flattening")
34
+ return input.view(input.size(0), -1)
35
+
36
+ class Conv_VAE(pl.LightningModule):
37
+ def __init__(self, channels: int, height: int, width: int, lr: int,
38
+ latent_size: int, hidden_size: int, alpha: int, batch_size: int,
39
+ dataset: Optional[str] = None,
40
+ save_images: Optional[bool] = None,
41
+ save_path: Optional[str] = None, **kwargs):
42
+ super().__init__()
43
+ self.latent_size = latent_size
44
+ self.hidden_size = hidden_size
45
+ if save_images:
46
+ self.save_path = f'{save_path}/{kwargs["model_type"]}_images/'
47
+ self.save_hyperparameters()
48
+ self.save_images = save_images
49
+ self.lr = lr
50
+ self.batch_size = batch_size
51
+ self.alpha = alpha
52
+ self.dataset = dataset
53
+ assert not height % 4 and not width % 4, "Choose height and width to "\
54
+ "be divisible by 4"
55
+ self.channels = channels
56
+ self.height = height
57
+ self.width = width
58
+ self.latent_size = latent_size
59
+ self.save_hyperparameters()
60
+
61
+ self.data_transform = transforms.Compose([
62
+ transforms.Resize(64),
63
+ transforms.CenterCrop((64, 64)),
64
+ transforms.ToTensor()
65
+ ])
66
+
67
+
68
+ self.encoder = nn.Sequential(
69
+ PrintShape(),
70
+ nn.Conv2d(self.channels, 32, kernel_size=3, stride=2, padding=1),
71
+ nn.BatchNorm2d(32),
72
+ nn.LeakyReLU(),
73
+ PrintShape(),
74
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
75
+ nn.BatchNorm2d(64),
76
+ nn.LeakyReLU(),
77
+ PrintShape(),
78
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
79
+ nn.BatchNorm2d(128),
80
+ nn.LeakyReLU(),
81
+ PrintShape(),
82
+ nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
83
+ nn.BatchNorm2d(256),
84
+ nn.LeakyReLU(),
85
+ PrintShape(),
86
+ Flatten(),
87
+ PrintShape(),
88
+ )
89
+
90
+ self.fc1 = nn.Linear(self.hidden_size, self.latent_size)
91
+ self.fc2 = nn.Linear(self.latent_size, self.hidden_size)
92
+
93
+ self.decoder = nn.Sequential(
94
+ PrintShape(),
95
+ # nn.Linear(self.hidden_size, self.hidden_size),
96
+ # PrintShape(),
97
+ # nn.BatchNorm1d(self.hidden_size),
98
+ UnFlatten(),
99
+ PrintShape(),
100
+ nn.ConvTranspose2d(self.hidden_size, 256, kernel_size=6, stride=2, padding=1),
101
+ PrintShape(),
102
+ nn.LeakyReLU(),
103
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
104
+ nn.BatchNorm2d(128),
105
+ PrintShape(),
106
+ nn.LeakyReLU(),
107
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
108
+ nn.BatchNorm2d(64),
109
+ PrintShape(),
110
+ nn.LeakyReLU(),
111
+ nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
112
+ nn.BatchNorm2d(32),
113
+ PrintShape(),
114
+ nn.LeakyReLU(),
115
+ nn.ConvTranspose2d(32, self.channels, kernel_size=4, stride=2, padding=1),
116
+ nn.BatchNorm2d(self.channels),
117
+ PrintShape(),
118
+ nn.Sigmoid(),
119
+ )
120
+
121
+ def encode(self, x):
122
+ hidden = self.encoder(x)
123
+ mu, log_var = self.fc1(hidden), self.fc1(hidden)
124
+ # print("Encoded")
125
+ return mu, log_var
126
+
127
+ def decode(self, z):
128
+ # print("Decoding")
129
+ # f = nn.Linear(self.latent_size, self.hidden_size)
130
+ z = self.fc2(z)
131
+ # print(f"L: {z.shape}")
132
+ x = self.decoder(z)
133
+ return x
134
+
135
+ def reparametrize(self, mu, log_var):
136
+ # Reparametrization Trick to allow gradients to backpropagate from the
137
+ # stochastic part of the model
138
+ sigma = torch.exp(0.5*log_var)
139
+ z = torch.randn_like(sigma)
140
+ return mu + sigma*z
141
+
142
+ def training_step(self, batch, batch_idx):
143
+ x, _ = batch
144
+ mu, log_var, x_out = self.forward(x)
145
+ kl_loss = (-0.5*(1+log_var - mu**2 -
146
+ torch.exp(log_var)).sum(dim=1)).mean(dim=0)
147
+ recon_loss_criterion = nn.MSELoss()
148
+ recon_loss = recon_loss_criterion(x, x_out)
149
+ # print(kl_loss.item(),recon_loss.item())
150
+ loss = recon_loss*self.alpha + kl_loss
151
+
152
+ self.log('train_loss', loss, on_step=False,
153
+ on_epoch=True, prog_bar=True)
154
+ return loss
155
+
156
+ def validation_step(self, batch, batch_idx):
157
+ x, _ = batch
158
+ mu, log_var, x_out = self.forward(x)
159
+
160
+ kl_loss = (-0.5*(1+log_var - mu**2 -
161
+ torch.exp(log_var)).sum(dim=1)).mean(dim=0)
162
+ recon_loss_criterion = nn.MSELoss()
163
+ recon_loss = recon_loss_criterion(x, x_out)
164
+ # print(kl_loss.item(),recon_loss.item())
165
+ loss = recon_loss*self.alpha + kl_loss
166
+ self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True)
167
+ self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True)
168
+ self.log('val_loss', loss, on_step=False, on_epoch=True)
169
+ # print(x.mean(),x_out.mean())
170
+ return x_out, loss
171
+
172
+ def validation_epoch_end(self, outputs):
173
+ if not self.save_images:
174
+ return
175
+ if not os.path.exists(self.save_path):
176
+ os.makedirs(self.save_path)
177
+ choice = random.choice(outputs)
178
+ output_sample = choice[0]
179
+ output_sample = output_sample.reshape(-1, 1, self.width, self.height)
180
+ # output_sample = self.scale_image(output_sample)
181
+ save_image(
182
+ output_sample,
183
+ f"{self.save_path}/epoch_{self.current_epoch+1}.png",
184
+ # value_range=(-1, 1)
185
+ )
186
+
187
+ def configure_optimizers(self):
188
+ optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate))
189
+ lr_scheduler = ReduceLROnPlateau(optimizer,)
190
+ return {
191
+ "optimizer": optimizer, "lr_scheduler": lr_scheduler,
192
+ "monitor": "val_loss"
193
+ }
194
+
195
+ def forward(self, x):
196
+ mu, log_var = self.encode(x)
197
+ hidden = self.reparametrize(mu, log_var)
198
+ output = self.decode(hidden)
199
+ return mu, log_var, output
200
+
201
+ # Functions for dataloading
202
+ def train_dataloader(self):
203
+ if self.dataset == "mnist":
204
+ train_set = MNIST('data/', download=True,
205
+ train=True, transform=self.data_transform)
206
+ elif self.dataset == "fashion-mnist":
207
+ train_set = FashionMNIST(
208
+ 'data/', download=True, train=True,
209
+ transform=self.data_transform)
210
+ elif self.dataset == "celeba":
211
+ train_set = CelebA('data/', download=False, split="train", transform=self.data_transform)
212
+ return DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
213
+
214
+ def val_dataloader(self):
215
+ if self.dataset == "mnist":
216
+ val_set = MNIST('data/', download=True, train=False,
217
+ transform=self.data_transform)
218
+ elif self.dataset == "fashion-mnist":
219
+ val_set = FashionMNIST(
220
+ 'data/', download=True, train=False,
221
+ transform=self.data_transform)
222
+ elif self.dataset == "celeba":
223
+ val_set = CelebA('data/', download=False, split="valid", transform=self.data_transform)
224
+ return DataLoader(val_set, batch_size=self.batch_size)
225
+
226
+ def test_dataloader(self):
227
+ if self.dataset == "mnist":
228
+ val_set = MNIST('data/', download=True, train=False,
229
+ transform=self.data_transform)
230
+ elif self.dataset == "fashion-mnist":
231
+ val_set = FashionMNIST(
232
+ 'data/', download=True, train=False,
233
+ transform=self.data_transform)
234
+ elif self.dataset == "celeba":
235
+ val_set = CelebA('data/', download=False, split="test", transform=self.data_transform)
236
+ return DataLoader(val_set, batch_size=self.batch_size)
237
+
238
+
239
+
models/vae.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ import random
5
+ from torchvision.datasets import MNIST, FashionMNIST, CelebA
6
+ import torchvision.transforms as transforms
7
+ from torch.utils.data import DataLoader
8
+ from torchvision.utils import save_image
9
+ from torch.optim import Adam
10
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
11
+
12
+ import os
13
+ from typing import Optional
14
+
15
+
16
+ class Flatten(nn.Module):
17
+ def forward(self, x):
18
+ return x.view(x.size(0), -1)
19
+
20
+
21
+ class Stack(nn.Module):
22
+ def __init__(self, channels, height, width):
23
+ super(Stack, self).__init__()
24
+ self.channels = channels
25
+ self.height = height
26
+ self.width = width
27
+
28
+ def forward(self, x):
29
+ return x.view(x.size(0), self.channels, self.height, self.width)
30
+
31
+
32
+ class VAE(pl.LightningModule):
33
+ def __init__(self, latent_size: int, hidden_size: int, alpha: int, lr: float,
34
+ batch_size: int,
35
+ dataset: Optional[str] = None,
36
+ save_images: Optional[bool] = None,
37
+ save_path: Optional[str] = None, **kwargs):
38
+ """Init function for the VAE
39
+
40
+ Args:
41
+
42
+ latent_size (int): Latent Hidden Size
43
+ alpha (int): Hyperparameter to control the importance of
44
+ reconstruction loss vs KL-Divergence Loss
45
+ lr (float): Learning Rate, will not be used if auto_lr_find is used.
46
+ dataset (Optional[str]): Dataset to used
47
+ save_images (Optional[bool]): Boolean to decide whether to save images
48
+ save_path (Optional[str]): Path to save images
49
+ """
50
+
51
+ super().__init__()
52
+ self.latent_size = latent_size
53
+ self.hidden_size = hidden_size
54
+ if save_images:
55
+ self.save_path = f'{save_path}/{kwargs["model_type"]}_images/'
56
+ self.save_hyperparameters()
57
+ self.save_images = save_images
58
+ self.lr = lr
59
+ self.batch_size = batch_size
60
+ self.encoder = nn.Sequential(
61
+ Flatten(),
62
+ nn.Linear(784, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1),
63
+ nn.Linear(392, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1),
64
+ nn.Linear(196, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1),
65
+ nn.Linear(128, latent_size)
66
+ )
67
+ self.hidden2mu = nn.Linear(latent_size, latent_size)
68
+ self.hidden2log_var = nn.Linear(latent_size, latent_size)
69
+ self.alpha = alpha
70
+ self.decoder = nn.Sequential(
71
+ nn.Linear(latent_size, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.1),
72
+ nn.Linear(128, 196), nn.BatchNorm1d(196), nn.LeakyReLU(0.1),
73
+ nn.Linear(196, 392), nn.BatchNorm1d(392), nn.LeakyReLU(0.1),
74
+ nn.Linear(392, 784),
75
+ Stack(1, 28, 28),
76
+ nn.Tanh()
77
+ )
78
+ self.height = kwargs.get("height")
79
+ self.width = kwargs.get("width")
80
+ self.data_transform = transforms.Compose([
81
+ transforms.ToTensor(),
82
+ transforms.Lambda(lambda x:2*x-1.)])
83
+ self.dataset = dataset
84
+
85
+ def encode(self, x):
86
+ hidden = self.encoder(x)
87
+ mu = self.hidden2mu(hidden)
88
+ log_var = self.hidden2log_var(hidden)
89
+ return mu, log_var
90
+
91
+ def decode(self, x):
92
+ x = self.decoder(x)
93
+ return x
94
+
95
+ def reparametrize(self, mu, log_var):
96
+ # Reparametrization Trick to allow gradients to backpropagate from the
97
+ # stochastic part of the model
98
+ sigma = torch.exp(0.5*log_var)
99
+ z = torch.randn_like(sigma)
100
+ return mu + sigma*z
101
+
102
+ def training_step(self, batch, batch_idx):
103
+ x, _ = batch
104
+ mu, log_var, x_out = self.forward(x)
105
+ kl_loss = (-0.5*(1+log_var - mu**2 -
106
+ torch.exp(log_var)).sum(dim=1)).mean(dim=0)
107
+ recon_loss_criterion = nn.MSELoss()
108
+ recon_loss = recon_loss_criterion(x, x_out)
109
+ # print(kl_loss.item(),recon_loss.item())
110
+ loss = recon_loss*self.alpha + kl_loss
111
+
112
+ self.log('train_loss', loss, on_step=False,
113
+ on_epoch=True, prog_bar=True)
114
+ return loss
115
+
116
+ def validation_step(self, batch, batch_idx):
117
+ x, _ = batch
118
+ mu, log_var, x_out = self.forward(x)
119
+
120
+ kl_loss = (-0.5*(1+log_var - mu**2 -
121
+ torch.exp(log_var)).sum(dim=1)).mean(dim=0)
122
+ recon_loss_criterion = nn.MSELoss()
123
+ recon_loss = recon_loss_criterion(x, x_out)
124
+ # print(kl_loss.item(),recon_loss.item())
125
+ loss = recon_loss*self.alpha + kl_loss
126
+ self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True)
127
+ self.log('val_recon_loss', recon_loss, on_step=False, on_epoch=True)
128
+ self.log('val_loss', loss, on_step=False, on_epoch=True)
129
+ # print(x.mean(),x_out.mean())
130
+ return x_out, loss
131
+
132
+ def validation_epoch_end(self, outputs):
133
+ if not self.save_images:
134
+ return
135
+ if not os.path.exists(self.save_path):
136
+ os.makedirs(self.save_path)
137
+ choice = random.choice(outputs)
138
+ output_sample = choice[0]
139
+ output_sample = output_sample.reshape(-1, 1, self.width, self.height)
140
+ # output_sample = self.scale_image(output_sample)
141
+ save_image(
142
+ output_sample,
143
+ f"{self.save_path}/epoch_{self.current_epoch+1}.png",
144
+ # value_range=(-1, 1)
145
+ )
146
+
147
+ def configure_optimizers(self):
148
+ optimizer = Adam(self.parameters(), lr=(self.lr or self.learning_rate))
149
+ lr_scheduler = ReduceLROnPlateau(optimizer,)
150
+ return {
151
+ "optimizer": optimizer, "lr_scheduler": lr_scheduler,
152
+ "monitor": "val_loss"
153
+ }
154
+
155
+ def forward(self, x):
156
+ mu, log_var = self.encode(x)
157
+ hidden = self.reparametrize(mu, log_var)
158
+ output = self.decode(hidden)
159
+ return mu, log_var, output
160
+
161
+ # Functions for dataloading
162
+ def train_dataloader(self):
163
+ if self.dataset == "mnist":
164
+ train_set = MNIST('data/', download=True,
165
+ train=True, transform=self.data_transform)
166
+ elif self.dataset == "fashion-mnist":
167
+ train_set = FashionMNIST(
168
+ 'data/', download=True, train=True,
169
+ transform=self.data_transform)
170
+ elif self.dataset == "celeba":
171
+ train_set = CelebA('data/', download=False, split="train", transform=self.data_transform)
172
+ return DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
173
+
174
+ def val_dataloader(self):
175
+ if self.dataset == "mnist":
176
+ val_set = MNIST('data/', download=True, train=False,
177
+ transform=self.data_transform)
178
+ elif self.dataset == "fashion-mnist":
179
+ val_set = FashionMNIST(
180
+ 'data/', download=True, train=False,
181
+ transform=self.data_transform)
182
+ elif self.dataset == "celeba":
183
+ val_set = CelebA('data/', download=False, split="valid", transform=self.data_transform)
184
+ return DataLoader(val_set, batch_size=self.batch_size)
185
+
186
+ def scale_image(self, img):
187
+ out = (img + 1) / 2
188
+ return out
189
+
190
+ def interpolate(self, x1, x2):
191
+
192
+ assert x1.shape == x2.shape, "Inputs must be of the same shape"
193
+ if x1.dim() == 3:
194
+ x1 = x1.unsqueeze(0)
195
+ if x2.dim() == 3:
196
+ x2 = x2.unsqueeze(0)
197
+ if self.training:
198
+ raise Exception(
199
+ "This function should not be called when model is still "
200
+ "in training mode. Use model.eval() before calling the "
201
+ "function")
202
+ mu1, lv1 = self.encode(x1)
203
+ mu2, lv2 = self.encode(x2)
204
+ z1 = self.reparametrize(mu1, lv1)
205
+ z2 = self.reparametrize(mu2, lv2)
206
+ weights = torch.arange(0.1, 0.9, 0.1)
207
+ intermediate = [self.decode(z1)]
208
+ for wt in weights:
209
+ inter = (1.-wt)*z1 + wt*z2
210
+ intermediate.append(self.decode(inter))
211
+ intermediate.append(self.decode(z2))
212
+ out = torch.stack(intermediate, dim=0).squeeze(1)
213
+ return out, (mu1, lv1), (mu2, lv2)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.8.0
2
+ python-box==5.3.0
3
+ tensorboardX==2.1
4
+ pydantic
5
+ streamlit==0.82.0
6
+ streamlit-drawable-canvas==0.8.0
7
+ pytorch_lightning>=1.1.1
8
+ torchvision==0.9.0
test.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import Trainer
2
+ from models import vae_models
3
+ from config import config
4
+ from pytorch_lightning.loggers import TensorBoardLogger
5
+ import os
6
+
7
+ def make_model(config):
8
+ model_type = config.model_type
9
+ model_config = config.model_config
10
+
11
+ if model_type not in vae_models.keys():
12
+ raise NotImplementedError("Model Architecture not implemented")
13
+ else:
14
+ return vae_models[model_type](**model_config.dict())
15
+
16
+
17
+ if __name__ == "__main__":
18
+ model_type = config.model_type
19
+ model = vae_models[model_type].load_from_checkpoint("./saved_models/vae_alpha_1024_dim_128.ckpt")
20
+ logger = TensorBoardLogger(**config.log_config.dict())
21
+ trainer = Trainer(gpus=1, logger=logger)
22
+ trainer.test(model)
train.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import Trainer
2
+ from models import vae_models
3
+ from config import config
4
+ from pytorch_lightning.callbacks import LearningRateMonitor
5
+ from pytorch_lightning.loggers import TensorBoardLogger
6
+ import os
7
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
8
+
9
+
10
+ def make_model(config):
11
+ model_type = config.model_type
12
+ model_config = config.model_config
13
+
14
+ if model_type not in vae_models.keys():
15
+ raise NotImplementedError("Model Architecture not implemented")
16
+ else:
17
+ return vae_models[model_type](**model_config.dict())
18
+
19
+
20
+ if __name__ == "__main__":
21
+ model = make_model(config)
22
+ train_config = config.train_config
23
+ logger = TensorBoardLogger(**config.log_config.dict())
24
+ trainer = Trainer(**train_config.dict(), logger=logger,
25
+ callbacks=LearningRateMonitor())
26
+ if train_config.auto_lr_find:
27
+ lr_finder = trainer.tuner.lr_find(model)
28
+ new_lr = lr_finder.suggestion()
29
+ print("Learning Rate Chosen:", new_lr)
30
+ model.lr = new_lr
31
+ trainer.fit(model)
32
+ else:
33
+ trainer.fit(model)
34
+ if not os.path.isdir("./saved_models"):
35
+ os.mkdir("./saved_models")
36
+ trainer.save_checkpoint(
37
+ f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt")
utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import Trainer
2
+ from torchvision.utils import save_image
3
+ from models import vae_models
4
+ from config import config
5
+ from PIL import Image
6
+ from pytorch_lightning.loggers import TensorBoardLogger
7
+ import torch
8
+ from torch.nn.functional import interpolate
9
+ from torchvision.transforms import Resize, ToPILImage, Compose
10
+ from torchvision.utils import make_grid
11
+
12
+ def load_model(ckpt, model_type="vae"):
13
+ model = vae_models[model_type].load_from_checkpoint(f"./saved_models/{ckpt}")
14
+ model.eval()
15
+ return model
16
+
17
+ def parse_model_file_name(file_name):
18
+ # Hard Coded Parsing based on the filenames that I use
19
+ substrings = file_name.split(".")[0].split("_")
20
+ name, alpha, dim = substrings[0], substrings[2], substrings[4]
21
+ new_name = ""
22
+ if name == "vae":
23
+ new_name += "Vanilla VAE"
24
+ new_name += f" | alpha={alpha}"
25
+ new_name += f" | dim={dim}"
26
+ return new_name
27
+
28
+ def tensor_to_img(tsr):
29
+ if tsr.ndim == 4:
30
+ tsr = tsr.squeeze(0)
31
+
32
+ transform = Compose([
33
+ ToPILImage()
34
+ ])
35
+ img = transform(tsr)
36
+ return img
37
+
38
+
39
+ def resize_img(img, w, h):
40
+ return img.resize((w, h))
41
+
42
+ def canvas_to_tensor(canvas):
43
+ """
44
+ Convert Image of RGBA to single channel B/W and convert from numpy array
45
+ to a PyTorch Tensor of [1,1,28,28]
46
+ """
47
+ img = canvas.image_data
48
+ img = img[:, :, :-1] # Ignore alpha channel
49
+ img = img.mean(axis=2)
50
+ img = img/255
51
+ img = img*2 - 1.
52
+ img = torch.FloatTensor(img)
53
+ tens = img.unsqueeze(0).unsqueeze(0)
54
+ tens = interpolate(tens, (28, 28))
55
+ return tens
56
+
57
+
58
+ def export_to_onnx(ckpt):
59
+ model = load_model(ckpt)
60
+ filepath = "model.onnx"
61
+ test_iter = iter(model.test_dataloader())
62
+ sample, _ = next(test_iter)
63
+ model.to_onnx(filepath, sample, export_params=True)