“Transcendental-Programmer”
FEAT: added server coordination and model aggregation logic
754afec
"""vae.py module."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class VAE(nn.Module):
def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List[int]):
super(VAE, self).__init__()
# Encoder
modules = []
in_features = input_dim
for h_dim in hidden_dims:
modules.append(nn.Linear(in_features, h_dim))
modules.append(nn.ReLU())
in_features = h_dim
self.encoder = nn.Sequential(*modules)
# Latent space
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
# Decoder
modules = []
hidden_dims.reverse()
in_features = latent_dim
for h_dim in hidden_dims:
modules.append(nn.Linear(in_features, h_dim))
modules.append(nn.ReLU())
in_features = h_dim
modules.append(nn.Linear(hidden_dims[-1], input_dim))
self.decoder = nn.Sequential(*modules)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_var(h)
def decode(self, z):
return self.decoder(z)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var