deep_privacy2 / dp2 /generator /stylegan_unet.py
haakohu's picture
fix
44539fc
import torch
import numpy as np
from dp2.layers import Sequential
from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock
from .base import BaseStyleGAN
from typing import List, Tuple
from .utils import spatial_embed_keypoints, mask_output
def get_chsize(imsize, cnum, max_imsize, max_cnum_mul):
n = int(np.log2(max_imsize) - np.log2(imsize))
mul = min(2**n, max_cnum_mul)
ch = cnum * mul
return int(ch)
class StyleGANUnet(BaseStyleGAN):
def __init__(
self,
scale_grad: bool,
im_channels: int,
min_fmap_resolution: int,
imsize: List[int],
cnum: int,
max_cnum_mul: int,
mask_output: bool,
conv_clamp: int,
input_cse: bool,
cse_nc: int,
n_middle_blocks: int,
input_keypoints: bool,
n_keypoints: int,
input_keypoint_indices: Tuple[int],
fix_errors: bool,
**kwargs
) -> None:
super().__init__(**kwargs)
self.n_keypoints = n_keypoints
self.input_keypoint_indices = list(input_keypoint_indices)
self.input_keypoints = input_keypoints
assert not (input_cse and input_keypoints)
cse_nc = 0 if cse_nc is None else cse_nc
self.imsize = imsize
self._cnum = cnum
self._max_cnum_mul = max_cnum_mul
self._min_fmap_resolution = min_fmap_resolution
self._image_channels = im_channels
self._max_imsize = max(imsize)
self.input_cse = input_cse
self.gain_unet = np.sqrt(1/3)
n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
encoder_layers = []
self.from_rgb = Conv2d(
im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices),
cnum, 1
)
for i in range(n_levels): # Encoder layers
resolution = [x//2**i for x in imsize]
in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
second_ch = in_ch
out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
down = 2
if i == 0: # first (lowest) block. Downsampling is performed at the start of the block
down = 1
if i == n_levels - 1:
out_ch = second_ch
block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors)
encoder_layers.append(block)
self._encoder_out_shape = [
get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul),
*resolution]
self.encoder = torch.nn.ModuleList(encoder_layers)
# initialize decoder
decoder_layers = []
for i in range(n_levels):
resolution = [x//2**(n_levels-1-i) for x in imsize]
in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
if i == 0: # first (lowest) block
in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
up = 1
if i != n_levels - 1:
up = 2
block = ResidualBlock(
in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3),
w_dim=self.style_net.w_dim, norm=True, up=up,
fix_residual=fix_errors
)
decoder_layers.append(block)
if i != 0:
unet_block = Conv2d(
in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True,
gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5))
setattr(self, f"unet_block{i}", unet_block)
# Initialize "middle blocks" that do not have down/up sample
middle_blocks = []
for i in range(n_middle_blocks):
ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul)
block = ResidualBlock(
ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3),
w_dim=self.style_net.w_dim, norm=True,
)
middle_blocks.append(block)
if n_middle_blocks != 0:
self.middle_blocks = Sequential(*middle_blocks)
self.decoder = torch.nn.ModuleList(decoder_layers)
self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
# Initialize "middle blocks" that do not have down/up sample
self.decoder = torch.nn.ModuleList(decoder_layers)
self.scale_grad = scale_grad
self.mask_output = mask_output
def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs):
for i, layer in enumerate(self.decoder):
if i != 0:
unet_layer = getattr(self, f"unet_block{i}")
x = x + unet_layer(unet_features[-i])
x = layer(x, w=w, s=s)
x = self.to_rgb(x)
if self.mask_output:
x = mask_output(True, condition, x, mask)
return dict(img=x)
def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs):
if self.input_cse:
x = torch.cat((condition, mask, embedding, E_mask), dim=1)
else:
x = torch.cat((condition, mask), dim=1)
if self.input_keypoints:
keypoints = keypoints[:, self.input_keypoint_indices]
one_hot_pose = spatial_embed_keypoints(keypoints, x)
x = torch.cat((x, one_hot_pose), dim=1)
x = self.from_rgb(x)
unet_features = []
for i, layer in enumerate(self.encoder):
x = layer(x)
if i != len(self.encoder)-1:
unet_features.append(x)
if hasattr(self, "middle_blocks"):
for layer in self.middle_blocks:
x = layer(x)
return x, unet_features
def forward(
self, condition, mask,
z=None, embedding=None, w=None, update_emas=False, x=None,
s=None,
keypoints=None,
unet_features=None,
E_mask=None,
**kwargs):
# Used to skip sampling from encoder in inference. E.g. for w projection.
if x is not None and unet_features is not None:
assert not self.training
else:
x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs)
if w is None:
if z is None:
z = self.get_z(condition)
w = self.get_w(z, update_emas=update_emas)
return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs)
class ComodStyleUNet(StyleGANUnet):
def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None:
super().__init__(**kwargs)
min_fmap = min(self._encoder_out_shape[1:])
enc_out_ch = self._encoder_out_shape[0]
n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res)))
comod_layers = []
in_ch = enc_out_ch
for i in range(n_down):
comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod))
in_ch = 256
if n_down == 0:
comod_layers = [Conv2d(in_ch, 256, kernel_size=3)]
comod_layers.append(torch.nn.Flatten())
out_res = [x//2**n_down for x in self._encoder_out_shape[1:]]
in_ch_fc = np.prod(out_res) * 256
comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod))
self.comod_block = Sequential(*comod_layers)
self.comod_fc = FullyConnectedLayer(
512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod)
def forward_dec(self, x, w, unet_features, condition, mask, **kwargs):
y = self.comod_block(x)
y = torch.cat((y, w), dim=1)
y = self.comod_fc(y)
for i, layer in enumerate(self.decoder):
if i != 0:
unet_layer = getattr(self, f"unet_block{i}")
x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5))
x = layer(x, w=y)
x = self.to_rgb(x)
if self.mask_output:
x = mask_output(True, condition, x, mask)
return dict(img=x)
def get_comod_y(self, batch, w):
x, unet_features = self.forward_enc(**batch)
y = self.comod_block(x)
y = torch.cat((y, w), dim=1)
y = self.comod_fc(y)
return y