PTI / models /e4e /discriminator.py
ucalyptus's picture
simp
2d7efb8
from torch import nn
class LatentCodesDiscriminator(nn.Module):
def __init__(self, style_dim, n_mlp):
super().__init__()
self.style_dim = style_dim
layers = []
for i in range(n_mlp-1):
layers.append(
nn.Linear(style_dim, style_dim)
)
layers.append(nn.LeakyReLU(0.2))
layers.append(nn.Linear(512, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, w):
return self.mlp(w)