|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class VAEEncoder(nn.Module): |
|
def __init__(self, input_channels=3, latent_dim=512): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(input_channels, 64, 4, 2, 1), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(64, 128, 4, 2, 1), |
|
nn.BatchNorm2d(128), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(128, 256, 4, 2, 1), |
|
nn.BatchNorm2d(256), |
|
nn.LeakyReLU(0.2) |
|
) |
|
self.fc_mu = nn.Linear(256*14*14, latent_dim) |
|
self.fc_logvar = nn.Linear(256*14*14, latent_dim) |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5*logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps*std |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = x.view(x.size(0), -1) |
|
mu, logvar = self.fc_mu(x), self.fc_logvar(x) |
|
z = self.reparameterize(mu, logvar) |
|
return z, mu, logvar |
|
|
|
class GANDecoder(nn.Module): |
|
def __init__(self, latent_dim=512): |
|
super().__init__() |
|
self.fc = nn.Sequential( |
|
nn.Linear(latent_dim, 1024*7*7), |
|
nn.BatchNorm1d(1024*7*7), |
|
nn.LeakyReLU(0.2) |
|
) |
|
self.conv = nn.Sequential( |
|
nn.ConvTranspose2d(1024, 512, 4, 2, 1), |
|
nn.BatchNorm2d(512), |
|
nn.LeakyReLU(0.2), |
|
nn.ConvTranspose2d(512, 256, 4, 2, 1), |
|
nn.BatchNorm2d(256), |
|
nn.LeakyReLU(0.2), |
|
nn.ConvTranspose2d(256, 128, 4, 2, 1), |
|
nn.BatchNorm2d(128), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(128, 3, 3, 1, 1), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, z): |
|
z = self.fc(z) |
|
z = z.view(-1, 1024, 7, 7) |
|
return self.conv(z) |
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(3, 64, 4, 2, 1), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(64, 128, 4, 2, 1), |
|
nn.BatchNorm2d(128), |
|
nn.LeakyReLU(0.2), |
|
nn.Conv2d(128, 256, 4, 2, 1), |
|
nn.BatchNorm2d(256), |
|
nn.LeakyReLU(0.2) |
|
) |
|
self.fc = nn.Linear(256*14*14, 1) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = x.view(x.size(0), -1) |
|
return torch.sigmoid(self.fc(x)) |