import torch import torch.nn as nn import torch.nn.functional as F import scipy.signal from .blurpool import BlurPool from .official_stylegan3_model_helper import SEL, SEL_unet_pro, MappingNetwork, FullyConnectedLayer, modulated_conv2d, SynthesisInput from third_party.stylegan3_official_ops import filtered_lrelu from third_party.stylegan3_official_ops import upfirdn2d from third_party.stylegan3_official_ops import bias_act class UNetBlock(nn.Module): def __init__(self, w_dim, in_channel, latent_channel, out_channel, ks=3, layer_num=2): super().__init__() self.ks = ks self.layer_num = layer_num self.weight1 = nn.Parameter(torch.randn([latent_channel, in_channel, ks, ks])) self.weight2 = nn.Parameter(torch.randn([out_channel, latent_channel, ks, ks])) self.bias1 = nn.Parameter(torch.zeros([latent_channel])) self.bias2 = nn.Parameter(torch.zeros([out_channel])) self.affine1 = FullyConnectedLayer(w_dim, in_channel, bias_init=1) self.affine2 = FullyConnectedLayer(w_dim, latent_channel, bias_init=1) if self.layer_num == 3: self.weight_mid = nn.Parameter(torch.randn([latent_channel, latent_channel, ks, ks])) self.bias_mid = nn.Parameter(torch.zeros([latent_channel])) self.affine_mid = FullyConnectedLayer(w_dim, latent_channel, bias_init=1) def forward(self, x, *w): s1 = self.affine1(w[0]) if self.layer_num == 3: s_mid = self.affine_mid(w[1]) s2 = self.affine2(w[2]) else: s2 = self.affine2(w[1]) x = modulated_conv2d(x, w=self.weight1, s=s1, padding=self.ks//2) x = bias_act.bias_act(x, self.bias1.to(x.dtype), act='lrelu') if self.layer_num == 3: x = modulated_conv2d(x, w=self.weight_mid, s=s_mid, padding=self.ks//2) x = bias_act.bias_act(x, self.bias_mid.to(x.dtype), act='lrelu') x = modulated_conv2d(x, w=self.weight2, s=s2, padding=self.ks//2) x = bias_act.bias_act(x, self.bias2.to(x.dtype), act='lrelu') return x class UNet(nn.Module): def __init__(self, w_dim, in_dim=3, base_dim=64, ks=3, block_num=3, layer_num=2, filt_size=3, output_dim=3, label_nc=14, sel_type='normal', img_resolution=256, wo_transform = False,): super().__init__() self.block_num = block_num self.layer_num = layer_num self.sel_type = sel_type if self.sel_type == 'normal': self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) else: for i in range(block_num): self.register_buffer(f'down_filter_{i}', self.design_lowpass_filter(numtaps=12, cutoff=2**((block_num-i+1)/2), width=None, fs=img_resolution//(2**i))) self.register_buffer(f'sel_down_filter_{i}', self.design_lowpass_filter(numtaps=6*2**i, cutoff=2**((block_num-i+2)/2), width=None, fs=img_resolution//(2**(i-1)))) self.input = SynthesisInput(w_dim=w_dim, channels=in_dim, size=img_resolution, sampling_rate=img_resolution, bound_len=0, bandwidth=4, wo_transform=wo_transform) # what is the bandwidth encoder_list, sel_enc_list, sel_dec_list, decoder_list, bp_list = [], [], [], [], [] for i in range(block_num): if i == 0: encoder_list.append(UNetBlock(w_dim, in_dim, base_dim, base_dim, layer_num=layer_num)) else: encoder_list.append(UNetBlock(w_dim, base_dim * 2 ** (i-1), base_dim * 2 ** i, base_dim * 2 ** i, layer_num=layer_num)) decoder_list.append(UNetBlock(w_dim, base_dim * 2 ** (block_num-i), base_dim * 2 ** (block_num-i-1), base_dim * 2 ** (block_num-i-2) if i < block_num-1 else base_dim * 2 ** (block_num-i-1), layer_num=layer_num )) if self.sel_type == 'normal': sel_enc_list.append(SEL(in_dim if i==0 else base_dim * 2 ** (i-1), label_nc)) sel_dec_list.append(SEL(base_dim * 2 ** (block_num-i-1), label_nc)) else: sel_enc_list.append(SEL_unet_pro(in_dim if i==0 else base_dim * 2 ** (i-1), label_nc, down_filter=getattr(self, f'sel_down_filter_{i}'))) sel_dec_list.append(SEL_unet_pro(base_dim * 2 ** (block_num-i-1), label_nc, down_filter=getattr(self, f'sel_down_filter_{block_num-i-1}'))) self.encoders = nn.ModuleList(encoder_list) self.decoders = nn.ModuleList(decoder_list) self.enc_sels = nn.ModuleList(sel_enc_list) self.dec_sels = nn.ModuleList(sel_dec_list) self.torgb = UNetBlock(w_dim, base_dim, base_dim, output_dim) @staticmethod def design_lowpass_filter(numtaps, cutoff, fs, width=None): if numtaps == 1: return None f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) return torch.as_tensor(f, dtype=torch.float32) def forward(self, ws, heatmap, **kwargs): ws = ws.unbind(1) x = self.input(ws[0]) ws = ws[1:] enc_x = [] for i in range(self.block_num): # modulate with SEL x = self.enc_sels[i] (x, heatmap) if self.layer_num==2: x = self.encoders[i] (x, ws[2*i], ws[2*i+1]) else: x = self.encoders[i] (x, ws[3*i], ws[3*i+1], ws[3*i+2]) enc_x.append(x) if self.sel_type == 'normal': x = self.pool(x) else: x = upfirdn2d.upfirdn2d(x=x, f=getattr(self, f'down_filter_{i}'), down=2, flip_filter=False, padding=5) ws = ws[self.layer_num*self.block_num: ] for i in range(self.block_num): x = F.interpolate(x, size=x.shape[-1] * 2, mode='bilinear', align_corners=False) # modulate with SEL x = self.dec_sels[i] (x, heatmap) if self.layer_num==2: x = self.decoders[i] (torch.cat([x, enc_x[-1-i]], 1), ws[2*i], ws[2*i+1]) else: x = self.decoders[i] (torch.cat([x, enc_x[-1-i]], 1), ws[3*i], ws[3*i+1], ws[3*i+2]) ws = ws[self.layer_num*self.block_num: ] x = self.torgb(x, ws[0], ws[1]) return x class Generator(nn.Module): def __init__(self, z_dim, c_dim, w_dim, img_resolution=256, img_channels=3, in_dim=3, base_dim=64, ks=3, block_num=3, layer_num=2, filt_size=3, output_dim=3, label_nc=14, sel_type='normal', wo_transform=False, **kwargs): super().__init__() self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=2*layer_num*block_num+3) self.synthesis = UNet(w_dim=w_dim, in_dim=in_dim, base_dim=64, ks=3, block_num=block_num, layer_num=layer_num, filt_size=3, output_dim=img_channels, label_nc=label_nc, sel_type=sel_type, img_resolution=img_resolution, wo_transform=wo_transform) def forward(self, z, c, heatmap, truncation_psi=1, truncation_cutoff=None, update_emas=False): ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) ret = self.synthesis(ws, heatmap=heatmap) return ret # class SELUNet(UNet): # def forward(self, x, hm): if __name__ == '__main__': # g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4, img_resolution=256, img_channels=32, sel_type='abn') # hm = torch.ones([10, 14, 256, 256]) # z = torch.zeros([10, 64]) # c = None # opt = g(z, c, hm) g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4,layer_num=3, img_resolution=512, img_channels=32, sel_type='abn') hm = torch.ones([10, 14, 512, 512]) z = torch.zeros([10, 64]) c = None opt = g(z, c, hm) g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4,layer_num=3, img_resolution=256, img_channels=32, sel_type='abn') hm = torch.ones([10, 14,256,256]) z = torch.zeros([10, 64]) c = None opt = g(z, c, hm)