import torch from torch import nn class VariationalAutoEncoder(nn.Module): # Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image def __init__(self, inpud_dim, h_dim=200, z_dim=20): super().__init__() # encoder self.img_2hid = nn.Linear(inpud_dim, h_dim) self.hid_2mu = nn.Linear(h_dim, z_dim) self.hid_2sigma = nn.Linear(h_dim, z_dim) # decoder self.z_2hi = nn.Linear(z_dim, h_dim) self.hid_2img = nn.Linear(h_dim, inpud_dim) self.relu = nn.ReLU() def encode(self, x): # q_phi(z/x) h = self.relu(self.img_2hid(x)) mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) return mu, sigma def decode(self, z): # p_theta(x/z) h = self.relu(self.z_2hi(z)) x = self.hid_2img(h) return torch.sigmoid(x) # image values should be between zero and one. def forward(self, x): mu, sigma = self.encode(x) # parametirazation trick epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1 z_reparametrized = mu + sigma * epsilon x_reconstructed = self.decode(z_reparametrized) return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x if __name__ == "__main__": x = torch.randn(4,28*28) vae = VariationalAutoEncoder(inpud_dim=784) x_reconstructed, mu, sigma = vae(x) print(x_reconstructed.shape) print(mu.shape) print(sigma.shape) print(torch.mean(mu))