Your Name commited on
Commit
2f85de4
1 Parent(s): cd0130f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +30 -0
  2. app.py +267 -0
  3. models/__init__.py +63 -0
  4. models/__pycache__/__init__.cpython-37.pyc +0 -0
  5. models/__pycache__/__init__.cpython-39.pyc +0 -0
  6. models/__pycache__/bev3d_generator.cpython-37.pyc +0 -0
  7. models/__pycache__/bev3d_generator.cpython-39.pyc +0 -0
  8. models/__pycache__/eg3d_discriminator.cpython-37.pyc +0 -0
  9. models/__pycache__/eg3d_discriminator.cpython-39.pyc +0 -0
  10. models/__pycache__/eg3d_generator.cpython-37.pyc +0 -0
  11. models/__pycache__/eg3d_generator.cpython-39.pyc +0 -0
  12. models/__pycache__/eg3d_generator_fv.cpython-37.pyc +0 -0
  13. models/__pycache__/eg3d_generator_fv.cpython-39.pyc +0 -0
  14. models/__pycache__/ghfeat_encoder.cpython-37.pyc +0 -0
  15. models/__pycache__/ghfeat_encoder.cpython-39.pyc +0 -0
  16. models/__pycache__/inception_model.cpython-37.pyc +0 -0
  17. models/__pycache__/inception_model.cpython-39.pyc +0 -0
  18. models/__pycache__/perceptual_model.cpython-37.pyc +0 -0
  19. models/__pycache__/perceptual_model.cpython-39.pyc +0 -0
  20. models/__pycache__/pggan_discriminator.cpython-37.pyc +0 -0
  21. models/__pycache__/pggan_discriminator.cpython-39.pyc +0 -0
  22. models/__pycache__/pggan_generator.cpython-37.pyc +0 -0
  23. models/__pycache__/pggan_generator.cpython-39.pyc +0 -0
  24. models/__pycache__/pigan_discriminator.cpython-37.pyc +0 -0
  25. models/__pycache__/pigan_discriminator.cpython-39.pyc +0 -0
  26. models/__pycache__/pigan_generator.cpython-37.pyc +0 -0
  27. models/__pycache__/pigan_generator.cpython-39.pyc +0 -0
  28. models/__pycache__/sgbev3d_generator.cpython-37.pyc +0 -0
  29. models/__pycache__/sgbev3d_generator.cpython-39.pyc +0 -0
  30. models/__pycache__/stylegan2_discriminator.cpython-37.pyc +0 -0
  31. models/__pycache__/stylegan2_discriminator.cpython-39.pyc +0 -0
  32. models/__pycache__/stylegan2_generator.cpython-37.pyc +0 -0
  33. models/__pycache__/stylegan2_generator.cpython-39.pyc +0 -0
  34. models/__pycache__/stylegan3_generator.cpython-37.pyc +0 -0
  35. models/__pycache__/stylegan3_generator.cpython-39.pyc +0 -0
  36. models/__pycache__/stylegan_discriminator.cpython-37.pyc +0 -0
  37. models/__pycache__/stylegan_discriminator.cpython-39.pyc +0 -0
  38. models/__pycache__/stylegan_generator.cpython-37.pyc +0 -0
  39. models/__pycache__/stylegan_generator.cpython-39.pyc +0 -0
  40. models/__pycache__/volumegan_discriminator.cpython-37.pyc +0 -0
  41. models/__pycache__/volumegan_discriminator.cpython-39.pyc +0 -0
  42. models/__pycache__/volumegan_generator.cpython-37.pyc +0 -0
  43. models/__pycache__/volumegan_generator.cpython-39.pyc +0 -0
  44. models/bev3d_generator.py +301 -0
  45. models/eg3d_discriminator.py +243 -0
  46. models/eg3d_generator.py +315 -0
  47. models/eg3d_generator_fv.py +320 -0
  48. models/ghfeat_encoder.py +563 -0
  49. models/inception_model.py +562 -0
  50. models/perceptual_model.py +519 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.1.0-devel-ubuntu22.04
2
+
3
+ ENV CUDA_HOME=/usr/local/cuda
4
+ ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
5
+ ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
6
+ ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
7
+
8
+ # apt install by root user
9
+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ curl \
12
+ git \
13
+ python-is-python3 \
14
+ python3.7-dev \
15
+ python3-pip \
16
+ wget \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
20
+
21
+
22
+ WORKDIR /code
23
+
24
+ COPY ./requirements.txt /code/requirements.txt
25
+
26
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
27
+
28
+ COPY . .
29
+
30
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from models import build_model
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torchvision
6
+ import ninja
7
+ import torch
8
+ from tqdm import trange
9
+ import imageio
10
+
11
+ checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth'
12
+ state = torch.load(checkpoint, map_location='cpu')
13
+ G = build_model(**state['model_kwargs_init']['generator_smooth'])
14
+ o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False)
15
+ G.eval().cuda()
16
+ G.backbone.synthesis.input.x_offset =0
17
+ G.backbone.synthesis.input.y_offset =0
18
+ G_kwargs= dict(noise_mode='const',
19
+ fused_modulate=False,
20
+ impl='cuda',
21
+ fp16_res=None)
22
+
23
+ def trans(x, y, z, length):
24
+ w = h = length
25
+ x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256
26
+ y = 0.5 * h - 128 + (y/9 + .5) * 256
27
+ z = z / 9 * 256
28
+ return x, y, z
29
+ def get_bev_from_objs(objs, length=256, scale = 6):
30
+ h, w = length, length *scale
31
+ nc = 14
32
+ canvas = np.zeros([h, w, nc])
33
+ xx = np.ones([h,w]).cumsum(0)
34
+ yy = np.ones([h,w]).cumsum(1)
35
+
36
+ for x, y, z, shape, color, material, rot in objs:
37
+ y, x, z = trans(x, y, z, length)
38
+
39
+ feat = [0] * nc
40
+ feat[0] = 1
41
+ feat[COLOR_NAME_LIST.index(color) + 1] = 1
42
+ feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1
43
+ feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1
44
+ feat = np.array(feat)
45
+ rot_sin = np.sin(rot / 180 * np.pi)
46
+ rot_cos = np.cos(rot / 180 * np.pi)
47
+
48
+ if shape == 'cube':
49
+ mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \
50
+ (np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z)
51
+ else:
52
+ mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z
53
+ canvas[mask] = feat
54
+ canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32)
55
+ rotate_angle = 0
56
+ canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy()
57
+ return canvas
58
+
59
+ # COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue']
60
+ COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue']
61
+ SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder']
62
+ MATERIAL_NAME_LIST = ['rubber', 'metal']
63
+
64
+ xy_lib = dict()
65
+ xy_lib['B'] = [
66
+ [-2, -1],
67
+ [-1, -1],
68
+ [-2, 0],
69
+ [-2, 1],
70
+ [-1, .5],
71
+ [0, 1],
72
+ [0, 0],
73
+ [0, -1],
74
+ [0, 2],
75
+ [-1, 2],
76
+ [-2, 2]
77
+ ]
78
+ xy_lib['B'] = [
79
+ [-2.5, 1.25],
80
+ [-2, 2],
81
+ [-2, 0.5],
82
+ [-2, -0.75],
83
+ [-1, -1],
84
+ [-1, 2],
85
+ [-1, 0],
86
+ [-1, 2],
87
+ [0, 1],
88
+ [0, 0],
89
+ [0, -1],
90
+ [0, 2],
91
+ # [-1, 2],
92
+
93
+ ]
94
+ xy_lib['B'] = [
95
+ [-2.5, 1.25],
96
+ [-2, 2],
97
+ [-2, 0.5],
98
+ [-2, -1],
99
+ [-1, -1.25],
100
+ [-1, 2],
101
+ [-1, 0],
102
+ [-1, 2],
103
+ [0, 1],
104
+ [0, 0],
105
+ [0, -1.25],
106
+ [0, 2],
107
+ # [-1, 2],
108
+
109
+ ]
110
+ xy_lib['R'] = [
111
+ [0, -1],
112
+ [0, 0],
113
+ [0, 1],
114
+ [0, 2],
115
+ [-1, -1],
116
+ # [-1, 2],
117
+ [-2, -1],
118
+ [-2, 0],
119
+ [-2.25, 2],
120
+ [-1, 1]
121
+ ]
122
+ xy_lib['C'] = [
123
+ [0, -1],
124
+ [0, 0],
125
+ [0, 1],
126
+ [0, 2],
127
+ [-1, -1],
128
+ [-1, 2],
129
+ [-2, -1],
130
+ # [-2, .5],
131
+ [-2, 2],
132
+ # [-1, .5]
133
+ ]
134
+ xy_lib['s'] = [
135
+ [0, -1],
136
+ [0, 0],
137
+ [0, 2],
138
+ [-1, -1],
139
+ [-1, 2],
140
+ [-2, -1],
141
+ [-2, 1],
142
+ [-2, 2],
143
+ [-1, .5]
144
+ ]
145
+
146
+ xy_lib['F'] = [
147
+ [0, -1],
148
+ [0, 0],
149
+ [0, 1],
150
+ [0, 2],
151
+ [-1, -1],
152
+ # [-1, 2],
153
+ [-2, -1],
154
+ [-2, .5],
155
+ # [-2, 2],
156
+ [-1, .5]
157
+ ]
158
+
159
+ xy_lib['c'] = [
160
+ [0.8,1],
161
+ # [-0.8,1],
162
+ [0,0.1],
163
+ [0,1.9],
164
+ ]
165
+
166
+ xy_lib['e'] = [
167
+ [0, -1],
168
+ [0, 0],
169
+ [0, 1],
170
+ [0, 2],
171
+ [-1, -1],
172
+ [-1, 2],
173
+ [-2, -1],
174
+ [-2, .5],
175
+ [-2, 2],
176
+ [-1, .5]
177
+ ]
178
+ xy_lib['n'] = [
179
+ [0,1],
180
+ [0,-1],
181
+ [0,0.1],
182
+ [0,1.9],
183
+ [-1,0],
184
+ [-2,1],
185
+ [-3,-1],
186
+ [-3,1],
187
+ [-3,0.1],
188
+ [-3,1.9],
189
+ ]
190
+ offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8)
191
+ s = 'BeRFsCene'
192
+ objs = []
193
+ offset = 2
194
+ for idx, c in enumerate(s):
195
+ xy = xy_lib[c]
196
+
197
+
198
+ color = np.random.choice(COLOR_NAME_LIST)
199
+ for i in range(len(xy)):
200
+ # while 1:
201
+ # is_ok = 1
202
+ # x, y =
203
+
204
+ # for prev_x, prev_y in zip(xpool, ypool):
205
+ x, y = xy[i]
206
+ y *= 1.5
207
+ y -= 0.5
208
+ x -= offset
209
+ z = 0.35
210
+ # if idx<4:
211
+ # color = np.random.choice(COLOR_NAME_LIST[:-1])
212
+ # else:
213
+ # color = 'blue'
214
+ shape = 'cube'
215
+ material = 'rubber'
216
+ rot = 0
217
+ objs.append([x, y, z, shape, color, material, rot])
218
+ offset += offset_x[c]
219
+ Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8))
220
+
221
+ batch_size = 1
222
+ code = torch.randn(1, G.z_dim).cuda()
223
+ to_pil = torchvision.transforms.ToPILImage()
224
+ large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None]
225
+ bevs = large_bevs[..., 0: 0+256]
226
+ RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660,
227
+ 10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000,
228
+ 0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000,
229
+ 32.0000, 0.0000, 0.0000, 1.0000]], device='cuda')
230
+
231
+ print('prepare finish', flush=True)
232
+
233
+ def inference(name):
234
+ print('inference', name, flush=True)
235
+ gen = G(code, RT, bevs)
236
+ rgb = gen['gen_output']['image'][0] * .5 + .5
237
+ print('inference', name, flush=True)
238
+ return np.array(to_pil(rgb))
239
+
240
+ # to_pil(rgb).save('tmp.png')
241
+ # save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png'
242
+ # return [save_path]
243
+
244
+ with gr.Blocks() as demo:
245
+ gr.HTML(
246
+ """
247
+ abc
248
+ """)
249
+
250
+ with gr.Group():
251
+ with gr.Row():
252
+ with gr.Column():
253
+ with gr.Row():
254
+ with gr.Column():
255
+ with gr.Row():
256
+ num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
257
+
258
+ with gr.Row():
259
+ with gr.Row():
260
+ btn = gr.Button("Result")
261
+
262
+ gallery = gr.Image(label='img', show_label=True, elem_id="gallery")
263
+
264
+ btn.click(fn=inference, inputs=num_frames, outputs=[gallery], postprocess=False)
265
+
266
+ demo.queue()
267
+ demo.launch(server_name='0.0.0.0', server_port=10093, debug=True, show_error=True)
models/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Collects all models."""
3
+
4
+ from .pggan_generator import PGGANGenerator
5
+ from .pggan_discriminator import PGGANDiscriminator
6
+ from .stylegan_generator import StyleGANGenerator
7
+ from .stylegan_discriminator import StyleGANDiscriminator
8
+ from .stylegan2_generator import StyleGAN2Generator
9
+ from .stylegan2_discriminator import StyleGAN2Discriminator
10
+ from .stylegan3_generator import StyleGAN3Generator
11
+ from .ghfeat_encoder import GHFeatEncoder
12
+ from .perceptual_model import PerceptualModel
13
+ from .inception_model import InceptionModel
14
+ from .eg3d_generator import EG3DGenerator
15
+ from .eg3d_discriminator import DualDiscriminator
16
+ from .pigan_generator import PiGANGenerator
17
+ from .pigan_discriminator import PiGANDiscriminator
18
+ from .volumegan_generator import VolumeGANGenerator
19
+ from .volumegan_discriminator import VolumeGANDiscriminator
20
+ from .eg3d_generator_fv import EG3DGeneratorFV
21
+ from .bev3d_generator import BEV3DGenerator
22
+ from .sgbev3d_generator import SGBEV3DGenerator
23
+
24
+ __all__ = ['build_model']
25
+
26
+ _MODELS = {
27
+ 'PGGANGenerator': PGGANGenerator,
28
+ 'PGGANDiscriminator': PGGANDiscriminator,
29
+ 'StyleGANGenerator': StyleGANGenerator,
30
+ 'StyleGANDiscriminator': StyleGANDiscriminator,
31
+ 'StyleGAN2Generator': StyleGAN2Generator,
32
+ 'StyleGAN2Discriminator': StyleGAN2Discriminator,
33
+ 'StyleGAN3Generator': StyleGAN3Generator,
34
+ 'GHFeatEncoder': GHFeatEncoder,
35
+ 'PerceptualModel': PerceptualModel.build_model,
36
+ 'InceptionModel': InceptionModel.build_model,
37
+ 'EG3DGenerator': EG3DGenerator,
38
+ 'EG3DDiscriminator': DualDiscriminator,
39
+ 'PiGANGenerator': PiGANGenerator,
40
+ 'PiGANDiscriminator': PiGANDiscriminator,
41
+ 'VolumeGANGenerator': VolumeGANGenerator,
42
+ 'VolumeGANDiscriminator': VolumeGANDiscriminator,
43
+ 'EG3DGeneratorFV': EG3DGeneratorFV,
44
+ 'BEV3DGenerator': BEV3DGenerator,
45
+ 'SGBEV3DGenerator': SGBEV3DGenerator,
46
+ }
47
+
48
+
49
+ def build_model(model_type, **kwargs):
50
+ """Builds a model based on its class type.
51
+
52
+ Args:
53
+ model_type: Class type to which the model belongs, which is case
54
+ sensitive.
55
+ **kwargs: Additional arguments to build the model.
56
+
57
+ Raises:
58
+ ValueError: If the `model_type` is not supported.
59
+ """
60
+ if model_type not in _MODELS:
61
+ raise ValueError(f'Invalid model type: `{model_type}`!\n'
62
+ f'Types allowed: {list(_MODELS)}.')
63
+ return _MODELS[model_type](**kwargs)
models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (2.06 kB). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
models/__pycache__/bev3d_generator.cpython-37.pyc ADDED
Binary file (6.16 kB). View file
 
models/__pycache__/bev3d_generator.cpython-39.pyc ADDED
Binary file (6.07 kB). View file
 
models/__pycache__/eg3d_discriminator.cpython-37.pyc ADDED
Binary file (8.01 kB). View file
 
models/__pycache__/eg3d_discriminator.cpython-39.pyc ADDED
Binary file (7.73 kB). View file
 
models/__pycache__/eg3d_generator.cpython-37.pyc ADDED
Binary file (6.21 kB). View file
 
models/__pycache__/eg3d_generator.cpython-39.pyc ADDED
Binary file (6.3 kB). View file
 
models/__pycache__/eg3d_generator_fv.cpython-37.pyc ADDED
Binary file (6.35 kB). View file
 
models/__pycache__/eg3d_generator_fv.cpython-39.pyc ADDED
Binary file (6.43 kB). View file
 
models/__pycache__/ghfeat_encoder.cpython-37.pyc ADDED
Binary file (14.3 kB). View file
 
models/__pycache__/ghfeat_encoder.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
models/__pycache__/inception_model.cpython-37.pyc ADDED
Binary file (16 kB). View file
 
models/__pycache__/inception_model.cpython-39.pyc ADDED
Binary file (15.7 kB). View file
 
models/__pycache__/perceptual_model.cpython-37.pyc ADDED
Binary file (14.3 kB). View file
 
models/__pycache__/perceptual_model.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
models/__pycache__/pggan_discriminator.cpython-37.pyc ADDED
Binary file (12 kB). View file
 
models/__pycache__/pggan_discriminator.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
models/__pycache__/pggan_generator.cpython-37.pyc ADDED
Binary file (10.6 kB). View file
 
models/__pycache__/pggan_generator.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
models/__pycache__/pigan_discriminator.cpython-37.pyc ADDED
Binary file (8.32 kB). View file
 
models/__pycache__/pigan_discriminator.cpython-39.pyc ADDED
Binary file (8.31 kB). View file
 
models/__pycache__/pigan_generator.cpython-37.pyc ADDED
Binary file (12.7 kB). View file
 
models/__pycache__/pigan_generator.cpython-39.pyc ADDED
Binary file (12.8 kB). View file
 
models/__pycache__/sgbev3d_generator.cpython-37.pyc ADDED
Binary file (7.01 kB). View file
 
models/__pycache__/sgbev3d_generator.cpython-39.pyc ADDED
Binary file (7.04 kB). View file
 
models/__pycache__/stylegan2_discriminator.cpython-37.pyc ADDED
Binary file (17.7 kB). View file
 
models/__pycache__/stylegan2_discriminator.cpython-39.pyc ADDED
Binary file (17.7 kB). View file
 
models/__pycache__/stylegan2_generator.cpython-37.pyc ADDED
Binary file (32.9 kB). View file
 
models/__pycache__/stylegan2_generator.cpython-39.pyc ADDED
Binary file (32.9 kB). View file
 
models/__pycache__/stylegan3_generator.cpython-37.pyc ADDED
Binary file (35.8 kB). View file
 
models/__pycache__/stylegan3_generator.cpython-39.pyc ADDED
Binary file (35.7 kB). View file
 
models/__pycache__/stylegan_discriminator.cpython-37.pyc ADDED
Binary file (15.9 kB). View file
 
models/__pycache__/stylegan_discriminator.cpython-39.pyc ADDED
Binary file (15.9 kB). View file
 
models/__pycache__/stylegan_generator.cpython-37.pyc ADDED
Binary file (24.9 kB). View file
 
models/__pycache__/stylegan_generator.cpython-39.pyc ADDED
Binary file (24.9 kB). View file
 
models/__pycache__/volumegan_discriminator.cpython-37.pyc ADDED
Binary file (17.8 kB). View file
 
models/__pycache__/volumegan_discriminator.cpython-39.pyc ADDED
Binary file (17.8 kB). View file
 
models/__pycache__/volumegan_generator.cpython-37.pyc ADDED
Binary file (18.2 kB). View file
 
models/__pycache__/volumegan_generator.cpython-39.pyc ADDED
Binary file (18.2 kB). View file
 
models/bev3d_generator.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.8
2
+ """Contains the implementation of generator described in BEV3D."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
7
+ from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
8
+ from models.utils.eg3d_superres import SuperresolutionHybrid2X
9
+ from models.utils.eg3d_superres import SuperresolutionHybrid4X
10
+ from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle
11
+ from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
12
+ from models.rendering.renderer import Renderer
13
+ from models.rendering.feature_extractor import FeatureExtractor
14
+
15
+ from models.utils.spade import SPADEGenerator
16
+
17
+ class BEV3DGenerator(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ z_dim,
22
+ semantic_nc,
23
+ ngf,
24
+ bev_grid_size,
25
+ aspect_ratio,
26
+ num_upsampling_layers,
27
+ not_use_vae,
28
+ norm_G,
29
+ img_resolution,
30
+ interpolate_sr,
31
+ segmask=False,
32
+ dim_seq='16,8,4,2,1',
33
+ xyz_pe=False,
34
+ hidden_dim=64,
35
+ additional_layer_num=0,
36
+ sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
37
+ rendering_kwargs={}, # Arguments for rendering.
38
+ sr_kwargs={}, # Arguments for SuperResolution Network.
39
+ ):
40
+ super().__init__()
41
+
42
+ self.z_dim = z_dim
43
+ self.interpolate_sr = interpolate_sr
44
+ self.segmask = segmask
45
+
46
+ # Set up the overall renderer.
47
+ self.renderer = Renderer()
48
+
49
+ # Set up the feature extractor.
50
+ self.feature_extractor = FeatureExtractor(ref_mode='bev_plane_clevr', xyz_pe=xyz_pe)
51
+
52
+ # Set up the reference representation generator.
53
+ self.backbone = SPADEGenerator(z_dim=z_dim, semantic_nc=semantic_nc, ngf=ngf, dim_seq=dim_seq, bev_grid_size=bev_grid_size,
54
+ aspect_ratio=aspect_ratio, num_upsampling_layers=num_upsampling_layers,
55
+ not_use_vae=not_use_vae, norm_G=norm_G)
56
+ print('backbone SPADEGenerator set up!')
57
+
58
+ # Set up the post module in the feature extractor.
59
+ self.post_module = None
60
+
61
+ # Set up the post neural renderer.
62
+ self.post_neural_renderer = None
63
+ sr_kwargs_total = dict(
64
+ channels=32,
65
+ img_resolution=img_resolution,
66
+ sr_num_fp16_res=sr_num_fp16_res,
67
+ sr_antialias=rendering_kwargs['sr_antialias'],)
68
+ sr_kwargs_total.update(**sr_kwargs)
69
+ if img_resolution == 128:
70
+ self.post_neural_renderer = SuperresolutionHybrid2X(
71
+ **sr_kwargs_total)
72
+ elif img_resolution == 256:
73
+ self.post_neural_renderer = SuperresolutionHybrid4X_conststyle(
74
+ **sr_kwargs_total)
75
+ elif img_resolution == 512:
76
+ self.post_neural_renderer = SuperresolutionHybrid8XDC(
77
+ **sr_kwargs_total)
78
+ else:
79
+ raise TypeError(f'Unsupported image resolution: {img_resolution}!')
80
+
81
+ # Set up the fully-connected layer head.
82
+ self.fc_head = OSGDecoder(
83
+ 128 if xyz_pe else 64 , {
84
+ 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
85
+ 'decoder_output_dim': 32
86
+ },
87
+ hidden_dim=hidden_dim,
88
+ additional_layer_num=additional_layer_num
89
+ )
90
+
91
+ # Set up some rendering related arguments.
92
+ self.neural_rendering_resolution = rendering_kwargs.get(
93
+ 'resolution', 64)
94
+ self.rendering_kwargs = rendering_kwargs
95
+
96
+ def synthesis(self,
97
+ z,
98
+ c,
99
+ seg,
100
+ neural_rendering_resolution=None,
101
+ update_emas=False,
102
+ **synthesis_kwargs):
103
+ cam2world_matrix = c[:, :16].view(-1, 4, 4)
104
+ if self.rendering_kwargs.get('random_pose', False):
105
+ cam2world_matrix = None
106
+
107
+ if neural_rendering_resolution is None:
108
+ neural_rendering_resolution = self.neural_rendering_resolution
109
+ else:
110
+ self.neural_rendering_resolution = neural_rendering_resolution
111
+
112
+ xy_planes = self.backbone(z=z, input=seg)
113
+ if self.segmask:
114
+ xy_planes = xy_planes * seg[:, 0, ...][:, None, ...]
115
+
116
+ # import pdb;pdb.set_trace()
117
+
118
+ wp = z # in our case, we do not use wp.
119
+
120
+ rendering_result = self.renderer(
121
+ wp=wp,
122
+ feature_extractor=self.feature_extractor,
123
+ rendering_options=self.rendering_kwargs,
124
+ cam2world_matrix=cam2world_matrix,
125
+ position_encoder=None,
126
+ ref_representation=xy_planes,
127
+ post_module=self.post_module,
128
+ fc_head=self.fc_head)
129
+
130
+ feature_samples = rendering_result['composite_rgb']
131
+ depth_samples = rendering_result['composite_depth']
132
+
133
+ # Reshape to keep consistent with 'raw' neural-rendered image.
134
+ N = wp.shape[0]
135
+ H = W = self.neural_rendering_resolution
136
+ feature_image = feature_samples.permute(0, 2, 1).reshape(
137
+ N, feature_samples.shape[-1], H, W).contiguous()
138
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
139
+
140
+ # Run the post neural renderer to get final image.
141
+ # Here, the post neural renderer is a super-resolution network.
142
+ rgb_image = feature_image[:, :3]
143
+ if self.interpolate_sr:
144
+ sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False)
145
+ else:
146
+ sr_image = self.post_neural_renderer(
147
+ rgb_image,
148
+ feature_image,
149
+ # wp,
150
+ noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
151
+ **{
152
+ k: synthesis_kwargs[k]
153
+ for k in synthesis_kwargs.keys() if k != 'noise_mode'
154
+ })
155
+
156
+ return {
157
+ 'image': sr_image,
158
+ 'image_raw': rgb_image,
159
+ 'image_depth': depth_image
160
+ }
161
+
162
+ def sample(self,
163
+ coordinates,
164
+ directions,
165
+ z,
166
+ c,
167
+ seg,
168
+ truncation_psi=1,
169
+ truncation_cutoff=None,
170
+ update_emas=False,
171
+ **synthesis_kwargs):
172
+ # Compute RGB features, density for arbitrary 3D coordinates.
173
+ # Mostly used for extracting shapes.
174
+ cam2world_matrix = c[:, :16].view(-1, 4, 4)
175
+ xy_planes = self.backbone(z=z, input=seg)
176
+ wp = z
177
+ result = self.renderer.get_sigma_rgb(
178
+ wp=wp,
179
+ points=coordinates,
180
+ feature_extractor=self.feature_extractor,
181
+ fc_head=self.fc_head,
182
+ rendering_options=self.rendering_kwargs,
183
+ ref_representation=xy_planes,
184
+ post_module=self.post_module,
185
+ ray_dirs=directions,
186
+ cam_matrix=cam2world_matrix)
187
+
188
+ return result
189
+
190
+ def sample_mixed(self,
191
+ coordinates,
192
+ directions,
193
+ z, c, seg,
194
+ truncation_psi=1,
195
+ truncation_cutoff=None,
196
+ update_emas=False,
197
+ **synthesis_kwargs):
198
+ # Same as function `self.sample()`, but expects latent vectors 'wp'
199
+ # instead of Gaussian noise 'z'.
200
+ cam2world_matrix = c[:, :16].view(-1, 4, 4)
201
+ xy_planes = self.backbone(z=z, input=seg)
202
+ wp = z
203
+ result = self.renderer.get_sigma_rgb(
204
+ wp=wp,
205
+ points=coordinates,
206
+ feature_extractor=self.feature_extractor,
207
+ fc_head=self.fc_head,
208
+ rendering_options=self.rendering_kwargs,
209
+ ref_representation=xy_planes,
210
+ post_module=self.post_module,
211
+ ray_dirs=directions,
212
+ cam_matrix=cam2world_matrix)
213
+
214
+ return result
215
+
216
+ def forward(self,
217
+ z,
218
+ c,
219
+ seg,
220
+ c_swapped=None, # `c_swapped` is swapped pose conditioning.
221
+ style_mixing_prob=0,
222
+ truncation_psi=1,
223
+ truncation_cutoff=None,
224
+ neural_rendering_resolution=None,
225
+ update_emas=False,
226
+ sample_mixed=False,
227
+ coordinates=None,
228
+ **synthesis_kwargs):
229
+
230
+ # Render a batch of generated images.
231
+ c_wp = c.clone()
232
+ if c_swapped is not None:
233
+ c_wp = c_swapped.clone()
234
+
235
+ if not sample_mixed:
236
+ gen_output = self.synthesis(
237
+ z,
238
+ c,
239
+ seg,
240
+ update_emas=update_emas,
241
+ neural_rendering_resolution=neural_rendering_resolution,
242
+ **synthesis_kwargs)
243
+
244
+ return {
245
+ 'wp': z,
246
+ 'gen_output': gen_output,
247
+ }
248
+
249
+ else:
250
+ # Only for density regularization in training process.
251
+ assert coordinates is not None
252
+ sample_sigma = self.sample_mixed(coordinates,
253
+ torch.randn_like(coordinates),
254
+ z, c, seg,
255
+ update_emas=False)['sigma']
256
+
257
+ return {
258
+ 'wp': z,
259
+ 'sample_sigma': sample_sigma
260
+ }
261
+
262
+
263
+ class OSGDecoder(nn.Module):
264
+ """Defines fully-connected layer head in EG3D."""
265
+ def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0):
266
+ super().__init__()
267
+ self.hidden_dim = hidden_dim
268
+
269
+ lst = []
270
+ lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
271
+ lst.append(nn.Softplus())
272
+ for i in range(additional_layer_num):
273
+ lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
274
+ lst.append(nn.Softplus())
275
+ lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']))
276
+ self.net = nn.Sequential(*lst)
277
+
278
+ # self.net = nn.Sequential(
279
+ # FullyConnectedLayer(n_features,
280
+ # self.hidden_dim,
281
+ # lr_multiplier=options['decoder_lr_mul']),
282
+ # nn.Softplus(),
283
+ # FullyConnectedLayer(self.hidden_dim,
284
+ # 1 + options['decoder_output_dim'],
285
+ # lr_multiplier=options['decoder_lr_mul']))
286
+
287
+ def forward(self, point_features, wp=None, dirs=None):
288
+ # Aggregate features
289
+ # point_features.shape: [N, R, K, C].
290
+ # Average across 'X, Y, Z' planes.
291
+
292
+ N, R, K, C = point_features.shape
293
+ x = point_features.reshape(-1, point_features.shape[-1])
294
+ x = self.net(x)
295
+ x = x.view(N, -1, x.shape[-1])
296
+
297
+ # Uses sigmoid clamping from MipNeRF
298
+ rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
299
+ sigma = x[..., 0:1]
300
+
301
+ return {'rgb': rgb, 'sigma': sigma}
models/eg3d_discriminator.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python 3.7
2
+ """Contains the implementation of discriminator described in EG3D."""
3
+
4
+
5
+ import numpy as np
6
+ import torch
7
+ from third_party.stylegan2_official_ops import upfirdn2d
8
+ from models.utils.official_stylegan2_model_helper import DiscriminatorBlock
9
+ from models.utils.official_stylegan2_model_helper import MappingNetwork
10
+ from models.utils.official_stylegan2_model_helper import DiscriminatorEpilogue
11
+
12
+
13
+ class SingleDiscriminator(torch.nn.Module):
14
+ def __init__(self,
15
+ c_dim, # Conditioning label (C) dimensionality.
16
+ img_resolution, # Input resolution.
17
+ img_channels, # Number of input color channels.
18
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
19
+ channel_base = 32768, # Overall multiplier for the number of channels.
20
+ channel_max = 512, # Maximum number of channels in any layer.
21
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
22
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
23
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
24
+ sr_upsample_factor = 1, # Ignored for SingleDiscriminator
25
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
26
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
27
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
28
+ ):
29
+ super().__init__()
30
+ self.c_dim = c_dim
31
+ self.img_resolution = img_resolution
32
+ self.img_resolution_log2 = int(np.log2(img_resolution))
33
+ self.img_channels = img_channels
34
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
35
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
36
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
37
+
38
+ if cmap_dim is None:
39
+ cmap_dim = channels_dict[4]
40
+ if c_dim == 0:
41
+ cmap_dim = 0
42
+
43
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
44
+ cur_layer_idx = 0
45
+ for res in self.block_resolutions:
46
+ in_channels = channels_dict[res] if res < img_resolution else 0
47
+ tmp_channels = channels_dict[res]
48
+ out_channels = channels_dict[res // 2]
49
+ use_fp16 = (res >= fp16_resolution)
50
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
51
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
52
+ setattr(self, f'b{res}', block)
53
+ cur_layer_idx += block.num_layers
54
+ if c_dim > 0:
55
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
56
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
57
+
58
+ def forward(self, img, c, update_emas=False, **block_kwargs):
59
+ img = img['image']
60
+
61
+ _ = update_emas # unused
62
+ x = None
63
+ for res in self.block_resolutions:
64
+ block = getattr(self, f'b{res}')
65
+ x, img = block(x, img, **block_kwargs)
66
+
67
+ cmap = None
68
+ if self.c_dim > 0:
69
+ cmap = self.mapping(None, c)
70
+ x = self.b4(x, img, cmap)
71
+ return x
72
+
73
+ def extra_repr(self):
74
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
75
+
76
+ #----------------------------------------------------------------------------
77
+
78
+ def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
79
+ if filter_mode == 'antialiased':
80
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
81
+ elif filter_mode == 'classic':
82
+ ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
83
+ ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
84
+ ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
85
+ elif filter_mode == 'none':
86
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
87
+ elif type(filter_mode) == float:
88
+ assert 0 < filter_mode < 1
89
+
90
+ filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
91
+ aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
92
+ ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
93
+
94
+ return ada_filtered_64
95
+
96
+ #----------------------------------------------------------------------------
97
+
98
+ class DualDiscriminator(torch.nn.Module):
99
+ def __init__(self,
100
+ c_dim, # Conditioning label (C) dimensionality.
101
+ img_resolution, # Input resolution.
102
+ img_channels, # Number of input color channels.
103
+ bev_channels = 0,
104
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
105
+ channel_base = 32768, # Overall multiplier for the number of channels.
106
+ channel_max = 512, # Maximum number of channels in any layer.
107
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
108
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
109
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
110
+ disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
111
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
112
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
113
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
114
+ ):
115
+ super().__init__()
116
+ img_channels *= 2
117
+
118
+ self.c_dim = c_dim
119
+ self.img_resolution = img_resolution
120
+ self.img_resolution_log2 = int(np.log2(img_resolution))
121
+ self.img_channels = img_channels + bev_channels
122
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
123
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
124
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
125
+
126
+ if cmap_dim is None:
127
+ cmap_dim = channels_dict[4]
128
+ if c_dim == 0:
129
+ cmap_dim = 0
130
+
131
+ common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
132
+ cur_layer_idx = 0
133
+ for res in self.block_resolutions:
134
+ in_channels = channels_dict[res] if res < img_resolution else 0
135
+ tmp_channels = channels_dict[res]
136
+ out_channels = channels_dict[res // 2]
137
+ use_fp16 = (res >= fp16_resolution)
138
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
139
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
140
+ setattr(self, f'b{res}', block)
141
+ cur_layer_idx += block.num_layers
142
+ if c_dim > 0:
143
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
144
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
145
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
146
+ self.disc_c_noise = disc_c_noise
147
+
148
+ def forward(self, img, c, bev=None, update_emas=False, **block_kwargs):
149
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
150
+ img = torch.cat([img['image'], image_raw], 1)
151
+ if bev is not None:
152
+ img = torch.cat([img, bev], 1)
153
+
154
+ _ = update_emas # unused
155
+ x = None
156
+ for res in self.block_resolutions:
157
+ block = getattr(self, f'b{res}')
158
+ x, img = block(x, img, **block_kwargs)
159
+
160
+ cmap = None
161
+ if self.c_dim > 0:
162
+ if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
163
+ cmap = self.mapping(None, c)
164
+ x = self.b4(x, img, cmap)
165
+ return x
166
+
167
+ def extra_repr(self):
168
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
169
+
170
+ #----------------------------------------------------------------------------
171
+
172
+ class DummyDualDiscriminator(torch.nn.Module):
173
+ def __init__(self,
174
+ c_dim, # Conditioning label (C) dimensionality.
175
+ img_resolution, # Input resolution.
176
+ img_channels, # Number of input color channels.
177
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
178
+ channel_base = 32768, # Overall multiplier for the number of channels.
179
+ channel_max = 512, # Maximum number of channels in any layer.
180
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
181
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
182
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
183
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
184
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
185
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
186
+ ):
187
+ super().__init__()
188
+ img_channels *= 2
189
+
190
+ self.c_dim = c_dim
191
+ self.img_resolution = img_resolution
192
+ self.img_resolution_log2 = int(np.log2(img_resolution))
193
+ self.img_channels = img_channels
194
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
195
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
196
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
197
+
198
+ if cmap_dim is None:
199
+ cmap_dim = channels_dict[4]
200
+ if c_dim == 0:
201
+ cmap_dim = 0
202
+
203
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
204
+ cur_layer_idx = 0
205
+ for res in self.block_resolutions:
206
+ in_channels = channels_dict[res] if res < img_resolution else 0
207
+ tmp_channels = channels_dict[res]
208
+ out_channels = channels_dict[res // 2]
209
+ use_fp16 = (res >= fp16_resolution)
210
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
211
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
212
+ setattr(self, f'b{res}', block)
213
+ cur_layer_idx += block.num_layers
214
+ if c_dim > 0:
215
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
216
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
217
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
218
+
219
+ self.raw_fade = 1
220
+
221
+ def forward(self, img, c, update_emas=False, **block_kwargs):
222
+ self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
223
+
224
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
225
+ img = torch.cat([img['image'], image_raw], 1)
226
+
227
+ _ = update_emas # unused
228
+ x = None
229
+ for res in self.block_resolutions:
230
+ block = getattr(self, f'b{res}')
231
+ x, img = block(x, img, **block_kwargs)
232
+
233
+ cmap = None
234
+ if self.c_dim > 0:
235
+ cmap = self.mapping(None, c)
236
+ x = self.b4(x, img, cmap)
237
+ return x
238
+
239
+ def extra_repr(self):
240
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
241
+
242
+ #----------------------------------------------------------------------------
243
+
models/eg3d_generator.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.8
2
+ """Contains the implementation of generator described in EG3D."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
7
+ from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
8
+ from models.utils.eg3d_superres import SuperresolutionHybrid2X
9
+ from models.utils.eg3d_superres import SuperresolutionHybrid4X
10
+ from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
11
+ from models.rendering.renderer import Renderer
12
+ from models.rendering.feature_extractor import FeatureExtractor
13
+
14
+ class EG3DGenerator(nn.Module):
15
+
16
+ def __init__(
17
+ self,
18
+ z_dim, # Input latent (Z) dimensionality.
19
+ c_dim, # Conditioning label (C) dimensionality.
20
+ w_dim, # Intermediate latent (W) dimensionality.
21
+ img_resolution, # Output resolution.
22
+ img_channels, # Number of output color channels.
23
+ sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
24
+ mapping_kwargs={}, # Arguments for MappingNetwork.
25
+ rendering_kwargs={}, # Arguments for rendering.
26
+ sr_kwargs={}, # Arguments for SuperResolution Network.
27
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
28
+ ):
29
+ super().__init__()
30
+ self.z_dim = z_dim
31
+ self.c_dim = c_dim
32
+ self.w_dim = w_dim
33
+ self.img_resolution = img_resolution
34
+ self.img_channels = img_channels
35
+
36
+ # Set up the overall renderer.
37
+ self.renderer = Renderer()
38
+
39
+ # Set up the feature extractor.
40
+ self.feature_extractor = FeatureExtractor(ref_mode='tri_plane')
41
+
42
+ # Set up the reference representation generator.
43
+ self.backbone = StyleGAN2Backbone(z_dim,
44
+ c_dim,
45
+ w_dim,
46
+ img_resolution=256,
47
+ img_channels=32 * 3,
48
+ mapping_kwargs=mapping_kwargs,
49
+ **synthesis_kwargs)
50
+
51
+ # Set up the post module in the feature extractor.
52
+ self.post_module = None
53
+
54
+ # Set up the post neural renderer.
55
+ self.post_neural_renderer = None
56
+ sr_kwargs_total = dict(
57
+ channels=32,
58
+ img_resolution=img_resolution,
59
+ sr_num_fp16_res=sr_num_fp16_res,
60
+ sr_antialias=rendering_kwargs['sr_antialias'],)
61
+ sr_kwargs_total.update(**sr_kwargs)
62
+ if img_resolution == 128:
63
+ self.post_neural_renderer = SuperresolutionHybrid2X(
64
+ **sr_kwargs_total)
65
+ elif img_resolution == 256:
66
+ self.post_neural_renderer = SuperresolutionHybrid4X(
67
+ **sr_kwargs_total)
68
+ elif img_resolution == 512:
69
+ self.post_neural_renderer = SuperresolutionHybrid8XDC(
70
+ **sr_kwargs_total)
71
+ else:
72
+ raise TypeError(f'Unsupported image resolution: {img_resolution}!')
73
+
74
+ # Set up the fully-connected layer head.
75
+ self.fc_head = OSGDecoder(
76
+ 32, {
77
+ 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
78
+ 'decoder_output_dim': 32
79
+ })
80
+
81
+ # Set up some rendering related arguments.
82
+ self.neural_rendering_resolution = rendering_kwargs.get(
83
+ 'resolution', 64)
84
+ self.rendering_kwargs = rendering_kwargs
85
+
86
+ def mapping(self,
87
+ z,
88
+ c,
89
+ truncation_psi=1,
90
+ truncation_cutoff=None,
91
+ update_emas=False):
92
+ if self.rendering_kwargs['c_gen_conditioning_zero']:
93
+ c = torch.zeros_like(c)
94
+ return self.backbone.mapping(z,
95
+ c *
96
+ self.rendering_kwargs.get('c_scale', 0),
97
+ truncation_psi=truncation_psi,
98
+ truncation_cutoff=truncation_cutoff,
99
+ update_emas=update_emas)
100
+
101
+ def synthesis(self,
102
+ wp,
103
+ c,
104
+ neural_rendering_resolution=None,
105
+ update_emas=False,
106
+ **synthesis_kwargs):
107
+ cam2world_matrix = c[:, :16].view(-1, 4, 4)
108
+ if self.rendering_kwargs.get('random_pose', False):
109
+ cam2world_matrix = None
110
+
111
+ if neural_rendering_resolution is None:
112
+ neural_rendering_resolution = self.neural_rendering_resolution
113
+ else:
114
+ self.neural_rendering_resolution = neural_rendering_resolution
115
+
116
+ tri_planes = self.backbone.synthesis(wp,
117
+ update_emas=update_emas,
118
+ **synthesis_kwargs)
119
+ tri_planes = tri_planes.view(len(tri_planes), 3, -1,
120
+ tri_planes.shape[-2],
121
+ tri_planes.shape[-1])
122
+
123
+ rendering_result = self.renderer(
124
+ wp=wp,
125
+ feature_extractor=self.feature_extractor,
126
+ rendering_options=self.rendering_kwargs,
127
+ cam2world_matrix=cam2world_matrix,
128
+ position_encoder=None,
129
+ ref_representation=tri_planes,
130
+ post_module=self.post_module,
131
+ fc_head=self.fc_head)
132
+
133
+ feature_samples = rendering_result['composite_rgb']
134
+ depth_samples = rendering_result['composite_depth']
135
+
136
+ # Reshape to keep consistent with 'raw' neural-rendered image.
137
+ N = wp.shape[0]
138
+ H = W = self.neural_rendering_resolution
139
+ feature_image = feature_samples.permute(0, 2, 1).reshape(
140
+ N, feature_samples.shape[-1], H, W).contiguous()
141
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
142
+
143
+ # Run the post neural renderer to get final image.
144
+ # Here, the post neural renderer is a super-resolution network.
145
+ rgb_image = feature_image[:, :3]
146
+ sr_image = self.post_neural_renderer(
147
+ rgb_image,
148
+ feature_image,
149
+ wp,
150
+ noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
151
+ **{
152
+ k: synthesis_kwargs[k]
153
+ for k in synthesis_kwargs.keys() if k != 'noise_mode'
154
+ })
155
+
156
+ return {
157
+ 'image': sr_image,
158
+ 'image_raw': rgb_image,
159
+ 'image_depth': depth_image
160
+ }
161
+
162
+ def sample(self,
163
+ coordinates,
164
+ directions,
165
+ z,
166
+ c,
167
+ truncation_psi=1,
168
+ truncation_cutoff=None,
169
+ update_emas=False,
170
+ **synthesis_kwargs):
171
+ # Compute RGB features, density for arbitrary 3D coordinates.
172
+ # Mostly used for extracting shapes.
173
+ wp = self.mapping(z,
174
+ c,
175
+ truncation_psi=truncation_psi,
176
+ truncation_cutoff=truncation_cutoff,
177
+ update_emas=update_emas)
178
+ tri_planes = self.backbone.synthesis(wp,
179
+ update_emas=update_emas,
180
+ **synthesis_kwargs)
181
+ tri_planes = tri_planes.view(len(tri_planes), 3, -1,
182
+ tri_planes.shape[-2],
183
+ tri_planes.shape[-1])
184
+ result = self.renderer.get_sigma_rgb(
185
+ wp=wp,
186
+ points=coordinates,
187
+ feature_extractor=self.feature_extractor,
188
+ fc_head=self.fc_head,
189
+ rendering_options=self.rendering_kwargs,
190
+ ref_representation=tri_planes,
191
+ post_module=self.post_module,
192
+ ray_dirs=directions)
193
+
194
+ return result
195
+
196
+ def sample_mixed(self,
197
+ coordinates,
198
+ directions,
199
+ wp,
200
+ truncation_psi=1,
201
+ truncation_cutoff=None,
202
+ update_emas=False,
203
+ **synthesis_kwargs):
204
+ # Same as function `self.sample()`, but expects latent vectors 'wp'
205
+ # instead of Gaussian noise 'z'.
206
+ tri_planes = self.backbone.synthesis(wp,
207
+ update_emas=update_emas,
208
+ **synthesis_kwargs)
209
+ tri_planes = tri_planes.view(len(tri_planes), 3, -1,
210
+ tri_planes.shape[-2],
211
+ tri_planes.shape[-1])
212
+
213
+ result = self.renderer.get_sigma_rgb(
214
+ wp=wp,
215
+ points=coordinates,
216
+ feature_extractor=self.feature_extractor,
217
+ fc_head=self.fc_head,
218
+ rendering_options=self.rendering_kwargs,
219
+ ref_representation=tri_planes,
220
+ post_module=self.post_module,
221
+ ray_dirs=directions)
222
+
223
+ return result
224
+
225
+ def forward(self,
226
+ z,
227
+ c,
228
+ c_swapped=None, # `c_swapped` is swapped pose conditioning.
229
+ style_mixing_prob=0,
230
+ truncation_psi=1,
231
+ truncation_cutoff=None,
232
+ neural_rendering_resolution=None,
233
+ update_emas=False,
234
+ sample_mixed=False,
235
+ coordinates=None,
236
+ **synthesis_kwargs):
237
+
238
+ # Render a batch of generated images.
239
+ c_wp = c.clone()
240
+ if c_swapped is not None:
241
+ c_wp = c_swapped.clone()
242
+ wp = self.mapping(z,
243
+ c_wp,
244
+ truncation_psi=truncation_psi,
245
+ truncation_cutoff=truncation_cutoff,
246
+ update_emas=update_emas)
247
+ if style_mixing_prob > 0:
248
+ cutoff = torch.empty([], dtype=torch.int64,
249
+ device=wp.device).random_(1, wp.shape[1])
250
+ cutoff = torch.where(
251
+ torch.rand([], device=wp.device) < style_mixing_prob,
252
+ cutoff, torch.full_like(cutoff, wp.shape[1]))
253
+ wp[:, cutoff:] = self.mapping(torch.randn_like(z),
254
+ c,
255
+ update_emas=update_emas)[:, cutoff:]
256
+ if not sample_mixed:
257
+ gen_output = self.synthesis(
258
+ wp,
259
+ c,
260
+ update_emas=update_emas,
261
+ neural_rendering_resolution=neural_rendering_resolution,
262
+ **synthesis_kwargs)
263
+
264
+ return {
265
+ 'wp': wp,
266
+ 'gen_output': gen_output,
267
+ }
268
+
269
+ else:
270
+ # Only for density regularization in training process.
271
+ assert coordinates is not None
272
+ sample_sigma = self.sample_mixed(coordinates,
273
+ torch.randn_like(coordinates),
274
+ wp,
275
+ update_emas=False)['sigma']
276
+
277
+ return {
278
+ 'wp': wp,
279
+ 'sample_sigma': sample_sigma
280
+ }
281
+
282
+
283
+ class OSGDecoder(nn.Module):
284
+ """Defines fully-connected layer head in EG3D."""
285
+ def __init__(self, n_features, options):
286
+ super().__init__()
287
+ self.hidden_dim = 64
288
+
289
+ self.net = nn.Sequential(
290
+ FullyConnectedLayer(n_features,
291
+ self.hidden_dim,
292
+ lr_multiplier=options['decoder_lr_mul']),
293
+ nn.Softplus(),
294
+ FullyConnectedLayer(self.hidden_dim,
295
+ 1 + options['decoder_output_dim'],
296
+ lr_multiplier=options['decoder_lr_mul']))
297
+
298
+ def forward(self, point_features, wp=None, dirs=None):
299
+ # Aggregate features
300
+ # point_features.shape: [N, 3, M, C].
301
+ # Average across 'X, Y, Z' planes.
302
+ point_features = point_features.mean(1)
303
+ x = point_features
304
+
305
+ N, M, C = x.shape
306
+ x = x.view(N * M, C)
307
+
308
+ x = self.net(x)
309
+ x = x.view(N, M, -1)
310
+
311
+ # Uses sigmoid clamping from MipNeRF
312
+ rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
313
+ sigma = x[..., 0:1]
314
+
315
+ return {'rgb': rgb, 'sigma': sigma}
models/eg3d_generator_fv.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.8
2
+ """Contains the implementation of generator described in EG3D."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from models.utils.official_stylegan2_model_helper import MappingNetwork
8
+ from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
9
+ from models.utils.eg3d_superres import SuperresolutionHybrid2X
10
+ from models.utils.eg3d_superres import SuperresolutionHybrid4X
11
+ from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
12
+ from models.rendering.renderer import Renderer
13
+ from models.rendering.feature_extractor import FeatureExtractor
14
+ from models.volumegan_generator import FeatureVolume
15
+ from models.volumegan_generator import PositionEncoder
16
+
17
+
18
+ class EG3DGeneratorFV(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ # Input latent (Z) dimensionality.
23
+ z_dim,
24
+ # Conditioning label (C) dimensionality.
25
+ c_dim,
26
+ # Intermediate latent (W) dimensionality.
27
+ w_dim,
28
+ # Final output image resolution.
29
+ img_resolution,
30
+ # Number of output color channels.
31
+ img_channels,
32
+ # Number of fp16 layers of SR Network.
33
+ sr_num_fp16_res=0,
34
+ # Arguments for MappingNetwork.
35
+ mapping_kwargs={},
36
+ # Arguments for rendering.
37
+ rendering_kwargs={},
38
+ # Arguments for SuperResolution Network.
39
+ sr_kwargs={},
40
+ # Configs for FeatureVolume.
41
+ fv_cfg=dict(feat_res=32,
42
+ init_res=4,
43
+ base_channels=256,
44
+ output_channels=32,
45
+ w_dim=512),
46
+ # Configs for position encoder.
47
+ embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10),
48
+ ):
49
+ super().__init__()
50
+ self.z_dim = z_dim
51
+ self.c_dim = c_dim
52
+ self.w_dim = w_dim
53
+ self.img_resolution = img_resolution
54
+ self.img_channels = img_channels
55
+
56
+ # Set up mapping network.
57
+ # Here `num_ws = 2`: one for FeatureVolume Network injection and one for
58
+ # post_neural_renderer injection.
59
+ num_ws = 2
60
+ self.mapping_network = MappingNetwork(z_dim=z_dim,
61
+ c_dim=c_dim,
62
+ w_dim=w_dim,
63
+ num_ws=num_ws,
64
+ **mapping_kwargs)
65
+
66
+ # Set up the overall renderer.
67
+ self.renderer = Renderer()
68
+
69
+ # Set up the feature extractor.
70
+ self.feature_extractor = FeatureExtractor(ref_mode='feature_volume')
71
+
72
+ # Set up the reference representation generator.
73
+ self.ref_representation_generator = FeatureVolume(**fv_cfg)
74
+
75
+ # Set up the position encoder.
76
+ self.position_encoder = PositionEncoder(**embed_cfg)
77
+
78
+ # Set up the post module in the feature extractor.
79
+ self.post_module = None
80
+
81
+ # Set up the post neural renderer.
82
+ self.post_neural_renderer = None
83
+ sr_kwargs_total = dict(
84
+ channels=32,
85
+ img_resolution=img_resolution,
86
+ sr_num_fp16_res=sr_num_fp16_res,
87
+ sr_antialias=rendering_kwargs['sr_antialias'],)
88
+ sr_kwargs_total.update(**sr_kwargs)
89
+ if img_resolution == 128:
90
+ self.post_neural_renderer = SuperresolutionHybrid2X(
91
+ **sr_kwargs_total)
92
+ elif img_resolution == 256:
93
+ self.post_neural_renderer = SuperresolutionHybrid4X(
94
+ **sr_kwargs_total)
95
+ elif img_resolution == 512:
96
+ self.post_neural_renderer = SuperresolutionHybrid8XDC(
97
+ **sr_kwargs_total)
98
+ else:
99
+ raise TypeError(f'Unsupported image resolution: {img_resolution}!')
100
+
101
+ # Set up the fully-connected layer head.
102
+ self.fc_head = OSGDecoder(
103
+ 32, {
104
+ 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
105
+ 'decoder_output_dim': 32
106
+ })
107
+
108
+ # Set up some rendering related arguments.
109
+ self.neural_rendering_resolution = rendering_kwargs.get(
110
+ 'resolution', 64)
111
+ self.rendering_kwargs = rendering_kwargs
112
+
113
+ def mapping(self,
114
+ z,
115
+ c,
116
+ truncation_psi=1,
117
+ truncation_cutoff=None,
118
+ update_emas=False):
119
+ if self.rendering_kwargs['c_gen_conditioning_zero']:
120
+ c = torch.zeros_like(c)
121
+ return self.mapping_network(z,
122
+ c *
123
+ self.rendering_kwargs.get('c_scale', 0),
124
+ truncation_psi=truncation_psi,
125
+ truncation_cutoff=truncation_cutoff,
126
+ update_emas=update_emas)
127
+
128
+ def synthesis(self,
129
+ wp,
130
+ c,
131
+ neural_rendering_resolution=None,
132
+ update_emas=False,
133
+ **synthesis_kwargs):
134
+ cam2world_matrix = c[:, :16].view(-1, 4, 4)
135
+ if self.rendering_kwargs.get('random_pose', False):
136
+ cam2world_matrix = None
137
+
138
+ if neural_rendering_resolution is None:
139
+ neural_rendering_resolution = self.neural_rendering_resolution
140
+ else:
141
+ self.neural_rendering_resolution = neural_rendering_resolution
142
+
143
+ feature_volume = self.ref_representation_generator(wp)
144
+
145
+ rendering_result = self.renderer(
146
+ wp=wp,
147
+ feature_extractor=self.feature_extractor,
148
+ rendering_options=self.rendering_kwargs,
149
+ cam2world_matrix=cam2world_matrix,
150
+ position_encoder=self.position_encoder,
151
+ ref_representation=feature_volume,
152
+ post_module=self.post_module,
153
+ fc_head=self.fc_head)
154
+
155
+ feature_samples = rendering_result['composite_rgb']
156
+ depth_samples = rendering_result['composite_depth']
157
+
158
+ # Reshape to keep consistent with 'raw' neural-rendered image.
159
+ N = wp.shape[0]
160
+ H = W = self.neural_rendering_resolution
161
+ feature_image = feature_samples.permute(0, 2, 1).reshape(
162
+ N, feature_samples.shape[-1], H, W).contiguous()
163
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
164
+
165
+ # Run the post neural renderer to get final image.
166
+ # Here, the post neural renderer is a super-resolution network.
167
+ rgb_image = feature_image[:, :3]
168
+ sr_image = self.post_neural_renderer(
169
+ rgb_image,
170
+ feature_image,
171
+ wp,
172
+ noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
173
+ **{
174
+ k: synthesis_kwargs[k]
175
+ for k in synthesis_kwargs.keys() if k != 'noise_mode'
176
+ })
177
+
178
+ return {
179
+ 'image': sr_image,
180
+ 'image_raw': rgb_image,
181
+ 'image_depth': depth_image
182
+ }
183
+
184
+ def sample(self,
185
+ coordinates,
186
+ directions,
187
+ z,
188
+ c,
189
+ truncation_psi=1,
190
+ truncation_cutoff=None,
191
+ update_emas=False):
192
+ # Compute RGB features, density for arbitrary 3D coordinates.
193
+ # Mostly used for extracting shapes.
194
+ wp = self.mapping_network(z,
195
+ c,
196
+ truncation_psi=truncation_psi,
197
+ truncation_cutoff=truncation_cutoff,
198
+ update_emas=update_emas)
199
+ feature_volume = self.ref_representation_generator(wp)
200
+ result = self.renderer.get_sigma_rgb(
201
+ wp=wp,
202
+ points=coordinates,
203
+ feature_extractor=self.feature_extractor,
204
+ fc_head=self.fc_head,
205
+ rendering_options=self.rendering_kwargs,
206
+ ref_representation=feature_volume,
207
+ position_encoder=self.position_encoder,
208
+ post_module=self.post_module,
209
+ ray_dirs=directions)
210
+
211
+ return result
212
+
213
+ def sample_mixed(self,
214
+ coordinates,
215
+ directions,
216
+ wp):
217
+ # Same as function `self.sample()`, but expects latent vectors 'wp'
218
+ # instead of Gaussian noise 'z'.
219
+ feature_volume = self.ref_representation_generator(wp)
220
+ result = self.renderer.get_sigma_rgb(
221
+ wp=wp,
222
+ points=coordinates,
223
+ feature_extractor=self.feature_extractor,
224
+ fc_head=self.fc_head,
225
+ rendering_options=self.rendering_kwargs,
226
+ ref_representation=feature_volume,
227
+ position_encoder=self.position_encoder,
228
+ post_module=self.post_module,
229
+ ray_dirs=directions)
230
+
231
+ return result
232
+
233
+ def forward(self,
234
+ z,
235
+ c,
236
+ c_swapped=None, # `c_swapped` is swapped pose conditioning.
237
+ style_mixing_prob=0,
238
+ truncation_psi=1,
239
+ truncation_cutoff=None,
240
+ neural_rendering_resolution=None,
241
+ update_emas=False,
242
+ sample_mixed=False,
243
+ coordinates=None,
244
+ **synthesis_kwargs):
245
+
246
+ # Render a batch of generated images.
247
+ c_wp = c.clone()
248
+ if c_swapped is not None:
249
+ c_wp = c_swapped.clone()
250
+ wp = self.mapping_network(z,
251
+ c_wp,
252
+ truncation_psi=truncation_psi,
253
+ truncation_cutoff=truncation_cutoff,
254
+ update_emas=update_emas)
255
+ if style_mixing_prob > 0:
256
+ cutoff = torch.empty([], dtype=torch.int64,
257
+ device=wp.device).random_(1, wp.shape[1])
258
+ cutoff = torch.where(
259
+ torch.rand([], device=wp.device) < style_mixing_prob, cutoff,
260
+ torch.full_like(cutoff, wp.shape[1]))
261
+ wp[:, cutoff:] = self.mapping_network(
262
+ torch.randn_like(z), c, update_emas=update_emas)[:, cutoff:]
263
+ if not sample_mixed:
264
+ gen_output = self.synthesis(
265
+ wp,
266
+ c,
267
+ update_emas=update_emas,
268
+ neural_rendering_resolution=neural_rendering_resolution,
269
+ **synthesis_kwargs)
270
+
271
+ return {
272
+ 'wp': wp,
273
+ 'gen_output': gen_output,
274
+ }
275
+
276
+ else:
277
+ # Only for density regularization in training process.
278
+ assert coordinates is not None
279
+ sample_sigma = self.sample_mixed(coordinates,
280
+ torch.randn_like(coordinates),
281
+ wp)['sigma']
282
+
283
+ return {
284
+ 'wp': wp,
285
+ 'sample_sigma': sample_sigma
286
+ }
287
+
288
+
289
+ class OSGDecoder(nn.Module):
290
+ """Defines fully-connected layer head in EG3D."""
291
+ def __init__(self, n_features, options):
292
+ super().__init__()
293
+ self.hidden_dim = 64
294
+
295
+ self.net = nn.Sequential(
296
+ FullyConnectedLayer(n_features,
297
+ self.hidden_dim,
298
+ lr_multiplier=options['decoder_lr_mul']),
299
+ nn.Softplus(),
300
+ FullyConnectedLayer(self.hidden_dim,
301
+ 1 + options['decoder_output_dim'],
302
+ lr_multiplier=options['decoder_lr_mul']))
303
+
304
+ def forward(self, point_features, wp=None, dirs=None):
305
+ # point_features.shape: [N, C, M, 1].
306
+ point_features = point_features.squeeze(-1)
307
+ point_features = point_features.permute(0, 2, 1)
308
+ x = point_features
309
+
310
+ N, M, C = x.shape
311
+ x = x.reshape(N * M, C)
312
+
313
+ x = self.net(x)
314
+ x = x.reshape(N, M, -1)
315
+
316
+ # Uses sigmoid clamping from MipNeRF
317
+ rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
318
+ sigma = x[..., 0:1]
319
+
320
+ return {'rgb': rgb, 'sigma': sigma}
models/ghfeat_encoder.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of encoder used in GH-Feat (including IDInvert).
3
+
4
+ ResNet is used as the backbone.
5
+
6
+ GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf
7
+ IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf
8
+
9
+ NOTE: Please use `latent_num` and `num_latents_per_head` to control the
10
+ inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert.
11
+ In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default.
12
+ """
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import torch.distributed as dist
20
+
21
+ __all__ = ['GHFeatEncoder']
22
+
23
+ # Resolutions allowed.
24
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
25
+
26
+ # pylint: disable=missing-function-docstring
27
+
28
+ class BasicBlock(nn.Module):
29
+ """Implementation of ResNet BasicBlock."""
30
+
31
+ expansion = 1
32
+
33
+ def __init__(self,
34
+ inplanes,
35
+ planes,
36
+ base_width=64,
37
+ stride=1,
38
+ groups=1,
39
+ dilation=1,
40
+ norm_layer=None,
41
+ downsample=None):
42
+ super().__init__()
43
+ if base_width != 64:
44
+ raise ValueError(f'BasicBlock of ResNet only supports '
45
+ f'`base_width=64`, but {base_width} received!')
46
+ if stride not in [1, 2]:
47
+ raise ValueError(f'BasicBlock of ResNet only supports `stride=1` '
48
+ f'and `stride=2`, but {stride} received!')
49
+ if groups != 1:
50
+ raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, '
51
+ f'but {groups} received!')
52
+ if dilation != 1:
53
+ raise ValueError(f'BasicBlock of ResNet only supports '
54
+ f'`dilation=1`, but {dilation} received!')
55
+ assert self.expansion == 1
56
+
57
+ self.stride = stride
58
+ if norm_layer is None:
59
+ norm_layer = nn.BatchNorm2d
60
+ self.conv1 = nn.Conv2d(in_channels=inplanes,
61
+ out_channels=planes,
62
+ kernel_size=3,
63
+ stride=stride,
64
+ padding=1,
65
+ groups=1,
66
+ dilation=1,
67
+ bias=False)
68
+ self.bn1 = norm_layer(planes)
69
+ self.relu = nn.ReLU(inplace=True)
70
+ self.conv2 = nn.Conv2d(in_channels=planes,
71
+ out_channels=planes,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ groups=1,
76
+ dilation=1,
77
+ bias=False)
78
+ self.bn2 = norm_layer(planes)
79
+ self.downsample = downsample
80
+
81
+ def forward(self, x):
82
+ identity = self.downsample(x) if self.downsample is not None else x
83
+
84
+ out = self.conv1(x)
85
+ out = self.bn1(out)
86
+ out = self.relu(out)
87
+
88
+ out = self.conv2(out)
89
+ out = self.bn2(out)
90
+ out = self.relu(out + identity)
91
+
92
+ return out
93
+
94
+
95
+ class Bottleneck(nn.Module):
96
+ """Implementation of ResNet Bottleneck."""
97
+
98
+ expansion = 4
99
+
100
+ def __init__(self,
101
+ inplanes,
102
+ planes,
103
+ base_width=64,
104
+ stride=1,
105
+ groups=1,
106
+ dilation=1,
107
+ norm_layer=None,
108
+ downsample=None):
109
+ super().__init__()
110
+ if stride not in [1, 2]:
111
+ raise ValueError(f'Bottleneck of ResNet only supports `stride=1` '
112
+ f'and `stride=2`, but {stride} received!')
113
+
114
+ width = int(planes * (base_width / 64)) * groups
115
+ self.stride = stride
116
+ if norm_layer is None:
117
+ norm_layer = nn.BatchNorm2d
118
+ self.conv1 = nn.Conv2d(in_channels=inplanes,
119
+ out_channels=width,
120
+ kernel_size=1,
121
+ stride=1,
122
+ padding=0,
123
+ dilation=1,
124
+ groups=1,
125
+ bias=False)
126
+ self.bn1 = norm_layer(width)
127
+ self.conv2 = nn.Conv2d(in_channels=width,
128
+ out_channels=width,
129
+ kernel_size=3,
130
+ stride=stride,
131
+ padding=dilation,
132
+ groups=groups,
133
+ dilation=dilation,
134
+ bias=False)
135
+ self.bn2 = norm_layer(width)
136
+ self.conv3 = nn.Conv2d(in_channels=width,
137
+ out_channels=planes * self.expansion,
138
+ kernel_size=1,
139
+ stride=1,
140
+ padding=0,
141
+ dilation=1,
142
+ groups=1,
143
+ bias=False)
144
+ self.bn3 = norm_layer(planes * self.expansion)
145
+ self.relu = nn.ReLU(inplace=True)
146
+ self.downsample = downsample
147
+
148
+ def forward(self, x):
149
+ identity = self.downsample(x) if self.downsample is not None else x
150
+
151
+ out = self.conv1(x)
152
+ out = self.bn1(out)
153
+ out = self.relu(out)
154
+
155
+ out = self.conv2(out)
156
+ out = self.bn2(out)
157
+ out = self.relu(out)
158
+
159
+ out = self.conv3(out)
160
+ out = self.bn3(out)
161
+ out = self.relu(out + identity)
162
+
163
+ return out
164
+
165
+
166
+ class GHFeatEncoder(nn.Module):
167
+ """Define the ResNet-based encoder network for GAN inversion.
168
+
169
+ On top of the backbone, there are several task-heads to produce inverted
170
+ codes. Please use `latent_dim` and `num_latents_per_head` to define the
171
+ structure. For example, `latent_dim = [512] * 14` and
172
+ `num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with
173
+ 14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers,
174
+ respectively) are used.
175
+
176
+ Settings for the encoder network:
177
+
178
+ (1) resolution: The resolution of the output image.
179
+ (2) latent_dim: Dimension of the latent space. A number (one code will be
180
+ produced), or a list of numbers regarding layer-wise latent codes.
181
+ (3) num_latents_per_head: Number of latents that is produced by each head.
182
+ (4) image_channels: Number of channels of the output image. (default: 3)
183
+ (5) final_res: Final resolution of the convolutional layers. (default: 4)
184
+
185
+ ResNet-related settings:
186
+
187
+ (1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18)
188
+ (2) inplanes: Number of channels of the first convolutional layer.
189
+ (default: 64)
190
+ (3) groups: Groups of the convolution, used in ResNet. (default: 1)
191
+ (4) width_per_group: Number of channels per group, used in ResNet.
192
+ (default: 64)
193
+ (5) replace_stride_with_dilation: Whether to replace stride with dilation,
194
+ used in ResNet. (default: None)
195
+ (6) norm_layer: Normalization layer used in the encoder. If set as `None`,
196
+ `nn.BatchNorm2d` will be used. Also, please NOTE that when using batch
197
+ normalization, the batch size is required to be larger than one for
198
+ training. (default: nn.BatchNorm2d)
199
+ (7) max_channels: Maximum number of channels in each layer. (default: 512)
200
+
201
+ Task-head related settings:
202
+
203
+ (1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting
204
+ the latent code. (default: True)
205
+ (2) fpn_channels: Number of channels used in FPN. (default: 512)
206
+ (3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting
207
+ the latent code. (default: True)
208
+ (4) sam_channels: Number of channels used in SAM. (default: 512)
209
+ """
210
+
211
+ arch_settings = {
212
+ 18: (BasicBlock, [2, 2, 2, 2]),
213
+ 34: (BasicBlock, [3, 4, 6, 3]),
214
+ 50: (Bottleneck, [3, 4, 6, 3]),
215
+ 101: (Bottleneck, [3, 4, 23, 3]),
216
+ 152: (Bottleneck, [3, 8, 36, 3])
217
+ }
218
+
219
+ def __init__(self,
220
+ resolution,
221
+ latent_dim,
222
+ num_latents_per_head,
223
+ image_channels=3,
224
+ final_res=4,
225
+ network_depth=18,
226
+ inplanes=64,
227
+ groups=1,
228
+ width_per_group=64,
229
+ replace_stride_with_dilation=None,
230
+ norm_layer=nn.BatchNorm2d,
231
+ max_channels=512,
232
+ use_fpn=True,
233
+ fpn_channels=512,
234
+ use_sam=True,
235
+ sam_channels=512):
236
+ super().__init__()
237
+
238
+ if resolution not in _RESOLUTIONS_ALLOWED:
239
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
240
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
241
+ if network_depth not in self.arch_settings:
242
+ raise ValueError(f'Invalid network depth: `{network_depth}`!\n'
243
+ f'Options allowed: '
244
+ f'{list(self.arch_settings.keys())}.')
245
+ if isinstance(latent_dim, int):
246
+ latent_dim = [latent_dim]
247
+ assert isinstance(latent_dim, (list, tuple))
248
+ assert isinstance(num_latents_per_head, (list, tuple))
249
+ assert sum(num_latents_per_head) == len(latent_dim)
250
+
251
+ self.resolution = resolution
252
+ self.latent_dim = latent_dim
253
+ self.num_latents_per_head = num_latents_per_head
254
+ self.num_heads = len(self.num_latents_per_head)
255
+ self.image_channels = image_channels
256
+ self.final_res = final_res
257
+ self.inplanes = inplanes
258
+ self.network_depth = network_depth
259
+ self.groups = groups
260
+ self.dilation = 1
261
+ self.base_width = width_per_group
262
+ self.replace_stride_with_dilation = replace_stride_with_dilation
263
+ if norm_layer is None:
264
+ norm_layer = nn.BatchNorm2d
265
+ if norm_layer == nn.BatchNorm2d and dist.is_initialized():
266
+ norm_layer = nn.SyncBatchNorm
267
+ self.norm_layer = norm_layer
268
+ self.max_channels = max_channels
269
+ self.use_fpn = use_fpn
270
+ self.fpn_channels = fpn_channels
271
+ self.use_sam = use_sam
272
+ self.sam_channels = sam_channels
273
+
274
+ block_fn, num_blocks_per_stage = self.arch_settings[network_depth]
275
+
276
+ self.num_stages = int(np.log2(resolution // final_res)) - 1
277
+ # Add one block for additional stages.
278
+ for i in range(len(num_blocks_per_stage), self.num_stages):
279
+ num_blocks_per_stage.append(1)
280
+ if replace_stride_with_dilation is None:
281
+ replace_stride_with_dilation = [False] * self.num_stages
282
+
283
+ # Backbone.
284
+ self.conv1 = nn.Conv2d(in_channels=self.image_channels,
285
+ out_channels=self.inplanes,
286
+ kernel_size=7,
287
+ stride=2,
288
+ padding=3,
289
+ bias=False)
290
+ self.bn1 = norm_layer(self.inplanes)
291
+ self.relu = nn.ReLU(inplace=True)
292
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
293
+
294
+ self.stage_channels = [self.inplanes]
295
+ self.stages = nn.ModuleList()
296
+ for i in range(self.num_stages):
297
+ inplanes = self.inplanes if i == 0 else planes * block_fn.expansion
298
+ planes = min(self.max_channels, self.inplanes * (2 ** i))
299
+ num_blocks = num_blocks_per_stage[i]
300
+ stride = 1 if i == 0 else 2
301
+ dilate = replace_stride_with_dilation[i]
302
+ self.stages.append(self._make_stage(block_fn=block_fn,
303
+ inplanes=inplanes,
304
+ planes=planes,
305
+ num_blocks=num_blocks,
306
+ stride=stride,
307
+ dilate=dilate))
308
+ self.stage_channels.append(planes * block_fn.expansion)
309
+
310
+ if self.num_heads > len(self.stage_channels):
311
+ raise ValueError('Number of task heads is larger than number of '
312
+ 'stages! Please reduce the number of heads.')
313
+
314
+ # Task-head.
315
+ if self.num_heads == 1:
316
+ self.use_fpn = False
317
+ self.use_sam = False
318
+
319
+ if self.use_fpn:
320
+ fpn_pyramid_channels = self.stage_channels[-self.num_heads:]
321
+ self.fpn = FPN(pyramid_channels=fpn_pyramid_channels,
322
+ out_channels=self.fpn_channels)
323
+ if self.use_sam:
324
+ if self.use_fpn:
325
+ sam_pyramid_channels = [self.fpn_channels] * self.num_heads
326
+ else:
327
+ sam_pyramid_channels = self.stage_channels[-self.num_heads:]
328
+ self.sam = SAM(pyramid_channels=sam_pyramid_channels,
329
+ out_channels=self.sam_channels)
330
+
331
+ self.heads = nn.ModuleList()
332
+ for head_idx in range(self.num_heads):
333
+ # Parse in_channels.
334
+ if self.use_sam:
335
+ in_channels = self.sam_channels
336
+ elif self.use_fpn:
337
+ in_channels = self.fpn_channels
338
+ else:
339
+ in_channels = self.stage_channels[head_idx - self.num_heads]
340
+ in_channels = in_channels * final_res * final_res
341
+
342
+ # Parse out_channels.
343
+ start_latent_idx = sum(self.num_latents_per_head[:head_idx])
344
+ end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
345
+ out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx])
346
+
347
+ self.heads.append(CodeHead(in_channels=in_channels,
348
+ out_channels=out_channels,
349
+ norm_layer=self.norm_layer))
350
+
351
+ def _make_stage(self,
352
+ block_fn,
353
+ inplanes,
354
+ planes,
355
+ num_blocks,
356
+ stride,
357
+ dilate):
358
+ norm_layer = self.norm_layer
359
+ downsample = None
360
+ previous_dilation = self.dilation
361
+ if dilate:
362
+ self.dilation *= stride
363
+ stride = 1
364
+ if stride != 1 or inplanes != planes * block_fn.expansion:
365
+ downsample = nn.Sequential(
366
+ nn.Conv2d(in_channels=inplanes,
367
+ out_channels=planes * block_fn.expansion,
368
+ kernel_size=1,
369
+ stride=stride,
370
+ padding=0,
371
+ dilation=1,
372
+ groups=1,
373
+ bias=False),
374
+ norm_layer(planes * block_fn.expansion),
375
+ )
376
+
377
+ blocks = []
378
+ blocks.append(block_fn(inplanes=inplanes,
379
+ planes=planes,
380
+ base_width=self.base_width,
381
+ stride=stride,
382
+ groups=self.groups,
383
+ dilation=previous_dilation,
384
+ norm_layer=norm_layer,
385
+ downsample=downsample))
386
+ for _ in range(1, num_blocks):
387
+ blocks.append(block_fn(inplanes=planes * block_fn.expansion,
388
+ planes=planes,
389
+ base_width=self.base_width,
390
+ stride=1,
391
+ groups=self.groups,
392
+ dilation=self.dilation,
393
+ norm_layer=norm_layer,
394
+ downsample=None))
395
+
396
+ return nn.Sequential(*blocks)
397
+
398
+ def forward(self, x):
399
+ x = self.conv1(x)
400
+ x = self.bn1(x)
401
+ x = self.relu(x)
402
+ x = self.maxpool(x)
403
+
404
+ features = [x]
405
+ for i in range(self.num_stages):
406
+ x = self.stages[i](x)
407
+ features.append(x)
408
+ features = features[-self.num_heads:]
409
+
410
+ if self.use_fpn:
411
+ features = self.fpn(features)
412
+ if self.use_sam:
413
+ features = self.sam(features)
414
+ else:
415
+ final_size = features[-1].shape[2:]
416
+ for i in range(self.num_heads - 1):
417
+ features[i] = F.adaptive_avg_pool2d(features[i], final_size)
418
+
419
+ outputs = []
420
+ for head_idx in range(self.num_heads):
421
+ codes = self.heads[head_idx](features[head_idx])
422
+ start_latent_idx = sum(self.num_latents_per_head[:head_idx])
423
+ end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
424
+ split_size = self.latent_dim[start_latent_idx:end_latent_idx]
425
+ outputs.extend(torch.split(codes, split_size, dim=1))
426
+ max_dim = max(self.latent_dim)
427
+ for i, dim in enumerate(self.latent_dim):
428
+ if dim < max_dim:
429
+ outputs[i] = F.pad(outputs[i], (0, max_dim - dim))
430
+ outputs[i] = outputs[i].unsqueeze(1)
431
+
432
+ return torch.cat(outputs, dim=1)
433
+
434
+
435
+ class FPN(nn.Module):
436
+ """Implementation of Feature Pyramid Network (FPN).
437
+
438
+ The input of this module is a pyramid of features with reducing resolutions.
439
+ Then, this module fuses these multi-level features from `top_level` to
440
+ `bottom_level`. In particular, starting from the `top_level`, each feature
441
+ is convoluted, upsampled, and fused into its previous feature (which is also
442
+ convoluted).
443
+
444
+ Args:
445
+ pyramid_channels: A list of integers, each of which indicates the number
446
+ of channels of the feature from a particular level.
447
+ out_channels: Number of channels for each output.
448
+
449
+ Returns:
450
+ A list of feature maps, each of which has `out_channels` channels.
451
+ """
452
+
453
+ def __init__(self, pyramid_channels, out_channels):
454
+ super().__init__()
455
+ assert isinstance(pyramid_channels, (list, tuple))
456
+ self.num_levels = len(pyramid_channels)
457
+
458
+ self.lateral_layers = nn.ModuleList()
459
+ self.feature_layers = nn.ModuleList()
460
+ for i in range(self.num_levels):
461
+ in_channels = pyramid_channels[i]
462
+ self.lateral_layers.append(nn.Conv2d(in_channels=in_channels,
463
+ out_channels=out_channels,
464
+ kernel_size=3,
465
+ padding=1,
466
+ bias=True))
467
+ self.feature_layers.append(nn.Conv2d(in_channels=out_channels,
468
+ out_channels=out_channels,
469
+ kernel_size=3,
470
+ padding=1,
471
+ bias=True))
472
+
473
+ def forward(self, inputs):
474
+ if len(inputs) != self.num_levels:
475
+ raise ValueError('Number of inputs and `num_levels` mismatch!')
476
+
477
+ # Project all related features to `out_channels`.
478
+ laterals = []
479
+ for i in range(self.num_levels):
480
+ laterals.append(self.lateral_layers[i](inputs[i]))
481
+
482
+ # Fusion, starting from `top_level`.
483
+ for i in range(self.num_levels - 1, 0, -1):
484
+ scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2]
485
+ laterals[i - 1] = (laterals[i - 1] +
486
+ F.interpolate(laterals[i],
487
+ mode='nearest',
488
+ scale_factor=scale_factor))
489
+
490
+ # Get outputs.
491
+ outputs = []
492
+ for i, lateral in enumerate(laterals):
493
+ outputs.append(self.feature_layers[i](lateral))
494
+
495
+ return outputs
496
+
497
+
498
+ class SAM(nn.Module):
499
+ """Implementation of Spatial Alignment Module (SAM).
500
+
501
+ The input of this module is a pyramid of features with reducing resolutions.
502
+ Then this module downsamples all levels of feature to the minimum resolution
503
+ and fuses it with the smallest feature map.
504
+
505
+ Args:
506
+ pyramid_channels: A list of integers, each of which indicates the number
507
+ of channels of the feature from a particular level.
508
+ out_channels: Number of channels for each output.
509
+
510
+ Returns:
511
+ A list of feature maps, each of which has `out_channels` channels.
512
+ """
513
+
514
+ def __init__(self, pyramid_channels, out_channels):
515
+ super().__init__()
516
+ assert isinstance(pyramid_channels, (list, tuple))
517
+ self.num_levels = len(pyramid_channels)
518
+
519
+ self.fusion_layers = nn.ModuleList()
520
+ for i in range(self.num_levels):
521
+ in_channels = pyramid_channels[i]
522
+ self.fusion_layers.append(nn.Conv2d(in_channels=in_channels,
523
+ out_channels=out_channels,
524
+ kernel_size=3,
525
+ padding=1,
526
+ bias=True))
527
+
528
+ def forward(self, inputs):
529
+ if len(inputs) != self.num_levels:
530
+ raise ValueError('Number of inputs and `num_levels` mismatch!')
531
+
532
+ output_res = inputs[-1].shape[2:]
533
+ for i in range(self.num_levels - 1, -1, -1):
534
+ if i != self.num_levels - 1:
535
+ inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res)
536
+ inputs[i] = self.fusion_layers[i](inputs[i])
537
+ if i != self.num_levels - 1:
538
+ inputs[i] = inputs[i] + inputs[-1]
539
+
540
+ return inputs
541
+
542
+
543
+ class CodeHead(nn.Module):
544
+ """Implementation of the task-head to produce inverted codes."""
545
+
546
+ def __init__(self, in_channels, out_channels, norm_layer):
547
+ super().__init__()
548
+ self.fc = nn.Linear(in_channels, out_channels, bias=True)
549
+ if norm_layer is None:
550
+ self.norm = nn.Identity()
551
+ else:
552
+ self.norm = norm_layer(out_channels)
553
+
554
+ def forward(self, x):
555
+ if x.ndim > 2:
556
+ x = x.flatten(start_dim=1)
557
+ latent = self.fc(x)
558
+ latent = latent.unsqueeze(2).unsqueeze(3)
559
+ latent = self.norm(latent)
560
+
561
+ return latent.flatten(start_dim=1)
562
+
563
+ # pylint: enable=missing-function-docstring
models/inception_model.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the Inception V3 model, which is used for inference ONLY.
3
+
4
+ This file is mostly borrowed from `torchvision/models/inception.py`.
5
+
6
+ Inception model is widely used to compute FID or IS metric for evaluating
7
+ generative models. However, the pre-trained models from torchvision is slightly
8
+ different from the TensorFlow version
9
+
10
+ http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
11
+
12
+ which is used by the official FID implementation
13
+
14
+ https://github.com/bioinf-jku/TTUR
15
+
16
+ In particular:
17
+
18
+ (1) The number of classes in TensorFlow model is 1008 instead of 1000.
19
+ (2) The avg_pool() layers in TensorFlow model does not include the padded zero.
20
+ (3) The last Inception E Block in TensorFlow model use max_pool() instead of
21
+ avg_pool().
22
+
23
+ Hence, to align the evaluation results with those from TensorFlow
24
+ implementation, we modified the inception model to support both versions. Please
25
+ use `align_tf` argument to control the version.
26
+ """
27
+
28
+ import warnings
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ import torch.distributed as dist
34
+
35
+ from utils.misc import download_url
36
+
37
+ __all__ = ['InceptionModel']
38
+
39
+ # pylint: disable=line-too-long
40
+
41
+ _MODEL_URL_SHA256 = {
42
+ # This model is provided by `torchvision`, which is ported from TensorFlow.
43
+ 'torchvision_official': (
44
+ 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
45
+ '1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256
46
+ ),
47
+
48
+ # This model is provided by https://github.com/mseitzer/pytorch-fid
49
+ 'tf_inception_v3': (
50
+ 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
51
+ '6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256
52
+ )
53
+ }
54
+
55
+
56
+ class InceptionModel(object):
57
+ """Defines the Inception (V3) model.
58
+
59
+ This is a static class, which is used to avoid this model to be built
60
+ repeatedly. Consequently, this model is particularly used for inference,
61
+ like computing FID. If training is required, please use the model from
62
+ `torchvision.models` or implement by yourself.
63
+
64
+ NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
65
+ order and pixel range [-1, 1], and will also resize the images to shape
66
+ [299, 299] automatically. If your input is normalized by subtracting
67
+ (0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
68
+ `transform_input` in the `forward()` function to un-normalize it.
69
+ """
70
+ models = dict()
71
+
72
+ @staticmethod
73
+ def build_model(align_tf=True):
74
+ """Builds the model and load pre-trained weights.
75
+
76
+ If `align_tf` is set as True, the model will predict 1008 classes, and
77
+ the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
78
+ will be loaded. Otherwise, the model will predict 1000 classes, and will
79
+ load the model from `torchvision`.
80
+
81
+ The built model supports following arguments when forwarding:
82
+
83
+ - transform_input: Whether to transform the input back to pixel range
84
+ (-1, 1). Please disable this argument if your input is already with
85
+ pixel range (-1, 1). (default: False)
86
+ - output_logits: Whether to output the categorical logits instead of
87
+ features. (default: False)
88
+ - remove_logits_bias: Whether to remove the bias when computing the
89
+ logits. The official implementation removes the bias by default.
90
+ Please refer to
91
+ `https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
92
+ (default: False)
93
+ - output_predictions: Whether to output the final predictions, i.e.,
94
+ `softmax(logits)`. (default: False)
95
+ """
96
+ if align_tf:
97
+ num_classes = 1008
98
+ model_source = 'tf_inception_v3'
99
+ else:
100
+ num_classes = 1000
101
+ model_source = 'torchvision_official'
102
+
103
+ fingerprint = model_source
104
+
105
+ if fingerprint not in InceptionModel.models:
106
+ # Build model.
107
+ model = Inception3(num_classes=num_classes,
108
+ aux_logits=False,
109
+ init_weights=False,
110
+ align_tf=align_tf)
111
+
112
+ # Download pre-trained weights.
113
+ if dist.is_initialized() and dist.get_rank() != 0:
114
+ dist.barrier() # Download by chief.
115
+
116
+ url, sha256 = _MODEL_URL_SHA256[model_source]
117
+ filename = f'inception_model_{model_source}_{sha256}.pth'
118
+ model_path, hash_check = download_url(url,
119
+ filename=filename,
120
+ sha256=sha256)
121
+ state_dict = torch.load(model_path, map_location='cpu')
122
+ if hash_check is False:
123
+ warnings.warn(f'Hash check failed! The remote file from URL '
124
+ f'`{url}` may be changed, or the downloading is '
125
+ f'interrupted. The loaded inception model may '
126
+ f'have unexpected behavior.')
127
+
128
+ if dist.is_initialized() and dist.get_rank() == 0:
129
+ dist.barrier() # Wait for other replicas.
130
+
131
+ # Load weights.
132
+ model.load_state_dict(state_dict, strict=False)
133
+ del state_dict
134
+
135
+ # For inference only.
136
+ model.eval().requires_grad_(False).cuda()
137
+ InceptionModel.models[fingerprint] = model
138
+
139
+ return InceptionModel.models[fingerprint]
140
+
141
+ # pylint: disable=missing-function-docstring
142
+ # pylint: disable=missing-class-docstring
143
+ # pylint: disable=super-with-arguments
144
+ # pylint: disable=consider-merging-isinstance
145
+ # pylint: disable=import-outside-toplevel
146
+ # pylint: disable=no-else-return
147
+
148
+ class Inception3(nn.Module):
149
+
150
+ def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
151
+ init_weights=True, align_tf=True):
152
+ super(Inception3, self).__init__()
153
+ if inception_blocks is None:
154
+ inception_blocks = [
155
+ BasicConv2d, InceptionA, InceptionB, InceptionC,
156
+ InceptionD, InceptionE, InceptionAux
157
+ ]
158
+ assert len(inception_blocks) == 7
159
+ conv_block = inception_blocks[0]
160
+ inception_a = inception_blocks[1]
161
+ inception_b = inception_blocks[2]
162
+ inception_c = inception_blocks[3]
163
+ inception_d = inception_blocks[4]
164
+ inception_e = inception_blocks[5]
165
+ inception_aux = inception_blocks[6]
166
+
167
+ self.aux_logits = aux_logits
168
+ self.align_tf = align_tf
169
+ self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
170
+ self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
171
+ self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
172
+ self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
173
+ self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
174
+ self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
175
+ self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
176
+ self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
177
+ self.Mixed_6a = inception_b(288)
178
+ self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
179
+ self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
180
+ self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
181
+ self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
182
+ if aux_logits:
183
+ self.AuxLogits = inception_aux(768, num_classes)
184
+ self.Mixed_7a = inception_d(768)
185
+ self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
186
+ self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
187
+ self.fc = nn.Linear(2048, num_classes)
188
+ if init_weights:
189
+ for m in self.modules():
190
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
191
+ import scipy.stats as stats
192
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
193
+ X = stats.truncnorm(-2, 2, scale=stddev)
194
+ values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
195
+ values = values.view(m.weight.size())
196
+ with torch.no_grad():
197
+ m.weight.copy_(values)
198
+ elif isinstance(m, nn.BatchNorm2d):
199
+ nn.init.constant_(m.weight, 1)
200
+ nn.init.constant_(m.bias, 0)
201
+
202
+ @staticmethod
203
+ def _transform_input(x, transform_input=False):
204
+ if transform_input:
205
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
206
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
207
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
208
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
209
+ return x
210
+
211
+ def _forward(self,
212
+ x,
213
+ output_logits=False,
214
+ remove_logits_bias=False,
215
+ output_predictions=False):
216
+ # Upsample if necessary.
217
+ if x.shape[2] != 299 or x.shape[3] != 299:
218
+ if self.align_tf:
219
+ theta = torch.eye(2, 3).to(x)
220
+ theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
221
+ theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
222
+ theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
223
+ grid = F.affine_grid(theta,
224
+ size=(x.shape[0], x.shape[1], 299, 299),
225
+ align_corners=False)
226
+ x = F.grid_sample(x, grid,
227
+ mode='bilinear',
228
+ padding_mode='border',
229
+ align_corners=False)
230
+ else:
231
+ x = F.interpolate(
232
+ x, size=(299, 299), mode='bilinear', align_corners=False)
233
+ if x.shape[1] == 1:
234
+ x = x.repeat((1, 3, 1, 1))
235
+
236
+ if self.align_tf:
237
+ x = (x * 127.5 + 127.5 - 128) / 128
238
+
239
+ # N x 3 x 299 x 299
240
+ x = self.Conv2d_1a_3x3(x)
241
+ # N x 32 x 149 x 149
242
+ x = self.Conv2d_2a_3x3(x)
243
+ # N x 32 x 147 x 147
244
+ x = self.Conv2d_2b_3x3(x)
245
+ # N x 64 x 147 x 147
246
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
247
+ # N x 64 x 73 x 73
248
+ x = self.Conv2d_3b_1x1(x)
249
+ # N x 80 x 73 x 73
250
+ x = self.Conv2d_4a_3x3(x)
251
+ # N x 192 x 71 x 71
252
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
253
+ # N x 192 x 35 x 35
254
+ x = self.Mixed_5b(x)
255
+ # N x 256 x 35 x 35
256
+ x = self.Mixed_5c(x)
257
+ # N x 288 x 35 x 35
258
+ x = self.Mixed_5d(x)
259
+ # N x 288 x 35 x 35
260
+ x = self.Mixed_6a(x)
261
+ # N x 768 x 17 x 17
262
+ x = self.Mixed_6b(x)
263
+ # N x 768 x 17 x 17
264
+ x = self.Mixed_6c(x)
265
+ # N x 768 x 17 x 17
266
+ x = self.Mixed_6d(x)
267
+ # N x 768 x 17 x 17
268
+ x = self.Mixed_6e(x)
269
+ # N x 768 x 17 x 17
270
+ if self.training and self.aux_logits:
271
+ aux = self.AuxLogits(x)
272
+ else:
273
+ aux = None
274
+ # N x 768 x 17 x 17
275
+ x = self.Mixed_7a(x)
276
+ # N x 1280 x 8 x 8
277
+ x = self.Mixed_7b(x)
278
+ # N x 2048 x 8 x 8
279
+ x = self.Mixed_7c(x)
280
+ # N x 2048 x 8 x 8
281
+ # Adaptive average pooling
282
+ x = F.adaptive_avg_pool2d(x, (1, 1))
283
+ # N x 2048 x 1 x 1
284
+ x = F.dropout(x, training=self.training)
285
+ # N x 2048 x 1 x 1
286
+ x = torch.flatten(x, 1)
287
+ # N x 2048
288
+ if output_logits or output_predictions:
289
+ x = self.fc(x)
290
+ # N x 1000 (num_classes)
291
+ if remove_logits_bias:
292
+ x = x - self.fc.bias.view(1, -1)
293
+ if output_predictions:
294
+ x = F.softmax(x, dim=1)
295
+ return x, aux
296
+
297
+ def forward(self,
298
+ x,
299
+ transform_input=False,
300
+ output_logits=False,
301
+ remove_logits_bias=False,
302
+ output_predictions=False):
303
+ x = self._transform_input(x, transform_input)
304
+ x, aux = self._forward(
305
+ x, output_logits, remove_logits_bias, output_predictions)
306
+ if self.training and self.aux_logits:
307
+ return x, aux
308
+ else:
309
+ return x
310
+
311
+
312
+ class InceptionA(nn.Module):
313
+
314
+ def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
315
+ super(InceptionA, self).__init__()
316
+ if conv_block is None:
317
+ conv_block = BasicConv2d
318
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
319
+
320
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
321
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
322
+
323
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
324
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
325
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
326
+
327
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
328
+ self.pool_include_padding = not align_tf
329
+
330
+ def _forward(self, x):
331
+ branch1x1 = self.branch1x1(x)
332
+
333
+ branch5x5 = self.branch5x5_1(x)
334
+ branch5x5 = self.branch5x5_2(branch5x5)
335
+
336
+ branch3x3dbl = self.branch3x3dbl_1(x)
337
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
338
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
339
+
340
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
341
+ count_include_pad=self.pool_include_padding)
342
+ branch_pool = self.branch_pool(branch_pool)
343
+
344
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
345
+ return outputs
346
+
347
+ def forward(self, x):
348
+ outputs = self._forward(x)
349
+ return torch.cat(outputs, 1)
350
+
351
+
352
+ class InceptionB(nn.Module):
353
+
354
+ def __init__(self, in_channels, conv_block=None):
355
+ super(InceptionB, self).__init__()
356
+ if conv_block is None:
357
+ conv_block = BasicConv2d
358
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
359
+
360
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
361
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
362
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
363
+
364
+ def _forward(self, x):
365
+ branch3x3 = self.branch3x3(x)
366
+
367
+ branch3x3dbl = self.branch3x3dbl_1(x)
368
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
369
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
370
+
371
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
372
+
373
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
374
+ return outputs
375
+
376
+ def forward(self, x):
377
+ outputs = self._forward(x)
378
+ return torch.cat(outputs, 1)
379
+
380
+
381
+ class InceptionC(nn.Module):
382
+
383
+ def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
384
+ super(InceptionC, self).__init__()
385
+ if conv_block is None:
386
+ conv_block = BasicConv2d
387
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
388
+
389
+ c7 = channels_7x7
390
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
391
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
392
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
393
+
394
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
395
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
396
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
397
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
398
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
399
+
400
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
401
+ self.pool_include_padding = not align_tf
402
+
403
+ def _forward(self, x):
404
+ branch1x1 = self.branch1x1(x)
405
+
406
+ branch7x7 = self.branch7x7_1(x)
407
+ branch7x7 = self.branch7x7_2(branch7x7)
408
+ branch7x7 = self.branch7x7_3(branch7x7)
409
+
410
+ branch7x7dbl = self.branch7x7dbl_1(x)
411
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
412
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
413
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
414
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
415
+
416
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
417
+ count_include_pad=self.pool_include_padding)
418
+ branch_pool = self.branch_pool(branch_pool)
419
+
420
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
421
+ return outputs
422
+
423
+ def forward(self, x):
424
+ outputs = self._forward(x)
425
+ return torch.cat(outputs, 1)
426
+
427
+
428
+ class InceptionD(nn.Module):
429
+
430
+ def __init__(self, in_channels, conv_block=None):
431
+ super(InceptionD, self).__init__()
432
+ if conv_block is None:
433
+ conv_block = BasicConv2d
434
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
435
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
436
+
437
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
438
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
439
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
440
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
441
+
442
+ def _forward(self, x):
443
+ branch3x3 = self.branch3x3_1(x)
444
+ branch3x3 = self.branch3x3_2(branch3x3)
445
+
446
+ branch7x7x3 = self.branch7x7x3_1(x)
447
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
448
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
449
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
450
+
451
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
452
+ outputs = [branch3x3, branch7x7x3, branch_pool]
453
+ return outputs
454
+
455
+ def forward(self, x):
456
+ outputs = self._forward(x)
457
+ return torch.cat(outputs, 1)
458
+
459
+
460
+ class InceptionE(nn.Module):
461
+
462
+ def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
463
+ super(InceptionE, self).__init__()
464
+ if conv_block is None:
465
+ conv_block = BasicConv2d
466
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
467
+
468
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
469
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
470
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
471
+
472
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
473
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
474
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
475
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
476
+
477
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
478
+ self.pool_include_padding = not align_tf
479
+ self.use_max_pool = use_max_pool
480
+
481
+ def _forward(self, x):
482
+ branch1x1 = self.branch1x1(x)
483
+
484
+ branch3x3 = self.branch3x3_1(x)
485
+ branch3x3 = [
486
+ self.branch3x3_2a(branch3x3),
487
+ self.branch3x3_2b(branch3x3),
488
+ ]
489
+ branch3x3 = torch.cat(branch3x3, 1)
490
+
491
+ branch3x3dbl = self.branch3x3dbl_1(x)
492
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
493
+ branch3x3dbl = [
494
+ self.branch3x3dbl_3a(branch3x3dbl),
495
+ self.branch3x3dbl_3b(branch3x3dbl),
496
+ ]
497
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
498
+
499
+ if self.use_max_pool:
500
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
501
+ else:
502
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
503
+ count_include_pad=self.pool_include_padding)
504
+ branch_pool = self.branch_pool(branch_pool)
505
+
506
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
507
+ return outputs
508
+
509
+ def forward(self, x):
510
+ outputs = self._forward(x)
511
+ return torch.cat(outputs, 1)
512
+
513
+
514
+ class InceptionAux(nn.Module):
515
+
516
+ def __init__(self, in_channels, num_classes, conv_block=None):
517
+ super(InceptionAux, self).__init__()
518
+ if conv_block is None:
519
+ conv_block = BasicConv2d
520
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
521
+ self.conv1 = conv_block(128, 768, kernel_size=5)
522
+ self.conv1.stddev = 0.01
523
+ self.fc = nn.Linear(768, num_classes)
524
+ self.fc.stddev = 0.001
525
+
526
+ def forward(self, x):
527
+ # N x 768 x 17 x 17
528
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
529
+ # N x 768 x 5 x 5
530
+ x = self.conv0(x)
531
+ # N x 128 x 5 x 5
532
+ x = self.conv1(x)
533
+ # N x 768 x 1 x 1
534
+ # Adaptive average pooling
535
+ x = F.adaptive_avg_pool2d(x, (1, 1))
536
+ # N x 768 x 1 x 1
537
+ x = torch.flatten(x, 1)
538
+ # N x 768
539
+ x = self.fc(x)
540
+ # N x 1000
541
+ return x
542
+
543
+
544
+ class BasicConv2d(nn.Module):
545
+
546
+ def __init__(self, in_channels, out_channels, **kwargs):
547
+ super(BasicConv2d, self).__init__()
548
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
549
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
550
+
551
+ def forward(self, x):
552
+ x = self.conv(x)
553
+ x = self.bn(x)
554
+ return F.relu(x, inplace=True)
555
+
556
+ # pylint: enable=line-too-long
557
+ # pylint: enable=missing-function-docstring
558
+ # pylint: enable=missing-class-docstring
559
+ # pylint: enable=super-with-arguments
560
+ # pylint: enable=consider-merging-isinstance
561
+ # pylint: enable=import-outside-toplevel
562
+ # pylint: enable=no-else-return
models/perceptual_model.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the VGG16 model, which is used for inference ONLY.
3
+
4
+ VGG16 is commonly used for perceptual feature extraction. The model implemented
5
+ in this file can be used for evaluation (like computing LPIPS, perceptual path
6
+ length, etc.), OR be used in training for loss computation (like perceptual
7
+ loss, etc.).
8
+
9
+ The pre-trained model is officially shared by
10
+
11
+ https://www.robots.ox.ac.uk/~vgg/research/very_deep/
12
+
13
+ and ported by
14
+
15
+ https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
16
+
17
+ Compared to the official VGG16 model, this ported model also support evaluating
18
+ LPIPS, which is introduced in
19
+
20
+ https://github.com/richzhang/PerceptualSimilarity
21
+ """
22
+
23
+ import warnings
24
+ import numpy as np
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch.distributed as dist
30
+
31
+ from utils.misc import download_url
32
+
33
+ __all__ = ['PerceptualModel']
34
+
35
+ # pylint: disable=line-too-long
36
+ _MODEL_URL_SHA256 = {
37
+ # This model is provided by `torchvision`, which is ported from TensorFlow.
38
+ 'torchvision_official': (
39
+ 'https://download.pytorch.org/models/vgg16-397923af.pth',
40
+ '397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256
41
+ ),
42
+
43
+ # This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch
44
+ 'vgg_perceptual_lpips': (
45
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt',
46
+ 'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256
47
+ )
48
+ }
49
+ # pylint: enable=line-too-long
50
+
51
+
52
+ class PerceptualModel(object):
53
+ """Defines the perceptual model, which is based on VGG16 structure.
54
+
55
+ This is a static class, which is used to avoid this model to be built
56
+ repeatedly. Consequently, this model is particularly used for inference,
57
+ like computing LPIPS, or for loss computation, like perceptual loss. If
58
+ training is required, please use the model from `torchvision.models` or
59
+ implement by yourself.
60
+
61
+ NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
62
+ order and pixel range [-1, 1], and will NOT resize the input automatically
63
+ if only perceptual feature is needed.
64
+ """
65
+ models = dict()
66
+
67
+ @staticmethod
68
+ def build_model(use_torchvision=False, no_top=True, enable_lpips=True):
69
+ """Builds the model and load pre-trained weights.
70
+
71
+ 1. If `use_torchvision` is set as True, the model released by
72
+ `torchvision` will be loaded, otherwise, the model released by
73
+ https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used.
74
+ (default: False)
75
+
76
+ 2. To save computing resources, these is an option to only load the
77
+ backbone (i.e., without the last three fully-connected layers). This
78
+ is commonly used for perceptual loss or LPIPS loss computation.
79
+ Please use argument `no_top` to control this. (default: True)
80
+
81
+ 3. For LPIPS loss computation, some additional weights (which is used
82
+ for balancing the features from different resolutions) are employed
83
+ on top of the original VGG16 backbone. Details can be found at
84
+ https://github.com/richzhang/PerceptualSimilarity. Please use
85
+ `enable_lpips` to enable this feature. (default: True)
86
+
87
+ The built model supports following arguments when forwarding:
88
+
89
+ - resize_input: Whether to resize the input image to size [224, 224]
90
+ before forwarding. For feature-based computation (i.e., only
91
+ convolutional layers are used), image resizing is not essential.
92
+ (default: False)
93
+ - return_tensor: This field resolves the model behavior. Following
94
+ options are supported:
95
+ `feature1`: Before the first max pooling layer.
96
+ `pool1`: After the first max pooling layer.
97
+ `feature2`: Before the second max pooling layer.
98
+ `pool2`: After the second max pooling layer.
99
+ `feature3`: Before the third max pooling layer.
100
+ `pool3`: After the third max pooling layer.
101
+ `feature4`: Before the fourth max pooling layer.
102
+ `pool4`: After the fourth max pooling layer.
103
+ `feature5`: Before the fifth max pooling layer.
104
+ `pool5`: After the fifth max pooling layer.
105
+ `flatten`: The flattened feature, after `adaptive_avgpool`.
106
+ `feature`: The 4096d feature for logits computation. (default)
107
+ `logits`: The 1000d categorical logits.
108
+ `prediction`: The 1000d predicted probability.
109
+ `lpips`: The LPIPS score between two input images.
110
+ """
111
+ if use_torchvision:
112
+ model_source = 'torchvision_official'
113
+ align_tf_resize = False
114
+ is_torch_script = False
115
+ else:
116
+ model_source = 'vgg_perceptual_lpips'
117
+ align_tf_resize = True
118
+ is_torch_script = True
119
+
120
+ if enable_lpips and model_source != 'vgg_perceptual_lpips':
121
+ warnings.warn('The pre-trained model officially released by '
122
+ '`torchvision` does not support LPIPS computation! '
123
+ 'Equal weights will be used for each resolution.')
124
+
125
+ fingerprint = (model_source, no_top, enable_lpips)
126
+
127
+ if fingerprint not in PerceptualModel.models:
128
+ # Build model.
129
+ model = VGG16(align_tf_resize=align_tf_resize,
130
+ no_top=no_top,
131
+ enable_lpips=enable_lpips)
132
+
133
+ # Download pre-trained weights.
134
+ if dist.is_initialized() and dist.get_rank() != 0:
135
+ dist.barrier() # Download by chief.
136
+
137
+ url, sha256 = _MODEL_URL_SHA256[model_source]
138
+ filename = f'perceptual_model_{model_source}_{sha256}.pth'
139
+ model_path, hash_check = download_url(url,
140
+ filename=filename,
141
+ sha256=sha256)
142
+ if is_torch_script:
143
+ src_state_dict = torch.jit.load(model_path, map_location='cpu')
144
+ else:
145
+ src_state_dict = torch.load(model_path, map_location='cpu')
146
+ if hash_check is False:
147
+ warnings.warn(f'Hash check failed! The remote file from URL '
148
+ f'`{url}` may be changed, or the downloading is '
149
+ f'interrupted. The loaded perceptual model may '
150
+ f'have unexpected behavior.')
151
+
152
+ if dist.is_initialized() and dist.get_rank() == 0:
153
+ dist.barrier() # Wait for other replicas.
154
+
155
+ # Load weights.
156
+ dst_state_dict = _convert_weights(src_state_dict, model_source)
157
+ model.load_state_dict(dst_state_dict, strict=False)
158
+ del src_state_dict, dst_state_dict
159
+
160
+ # For inference only.
161
+ model.eval().requires_grad_(False).cuda()
162
+ PerceptualModel.models[fingerprint] = model
163
+
164
+ return PerceptualModel.models[fingerprint]
165
+
166
+
167
+ def _convert_weights(src_state_dict, model_source):
168
+ if model_source not in _MODEL_URL_SHA256:
169
+ raise ValueError(f'Invalid model source `{model_source}`!\n'
170
+ f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.')
171
+ if model_source == 'torchvision_official':
172
+ dst_to_src_var_mapping = {
173
+ 'conv11.weight': 'features.0.weight',
174
+ 'conv11.bias': 'features.0.bias',
175
+ 'conv12.weight': 'features.2.weight',
176
+ 'conv12.bias': 'features.2.bias',
177
+ 'conv21.weight': 'features.5.weight',
178
+ 'conv21.bias': 'features.5.bias',
179
+ 'conv22.weight': 'features.7.weight',
180
+ 'conv22.bias': 'features.7.bias',
181
+ 'conv31.weight': 'features.10.weight',
182
+ 'conv31.bias': 'features.10.bias',
183
+ 'conv32.weight': 'features.12.weight',
184
+ 'conv32.bias': 'features.12.bias',
185
+ 'conv33.weight': 'features.14.weight',
186
+ 'conv33.bias': 'features.14.bias',
187
+ 'conv41.weight': 'features.17.weight',
188
+ 'conv41.bias': 'features.17.bias',
189
+ 'conv42.weight': 'features.19.weight',
190
+ 'conv42.bias': 'features.19.bias',
191
+ 'conv43.weight': 'features.21.weight',
192
+ 'conv43.bias': 'features.21.bias',
193
+ 'conv51.weight': 'features.24.weight',
194
+ 'conv51.bias': 'features.24.bias',
195
+ 'conv52.weight': 'features.26.weight',
196
+ 'conv52.bias': 'features.26.bias',
197
+ 'conv53.weight': 'features.28.weight',
198
+ 'conv53.bias': 'features.28.bias',
199
+ 'fc1.weight': 'classifier.0.weight',
200
+ 'fc1.bias': 'classifier.0.bias',
201
+ 'fc2.weight': 'classifier.3.weight',
202
+ 'fc2.bias': 'classifier.3.bias',
203
+ 'fc3.weight': 'classifier.6.weight',
204
+ 'fc3.bias': 'classifier.6.bias',
205
+ }
206
+ elif model_source == 'vgg_perceptual_lpips':
207
+ src_state_dict = src_state_dict.state_dict()
208
+ dst_to_src_var_mapping = {
209
+ 'conv11.weight': 'layers.conv1.weight',
210
+ 'conv11.bias': 'layers.conv1.bias',
211
+ 'conv12.weight': 'layers.conv2.weight',
212
+ 'conv12.bias': 'layers.conv2.bias',
213
+ 'conv21.weight': 'layers.conv3.weight',
214
+ 'conv21.bias': 'layers.conv3.bias',
215
+ 'conv22.weight': 'layers.conv4.weight',
216
+ 'conv22.bias': 'layers.conv4.bias',
217
+ 'conv31.weight': 'layers.conv5.weight',
218
+ 'conv31.bias': 'layers.conv5.bias',
219
+ 'conv32.weight': 'layers.conv6.weight',
220
+ 'conv32.bias': 'layers.conv6.bias',
221
+ 'conv33.weight': 'layers.conv7.weight',
222
+ 'conv33.bias': 'layers.conv7.bias',
223
+ 'conv41.weight': 'layers.conv8.weight',
224
+ 'conv41.bias': 'layers.conv8.bias',
225
+ 'conv42.weight': 'layers.conv9.weight',
226
+ 'conv42.bias': 'layers.conv9.bias',
227
+ 'conv43.weight': 'layers.conv10.weight',
228
+ 'conv43.bias': 'layers.conv10.bias',
229
+ 'conv51.weight': 'layers.conv11.weight',
230
+ 'conv51.bias': 'layers.conv11.bias',
231
+ 'conv52.weight': 'layers.conv12.weight',
232
+ 'conv52.bias': 'layers.conv12.bias',
233
+ 'conv53.weight': 'layers.conv13.weight',
234
+ 'conv53.bias': 'layers.conv13.bias',
235
+ 'fc1.weight': 'layers.fc1.weight',
236
+ 'fc1.bias': 'layers.fc1.bias',
237
+ 'fc2.weight': 'layers.fc2.weight',
238
+ 'fc2.bias': 'layers.fc2.bias',
239
+ 'fc3.weight': 'layers.fc3.weight',
240
+ 'fc3.bias': 'layers.fc3.bias',
241
+ 'lpips.0.weight': 'lpips0',
242
+ 'lpips.1.weight': 'lpips1',
243
+ 'lpips.2.weight': 'lpips2',
244
+ 'lpips.3.weight': 'lpips3',
245
+ 'lpips.4.weight': 'lpips4',
246
+ }
247
+ else:
248
+ raise NotImplementedError(f'Not implemented model source '
249
+ f'`{model_source}`!')
250
+
251
+ dst_state_dict = {}
252
+ for dst_name, src_name in dst_to_src_var_mapping.items():
253
+ if dst_name.startswith('lpips'):
254
+ dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0)
255
+ else:
256
+ dst_state_dict[dst_name] = src_state_dict[src_name].clone()
257
+ return dst_state_dict
258
+
259
+
260
+ _IMG_MEAN = (0.485, 0.456, 0.406)
261
+ _IMG_STD = (0.229, 0.224, 0.225)
262
+ _ALLOWED_RETURN = [
263
+ 'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4',
264
+ 'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction',
265
+ 'lpips'
266
+ ]
267
+
268
+ # pylint: disable=missing-function-docstring
269
+
270
+ class VGG16(nn.Module):
271
+ """Defines the VGG16 structure.
272
+
273
+ This model takes `RGB` images with data format `NCHW` as the raw inputs. The
274
+ pixel range are assumed to be [-1, 1].
275
+ """
276
+
277
+ def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True):
278
+ """Defines the network structure."""
279
+ super().__init__()
280
+
281
+ self.align_tf_resize = align_tf_resize
282
+ self.no_top = no_top
283
+ self.enable_lpips = enable_lpips
284
+
285
+ self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
286
+ self.relu11 = nn.ReLU(inplace=True)
287
+ self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
288
+ self.relu12 = nn.ReLU(inplace=True)
289
+ # output `feature1`, with shape [N, 64, 224, 224]
290
+
291
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
292
+ # output `pool1`, with shape [N, 64, 112, 112]
293
+
294
+ self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
295
+ self.relu21 = nn.ReLU(inplace=True)
296
+ self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
297
+ self.relu22 = nn.ReLU(inplace=True)
298
+ # output `feature2`, with shape [N, 128, 112, 112]
299
+
300
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
301
+ # output `pool2`, with shape [N, 128, 56, 56]
302
+
303
+ self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
304
+ self.relu31 = nn.ReLU(inplace=True)
305
+ self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
306
+ self.relu32 = nn.ReLU(inplace=True)
307
+ self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
308
+ self.relu33 = nn.ReLU(inplace=True)
309
+ # output `feature3`, with shape [N, 256, 56, 56]
310
+
311
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
312
+ # output `pool3`, with shape [N,256, 28, 28]
313
+
314
+ self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
315
+ self.relu41 = nn.ReLU(inplace=True)
316
+ self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
317
+ self.relu42 = nn.ReLU(inplace=True)
318
+ self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
319
+ self.relu43 = nn.ReLU(inplace=True)
320
+ # output `feature4`, with shape [N, 512, 28, 28]
321
+
322
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
323
+ # output `pool4`, with shape [N, 512, 14, 14]
324
+
325
+ self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
326
+ self.relu51 = nn.ReLU(inplace=True)
327
+ self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
328
+ self.relu52 = nn.ReLU(inplace=True)
329
+ self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
330
+ self.relu53 = nn.ReLU(inplace=True)
331
+ # output `feature5`, with shape [N, 512, 14, 14]
332
+
333
+ self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
334
+ # output `pool5`, with shape [N, 512, 7, 7]
335
+
336
+ if self.enable_lpips:
337
+ self.lpips = nn.ModuleList()
338
+ for idx, ch in enumerate([64, 128, 256, 512, 512]):
339
+ self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False))
340
+ self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1))
341
+
342
+ if not self.no_top:
343
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
344
+ self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
345
+ # output `flatten`, with shape [N, 25088]
346
+
347
+ self.fc1 = nn.Linear(512 * 7 * 7, 4096)
348
+ self.fc1_relu = nn.ReLU(inplace=True)
349
+ self.fc1_dropout = nn.Dropout(0.5, inplace=False)
350
+ self.fc2 = nn.Linear(4096, 4096)
351
+ self.fc2_relu = nn.ReLU(inplace=True)
352
+ self.fc2_dropout = nn.Dropout(0.5, inplace=False)
353
+ # output `feature`, with shape [N, 4096]
354
+
355
+ self.fc3 = nn.Linear(4096, 1000)
356
+ # output `logits`, with shape [N, 1000]
357
+
358
+ self.out = nn.Softmax(dim=1)
359
+ # output `softmax`, with shape [N, 1000]
360
+
361
+ img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32)
362
+ img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32)
363
+ self.register_buffer('img_mean', torch.from_numpy(img_mean))
364
+ self.register_buffer('img_std', torch.from_numpy(img_std))
365
+
366
+ def forward(self,
367
+ x,
368
+ y=None,
369
+ *,
370
+ resize_input=False,
371
+ return_tensor='feature'):
372
+ return_tensor = return_tensor.lower()
373
+ if return_tensor not in _ALLOWED_RETURN:
374
+ raise ValueError(f'Invalid output tensor name `{return_tensor}` '
375
+ f'for perceptual model (VGG16)!\n'
376
+ f'Names allowed: {_ALLOWED_RETURN}.')
377
+
378
+ if return_tensor == 'lpips' and y is None:
379
+ raise ValueError('Two images are required for LPIPS computation, '
380
+ 'but only one is received!')
381
+
382
+ if return_tensor == 'lpips':
383
+ assert x.shape == y.shape
384
+ x = torch.cat([x, y], dim=0)
385
+ features = []
386
+
387
+ if resize_input:
388
+ if self.align_tf_resize:
389
+ theta = torch.eye(2, 3).to(x)
390
+ theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224
391
+ theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224
392
+ theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
393
+ grid = F.affine_grid(theta,
394
+ size=(x.shape[0], x.shape[1], 224, 224),
395
+ align_corners=False)
396
+ x = F.grid_sample(x, grid,
397
+ mode='bilinear',
398
+ padding_mode='border',
399
+ align_corners=False)
400
+ else:
401
+ x = F.interpolate(x,
402
+ size=(224, 224),
403
+ mode='bilinear',
404
+ align_corners=False)
405
+ if x.shape[1] == 1:
406
+ x = x.repeat((1, 3, 1, 1))
407
+
408
+ x = (x + 1) / 2
409
+ x = (x - self.img_mean) / self.img_std
410
+
411
+ x = self.conv11(x)
412
+ x = self.relu11(x)
413
+ x = self.conv12(x)
414
+ x = self.relu12(x)
415
+ if return_tensor == 'feature1':
416
+ return x
417
+ if return_tensor == 'lpips':
418
+ features.append(x)
419
+
420
+ x = self.pool1(x)
421
+ if return_tensor == 'pool1':
422
+ return x
423
+
424
+ x = self.conv21(x)
425
+ x = self.relu21(x)
426
+ x = self.conv22(x)
427
+ x = self.relu22(x)
428
+ if return_tensor == 'feature2':
429
+ return x
430
+ if return_tensor == 'lpips':
431
+ features.append(x)
432
+
433
+ x = self.pool2(x)
434
+ if return_tensor == 'pool2':
435
+ return x
436
+
437
+ x = self.conv31(x)
438
+ x = self.relu31(x)
439
+ x = self.conv32(x)
440
+ x = self.relu32(x)
441
+ x = self.conv33(x)
442
+ x = self.relu33(x)
443
+ if return_tensor == 'feature3':
444
+ return x
445
+ if return_tensor == 'lpips':
446
+ features.append(x)
447
+
448
+ x = self.pool3(x)
449
+ if return_tensor == 'pool3':
450
+ return x
451
+
452
+ x = self.conv41(x)
453
+ x = self.relu41(x)
454
+ x = self.conv42(x)
455
+ x = self.relu42(x)
456
+ x = self.conv43(x)
457
+ x = self.relu43(x)
458
+ if return_tensor == 'feature4':
459
+ return x
460
+ if return_tensor == 'lpips':
461
+ features.append(x)
462
+
463
+ x = self.pool4(x)
464
+ if return_tensor == 'pool4':
465
+ return x
466
+
467
+ x = self.conv51(x)
468
+ x = self.relu51(x)
469
+ x = self.conv52(x)
470
+ x = self.relu52(x)
471
+ x = self.conv53(x)
472
+ x = self.relu53(x)
473
+ if return_tensor == 'feature5':
474
+ return x
475
+ if return_tensor == 'lpips':
476
+ features.append(x)
477
+
478
+ x = self.pool5(x)
479
+ if return_tensor == 'pool5':
480
+ return x
481
+
482
+ if return_tensor == 'lpips':
483
+ score = 0
484
+ assert len(features) == 5
485
+ for idx in range(5):
486
+ feature = features[idx]
487
+ norm = feature.norm(dim=1, keepdim=True)
488
+ feature = feature / (norm + 1e-10)
489
+ feature_x, feature_y = feature.chunk(2, dim=0)
490
+ diff = (feature_x - feature_y).square()
491
+ score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False)
492
+ return score.sum(dim=1, keepdim=False)
493
+
494
+ x = self.avgpool(x)
495
+ x = self.flatten(x)
496
+ if return_tensor == 'flatten':
497
+ return x
498
+
499
+ x = self.fc1(x)
500
+ x = self.fc1_relu(x)
501
+ x = self.fc1_dropout(x)
502
+ x = self.fc2(x)
503
+ x = self.fc2_relu(x)
504
+ x = self.fc2_dropout(x)
505
+ if return_tensor == 'feature':
506
+ return x
507
+
508
+ x = self.fc3(x)
509
+ if return_tensor == 'logits':
510
+ return x
511
+
512
+ x = self.out(x)
513
+ if return_tensor == 'prediction':
514
+ return x
515
+
516
+ raise NotImplementedError(f'Output tensor name `{return_tensor}` is '
517
+ f'not implemented!')
518
+
519
+ # pylint: enable=missing-function-docstring