Jonas Becker
1st try
7f19394
raw
history blame contribute delete
No virus
2.9 kB
"""
Module containing the decoders.
"""
import numpy as np
import torch
from torch import nn
# ALL decoders should be called Decoder<Model>
def get_decoder(model_type):
model_type = model_type.lower().capitalize()
return eval("Decoder{}".format(model_type))
class DecoderBurgess(nn.Module):
def __init__(self, img_size,
latent_dim=10):
r"""Decoder of the model proposed in [1].
Parameters
----------
img_size : tuple of ints
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
latent_dim : int
Dimensionality of latent output.
Model Architecture (transposed for decoder)
------------
- 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
- 2 fully connected layers (each of 256 units)
- Latent distribution:
- 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
References:
[1] Burgess, Christopher P., et al. "Understanding disentangling in
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
"""
super(DecoderBurgess, self).__init__()
# Layer parameters
hid_channels = 32
kernel_size = 4
hidden_dim = 256
self.img_size = img_size
# Shape required to start transpose convs
self.reshape = (hid_channels, kernel_size, kernel_size)
n_chan = self.img_size[0]
self.img_size = img_size
# Fully connected layers
self.lin1 = nn.Linear(latent_dim, hidden_dim)
self.lin2 = nn.Linear(hidden_dim, hidden_dim)
self.lin3 = nn.Linear(hidden_dim, np.product(self.reshape))
# Convolutional layers
cnn_kwargs = dict(stride=2, padding=1)
# If input image is 64x64 do fourth convolution
if self.img_size[1] == self.img_size[2] == 64:
self.convT_64 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
self.convT1 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
self.convT2 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
self.convT3 = nn.ConvTranspose2d(hid_channels, n_chan, kernel_size, **cnn_kwargs)
def forward(self, z):
batch_size = z.size(0)
# Fully connected layers with ReLu activations
x = torch.relu(self.lin1(z))
x = torch.relu(self.lin2(x))
x = torch.relu(self.lin3(x))
x = x.view(batch_size, *self.reshape)
# Convolutional layers with ReLu activations
if self.img_size[1] == self.img_size[2] == 64:
x = torch.relu(self.convT_64(x))
x = torch.relu(self.convT1(x))
x = torch.relu(self.convT2(x))
# Sigmoid activation for final conv layer
x = torch.sigmoid(self.convT3(x))
return x