Johannes Kolbe commited on
Commit
ff2b8e3
1 Parent(s): 3b72cdb

add original sefa files back in

Browse files
SessionState.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adds pre-session state to StreamLit.
2
+
3
+ This file is borrowed from
4
+ https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
5
+ """
6
+
7
+ # pylint: disable=protected-access
8
+
9
+ try:
10
+ import streamlit.ReportThread as ReportThread
11
+ from streamlit.server.Server import Server
12
+ except ModuleNotFoundError:
13
+ # Streamlit >= 0.65.0
14
+ import streamlit.report_thread as ReportThread
15
+ from streamlit.server.server import Server
16
+
17
+
18
+ class SessionState(object):
19
+ """Hack to add per-session state to Streamlit.
20
+
21
+ Usage
22
+ -----
23
+
24
+ >>> import SessionState
25
+ >>>
26
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
27
+ >>> session_state.user_name
28
+ ''
29
+ >>> session_state.user_name = 'Mary'
30
+ >>> session_state.favorite_color
31
+ 'black'
32
+
33
+ Since you set user_name above, next time your script runs this will be the
34
+ result:
35
+ >>> session_state = get(user_name='', favorite_color='black')
36
+ >>> session_state.user_name
37
+ 'Mary'
38
+
39
+ """
40
+
41
+ def __init__(self, **kwargs):
42
+ """A new SessionState object.
43
+
44
+ Parameters
45
+ ----------
46
+ **kwargs : any
47
+ Default values for the session state.
48
+
49
+ Example
50
+ -------
51
+ >>> session_state = SessionState(user_name='', favorite_color='black')
52
+ >>> session_state.user_name = 'Mary'
53
+ ''
54
+ >>> session_state.favorite_color
55
+ 'black'
56
+
57
+ """
58
+ for key, val in kwargs.items():
59
+ setattr(self, key, val)
60
+
61
+
62
+ def get(**kwargs):
63
+ """Gets a SessionState object for the current session.
64
+
65
+ Creates a new object if necessary.
66
+
67
+ Parameters
68
+ ----------
69
+ **kwargs : any
70
+ Default values you want to add to the session state, if we're creating a
71
+ new one.
72
+
73
+ Example
74
+ -------
75
+ >>> session_state = get(user_name='', favorite_color='black')
76
+ >>> session_state.user_name
77
+ ''
78
+ >>> session_state.user_name = 'Mary'
79
+ >>> session_state.favorite_color
80
+ 'black'
81
+
82
+ Since you set user_name above, next time your script runs this will be the
83
+ result:
84
+ >>> session_state = get(user_name='', favorite_color='black')
85
+ >>> session_state.user_name
86
+ 'Mary'
87
+
88
+ """
89
+ # Hack to get the session object from Streamlit.
90
+
91
+ ctx = ReportThread.get_report_ctx()
92
+
93
+ this_session = None
94
+
95
+ current_server = Server.get_current()
96
+ if hasattr(current_server, '_session_infos'):
97
+ # Streamlit < 0.56
98
+ session_infos = Server.get_current()._session_infos.values()
99
+ else:
100
+ session_infos = Server.get_current()._session_info_by_id.values()
101
+
102
+ for session_info in session_infos:
103
+ s = session_info.session
104
+ if (
105
+ # Streamlit < 0.54.0
106
+ (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
107
+ or
108
+ # Streamlit >= 0.54.0
109
+ (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
110
+ or
111
+ # Streamlit >= 0.65.2
112
+ (not hasattr(s, '_main_dg') and
113
+ s._uploaded_file_mgr == ctx.uploaded_file_mgr)
114
+ ):
115
+ this_session = s
116
+
117
+ if this_session is None:
118
+ raise RuntimeError(
119
+ "Oh noes. Couldn't get your Streamlit Session object. "
120
+ 'Are you doing something fancy with threads?')
121
+
122
+ # Got the session object! Now let's attach some state into it.
123
+
124
+ if not hasattr(this_session, '_custom_session_state'):
125
+ this_session._custom_session_state = SessionState(**kwargs)
126
+
127
+ return this_session._custom_session_state
128
+
129
+ # pylint: enable=protected-access
interface.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python 3.7
2
+ """Demo."""
3
+
4
+ import numpy as np
5
+ import torch
6
+ import streamlit as st
7
+ import SessionState
8
+
9
+ from models import parse_gan_type
10
+ from utils import to_tensor
11
+ from utils import postprocess
12
+ from utils import load_generator
13
+ from utils import factorize_weight
14
+
15
+
16
+ @st.cache(allow_output_mutation=True, show_spinner=False)
17
+ def get_model(model_name):
18
+ """Gets model by name."""
19
+ return load_generator(model_name)
20
+
21
+
22
+ @st.cache(allow_output_mutation=True, show_spinner=False)
23
+ def factorize_model(model, layer_idx):
24
+ """Factorizes semantics from target layers of the given model."""
25
+ return factorize_weight(model, layer_idx)
26
+
27
+
28
+ def sample(model, gan_type, num=1):
29
+ """Samples latent codes."""
30
+ codes = torch.randn(num, model.z_space_dim).cuda()
31
+ if gan_type == 'pggan':
32
+ codes = model.layer0.pixel_norm(codes)
33
+ elif gan_type == 'stylegan':
34
+ codes = model.mapping(codes)['w']
35
+ codes = model.truncation(codes,
36
+ trunc_psi=0.7,
37
+ trunc_layers=8)
38
+ elif gan_type == 'stylegan2':
39
+ codes = model.mapping(codes)['w']
40
+ codes = model.truncation(codes,
41
+ trunc_psi=0.5,
42
+ trunc_layers=18)
43
+ codes = codes.detach().cpu().numpy()
44
+ return codes
45
+
46
+
47
+ @st.cache(allow_output_mutation=True, show_spinner=False)
48
+ def synthesize(model, gan_type, code):
49
+ """Synthesizes an image with the give code."""
50
+ if gan_type == 'pggan':
51
+ image = model(to_tensor(code))['image']
52
+ elif gan_type in ['stylegan', 'stylegan2']:
53
+ image = model.synthesis(to_tensor(code))['image']
54
+ image = postprocess(image)[0]
55
+ return image
56
+
57
+
58
+ def main():
59
+ """Main function (loop for StreamLit)."""
60
+ st.title('Closed-Form Factorization of Latent Semantics in GANs')
61
+ st.sidebar.title('Options')
62
+ reset = st.sidebar.button('Reset')
63
+
64
+ model_name = st.sidebar.selectbox(
65
+ 'Model to Interpret',
66
+ ['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',
67
+ 'pggan_celebahq1024'])
68
+
69
+ model = get_model(model_name)
70
+ gan_type = parse_gan_type(model)
71
+ layer_idx = st.sidebar.selectbox(
72
+ 'Layers to Interpret',
73
+ ['all', '0-1', '2-5', '6-13'])
74
+ layers, boundaries, eigen_values = factorize_model(model, layer_idx)
75
+
76
+ num_semantics = st.sidebar.number_input(
77
+ 'Number of semantics', value=10, min_value=0, max_value=None, step=1)
78
+ steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
79
+ if gan_type == 'pggan':
80
+ max_step = 5.0
81
+ elif gan_type == 'stylegan':
82
+ max_step = 2.0
83
+ elif gan_type == 'stylegan2':
84
+ max_step = 15.0
85
+ for sem_idx in steps:
86
+ eigen_value = eigen_values[sem_idx]
87
+ steps[sem_idx] = st.sidebar.slider(
88
+ f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
89
+ value=0.0,
90
+ min_value=-max_step,
91
+ max_value=max_step,
92
+ step=0.04 * max_step if not reset else 0.0)
93
+
94
+ image_placeholder = st.empty()
95
+ button_placeholder = st.empty()
96
+
97
+ try:
98
+ base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
99
+ except FileNotFoundError:
100
+ base_codes = sample(model, gan_type)
101
+
102
+ state = SessionState.get(model_name=model_name,
103
+ code_idx=0,
104
+ codes=base_codes[0:1])
105
+ if state.model_name != model_name:
106
+ state.model_name = model_name
107
+ state.code_idx = 0
108
+ state.codes = base_codes[0:1]
109
+
110
+ if button_placeholder.button('Random', key=0):
111
+ state.code_idx += 1
112
+ if state.code_idx < base_codes.shape[0]:
113
+ state.codes = base_codes[state.code_idx][np.newaxis]
114
+ else:
115
+ state.codes = sample(model, gan_type)
116
+
117
+ code = state.codes.copy()
118
+ for sem_idx, step in steps.items():
119
+ if gan_type == 'pggan':
120
+ code += boundaries[sem_idx:sem_idx + 1] * step
121
+ elif gan_type in ['stylegan', 'stylegan2']:
122
+ code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
123
+ image = synthesize(model, gan_type, code)
124
+ image_placeholder.image(image / 255.0)
125
+
126
+
127
+ if __name__ == '__main__':
128
+ main()
models/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Collects all available models together."""
3
+
4
+ from .model_zoo import MODEL_ZOO
5
+ from .pggan_generator import PGGANGenerator
6
+ from .pggan_discriminator import PGGANDiscriminator
7
+ from .stylegan_generator import StyleGANGenerator
8
+ from .stylegan_discriminator import StyleGANDiscriminator
9
+ from .stylegan2_generator import StyleGAN2Generator
10
+ from .stylegan2_discriminator import StyleGAN2Discriminator
11
+
12
+ __all__ = [
13
+ 'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator',
14
+ 'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator',
15
+ 'build_generator', 'build_discriminator', 'build_model'
16
+ ]
17
+
18
+ _GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2']
19
+ _MODULES_ALLOWED = ['generator', 'discriminator']
20
+
21
+
22
+ def build_generator(gan_type, resolution, **kwargs):
23
+ """Builds generator by GAN type.
24
+
25
+ Args:
26
+ gan_type: GAN type to which the generator belong.
27
+ resolution: Synthesis resolution.
28
+ **kwargs: Additional arguments to build the generator.
29
+
30
+ Raises:
31
+ ValueError: If the `gan_type` is not supported.
32
+ NotImplementedError: If the `gan_type` is not implemented.
33
+ """
34
+ if gan_type not in _GAN_TYPES_ALLOWED:
35
+ raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
36
+ f'Types allowed: {_GAN_TYPES_ALLOWED}.')
37
+
38
+ if gan_type == 'pggan':
39
+ return PGGANGenerator(resolution, **kwargs)
40
+ if gan_type == 'stylegan':
41
+ return StyleGANGenerator(resolution, **kwargs)
42
+ if gan_type == 'stylegan2':
43
+ return StyleGAN2Generator(resolution, **kwargs)
44
+ raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
45
+
46
+
47
+ def build_discriminator(gan_type, resolution, **kwargs):
48
+ """Builds discriminator by GAN type.
49
+
50
+ Args:
51
+ gan_type: GAN type to which the discriminator belong.
52
+ resolution: Synthesis resolution.
53
+ **kwargs: Additional arguments to build the discriminator.
54
+
55
+ Raises:
56
+ ValueError: If the `gan_type` is not supported.
57
+ NotImplementedError: If the `gan_type` is not implemented.
58
+ """
59
+ if gan_type not in _GAN_TYPES_ALLOWED:
60
+ raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
61
+ f'Types allowed: {_GAN_TYPES_ALLOWED}.')
62
+
63
+ if gan_type == 'pggan':
64
+ return PGGANDiscriminator(resolution, **kwargs)
65
+ if gan_type == 'stylegan':
66
+ return StyleGANDiscriminator(resolution, **kwargs)
67
+ if gan_type == 'stylegan2':
68
+ return StyleGAN2Discriminator(resolution, **kwargs)
69
+ raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
70
+
71
+
72
+ def build_model(gan_type, module, resolution, **kwargs):
73
+ """Builds a GAN module (generator/discriminator/etc).
74
+
75
+ Args:
76
+ gan_type: GAN type to which the model belong.
77
+ module: GAN module to build, such as generator or discrimiantor.
78
+ resolution: Synthesis resolution.
79
+ **kwargs: Additional arguments to build the discriminator.
80
+
81
+ Raises:
82
+ ValueError: If the `module` is not supported.
83
+ NotImplementedError: If the `module` is not implemented.
84
+ """
85
+ if module not in _MODULES_ALLOWED:
86
+ raise ValueError(f'Invalid module: `{module}`!\n'
87
+ f'Modules allowed: {_MODULES_ALLOWED}.')
88
+
89
+ if module == 'generator':
90
+ return build_generator(gan_type, resolution, **kwargs)
91
+ if module == 'discriminator':
92
+ return build_discriminator(gan_type, resolution, **kwargs)
93
+ raise NotImplementedError(f'Unsupported module `{module}`!')
94
+
95
+
96
+ def parse_gan_type(module):
97
+ """Parses GAN type of a given module.
98
+
99
+ Args:
100
+ module: The module to parse GAN type from.
101
+
102
+ Returns:
103
+ A string, indicating the GAN type.
104
+
105
+ Raises:
106
+ ValueError: If the GAN type is unknown.
107
+ """
108
+ if isinstance(module, (PGGANGenerator, PGGANDiscriminator)):
109
+ return 'pggan'
110
+ if isinstance(module, (StyleGANGenerator, StyleGANDiscriminator)):
111
+ return 'stylegan'
112
+ if isinstance(module, (StyleGAN2Generator, StyleGAN2Discriminator)):
113
+ return 'stylegan2'
114
+ raise ValueError(f'Unable to parse GAN type from type `{type(module)}`!')
models/pggan_discriminator.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of discriminator described in PGGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1710.10196.pdf
5
+
6
+ Official TensorFlow implementation:
7
+ https://github.com/tkarras/progressive_growing_of_gans
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ __all__ = ['PGGANDiscriminator']
17
+
18
+ # Resolutions allowed.
19
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
20
+
21
+ # Initial resolution.
22
+ _INIT_RES = 4
23
+
24
+ # Default gain factor for weight scaling.
25
+ _WSCALE_GAIN = np.sqrt(2.0)
26
+
27
+
28
+ class PGGANDiscriminator(nn.Module):
29
+ """Defines the discriminator network in PGGAN.
30
+
31
+ NOTE: The discriminator takes images with `RGB` channel order and pixel
32
+ range [-1, 1] as inputs.
33
+
34
+ Settings for the network:
35
+
36
+ (1) resolution: The resolution of the input image.
37
+ (2) image_channels: Number of channels of the input image. (default: 3)
38
+ (3) label_size: Size of the additional label for conditional generation.
39
+ (default: 0)
40
+ (4) fused_scale: Whether to fused `conv2d` and `downsample` together,
41
+ resulting in `conv2d` with strides. (default: False)
42
+ (5) use_wscale: Whether to use weight scaling. (default: True)
43
+ (6) minibatch_std_group_size: Group size for the minibatch standard
44
+ deviation layer. 0 means disable. (default: 16)
45
+ (7) fmaps_base: Factor to control number of feature maps for each layer.
46
+ (default: 16 << 10)
47
+ (8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
48
+ """
49
+
50
+ def __init__(self,
51
+ resolution,
52
+ image_channels=3,
53
+ label_size=0,
54
+ fused_scale=False,
55
+ use_wscale=True,
56
+ minibatch_std_group_size=16,
57
+ fmaps_base=16 << 10,
58
+ fmaps_max=512):
59
+ """Initializes with basic settings.
60
+
61
+ Raises:
62
+ ValueError: If the `resolution` is not supported.
63
+ """
64
+ super().__init__()
65
+
66
+ if resolution not in _RESOLUTIONS_ALLOWED:
67
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
68
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
69
+
70
+ self.init_res = _INIT_RES
71
+ self.init_res_log2 = int(np.log2(self.init_res))
72
+ self.resolution = resolution
73
+ self.final_res_log2 = int(np.log2(self.resolution))
74
+ self.image_channels = image_channels
75
+ self.label_size = label_size
76
+ self.fused_scale = fused_scale
77
+ self.use_wscale = use_wscale
78
+ self.minibatch_std_group_size = minibatch_std_group_size
79
+ self.fmaps_base = fmaps_base
80
+ self.fmaps_max = fmaps_max
81
+
82
+ # Level of detail (used for progressive training).
83
+ self.register_buffer('lod', torch.zeros(()))
84
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
85
+
86
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
87
+ res = 2 ** res_log2
88
+ block_idx = self.final_res_log2 - res_log2
89
+
90
+ # Input convolution layer for each resolution.
91
+ self.add_module(
92
+ f'input{block_idx}',
93
+ ConvBlock(in_channels=self.image_channels,
94
+ out_channels=self.get_nf(res),
95
+ kernel_size=1,
96
+ padding=0,
97
+ use_wscale=self.use_wscale))
98
+ self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
99
+ f'FromRGB_lod{block_idx}/weight')
100
+ self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
101
+ f'FromRGB_lod{block_idx}/bias')
102
+
103
+ # Convolution block for each resolution (except the last one).
104
+ if res != self.init_res:
105
+ self.add_module(
106
+ f'layer{2 * block_idx}',
107
+ ConvBlock(in_channels=self.get_nf(res),
108
+ out_channels=self.get_nf(res),
109
+ use_wscale=self.use_wscale))
110
+ tf_layer0_name = 'Conv0'
111
+ self.add_module(
112
+ f'layer{2 * block_idx + 1}',
113
+ ConvBlock(in_channels=self.get_nf(res),
114
+ out_channels=self.get_nf(res // 2),
115
+ downsample=True,
116
+ fused_scale=self.fused_scale,
117
+ use_wscale=self.use_wscale))
118
+ tf_layer1_name = 'Conv1_down' if self.fused_scale else 'Conv1'
119
+
120
+ # Convolution block for last resolution.
121
+ else:
122
+ self.add_module(
123
+ f'layer{2 * block_idx}',
124
+ ConvBlock(
125
+ in_channels=self.get_nf(res),
126
+ out_channels=self.get_nf(res),
127
+ use_wscale=self.use_wscale,
128
+ minibatch_std_group_size=self.minibatch_std_group_size))
129
+ tf_layer0_name = 'Conv'
130
+ self.add_module(
131
+ f'layer{2 * block_idx + 1}',
132
+ DenseBlock(in_channels=self.get_nf(res) * res * res,
133
+ out_channels=self.get_nf(res // 2),
134
+ use_wscale=self.use_wscale))
135
+ tf_layer1_name = 'Dense0'
136
+
137
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
138
+ f'{res}x{res}/{tf_layer0_name}/weight')
139
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
140
+ f'{res}x{res}/{tf_layer0_name}/bias')
141
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
142
+ f'{res}x{res}/{tf_layer1_name}/weight')
143
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
144
+ f'{res}x{res}/{tf_layer1_name}/bias')
145
+
146
+ # Final dense block.
147
+ self.add_module(
148
+ f'layer{2 * block_idx + 2}',
149
+ DenseBlock(in_channels=self.get_nf(res // 2),
150
+ out_channels=1 + self.label_size,
151
+ use_wscale=self.use_wscale,
152
+ wscale_gain=1.0,
153
+ activation_type='linear'))
154
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
155
+ f'{res}x{res}/Dense1/weight')
156
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
157
+ f'{res}x{res}/Dense1/bias')
158
+
159
+ self.downsample = DownsamplingLayer()
160
+
161
+ def get_nf(self, res):
162
+ """Gets number of feature maps according to current resolution."""
163
+ return min(self.fmaps_base // res, self.fmaps_max)
164
+
165
+ def forward(self, image, lod=None, **_unused_kwargs):
166
+ expected_shape = (self.image_channels, self.resolution, self.resolution)
167
+ if image.ndim != 4 or image.shape[1:] != expected_shape:
168
+ raise ValueError(f'The input tensor should be with shape '
169
+ f'[batch_size, channel, height, width], where '
170
+ f'`channel` equals to {self.image_channels}, '
171
+ f'`height`, `width` equal to {self.resolution}!\n'
172
+ f'But `{image.shape}` is received!')
173
+
174
+ lod = self.lod.cpu().tolist() if lod is None else lod
175
+ if lod + self.init_res_log2 > self.final_res_log2:
176
+ raise ValueError(f'Maximum level-of-detail (lod) is '
177
+ f'{self.final_res_log2 - self.init_res_log2}, '
178
+ f'but `{lod}` is received!')
179
+
180
+ lod = self.lod.cpu().tolist()
181
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
182
+ block_idx = current_lod = self.final_res_log2 - res_log2
183
+ if current_lod <= lod < current_lod + 1:
184
+ x = self.__getattr__(f'input{block_idx}')(image)
185
+ elif current_lod - 1 < lod < current_lod:
186
+ alpha = lod - np.floor(lod)
187
+ x = (self.__getattr__(f'input{block_idx}')(image) * alpha +
188
+ x * (1 - alpha))
189
+ if lod < current_lod + 1:
190
+ x = self.__getattr__(f'layer{2 * block_idx}')(x)
191
+ x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
192
+ if lod > current_lod:
193
+ image = self.downsample(image)
194
+ x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
195
+ return x
196
+
197
+
198
+ class MiniBatchSTDLayer(nn.Module):
199
+ """Implements the minibatch standard deviation layer."""
200
+
201
+ def __init__(self, group_size=16, epsilon=1e-8):
202
+ super().__init__()
203
+ self.group_size = group_size
204
+ self.epsilon = epsilon
205
+
206
+ def forward(self, x):
207
+ if self.group_size <= 1:
208
+ return x
209
+ group_size = min(self.group_size, x.shape[0]) # [NCHW]
210
+ y = x.view(group_size, -1, x.shape[1], x.shape[2], x.shape[3]) # [GMCHW]
211
+ y = y - torch.mean(y, dim=0, keepdim=True) # [GMCHW]
212
+ y = torch.mean(y ** 2, dim=0) # [MCHW]
213
+ y = torch.sqrt(y + self.epsilon) # [MCHW]
214
+ y = torch.mean(y, dim=[1, 2, 3], keepdim=True) # [M111]
215
+ y = y.repeat(group_size, 1, x.shape[2], x.shape[3]) # [N1HW]
216
+ return torch.cat([x, y], dim=1)
217
+
218
+
219
+ class DownsamplingLayer(nn.Module):
220
+ """Implements the downsampling layer.
221
+
222
+ Basically, this layer can be used to downsample feature maps with average
223
+ pooling.
224
+ """
225
+
226
+ def __init__(self, scale_factor=2):
227
+ super().__init__()
228
+ self.scale_factor = scale_factor
229
+
230
+ def forward(self, x):
231
+ if self.scale_factor <= 1:
232
+ return x
233
+ return F.avg_pool2d(x,
234
+ kernel_size=self.scale_factor,
235
+ stride=self.scale_factor,
236
+ padding=0)
237
+
238
+
239
+ class ConvBlock(nn.Module):
240
+ """Implements the convolutional block.
241
+
242
+ Basically, this block executes minibatch standard deviation layer (if
243
+ needed), convolutional layer, activation layer, and downsampling layer (
244
+ if needed) in sequence.
245
+ """
246
+
247
+ def __init__(self,
248
+ in_channels,
249
+ out_channels,
250
+ kernel_size=3,
251
+ stride=1,
252
+ padding=1,
253
+ add_bias=True,
254
+ downsample=False,
255
+ fused_scale=False,
256
+ use_wscale=True,
257
+ wscale_gain=_WSCALE_GAIN,
258
+ activation_type='lrelu',
259
+ minibatch_std_group_size=0):
260
+ """Initializes with block settings.
261
+
262
+ Args:
263
+ in_channels: Number of channels of the input tensor.
264
+ out_channels: Number of channels of the output tensor.
265
+ kernel_size: Size of the convolutional kernels. (default: 3)
266
+ stride: Stride parameter for convolution operation. (default: 1)
267
+ padding: Padding parameter for convolution operation. (default: 1)
268
+ add_bias: Whether to add bias onto the convolutional result.
269
+ (default: True)
270
+ downsample: Whether to downsample the result after convolution.
271
+ (default: False)
272
+ fused_scale: Whether to fused `conv2d` and `downsample` together,
273
+ resulting in `conv2d` with strides. (default: False)
274
+ use_wscale: Whether to use weight scaling. (default: True)
275
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
276
+ activation_type: Type of activation. Support `linear` and `lrelu`.
277
+ (default: `lrelu`)
278
+ minibatch_std_group_size: Group size for the minibatch standard
279
+ deviation layer. 0 means disable. (default: 0)
280
+
281
+ Raises:
282
+ NotImplementedError: If the `activation_type` is not supported.
283
+ """
284
+ super().__init__()
285
+
286
+ if minibatch_std_group_size > 1:
287
+ in_channels = in_channels + 1
288
+ self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size)
289
+ else:
290
+ self.mbstd = nn.Identity()
291
+
292
+ if downsample and not fused_scale:
293
+ self.downsample = DownsamplingLayer()
294
+ else:
295
+ self.downsample = nn.Identity()
296
+
297
+ if downsample and fused_scale:
298
+ self.use_stride = True
299
+ self.stride = 2
300
+ self.padding = 1
301
+ else:
302
+ self.use_stride = False
303
+ self.stride = stride
304
+ self.padding = padding
305
+
306
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
307
+ fan_in = kernel_size * kernel_size * in_channels
308
+ wscale = wscale_gain / np.sqrt(fan_in)
309
+ if use_wscale:
310
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
311
+ self.wscale = wscale
312
+ else:
313
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
314
+ self.wscale = 1.0
315
+
316
+ if add_bias:
317
+ self.bias = nn.Parameter(torch.zeros(out_channels))
318
+ else:
319
+ self.bias = None
320
+
321
+ if activation_type == 'linear':
322
+ self.activate = nn.Identity()
323
+ elif activation_type == 'lrelu':
324
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
325
+ else:
326
+ raise NotImplementedError(f'Not implemented activation function: '
327
+ f'`{activation_type}`!')
328
+
329
+ def forward(self, x):
330
+ x = self.mbstd(x)
331
+ weight = self.weight * self.wscale
332
+ if self.use_stride:
333
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
334
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
335
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
336
+ x = F.conv2d(x,
337
+ weight=weight,
338
+ bias=self.bias,
339
+ stride=self.stride,
340
+ padding=self.padding)
341
+ x = self.activate(x)
342
+ x = self.downsample(x)
343
+ return x
344
+
345
+
346
+ class DenseBlock(nn.Module):
347
+ """Implements the dense block.
348
+
349
+ Basically, this block executes fully-connected layer, and activation layer.
350
+ """
351
+
352
+ def __init__(self,
353
+ in_channels,
354
+ out_channels,
355
+ add_bias=True,
356
+ use_wscale=True,
357
+ wscale_gain=_WSCALE_GAIN,
358
+ activation_type='lrelu'):
359
+ """Initializes with block settings.
360
+
361
+ Args:
362
+ in_channels: Number of channels of the input tensor.
363
+ out_channels: Number of channels of the output tensor.
364
+ add_bias: Whether to add bias onto the fully-connected result.
365
+ (default: True)
366
+ use_wscale: Whether to use weight scaling. (default: True)
367
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
368
+ activation_type: Type of activation. Support `linear` and `lrelu`.
369
+ (default: `lrelu`)
370
+
371
+ Raises:
372
+ NotImplementedError: If the `activation_type` is not supported.
373
+ """
374
+ super().__init__()
375
+ weight_shape = (out_channels, in_channels)
376
+ wscale = wscale_gain / np.sqrt(in_channels)
377
+ if use_wscale:
378
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
379
+ self.wscale = wscale
380
+ else:
381
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
382
+ self.wscale = 1.0
383
+
384
+ if add_bias:
385
+ self.bias = nn.Parameter(torch.zeros(out_channels))
386
+ else:
387
+ self.bias = None
388
+
389
+ if activation_type == 'linear':
390
+ self.activate = nn.Identity()
391
+ elif activation_type == 'lrelu':
392
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
393
+ else:
394
+ raise NotImplementedError(f'Not implemented activation function: '
395
+ f'`{activation_type}`!')
396
+
397
+ def forward(self, x):
398
+ if x.ndim != 2:
399
+ x = x.view(x.shape[0], -1)
400
+ x = F.linear(x, weight=self.weight * self.wscale, bias=self.bias)
401
+ x = self.activate(x)
402
+ return x
models/pggan_generator.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in PGGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1710.10196.pdf
5
+
6
+ Official TensorFlow implementation:
7
+ https://github.com/tkarras/progressive_growing_of_gans
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ __all__ = ['PGGANGenerator']
17
+
18
+ # Resolutions allowed.
19
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
20
+
21
+ # Initial resolution.
22
+ _INIT_RES = 4
23
+
24
+ # Default gain factor for weight scaling.
25
+ _WSCALE_GAIN = np.sqrt(2.0)
26
+
27
+
28
+ class PGGANGenerator(nn.Module):
29
+ """Defines the generator network in PGGAN.
30
+
31
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
32
+ [-1, 1].
33
+
34
+ Settings for the network:
35
+
36
+ (1) resolution: The resolution of the output image.
37
+ (2) z_space_dim: The dimension of the latent space, Z. (default: 512)
38
+ (3) image_channels: Number of channels of the output image. (default: 3)
39
+ (4) final_tanh: Whether to use `tanh` to control the final pixel range.
40
+ (default: False)
41
+ (5) label_size: Size of the additional label for conditional generation.
42
+ (default: 0)
43
+ (6) fused_scale: Whether to fused `upsample` and `conv2d` together,
44
+ resulting in `conv2d_transpose`. (default: False)
45
+ (7) use_wscale: Whether to use weight scaling. (default: True)
46
+ (8) fmaps_base: Factor to control number of feature maps for each layer.
47
+ (default: 16 << 10)
48
+ (9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
49
+ """
50
+
51
+ def __init__(self,
52
+ resolution,
53
+ z_space_dim=512,
54
+ image_channels=3,
55
+ final_tanh=False,
56
+ label_size=0,
57
+ fused_scale=False,
58
+ use_wscale=True,
59
+ fmaps_base=16 << 10,
60
+ fmaps_max=512):
61
+ """Initializes with basic settings.
62
+
63
+ Raises:
64
+ ValueError: If the `resolution` is not supported.
65
+ """
66
+ super().__init__()
67
+
68
+ if resolution not in _RESOLUTIONS_ALLOWED:
69
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
70
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
71
+
72
+ self.init_res = _INIT_RES
73
+ self.init_res_log2 = int(np.log2(self.init_res))
74
+ self.resolution = resolution
75
+ self.final_res_log2 = int(np.log2(self.resolution))
76
+ self.z_space_dim = z_space_dim
77
+ self.image_channels = image_channels
78
+ self.final_tanh = final_tanh
79
+ self.label_size = label_size
80
+ self.fused_scale = fused_scale
81
+ self.use_wscale = use_wscale
82
+ self.fmaps_base = fmaps_base
83
+ self.fmaps_max = fmaps_max
84
+
85
+ # Number of convolutional layers.
86
+ self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
87
+
88
+ # Level of detail (used for progressive training).
89
+ self.register_buffer('lod', torch.zeros(()))
90
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
91
+
92
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
93
+ res = 2 ** res_log2
94
+ block_idx = res_log2 - self.init_res_log2
95
+
96
+ # First convolution layer for each resolution.
97
+ if res == self.init_res:
98
+ self.add_module(
99
+ f'layer{2 * block_idx}',
100
+ ConvBlock(in_channels=self.z_space_dim + self.label_size,
101
+ out_channels=self.get_nf(res),
102
+ kernel_size=self.init_res,
103
+ padding=self.init_res - 1,
104
+ use_wscale=self.use_wscale))
105
+ tf_layer_name = 'Dense'
106
+ else:
107
+ self.add_module(
108
+ f'layer{2 * block_idx}',
109
+ ConvBlock(in_channels=self.get_nf(res // 2),
110
+ out_channels=self.get_nf(res),
111
+ upsample=True,
112
+ fused_scale=self.fused_scale,
113
+ use_wscale=self.use_wscale))
114
+ tf_layer_name = 'Conv0_up' if self.fused_scale else 'Conv0'
115
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
116
+ f'{res}x{res}/{tf_layer_name}/weight')
117
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
118
+ f'{res}x{res}/{tf_layer_name}/bias')
119
+
120
+ # Second convolution layer for each resolution.
121
+ self.add_module(
122
+ f'layer{2 * block_idx + 1}',
123
+ ConvBlock(in_channels=self.get_nf(res),
124
+ out_channels=self.get_nf(res),
125
+ use_wscale=self.use_wscale))
126
+ tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
127
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
128
+ f'{res}x{res}/{tf_layer_name}/weight')
129
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
130
+ f'{res}x{res}/{tf_layer_name}/bias')
131
+
132
+ # Output convolution layer for each resolution.
133
+ self.add_module(
134
+ f'output{block_idx}',
135
+ ConvBlock(in_channels=self.get_nf(res),
136
+ out_channels=self.image_channels,
137
+ kernel_size=1,
138
+ padding=0,
139
+ use_wscale=self.use_wscale,
140
+ wscale_gain=1.0,
141
+ activation_type='linear'))
142
+ self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
143
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
144
+ self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
145
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
146
+
147
+ self.upsample = UpsamplingLayer()
148
+ self.final_activate = nn.Tanh() if self.final_tanh else nn.Identity()
149
+
150
+ def get_nf(self, res):
151
+ """Gets number of feature maps according to current resolution."""
152
+ return min(self.fmaps_base // res, self.fmaps_max)
153
+
154
+ def forward(self, z, label=None, lod=None, **_unused_kwargs):
155
+ if z.ndim != 2 or z.shape[1] != self.z_space_dim:
156
+ raise ValueError(f'Input latent code should be with shape '
157
+ f'[batch_size, latent_dim], where '
158
+ f'`latent_dim` equals to {self.z_space_dim}!\n'
159
+ f'But `{z.shape}` is received!')
160
+ z = self.layer0.pixel_norm(z)
161
+ if self.label_size:
162
+ if label is None:
163
+ raise ValueError(f'Model requires an additional label '
164
+ f'(with size {self.label_size}) as input, '
165
+ f'but no label is received!')
166
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
167
+ raise ValueError(f'Input label should be with shape '
168
+ f'[batch_size, label_size], where '
169
+ f'`batch_size` equals to that of '
170
+ f'latent codes ({z.shape[0]}) and '
171
+ f'`label_size` equals to {self.label_size}!\n'
172
+ f'But `{label.shape}` is received!')
173
+ z = torch.cat((z, label), dim=1)
174
+
175
+ lod = self.lod.cpu().tolist() if lod is None else lod
176
+ if lod + self.init_res_log2 > self.final_res_log2:
177
+ raise ValueError(f'Maximum level-of-detail (lod) is '
178
+ f'{self.final_res_log2 - self.init_res_log2}, '
179
+ f'but `{lod}` is received!')
180
+
181
+ x = z.view(z.shape[0], self.z_space_dim + self.label_size, 1, 1)
182
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
183
+ current_lod = self.final_res_log2 - res_log2
184
+ if lod < current_lod + 1:
185
+ block_idx = res_log2 - self.init_res_log2
186
+ x = self.__getattr__(f'layer{2 * block_idx}')(x)
187
+ x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
188
+ if current_lod - 1 < lod <= current_lod:
189
+ image = self.__getattr__(f'output{block_idx}')(x)
190
+ elif current_lod < lod < current_lod + 1:
191
+ alpha = np.ceil(lod) - lod
192
+ image = (self.__getattr__(f'output{block_idx}')(x) * alpha +
193
+ self.upsample(image) * (1 - alpha))
194
+ elif lod >= current_lod + 1:
195
+ image = self.upsample(image)
196
+ image = self.final_activate(image)
197
+
198
+ results = {
199
+ 'z': z,
200
+ 'label': label,
201
+ 'image': image,
202
+ }
203
+ return results
204
+
205
+
206
+ class PixelNormLayer(nn.Module):
207
+ """Implements pixel-wise feature vector normalization layer."""
208
+
209
+ def __init__(self, epsilon=1e-8):
210
+ super().__init__()
211
+ self.eps = epsilon
212
+
213
+ def forward(self, x):
214
+ norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
215
+ return x / norm
216
+
217
+
218
+ class UpsamplingLayer(nn.Module):
219
+ """Implements the upsampling layer.
220
+
221
+ Basically, this layer can be used to upsample feature maps with nearest
222
+ neighbor interpolation.
223
+ """
224
+
225
+ def __init__(self, scale_factor=2):
226
+ super().__init__()
227
+ self.scale_factor = scale_factor
228
+
229
+ def forward(self, x):
230
+ if self.scale_factor <= 1:
231
+ return x
232
+ return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
233
+
234
+
235
+ class ConvBlock(nn.Module):
236
+ """Implements the convolutional block.
237
+
238
+ Basically, this block executes pixel-wise normalization layer, upsampling
239
+ layer (if needed), convolutional layer, and activation layer in sequence.
240
+ """
241
+
242
+ def __init__(self,
243
+ in_channels,
244
+ out_channels,
245
+ kernel_size=3,
246
+ stride=1,
247
+ padding=1,
248
+ add_bias=True,
249
+ upsample=False,
250
+ fused_scale=False,
251
+ use_wscale=True,
252
+ wscale_gain=_WSCALE_GAIN,
253
+ activation_type='lrelu'):
254
+ """Initializes with block settings.
255
+
256
+ Args:
257
+ in_channels: Number of channels of the input tensor.
258
+ out_channels: Number of channels of the output tensor.
259
+ kernel_size: Size of the convolutional kernels. (default: 3)
260
+ stride: Stride parameter for convolution operation. (default: 1)
261
+ padding: Padding parameter for convolution operation. (default: 1)
262
+ add_bias: Whether to add bias onto the convolutional result.
263
+ (default: True)
264
+ upsample: Whether to upsample the input tensor before convolution.
265
+ (default: False)
266
+ fused_scale: Whether to fused `upsample` and `conv2d` together,
267
+ resulting in `conv2d_transpose`. (default: False)
268
+ use_wscale: Whether to use weight scaling. (default: True)
269
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
270
+ activation_type: Type of activation. Support `linear` and `lrelu`.
271
+ (default: `lrelu`)
272
+
273
+ Raises:
274
+ NotImplementedError: If the `activation_type` is not supported.
275
+ """
276
+ super().__init__()
277
+
278
+ self.pixel_norm = PixelNormLayer()
279
+
280
+ if upsample and not fused_scale:
281
+ self.upsample = UpsamplingLayer()
282
+ else:
283
+ self.upsample = nn.Identity()
284
+
285
+ if upsample and fused_scale:
286
+ self.use_conv2d_transpose = True
287
+ weight_shape = (in_channels, out_channels, kernel_size, kernel_size)
288
+ self.stride = 2
289
+ self.padding = 1
290
+ else:
291
+ self.use_conv2d_transpose = False
292
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
293
+ self.stride = stride
294
+ self.padding = padding
295
+
296
+ fan_in = kernel_size * kernel_size * in_channels
297
+ wscale = wscale_gain / np.sqrt(fan_in)
298
+ if use_wscale:
299
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
300
+ self.wscale = wscale
301
+ else:
302
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
303
+ self.wscale = 1.0
304
+
305
+ if add_bias:
306
+ self.bias = nn.Parameter(torch.zeros(out_channels))
307
+ else:
308
+ self.bias = None
309
+
310
+ if activation_type == 'linear':
311
+ self.activate = nn.Identity()
312
+ elif activation_type == 'lrelu':
313
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
314
+ else:
315
+ raise NotImplementedError(f'Not implemented activation function: '
316
+ f'`{activation_type}`!')
317
+
318
+ def forward(self, x):
319
+ x = self.pixel_norm(x)
320
+ x = self.upsample(x)
321
+ weight = self.weight * self.wscale
322
+ if self.use_conv2d_transpose:
323
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
324
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
325
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
326
+ x = F.conv_transpose2d(x,
327
+ weight=weight,
328
+ bias=self.bias,
329
+ stride=self.stride,
330
+ padding=self.padding)
331
+ else:
332
+ x = F.conv2d(x,
333
+ weight=weight,
334
+ bias=self.bias,
335
+ stride=self.stride,
336
+ padding=self.padding)
337
+ x = self.activate(x)
338
+ return x
models/stylegan2_discriminator.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of discriminator described in StyleGAN2.
3
+
4
+ Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip
5
+ connections, increases model size and disables progressive growth. This script
6
+ ONLY supports config F in the original paper.
7
+
8
+ Paper: https://arxiv.org/pdf/1912.04958.pdf
9
+
10
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
11
+ """
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ __all__ = ['StyleGAN2Discriminator']
20
+
21
+ # Resolutions allowed.
22
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
23
+
24
+ # Initial resolution.
25
+ _INIT_RES = 4
26
+
27
+ # Architectures allowed.
28
+ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
29
+
30
+ # Default gain factor for weight scaling.
31
+ _WSCALE_GAIN = 1.0
32
+
33
+
34
+ class StyleGAN2Discriminator(nn.Module):
35
+ """Defines the discriminator network in StyleGAN2.
36
+
37
+ NOTE: The discriminator takes images with `RGB` channel order and pixel
38
+ range [-1, 1] as inputs.
39
+
40
+ Settings for the network:
41
+
42
+ (1) resolution: The resolution of the input image.
43
+ (2) image_channels: Number of channels of the input image. (default: 3)
44
+ (3) label_size: Size of the additional label for conditional generation.
45
+ (default: 0)
46
+ (4) architecture: Type of architecture. Support `origin`, `skip`, and
47
+ `resnet`. (default: `resnet`)
48
+ (5) use_wscale: Whether to use weight scaling. (default: True)
49
+ (6) minibatch_std_group_size: Group size for the minibatch standard
50
+ deviation layer. 0 means disable. (default: 4)
51
+ (7) minibatch_std_channels: Number of new channels after the minibatch
52
+ standard deviation layer. (default: 1)
53
+ (8) fmaps_base: Factor to control number of feature maps for each layer.
54
+ (default: 32 << 10)
55
+ (9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
56
+ """
57
+
58
+ def __init__(self,
59
+ resolution,
60
+ image_channels=3,
61
+ label_size=0,
62
+ architecture='resnet',
63
+ use_wscale=True,
64
+ minibatch_std_group_size=4,
65
+ minibatch_std_channels=1,
66
+ fmaps_base=32 << 10,
67
+ fmaps_max=512):
68
+ """Initializes with basic settings.
69
+
70
+ Raises:
71
+ ValueError: If the `resolution` is not supported, or `architecture`
72
+ is not supported.
73
+ """
74
+ super().__init__()
75
+
76
+ if resolution not in _RESOLUTIONS_ALLOWED:
77
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
78
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
79
+ if architecture not in _ARCHITECTURES_ALLOWED:
80
+ raise ValueError(f'Invalid architecture: `{architecture}`!\n'
81
+ f'Architectures allowed: '
82
+ f'{_ARCHITECTURES_ALLOWED}.')
83
+
84
+ self.init_res = _INIT_RES
85
+ self.init_res_log2 = int(np.log2(self.init_res))
86
+ self.resolution = resolution
87
+ self.final_res_log2 = int(np.log2(self.resolution))
88
+ self.image_channels = image_channels
89
+ self.label_size = label_size
90
+ self.architecture = architecture
91
+ self.use_wscale = use_wscale
92
+ self.minibatch_std_group_size = minibatch_std_group_size
93
+ self.minibatch_std_channels = minibatch_std_channels
94
+ self.fmaps_base = fmaps_base
95
+ self.fmaps_max = fmaps_max
96
+
97
+ self.pth_to_tf_var_mapping = {}
98
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
99
+ res = 2 ** res_log2
100
+ block_idx = self.final_res_log2 - res_log2
101
+
102
+ # Input convolution layer for each resolution (if needed).
103
+ if res_log2 == self.final_res_log2 or self.architecture == 'skip':
104
+ self.add_module(
105
+ f'input{block_idx}',
106
+ ConvBlock(in_channels=self.image_channels,
107
+ out_channels=self.get_nf(res),
108
+ kernel_size=1,
109
+ use_wscale=self.use_wscale))
110
+ self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
111
+ f'{res}x{res}/FromRGB/weight')
112
+ self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
113
+ f'{res}x{res}/FromRGB/bias')
114
+
115
+ # Convolution block for each resolution (except the last one).
116
+ if res != self.init_res:
117
+ self.add_module(
118
+ f'layer{2 * block_idx}',
119
+ ConvBlock(in_channels=self.get_nf(res),
120
+ out_channels=self.get_nf(res),
121
+ use_wscale=self.use_wscale))
122
+ tf_layer0_name = 'Conv0'
123
+ self.add_module(
124
+ f'layer{2 * block_idx + 1}',
125
+ ConvBlock(in_channels=self.get_nf(res),
126
+ out_channels=self.get_nf(res // 2),
127
+ scale_factor=2,
128
+ use_wscale=self.use_wscale))
129
+ tf_layer1_name = 'Conv1_down'
130
+
131
+ if self.architecture == 'resnet':
132
+ layer_name = f'skip_layer{block_idx}'
133
+ self.add_module(
134
+ layer_name,
135
+ ConvBlock(in_channels=self.get_nf(res),
136
+ out_channels=self.get_nf(res // 2),
137
+ kernel_size=1,
138
+ add_bias=False,
139
+ scale_factor=2,
140
+ use_wscale=self.use_wscale,
141
+ activation_type='linear'))
142
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
143
+ f'{res}x{res}/Skip/weight')
144
+
145
+ # Convolution block for last resolution.
146
+ else:
147
+ self.add_module(
148
+ f'layer{2 * block_idx}',
149
+ ConvBlock(in_channels=self.get_nf(res),
150
+ out_channels=self.get_nf(res),
151
+ use_wscale=self.use_wscale,
152
+ minibatch_std_group_size=minibatch_std_group_size,
153
+ minibatch_std_channels=minibatch_std_channels))
154
+ tf_layer0_name = 'Conv'
155
+ self.add_module(
156
+ f'layer{2 * block_idx + 1}',
157
+ DenseBlock(in_channels=self.get_nf(res) * res * res,
158
+ out_channels=self.get_nf(res // 2),
159
+ use_wscale=self.use_wscale))
160
+ tf_layer1_name = 'Dense0'
161
+
162
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
163
+ f'{res}x{res}/{tf_layer0_name}/weight')
164
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
165
+ f'{res}x{res}/{tf_layer0_name}/bias')
166
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
167
+ f'{res}x{res}/{tf_layer1_name}/weight')
168
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
169
+ f'{res}x{res}/{tf_layer1_name}/bias')
170
+
171
+ # Final dense block.
172
+ self.add_module(
173
+ f'layer{2 * block_idx + 2}',
174
+ DenseBlock(in_channels=self.get_nf(res // 2),
175
+ out_channels=max(self.label_size, 1),
176
+ use_wscale=self.use_wscale,
177
+ activation_type='linear'))
178
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
179
+ f'Output/weight')
180
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
181
+ f'Output/bias')
182
+
183
+ if self.architecture == 'skip':
184
+ self.downsample = DownsamplingLayer()
185
+
186
+ def get_nf(self, res):
187
+ """Gets number of feature maps according to current resolution."""
188
+ return min(self.fmaps_base // res, self.fmaps_max)
189
+
190
+ def forward(self, image, label=None, **_unused_kwargs):
191
+ expected_shape = (self.image_channels, self.resolution, self.resolution)
192
+ if image.ndim != 4 or image.shape[1:] != expected_shape:
193
+ raise ValueError(f'The input tensor should be with shape '
194
+ f'[batch_size, channel, height, width], where '
195
+ f'`channel` equals to {self.image_channels}, '
196
+ f'`height`, `width` equal to {self.resolution}!\n'
197
+ f'But `{image.shape}` is received!')
198
+ if self.label_size:
199
+ if label is None:
200
+ raise ValueError(f'Model requires an additional label '
201
+ f'(with size {self.label_size}) as inputs, '
202
+ f'but no label is received!')
203
+ batch_size = image.shape[0]
204
+ if label.ndim != 2 or label.shape != (batch_size, self.label_size):
205
+ raise ValueError(f'Input label should be with shape '
206
+ f'[batch_size, label_size], where '
207
+ f'`batch_size` equals to that of '
208
+ f'images ({image.shape[0]}) and '
209
+ f'`label_size` equals to {self.label_size}!\n'
210
+ f'But `{label.shape}` is received!')
211
+
212
+ x = self.input0(image)
213
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
214
+ block_idx = self.final_res_log2 - res_log2
215
+ if self.architecture == 'skip' and block_idx > 0:
216
+ image = self.downsample(image)
217
+ x = x + self.__getattr__(f'input{block_idx}')(image)
218
+ if self.architecture == 'resnet' and res_log2 != self.init_res_log2:
219
+ residual = self.__getattr__(f'skip_layer{block_idx}')(x)
220
+ x = self.__getattr__(f'layer{2 * block_idx}')(x)
221
+ x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
222
+ if self.architecture == 'resnet' and res_log2 != self.init_res_log2:
223
+ x = (x + residual) / np.sqrt(2.0)
224
+ x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
225
+
226
+ if self.label_size:
227
+ x = torch.sum(x * label, dim=1, keepdim=True)
228
+ return x
229
+
230
+
231
+ class MiniBatchSTDLayer(nn.Module):
232
+ """Implements the minibatch standard deviation layer."""
233
+
234
+ def __init__(self, group_size=4, new_channels=1, epsilon=1e-8):
235
+ super().__init__()
236
+ self.group_size = group_size
237
+ self.new_channels = new_channels
238
+ self.epsilon = epsilon
239
+
240
+ def forward(self, x):
241
+ if self.group_size <= 1:
242
+ return x
243
+ ng = min(self.group_size, x.shape[0])
244
+ nc = self.new_channels
245
+ temp_c = x.shape[1] // nc # [NCHW]
246
+ y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW]
247
+ y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW]
248
+ y = torch.mean(y ** 2, dim=0) # [MncHW]
249
+ y = torch.sqrt(y + self.epsilon) # [MncHW]
250
+ y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111]
251
+ y = torch.mean(y, dim=2) # [Mn11]
252
+ y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW]
253
+ return torch.cat([x, y], dim=1)
254
+
255
+
256
+ class DownsamplingLayer(nn.Module):
257
+ """Implements the downsampling layer.
258
+
259
+ This layer can also be used as filtering by setting `scale_factor` as 1.
260
+ """
261
+
262
+ def __init__(self, scale_factor=2, kernel=(1, 3, 3, 1), extra_padding=0):
263
+ super().__init__()
264
+ assert scale_factor >= 1
265
+ self.scale_factor = scale_factor
266
+
267
+ if extra_padding != 0:
268
+ assert scale_factor == 1
269
+
270
+ if kernel is None:
271
+ kernel = np.ones((scale_factor), dtype=np.float32)
272
+ else:
273
+ kernel = np.array(kernel, dtype=np.float32)
274
+ assert kernel.ndim == 1
275
+ kernel = np.outer(kernel, kernel)
276
+ kernel = kernel / np.sum(kernel)
277
+ assert kernel.ndim == 2
278
+ assert kernel.shape[0] == kernel.shape[1]
279
+ kernel = kernel[np.newaxis, np.newaxis]
280
+ self.register_buffer('kernel', torch.from_numpy(kernel))
281
+ self.kernel = self.kernel.flip(0, 1)
282
+ padding = kernel.shape[2] - scale_factor + extra_padding
283
+ self.padding = ((padding + 1) // 2, padding // 2,
284
+ (padding + 1) // 2, padding // 2)
285
+
286
+ def forward(self, x):
287
+ assert x.ndim == 4
288
+ channels = x.shape[1]
289
+ x = x.view(-1, 1, x.shape[2], x.shape[3])
290
+ x = F.pad(x, self.padding, mode='constant', value=0)
291
+ x = F.conv2d(x, self.kernel, stride=self.scale_factor)
292
+ x = x.view(-1, channels, x.shape[2], x.shape[3])
293
+ return x
294
+
295
+
296
+ class ConvBlock(nn.Module):
297
+ """Implements the convolutional block.
298
+
299
+ Basically, this block executes minibatch standard deviation layer (if
300
+ needed), filtering layer (if needed), convolutional layer, and activation
301
+ layer in sequence.
302
+ """
303
+
304
+ def __init__(self,
305
+ in_channels,
306
+ out_channels,
307
+ kernel_size=3,
308
+ add_bias=True,
309
+ scale_factor=1,
310
+ filtering_kernel=(1, 3, 3, 1),
311
+ use_wscale=True,
312
+ wscale_gain=_WSCALE_GAIN,
313
+ lr_mul=1.0,
314
+ activation_type='lrelu',
315
+ minibatch_std_group_size=0,
316
+ minibatch_std_channels=1):
317
+ """Initializes with block settings.
318
+
319
+ Args:
320
+ in_channels: Number of channels of the input tensor.
321
+ out_channels: Number of channels of the output tensor.
322
+ kernel_size: Size of the convolutional kernels. (default: 3)
323
+ add_bias: Whether to add bias onto the convolutional result.
324
+ (default: True)
325
+ scale_factor: Scale factor for downsampling. `1` means skip
326
+ downsampling. (default: 1)
327
+ filtering_kernel: Kernel used for filtering before downsampling.
328
+ (default: (1, 3, 3, 1))
329
+ use_wscale: Whether to use weight scaling. (default: True)
330
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
331
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
332
+ activation_type: Type of activation. Support `linear` and `lrelu`.
333
+ (default: `lrelu`)
334
+ minibatch_std_group_size: Group size for the minibatch standard
335
+ deviation layer. 0 means disable. (default: 0)
336
+ minibatch_std_channels: Number of new channels after the minibatch
337
+ standard deviation layer. (default: 1)
338
+
339
+ Raises:
340
+ NotImplementedError: If the `activation_type` is not supported.
341
+ """
342
+ super().__init__()
343
+
344
+ if minibatch_std_group_size > 1:
345
+ in_channels = in_channels + minibatch_std_channels
346
+ self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size,
347
+ new_channels=minibatch_std_channels)
348
+ else:
349
+ self.mbstd = nn.Identity()
350
+
351
+ if scale_factor > 1:
352
+ extra_padding = kernel_size - scale_factor
353
+ self.filter = DownsamplingLayer(scale_factor=1,
354
+ kernel=filtering_kernel,
355
+ extra_padding=extra_padding)
356
+ self.stride = scale_factor
357
+ self.padding = 0 # Padding is done in `DownsamplingLayer`.
358
+ else:
359
+ self.filter = nn.Identity()
360
+ assert kernel_size % 2 == 1
361
+ self.stride = 1
362
+ self.padding = kernel_size // 2
363
+
364
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
365
+ fan_in = kernel_size * kernel_size * in_channels
366
+ wscale = wscale_gain / np.sqrt(fan_in)
367
+ if use_wscale:
368
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
369
+ self.wscale = wscale * lr_mul
370
+ else:
371
+ self.weight = nn.Parameter(
372
+ torch.randn(*weight_shape) * wscale / lr_mul)
373
+ self.wscale = lr_mul
374
+
375
+ if add_bias:
376
+ self.bias = nn.Parameter(torch.zeros(out_channels))
377
+ else:
378
+ self.bias = None
379
+ self.bscale = lr_mul
380
+
381
+ if activation_type == 'linear':
382
+ self.activate = nn.Identity()
383
+ self.activate_scale = 1.0
384
+ elif activation_type == 'lrelu':
385
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
386
+ self.activate_scale = np.sqrt(2.0)
387
+ else:
388
+ raise NotImplementedError(f'Not implemented activation function: '
389
+ f'`{activation_type}`!')
390
+
391
+ def forward(self, x):
392
+ x = self.mbstd(x)
393
+ x = self.filter(x)
394
+ weight = self.weight * self.wscale
395
+ bias = self.bias * self.bscale if self.bias is not None else None
396
+ x = F.conv2d(x,
397
+ weight=weight,
398
+ bias=bias,
399
+ stride=self.stride,
400
+ padding=self.padding)
401
+ x = self.activate(x) * self.activate_scale
402
+ return x
403
+
404
+
405
+ class DenseBlock(nn.Module):
406
+ """Implements the dense block.
407
+
408
+ Basically, this block executes fully-connected layer and activation layer.
409
+ """
410
+
411
+ def __init__(self,
412
+ in_channels,
413
+ out_channels,
414
+ add_bias=True,
415
+ use_wscale=True,
416
+ wscale_gain=_WSCALE_GAIN,
417
+ lr_mul=1.0,
418
+ activation_type='lrelu'):
419
+ """Initializes with block settings.
420
+
421
+ Args:
422
+ in_channels: Number of channels of the input tensor.
423
+ out_channels: Number of channels of the output tensor.
424
+ add_bias: Whether to add bias onto the fully-connected result.
425
+ (default: True)
426
+ use_wscale: Whether to use weight scaling. (default: True)
427
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
428
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
429
+ activation_type: Type of activation. Support `linear` and `lrelu`.
430
+ (default: `lrelu`)
431
+
432
+ Raises:
433
+ NotImplementedError: If the `activation_type` is not supported.
434
+ """
435
+ super().__init__()
436
+ weight_shape = (out_channels, in_channels)
437
+ wscale = wscale_gain / np.sqrt(in_channels)
438
+ if use_wscale:
439
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
440
+ self.wscale = wscale * lr_mul
441
+ else:
442
+ self.weight = nn.Parameter(
443
+ torch.randn(*weight_shape) * wscale / lr_mul)
444
+ self.wscale = lr_mul
445
+
446
+ if add_bias:
447
+ self.bias = nn.Parameter(torch.zeros(out_channels))
448
+ else:
449
+ self.bias = None
450
+ self.bscale = lr_mul
451
+
452
+ if activation_type == 'linear':
453
+ self.activate = nn.Identity()
454
+ self.activate_scale = 1.0
455
+ elif activation_type == 'lrelu':
456
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
457
+ self.activate_scale = np.sqrt(2.0)
458
+ else:
459
+ raise NotImplementedError(f'Not implemented activation function: '
460
+ f'`{activation_type}`!')
461
+
462
+ def forward(self, x):
463
+ if x.ndim != 2:
464
+ x = x.view(x.shape[0], -1)
465
+ bias = self.bias * self.bscale if self.bias is not None else None
466
+ x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
467
+ x = self.activate(x) * self.activate_scale
468
+ return x
models/stylegan2_generator.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in StyleGAN2.
3
+
4
+ Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style
5
+ demodulation, adds skip connections, increases model size, and disables
6
+ progressive growth. This script ONLY supports config F in the original paper.
7
+
8
+ Paper: https://arxiv.org/pdf/1912.04958.pdf
9
+
10
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
11
+ """
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from .sync_op import all_gather
20
+
21
+ __all__ = ['StyleGAN2Generator']
22
+
23
+ # Resolutions allowed.
24
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
25
+
26
+ # Initial resolution.
27
+ _INIT_RES = 4
28
+
29
+ # Architectures allowed.
30
+ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
31
+
32
+ # Default gain factor for weight scaling.
33
+ _WSCALE_GAIN = 1.0
34
+
35
+
36
+ class StyleGAN2Generator(nn.Module):
37
+ """Defines the generator network in StyleGAN2.
38
+
39
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
40
+ [-1, 1].
41
+
42
+ Settings for the mapping network:
43
+
44
+ (1) z_space_dim: Dimension of the input latent space, Z. (default: 512)
45
+ (2) w_space_dim: Dimension of the outout latent space, W. (default: 512)
46
+ (3) label_size: Size of the additional label for conditional generation.
47
+ (default: 0)
48
+ (4)mapping_layers: Number of layers of the mapping network. (default: 8)
49
+ (5) mapping_fmaps: Number of hidden channels of the mapping network.
50
+ (default: 512)
51
+ (6) mapping_lr_mul: Learning rate multiplier for the mapping network.
52
+ (default: 0.01)
53
+ (7) repeat_w: Repeat w-code for different layers.
54
+
55
+ Settings for the synthesis network:
56
+
57
+ (1) resolution: The resolution of the output image.
58
+ (2) image_channels: Number of channels of the output image. (default: 3)
59
+ (3) final_tanh: Whether to use `tanh` to control the final pixel range.
60
+ (default: False)
61
+ (4) const_input: Whether to use a constant in the first convolutional layer.
62
+ (default: True)
63
+ (5) architecture: Type of architecture. Support `origin`, `skip`, and
64
+ `resnet`. (default: `resnet`)
65
+ (6) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together.
66
+ (default: True)
67
+ (7) demodulate: Whether to perform style demodulation. (default: True)
68
+ (8) use_wscale: Whether to use weight scaling. (default: True)
69
+ (9) fmaps_base: Factor to control number of feature maps for each layer.
70
+ (default: 16 << 10)
71
+ (10) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
72
+ """
73
+
74
+ def __init__(self,
75
+ resolution,
76
+ z_space_dim=512,
77
+ w_space_dim=512,
78
+ label_size=0,
79
+ mapping_layers=8,
80
+ mapping_fmaps=512,
81
+ mapping_lr_mul=0.01,
82
+ repeat_w=True,
83
+ image_channels=3,
84
+ final_tanh=False,
85
+ const_input=True,
86
+ architecture='skip',
87
+ fused_modulate=True,
88
+ demodulate=True,
89
+ use_wscale=True,
90
+ fmaps_base=32 << 10,
91
+ fmaps_max=512):
92
+ """Initializes with basic settings.
93
+
94
+ Raises:
95
+ ValueError: If the `resolution` is not supported, or `architecture`
96
+ is not supported.
97
+ """
98
+ super().__init__()
99
+
100
+ if resolution not in _RESOLUTIONS_ALLOWED:
101
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
102
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
103
+ if architecture not in _ARCHITECTURES_ALLOWED:
104
+ raise ValueError(f'Invalid architecture: `{architecture}`!\n'
105
+ f'Architectures allowed: '
106
+ f'{_ARCHITECTURES_ALLOWED}.')
107
+
108
+ self.init_res = _INIT_RES
109
+ self.resolution = resolution
110
+ self.z_space_dim = z_space_dim
111
+ self.w_space_dim = w_space_dim
112
+ self.label_size = label_size
113
+ self.mapping_layers = mapping_layers
114
+ self.mapping_fmaps = mapping_fmaps
115
+ self.mapping_lr_mul = mapping_lr_mul
116
+ self.repeat_w = repeat_w
117
+ self.image_channels = image_channels
118
+ self.final_tanh = final_tanh
119
+ self.const_input = const_input
120
+ self.architecture = architecture
121
+ self.fused_modulate = fused_modulate
122
+ self.demodulate = demodulate
123
+ self.use_wscale = use_wscale
124
+ self.fmaps_base = fmaps_base
125
+ self.fmaps_max = fmaps_max
126
+
127
+ self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
128
+
129
+ if self.repeat_w:
130
+ self.mapping_space_dim = self.w_space_dim
131
+ else:
132
+ self.mapping_space_dim = self.w_space_dim * self.num_layers
133
+ self.mapping = MappingModule(input_space_dim=self.z_space_dim,
134
+ hidden_space_dim=self.mapping_fmaps,
135
+ final_space_dim=self.mapping_space_dim,
136
+ label_size=self.label_size,
137
+ num_layers=self.mapping_layers,
138
+ use_wscale=self.use_wscale,
139
+ lr_mul=self.mapping_lr_mul)
140
+
141
+ self.truncation = TruncationModule(w_space_dim=self.w_space_dim,
142
+ num_layers=self.num_layers,
143
+ repeat_w=self.repeat_w)
144
+
145
+ self.synthesis = SynthesisModule(resolution=self.resolution,
146
+ init_resolution=self.init_res,
147
+ w_space_dim=self.w_space_dim,
148
+ image_channels=self.image_channels,
149
+ final_tanh=self.final_tanh,
150
+ const_input=self.const_input,
151
+ architecture=self.architecture,
152
+ fused_modulate=self.fused_modulate,
153
+ demodulate=self.demodulate,
154
+ use_wscale=self.use_wscale,
155
+ fmaps_base=self.fmaps_base,
156
+ fmaps_max=self.fmaps_max)
157
+
158
+ self.pth_to_tf_var_mapping = {}
159
+ for key, val in self.mapping.pth_to_tf_var_mapping.items():
160
+ self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
161
+ for key, val in self.truncation.pth_to_tf_var_mapping.items():
162
+ self.pth_to_tf_var_mapping[f'truncation.{key}'] = val
163
+ for key, val in self.synthesis.pth_to_tf_var_mapping.items():
164
+ self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
165
+
166
+ def forward(self,
167
+ z,
168
+ label=None,
169
+ w_moving_decay=0.995,
170
+ style_mixing_prob=0.9,
171
+ trunc_psi=None,
172
+ trunc_layers=None,
173
+ randomize_noise=False,
174
+ **_unused_kwargs):
175
+ mapping_results = self.mapping(z, label)
176
+ w = mapping_results['w']
177
+
178
+ if self.training and w_moving_decay < 1:
179
+ batch_w_avg = all_gather(w).mean(dim=0)
180
+ self.truncation.w_avg.copy_(
181
+ self.truncation.w_avg * w_moving_decay +
182
+ batch_w_avg * (1 - w_moving_decay))
183
+
184
+ if self.training and style_mixing_prob > 0:
185
+ new_z = torch.randn_like(z)
186
+ new_w = self.mapping(new_z, label)['w']
187
+ if np.random.uniform() < style_mixing_prob:
188
+ mixing_cutoff = np.random.randint(1, self.num_layers)
189
+ w = self.truncation(w)
190
+ new_w = self.truncation(new_w)
191
+ w[:, :mixing_cutoff] = new_w[:, :mixing_cutoff]
192
+
193
+ wp = self.truncation(w, trunc_psi, trunc_layers)
194
+ synthesis_results = self.synthesis(wp, randomize_noise)
195
+
196
+ return {**mapping_results, **synthesis_results}
197
+
198
+
199
+ class MappingModule(nn.Module):
200
+ """Implements the latent space mapping module.
201
+
202
+ Basically, this module executes several dense layers in sequence.
203
+ """
204
+
205
+ def __init__(self,
206
+ input_space_dim=512,
207
+ hidden_space_dim=512,
208
+ final_space_dim=512,
209
+ label_size=0,
210
+ num_layers=8,
211
+ normalize_input=True,
212
+ use_wscale=True,
213
+ lr_mul=0.01):
214
+ super().__init__()
215
+
216
+ self.input_space_dim = input_space_dim
217
+ self.hidden_space_dim = hidden_space_dim
218
+ self.final_space_dim = final_space_dim
219
+ self.label_size = label_size
220
+ self.num_layers = num_layers
221
+ self.normalize_input = normalize_input
222
+ self.use_wscale = use_wscale
223
+ self.lr_mul = lr_mul
224
+
225
+ self.norm = PixelNormLayer() if self.normalize_input else nn.Identity()
226
+
227
+ self.pth_to_tf_var_mapping = {}
228
+ for i in range(num_layers):
229
+ dim_mul = 2 if label_size else 1
230
+ in_channels = (input_space_dim * dim_mul if i == 0 else
231
+ hidden_space_dim)
232
+ out_channels = (final_space_dim if i == (num_layers - 1) else
233
+ hidden_space_dim)
234
+ self.add_module(f'dense{i}',
235
+ DenseBlock(in_channels=in_channels,
236
+ out_channels=out_channels,
237
+ use_wscale=self.use_wscale,
238
+ lr_mul=self.lr_mul))
239
+ self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
240
+ self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
241
+ if label_size:
242
+ self.label_weight = nn.Parameter(
243
+ torch.randn(label_size, input_space_dim))
244
+ self.pth_to_tf_var_mapping[f'label_weight'] = f'LabelConcat/weight'
245
+
246
+ def forward(self, z, label=None):
247
+ if z.ndim != 2 or z.shape[1] != self.input_space_dim:
248
+ raise ValueError(f'Input latent code should be with shape '
249
+ f'[batch_size, input_dim], where '
250
+ f'`input_dim` equals to {self.input_space_dim}!\n'
251
+ f'But `{z.shape}` is received!')
252
+ if self.label_size:
253
+ if label is None:
254
+ raise ValueError(f'Model requires an additional label '
255
+ f'(with size {self.label_size}) as input, '
256
+ f'but no label is received!')
257
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
258
+ raise ValueError(f'Input label should be with shape '
259
+ f'[batch_size, label_size], where '
260
+ f'`batch_size` equals to that of '
261
+ f'latent codes ({z.shape[0]}) and '
262
+ f'`label_size` equals to {self.label_size}!\n'
263
+ f'But `{label.shape}` is received!')
264
+ embedding = torch.matmul(label, self.label_weight)
265
+ z = torch.cat((z, embedding), dim=1)
266
+
267
+ z = self.norm(z)
268
+ w = z
269
+ for i in range(self.num_layers):
270
+ w = self.__getattr__(f'dense{i}')(w)
271
+ results = {
272
+ 'z': z,
273
+ 'label': label,
274
+ 'w': w,
275
+ }
276
+ if self.label_size:
277
+ results['embedding'] = embedding
278
+ return results
279
+
280
+
281
+ class TruncationModule(nn.Module):
282
+ """Implements the truncation module.
283
+
284
+ Truncation is executed as follows:
285
+
286
+ For layers in range [0, truncation_layers), the truncated w-code is computed
287
+ as
288
+
289
+ w_new = w_avg + (w - w_avg) * truncation_psi
290
+
291
+ To disable truncation, please set
292
+ (1) truncation_psi = 1.0 (None) OR
293
+ (2) truncation_layers = 0 (None)
294
+
295
+ NOTE: The returned tensor is layer-wise style codes.
296
+ """
297
+
298
+ def __init__(self, w_space_dim, num_layers, repeat_w=True):
299
+ super().__init__()
300
+
301
+ self.num_layers = num_layers
302
+ self.w_space_dim = w_space_dim
303
+ self.repeat_w = repeat_w
304
+
305
+ if self.repeat_w:
306
+ self.register_buffer('w_avg', torch.zeros(w_space_dim))
307
+ else:
308
+ self.register_buffer('w_avg', torch.zeros(num_layers * w_space_dim))
309
+ self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
310
+
311
+ def forward(self, w, trunc_psi=None, trunc_layers=None):
312
+ if w.ndim == 2:
313
+ if self.repeat_w and w.shape[1] == self.w_space_dim:
314
+ w = w.view(-1, 1, self.w_space_dim)
315
+ wp = w.repeat(1, self.num_layers, 1)
316
+ else:
317
+ assert w.shape[1] == self.w_space_dim * self.num_layers
318
+ wp = w.view(-1, self.num_layers, self.w_space_dim)
319
+ else:
320
+ wp = w
321
+ assert wp.ndim == 3
322
+ assert wp.shape[1:] == (self.num_layers, self.w_space_dim)
323
+
324
+ trunc_psi = 1.0 if trunc_psi is None else trunc_psi
325
+ trunc_layers = 0 if trunc_layers is None else trunc_layers
326
+ if trunc_psi < 1.0 and trunc_layers > 0:
327
+ layer_idx = np.arange(self.num_layers).reshape(1, -1, 1)
328
+ coefs = np.ones_like(layer_idx, dtype=np.float32)
329
+ coefs[layer_idx < trunc_layers] *= trunc_psi
330
+ coefs = torch.from_numpy(coefs).to(wp)
331
+ w_avg = self.w_avg.view(1, -1, self.w_space_dim)
332
+ wp = w_avg + (wp - w_avg) * coefs
333
+ return wp
334
+
335
+
336
+ class SynthesisModule(nn.Module):
337
+ """Implements the image synthesis module.
338
+
339
+ Basically, this module executes several convolutional layers in sequence.
340
+ """
341
+
342
+ def __init__(self,
343
+ resolution=1024,
344
+ init_resolution=4,
345
+ w_space_dim=512,
346
+ image_channels=3,
347
+ final_tanh=False,
348
+ const_input=True,
349
+ architecture='skip',
350
+ fused_modulate=True,
351
+ demodulate=True,
352
+ use_wscale=True,
353
+ fmaps_base=32 << 10,
354
+ fmaps_max=512):
355
+ super().__init__()
356
+
357
+ self.init_res = init_resolution
358
+ self.init_res_log2 = int(np.log2(self.init_res))
359
+ self.resolution = resolution
360
+ self.final_res_log2 = int(np.log2(self.resolution))
361
+ self.w_space_dim = w_space_dim
362
+ self.image_channels = image_channels
363
+ self.final_tanh = final_tanh
364
+ self.const_input = const_input
365
+ self.architecture = architecture
366
+ self.fused_modulate = fused_modulate
367
+ self.demodulate = demodulate
368
+ self.use_wscale = use_wscale
369
+ self.fmaps_base = fmaps_base
370
+ self.fmaps_max = fmaps_max
371
+
372
+ self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
373
+
374
+ self.pth_to_tf_var_mapping = {}
375
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
376
+ res = 2 ** res_log2
377
+ block_idx = res_log2 - self.init_res_log2
378
+
379
+ # First convolution layer for each resolution.
380
+ if res == self.init_res:
381
+ if self.const_input:
382
+ self.add_module(f'early_layer',
383
+ InputBlock(init_resolution=self.init_res,
384
+ channels=self.get_nf(res)))
385
+ self.pth_to_tf_var_mapping[f'early_layer.const'] = (
386
+ f'{res}x{res}/Const/const')
387
+ else:
388
+ self.add_module(f'early_layer',
389
+ DenseBlock(in_channels=self.w_space_dim,
390
+ out_channels=self.get_nf(res),
391
+ use_wscale=self.use_wscale))
392
+ self.pth_to_tf_var_mapping[f'early_layer.weight'] = (
393
+ f'{res}x{res}/Dense/weight')
394
+ self.pth_to_tf_var_mapping[f'early_layer.bias'] = (
395
+ f'{res}x{res}/Dense/bias')
396
+ else:
397
+ layer_name = f'layer{2 * block_idx - 1}'
398
+ self.add_module(
399
+ layer_name,
400
+ ModulateConvBlock(in_channels=self.get_nf(res // 2),
401
+ out_channels=self.get_nf(res),
402
+ resolution=res,
403
+ w_space_dim=self.w_space_dim,
404
+ scale_factor=2,
405
+ fused_modulate=self.fused_modulate,
406
+ demodulate=self.demodulate,
407
+ use_wscale=self.use_wscale))
408
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
409
+ f'{res}x{res}/Conv0_up/weight')
410
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
411
+ f'{res}x{res}/Conv0_up/bias')
412
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
413
+ f'{res}x{res}/Conv0_up/mod_weight')
414
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
415
+ f'{res}x{res}/Conv0_up/mod_bias')
416
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
417
+ f'{res}x{res}/Conv0_up/noise_strength')
418
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
419
+ f'noise{2 * block_idx - 1}')
420
+
421
+ if self.architecture == 'resnet':
422
+ layer_name = f'layer{2 * block_idx - 1}'
423
+ self.add_module(
424
+ layer_name,
425
+ ConvBlock(in_channels=self.get_nf(res // 2),
426
+ out_channels=self.get_nf(res),
427
+ kernel_size=1,
428
+ add_bias=False,
429
+ scale_factor=2,
430
+ use_wscale=self.use_wscale,
431
+ activation_type='linear'))
432
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
433
+ f'{res}x{res}/Skip/weight')
434
+
435
+ # Second convolution layer for each resolution.
436
+ layer_name = f'layer{2 * block_idx}'
437
+ self.add_module(
438
+ layer_name,
439
+ ModulateConvBlock(in_channels=self.get_nf(res),
440
+ out_channels=self.get_nf(res),
441
+ resolution=res,
442
+ w_space_dim=self.w_space_dim,
443
+ fused_modulate=self.fused_modulate,
444
+ demodulate=self.demodulate,
445
+ use_wscale=self.use_wscale))
446
+ tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
447
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
448
+ f'{res}x{res}/{tf_layer_name}/weight')
449
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
450
+ f'{res}x{res}/{tf_layer_name}/bias')
451
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
452
+ f'{res}x{res}/{tf_layer_name}/mod_weight')
453
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
454
+ f'{res}x{res}/{tf_layer_name}/mod_bias')
455
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
456
+ f'{res}x{res}/{tf_layer_name}/noise_strength')
457
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
458
+ f'noise{2 * block_idx}')
459
+
460
+ # Output convolution layer for each resolution (if needed).
461
+ if res_log2 == self.final_res_log2 or self.architecture == 'skip':
462
+ layer_name = f'output{block_idx}'
463
+ self.add_module(
464
+ layer_name,
465
+ ModulateConvBlock(in_channels=self.get_nf(res),
466
+ out_channels=image_channels,
467
+ resolution=res,
468
+ w_space_dim=self.w_space_dim,
469
+ kernel_size=1,
470
+ fused_modulate=self.fused_modulate,
471
+ demodulate=False,
472
+ use_wscale=self.use_wscale,
473
+ add_noise=False,
474
+ activation_type='linear'))
475
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
476
+ f'{res}x{res}/ToRGB/weight')
477
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
478
+ f'{res}x{res}/ToRGB/bias')
479
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
480
+ f'{res}x{res}/ToRGB/mod_weight')
481
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
482
+ f'{res}x{res}/ToRGB/mod_bias')
483
+
484
+ if self.architecture == 'skip':
485
+ self.upsample = UpsamplingLayer()
486
+ self.final_activate = nn.Tanh() if final_tanh else nn.Identity()
487
+
488
+ def get_nf(self, res):
489
+ """Gets number of feature maps according to current resolution."""
490
+ return min(self.fmaps_base // res, self.fmaps_max)
491
+
492
+ def forward(self, wp, randomize_noise=False):
493
+ if wp.ndim != 3 or wp.shape[1:] != (self.num_layers, self.w_space_dim):
494
+ raise ValueError(f'Input tensor should be with shape '
495
+ f'[batch_size, num_layers, w_space_dim], where '
496
+ f'`num_layers` equals to {self.num_layers}, and '
497
+ f'`w_space_dim` equals to {self.w_space_dim}!\n'
498
+ f'But `{wp.shape}` is received!')
499
+
500
+ results = {'wp': wp}
501
+ x = self.early_layer(wp[:, 0])
502
+ if self.architecture == 'origin':
503
+ for layer_idx in range(self.num_layers - 1):
504
+ x, style = self.__getattr__(f'layer{layer_idx}')(
505
+ x, wp[:, layer_idx], randomize_noise)
506
+ results[f'style{layer_idx:02d}'] = style
507
+ image, style = self.__getattr__(f'output{layer_idx // 2}')(
508
+ x, wp[:, layer_idx + 1])
509
+ results[f'output_style{layer_idx // 2}'] = style
510
+ elif self.architecture == 'skip':
511
+ for layer_idx in range(self.num_layers - 1):
512
+ x, style = self.__getattr__(f'layer{layer_idx}')(
513
+ x, wp[:, layer_idx], randomize_noise)
514
+ results[f'style{layer_idx:02d}'] = style
515
+ if layer_idx % 2 == 0:
516
+ temp, style = self.__getattr__(f'output{layer_idx // 2}')(
517
+ x, wp[:, layer_idx + 1])
518
+ results[f'output_style{layer_idx // 2}'] = style
519
+ if layer_idx == 0:
520
+ image = temp
521
+ else:
522
+ image = temp + self.upsample(image)
523
+ elif self.architecture == 'resnet':
524
+ x, style = self.layer0(x)
525
+ results[f'style00'] = style
526
+ for layer_idx in range(1, self.num_layers - 1, 2):
527
+ residual = self.__getattr__(f'skip_layer{layer_idx // 2}')(x)
528
+ x, style = self.__getattr__(f'layer{layer_idx}')(
529
+ x, wp[:, layer_idx], randomize_noise)
530
+ results[f'style{layer_idx:02d}'] = style
531
+ x, style = self.__getattr__(f'layer{layer_idx + 1}')(
532
+ x, wp[:, layer_idx + 1], randomize_noise)
533
+ results[f'style{layer_idx + 1:02d}'] = style
534
+ x = (x + residual) / np.sqrt(2.0)
535
+ image, style = self.__getattr__(f'output{layer_idx // 2 + 1}')(
536
+ x, wp[:, layer_idx + 2])
537
+ results[f'output_style{layer_idx // 2}'] = style
538
+ results['image'] = self.final_activate(image)
539
+ return results
540
+
541
+
542
+ class PixelNormLayer(nn.Module):
543
+ """Implements pixel-wise feature vector normalization layer."""
544
+
545
+ def __init__(self, dim=1, epsilon=1e-8):
546
+ super().__init__()
547
+ self.dim = dim
548
+ self.eps = epsilon
549
+
550
+ def forward(self, x):
551
+ norm = torch.sqrt(
552
+ torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
553
+ return x / norm
554
+
555
+
556
+ class UpsamplingLayer(nn.Module):
557
+ """Implements the upsampling layer.
558
+
559
+ This layer can also be used as filtering by setting `scale_factor` as 1.
560
+ """
561
+
562
+ def __init__(self,
563
+ scale_factor=2,
564
+ kernel=(1, 3, 3, 1),
565
+ extra_padding=0,
566
+ kernel_gain=None):
567
+ super().__init__()
568
+ assert scale_factor >= 1
569
+ self.scale_factor = scale_factor
570
+
571
+ if extra_padding != 0:
572
+ assert scale_factor == 1
573
+
574
+ if kernel is None:
575
+ kernel = np.ones((scale_factor), dtype=np.float32)
576
+ else:
577
+ kernel = np.array(kernel, dtype=np.float32)
578
+ assert kernel.ndim == 1
579
+ kernel = np.outer(kernel, kernel)
580
+ kernel = kernel / np.sum(kernel)
581
+ if kernel_gain is None:
582
+ kernel = kernel * (scale_factor ** 2)
583
+ else:
584
+ assert kernel_gain > 0
585
+ kernel = kernel * (kernel_gain ** 2)
586
+ assert kernel.ndim == 2
587
+ assert kernel.shape[0] == kernel.shape[1]
588
+ kernel = kernel[np.newaxis, np.newaxis]
589
+ self.register_buffer('kernel', torch.from_numpy(kernel))
590
+ self.kernel = self.kernel.flip(0, 1)
591
+
592
+ self.upsample_padding = (0, scale_factor - 1, # Width padding.
593
+ 0, 0, # Width.
594
+ 0, scale_factor - 1, # Height padding.
595
+ 0, 0, # Height.
596
+ 0, 0, # Channel.
597
+ 0, 0) # Batch size.
598
+
599
+ padding = kernel.shape[2] - scale_factor + extra_padding
600
+ self.padding = ((padding + 1) // 2 + scale_factor - 1, padding // 2,
601
+ (padding + 1) // 2 + scale_factor - 1, padding // 2)
602
+
603
+ def forward(self, x):
604
+ assert x.ndim == 4
605
+ channels = x.shape[1]
606
+ if self.scale_factor > 1:
607
+ x = x.view(-1, channels, x.shape[2], 1, x.shape[3], 1)
608
+ x = F.pad(x, self.upsample_padding, mode='constant', value=0)
609
+ x = x.view(-1, channels, x.shape[2] * self.scale_factor,
610
+ x.shape[4] * self.scale_factor)
611
+ x = x.view(-1, 1, x.shape[2], x.shape[3])
612
+ x = F.pad(x, self.padding, mode='constant', value=0)
613
+ x = F.conv2d(x, self.kernel, stride=1)
614
+ x = x.view(-1, channels, x.shape[2], x.shape[3])
615
+ return x
616
+
617
+
618
+ class InputBlock(nn.Module):
619
+ """Implements the input block.
620
+
621
+ Basically, this block starts from a const input, which is with shape
622
+ `(channels, init_resolution, init_resolution)`.
623
+ """
624
+
625
+ def __init__(self, init_resolution, channels):
626
+ super().__init__()
627
+ self.const = nn.Parameter(
628
+ torch.randn(1, channels, init_resolution, init_resolution))
629
+
630
+ def forward(self, w):
631
+ x = self.const.repeat(w.shape[0], 1, 1, 1)
632
+ return x
633
+
634
+
635
+ class ConvBlock(nn.Module):
636
+ """Implements the convolutional block (no style modulation).
637
+
638
+ Basically, this block executes, convolutional layer, filtering layer (if
639
+ needed), and activation layer in sequence.
640
+
641
+ NOTE: This block is particularly used for skip-connection branch in the
642
+ `resnet` structure.
643
+ """
644
+
645
+ def __init__(self,
646
+ in_channels,
647
+ out_channels,
648
+ kernel_size=3,
649
+ add_bias=True,
650
+ scale_factor=1,
651
+ filtering_kernel=(1, 3, 3, 1),
652
+ use_wscale=True,
653
+ wscale_gain=_WSCALE_GAIN,
654
+ lr_mul=1.0,
655
+ activation_type='lrelu'):
656
+ """Initializes with block settings.
657
+
658
+ Args:
659
+ in_channels: Number of channels of the input tensor.
660
+ out_channels: Number of channels of the output tensor.
661
+ kernel_size: Size of the convolutional kernels. (default: 3)
662
+ add_bias: Whether to add bias onto the convolutional result.
663
+ (default: True)
664
+ scale_factor: Scale factor for upsampling. `1` means skip
665
+ upsampling. (default: 1)
666
+ filtering_kernel: Kernel used for filtering after upsampling.
667
+ (default: (1, 3, 3, 1))
668
+ use_wscale: Whether to use weight scaling. (default: True)
669
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
670
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
671
+ activation_type: Type of activation. Support `linear` and `lrelu`.
672
+ (default: `lrelu`)
673
+
674
+ Raises:
675
+ NotImplementedError: If the `activation_type` is not supported.
676
+ """
677
+ super().__init__()
678
+
679
+ if scale_factor > 1:
680
+ self.use_conv2d_transpose = True
681
+ extra_padding = scale_factor - kernel_size
682
+ self.filter = UpsamplingLayer(scale_factor=1,
683
+ kernel=filtering_kernel,
684
+ extra_padding=extra_padding,
685
+ kernel_gain=scale_factor)
686
+ self.stride = scale_factor
687
+ self.padding = 0 # Padding is done in `UpsamplingLayer`.
688
+ else:
689
+ self.use_conv2d_transpose = False
690
+ assert kernel_size % 2 == 1
691
+ self.stride = 1
692
+ self.padding = kernel_size // 2
693
+
694
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
695
+ fan_in = kernel_size * kernel_size * in_channels
696
+ wscale = wscale_gain / np.sqrt(fan_in)
697
+ if use_wscale:
698
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
699
+ self.wscale = wscale * lr_mul
700
+ else:
701
+ self.weight = nn.Parameter(
702
+ torch.randn(*weight_shape) * wscale / lr_mul)
703
+ self.wscale = lr_mul
704
+
705
+ if add_bias:
706
+ self.bias = nn.Parameter(torch.zeros(out_channels))
707
+ else:
708
+ self.bias = None
709
+ self.bscale = lr_mul
710
+
711
+ if activation_type == 'linear':
712
+ self.activate = nn.Identity()
713
+ self.activate_scale = 1.0
714
+ elif activation_type == 'lrelu':
715
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
716
+ self.activate_scale = np.sqrt(2.0)
717
+ else:
718
+ raise NotImplementedError(f'Not implemented activation function: '
719
+ f'`{activation_type}`!')
720
+
721
+ def forward(self, x):
722
+ weight = self.weight * self.wscale
723
+ bias = self.bias * self.bscale if self.bias is not None else None
724
+ if self.use_conv2d_transpose:
725
+ weight = weight.permute(1, 0, 2, 3).flip(2, 3)
726
+ x = F.conv_transpose2d(x,
727
+ weight=weight,
728
+ bias=bias,
729
+ stride=self.scale_factor,
730
+ padding=self.padding)
731
+ x = self.filter(x)
732
+ else:
733
+ x = F.conv2d(x,
734
+ weight=weight,
735
+ bias=bias,
736
+ stride=self.stride,
737
+ padding=self.padding)
738
+ x = self.activate(x) * self.activate_scale
739
+ return x
740
+
741
+
742
+ class ModulateConvBlock(nn.Module):
743
+ """Implements the convolutional block with style modulation."""
744
+
745
+ def __init__(self,
746
+ in_channels,
747
+ out_channels,
748
+ resolution,
749
+ w_space_dim,
750
+ kernel_size=3,
751
+ add_bias=True,
752
+ scale_factor=1,
753
+ filtering_kernel=(1, 3, 3, 1),
754
+ fused_modulate=True,
755
+ demodulate=True,
756
+ use_wscale=True,
757
+ wscale_gain=_WSCALE_GAIN,
758
+ lr_mul=1.0,
759
+ add_noise=True,
760
+ activation_type='lrelu',
761
+ epsilon=1e-8):
762
+ """Initializes with block settings.
763
+
764
+ Args:
765
+ in_channels: Number of channels of the input tensor.
766
+ out_channels: Number of channels of the output tensor.
767
+ resolution: Resolution of the output tensor.
768
+ w_space_dim: Dimension of W space for style modulation.
769
+ kernel_size: Size of the convolutional kernels. (default: 3)
770
+ add_bias: Whether to add bias onto the convolutional result.
771
+ (default: True)
772
+ scale_factor: Scale factor for upsampling. `1` means skip
773
+ upsampling. (default: 1)
774
+ filtering_kernel: Kernel used for filtering after upsampling.
775
+ (default: (1, 3, 3, 1))
776
+ fused_modulate: Whether to fuse `style_modulate` and `conv2d`
777
+ together. (default: True)
778
+ demodulate: Whether to perform style demodulation. (default: True)
779
+ use_wscale: Whether to use weight scaling. (default: True)
780
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
781
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
782
+ add_noise: Whether to add noise onto the output tensor. (default:
783
+ True)
784
+ activation_type: Type of activation. Support `linear` and `lrelu`.
785
+ (default: `lrelu`)
786
+ epsilon: Small number to avoid `divide by zero`. (default: 1e-8)
787
+
788
+ Raises:
789
+ NotImplementedError: If the `activation_type` is not supported.
790
+ """
791
+ super().__init__()
792
+
793
+ self.res = resolution
794
+ self.in_c = in_channels
795
+ self.out_c = out_channels
796
+ self.ksize = kernel_size
797
+ self.eps = epsilon
798
+
799
+ if scale_factor > 1:
800
+ self.use_conv2d_transpose = True
801
+ extra_padding = scale_factor - kernel_size
802
+ self.filter = UpsamplingLayer(scale_factor=1,
803
+ kernel=filtering_kernel,
804
+ extra_padding=extra_padding,
805
+ kernel_gain=scale_factor)
806
+ self.stride = scale_factor
807
+ self.padding = 0 # Padding is done in `UpsamplingLayer`.
808
+ else:
809
+ self.use_conv2d_transpose = False
810
+ assert kernel_size % 2 == 1
811
+ self.stride = 1
812
+ self.padding = kernel_size // 2
813
+
814
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
815
+ fan_in = kernel_size * kernel_size * in_channels
816
+ wscale = wscale_gain / np.sqrt(fan_in)
817
+ if use_wscale:
818
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
819
+ self.wscale = wscale * lr_mul
820
+ else:
821
+ self.weight = nn.Parameter(
822
+ torch.randn(*weight_shape) * wscale / lr_mul)
823
+ self.wscale = lr_mul
824
+
825
+ self.style = DenseBlock(in_channels=w_space_dim,
826
+ out_channels=in_channels,
827
+ additional_bias=1.0,
828
+ use_wscale=use_wscale,
829
+ activation_type='linear')
830
+
831
+ self.fused_modulate = fused_modulate
832
+ self.demodulate = demodulate
833
+
834
+ if add_bias:
835
+ self.bias = nn.Parameter(torch.zeros(out_channels))
836
+ else:
837
+ self.bias = None
838
+ self.bscale = lr_mul
839
+
840
+ if activation_type == 'linear':
841
+ self.activate = nn.Identity()
842
+ self.activate_scale = 1.0
843
+ elif activation_type == 'lrelu':
844
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
845
+ self.activate_scale = np.sqrt(2.0)
846
+ else:
847
+ raise NotImplementedError(f'Not implemented activation function: '
848
+ f'`{activation_type}`!')
849
+
850
+ self.add_noise = add_noise
851
+ if self.add_noise:
852
+ self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
853
+ self.noise_strength = nn.Parameter(torch.zeros(()))
854
+
855
+ def forward(self, x, w, randomize_noise=False):
856
+ batch = x.shape[0]
857
+
858
+ weight = self.weight * self.wscale
859
+ weight = weight.permute(2, 3, 1, 0)
860
+
861
+ # Style modulation.
862
+ style = self.style(w)
863
+ _weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c)
864
+ _weight = _weight * style.view(batch, 1, 1, self.in_c, 1)
865
+
866
+ # Style demodulation.
867
+ if self.demodulate:
868
+ _weight_norm = torch.sqrt(
869
+ torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps)
870
+ _weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c)
871
+
872
+ if self.fused_modulate:
873
+ x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3])
874
+ weight = _weight.permute(1, 2, 3, 0, 4).reshape(
875
+ self.ksize, self.ksize, self.in_c, batch * self.out_c)
876
+ else:
877
+ x = x * style.view(batch, self.in_c, 1, 1)
878
+
879
+ if self.use_conv2d_transpose:
880
+ weight = weight.flip(0, 1)
881
+ if self.fused_modulate:
882
+ weight = weight.view(
883
+ self.ksize, self.ksize, self.in_c, batch, self.out_c)
884
+ weight = weight.permute(0, 1, 4, 3, 2)
885
+ weight = weight.reshape(
886
+ self.ksize, self.ksize, self.out_c, batch * self.in_c)
887
+ weight = weight.permute(3, 2, 0, 1)
888
+ else:
889
+ weight = weight.permute(2, 3, 0, 1)
890
+ x = F.conv_transpose2d(x,
891
+ weight=weight,
892
+ bias=None,
893
+ stride=self.stride,
894
+ padding=self.padding,
895
+ groups=(batch if self.fused_modulate else 1))
896
+ x = self.filter(x)
897
+ else:
898
+ weight = weight.permute(3, 2, 0, 1)
899
+ x = F.conv2d(x,
900
+ weight=weight,
901
+ bias=None,
902
+ stride=self.stride,
903
+ padding=self.padding,
904
+ groups=(batch if self.fused_modulate else 1))
905
+
906
+ if self.fused_modulate:
907
+ x = x.view(batch, self.out_c, self.res, self.res)
908
+ elif self.demodulate:
909
+ x = x / _weight_norm.view(batch, self.out_c, 1, 1)
910
+
911
+ if self.add_noise:
912
+ if randomize_noise:
913
+ noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
914
+ else:
915
+ noise = self.noise
916
+ x = x + noise * self.noise_strength.view(1, 1, 1, 1)
917
+
918
+ bias = self.bias * self.bscale if self.bias is not None else None
919
+ if bias is not None:
920
+ x = x + bias.view(1, -1, 1, 1)
921
+ x = self.activate(x) * self.activate_scale
922
+ return x, style
923
+
924
+
925
+ class DenseBlock(nn.Module):
926
+ """Implements the dense block.
927
+
928
+ Basically, this block executes fully-connected layer and activation layer.
929
+
930
+ NOTE: This layer supports adding an additional bias beyond the trainable
931
+ bias parameter. This is specially used for the mapping from the w code to
932
+ the style code.
933
+ """
934
+
935
+ def __init__(self,
936
+ in_channels,
937
+ out_channels,
938
+ add_bias=True,
939
+ additional_bias=0,
940
+ use_wscale=True,
941
+ wscale_gain=_WSCALE_GAIN,
942
+ lr_mul=1.0,
943
+ activation_type='lrelu'):
944
+ """Initializes with block settings.
945
+
946
+ Args:
947
+ in_channels: Number of channels of the input tensor.
948
+ out_channels: Number of channels of the output tensor.
949
+ add_bias: Whether to add bias onto the fully-connected result.
950
+ (default: True)
951
+ additional_bias: The additional bias, which is independent from the
952
+ bias parameter. (default: 0.0)
953
+ use_wscale: Whether to use weight scaling. (default: True)
954
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
955
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
956
+ activation_type: Type of activation. Support `linear` and `lrelu`.
957
+ (default: `lrelu`)
958
+
959
+ Raises:
960
+ NotImplementedError: If the `activation_type` is not supported.
961
+ """
962
+ super().__init__()
963
+ weight_shape = (out_channels, in_channels)
964
+ wscale = wscale_gain / np.sqrt(in_channels)
965
+ if use_wscale:
966
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
967
+ self.wscale = wscale * lr_mul
968
+ else:
969
+ self.weight = nn.Parameter(
970
+ torch.randn(*weight_shape) * wscale / lr_mul)
971
+ self.wscale = lr_mul
972
+
973
+ if add_bias:
974
+ self.bias = nn.Parameter(torch.zeros(out_channels))
975
+ else:
976
+ self.bias = None
977
+ self.bscale = lr_mul
978
+ self.additional_bias = additional_bias
979
+
980
+ if activation_type == 'linear':
981
+ self.activate = nn.Identity()
982
+ self.activate_scale = 1.0
983
+ elif activation_type == 'lrelu':
984
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
985
+ self.activate_scale = np.sqrt(2.0)
986
+ else:
987
+ raise NotImplementedError(f'Not implemented activation function: '
988
+ f'`{activation_type}`!')
989
+
990
+ def forward(self, x):
991
+ if x.ndim != 2:
992
+ x = x.view(x.shape[0], -1)
993
+ bias = self.bias * self.bscale if self.bias is not None else None
994
+ x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
995
+ x = self.activate(x + self.additional_bias) * self.activate_scale
996
+ return x
models/stylegan_discriminator.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of discriminator described in StyleGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1812.04948.pdf
5
+
6
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan
7
+ """
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ __all__ = ['StyleGANDiscriminator']
16
+
17
+ # Resolutions allowed.
18
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
19
+
20
+ # Initial resolution.
21
+ _INIT_RES = 4
22
+
23
+ # Fused-scale options allowed.
24
+ _FUSED_SCALE_ALLOWED = [True, False, 'auto']
25
+
26
+ # Minimal resolution for `auto` fused-scale strategy.
27
+ _AUTO_FUSED_SCALE_MIN_RES = 128
28
+
29
+ # Default gain factor for weight scaling.
30
+ _WSCALE_GAIN = np.sqrt(2.0)
31
+
32
+
33
+ class StyleGANDiscriminator(nn.Module):
34
+ """Defines the discriminator network in StyleGAN.
35
+
36
+ NOTE: The discriminator takes images with `RGB` channel order and pixel
37
+ range [-1, 1] as inputs.
38
+
39
+ Settings for the network:
40
+
41
+ (1) resolution: The resolution of the input image.
42
+ (2) image_channels: Number of channels of the input image. (default: 3)
43
+ (3) label_size: Size of the additional label for conditional generation.
44
+ (default: 0)
45
+ (4) fused_scale: Whether to fused `conv2d` and `downsample` together,
46
+ resulting in `conv2d` with strides. (default: `auto`)
47
+ (5) use_wscale: Whether to use weight scaling. (default: True)
48
+ (6) minibatch_std_group_size: Group size for the minibatch standard
49
+ deviation layer. 0 means disable. (default: 4)
50
+ (7) minibatch_std_channels: Number of new channels after the minibatch
51
+ standard deviation layer. (default: 1)
52
+ (8) fmaps_base: Factor to control number of feature maps for each layer.
53
+ (default: 16 << 10)
54
+ (9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
55
+ """
56
+
57
+ def __init__(self,
58
+ resolution,
59
+ image_channels=3,
60
+ label_size=0,
61
+ fused_scale='auto',
62
+ use_wscale=True,
63
+ minibatch_std_group_size=4,
64
+ minibatch_std_channels=1,
65
+ fmaps_base=16 << 10,
66
+ fmaps_max=512):
67
+ """Initializes with basic settings.
68
+
69
+ Raises:
70
+ ValueError: If the `resolution` is not supported, or `fused_scale`
71
+ is not supported.
72
+ """
73
+ super().__init__()
74
+
75
+ if resolution not in _RESOLUTIONS_ALLOWED:
76
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
77
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
78
+ if fused_scale not in _FUSED_SCALE_ALLOWED:
79
+ raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
80
+ f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
81
+
82
+ self.init_res = _INIT_RES
83
+ self.init_res_log2 = int(np.log2(self.init_res))
84
+ self.resolution = resolution
85
+ self.final_res_log2 = int(np.log2(self.resolution))
86
+ self.image_channels = image_channels
87
+ self.label_size = label_size
88
+ self.fused_scale = fused_scale
89
+ self.use_wscale = use_wscale
90
+ self.minibatch_std_group_size = minibatch_std_group_size
91
+ self.minibatch_std_channels = minibatch_std_channels
92
+ self.fmaps_base = fmaps_base
93
+ self.fmaps_max = fmaps_max
94
+
95
+ # Level of detail (used for progressive training).
96
+ self.register_buffer('lod', torch.zeros(()))
97
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
98
+
99
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
100
+ res = 2 ** res_log2
101
+ block_idx = self.final_res_log2 - res_log2
102
+
103
+ # Input convolution layer for each resolution.
104
+ self.add_module(
105
+ f'input{block_idx}',
106
+ ConvBlock(in_channels=self.image_channels,
107
+ out_channels=self.get_nf(res),
108
+ kernel_size=1,
109
+ padding=0,
110
+ use_wscale=self.use_wscale))
111
+ self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
112
+ f'FromRGB_lod{block_idx}/weight')
113
+ self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
114
+ f'FromRGB_lod{block_idx}/bias')
115
+
116
+ # Convolution block for each resolution (except the last one).
117
+ if res != self.init_res:
118
+ if self.fused_scale == 'auto':
119
+ fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES)
120
+ else:
121
+ fused_scale = self.fused_scale
122
+ self.add_module(
123
+ f'layer{2 * block_idx}',
124
+ ConvBlock(in_channels=self.get_nf(res),
125
+ out_channels=self.get_nf(res),
126
+ use_wscale=self.use_wscale))
127
+ tf_layer0_name = 'Conv0'
128
+ self.add_module(
129
+ f'layer{2 * block_idx + 1}',
130
+ ConvBlock(in_channels=self.get_nf(res),
131
+ out_channels=self.get_nf(res // 2),
132
+ downsample=True,
133
+ fused_scale=fused_scale,
134
+ use_wscale=self.use_wscale))
135
+ tf_layer1_name = 'Conv1_down'
136
+
137
+ # Convolution block for last resolution.
138
+ else:
139
+ self.add_module(
140
+ f'layer{2 * block_idx}',
141
+ ConvBlock(in_channels=self.get_nf(res),
142
+ out_channels=self.get_nf(res),
143
+ use_wscale=self.use_wscale,
144
+ minibatch_std_group_size=minibatch_std_group_size,
145
+ minibatch_std_channels=minibatch_std_channels))
146
+ tf_layer0_name = 'Conv'
147
+ self.add_module(
148
+ f'layer{2 * block_idx + 1}',
149
+ DenseBlock(in_channels=self.get_nf(res) * res * res,
150
+ out_channels=self.get_nf(res // 2),
151
+ use_wscale=self.use_wscale))
152
+ tf_layer1_name = 'Dense0'
153
+
154
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
155
+ f'{res}x{res}/{tf_layer0_name}/weight')
156
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
157
+ f'{res}x{res}/{tf_layer0_name}/bias')
158
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
159
+ f'{res}x{res}/{tf_layer1_name}/weight')
160
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
161
+ f'{res}x{res}/{tf_layer1_name}/bias')
162
+
163
+ # Final dense block.
164
+ self.add_module(
165
+ f'layer{2 * block_idx + 2}',
166
+ DenseBlock(in_channels=self.get_nf(res // 2),
167
+ out_channels=max(self.label_size, 1),
168
+ use_wscale=self.use_wscale,
169
+ wscale_gain=1.0,
170
+ activation_type='linear'))
171
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
172
+ f'{res}x{res}/Dense1/weight')
173
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
174
+ f'{res}x{res}/Dense1/bias')
175
+
176
+ self.downsample = DownsamplingLayer()
177
+
178
+ def get_nf(self, res):
179
+ """Gets number of feature maps according to current resolution."""
180
+ return min(self.fmaps_base // res, self.fmaps_max)
181
+
182
+ def forward(self, image, label=None, lod=None, **_unused_kwargs):
183
+ expected_shape = (self.image_channels, self.resolution, self.resolution)
184
+ if image.ndim != 4 or image.shape[1:] != expected_shape:
185
+ raise ValueError(f'The input tensor should be with shape '
186
+ f'[batch_size, channel, height, width], where '
187
+ f'`channel` equals to {self.image_channels}, '
188
+ f'`height`, `width` equal to {self.resolution}!\n'
189
+ f'But `{image.shape}` is received!')
190
+
191
+ lod = self.lod.cpu().tolist() if lod is None else lod
192
+ if lod + self.init_res_log2 > self.final_res_log2:
193
+ raise ValueError(f'Maximum level-of-detail (lod) is '
194
+ f'{self.final_res_log2 - self.init_res_log2}, '
195
+ f'but `{lod}` is received!')
196
+
197
+ if self.label_size:
198
+ if label is None:
199
+ raise ValueError(f'Model requires an additional label '
200
+ f'(with size {self.label_size}) as input, '
201
+ f'but no label is received!')
202
+ batch_size = image.shape[0]
203
+ if label.ndim != 2 or label.shape != (batch_size, self.label_size):
204
+ raise ValueError(f'Input label should be with shape '
205
+ f'[batch_size, label_size], where '
206
+ f'`batch_size` equals to that of '
207
+ f'images ({image.shape[0]}) and '
208
+ f'`label_size` equals to {self.label_size}!\n'
209
+ f'But `{label.shape}` is received!')
210
+
211
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
212
+ block_idx = current_lod = self.final_res_log2 - res_log2
213
+ if current_lod <= lod < current_lod + 1:
214
+ x = self.__getattr__(f'input{block_idx}')(image)
215
+ elif current_lod - 1 < lod < current_lod:
216
+ alpha = lod - np.floor(lod)
217
+ x = (self.__getattr__(f'input{block_idx}')(image) * alpha +
218
+ x * (1 - alpha))
219
+ if lod < current_lod + 1:
220
+ x = self.__getattr__(f'layer{2 * block_idx}')(x)
221
+ x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
222
+ if lod > current_lod:
223
+ image = self.downsample(image)
224
+ x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
225
+
226
+ if self.label_size:
227
+ x = torch.sum(x * label, dim=1, keepdim=True)
228
+
229
+ return x
230
+
231
+
232
+ class MiniBatchSTDLayer(nn.Module):
233
+ """Implements the minibatch standard deviation layer."""
234
+
235
+ def __init__(self, group_size=4, new_channels=1, epsilon=1e-8):
236
+ super().__init__()
237
+ self.group_size = group_size
238
+ self.new_channels = new_channels
239
+ self.epsilon = epsilon
240
+
241
+ def forward(self, x):
242
+ if self.group_size <= 1:
243
+ return x
244
+ ng = min(self.group_size, x.shape[0])
245
+ nc = self.new_channels
246
+ temp_c = x.shape[1] // nc # [NCHW]
247
+ y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW]
248
+ y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW]
249
+ y = torch.mean(y ** 2, dim=0) # [MncHW]
250
+ y = torch.sqrt(y + self.epsilon) # [MncHW]
251
+ y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111]
252
+ y = torch.mean(y, dim=2) # [Mn11]
253
+ y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW]
254
+ return torch.cat([x, y], dim=1)
255
+
256
+
257
+ class DownsamplingLayer(nn.Module):
258
+ """Implements the downsampling layer.
259
+
260
+ Basically, this layer can be used to downsample feature maps with average
261
+ pooling.
262
+ """
263
+
264
+ def __init__(self, scale_factor=2):
265
+ super().__init__()
266
+ self.scale_factor = scale_factor
267
+
268
+ def forward(self, x):
269
+ if self.scale_factor <= 1:
270
+ return x
271
+ return F.avg_pool2d(x,
272
+ kernel_size=self.scale_factor,
273
+ stride=self.scale_factor,
274
+ padding=0)
275
+
276
+
277
+ class Blur(torch.autograd.Function):
278
+ """Defines blur operation with customized gradient computation."""
279
+
280
+ @staticmethod
281
+ def forward(ctx, x, kernel):
282
+ ctx.save_for_backward(kernel)
283
+ y = F.conv2d(input=x,
284
+ weight=kernel,
285
+ bias=None,
286
+ stride=1,
287
+ padding=1,
288
+ groups=x.shape[1])
289
+ return y
290
+
291
+ @staticmethod
292
+ def backward(ctx, dy):
293
+ kernel, = ctx.saved_tensors
294
+ dx = BlurBackPropagation.apply(dy, kernel)
295
+ return dx, None, None
296
+
297
+
298
+ class BlurBackPropagation(torch.autograd.Function):
299
+ """Defines the back propagation of blur operation.
300
+
301
+ NOTE: This is used to speed up the backward of gradient penalty.
302
+ """
303
+
304
+ @staticmethod
305
+ def forward(ctx, dy, kernel):
306
+ ctx.save_for_backward(kernel)
307
+ dx = F.conv2d(input=dy,
308
+ weight=kernel.flip((2, 3)),
309
+ bias=None,
310
+ stride=1,
311
+ padding=1,
312
+ groups=dy.shape[1])
313
+ return dx
314
+
315
+ @staticmethod
316
+ def backward(ctx, ddx):
317
+ kernel, = ctx.saved_tensors
318
+ ddy = F.conv2d(input=ddx,
319
+ weight=kernel,
320
+ bias=None,
321
+ stride=1,
322
+ padding=1,
323
+ groups=ddx.shape[1])
324
+ return ddy, None, None
325
+
326
+
327
+ class BlurLayer(nn.Module):
328
+ """Implements the blur layer."""
329
+
330
+ def __init__(self,
331
+ channels,
332
+ kernel=(1, 2, 1),
333
+ normalize=True):
334
+ super().__init__()
335
+ kernel = np.array(kernel, dtype=np.float32).reshape(1, -1)
336
+ kernel = kernel.T.dot(kernel)
337
+ if normalize:
338
+ kernel = kernel / np.sum(kernel)
339
+ kernel = kernel[np.newaxis, np.newaxis]
340
+ kernel = np.tile(kernel, [channels, 1, 1, 1])
341
+ self.register_buffer('kernel', torch.from_numpy(kernel))
342
+
343
+ def forward(self, x):
344
+ return Blur.apply(x, self.kernel)
345
+
346
+
347
+ class ConvBlock(nn.Module):
348
+ """Implements the convolutional block.
349
+
350
+ Basically, this block executes minibatch standard deviation layer (if
351
+ needed), convolutional layer, activation layer, and downsampling layer (
352
+ if needed) in sequence.
353
+ """
354
+
355
+ def __init__(self,
356
+ in_channels,
357
+ out_channels,
358
+ kernel_size=3,
359
+ stride=1,
360
+ padding=1,
361
+ add_bias=True,
362
+ downsample=False,
363
+ fused_scale=False,
364
+ use_wscale=True,
365
+ wscale_gain=_WSCALE_GAIN,
366
+ lr_mul=1.0,
367
+ activation_type='lrelu',
368
+ minibatch_std_group_size=0,
369
+ minibatch_std_channels=1):
370
+ """Initializes with block settings.
371
+
372
+ Args:
373
+ in_channels: Number of channels of the input tensor.
374
+ out_channels: Number of channels of the output tensor.
375
+ kernel_size: Size of the convolutional kernels. (default: 3)
376
+ stride: Stride parameter for convolution operation. (default: 1)
377
+ padding: Padding parameter for convolution operation. (default: 1)
378
+ add_bias: Whether to add bias onto the convolutional result.
379
+ (default: True)
380
+ downsample: Whether to downsample the result after convolution.
381
+ (default: False)
382
+ fused_scale: Whether to fused `conv2d` and `downsample` together,
383
+ resulting in `conv2d` with strides. (default: False)
384
+ use_wscale: Whether to use weight scaling. (default: True)
385
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
386
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
387
+ activation_type: Type of activation. Support `linear` and `lrelu`.
388
+ (default: `lrelu`)
389
+ minibatch_std_group_size: Group size for the minibatch standard
390
+ deviation layer. 0 means disable. (default: 0)
391
+ minibatch_std_channels: Number of new channels after the minibatch
392
+ standard deviation layer. (default: 1)
393
+
394
+ Raises:
395
+ NotImplementedError: If the `activation_type` is not supported.
396
+ """
397
+ super().__init__()
398
+
399
+ if minibatch_std_group_size > 1:
400
+ in_channels = in_channels + minibatch_std_channels
401
+ self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size,
402
+ new_channels=minibatch_std_channels)
403
+ else:
404
+ self.mbstd = nn.Identity()
405
+
406
+ if downsample:
407
+ self.blur = BlurLayer(channels=in_channels)
408
+ else:
409
+ self.blur = nn.Identity()
410
+
411
+ if downsample and not fused_scale:
412
+ self.downsample = DownsamplingLayer()
413
+ else:
414
+ self.downsample = nn.Identity()
415
+
416
+ if downsample and fused_scale:
417
+ self.use_stride = True
418
+ self.stride = 2
419
+ self.padding = 1
420
+ else:
421
+ self.use_stride = False
422
+ self.stride = stride
423
+ self.padding = padding
424
+
425
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
426
+ fan_in = kernel_size * kernel_size * in_channels
427
+ wscale = wscale_gain / np.sqrt(fan_in)
428
+ if use_wscale:
429
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
430
+ self.wscale = wscale * lr_mul
431
+ else:
432
+ self.weight = nn.Parameter(
433
+ torch.randn(*weight_shape) * wscale / lr_mul)
434
+ self.wscale = lr_mul
435
+
436
+ if add_bias:
437
+ self.bias = nn.Parameter(torch.zeros(out_channels))
438
+ self.bscale = lr_mul
439
+ else:
440
+ self.bias = None
441
+
442
+ if activation_type == 'linear':
443
+ self.activate = nn.Identity()
444
+ elif activation_type == 'lrelu':
445
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
446
+ else:
447
+ raise NotImplementedError(f'Not implemented activation function: '
448
+ f'`{activation_type}`!')
449
+
450
+ def forward(self, x):
451
+ x = self.mbstd(x)
452
+ x = self.blur(x)
453
+ weight = self.weight * self.wscale
454
+ bias = self.bias * self.bscale if self.bias is not None else None
455
+ if self.use_stride:
456
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
457
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
458
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
459
+ x = F.conv2d(x,
460
+ weight=weight,
461
+ bias=bias,
462
+ stride=self.stride,
463
+ padding=self.padding)
464
+ x = self.downsample(x)
465
+ x = self.activate(x)
466
+ return x
467
+
468
+
469
+ class DenseBlock(nn.Module):
470
+ """Implements the dense block.
471
+
472
+ Basically, this block executes fully-connected layer and activation layer.
473
+ """
474
+
475
+ def __init__(self,
476
+ in_channels,
477
+ out_channels,
478
+ add_bias=True,
479
+ use_wscale=True,
480
+ wscale_gain=_WSCALE_GAIN,
481
+ lr_mul=1.0,
482
+ activation_type='lrelu'):
483
+ """Initializes with block settings.
484
+
485
+ Args:
486
+ in_channels: Number of channels of the input tensor.
487
+ out_channels: Number of channels of the output tensor.
488
+ add_bias: Whether to add bias onto the fully-connected result.
489
+ (default: True)
490
+ use_wscale: Whether to use weight scaling. (default: True)
491
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
492
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
493
+ activation_type: Type of activation. Support `linear` and `lrelu`.
494
+ (default: `lrelu`)
495
+
496
+ Raises:
497
+ NotImplementedError: If the `activation_type` is not supported.
498
+ """
499
+ super().__init__()
500
+ weight_shape = (out_channels, in_channels)
501
+ wscale = wscale_gain / np.sqrt(in_channels)
502
+ if use_wscale:
503
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
504
+ self.wscale = wscale * lr_mul
505
+ else:
506
+ self.weight = nn.Parameter(
507
+ torch.randn(*weight_shape) * wscale / lr_mul)
508
+ self.wscale = lr_mul
509
+
510
+ if add_bias:
511
+ self.bias = nn.Parameter(torch.zeros(out_channels))
512
+ self.bscale = lr_mul
513
+ else:
514
+ self.bias = None
515
+
516
+ if activation_type == 'linear':
517
+ self.activate = nn.Identity()
518
+ elif activation_type == 'lrelu':
519
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
520
+ else:
521
+ raise NotImplementedError(f'Not implemented activation function: '
522
+ f'`{activation_type}`!')
523
+
524
+ def forward(self, x):
525
+ if x.ndim != 2:
526
+ x = x.view(x.shape[0], -1)
527
+ bias = self.bias * self.bscale if self.bias is not None else None
528
+ x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
529
+ x = self.activate(x)
530
+ return x
models/stylegan_generator.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in StyleGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1812.04948.pdf
5
+
6
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan
7
+ """
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .sync_op import all_gather
16
+
17
+ __all__ = ['StyleGANGenerator']
18
+
19
+ # Resolutions allowed.
20
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
21
+
22
+ # Initial resolution.
23
+ _INIT_RES = 4
24
+
25
+ # Fused-scale options allowed.
26
+ _FUSED_SCALE_ALLOWED = [True, False, 'auto']
27
+
28
+ # Minimal resolution for `auto` fused-scale strategy.
29
+ _AUTO_FUSED_SCALE_MIN_RES = 128
30
+
31
+ # Default gain factor for weight scaling.
32
+ _WSCALE_GAIN = np.sqrt(2.0)
33
+ _STYLEMOD_WSCALE_GAIN = 1.0
34
+
35
+
36
+ class StyleGANGenerator(nn.Module):
37
+ """Defines the generator network in StyleGAN.
38
+
39
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
40
+ [-1, 1].
41
+
42
+ Settings for the mapping network:
43
+
44
+ (1) z_space_dim: Dimension of the input latent space, Z. (default: 512)
45
+ (2) w_space_dim: Dimension of the outout latent space, W. (default: 512)
46
+ (3) label_size: Size of the additional label for conditional generation.
47
+ (default: 0)
48
+ (4)mapping_layers: Number of layers of the mapping network. (default: 8)
49
+ (5) mapping_fmaps: Number of hidden channels of the mapping network.
50
+ (default: 512)
51
+ (6) mapping_lr_mul: Learning rate multiplier for the mapping network.
52
+ (default: 0.01)
53
+ (7) repeat_w: Repeat w-code for different layers.
54
+
55
+ Settings for the synthesis network:
56
+
57
+ (1) resolution: The resolution of the output image.
58
+ (2) image_channels: Number of channels of the output image. (default: 3)
59
+ (3) final_tanh: Whether to use `tanh` to control the final pixel range.
60
+ (default: False)
61
+ (4) const_input: Whether to use a constant in the first convolutional layer.
62
+ (default: True)
63
+ (5) fused_scale: Whether to fused `upsample` and `conv2d` together,
64
+ resulting in `conv2d_transpose`. (default: `auto`)
65
+ (6) use_wscale: Whether to use weight scaling. (default: True)
66
+ (7) fmaps_base: Factor to control number of feature maps for each layer.
67
+ (default: 16 << 10)
68
+ (8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
69
+ """
70
+
71
+ def __init__(self,
72
+ resolution,
73
+ z_space_dim=512,
74
+ w_space_dim=512,
75
+ label_size=0,
76
+ mapping_layers=8,
77
+ mapping_fmaps=512,
78
+ mapping_lr_mul=0.01,
79
+ repeat_w=True,
80
+ image_channels=3,
81
+ final_tanh=False,
82
+ const_input=True,
83
+ fused_scale='auto',
84
+ use_wscale=True,
85
+ fmaps_base=16 << 10,
86
+ fmaps_max=512):
87
+ """Initializes with basic settings.
88
+
89
+ Raises:
90
+ ValueError: If the `resolution` is not supported, or `fused_scale`
91
+ is not supported.
92
+ """
93
+ super().__init__()
94
+
95
+ if resolution not in _RESOLUTIONS_ALLOWED:
96
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
97
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
98
+ if fused_scale not in _FUSED_SCALE_ALLOWED:
99
+ raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
100
+ f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
101
+
102
+ self.init_res = _INIT_RES
103
+ self.resolution = resolution
104
+ self.z_space_dim = z_space_dim
105
+ self.w_space_dim = w_space_dim
106
+ self.label_size = label_size
107
+ self.mapping_layers = mapping_layers
108
+ self.mapping_fmaps = mapping_fmaps
109
+ self.mapping_lr_mul = mapping_lr_mul
110
+ self.repeat_w = repeat_w
111
+ self.image_channels = image_channels
112
+ self.final_tanh = final_tanh
113
+ self.const_input = const_input
114
+ self.fused_scale = fused_scale
115
+ self.use_wscale = use_wscale
116
+ self.fmaps_base = fmaps_base
117
+ self.fmaps_max = fmaps_max
118
+
119
+ self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
120
+
121
+ if self.repeat_w:
122
+ self.mapping_space_dim = self.w_space_dim
123
+ else:
124
+ self.mapping_space_dim = self.w_space_dim * self.num_layers
125
+ self.mapping = MappingModule(input_space_dim=self.z_space_dim,
126
+ hidden_space_dim=self.mapping_fmaps,
127
+ final_space_dim=self.mapping_space_dim,
128
+ label_size=self.label_size,
129
+ num_layers=self.mapping_layers,
130
+ use_wscale=self.use_wscale,
131
+ lr_mul=self.mapping_lr_mul)
132
+
133
+ self.truncation = TruncationModule(w_space_dim=self.w_space_dim,
134
+ num_layers=self.num_layers,
135
+ repeat_w=self.repeat_w)
136
+
137
+ self.synthesis = SynthesisModule(resolution=self.resolution,
138
+ init_resolution=self.init_res,
139
+ w_space_dim=self.w_space_dim,
140
+ image_channels=self.image_channels,
141
+ final_tanh=self.final_tanh,
142
+ const_input=self.const_input,
143
+ fused_scale=self.fused_scale,
144
+ use_wscale=self.use_wscale,
145
+ fmaps_base=self.fmaps_base,
146
+ fmaps_max=self.fmaps_max)
147
+
148
+ self.pth_to_tf_var_mapping = {}
149
+ for key, val in self.mapping.pth_to_tf_var_mapping.items():
150
+ self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
151
+ for key, val in self.truncation.pth_to_tf_var_mapping.items():
152
+ self.pth_to_tf_var_mapping[f'truncation.{key}'] = val
153
+ for key, val in self.synthesis.pth_to_tf_var_mapping.items():
154
+ self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
155
+
156
+ def forward(self,
157
+ z,
158
+ label=None,
159
+ lod=None,
160
+ w_moving_decay=0.995,
161
+ style_mixing_prob=0.9,
162
+ trunc_psi=None,
163
+ trunc_layers=None,
164
+ randomize_noise=False,
165
+ **_unused_kwargs):
166
+ mapping_results = self.mapping(z, label)
167
+ w = mapping_results['w']
168
+
169
+ if self.training and w_moving_decay < 1:
170
+ batch_w_avg = all_gather(w).mean(dim=0)
171
+ self.truncation.w_avg.copy_(
172
+ self.truncation.w_avg * w_moving_decay +
173
+ batch_w_avg * (1 - w_moving_decay))
174
+
175
+ if self.training and style_mixing_prob > 0:
176
+ new_z = torch.randn_like(z)
177
+ new_w = self.mapping(new_z, label)['w']
178
+ lod = self.synthesis.lod.cpu().tolist() if lod is None else lod
179
+ current_layers = self.num_layers - int(lod) * 2
180
+ if np.random.uniform() < style_mixing_prob:
181
+ mixing_cutoff = np.random.randint(1, current_layers)
182
+ w = self.truncation(w)
183
+ new_w = self.truncation(new_w)
184
+ w[:, mixing_cutoff:] = new_w[:, mixing_cutoff:]
185
+
186
+ wp = self.truncation(w, trunc_psi, trunc_layers)
187
+ synthesis_results = self.synthesis(wp, lod, randomize_noise)
188
+
189
+ return {**mapping_results, **synthesis_results}
190
+
191
+
192
+ class MappingModule(nn.Module):
193
+ """Implements the latent space mapping module.
194
+
195
+ Basically, this module executes several dense layers in sequence.
196
+ """
197
+
198
+ def __init__(self,
199
+ input_space_dim=512,
200
+ hidden_space_dim=512,
201
+ final_space_dim=512,
202
+ label_size=0,
203
+ num_layers=8,
204
+ normalize_input=True,
205
+ use_wscale=True,
206
+ lr_mul=0.01):
207
+ super().__init__()
208
+
209
+ self.input_space_dim = input_space_dim
210
+ self.hidden_space_dim = hidden_space_dim
211
+ self.final_space_dim = final_space_dim
212
+ self.label_size = label_size
213
+ self.num_layers = num_layers
214
+ self.normalize_input = normalize_input
215
+ self.use_wscale = use_wscale
216
+ self.lr_mul = lr_mul
217
+
218
+ self.norm = PixelNormLayer() if self.normalize_input else nn.Identity()
219
+
220
+ self.pth_to_tf_var_mapping = {}
221
+ for i in range(num_layers):
222
+ dim_mul = 2 if label_size else 1
223
+ in_channels = (input_space_dim * dim_mul if i == 0 else
224
+ hidden_space_dim)
225
+ out_channels = (final_space_dim if i == (num_layers - 1) else
226
+ hidden_space_dim)
227
+ self.add_module(f'dense{i}',
228
+ DenseBlock(in_channels=in_channels,
229
+ out_channels=out_channels,
230
+ use_wscale=self.use_wscale,
231
+ lr_mul=self.lr_mul))
232
+ self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
233
+ self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
234
+ if label_size:
235
+ self.label_weight = nn.Parameter(
236
+ torch.randn(label_size, input_space_dim))
237
+ self.pth_to_tf_var_mapping[f'label_weight'] = f'LabelConcat/weight'
238
+
239
+ def forward(self, z, label=None):
240
+ if z.ndim != 2 or z.shape[1] != self.input_space_dim:
241
+ raise ValueError(f'Input latent code should be with shape '
242
+ f'[batch_size, input_dim], where '
243
+ f'`input_dim` equals to {self.input_space_dim}!\n'
244
+ f'But `{z.shape}` is received!')
245
+ if self.label_size:
246
+ if label is None:
247
+ raise ValueError(f'Model requires an additional label '
248
+ f'(with size {self.label_size}) as input, '
249
+ f'but no label is received!')
250
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
251
+ raise ValueError(f'Input label should be with shape '
252
+ f'[batch_size, label_size], where '
253
+ f'`batch_size` equals to that of '
254
+ f'latent codes ({z.shape[0]}) and '
255
+ f'`label_size` equals to {self.label_size}!\n'
256
+ f'But `{label.shape}` is received!')
257
+ embedding = torch.matmul(label, self.label_weight)
258
+ z = torch.cat((z, embedding), dim=1)
259
+
260
+ z = self.norm(z)
261
+ w = z
262
+ for i in range(self.num_layers):
263
+ w = self.__getattr__(f'dense{i}')(w)
264
+ results = {
265
+ 'z': z,
266
+ 'label': label,
267
+ 'w': w,
268
+ }
269
+ if self.label_size:
270
+ results['embedding'] = embedding
271
+ return results
272
+
273
+
274
+ class TruncationModule(nn.Module):
275
+ """Implements the truncation module.
276
+
277
+ Truncation is executed as follows:
278
+
279
+ For layers in range [0, truncation_layers), the truncated w-code is computed
280
+ as
281
+
282
+ w_new = w_avg + (w - w_avg) * truncation_psi
283
+
284
+ To disable truncation, please set
285
+ (1) truncation_psi = 1.0 (None) OR
286
+ (2) truncation_layers = 0 (None)
287
+
288
+ NOTE: The returned tensor is layer-wise style codes.
289
+ """
290
+
291
+ def __init__(self, w_space_dim, num_layers, repeat_w=True):
292
+ super().__init__()
293
+
294
+ self.num_layers = num_layers
295
+ self.w_space_dim = w_space_dim
296
+ self.repeat_w = repeat_w
297
+
298
+ if self.repeat_w:
299
+ self.register_buffer('w_avg', torch.zeros(w_space_dim))
300
+ else:
301
+ self.register_buffer('w_avg', torch.zeros(num_layers * w_space_dim))
302
+ self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
303
+
304
+ def forward(self, w, trunc_psi=None, trunc_layers=None):
305
+ if w.ndim == 2:
306
+ if self.repeat_w and w.shape[1] == self.w_space_dim:
307
+ w = w.view(-1, 1, self.w_space_dim)
308
+ wp = w.repeat(1, self.num_layers, 1)
309
+ else:
310
+ assert w.shape[1] == self.w_space_dim * self.num_layers
311
+ wp = w.view(-1, self.num_layers, self.w_space_dim)
312
+ else:
313
+ wp = w
314
+ assert wp.ndim == 3
315
+ assert wp.shape[1:] == (self.num_layers, self.w_space_dim)
316
+
317
+ trunc_psi = 1.0 if trunc_psi is None else trunc_psi
318
+ trunc_layers = 0 if trunc_layers is None else trunc_layers
319
+ if trunc_psi < 1.0 and trunc_layers > 0:
320
+ layer_idx = np.arange(self.num_layers).reshape(1, -1, 1)
321
+ coefs = np.ones_like(layer_idx, dtype=np.float32)
322
+ coefs[layer_idx < trunc_layers] *= trunc_psi
323
+ coefs = torch.from_numpy(coefs).to(wp)
324
+ w_avg = self.w_avg.view(1, -1, self.w_space_dim)
325
+ wp = w_avg + (wp - w_avg) * coefs
326
+ return wp
327
+
328
+
329
+ class SynthesisModule(nn.Module):
330
+ """Implements the image synthesis module.
331
+
332
+ Basically, this module executes several convolutional layers in sequence.
333
+ """
334
+
335
+ def __init__(self,
336
+ resolution=1024,
337
+ init_resolution=4,
338
+ w_space_dim=512,
339
+ image_channels=3,
340
+ final_tanh=False,
341
+ const_input=True,
342
+ fused_scale='auto',
343
+ use_wscale=True,
344
+ fmaps_base=16 << 10,
345
+ fmaps_max=512):
346
+ super().__init__()
347
+
348
+ self.init_res = init_resolution
349
+ self.init_res_log2 = int(np.log2(self.init_res))
350
+ self.resolution = resolution
351
+ self.final_res_log2 = int(np.log2(self.resolution))
352
+ self.w_space_dim = w_space_dim
353
+ self.image_channels = image_channels
354
+ self.final_tanh = final_tanh
355
+ self.const_input = const_input
356
+ self.fused_scale = fused_scale
357
+ self.use_wscale = use_wscale
358
+ self.fmaps_base = fmaps_base
359
+ self.fmaps_max = fmaps_max
360
+
361
+ self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
362
+
363
+ # Level of detail (used for progressive training).
364
+ self.register_buffer('lod', torch.zeros(()))
365
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
366
+
367
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
368
+ res = 2 ** res_log2
369
+ block_idx = res_log2 - self.init_res_log2
370
+
371
+ # First convolution layer for each resolution.
372
+ layer_name = f'layer{2 * block_idx}'
373
+ if res == self.init_res:
374
+ if self.const_input:
375
+ self.add_module(layer_name,
376
+ ConvBlock(in_channels=self.get_nf(res),
377
+ out_channels=self.get_nf(res),
378
+ resolution=self.init_res,
379
+ w_space_dim=self.w_space_dim,
380
+ position='const_init',
381
+ use_wscale=self.use_wscale))
382
+ tf_layer_name = 'Const'
383
+ self.pth_to_tf_var_mapping[f'{layer_name}.const'] = (
384
+ f'{res}x{res}/{tf_layer_name}/const')
385
+ else:
386
+ self.add_module(layer_name,
387
+ ConvBlock(in_channels=self.w_space_dim,
388
+ out_channels=self.get_nf(res),
389
+ resolution=self.init_res,
390
+ w_space_dim=self.w_space_dim,
391
+ kernel_size=self.init_res,
392
+ padding=self.init_res - 1,
393
+ use_wscale=self.use_wscale))
394
+ tf_layer_name = 'Dense'
395
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
396
+ f'{res}x{res}/{tf_layer_name}/weight')
397
+ else:
398
+ if self.fused_scale == 'auto':
399
+ fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES)
400
+ else:
401
+ fused_scale = self.fused_scale
402
+ self.add_module(layer_name,
403
+ ConvBlock(in_channels=self.get_nf(res // 2),
404
+ out_channels=self.get_nf(res),
405
+ resolution=res,
406
+ w_space_dim=self.w_space_dim,
407
+ upsample=True,
408
+ fused_scale=fused_scale,
409
+ use_wscale=self.use_wscale))
410
+ tf_layer_name = 'Conv0_up'
411
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
412
+ f'{res}x{res}/{tf_layer_name}/weight')
413
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
414
+ f'{res}x{res}/{tf_layer_name}/bias')
415
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
416
+ f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
417
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
418
+ f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
419
+ self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.weight'] = (
420
+ f'{res}x{res}/{tf_layer_name}/Noise/weight')
421
+ self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.noise'] = (
422
+ f'noise{2 * block_idx}')
423
+
424
+ # Second convolution layer for each resolution.
425
+ layer_name = f'layer{2 * block_idx + 1}'
426
+ self.add_module(layer_name,
427
+ ConvBlock(in_channels=self.get_nf(res),
428
+ out_channels=self.get_nf(res),
429
+ resolution=res,
430
+ w_space_dim=self.w_space_dim,
431
+ use_wscale=self.use_wscale))
432
+ tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
433
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
434
+ f'{res}x{res}/{tf_layer_name}/weight')
435
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
436
+ f'{res}x{res}/{tf_layer_name}/bias')
437
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
438
+ f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
439
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
440
+ f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
441
+ self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.weight'] = (
442
+ f'{res}x{res}/{tf_layer_name}/Noise/weight')
443
+ self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.noise'] = (
444
+ f'noise{2 * block_idx + 1}')
445
+
446
+ # Output convolution layer for each resolution.
447
+ self.add_module(f'output{block_idx}',
448
+ ConvBlock(in_channels=self.get_nf(res),
449
+ out_channels=self.image_channels,
450
+ resolution=res,
451
+ w_space_dim=self.w_space_dim,
452
+ position='last',
453
+ kernel_size=1,
454
+ padding=0,
455
+ use_wscale=self.use_wscale,
456
+ wscale_gain=1.0,
457
+ activation_type='linear'))
458
+ self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
459
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
460
+ self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
461
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
462
+
463
+ self.upsample = UpsamplingLayer()
464
+ self.final_activate = nn.Tanh() if final_tanh else nn.Identity()
465
+
466
+ def get_nf(self, res):
467
+ """Gets number of feature maps according to current resolution."""
468
+ return min(self.fmaps_base // res, self.fmaps_max)
469
+
470
+ def forward(self, wp, lod=None, randomize_noise=False):
471
+ if wp.ndim != 3 or wp.shape[1:] != (self.num_layers, self.w_space_dim):
472
+ raise ValueError(f'Input tensor should be with shape '
473
+ f'[batch_size, num_layers, w_space_dim], where '
474
+ f'`num_layers` equals to {self.num_layers}, and '
475
+ f'`w_space_dim` equals to {self.w_space_dim}!\n'
476
+ f'But `{wp.shape}` is received!')
477
+
478
+ lod = self.lod.cpu().tolist() if lod is None else lod
479
+ if lod + self.init_res_log2 > self.final_res_log2:
480
+ raise ValueError(f'Maximum level-of-detail (lod) is '
481
+ f'{self.final_res_log2 - self.init_res_log2}, '
482
+ f'but `{lod}` is received!')
483
+
484
+ results = {'wp': wp}
485
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
486
+ current_lod = self.final_res_log2 - res_log2
487
+ if lod < current_lod + 1:
488
+ block_idx = res_log2 - self.init_res_log2
489
+ if block_idx == 0:
490
+ if self.const_input:
491
+ x, style = self.layer0(None, wp[:, 0], randomize_noise)
492
+ else:
493
+ x = wp[:, 0].view(-1, self.w_space_dim, 1, 1)
494
+ x, style = self.layer0(x, wp[:, 0], randomize_noise)
495
+ else:
496
+ x, style = self.__getattr__(f'layer{2 * block_idx}')(
497
+ x, wp[:, 2 * block_idx])
498
+ results[f'style{2 * block_idx:02d}'] = style
499
+ x, style = self.__getattr__(f'layer{2 * block_idx + 1}')(
500
+ x, wp[:, 2 * block_idx + 1])
501
+ results[f'style{2 * block_idx + 1:02d}'] = style
502
+ if current_lod - 1 < lod <= current_lod:
503
+ image = self.__getattr__(f'output{block_idx}')(x, None)
504
+ elif current_lod < lod < current_lod + 1:
505
+ alpha = np.ceil(lod) - lod
506
+ image = (self.__getattr__(f'output{block_idx}')(x, None) * alpha
507
+ + self.upsample(image) * (1 - alpha))
508
+ elif lod >= current_lod + 1:
509
+ image = self.upsample(image)
510
+ results['image'] = self.final_activate(image)
511
+ return results
512
+
513
+
514
+ class PixelNormLayer(nn.Module):
515
+ """Implements pixel-wise feature vector normalization layer."""
516
+
517
+ def __init__(self, epsilon=1e-8):
518
+ super().__init__()
519
+ self.eps = epsilon
520
+
521
+ def forward(self, x):
522
+ norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
523
+ return x / norm
524
+
525
+
526
+ class InstanceNormLayer(nn.Module):
527
+ """Implements instance normalization layer."""
528
+
529
+ def __init__(self, epsilon=1e-8):
530
+ super().__init__()
531
+ self.eps = epsilon
532
+
533
+ def forward(self, x):
534
+ if x.ndim != 4:
535
+ raise ValueError(f'The input tensor should be with shape '
536
+ f'[batch_size, channel, height, width], '
537
+ f'but `{x.shape}` is received!')
538
+ x = x - torch.mean(x, dim=[2, 3], keepdim=True)
539
+ norm = torch.sqrt(
540
+ torch.mean(x ** 2, dim=[2, 3], keepdim=True) + self.eps)
541
+ return x / norm
542
+
543
+
544
+ class UpsamplingLayer(nn.Module):
545
+ """Implements the upsampling layer.
546
+
547
+ Basically, this layer can be used to upsample feature maps with nearest
548
+ neighbor interpolation.
549
+ """
550
+
551
+ def __init__(self, scale_factor=2):
552
+ super().__init__()
553
+ self.scale_factor = scale_factor
554
+
555
+ def forward(self, x):
556
+ if self.scale_factor <= 1:
557
+ return x
558
+ return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
559
+
560
+
561
+ class Blur(torch.autograd.Function):
562
+ """Defines blur operation with customized gradient computation."""
563
+
564
+ @staticmethod
565
+ def forward(ctx, x, kernel):
566
+ ctx.save_for_backward(kernel)
567
+ y = F.conv2d(input=x,
568
+ weight=kernel,
569
+ bias=None,
570
+ stride=1,
571
+ padding=1,
572
+ groups=x.shape[1])
573
+ return y
574
+
575
+ @staticmethod
576
+ def backward(ctx, dy):
577
+ kernel, = ctx.saved_tensors
578
+ dx = F.conv2d(input=dy,
579
+ weight=kernel.flip((2, 3)),
580
+ bias=None,
581
+ stride=1,
582
+ padding=1,
583
+ groups=dy.shape[1])
584
+ return dx, None, None
585
+
586
+
587
+ class BlurLayer(nn.Module):
588
+ """Implements the blur layer."""
589
+
590
+ def __init__(self,
591
+ channels,
592
+ kernel=(1, 2, 1),
593
+ normalize=True):
594
+ super().__init__()
595
+ kernel = np.array(kernel, dtype=np.float32).reshape(1, -1)
596
+ kernel = kernel.T.dot(kernel)
597
+ if normalize:
598
+ kernel /= np.sum(kernel)
599
+ kernel = kernel[np.newaxis, np.newaxis]
600
+ kernel = np.tile(kernel, [channels, 1, 1, 1])
601
+ self.register_buffer('kernel', torch.from_numpy(kernel))
602
+
603
+ def forward(self, x):
604
+ return Blur.apply(x, self.kernel)
605
+
606
+
607
+ class NoiseApplyingLayer(nn.Module):
608
+ """Implements the noise applying layer."""
609
+
610
+ def __init__(self, resolution, channels):
611
+ super().__init__()
612
+ self.res = resolution
613
+ self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
614
+ self.weight = nn.Parameter(torch.zeros(channels))
615
+
616
+ def forward(self, x, randomize_noise=False):
617
+ if x.ndim != 4:
618
+ raise ValueError(f'The input tensor should be with shape '
619
+ f'[batch_size, channel, height, width], '
620
+ f'but `{x.shape}` is received!')
621
+ if randomize_noise:
622
+ noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
623
+ else:
624
+ noise = self.noise
625
+ return x + noise * self.weight.view(1, -1, 1, 1)
626
+
627
+
628
+ class StyleModLayer(nn.Module):
629
+ """Implements the style modulation layer."""
630
+
631
+ def __init__(self,
632
+ w_space_dim,
633
+ out_channels,
634
+ use_wscale=True):
635
+ super().__init__()
636
+ self.w_space_dim = w_space_dim
637
+ self.out_channels = out_channels
638
+
639
+ weight_shape = (self.out_channels * 2, self.w_space_dim)
640
+ wscale = _STYLEMOD_WSCALE_GAIN / np.sqrt(self.w_space_dim)
641
+ if use_wscale:
642
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
643
+ self.wscale = wscale
644
+ else:
645
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
646
+ self.wscale = 1.0
647
+
648
+ self.bias = nn.Parameter(torch.zeros(self.out_channels * 2))
649
+
650
+ def forward(self, x, w):
651
+ if w.ndim != 2 or w.shape[1] != self.w_space_dim:
652
+ raise ValueError(f'The input tensor should be with shape '
653
+ f'[batch_size, w_space_dim], where '
654
+ f'`w_space_dim` equals to {self.w_space_dim}!\n'
655
+ f'But `{w.shape}` is received!')
656
+ style = F.linear(w, weight=self.weight * self.wscale, bias=self.bias)
657
+ style_split = style.view(-1, 2, self.out_channels, 1, 1)
658
+ x = x * (style_split[:, 0] + 1) + style_split[:, 1]
659
+ return x, style
660
+
661
+
662
+ class ConvBlock(nn.Module):
663
+ """Implements the normal convolutional block.
664
+
665
+ Basically, this block executes upsampling layer (if needed), convolutional
666
+ layer, blurring layer, noise applying layer, activation layer, instance
667
+ normalization layer, and style modulation layer in sequence.
668
+ """
669
+
670
+ def __init__(self,
671
+ in_channels,
672
+ out_channels,
673
+ resolution,
674
+ w_space_dim,
675
+ position=None,
676
+ kernel_size=3,
677
+ stride=1,
678
+ padding=1,
679
+ add_bias=True,
680
+ upsample=False,
681
+ fused_scale=False,
682
+ use_wscale=True,
683
+ wscale_gain=_WSCALE_GAIN,
684
+ lr_mul=1.0,
685
+ activation_type='lrelu'):
686
+ """Initializes with block settings.
687
+
688
+ Args:
689
+ in_channels: Number of channels of the input tensor.
690
+ out_channels: Number of channels of the output tensor.
691
+ resolution: Resolution of the output tensor.
692
+ w_space_dim: Dimension of W space for style modulation.
693
+ position: Position of the layer. `const_init`, `last` would lead to
694
+ different behavior. (default: None)
695
+ kernel_size: Size of the convolutional kernels. (default: 3)
696
+ stride: Stride parameter for convolution operation. (default: 1)
697
+ padding: Padding parameter for convolution operation. (default: 1)
698
+ add_bias: Whether to add bias onto the convolutional result.
699
+ (default: True)
700
+ upsample: Whether to upsample the input tensor before convolution.
701
+ (default: False)
702
+ fused_scale: Whether to fused `upsample` and `conv2d` together,
703
+ resulting in `conv2d_transpose`. (default: False)
704
+ use_wscale: Whether to use weight scaling. (default: True)
705
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
706
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
707
+ activation_type: Type of activation. Support `linear` and `lrelu`.
708
+ (default: `lrelu`)
709
+
710
+ Raises:
711
+ NotImplementedError: If the `activation_type` is not supported.
712
+ """
713
+ super().__init__()
714
+
715
+ self.position = position
716
+
717
+ if add_bias:
718
+ self.bias = nn.Parameter(torch.zeros(out_channels))
719
+ self.bscale = lr_mul
720
+ else:
721
+ self.bias = None
722
+
723
+ if activation_type == 'linear':
724
+ self.activate = nn.Identity()
725
+ elif activation_type == 'lrelu':
726
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
727
+ else:
728
+ raise NotImplementedError(f'Not implemented activation function: '
729
+ f'`{activation_type}`!')
730
+
731
+ if self.position != 'last':
732
+ self.apply_noise = NoiseApplyingLayer(resolution, out_channels)
733
+ self.normalize = InstanceNormLayer()
734
+ self.style = StyleModLayer(w_space_dim, out_channels, use_wscale)
735
+
736
+ if self.position == 'const_init':
737
+ self.const = nn.Parameter(
738
+ torch.ones(1, in_channels, resolution, resolution))
739
+ return
740
+
741
+ self.blur = BlurLayer(out_channels) if upsample else nn.Identity()
742
+
743
+ if upsample and not fused_scale:
744
+ self.upsample = UpsamplingLayer()
745
+ else:
746
+ self.upsample = nn.Identity()
747
+
748
+ if upsample and fused_scale:
749
+ self.use_conv2d_transpose = True
750
+ self.stride = 2
751
+ self.padding = 1
752
+ else:
753
+ self.use_conv2d_transpose = False
754
+ self.stride = stride
755
+ self.padding = padding
756
+
757
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
758
+ fan_in = kernel_size * kernel_size * in_channels
759
+ wscale = wscale_gain / np.sqrt(fan_in)
760
+ if use_wscale:
761
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
762
+ self.wscale = wscale * lr_mul
763
+ else:
764
+ self.weight = nn.Parameter(
765
+ torch.randn(*weight_shape) * wscale / lr_mul)
766
+ self.wscale = lr_mul
767
+
768
+ def forward(self, x, w, randomize_noise=False):
769
+ if self.position != 'const_init':
770
+ x = self.upsample(x)
771
+ weight = self.weight * self.wscale
772
+ if self.use_conv2d_transpose:
773
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0)
774
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
775
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
776
+ weight = weight.permute(1, 0, 2, 3)
777
+ x = F.conv_transpose2d(x,
778
+ weight=weight,
779
+ bias=None,
780
+ stride=self.stride,
781
+ padding=self.padding)
782
+ else:
783
+ x = F.conv2d(x,
784
+ weight=weight,
785
+ bias=None,
786
+ stride=self.stride,
787
+ padding=self.padding)
788
+ x = self.blur(x)
789
+ else:
790
+ x = self.const.repeat(w.shape[0], 1, 1, 1)
791
+
792
+ bias = self.bias * self.bscale if self.bias is not None else None
793
+
794
+ if self.position == 'last':
795
+ if bias is not None:
796
+ x = x + bias.view(1, -1, 1, 1)
797
+ return x
798
+
799
+ x = self.apply_noise(x, randomize_noise)
800
+ if bias is not None:
801
+ x = x + bias.view(1, -1, 1, 1)
802
+ x = self.activate(x)
803
+ x = self.normalize(x)
804
+ x, style = self.style(x, w)
805
+ return x, style
806
+
807
+
808
+ class DenseBlock(nn.Module):
809
+ """Implements the dense block.
810
+
811
+ Basically, this block executes fully-connected layer and activation layer.
812
+ """
813
+
814
+ def __init__(self,
815
+ in_channels,
816
+ out_channels,
817
+ add_bias=True,
818
+ use_wscale=True,
819
+ wscale_gain=_WSCALE_GAIN,
820
+ lr_mul=1.0,
821
+ activation_type='lrelu'):
822
+ """Initializes with block settings.
823
+
824
+ Args:
825
+ in_channels: Number of channels of the input tensor.
826
+ out_channels: Number of channels of the output tensor.
827
+ add_bias: Whether to add bias onto the fully-connected result.
828
+ (default: True)
829
+ use_wscale: Whether to use weight scaling. (default: True)
830
+ wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
831
+ lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
832
+ activation_type: Type of activation. Support `linear` and `lrelu`.
833
+ (default: `lrelu`)
834
+
835
+ Raises:
836
+ NotImplementedError: If the `activation_type` is not supported.
837
+ """
838
+ super().__init__()
839
+ weight_shape = (out_channels, in_channels)
840
+ wscale = wscale_gain / np.sqrt(in_channels)
841
+ if use_wscale:
842
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
843
+ self.wscale = wscale * lr_mul
844
+ else:
845
+ self.weight = nn.Parameter(
846
+ torch.randn(*weight_shape) * wscale / lr_mul)
847
+ self.wscale = lr_mul
848
+
849
+ if add_bias:
850
+ self.bias = nn.Parameter(torch.zeros(out_channels))
851
+ self.bscale = lr_mul
852
+ else:
853
+ self.bias = None
854
+
855
+ if activation_type == 'linear':
856
+ self.activate = nn.Identity()
857
+ elif activation_type == 'lrelu':
858
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
859
+ else:
860
+ raise NotImplementedError(f'Not implemented activation function: '
861
+ f'`{activation_type}`!')
862
+
863
+ def forward(self, x):
864
+ if x.ndim != 2:
865
+ x = x.view(x.shape[0], -1)
866
+ bias = self.bias * self.bscale if self.bias is not None else None
867
+ x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
868
+ x = self.activate(x)
869
+ return x
models/sync_op.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the synchronizing operator."""
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ __all__ = ['all_gather']
8
+
9
+
10
+ def all_gather(tensor):
11
+ """Gathers tensor from all devices and does averaging."""
12
+ if not dist.is_initialized():
13
+ return tensor
14
+
15
+ world_size = dist.get_world_size()
16
+ tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
17
+ dist.all_gather(tensor_list, tensor, async_op=False)
18
+ return torch.mean(torch.stack(tensor_list, dim=0), dim=0)
sefa.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SeFa."""
2
+
3
+ import os
4
+ import argparse
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ import torch
9
+
10
+ from models import parse_gan_type
11
+ from utils import to_tensor
12
+ from utils import postprocess
13
+ from utils import load_generator
14
+ from utils import factorize_weight
15
+ from utils import HtmlPageVisualizer
16
+
17
+
18
+ def parse_args():
19
+ """Parses arguments."""
20
+ parser = argparse.ArgumentParser(
21
+ description='Discover semantics from the pre-trained weight.')
22
+ parser.add_argument('model_name', type=str,
23
+ help='Name to the pre-trained model.')
24
+ parser.add_argument('--save_dir', type=str, default='results',
25
+ help='Directory to save the visualization pages. '
26
+ '(default: %(default)s)')
27
+ parser.add_argument('-L', '--layer_idx', type=str, default='all',
28
+ help='Indices of layers to interpret. '
29
+ '(default: %(default)s)')
30
+ parser.add_argument('-N', '--num_samples', type=int, default=5,
31
+ help='Number of samples used for visualization. '
32
+ '(default: %(default)s)')
33
+ parser.add_argument('-K', '--num_semantics', type=int, default=5,
34
+ help='Number of semantic boundaries corresponding to '
35
+ 'the top-k eigen values. (default: %(default)s)')
36
+ parser.add_argument('--start_distance', type=float, default=-3.0,
37
+ help='Start point for manipulation on each semantic. '
38
+ '(default: %(default)s)')
39
+ parser.add_argument('--end_distance', type=float, default=3.0,
40
+ help='Ending point for manipulation on each semantic. '
41
+ '(default: %(default)s)')
42
+ parser.add_argument('--step', type=int, default=11,
43
+ help='Manipulation step on each semantic. '
44
+ '(default: %(default)s)')
45
+ parser.add_argument('--viz_size', type=int, default=256,
46
+ help='Size of images to visualize on the HTML page. '
47
+ '(default: %(default)s)')
48
+ parser.add_argument('--trunc_psi', type=float, default=0.7,
49
+ help='Psi factor used for truncation. This is '
50
+ 'particularly applicable to StyleGAN (v1/v2). '
51
+ '(default: %(default)s)')
52
+ parser.add_argument('--trunc_layers', type=int, default=8,
53
+ help='Number of layers to perform truncation. This is '
54
+ 'particularly applicable to StyleGAN (v1/v2). '
55
+ '(default: %(default)s)')
56
+ parser.add_argument('--seed', type=int, default=0,
57
+ help='Seed for sampling. (default: %(default)s)')
58
+ parser.add_argument('--gpu_id', type=str, default='0',
59
+ help='GPU(s) to use. (default: %(default)s)')
60
+ return parser.parse_args()
61
+
62
+
63
+ def main():
64
+ """Main function."""
65
+ args = parse_args()
66
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
67
+ os.makedirs(args.save_dir, exist_ok=True)
68
+
69
+ # Factorize weights.
70
+ generator = load_generator(args.model_name)
71
+ gan_type = parse_gan_type(generator)
72
+ layers, boundaries, values = factorize_weight(generator, args.layer_idx)
73
+
74
+ # Set random seed.
75
+ np.random.seed(args.seed)
76
+ torch.manual_seed(args.seed)
77
+
78
+ # Prepare codes.
79
+ codes = torch.randn(args.num_samples, generator.z_space_dim).cuda()
80
+ if gan_type == 'pggan':
81
+ codes = generator.layer0.pixel_norm(codes)
82
+ elif gan_type in ['stylegan', 'stylegan2']:
83
+ codes = generator.mapping(codes)['w']
84
+ codes = generator.truncation(codes,
85
+ trunc_psi=args.trunc_psi,
86
+ trunc_layers=args.trunc_layers)
87
+ codes = codes.detach().cpu().numpy()
88
+
89
+ # Generate visualization pages.
90
+ distances = np.linspace(args.start_distance,args.end_distance, args.step)
91
+ num_sam = args.num_samples
92
+ num_sem = args.num_semantics
93
+ vizer_1 = HtmlPageVisualizer(num_rows=num_sem * (num_sam + 1),
94
+ num_cols=args.step + 1,
95
+ viz_size=args.viz_size)
96
+ vizer_2 = HtmlPageVisualizer(num_rows=num_sam * (num_sem + 1),
97
+ num_cols=args.step + 1,
98
+ viz_size=args.viz_size)
99
+
100
+ headers = [''] + [f'Distance {d:.2f}' for d in distances]
101
+ vizer_1.set_headers(headers)
102
+ vizer_2.set_headers(headers)
103
+ for sem_id in range(num_sem):
104
+ value = values[sem_id]
105
+ vizer_1.set_cell(sem_id * (num_sam + 1), 0,
106
+ text=f'Semantic {sem_id:03d}<br>({value:.3f})',
107
+ highlight=True)
108
+ for sam_id in range(num_sam):
109
+ vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, 0,
110
+ text=f'Sample {sam_id:03d}')
111
+ for sam_id in range(num_sam):
112
+ vizer_2.set_cell(sam_id * (num_sem + 1), 0,
113
+ text=f'Sample {sam_id:03d}',
114
+ highlight=True)
115
+ for sem_id in range(num_sem):
116
+ value = values[sem_id]
117
+ vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, 0,
118
+ text=f'Semantic {sem_id:03d}<br>({value:.3f})')
119
+
120
+ for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
121
+ code = codes[sam_id:sam_id + 1]
122
+ for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
123
+ boundary = boundaries[sem_id:sem_id + 1]
124
+ for col_id, d in enumerate(distances, start=1):
125
+ temp_code = code.copy()
126
+ if gan_type == 'pggan':
127
+ temp_code += boundary * d
128
+ image = generator(to_tensor(temp_code))['image']
129
+ elif gan_type in ['stylegan', 'stylegan2']:
130
+ temp_code[:, layers, :] += boundary * d
131
+ image = generator.synthesis(to_tensor(temp_code))['image']
132
+ image = postprocess(image)[0]
133
+ vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, col_id,
134
+ image=image)
135
+ vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, col_id,
136
+ image=image)
137
+
138
+ prefix = (f'{args.model_name}_'
139
+ f'N{num_sam}_K{num_sem}_L{args.layer_idx}_seed{args.seed}')
140
+ vizer_1.save(os.path.join(args.save_dir, f'{prefix}_sample_first.html'))
141
+ vizer_2.save(os.path.join(args.save_dir, f'{prefix}_semantic_first.html'))
142
+
143
+
144
+ if __name__ == '__main__':
145
+ main()
utils.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions."""
2
+
3
+ import base64
4
+ import os
5
+ import subprocess
6
+ import cv2
7
+ import numpy as np
8
+
9
+ import torch
10
+
11
+ from models import MODEL_ZOO
12
+ from models import build_generator
13
+ from models import parse_gan_type
14
+
15
+ __all__ = ['postprocess', 'load_generator', 'factorize_weight',
16
+ 'HtmlPageVisualizer']
17
+
18
+ CHECKPOINT_DIR = 'checkpoints'
19
+
20
+
21
+ def to_tensor(array):
22
+ """Converts a `numpy.ndarray` to `torch.Tensor`.
23
+
24
+ Args:
25
+ array: The input array to convert.
26
+
27
+ Returns:
28
+ A `torch.Tensor` with dtype `torch.FloatTensor` on cuda device.
29
+ """
30
+ assert isinstance(array, np.ndarray)
31
+ return torch.from_numpy(array).type(torch.FloatTensor).cuda()
32
+
33
+
34
+ def postprocess(images, min_val=-1.0, max_val=1.0):
35
+ """Post-processes images from `torch.Tensor` to `numpy.ndarray`.
36
+
37
+ Args:
38
+ images: A `torch.Tensor` with shape `NCHW` to process.
39
+ min_val: The minimum value of the input tensor. (default: -1.0)
40
+ max_val: The maximum value of the input tensor. (default: 1.0)
41
+
42
+ Returns:
43
+ A `numpy.ndarray` with shape `NHWC` and pixel range [0, 255].
44
+ """
45
+ assert isinstance(images, torch.Tensor)
46
+ images = images.detach().cpu().numpy()
47
+ images = (images - min_val) * 255 / (max_val - min_val)
48
+ images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
49
+ images = images.transpose(0, 2, 3, 1)
50
+ return images
51
+
52
+
53
+ def load_generator(model_name):
54
+ """Loads pre-trained generator.
55
+
56
+ Args:
57
+ model_name: Name of the model. Should be a key in `models.MODEL_ZOO`.
58
+
59
+ Returns:
60
+ A generator, which is a `torch.nn.Module`, with pre-trained weights
61
+ loaded.
62
+
63
+ Raises:
64
+ KeyError: If the input `model_name` is not in `models.MODEL_ZOO`.
65
+ """
66
+ if model_name not in MODEL_ZOO:
67
+ raise KeyError(f'Unknown model name `{model_name}`!')
68
+
69
+ model_config = MODEL_ZOO[model_name].copy()
70
+ url = model_config.pop('url') # URL to download model if needed.
71
+
72
+ # Build generator.
73
+ print(f'Building generator for model `{model_name}` ...')
74
+ generator = build_generator(**model_config)
75
+ print(f'Finish building generator.')
76
+
77
+ # Load pre-trained weights.
78
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
79
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
80
+ print(f'Loading checkpoint from `{checkpoint_path}` ...')
81
+ if not os.path.exists(checkpoint_path):
82
+ print(f' Downloading checkpoint from `{url}` ...')
83
+ subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
84
+ print(f' Finish downloading checkpoint.')
85
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
86
+ if 'generator_smooth' in checkpoint:
87
+ generator.load_state_dict(checkpoint['generator_smooth'])
88
+ else:
89
+ generator.load_state_dict(checkpoint['generator'])
90
+ generator = generator.cuda()
91
+ generator.eval()
92
+ print(f'Finish loading checkpoint.')
93
+ return generator
94
+
95
+
96
+ def parse_indices(obj, min_val=None, max_val=None):
97
+ """Parses indices.
98
+
99
+ The input can be a list or a tuple or a string, which is either a comma
100
+ separated list of numbers 'a, b, c', or a dash separated range 'a - c'.
101
+ Space in the string will be ignored.
102
+
103
+ Args:
104
+ obj: The input object to parse indices from.
105
+ min_val: If not `None`, this function will check that all indices are
106
+ equal to or larger than this value. (default: None)
107
+ max_val: If not `None`, this function will check that all indices are
108
+ equal to or smaller than this value. (default: None)
109
+
110
+ Returns:
111
+ A list of integers.
112
+
113
+ Raises:
114
+ If the input is invalid, i.e., neither a list or tuple, nor a string.
115
+ """
116
+ if obj is None or obj == '':
117
+ indices = []
118
+ elif isinstance(obj, int):
119
+ indices = [obj]
120
+ elif isinstance(obj, (list, tuple, np.ndarray)):
121
+ indices = list(obj)
122
+ elif isinstance(obj, str):
123
+ indices = []
124
+ splits = obj.replace(' ', '').split(',')
125
+ for split in splits:
126
+ numbers = list(map(int, split.split('-')))
127
+ if len(numbers) == 1:
128
+ indices.append(numbers[0])
129
+ elif len(numbers) == 2:
130
+ indices.extend(list(range(numbers[0], numbers[1] + 1)))
131
+ else:
132
+ raise ValueError(f'Unable to parse the input!')
133
+
134
+ else:
135
+ raise ValueError(f'Invalid type of input: `{type(obj)}`!')
136
+
137
+ assert isinstance(indices, list)
138
+ indices = sorted(list(set(indices)))
139
+ for idx in indices:
140
+ assert isinstance(idx, int)
141
+ if min_val is not None:
142
+ assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
143
+ if max_val is not None:
144
+ assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'
145
+
146
+ return indices
147
+
148
+
149
+ def factorize_weight(generator, layer_idx='all'):
150
+ """Factorizes the generator weight to get semantics boundaries.
151
+
152
+ Args:
153
+ generator: Generator to factorize.
154
+ layer_idx: Indices of layers to interpret, especially for StyleGAN and
155
+ StyleGAN2. (default: `all`)
156
+
157
+ Returns:
158
+ A tuple of (layers_to_interpret, semantic_boundaries, eigen_values).
159
+
160
+ Raises:
161
+ ValueError: If the generator type is not supported.
162
+ """
163
+ # Get GAN type.
164
+ gan_type = parse_gan_type(generator)
165
+
166
+ # Get layers.
167
+ if gan_type == 'pggan':
168
+ layers = [0]
169
+ elif gan_type in ['stylegan', 'stylegan2']:
170
+ if layer_idx == 'all':
171
+ layers = list(range(generator.num_layers))
172
+ else:
173
+ layers = parse_indices(layer_idx,
174
+ min_val=0,
175
+ max_val=generator.num_layers - 1)
176
+
177
+ # Factorize semantics from weight.
178
+ weights = []
179
+ for idx in layers:
180
+ layer_name = f'layer{idx}'
181
+ if gan_type == 'stylegan2' and idx == generator.num_layers - 1:
182
+ layer_name = f'output{idx // 2}'
183
+ if gan_type == 'pggan':
184
+ weight = generator.__getattr__(layer_name).weight
185
+ weight = weight.flip(2, 3).permute(1, 0, 2, 3).flatten(1)
186
+ elif gan_type in ['stylegan', 'stylegan2']:
187
+ weight = generator.synthesis.__getattr__(layer_name).style.weight.T
188
+ weights.append(weight.cpu().detach().numpy())
189
+ weight = np.concatenate(weights, axis=1).astype(np.float32)
190
+ weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
191
+ eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
192
+
193
+ return layers, eigen_vectors.T, eigen_values
194
+
195
+
196
+ def get_sortable_html_header(column_name_list, sort_by_ascending=False):
197
+ """Gets header for sortable html page.
198
+
199
+ Basically, the html page contains a sortable table, where user can sort the
200
+ rows by a particular column by clicking the column head.
201
+
202
+ Example:
203
+
204
+ column_name_list = [name_1, name_2, name_3]
205
+ header = get_sortable_html_header(column_name_list)
206
+ footer = get_sortable_html_footer()
207
+ sortable_table = ...
208
+ html_page = header + sortable_table + footer
209
+
210
+ Args:
211
+ column_name_list: List of column header names.
212
+ sort_by_ascending: Default sorting order. If set as `True`, the html
213
+ page will be sorted by ascending order when the header is clicked
214
+ for the first time.
215
+
216
+ Returns:
217
+ A string, which represents for the header for a sortable html page.
218
+ """
219
+ header = '\n'.join([
220
+ '<script type="text/javascript">',
221
+ 'var column_idx;',
222
+ 'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';',
223
+ '',
224
+ 'function sorting(tbody, column_idx){',
225
+ ' this.column_idx = column_idx;',
226
+ ' Array.from(tbody.rows)',
227
+ ' .sort(compareCells)',
228
+ ' .forEach(function(row) { tbody.appendChild(row); })',
229
+ ' sort_by_ascending = !sort_by_ascending;',
230
+ '}',
231
+ '',
232
+ 'function compareCells(row_a, row_b) {',
233
+ ' var val_a = row_a.cells[column_idx].innerText;',
234
+ ' var val_b = row_b.cells[column_idx].innerText;',
235
+ ' var flag = sort_by_ascending ? 1 : -1;',
236
+ ' return flag * (val_a > val_b ? 1 : -1);',
237
+ '}',
238
+ '</script>',
239
+ '',
240
+ '<html>',
241
+ '',
242
+ '<head>',
243
+ '<style>',
244
+ ' table {',
245
+ ' border-spacing: 0;',
246
+ ' border: 1px solid black;',
247
+ ' }',
248
+ ' th {',
249
+ ' cursor: pointer;',
250
+ ' }',
251
+ ' th, td {',
252
+ ' text-align: left;',
253
+ ' vertical-align: middle;',
254
+ ' border-collapse: collapse;',
255
+ ' border: 0.5px solid black;',
256
+ ' padding: 8px;',
257
+ ' }',
258
+ ' tr:nth-child(even) {',
259
+ ' background-color: #d2d2d2;',
260
+ ' }',
261
+ '</style>',
262
+ '</head>',
263
+ '',
264
+ '<body>',
265
+ '',
266
+ '<table>',
267
+ '<thead>',
268
+ '<tr>',
269
+ ''])
270
+ for idx, name in enumerate(column_name_list):
271
+ header += f' <th onclick="sorting(tbody, {idx})">{name}</th>\n'
272
+ header += '</tr>\n'
273
+ header += '</thead>\n'
274
+ header += '<tbody id="tbody">\n'
275
+
276
+ return header
277
+
278
+
279
+ def get_sortable_html_footer():
280
+ """Gets footer for sortable html page.
281
+
282
+ Check function `get_sortable_html_header()` for more details.
283
+ """
284
+ return '</tbody>\n</table>\n\n</body>\n</html>\n'
285
+
286
+
287
+ def parse_image_size(obj):
288
+ """Parses object to a pair of image size, i.e., (width, height).
289
+
290
+ Args:
291
+ obj: The input object to parse image size from.
292
+
293
+ Returns:
294
+ A two-element tuple, indicating image width and height respectively.
295
+
296
+ Raises:
297
+ If the input is invalid, i.e., neither a list or tuple, nor a string.
298
+ """
299
+ if obj is None or obj == '':
300
+ width = height = 0
301
+ elif isinstance(obj, int):
302
+ width = height = obj
303
+ elif isinstance(obj, (list, tuple, np.ndarray)):
304
+ numbers = tuple(obj)
305
+ if len(numbers) == 0:
306
+ width = height = 0
307
+ elif len(numbers) == 1:
308
+ width = height = numbers[0]
309
+ elif len(numbers) == 2:
310
+ width = numbers[0]
311
+ height = numbers[1]
312
+ else:
313
+ raise ValueError(f'At most two elements for image size.')
314
+ elif isinstance(obj, str):
315
+ splits = obj.replace(' ', '').split(',')
316
+ numbers = tuple(map(int, splits))
317
+ if len(numbers) == 0:
318
+ width = height = 0
319
+ elif len(numbers) == 1:
320
+ width = height = numbers[0]
321
+ elif len(numbers) == 2:
322
+ width = numbers[0]
323
+ height = numbers[1]
324
+ else:
325
+ raise ValueError(f'At most two elements for image size.')
326
+ else:
327
+ raise ValueError(f'Invalid type of input: {type(obj)}!')
328
+
329
+ return (max(0, width), max(0, height))
330
+
331
+
332
+ def encode_image_to_html_str(image, image_size=None):
333
+ """Encodes an image to html language.
334
+ NOTE: Input image is always assumed to be with `RGB` channel order.
335
+ Args:
336
+ image: The input image to encode. Should be with `RGB` channel order.
337
+ image_size: This field is used to resize the image before encoding. `0`
338
+ disables resizing. (default: None)
339
+ Returns:
340
+ A string which represents the encoded image.
341
+ """
342
+ if image is None:
343
+ return ''
344
+
345
+ assert image.ndim == 3 and image.shape[2] in [1, 3]
346
+
347
+ # Change channel order to `BGR`, which is opencv-friendly.
348
+ image = image[:, :, ::-1]
349
+
350
+ # Resize the image if needed.
351
+ width, height = parse_image_size(image_size)
352
+ if height or width:
353
+ height = height or image.shape[0]
354
+ width = width or image.shape[1]
355
+ image = cv2.resize(image, (width, height))
356
+
357
+ # Encode the image to html-format string.
358
+ encoded_image = cv2.imencode('.jpg', image)[1].tostring()
359
+ encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
360
+ html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>'
361
+
362
+ return html_str
363
+
364
+
365
+ def get_grid_shape(size, row=0, col=0, is_portrait=False):
366
+ """Gets the shape of a grid based on the size.
367
+
368
+ This function makes greatest effort on making the output grid square if
369
+ neither `row` nor `col` is set. If `is_portrait` is set as `False`, the
370
+ height will always be equal to or smaller than the width. For example, if
371
+ input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
372
+ output shape will be (3, 5). Otherwise, the height will always be equal to
373
+ or larger than the width.
374
+
375
+ Args:
376
+ size: Size (height * width) of the target grid.
377
+ is_portrait: Whether to return a portrait size of a landscape size.
378
+ (default: False)
379
+
380
+ Returns:
381
+ A two-element tuple, representing height and width respectively.
382
+ """
383
+ assert isinstance(size, int)
384
+ assert isinstance(row, int)
385
+ assert isinstance(col, int)
386
+ if size == 0:
387
+ return (0, 0)
388
+
389
+ if row > 0 and col > 0 and row * col != size:
390
+ row = 0
391
+ col = 0
392
+
393
+ if row > 0 and size % row == 0:
394
+ return (row, size // row)
395
+ if col > 0 and size % col == 0:
396
+ return (size // col, col)
397
+
398
+ row = int(np.sqrt(size))
399
+ while row > 0:
400
+ if size % row == 0:
401
+ col = size // row
402
+ break
403
+ row = row - 1
404
+
405
+ return (col, row) if is_portrait else (row, col)
406
+
407
+
408
+ class HtmlPageVisualizer(object):
409
+ """Defines the html page visualizer.
410
+
411
+ This class can be used to visualize image results as html page. Basically,
412
+ it is based on an html-format sorted table with helper functions
413
+ `get_sortable_html_header()`, `get_sortable_html_footer()`, and
414
+ `encode_image_to_html_str()`. To simplify the usage, specifying the
415
+ following fields are enough to create a visualization page:
416
+
417
+ (1) num_rows: Number of rows of the table (header-row exclusive).
418
+ (2) num_cols: Number of columns of the table.
419
+ (3) header contents (optional): Title of each column.
420
+
421
+ NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
422
+ automatically.
423
+
424
+ Example:
425
+
426
+ html = HtmlPageVisualizer(num_rows, num_cols)
427
+ html.set_headers([...])
428
+ for i in range(num_rows):
429
+ for j in range(num_cols):
430
+ html.set_cell(i, j, text=..., image=..., highlight=False)
431
+ html.save('visualize.html')
432
+ """
433
+
434
+ def __init__(self,
435
+ num_rows=0,
436
+ num_cols=0,
437
+ grid_size=0,
438
+ is_portrait=True,
439
+ viz_size=None):
440
+ if grid_size > 0:
441
+ num_rows, num_cols = get_grid_shape(
442
+ grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
443
+ assert num_rows > 0 and num_cols > 0
444
+
445
+ self.num_rows = num_rows
446
+ self.num_cols = num_cols
447
+ self.viz_size = parse_image_size(viz_size)
448
+ self.headers = ['' for _ in range(self.num_cols)]
449
+ self.cells = [[{
450
+ 'text': '',
451
+ 'image': '',
452
+ 'highlight': False,
453
+ } for _ in range(self.num_cols)] for _ in range(self.num_rows)]
454
+
455
+ def set_header(self, col_idx, content):
456
+ """Sets the content of a particular header by column index."""
457
+ self.headers[col_idx] = content
458
+
459
+ def set_headers(self, contents):
460
+ """Sets the contents of all headers."""
461
+ if isinstance(contents, str):
462
+ contents = [contents]
463
+ assert isinstance(contents, (list, tuple))
464
+ assert len(contents) == self.num_cols
465
+ for col_idx, content in enumerate(contents):
466
+ self.set_header(col_idx, content)
467
+
468
+ def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False):
469
+ """Sets the content of a particular cell.
470
+
471
+ Basically, a cell contains some text as well as an image. Both text and
472
+ image can be empty.
473
+
474
+ Args:
475
+ row_idx: Row index of the cell to edit.
476
+ col_idx: Column index of the cell to edit.
477
+ text: Text to add into the target cell. (default: None)
478
+ image: Image to show in the target cell. Should be with `RGB`
479
+ channel order. (default: None)
480
+ highlight: Whether to highlight this cell. (default: False)
481
+ """
482
+ self.cells[row_idx][col_idx]['text'] = text
483
+ self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str(
484
+ image, self.viz_size)
485
+ self.cells[row_idx][col_idx]['highlight'] = bool(highlight)
486
+
487
+ def save(self, save_path):
488
+ """Saves the html page."""
489
+ html = ''
490
+ for i in range(self.num_rows):
491
+ html += f'<tr>\n'
492
+ for j in range(self.num_cols):
493
+ text = self.cells[i][j]['text']
494
+ image = self.cells[i][j]['image']
495
+ if self.cells[i][j]['highlight']:
496
+ color = ' bgcolor="#FF8888"'
497
+ else:
498
+ color = ''
499
+ if text:
500
+ html += f' <td{color}>{text}<br><br>{image}</td>\n'
501
+ else:
502
+ html += f' <td{color}>{image}</td>\n'
503
+ html += f'</tr>\n'
504
+
505
+ header = get_sortable_html_header(self.headers)
506
+ footer = get_sortable_html_footer()
507
+
508
+ with open(save_path, 'w') as f:
509
+ f.write(header + html + footer)