projectai / vae_gan.py
Matthew Frazer
Create vae_gan.py
8df7cd4 verified
# File 1: models/vae_gan.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAEEncoder(nn.Module):
def __init__(self, input_channels=3, latent_dim=512):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channels, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2)
)
self.fc_mu = nn.Linear(256*14*14, latent_dim)
self.fc_logvar = nn.Linear(256*14*14, latent_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
mu, logvar = self.fc_mu(x), self.fc_logvar(x)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
class GANDecoder(nn.Module):
def __init__(self, latent_dim=512):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(latent_dim, 1024*7*7),
nn.BatchNorm1d(1024*7*7),
nn.LeakyReLU(0.2)
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(1024, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 3, 3, 1, 1),
nn.Tanh()
)
def forward(self, z):
z = self.fc(z)
z = z.view(-1, 1024, 7, 7)
return self.conv(z)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2)
)
self.fc = nn.Linear(256*14*14, 1)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return torch.sigmoid(self.fc(x))