diff --git a/model.py b/model.py index 0134c39..3a7826c 100755 --- a/model.py +++ b/model.py @@ -395,6 +395,7 @@ class Generator(nn.Module): style_dim, n_mlp, channel_multiplier=2, + additional_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, ): @@ -426,6 +427,9 @@ class Generator(nn.Module): 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } + if additional_multiplier > 1: + for k in list(self.channels.keys()): + self.channels[k] *= additional_multiplier self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( @@ -518,7 +522,7 @@ class Generator(nn.Module): getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) ] - if truncation < 1: + if truncation_latent is not None: style_t = [] for style in styles: