from torch import nn class LatentCodesDiscriminator(nn.Module): def __init__(self, style_dim, n_mlp): super().__init__() self.style_dim = style_dim layers = [] for i in range(n_mlp-1): layers.append( nn.Linear(style_dim, style_dim) ) layers.append(nn.LeakyReLU(0.2)) layers.append(nn.Linear(512, 1)) self.mlp = nn.Sequential(*layers) def forward(self, w): return self.mlp(w)