diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..29d53413603f5cf9209f175a40ccbf475af3378a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +FROM nvidia/cuda:11.1.0-devel-ubuntu22.04 + +ENV CUDA_HOME=/usr/local/cuda +ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH} +ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} +ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH} + +# apt install by root user +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + python-is-python3 \ + python3.7-dev \ + python3-pip \ + wget \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html + + +WORKDIR /code + +COPY ./requirements.txt /code/requirements.txt + +RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt + +COPY . . + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..67748a1050de25cfd7829074d4e40090daa3d81a --- /dev/null +++ b/app.py @@ -0,0 +1,267 @@ +import gradio as gr +from models import build_model +from PIL import Image +import numpy as np +import torchvision +import ninja +import torch +from tqdm import trange +import imageio + +checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth' +state = torch.load(checkpoint, map_location='cpu') +G = build_model(**state['model_kwargs_init']['generator_smooth']) +o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False) +G.eval().cuda() +G.backbone.synthesis.input.x_offset =0 +G.backbone.synthesis.input.y_offset =0 +G_kwargs= dict(noise_mode='const', + fused_modulate=False, + impl='cuda', + fp16_res=None) + +def trans(x, y, z, length): + w = h = length + x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256 + y = 0.5 * h - 128 + (y/9 + .5) * 256 + z = z / 9 * 256 + return x, y, z +def get_bev_from_objs(objs, length=256, scale = 6): + h, w = length, length *scale + nc = 14 + canvas = np.zeros([h, w, nc]) + xx = np.ones([h,w]).cumsum(0) + yy = np.ones([h,w]).cumsum(1) + + for x, y, z, shape, color, material, rot in objs: + y, x, z = trans(x, y, z, length) + + feat = [0] * nc + feat[0] = 1 + feat[COLOR_NAME_LIST.index(color) + 1] = 1 + feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1 + feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1 + feat = np.array(feat) + rot_sin = np.sin(rot / 180 * np.pi) + rot_cos = np.cos(rot / 180 * np.pi) + + if shape == 'cube': + mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \ + (np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z) + else: + mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z + canvas[mask] = feat + canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32) + rotate_angle = 0 + canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy() + return canvas + +# COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue'] +COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue'] +SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder'] +MATERIAL_NAME_LIST = ['rubber', 'metal'] + +xy_lib = dict() +xy_lib['B'] = [ + [-2, -1], + [-1, -1], + [-2, 0], + [-2, 1], + [-1, .5], + [0, 1], + [0, 0], + [0, -1], + [0, 2], + [-1, 2], + [-2, 2] +] +xy_lib['B'] = [ + [-2.5, 1.25], + [-2, 2], + [-2, 0.5], + [-2, -0.75], + [-1, -1], + [-1, 2], + [-1, 0], + [-1, 2], + [0, 1], + [0, 0], + [0, -1], + [0, 2], + # [-1, 2], + +] +xy_lib['B'] = [ + [-2.5, 1.25], + [-2, 2], + [-2, 0.5], + [-2, -1], + [-1, -1.25], + [-1, 2], + [-1, 0], + [-1, 2], + [0, 1], + [0, 0], + [0, -1.25], + [0, 2], + # [-1, 2], + +] +xy_lib['R'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + # [-1, 2], + [-2, -1], + [-2, 0], + [-2.25, 2], + [-1, 1] +] +xy_lib['C'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + [-1, 2], + [-2, -1], + # [-2, .5], + [-2, 2], + # [-1, .5] +] +xy_lib['s'] = [ + [0, -1], + [0, 0], + [0, 2], + [-1, -1], + [-1, 2], + [-2, -1], + [-2, 1], + [-2, 2], + [-1, .5] +] + +xy_lib['F'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + # [-1, 2], + [-2, -1], + [-2, .5], + # [-2, 2], + [-1, .5] +] + +xy_lib['c'] = [ + [0.8,1], + # [-0.8,1], + [0,0.1], + [0,1.9], +] + +xy_lib['e'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + [-1, 2], + [-2, -1], + [-2, .5], + [-2, 2], + [-1, .5] +] +xy_lib['n'] = [ + [0,1], + [0,-1], + [0,0.1], + [0,1.9], + [-1,0], + [-2,1], + [-3,-1], + [-3,1], + [-3,0.1], + [-3,1.9], +] +offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8) +s = 'BeRFsCene' +objs = [] +offset = 2 +for idx, c in enumerate(s): + xy = xy_lib[c] + + + color = np.random.choice(COLOR_NAME_LIST) + for i in range(len(xy)): + # while 1: + # is_ok = 1 + # x, y = + + # for prev_x, prev_y in zip(xpool, ypool): + x, y = xy[i] + y *= 1.5 + y -= 0.5 + x -= offset + z = 0.35 + # if idx<4: + # color = np.random.choice(COLOR_NAME_LIST[:-1]) + # else: + # color = 'blue' + shape = 'cube' + material = 'rubber' + rot = 0 + objs.append([x, y, z, shape, color, material, rot]) + offset += offset_x[c] +Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8)) + +batch_size = 1 +code = torch.randn(1, G.z_dim).cuda() +to_pil = torchvision.transforms.ToPILImage() +large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None] +bevs = large_bevs[..., 0: 0+256] +RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660, + 10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000, + 0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000, + 32.0000, 0.0000, 0.0000, 1.0000]], device='cuda') + +print('prepare finish', flush=True) + +def inference(name): + print('inference', name, flush=True) + gen = G(code, RT, bevs) + rgb = gen['gen_output']['image'][0] * .5 + .5 + print('inference', name, flush=True) + return np.array(to_pil(rgb)) + + # to_pil(rgb).save('tmp.png') + # save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png' + # return [save_path] + +with gr.Blocks() as demo: + gr.HTML( + """ + abc + """) + + with gr.Group(): + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + with gr.Row(): + num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames") + + with gr.Row(): + with gr.Row(): + btn = gr.Button("Result") + + gallery = gr.Image(label='img', show_label=True, elem_id="gallery") + + btn.click(fn=inference, inputs=num_frames, outputs=[gallery], postprocess=False) + +demo.queue() +demo.launch(server_name='0.0.0.0', server_port=10093, debug=True, show_error=True) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f6066f37d20f67148afff4519a03e24dd758e9 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,63 @@ +# python3.7 +"""Collects all models.""" + +from .pggan_generator import PGGANGenerator +from .pggan_discriminator import PGGANDiscriminator +from .stylegan_generator import StyleGANGenerator +from .stylegan_discriminator import StyleGANDiscriminator +from .stylegan2_generator import StyleGAN2Generator +from .stylegan2_discriminator import StyleGAN2Discriminator +from .stylegan3_generator import StyleGAN3Generator +from .ghfeat_encoder import GHFeatEncoder +from .perceptual_model import PerceptualModel +from .inception_model import InceptionModel +from .eg3d_generator import EG3DGenerator +from .eg3d_discriminator import DualDiscriminator +from .pigan_generator import PiGANGenerator +from .pigan_discriminator import PiGANDiscriminator +from .volumegan_generator import VolumeGANGenerator +from .volumegan_discriminator import VolumeGANDiscriminator +from .eg3d_generator_fv import EG3DGeneratorFV +from .bev3d_generator import BEV3DGenerator +from .sgbev3d_generator import SGBEV3DGenerator + +__all__ = ['build_model'] + +_MODELS = { + 'PGGANGenerator': PGGANGenerator, + 'PGGANDiscriminator': PGGANDiscriminator, + 'StyleGANGenerator': StyleGANGenerator, + 'StyleGANDiscriminator': StyleGANDiscriminator, + 'StyleGAN2Generator': StyleGAN2Generator, + 'StyleGAN2Discriminator': StyleGAN2Discriminator, + 'StyleGAN3Generator': StyleGAN3Generator, + 'GHFeatEncoder': GHFeatEncoder, + 'PerceptualModel': PerceptualModel.build_model, + 'InceptionModel': InceptionModel.build_model, + 'EG3DGenerator': EG3DGenerator, + 'EG3DDiscriminator': DualDiscriminator, + 'PiGANGenerator': PiGANGenerator, + 'PiGANDiscriminator': PiGANDiscriminator, + 'VolumeGANGenerator': VolumeGANGenerator, + 'VolumeGANDiscriminator': VolumeGANDiscriminator, + 'EG3DGeneratorFV': EG3DGeneratorFV, + 'BEV3DGenerator': BEV3DGenerator, + 'SGBEV3DGenerator': SGBEV3DGenerator, +} + + +def build_model(model_type, **kwargs): + """Builds a model based on its class type. + + Args: + model_type: Class type to which the model belongs, which is case + sensitive. + **kwargs: Additional arguments to build the model. + + Raises: + ValueError: If the `model_type` is not supported. + """ + if model_type not in _MODELS: + raise ValueError(f'Invalid model type: `{model_type}`!\n' + f'Types allowed: {list(_MODELS)}.') + return _MODELS[model_type](**kwargs) diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ad8987a0c802f012f8fd5ef33c18219fede4914 Binary files /dev/null and b/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554ec990b6ad3553953f71847ca764ec6b180884 Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/__pycache__/bev3d_generator.cpython-37.pyc b/models/__pycache__/bev3d_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cddbd3800b8984aa4fd3a0ca5e2ce5ec84dce79 Binary files /dev/null and b/models/__pycache__/bev3d_generator.cpython-37.pyc differ diff --git a/models/__pycache__/bev3d_generator.cpython-39.pyc b/models/__pycache__/bev3d_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9505171143985103b692d5c71686126240d9d7 Binary files /dev/null and b/models/__pycache__/bev3d_generator.cpython-39.pyc differ diff --git a/models/__pycache__/eg3d_discriminator.cpython-37.pyc b/models/__pycache__/eg3d_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e00fdc53df9a3a481894943fa2447ce3bb3be82a Binary files /dev/null and b/models/__pycache__/eg3d_discriminator.cpython-37.pyc differ diff --git a/models/__pycache__/eg3d_discriminator.cpython-39.pyc b/models/__pycache__/eg3d_discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5451647ec2d7a21e1bc4ebaa89bbc0bf89b5c80b Binary files /dev/null and b/models/__pycache__/eg3d_discriminator.cpython-39.pyc differ diff --git a/models/__pycache__/eg3d_generator.cpython-37.pyc b/models/__pycache__/eg3d_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b4e4567745b27854821fa4ddca545751ef6209 Binary files /dev/null and b/models/__pycache__/eg3d_generator.cpython-37.pyc differ diff --git a/models/__pycache__/eg3d_generator.cpython-39.pyc b/models/__pycache__/eg3d_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58a656f79e7bae3bc91497227745e5cf00a8c014 Binary files /dev/null and b/models/__pycache__/eg3d_generator.cpython-39.pyc differ diff --git a/models/__pycache__/eg3d_generator_fv.cpython-37.pyc b/models/__pycache__/eg3d_generator_fv.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..971041c1b9add3e50dc4b2e7da90a9be714c2d27 Binary files /dev/null and b/models/__pycache__/eg3d_generator_fv.cpython-37.pyc differ diff --git a/models/__pycache__/eg3d_generator_fv.cpython-39.pyc b/models/__pycache__/eg3d_generator_fv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a8849e2fe24e00fd23af5911159b6f9ac5453d Binary files /dev/null and b/models/__pycache__/eg3d_generator_fv.cpython-39.pyc differ diff --git a/models/__pycache__/ghfeat_encoder.cpython-37.pyc b/models/__pycache__/ghfeat_encoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc05ed52c527f75332300058f1a1a8269dd625ff Binary files /dev/null and b/models/__pycache__/ghfeat_encoder.cpython-37.pyc differ diff --git a/models/__pycache__/ghfeat_encoder.cpython-39.pyc b/models/__pycache__/ghfeat_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e58927ba46f18de44cd004661ea9436939a69e Binary files /dev/null and b/models/__pycache__/ghfeat_encoder.cpython-39.pyc differ diff --git a/models/__pycache__/inception_model.cpython-37.pyc b/models/__pycache__/inception_model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c29ea2e425bb222bbf26ed6ff26bfd07ba54902d Binary files /dev/null and b/models/__pycache__/inception_model.cpython-37.pyc differ diff --git a/models/__pycache__/inception_model.cpython-39.pyc b/models/__pycache__/inception_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c43350fd22f69c79646cb9936e256f77eb1bec19 Binary files /dev/null and b/models/__pycache__/inception_model.cpython-39.pyc differ diff --git a/models/__pycache__/perceptual_model.cpython-37.pyc b/models/__pycache__/perceptual_model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb811ea79064bc510881e20b5b13684c6f296df8 Binary files /dev/null and b/models/__pycache__/perceptual_model.cpython-37.pyc differ diff --git a/models/__pycache__/perceptual_model.cpython-39.pyc b/models/__pycache__/perceptual_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..091af90a2b09b64cb7547fcf28b423e4f1c22d85 Binary files /dev/null and b/models/__pycache__/perceptual_model.cpython-39.pyc differ diff --git a/models/__pycache__/pggan_discriminator.cpython-37.pyc b/models/__pycache__/pggan_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0439f4041c4d075e4b66fb81e4064bc195f2e3e4 Binary files /dev/null and b/models/__pycache__/pggan_discriminator.cpython-37.pyc differ diff --git a/models/__pycache__/pggan_discriminator.cpython-39.pyc b/models/__pycache__/pggan_discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb443ba60d752544a358a654ea67250877ed03eb Binary files /dev/null and b/models/__pycache__/pggan_discriminator.cpython-39.pyc differ diff --git a/models/__pycache__/pggan_generator.cpython-37.pyc b/models/__pycache__/pggan_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82279fc933f7c291c632e564aeeee5513e4fd89a Binary files /dev/null and b/models/__pycache__/pggan_generator.cpython-37.pyc differ diff --git a/models/__pycache__/pggan_generator.cpython-39.pyc b/models/__pycache__/pggan_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a8c90d76a6d792267ba2633f5f623ec399fc05 Binary files /dev/null and b/models/__pycache__/pggan_generator.cpython-39.pyc differ diff --git a/models/__pycache__/pigan_discriminator.cpython-37.pyc b/models/__pycache__/pigan_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdf828d2175efc9b39c14def7e5e0329e7e225cc Binary files /dev/null and b/models/__pycache__/pigan_discriminator.cpython-37.pyc differ diff --git a/models/__pycache__/pigan_discriminator.cpython-39.pyc b/models/__pycache__/pigan_discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696b0f343238098964f764a4da85b755025c79be Binary files /dev/null and b/models/__pycache__/pigan_discriminator.cpython-39.pyc differ diff --git a/models/__pycache__/pigan_generator.cpython-37.pyc b/models/__pycache__/pigan_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfeecc09dc6ce41ef79e3f5c1ecec0c9c3bd615f Binary files /dev/null and b/models/__pycache__/pigan_generator.cpython-37.pyc differ diff --git a/models/__pycache__/pigan_generator.cpython-39.pyc b/models/__pycache__/pigan_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3930c111b8459c51750620416bc66637c816e5b Binary files /dev/null and b/models/__pycache__/pigan_generator.cpython-39.pyc differ diff --git a/models/__pycache__/sgbev3d_generator.cpython-37.pyc b/models/__pycache__/sgbev3d_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be608fcd6bc27dd0a31598781ebb09f3437587d7 Binary files /dev/null and b/models/__pycache__/sgbev3d_generator.cpython-37.pyc differ diff --git a/models/__pycache__/sgbev3d_generator.cpython-39.pyc b/models/__pycache__/sgbev3d_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81eb677350489cc3917e430bb03cf5433f220b20 Binary files /dev/null and b/models/__pycache__/sgbev3d_generator.cpython-39.pyc differ diff --git a/models/__pycache__/stylegan2_discriminator.cpython-37.pyc b/models/__pycache__/stylegan2_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..858b0f17dc2721e3a4f0514caa1a65304a51a3b8 Binary files /dev/null and b/models/__pycache__/stylegan2_discriminator.cpython-37.pyc differ diff --git a/models/__pycache__/stylegan2_discriminator.cpython-39.pyc b/models/__pycache__/stylegan2_discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb47abd594063316df1ae81c5e8acf028e7dd0fd Binary files /dev/null and b/models/__pycache__/stylegan2_discriminator.cpython-39.pyc differ diff --git a/models/__pycache__/stylegan2_generator.cpython-37.pyc b/models/__pycache__/stylegan2_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf29277caf659c29d7f750fcd2ba4d7950cc94a0 Binary files /dev/null and b/models/__pycache__/stylegan2_generator.cpython-37.pyc differ diff --git a/models/__pycache__/stylegan2_generator.cpython-39.pyc b/models/__pycache__/stylegan2_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ca8fa8dec248852b2b2eba78883d8ff0c059df Binary files /dev/null and b/models/__pycache__/stylegan2_generator.cpython-39.pyc differ diff --git a/models/__pycache__/stylegan3_generator.cpython-37.pyc b/models/__pycache__/stylegan3_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64fc655980fde8231f2d873be116901b41bec7a3 Binary files /dev/null and b/models/__pycache__/stylegan3_generator.cpython-37.pyc differ diff --git a/models/__pycache__/stylegan3_generator.cpython-39.pyc b/models/__pycache__/stylegan3_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bdd89797fcd63aabb19ac49c6e0cada0113c10a Binary files /dev/null and b/models/__pycache__/stylegan3_generator.cpython-39.pyc differ diff --git a/models/__pycache__/stylegan_discriminator.cpython-37.pyc b/models/__pycache__/stylegan_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eeb38e1b9120680c9d0e67da497e765f0524495 Binary files /dev/null and b/models/__pycache__/stylegan_discriminator.cpython-37.pyc differ diff --git a/models/__pycache__/stylegan_discriminator.cpython-39.pyc b/models/__pycache__/stylegan_discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52983a6ede8abb4ab66560a0e348afb75cdc09e5 Binary files /dev/null and b/models/__pycache__/stylegan_discriminator.cpython-39.pyc differ diff --git a/models/__pycache__/stylegan_generator.cpython-37.pyc b/models/__pycache__/stylegan_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7111ad4910c08734d85173fbebc13d14574d77d Binary files /dev/null and b/models/__pycache__/stylegan_generator.cpython-37.pyc differ diff --git a/models/__pycache__/stylegan_generator.cpython-39.pyc b/models/__pycache__/stylegan_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65bc6a4c99a717f363166acf196a1b5aba623786 Binary files /dev/null and b/models/__pycache__/stylegan_generator.cpython-39.pyc differ diff --git a/models/__pycache__/volumegan_discriminator.cpython-37.pyc b/models/__pycache__/volumegan_discriminator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b30fdb3e18fa716f340ab5cc7cae86ab3fdd1f4 Binary files /dev/null and b/models/__pycache__/volumegan_discriminator.cpython-37.pyc differ diff --git a/models/__pycache__/volumegan_discriminator.cpython-39.pyc b/models/__pycache__/volumegan_discriminator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b0a6360d06adf0cb6a49c5bba1b32562935eb79 Binary files /dev/null and b/models/__pycache__/volumegan_discriminator.cpython-39.pyc differ diff --git a/models/__pycache__/volumegan_generator.cpython-37.pyc b/models/__pycache__/volumegan_generator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac334ecdaa38550c39e9e85f90d7284fb54a3205 Binary files /dev/null and b/models/__pycache__/volumegan_generator.cpython-37.pyc differ diff --git a/models/__pycache__/volumegan_generator.cpython-39.pyc b/models/__pycache__/volumegan_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ca81c2c61035f5f7d095f0c3a208382c9e5c23a Binary files /dev/null and b/models/__pycache__/volumegan_generator.cpython-39.pyc differ diff --git a/models/bev3d_generator.py b/models/bev3d_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..400aef98f70875589cd255eb700f002a7fa43fd6 --- /dev/null +++ b/models/bev3d_generator.py @@ -0,0 +1,301 @@ +# python3.8 +"""Contains the implementation of generator described in BEV3D.""" + +import torch +import torch.nn as nn +from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone +from models.utils.official_stylegan2_model_helper import FullyConnectedLayer +from models.utils.eg3d_superres import SuperresolutionHybrid2X +from models.utils.eg3d_superres import SuperresolutionHybrid4X +from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle +from models.utils.eg3d_superres import SuperresolutionHybrid8XDC +from models.rendering.renderer import Renderer +from models.rendering.feature_extractor import FeatureExtractor + +from models.utils.spade import SPADEGenerator + +class BEV3DGenerator(nn.Module): + + def __init__( + self, + z_dim, + semantic_nc, + ngf, + bev_grid_size, + aspect_ratio, + num_upsampling_layers, + not_use_vae, + norm_G, + img_resolution, + interpolate_sr, + segmask=False, + dim_seq='16,8,4,2,1', + xyz_pe=False, + hidden_dim=64, + additional_layer_num=0, + sr_num_fp16_res=0, # Number of fp16 layers of SR Network. + rendering_kwargs={}, # Arguments for rendering. + sr_kwargs={}, # Arguments for SuperResolution Network. + ): + super().__init__() + + self.z_dim = z_dim + self.interpolate_sr = interpolate_sr + self.segmask = segmask + + # Set up the overall renderer. + self.renderer = Renderer() + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode='bev_plane_clevr', xyz_pe=xyz_pe) + + # Set up the reference representation generator. + self.backbone = SPADEGenerator(z_dim=z_dim, semantic_nc=semantic_nc, ngf=ngf, dim_seq=dim_seq, bev_grid_size=bev_grid_size, + aspect_ratio=aspect_ratio, num_upsampling_layers=num_upsampling_layers, + not_use_vae=not_use_vae, norm_G=norm_G) + print('backbone SPADEGenerator set up!') + + # Set up the post module in the feature extractor. + self.post_module = None + + # Set up the post neural renderer. + self.post_neural_renderer = None + sr_kwargs_total = dict( + channels=32, + img_resolution=img_resolution, + sr_num_fp16_res=sr_num_fp16_res, + sr_antialias=rendering_kwargs['sr_antialias'],) + sr_kwargs_total.update(**sr_kwargs) + if img_resolution == 128: + self.post_neural_renderer = SuperresolutionHybrid2X( + **sr_kwargs_total) + elif img_resolution == 256: + self.post_neural_renderer = SuperresolutionHybrid4X_conststyle( + **sr_kwargs_total) + elif img_resolution == 512: + self.post_neural_renderer = SuperresolutionHybrid8XDC( + **sr_kwargs_total) + else: + raise TypeError(f'Unsupported image resolution: {img_resolution}!') + + # Set up the fully-connected layer head. + self.fc_head = OSGDecoder( + 128 if xyz_pe else 64 , { + 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': 32 + }, + hidden_dim=hidden_dim, + additional_layer_num=additional_layer_num + ) + + # Set up some rendering related arguments. + self.neural_rendering_resolution = rendering_kwargs.get( + 'resolution', 64) + self.rendering_kwargs = rendering_kwargs + + def synthesis(self, + z, + c, + seg, + neural_rendering_resolution=None, + update_emas=False, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + if self.rendering_kwargs.get('random_pose', False): + cam2world_matrix = None + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + xy_planes = self.backbone(z=z, input=seg) + if self.segmask: + xy_planes = xy_planes * seg[:, 0, ...][:, None, ...] + + # import pdb;pdb.set_trace() + + wp = z # in our case, we do not use wp. + + rendering_result = self.renderer( + wp=wp, + feature_extractor=self.feature_extractor, + rendering_options=self.rendering_kwargs, + cam2world_matrix=cam2world_matrix, + position_encoder=None, + ref_representation=xy_planes, + post_module=self.post_module, + fc_head=self.fc_head) + + feature_samples = rendering_result['composite_rgb'] + depth_samples = rendering_result['composite_depth'] + + # Reshape to keep consistent with 'raw' neural-rendered image. + N = wp.shape[0] + H = W = self.neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape( + N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run the post neural renderer to get final image. + # Here, the post neural renderer is a super-resolution network. + rgb_image = feature_image[:, :3] + if self.interpolate_sr: + sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False) + else: + sr_image = self.post_neural_renderer( + rgb_image, + feature_image, + # wp, + noise_mode=self.rendering_kwargs['superresolution_noise_mode'], + **{ + k: synthesis_kwargs[k] + for k in synthesis_kwargs.keys() if k != 'noise_mode' + }) + + return { + 'image': sr_image, + 'image_raw': rgb_image, + 'image_depth': depth_image + } + + def sample(self, + coordinates, + directions, + z, + c, + seg, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False, + **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. + # Mostly used for extracting shapes. + cam2world_matrix = c[:, :16].view(-1, 4, 4) + xy_planes = self.backbone(z=z, input=seg) + wp = z + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=xy_planes, + post_module=self.post_module, + ray_dirs=directions, + cam_matrix=cam2world_matrix) + + return result + + def sample_mixed(self, + coordinates, + directions, + z, c, seg, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False, + **synthesis_kwargs): + # Same as function `self.sample()`, but expects latent vectors 'wp' + # instead of Gaussian noise 'z'. + cam2world_matrix = c[:, :16].view(-1, 4, 4) + xy_planes = self.backbone(z=z, input=seg) + wp = z + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=xy_planes, + post_module=self.post_module, + ray_dirs=directions, + cam_matrix=cam2world_matrix) + + return result + + def forward(self, + z, + c, + seg, + c_swapped=None, # `c_swapped` is swapped pose conditioning. + style_mixing_prob=0, + truncation_psi=1, + truncation_cutoff=None, + neural_rendering_resolution=None, + update_emas=False, + sample_mixed=False, + coordinates=None, + **synthesis_kwargs): + + # Render a batch of generated images. + c_wp = c.clone() + if c_swapped is not None: + c_wp = c_swapped.clone() + + if not sample_mixed: + gen_output = self.synthesis( + z, + c, + seg, + update_emas=update_emas, + neural_rendering_resolution=neural_rendering_resolution, + **synthesis_kwargs) + + return { + 'wp': z, + 'gen_output': gen_output, + } + + else: + # Only for density regularization in training process. + assert coordinates is not None + sample_sigma = self.sample_mixed(coordinates, + torch.randn_like(coordinates), + z, c, seg, + update_emas=False)['sigma'] + + return { + 'wp': z, + 'sample_sigma': sample_sigma + } + + +class OSGDecoder(nn.Module): + """Defines fully-connected layer head in EG3D.""" + def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0): + super().__init__() + self.hidden_dim = hidden_dim + + lst = [] + lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) + lst.append(nn.Softplus()) + for i in range(additional_layer_num): + lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) + lst.append(nn.Softplus()) + lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])) + self.net = nn.Sequential(*lst) + + # self.net = nn.Sequential( + # FullyConnectedLayer(n_features, + # self.hidden_dim, + # lr_multiplier=options['decoder_lr_mul']), + # nn.Softplus(), + # FullyConnectedLayer(self.hidden_dim, + # 1 + options['decoder_output_dim'], + # lr_multiplier=options['decoder_lr_mul'])) + + def forward(self, point_features, wp=None, dirs=None): + # Aggregate features + # point_features.shape: [N, R, K, C]. + # Average across 'X, Y, Z' planes. + + N, R, K, C = point_features.shape + x = point_features.reshape(-1, point_features.shape[-1]) + x = self.net(x) + x = x.view(N, -1, x.shape[-1]) + + # Uses sigmoid clamping from MipNeRF + rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} diff --git a/models/eg3d_discriminator.py b/models/eg3d_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..79d572488734cb6f5397681cd15b64fc30f4ef4b --- /dev/null +++ b/models/eg3d_discriminator.py @@ -0,0 +1,243 @@ +# python 3.7 +"""Contains the implementation of discriminator described in EG3D.""" + + +import numpy as np +import torch +from third_party.stylegan2_official_ops import upfirdn2d +from models.utils.official_stylegan2_model_helper import DiscriminatorBlock +from models.utils.official_stylegan2_model_helper import MappingNetwork +from models.utils.official_stylegan2_model_helper import DiscriminatorEpilogue + + +class SingleDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + sr_upsample_factor = 1, # Ignored for SingleDiscriminator + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + img = img['image'] + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- + +def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'): + if filter_mode == 'antialiased': + ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False) + elif filter_mode == 'classic': + ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2) + ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False) + ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1) + elif filter_mode == 'none': + ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False) + elif type(filter_mode) == float: + assert 0 < filter_mode < 1 + + filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False) + aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False) + ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered + + return ada_filtered_64 + +#---------------------------------------------------------------------------- + +class DualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + bev_channels = 0, + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels *= 2 + + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + bev_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + self.disc_c_noise = disc_c_noise + + def forward(self, img, c, bev=None, update_emas=False, **block_kwargs): + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) + img = torch.cat([img['image'], image_raw], 1) + if bev is not None: + img = torch.cat([img, bev], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- + +class DummyDualDiscriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + img_channels *= 2 + + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + self.raw_fade = 1 + + def forward(self, img, c, update_emas=False, **block_kwargs): + self.raw_fade = max(0, self.raw_fade - 1/(500000/32)) + + image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade + img = torch.cat([img['image'], image_raw], 1) + + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- + diff --git a/models/eg3d_generator.py b/models/eg3d_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..4284c912234d75f1989c54f186e52d0ce1428007 --- /dev/null +++ b/models/eg3d_generator.py @@ -0,0 +1,315 @@ +# python3.8 +"""Contains the implementation of generator described in EG3D.""" + +import torch +import torch.nn as nn +from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone +from models.utils.official_stylegan2_model_helper import FullyConnectedLayer +from models.utils.eg3d_superres import SuperresolutionHybrid2X +from models.utils.eg3d_superres import SuperresolutionHybrid4X +from models.utils.eg3d_superres import SuperresolutionHybrid8XDC +from models.rendering.renderer import Renderer +from models.rendering.feature_extractor import FeatureExtractor + +class EG3DGenerator(nn.Module): + + def __init__( + self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + sr_num_fp16_res=0, # Number of fp16 layers of SR Network. + mapping_kwargs={}, # Arguments for MappingNetwork. + rendering_kwargs={}, # Arguments for rendering. + sr_kwargs={}, # Arguments for SuperResolution Network. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + # Set up the overall renderer. + self.renderer = Renderer() + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode='tri_plane') + + # Set up the reference representation generator. + self.backbone = StyleGAN2Backbone(z_dim, + c_dim, + w_dim, + img_resolution=256, + img_channels=32 * 3, + mapping_kwargs=mapping_kwargs, + **synthesis_kwargs) + + # Set up the post module in the feature extractor. + self.post_module = None + + # Set up the post neural renderer. + self.post_neural_renderer = None + sr_kwargs_total = dict( + channels=32, + img_resolution=img_resolution, + sr_num_fp16_res=sr_num_fp16_res, + sr_antialias=rendering_kwargs['sr_antialias'],) + sr_kwargs_total.update(**sr_kwargs) + if img_resolution == 128: + self.post_neural_renderer = SuperresolutionHybrid2X( + **sr_kwargs_total) + elif img_resolution == 256: + self.post_neural_renderer = SuperresolutionHybrid4X( + **sr_kwargs_total) + elif img_resolution == 512: + self.post_neural_renderer = SuperresolutionHybrid8XDC( + **sr_kwargs_total) + else: + raise TypeError(f'Unsupported image resolution: {img_resolution}!') + + # Set up the fully-connected layer head. + self.fc_head = OSGDecoder( + 32, { + 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': 32 + }) + + # Set up some rendering related arguments. + self.neural_rendering_resolution = rendering_kwargs.get( + 'resolution', 64) + self.rendering_kwargs = rendering_kwargs + + def mapping(self, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False): + if self.rendering_kwargs['c_gen_conditioning_zero']: + c = torch.zeros_like(c) + return self.backbone.mapping(z, + c * + self.rendering_kwargs.get('c_scale', 0), + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + + def synthesis(self, + wp, + c, + neural_rendering_resolution=None, + update_emas=False, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + if self.rendering_kwargs.get('random_pose', False): + cam2world_matrix = None + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + tri_planes = self.backbone.synthesis(wp, + update_emas=update_emas, + **synthesis_kwargs) + tri_planes = tri_planes.view(len(tri_planes), 3, -1, + tri_planes.shape[-2], + tri_planes.shape[-1]) + + rendering_result = self.renderer( + wp=wp, + feature_extractor=self.feature_extractor, + rendering_options=self.rendering_kwargs, + cam2world_matrix=cam2world_matrix, + position_encoder=None, + ref_representation=tri_planes, + post_module=self.post_module, + fc_head=self.fc_head) + + feature_samples = rendering_result['composite_rgb'] + depth_samples = rendering_result['composite_depth'] + + # Reshape to keep consistent with 'raw' neural-rendered image. + N = wp.shape[0] + H = W = self.neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape( + N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run the post neural renderer to get final image. + # Here, the post neural renderer is a super-resolution network. + rgb_image = feature_image[:, :3] + sr_image = self.post_neural_renderer( + rgb_image, + feature_image, + wp, + noise_mode=self.rendering_kwargs['superresolution_noise_mode'], + **{ + k: synthesis_kwargs[k] + for k in synthesis_kwargs.keys() if k != 'noise_mode' + }) + + return { + 'image': sr_image, + 'image_raw': rgb_image, + 'image_depth': depth_image + } + + def sample(self, + coordinates, + directions, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False, + **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. + # Mostly used for extracting shapes. + wp = self.mapping(z, + c, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + tri_planes = self.backbone.synthesis(wp, + update_emas=update_emas, + **synthesis_kwargs) + tri_planes = tri_planes.view(len(tri_planes), 3, -1, + tri_planes.shape[-2], + tri_planes.shape[-1]) + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=tri_planes, + post_module=self.post_module, + ray_dirs=directions) + + return result + + def sample_mixed(self, + coordinates, + directions, + wp, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False, + **synthesis_kwargs): + # Same as function `self.sample()`, but expects latent vectors 'wp' + # instead of Gaussian noise 'z'. + tri_planes = self.backbone.synthesis(wp, + update_emas=update_emas, + **synthesis_kwargs) + tri_planes = tri_planes.view(len(tri_planes), 3, -1, + tri_planes.shape[-2], + tri_planes.shape[-1]) + + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=tri_planes, + post_module=self.post_module, + ray_dirs=directions) + + return result + + def forward(self, + z, + c, + c_swapped=None, # `c_swapped` is swapped pose conditioning. + style_mixing_prob=0, + truncation_psi=1, + truncation_cutoff=None, + neural_rendering_resolution=None, + update_emas=False, + sample_mixed=False, + coordinates=None, + **synthesis_kwargs): + + # Render a batch of generated images. + c_wp = c.clone() + if c_swapped is not None: + c_wp = c_swapped.clone() + wp = self.mapping(z, + c_wp, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + if style_mixing_prob > 0: + cutoff = torch.empty([], dtype=torch.int64, + device=wp.device).random_(1, wp.shape[1]) + cutoff = torch.where( + torch.rand([], device=wp.device) < style_mixing_prob, + cutoff, torch.full_like(cutoff, wp.shape[1])) + wp[:, cutoff:] = self.mapping(torch.randn_like(z), + c, + update_emas=update_emas)[:, cutoff:] + if not sample_mixed: + gen_output = self.synthesis( + wp, + c, + update_emas=update_emas, + neural_rendering_resolution=neural_rendering_resolution, + **synthesis_kwargs) + + return { + 'wp': wp, + 'gen_output': gen_output, + } + + else: + # Only for density regularization in training process. + assert coordinates is not None + sample_sigma = self.sample_mixed(coordinates, + torch.randn_like(coordinates), + wp, + update_emas=False)['sigma'] + + return { + 'wp': wp, + 'sample_sigma': sample_sigma + } + + +class OSGDecoder(nn.Module): + """Defines fully-connected layer head in EG3D.""" + def __init__(self, n_features, options): + super().__init__() + self.hidden_dim = 64 + + self.net = nn.Sequential( + FullyConnectedLayer(n_features, + self.hidden_dim, + lr_multiplier=options['decoder_lr_mul']), + nn.Softplus(), + FullyConnectedLayer(self.hidden_dim, + 1 + options['decoder_output_dim'], + lr_multiplier=options['decoder_lr_mul'])) + + def forward(self, point_features, wp=None, dirs=None): + # Aggregate features + # point_features.shape: [N, 3, M, C]. + # Average across 'X, Y, Z' planes. + point_features = point_features.mean(1) + x = point_features + + N, M, C = x.shape + x = x.view(N * M, C) + + x = self.net(x) + x = x.view(N, M, -1) + + # Uses sigmoid clamping from MipNeRF + rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} diff --git a/models/eg3d_generator_fv.py b/models/eg3d_generator_fv.py new file mode 100644 index 0000000000000000000000000000000000000000..14cc488278034879a2e0f618c27e3a61ecd5c845 --- /dev/null +++ b/models/eg3d_generator_fv.py @@ -0,0 +1,320 @@ +# python3.8 +"""Contains the implementation of generator described in EG3D.""" + +import torch +import torch.nn as nn +import numpy as np +from models.utils.official_stylegan2_model_helper import MappingNetwork +from models.utils.official_stylegan2_model_helper import FullyConnectedLayer +from models.utils.eg3d_superres import SuperresolutionHybrid2X +from models.utils.eg3d_superres import SuperresolutionHybrid4X +from models.utils.eg3d_superres import SuperresolutionHybrid8XDC +from models.rendering.renderer import Renderer +from models.rendering.feature_extractor import FeatureExtractor +from models.volumegan_generator import FeatureVolume +from models.volumegan_generator import PositionEncoder + + +class EG3DGeneratorFV(nn.Module): + + def __init__( + self, + # Input latent (Z) dimensionality. + z_dim, + # Conditioning label (C) dimensionality. + c_dim, + # Intermediate latent (W) dimensionality. + w_dim, + # Final output image resolution. + img_resolution, + # Number of output color channels. + img_channels, + # Number of fp16 layers of SR Network. + sr_num_fp16_res=0, + # Arguments for MappingNetwork. + mapping_kwargs={}, + # Arguments for rendering. + rendering_kwargs={}, + # Arguments for SuperResolution Network. + sr_kwargs={}, + # Configs for FeatureVolume. + fv_cfg=dict(feat_res=32, + init_res=4, + base_channels=256, + output_channels=32, + w_dim=512), + # Configs for position encoder. + embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10), + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + + # Set up mapping network. + # Here `num_ws = 2`: one for FeatureVolume Network injection and one for + # post_neural_renderer injection. + num_ws = 2 + self.mapping_network = MappingNetwork(z_dim=z_dim, + c_dim=c_dim, + w_dim=w_dim, + num_ws=num_ws, + **mapping_kwargs) + + # Set up the overall renderer. + self.renderer = Renderer() + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode='feature_volume') + + # Set up the reference representation generator. + self.ref_representation_generator = FeatureVolume(**fv_cfg) + + # Set up the position encoder. + self.position_encoder = PositionEncoder(**embed_cfg) + + # Set up the post module in the feature extractor. + self.post_module = None + + # Set up the post neural renderer. + self.post_neural_renderer = None + sr_kwargs_total = dict( + channels=32, + img_resolution=img_resolution, + sr_num_fp16_res=sr_num_fp16_res, + sr_antialias=rendering_kwargs['sr_antialias'],) + sr_kwargs_total.update(**sr_kwargs) + if img_resolution == 128: + self.post_neural_renderer = SuperresolutionHybrid2X( + **sr_kwargs_total) + elif img_resolution == 256: + self.post_neural_renderer = SuperresolutionHybrid4X( + **sr_kwargs_total) + elif img_resolution == 512: + self.post_neural_renderer = SuperresolutionHybrid8XDC( + **sr_kwargs_total) + else: + raise TypeError(f'Unsupported image resolution: {img_resolution}!') + + # Set up the fully-connected layer head. + self.fc_head = OSGDecoder( + 32, { + 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': 32 + }) + + # Set up some rendering related arguments. + self.neural_rendering_resolution = rendering_kwargs.get( + 'resolution', 64) + self.rendering_kwargs = rendering_kwargs + + def mapping(self, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False): + if self.rendering_kwargs['c_gen_conditioning_zero']: + c = torch.zeros_like(c) + return self.mapping_network(z, + c * + self.rendering_kwargs.get('c_scale', 0), + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + + def synthesis(self, + wp, + c, + neural_rendering_resolution=None, + update_emas=False, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + if self.rendering_kwargs.get('random_pose', False): + cam2world_matrix = None + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + feature_volume = self.ref_representation_generator(wp) + + rendering_result = self.renderer( + wp=wp, + feature_extractor=self.feature_extractor, + rendering_options=self.rendering_kwargs, + cam2world_matrix=cam2world_matrix, + position_encoder=self.position_encoder, + ref_representation=feature_volume, + post_module=self.post_module, + fc_head=self.fc_head) + + feature_samples = rendering_result['composite_rgb'] + depth_samples = rendering_result['composite_depth'] + + # Reshape to keep consistent with 'raw' neural-rendered image. + N = wp.shape[0] + H = W = self.neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape( + N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run the post neural renderer to get final image. + # Here, the post neural renderer is a super-resolution network. + rgb_image = feature_image[:, :3] + sr_image = self.post_neural_renderer( + rgb_image, + feature_image, + wp, + noise_mode=self.rendering_kwargs['superresolution_noise_mode'], + **{ + k: synthesis_kwargs[k] + for k in synthesis_kwargs.keys() if k != 'noise_mode' + }) + + return { + 'image': sr_image, + 'image_raw': rgb_image, + 'image_depth': depth_image + } + + def sample(self, + coordinates, + directions, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False): + # Compute RGB features, density for arbitrary 3D coordinates. + # Mostly used for extracting shapes. + wp = self.mapping_network(z, + c, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + feature_volume = self.ref_representation_generator(wp) + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=feature_volume, + position_encoder=self.position_encoder, + post_module=self.post_module, + ray_dirs=directions) + + return result + + def sample_mixed(self, + coordinates, + directions, + wp): + # Same as function `self.sample()`, but expects latent vectors 'wp' + # instead of Gaussian noise 'z'. + feature_volume = self.ref_representation_generator(wp) + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=feature_volume, + position_encoder=self.position_encoder, + post_module=self.post_module, + ray_dirs=directions) + + return result + + def forward(self, + z, + c, + c_swapped=None, # `c_swapped` is swapped pose conditioning. + style_mixing_prob=0, + truncation_psi=1, + truncation_cutoff=None, + neural_rendering_resolution=None, + update_emas=False, + sample_mixed=False, + coordinates=None, + **synthesis_kwargs): + + # Render a batch of generated images. + c_wp = c.clone() + if c_swapped is not None: + c_wp = c_swapped.clone() + wp = self.mapping_network(z, + c_wp, + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + if style_mixing_prob > 0: + cutoff = torch.empty([], dtype=torch.int64, + device=wp.device).random_(1, wp.shape[1]) + cutoff = torch.where( + torch.rand([], device=wp.device) < style_mixing_prob, cutoff, + torch.full_like(cutoff, wp.shape[1])) + wp[:, cutoff:] = self.mapping_network( + torch.randn_like(z), c, update_emas=update_emas)[:, cutoff:] + if not sample_mixed: + gen_output = self.synthesis( + wp, + c, + update_emas=update_emas, + neural_rendering_resolution=neural_rendering_resolution, + **synthesis_kwargs) + + return { + 'wp': wp, + 'gen_output': gen_output, + } + + else: + # Only for density regularization in training process. + assert coordinates is not None + sample_sigma = self.sample_mixed(coordinates, + torch.randn_like(coordinates), + wp)['sigma'] + + return { + 'wp': wp, + 'sample_sigma': sample_sigma + } + + +class OSGDecoder(nn.Module): + """Defines fully-connected layer head in EG3D.""" + def __init__(self, n_features, options): + super().__init__() + self.hidden_dim = 64 + + self.net = nn.Sequential( + FullyConnectedLayer(n_features, + self.hidden_dim, + lr_multiplier=options['decoder_lr_mul']), + nn.Softplus(), + FullyConnectedLayer(self.hidden_dim, + 1 + options['decoder_output_dim'], + lr_multiplier=options['decoder_lr_mul'])) + + def forward(self, point_features, wp=None, dirs=None): + # point_features.shape: [N, C, M, 1]. + point_features = point_features.squeeze(-1) + point_features = point_features.permute(0, 2, 1) + x = point_features + + N, M, C = x.shape + x = x.reshape(N * M, C) + + x = self.net(x) + x = x.reshape(N, M, -1) + + # Uses sigmoid clamping from MipNeRF + rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} diff --git a/models/ghfeat_encoder.py b/models/ghfeat_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dba2ec7fe8e03df2ae9fda81daf087deb218cdf0 --- /dev/null +++ b/models/ghfeat_encoder.py @@ -0,0 +1,563 @@ +# python3.7 +"""Contains the implementation of encoder used in GH-Feat (including IDInvert). + +ResNet is used as the backbone. + +GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf +IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf + +NOTE: Please use `latent_num` and `num_latents_per_head` to control the +inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert. +In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default. +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +__all__ = ['GHFeatEncoder'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# pylint: disable=missing-function-docstring + +class BasicBlock(nn.Module): + """Implementation of ResNet BasicBlock.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + base_width=64, + stride=1, + groups=1, + dilation=1, + norm_layer=None, + downsample=None): + super().__init__() + if base_width != 64: + raise ValueError(f'BasicBlock of ResNet only supports ' + f'`base_width=64`, but {base_width} received!') + if stride not in [1, 2]: + raise ValueError(f'BasicBlock of ResNet only supports `stride=1` ' + f'and `stride=2`, but {stride} received!') + if groups != 1: + raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, ' + f'but {groups} received!') + if dilation != 1: + raise ValueError(f'BasicBlock of ResNet only supports ' + f'`dilation=1`, but {dilation} received!') + assert self.expansion == 1 + + self.stride = stride + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = nn.Conv2d(in_channels=inplanes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + groups=1, + dilation=1, + bias=False) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=1, + padding=1, + groups=1, + dilation=1, + bias=False) + self.bn2 = norm_layer(planes) + self.downsample = downsample + + def forward(self, x): + identity = self.downsample(x) if self.downsample is not None else x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out + identity) + + return out + + +class Bottleneck(nn.Module): + """Implementation of ResNet Bottleneck.""" + + expansion = 4 + + def __init__(self, + inplanes, + planes, + base_width=64, + stride=1, + groups=1, + dilation=1, + norm_layer=None, + downsample=None): + super().__init__() + if stride not in [1, 2]: + raise ValueError(f'Bottleneck of ResNet only supports `stride=1` ' + f'and `stride=2`, but {stride} received!') + + width = int(planes * (base_width / 64)) * groups + self.stride = stride + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = nn.Conv2d(in_channels=inplanes, + out_channels=width, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False) + self.bn1 = norm_layer(width) + self.conv2 = nn.Conv2d(in_channels=width, + out_channels=width, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + dilation=dilation, + bias=False) + self.bn2 = norm_layer(width) + self.conv3 = nn.Conv2d(in_channels=width, + out_channels=planes * self.expansion, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + identity = self.downsample(x) if self.downsample is not None else x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out + identity) + + return out + + +class GHFeatEncoder(nn.Module): + """Define the ResNet-based encoder network for GAN inversion. + + On top of the backbone, there are several task-heads to produce inverted + codes. Please use `latent_dim` and `num_latents_per_head` to define the + structure. For example, `latent_dim = [512] * 14` and + `num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with + 14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers, + respectively) are used. + + Settings for the encoder network: + + (1) resolution: The resolution of the output image. + (2) latent_dim: Dimension of the latent space. A number (one code will be + produced), or a list of numbers regarding layer-wise latent codes. + (3) num_latents_per_head: Number of latents that is produced by each head. + (4) image_channels: Number of channels of the output image. (default: 3) + (5) final_res: Final resolution of the convolutional layers. (default: 4) + + ResNet-related settings: + + (1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18) + (2) inplanes: Number of channels of the first convolutional layer. + (default: 64) + (3) groups: Groups of the convolution, used in ResNet. (default: 1) + (4) width_per_group: Number of channels per group, used in ResNet. + (default: 64) + (5) replace_stride_with_dilation: Whether to replace stride with dilation, + used in ResNet. (default: None) + (6) norm_layer: Normalization layer used in the encoder. If set as `None`, + `nn.BatchNorm2d` will be used. Also, please NOTE that when using batch + normalization, the batch size is required to be larger than one for + training. (default: nn.BatchNorm2d) + (7) max_channels: Maximum number of channels in each layer. (default: 512) + + Task-head related settings: + + (1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting + the latent code. (default: True) + (2) fpn_channels: Number of channels used in FPN. (default: 512) + (3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting + the latent code. (default: True) + (4) sam_channels: Number of channels used in SAM. (default: 512) + """ + + arch_settings = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) + } + + def __init__(self, + resolution, + latent_dim, + num_latents_per_head, + image_channels=3, + final_res=4, + network_depth=18, + inplanes=64, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=nn.BatchNorm2d, + max_channels=512, + use_fpn=True, + fpn_channels=512, + use_sam=True, + sam_channels=512): + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + if network_depth not in self.arch_settings: + raise ValueError(f'Invalid network depth: `{network_depth}`!\n' + f'Options allowed: ' + f'{list(self.arch_settings.keys())}.') + if isinstance(latent_dim, int): + latent_dim = [latent_dim] + assert isinstance(latent_dim, (list, tuple)) + assert isinstance(num_latents_per_head, (list, tuple)) + assert sum(num_latents_per_head) == len(latent_dim) + + self.resolution = resolution + self.latent_dim = latent_dim + self.num_latents_per_head = num_latents_per_head + self.num_heads = len(self.num_latents_per_head) + self.image_channels = image_channels + self.final_res = final_res + self.inplanes = inplanes + self.network_depth = network_depth + self.groups = groups + self.dilation = 1 + self.base_width = width_per_group + self.replace_stride_with_dilation = replace_stride_with_dilation + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if norm_layer == nn.BatchNorm2d and dist.is_initialized(): + norm_layer = nn.SyncBatchNorm + self.norm_layer = norm_layer + self.max_channels = max_channels + self.use_fpn = use_fpn + self.fpn_channels = fpn_channels + self.use_sam = use_sam + self.sam_channels = sam_channels + + block_fn, num_blocks_per_stage = self.arch_settings[network_depth] + + self.num_stages = int(np.log2(resolution // final_res)) - 1 + # Add one block for additional stages. + for i in range(len(num_blocks_per_stage), self.num_stages): + num_blocks_per_stage.append(1) + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False] * self.num_stages + + # Backbone. + self.conv1 = nn.Conv2d(in_channels=self.image_channels, + out_channels=self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.stage_channels = [self.inplanes] + self.stages = nn.ModuleList() + for i in range(self.num_stages): + inplanes = self.inplanes if i == 0 else planes * block_fn.expansion + planes = min(self.max_channels, self.inplanes * (2 ** i)) + num_blocks = num_blocks_per_stage[i] + stride = 1 if i == 0 else 2 + dilate = replace_stride_with_dilation[i] + self.stages.append(self._make_stage(block_fn=block_fn, + inplanes=inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilate=dilate)) + self.stage_channels.append(planes * block_fn.expansion) + + if self.num_heads > len(self.stage_channels): + raise ValueError('Number of task heads is larger than number of ' + 'stages! Please reduce the number of heads.') + + # Task-head. + if self.num_heads == 1: + self.use_fpn = False + self.use_sam = False + + if self.use_fpn: + fpn_pyramid_channels = self.stage_channels[-self.num_heads:] + self.fpn = FPN(pyramid_channels=fpn_pyramid_channels, + out_channels=self.fpn_channels) + if self.use_sam: + if self.use_fpn: + sam_pyramid_channels = [self.fpn_channels] * self.num_heads + else: + sam_pyramid_channels = self.stage_channels[-self.num_heads:] + self.sam = SAM(pyramid_channels=sam_pyramid_channels, + out_channels=self.sam_channels) + + self.heads = nn.ModuleList() + for head_idx in range(self.num_heads): + # Parse in_channels. + if self.use_sam: + in_channels = self.sam_channels + elif self.use_fpn: + in_channels = self.fpn_channels + else: + in_channels = self.stage_channels[head_idx - self.num_heads] + in_channels = in_channels * final_res * final_res + + # Parse out_channels. + start_latent_idx = sum(self.num_latents_per_head[:head_idx]) + end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1]) + out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx]) + + self.heads.append(CodeHead(in_channels=in_channels, + out_channels=out_channels, + norm_layer=self.norm_layer)) + + def _make_stage(self, + block_fn, + inplanes, + planes, + num_blocks, + stride, + dilate): + norm_layer = self.norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or inplanes != planes * block_fn.expansion: + downsample = nn.Sequential( + nn.Conv2d(in_channels=inplanes, + out_channels=planes * block_fn.expansion, + kernel_size=1, + stride=stride, + padding=0, + dilation=1, + groups=1, + bias=False), + norm_layer(planes * block_fn.expansion), + ) + + blocks = [] + blocks.append(block_fn(inplanes=inplanes, + planes=planes, + base_width=self.base_width, + stride=stride, + groups=self.groups, + dilation=previous_dilation, + norm_layer=norm_layer, + downsample=downsample)) + for _ in range(1, num_blocks): + blocks.append(block_fn(inplanes=planes * block_fn.expansion, + planes=planes, + base_width=self.base_width, + stride=1, + groups=self.groups, + dilation=self.dilation, + norm_layer=norm_layer, + downsample=None)) + + return nn.Sequential(*blocks) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + features = [x] + for i in range(self.num_stages): + x = self.stages[i](x) + features.append(x) + features = features[-self.num_heads:] + + if self.use_fpn: + features = self.fpn(features) + if self.use_sam: + features = self.sam(features) + else: + final_size = features[-1].shape[2:] + for i in range(self.num_heads - 1): + features[i] = F.adaptive_avg_pool2d(features[i], final_size) + + outputs = [] + for head_idx in range(self.num_heads): + codes = self.heads[head_idx](features[head_idx]) + start_latent_idx = sum(self.num_latents_per_head[:head_idx]) + end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1]) + split_size = self.latent_dim[start_latent_idx:end_latent_idx] + outputs.extend(torch.split(codes, split_size, dim=1)) + max_dim = max(self.latent_dim) + for i, dim in enumerate(self.latent_dim): + if dim < max_dim: + outputs[i] = F.pad(outputs[i], (0, max_dim - dim)) + outputs[i] = outputs[i].unsqueeze(1) + + return torch.cat(outputs, dim=1) + + +class FPN(nn.Module): + """Implementation of Feature Pyramid Network (FPN). + + The input of this module is a pyramid of features with reducing resolutions. + Then, this module fuses these multi-level features from `top_level` to + `bottom_level`. In particular, starting from the `top_level`, each feature + is convoluted, upsampled, and fused into its previous feature (which is also + convoluted). + + Args: + pyramid_channels: A list of integers, each of which indicates the number + of channels of the feature from a particular level. + out_channels: Number of channels for each output. + + Returns: + A list of feature maps, each of which has `out_channels` channels. + """ + + def __init__(self, pyramid_channels, out_channels): + super().__init__() + assert isinstance(pyramid_channels, (list, tuple)) + self.num_levels = len(pyramid_channels) + + self.lateral_layers = nn.ModuleList() + self.feature_layers = nn.ModuleList() + for i in range(self.num_levels): + in_channels = pyramid_channels[i] + self.lateral_layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True)) + self.feature_layers.append(nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True)) + + def forward(self, inputs): + if len(inputs) != self.num_levels: + raise ValueError('Number of inputs and `num_levels` mismatch!') + + # Project all related features to `out_channels`. + laterals = [] + for i in range(self.num_levels): + laterals.append(self.lateral_layers[i](inputs[i])) + + # Fusion, starting from `top_level`. + for i in range(self.num_levels - 1, 0, -1): + scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2] + laterals[i - 1] = (laterals[i - 1] + + F.interpolate(laterals[i], + mode='nearest', + scale_factor=scale_factor)) + + # Get outputs. + outputs = [] + for i, lateral in enumerate(laterals): + outputs.append(self.feature_layers[i](lateral)) + + return outputs + + +class SAM(nn.Module): + """Implementation of Spatial Alignment Module (SAM). + + The input of this module is a pyramid of features with reducing resolutions. + Then this module downsamples all levels of feature to the minimum resolution + and fuses it with the smallest feature map. + + Args: + pyramid_channels: A list of integers, each of which indicates the number + of channels of the feature from a particular level. + out_channels: Number of channels for each output. + + Returns: + A list of feature maps, each of which has `out_channels` channels. + """ + + def __init__(self, pyramid_channels, out_channels): + super().__init__() + assert isinstance(pyramid_channels, (list, tuple)) + self.num_levels = len(pyramid_channels) + + self.fusion_layers = nn.ModuleList() + for i in range(self.num_levels): + in_channels = pyramid_channels[i] + self.fusion_layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True)) + + def forward(self, inputs): + if len(inputs) != self.num_levels: + raise ValueError('Number of inputs and `num_levels` mismatch!') + + output_res = inputs[-1].shape[2:] + for i in range(self.num_levels - 1, -1, -1): + if i != self.num_levels - 1: + inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res) + inputs[i] = self.fusion_layers[i](inputs[i]) + if i != self.num_levels - 1: + inputs[i] = inputs[i] + inputs[-1] + + return inputs + + +class CodeHead(nn.Module): + """Implementation of the task-head to produce inverted codes.""" + + def __init__(self, in_channels, out_channels, norm_layer): + super().__init__() + self.fc = nn.Linear(in_channels, out_channels, bias=True) + if norm_layer is None: + self.norm = nn.Identity() + else: + self.norm = norm_layer(out_channels) + + def forward(self, x): + if x.ndim > 2: + x = x.flatten(start_dim=1) + latent = self.fc(x) + latent = latent.unsqueeze(2).unsqueeze(3) + latent = self.norm(latent) + + return latent.flatten(start_dim=1) + +# pylint: enable=missing-function-docstring diff --git a/models/inception_model.py b/models/inception_model.py new file mode 100644 index 0000000000000000000000000000000000000000..68fe4ece6b6cdc864b7de49719d7714cabfacedf --- /dev/null +++ b/models/inception_model.py @@ -0,0 +1,562 @@ +# python3.7 +"""Contains the Inception V3 model, which is used for inference ONLY. + +This file is mostly borrowed from `torchvision/models/inception.py`. + +Inception model is widely used to compute FID or IS metric for evaluating +generative models. However, the pre-trained models from torchvision is slightly +different from the TensorFlow version + +http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + +which is used by the official FID implementation + +https://github.com/bioinf-jku/TTUR + +In particular: + +(1) The number of classes in TensorFlow model is 1008 instead of 1000. +(2) The avg_pool() layers in TensorFlow model does not include the padded zero. +(3) The last Inception E Block in TensorFlow model use max_pool() instead of + avg_pool(). + +Hence, to align the evaluation results with those from TensorFlow +implementation, we modified the inception model to support both versions. Please +use `align_tf` argument to control the version. +""" + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from utils.misc import download_url + +__all__ = ['InceptionModel'] + +# pylint: disable=line-too-long + +_MODEL_URL_SHA256 = { + # This model is provided by `torchvision`, which is ported from TensorFlow. + 'torchvision_official': ( + 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + '1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256 + ), + + # This model is provided by https://github.com/mseitzer/pytorch-fid + 'tf_inception_v3': ( + 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth', + '6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256 + ) +} + + +class InceptionModel(object): + """Defines the Inception (V3) model. + + This is a static class, which is used to avoid this model to be built + repeatedly. Consequently, this model is particularly used for inference, + like computing FID. If training is required, please use the model from + `torchvision.models` or implement by yourself. + + NOTE: The pre-trained model assumes the inputs to be with `RGB` channel + order and pixel range [-1, 1], and will also resize the images to shape + [299, 299] automatically. If your input is normalized by subtracting + (0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use + `transform_input` in the `forward()` function to un-normalize it. + """ + models = dict() + + @staticmethod + def build_model(align_tf=True): + """Builds the model and load pre-trained weights. + + If `align_tf` is set as True, the model will predict 1008 classes, and + the pre-trained weight from `https://github.com/mseitzer/pytorch-fid` + will be loaded. Otherwise, the model will predict 1000 classes, and will + load the model from `torchvision`. + + The built model supports following arguments when forwarding: + + - transform_input: Whether to transform the input back to pixel range + (-1, 1). Please disable this argument if your input is already with + pixel range (-1, 1). (default: False) + - output_logits: Whether to output the categorical logits instead of + features. (default: False) + - remove_logits_bias: Whether to remove the bias when computing the + logits. The official implementation removes the bias by default. + Please refer to + `https://github.com/openai/improved-gan/blob/master/inception_score/model.py`. + (default: False) + - output_predictions: Whether to output the final predictions, i.e., + `softmax(logits)`. (default: False) + """ + if align_tf: + num_classes = 1008 + model_source = 'tf_inception_v3' + else: + num_classes = 1000 + model_source = 'torchvision_official' + + fingerprint = model_source + + if fingerprint not in InceptionModel.models: + # Build model. + model = Inception3(num_classes=num_classes, + aux_logits=False, + init_weights=False, + align_tf=align_tf) + + # Download pre-trained weights. + if dist.is_initialized() and dist.get_rank() != 0: + dist.barrier() # Download by chief. + + url, sha256 = _MODEL_URL_SHA256[model_source] + filename = f'inception_model_{model_source}_{sha256}.pth' + model_path, hash_check = download_url(url, + filename=filename, + sha256=sha256) + state_dict = torch.load(model_path, map_location='cpu') + if hash_check is False: + warnings.warn(f'Hash check failed! The remote file from URL ' + f'`{url}` may be changed, or the downloading is ' + f'interrupted. The loaded inception model may ' + f'have unexpected behavior.') + + if dist.is_initialized() and dist.get_rank() == 0: + dist.barrier() # Wait for other replicas. + + # Load weights. + model.load_state_dict(state_dict, strict=False) + del state_dict + + # For inference only. + model.eval().requires_grad_(False).cuda() + InceptionModel.models[fingerprint] = model + + return InceptionModel.models[fingerprint] + +# pylint: disable=missing-function-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=super-with-arguments +# pylint: disable=consider-merging-isinstance +# pylint: disable=import-outside-toplevel +# pylint: disable=no-else-return + +class Inception3(nn.Module): + + def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None, + init_weights=True, align_tf=True): + super(Inception3, self).__init__() + if inception_blocks is None: + inception_blocks = [ + BasicConv2d, InceptionA, InceptionB, InceptionC, + InceptionD, InceptionE, InceptionAux + ] + assert len(inception_blocks) == 7 + conv_block = inception_blocks[0] + inception_a = inception_blocks[1] + inception_b = inception_blocks[2] + inception_c = inception_blocks[3] + inception_d = inception_blocks[4] + inception_e = inception_blocks[5] + inception_aux = inception_blocks[6] + + self.aux_logits = aux_logits + self.align_tf = align_tf + self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf) + self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf) + self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf) + self.Mixed_6a = inception_b(288) + self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf) + self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf) + self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf) + self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf) + if aux_logits: + self.AuxLogits = inception_aux(768, num_classes) + self.Mixed_7a = inception_d(768) + self.Mixed_7b = inception_e(1280, align_tf=self.align_tf) + self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf) + self.fc = nn.Linear(2048, num_classes) + if init_weights: + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + import scipy.stats as stats + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + X = stats.truncnorm(-2, 2, scale=stddev) + values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) + values = values.view(m.weight.size()) + with torch.no_grad(): + m.weight.copy_(values) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + @staticmethod + def _transform_input(x, transform_input=False): + if transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + return x + + def _forward(self, + x, + output_logits=False, + remove_logits_bias=False, + output_predictions=False): + # Upsample if necessary. + if x.shape[2] != 299 or x.shape[3] != 299: + if self.align_tf: + theta = torch.eye(2, 3).to(x) + theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299 + theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299 + theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1) + grid = F.affine_grid(theta, + size=(x.shape[0], x.shape[1], 299, 299), + align_corners=False) + x = F.grid_sample(x, grid, + mode='bilinear', + padding_mode='border', + align_corners=False) + else: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + + if self.align_tf: + x = (x * 127.5 + 127.5 - 128) / 128 + + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + if self.training and self.aux_logits: + aux = self.AuxLogits(x) + else: + aux = None + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 2048 x 1 x 1 + x = F.dropout(x, training=self.training) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + if output_logits or output_predictions: + x = self.fc(x) + # N x 1000 (num_classes) + if remove_logits_bias: + x = x - self.fc.bias.view(1, -1) + if output_predictions: + x = F.softmax(x, dim=1) + return x, aux + + def forward(self, + x, + transform_input=False, + output_logits=False, + remove_logits_bias=False, + output_predictions=False): + x = self._transform_input(x, transform_input) + x, aux = self._forward( + x, output_logits, remove_logits_bias, output_predictions) + if self.training and self.aux_logits: + return x, aux + else: + return x + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + self.pool_include_padding = not align_tf + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=self.pool_include_padding) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + self.pool_include_padding = not align_tf + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=self.pool_include_padding) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + self.pool_include_padding = not align_tf + self.use_max_pool = use_max_pool + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + if self.use_max_pool: + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + else: + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=self.pool_include_padding) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = nn.Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring +# pylint: enable=missing-class-docstring +# pylint: enable=super-with-arguments +# pylint: enable=consider-merging-isinstance +# pylint: enable=import-outside-toplevel +# pylint: enable=no-else-return diff --git a/models/perceptual_model.py b/models/perceptual_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0aaa82789f19e9f4760d3b42e00b44e3728ffa --- /dev/null +++ b/models/perceptual_model.py @@ -0,0 +1,519 @@ +# python3.7 +"""Contains the VGG16 model, which is used for inference ONLY. + +VGG16 is commonly used for perceptual feature extraction. The model implemented +in this file can be used for evaluation (like computing LPIPS, perceptual path +length, etc.), OR be used in training for loss computation (like perceptual +loss, etc.). + +The pre-trained model is officially shared by + +https://www.robots.ox.ac.uk/~vgg/research/very_deep/ + +and ported by + +https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt + +Compared to the official VGG16 model, this ported model also support evaluating +LPIPS, which is introduced in + +https://github.com/richzhang/PerceptualSimilarity +""" + +import warnings +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from utils.misc import download_url + +__all__ = ['PerceptualModel'] + +# pylint: disable=line-too-long +_MODEL_URL_SHA256 = { + # This model is provided by `torchvision`, which is ported from TensorFlow. + 'torchvision_official': ( + 'https://download.pytorch.org/models/vgg16-397923af.pth', + '397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256 + ), + + # This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch + 'vgg_perceptual_lpips': ( + 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt', + 'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256 + ) +} +# pylint: enable=line-too-long + + +class PerceptualModel(object): + """Defines the perceptual model, which is based on VGG16 structure. + + This is a static class, which is used to avoid this model to be built + repeatedly. Consequently, this model is particularly used for inference, + like computing LPIPS, or for loss computation, like perceptual loss. If + training is required, please use the model from `torchvision.models` or + implement by yourself. + + NOTE: The pre-trained model assumes the inputs to be with `RGB` channel + order and pixel range [-1, 1], and will NOT resize the input automatically + if only perceptual feature is needed. + """ + models = dict() + + @staticmethod + def build_model(use_torchvision=False, no_top=True, enable_lpips=True): + """Builds the model and load pre-trained weights. + + 1. If `use_torchvision` is set as True, the model released by + `torchvision` will be loaded, otherwise, the model released by + https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used. + (default: False) + + 2. To save computing resources, these is an option to only load the + backbone (i.e., without the last three fully-connected layers). This + is commonly used for perceptual loss or LPIPS loss computation. + Please use argument `no_top` to control this. (default: True) + + 3. For LPIPS loss computation, some additional weights (which is used + for balancing the features from different resolutions) are employed + on top of the original VGG16 backbone. Details can be found at + https://github.com/richzhang/PerceptualSimilarity. Please use + `enable_lpips` to enable this feature. (default: True) + + The built model supports following arguments when forwarding: + + - resize_input: Whether to resize the input image to size [224, 224] + before forwarding. For feature-based computation (i.e., only + convolutional layers are used), image resizing is not essential. + (default: False) + - return_tensor: This field resolves the model behavior. Following + options are supported: + `feature1`: Before the first max pooling layer. + `pool1`: After the first max pooling layer. + `feature2`: Before the second max pooling layer. + `pool2`: After the second max pooling layer. + `feature3`: Before the third max pooling layer. + `pool3`: After the third max pooling layer. + `feature4`: Before the fourth max pooling layer. + `pool4`: After the fourth max pooling layer. + `feature5`: Before the fifth max pooling layer. + `pool5`: After the fifth max pooling layer. + `flatten`: The flattened feature, after `adaptive_avgpool`. + `feature`: The 4096d feature for logits computation. (default) + `logits`: The 1000d categorical logits. + `prediction`: The 1000d predicted probability. + `lpips`: The LPIPS score between two input images. + """ + if use_torchvision: + model_source = 'torchvision_official' + align_tf_resize = False + is_torch_script = False + else: + model_source = 'vgg_perceptual_lpips' + align_tf_resize = True + is_torch_script = True + + if enable_lpips and model_source != 'vgg_perceptual_lpips': + warnings.warn('The pre-trained model officially released by ' + '`torchvision` does not support LPIPS computation! ' + 'Equal weights will be used for each resolution.') + + fingerprint = (model_source, no_top, enable_lpips) + + if fingerprint not in PerceptualModel.models: + # Build model. + model = VGG16(align_tf_resize=align_tf_resize, + no_top=no_top, + enable_lpips=enable_lpips) + + # Download pre-trained weights. + if dist.is_initialized() and dist.get_rank() != 0: + dist.barrier() # Download by chief. + + url, sha256 = _MODEL_URL_SHA256[model_source] + filename = f'perceptual_model_{model_source}_{sha256}.pth' + model_path, hash_check = download_url(url, + filename=filename, + sha256=sha256) + if is_torch_script: + src_state_dict = torch.jit.load(model_path, map_location='cpu') + else: + src_state_dict = torch.load(model_path, map_location='cpu') + if hash_check is False: + warnings.warn(f'Hash check failed! The remote file from URL ' + f'`{url}` may be changed, or the downloading is ' + f'interrupted. The loaded perceptual model may ' + f'have unexpected behavior.') + + if dist.is_initialized() and dist.get_rank() == 0: + dist.barrier() # Wait for other replicas. + + # Load weights. + dst_state_dict = _convert_weights(src_state_dict, model_source) + model.load_state_dict(dst_state_dict, strict=False) + del src_state_dict, dst_state_dict + + # For inference only. + model.eval().requires_grad_(False).cuda() + PerceptualModel.models[fingerprint] = model + + return PerceptualModel.models[fingerprint] + + +def _convert_weights(src_state_dict, model_source): + if model_source not in _MODEL_URL_SHA256: + raise ValueError(f'Invalid model source `{model_source}`!\n' + f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.') + if model_source == 'torchvision_official': + dst_to_src_var_mapping = { + 'conv11.weight': 'features.0.weight', + 'conv11.bias': 'features.0.bias', + 'conv12.weight': 'features.2.weight', + 'conv12.bias': 'features.2.bias', + 'conv21.weight': 'features.5.weight', + 'conv21.bias': 'features.5.bias', + 'conv22.weight': 'features.7.weight', + 'conv22.bias': 'features.7.bias', + 'conv31.weight': 'features.10.weight', + 'conv31.bias': 'features.10.bias', + 'conv32.weight': 'features.12.weight', + 'conv32.bias': 'features.12.bias', + 'conv33.weight': 'features.14.weight', + 'conv33.bias': 'features.14.bias', + 'conv41.weight': 'features.17.weight', + 'conv41.bias': 'features.17.bias', + 'conv42.weight': 'features.19.weight', + 'conv42.bias': 'features.19.bias', + 'conv43.weight': 'features.21.weight', + 'conv43.bias': 'features.21.bias', + 'conv51.weight': 'features.24.weight', + 'conv51.bias': 'features.24.bias', + 'conv52.weight': 'features.26.weight', + 'conv52.bias': 'features.26.bias', + 'conv53.weight': 'features.28.weight', + 'conv53.bias': 'features.28.bias', + 'fc1.weight': 'classifier.0.weight', + 'fc1.bias': 'classifier.0.bias', + 'fc2.weight': 'classifier.3.weight', + 'fc2.bias': 'classifier.3.bias', + 'fc3.weight': 'classifier.6.weight', + 'fc3.bias': 'classifier.6.bias', + } + elif model_source == 'vgg_perceptual_lpips': + src_state_dict = src_state_dict.state_dict() + dst_to_src_var_mapping = { + 'conv11.weight': 'layers.conv1.weight', + 'conv11.bias': 'layers.conv1.bias', + 'conv12.weight': 'layers.conv2.weight', + 'conv12.bias': 'layers.conv2.bias', + 'conv21.weight': 'layers.conv3.weight', + 'conv21.bias': 'layers.conv3.bias', + 'conv22.weight': 'layers.conv4.weight', + 'conv22.bias': 'layers.conv4.bias', + 'conv31.weight': 'layers.conv5.weight', + 'conv31.bias': 'layers.conv5.bias', + 'conv32.weight': 'layers.conv6.weight', + 'conv32.bias': 'layers.conv6.bias', + 'conv33.weight': 'layers.conv7.weight', + 'conv33.bias': 'layers.conv7.bias', + 'conv41.weight': 'layers.conv8.weight', + 'conv41.bias': 'layers.conv8.bias', + 'conv42.weight': 'layers.conv9.weight', + 'conv42.bias': 'layers.conv9.bias', + 'conv43.weight': 'layers.conv10.weight', + 'conv43.bias': 'layers.conv10.bias', + 'conv51.weight': 'layers.conv11.weight', + 'conv51.bias': 'layers.conv11.bias', + 'conv52.weight': 'layers.conv12.weight', + 'conv52.bias': 'layers.conv12.bias', + 'conv53.weight': 'layers.conv13.weight', + 'conv53.bias': 'layers.conv13.bias', + 'fc1.weight': 'layers.fc1.weight', + 'fc1.bias': 'layers.fc1.bias', + 'fc2.weight': 'layers.fc2.weight', + 'fc2.bias': 'layers.fc2.bias', + 'fc3.weight': 'layers.fc3.weight', + 'fc3.bias': 'layers.fc3.bias', + 'lpips.0.weight': 'lpips0', + 'lpips.1.weight': 'lpips1', + 'lpips.2.weight': 'lpips2', + 'lpips.3.weight': 'lpips3', + 'lpips.4.weight': 'lpips4', + } + else: + raise NotImplementedError(f'Not implemented model source ' + f'`{model_source}`!') + + dst_state_dict = {} + for dst_name, src_name in dst_to_src_var_mapping.items(): + if dst_name.startswith('lpips'): + dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0) + else: + dst_state_dict[dst_name] = src_state_dict[src_name].clone() + return dst_state_dict + + +_IMG_MEAN = (0.485, 0.456, 0.406) +_IMG_STD = (0.229, 0.224, 0.225) +_ALLOWED_RETURN = [ + 'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4', + 'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction', + 'lpips' +] + +# pylint: disable=missing-function-docstring + +class VGG16(nn.Module): + """Defines the VGG16 structure. + + This model takes `RGB` images with data format `NCHW` as the raw inputs. The + pixel range are assumed to be [-1, 1]. + """ + + def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True): + """Defines the network structure.""" + super().__init__() + + self.align_tf_resize = align_tf_resize + self.no_top = no_top + self.enable_lpips = enable_lpips + + self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.relu11 = nn.ReLU(inplace=True) + self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.relu12 = nn.ReLU(inplace=True) + # output `feature1`, with shape [N, 64, 224, 224] + + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool1`, with shape [N, 64, 112, 112] + + self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.relu21 = nn.ReLU(inplace=True) + self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + self.relu22 = nn.ReLU(inplace=True) + # output `feature2`, with shape [N, 128, 112, 112] + + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool2`, with shape [N, 128, 56, 56] + + self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.relu31 = nn.ReLU(inplace=True) + self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.relu32 = nn.ReLU(inplace=True) + self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.relu33 = nn.ReLU(inplace=True) + # output `feature3`, with shape [N, 256, 56, 56] + + self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool3`, with shape [N,256, 28, 28] + + self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.relu41 = nn.ReLU(inplace=True) + self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu42 = nn.ReLU(inplace=True) + self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu43 = nn.ReLU(inplace=True) + # output `feature4`, with shape [N, 512, 28, 28] + + self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool4`, with shape [N, 512, 14, 14] + + self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu51 = nn.ReLU(inplace=True) + self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu52 = nn.ReLU(inplace=True) + self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu53 = nn.ReLU(inplace=True) + # output `feature5`, with shape [N, 512, 14, 14] + + self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool5`, with shape [N, 512, 7, 7] + + if self.enable_lpips: + self.lpips = nn.ModuleList() + for idx, ch in enumerate([64, 128, 256, 512, 512]): + self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False)) + self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1)) + + if not self.no_top: + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.flatten = nn.Flatten(start_dim=1, end_dim=-1) + # output `flatten`, with shape [N, 25088] + + self.fc1 = nn.Linear(512 * 7 * 7, 4096) + self.fc1_relu = nn.ReLU(inplace=True) + self.fc1_dropout = nn.Dropout(0.5, inplace=False) + self.fc2 = nn.Linear(4096, 4096) + self.fc2_relu = nn.ReLU(inplace=True) + self.fc2_dropout = nn.Dropout(0.5, inplace=False) + # output `feature`, with shape [N, 4096] + + self.fc3 = nn.Linear(4096, 1000) + # output `logits`, with shape [N, 1000] + + self.out = nn.Softmax(dim=1) + # output `softmax`, with shape [N, 1000] + + img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32) + img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32) + self.register_buffer('img_mean', torch.from_numpy(img_mean)) + self.register_buffer('img_std', torch.from_numpy(img_std)) + + def forward(self, + x, + y=None, + *, + resize_input=False, + return_tensor='feature'): + return_tensor = return_tensor.lower() + if return_tensor not in _ALLOWED_RETURN: + raise ValueError(f'Invalid output tensor name `{return_tensor}` ' + f'for perceptual model (VGG16)!\n' + f'Names allowed: {_ALLOWED_RETURN}.') + + if return_tensor == 'lpips' and y is None: + raise ValueError('Two images are required for LPIPS computation, ' + 'but only one is received!') + + if return_tensor == 'lpips': + assert x.shape == y.shape + x = torch.cat([x, y], dim=0) + features = [] + + if resize_input: + if self.align_tf_resize: + theta = torch.eye(2, 3).to(x) + theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224 + theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224 + theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1) + grid = F.affine_grid(theta, + size=(x.shape[0], x.shape[1], 224, 224), + align_corners=False) + x = F.grid_sample(x, grid, + mode='bilinear', + padding_mode='border', + align_corners=False) + else: + x = F.interpolate(x, + size=(224, 224), + mode='bilinear', + align_corners=False) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + + x = (x + 1) / 2 + x = (x - self.img_mean) / self.img_std + + x = self.conv11(x) + x = self.relu11(x) + x = self.conv12(x) + x = self.relu12(x) + if return_tensor == 'feature1': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool1(x) + if return_tensor == 'pool1': + return x + + x = self.conv21(x) + x = self.relu21(x) + x = self.conv22(x) + x = self.relu22(x) + if return_tensor == 'feature2': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool2(x) + if return_tensor == 'pool2': + return x + + x = self.conv31(x) + x = self.relu31(x) + x = self.conv32(x) + x = self.relu32(x) + x = self.conv33(x) + x = self.relu33(x) + if return_tensor == 'feature3': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool3(x) + if return_tensor == 'pool3': + return x + + x = self.conv41(x) + x = self.relu41(x) + x = self.conv42(x) + x = self.relu42(x) + x = self.conv43(x) + x = self.relu43(x) + if return_tensor == 'feature4': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool4(x) + if return_tensor == 'pool4': + return x + + x = self.conv51(x) + x = self.relu51(x) + x = self.conv52(x) + x = self.relu52(x) + x = self.conv53(x) + x = self.relu53(x) + if return_tensor == 'feature5': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool5(x) + if return_tensor == 'pool5': + return x + + if return_tensor == 'lpips': + score = 0 + assert len(features) == 5 + for idx in range(5): + feature = features[idx] + norm = feature.norm(dim=1, keepdim=True) + feature = feature / (norm + 1e-10) + feature_x, feature_y = feature.chunk(2, dim=0) + diff = (feature_x - feature_y).square() + score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False) + return score.sum(dim=1, keepdim=False) + + x = self.avgpool(x) + x = self.flatten(x) + if return_tensor == 'flatten': + return x + + x = self.fc1(x) + x = self.fc1_relu(x) + x = self.fc1_dropout(x) + x = self.fc2(x) + x = self.fc2_relu(x) + x = self.fc2_dropout(x) + if return_tensor == 'feature': + return x + + x = self.fc3(x) + if return_tensor == 'logits': + return x + + x = self.out(x) + if return_tensor == 'prediction': + return x + + raise NotImplementedError(f'Output tensor name `{return_tensor}` is ' + f'not implemented!') + +# pylint: enable=missing-function-docstring diff --git a/models/pggan_discriminator.py b/models/pggan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..30b0868dd6a753ba7f2712c10b4f19708b67eee3 --- /dev/null +++ b/models/pggan_discriminator.py @@ -0,0 +1,465 @@ +# python3.7 +"""Contains the implementation of discriminator described in PGGAN. + +Paper: https://arxiv.org/pdf/1710.10196.pdf + +Official TensorFlow implementation: +https://github.com/tkarras/progressive_growing_of_gans +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['PGGANDiscriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Default gain factor for weight scaling. +_WSCALE_GAIN = np.sqrt(2.0) + +# pylint: disable=missing-function-docstring + +class PGGANDiscriminator(nn.Module): + """Defines the discriminator network in PGGAN. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the network: + + (1) resolution: The resolution of the input image. + (2) init_res: Smallest resolution of the convolutional backbone. + (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (5) fused_scale: Whether to fused `conv2d` and `downsample` together, + resulting in `conv2d` with strides. (default: False) + (6) use_wscale: Whether to use weight scaling. (default: True) + (7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (8) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 16) + (9) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (10) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (11) eps: A small value to avoid divide overflow. (default: 1e-8) + """ + + def __init__(self, + resolution, + init_res=4, + image_channels=3, + label_dim=0, + fused_scale=False, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + mbstd_groups=16, + fmaps_base=16 << 10, + fmaps_max=512, + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(self.init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(self.resolution)) + self.image_channels = image_channels + self.label_dim = label_dim + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.mbstd_groups = mbstd_groups + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.eps = eps + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution. + self.add_module( + f'input{block_idx}', + ConvLayer(in_channels=self.image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + downsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( + f'FromRGB_lod{block_idx}/weight') + self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( + f'FromRGB_lod{block_idx}/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + self.add_module( + f'layer{2 * block_idx}', + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + downsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer0_name = 'Conv0' + self.add_module( + f'layer{2 * block_idx + 1}', + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + downsample=True, + fused_scale=fused_scale, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer1_name = 'Conv1_down' if fused_scale else 'Conv1' + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, eps=eps) + self.add_module( + f'layer{2 * block_idx}', + ConvLayer( + in_channels=in_channels + 1, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + downsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer0_name = 'Conv' + self.add_module( + f'layer{2 * block_idx + 1}', + DenseLayer(in_channels=in_channels * res * res, + out_channels=out_channels, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer1_name = 'Dense0' + + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( + f'{res}x{res}/{tf_layer0_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( + f'{res}x{res}/{tf_layer0_name}/bias') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( + f'{res}x{res}/{tf_layer1_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( + f'{res}x{res}/{tf_layer1_name}/bias') + + # Final dense layer. + self.output = DenseLayer(in_channels=out_channels, + out_channels=1 + self.label_dim, + add_bias=True, + use_wscale=self.use_wscale, + wscale_gain=1.0, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = ( + f'{res}x{res}/Dense1/weight') + self.pth_to_tf_var_mapping['output.bias'] = ( + f'{res}x{res}/Dense1/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, lod=None): + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + lod = self.lod.item() + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + block_idx = current_lod = self.final_res_log2 - res_log2 + if current_lod <= lod < current_lod + 1: + x = getattr(self, f'input{block_idx}')(image) + elif current_lod - 1 < lod < current_lod: + alpha = lod - np.floor(lod) + y = getattr(self, f'input{block_idx}')(image) + x = y * alpha + x * (1 - alpha) + if lod < current_lod + 1: + if res_log2 == self.init_res_log2: + x = self.mbstd(x) + x = getattr(self, f'layer{2 * block_idx}')(x) + x = getattr(self, f'layer{2 * block_idx + 1}')(x) + if lod > current_lod: + image = F.avg_pool2d( + image, kernel_size=2, stride=2, padding=0) + x = self.output(x) + + return {'score': x} + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, eps): + super().__init__() + self.groups = groups + self.eps = eps + + def extra_repr(self): + return f'groups={self.groups}, epsilon={self.eps}' + + def forward(self, x): + if self.groups <= 1: + return x + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + + y = x.reshape(G, -1, C, H, W) # [GnCHW] + y = y - y.mean(dim=0) # [GnCHW] + y = y.square().mean(dim=0) # [nCHW] + y = (y + self.eps).sqrt() # [nCHW] + y = y.mean(dim=(1, 2, 3), keepdim=True) # [n111] + y = y.repeat(G, 1, H, W) # [N1HW] + x = torch.cat([x, y], dim=1) # [N(C+1)HW] + + return x + + +class DownsamplingLayer(nn.Module): + """Implements the downsampling layer. + + Basically, this layer can be used to downsample feature maps with average + pooling. + """ + + def __init__(self, scale_factor): + super().__init__() + self.scale_factor = scale_factor + + def extra_repr(self): + return f'factor={self.scale_factor}' + + def forward(self, x): + if self.scale_factor <= 1: + return x + return F.avg_pool2d(x, + kernel_size=self.scale_factor, + stride=self.scale_factor, + padding=0) + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + Basically, this layer executes convolution, activation, and downsampling (if + needed) in sequence. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + downsample, + fused_scale, + use_wscale, + wscale_gain, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + downsample: Whether to downsample the result after convolution. + fused_scale: Whether to fused `conv2d` and `downsample` together, + resulting in `conv2d` with strides. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.downsample = downsample + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.activation_type = activation_type + + if downsample and not fused_scale: + self.down = DownsamplingLayer(scale_factor=2) + else: + self.down = nn.Identity() + + if downsample and fused_scale: + self.use_stride = True + self.stride = 2 + self.padding = 1 + else: + self.use_stride = False + self.stride = 1 + self.padding = kernel_size // 2 + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = wscale + else: + self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) + self.wscale = 1.0 + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'downsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'act={self.activation_type}') + + def forward(self, x): + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + + if self.use_stride: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 + x = F.conv2d(x, + weight=weight, + bias=self.bias, + stride=self.stride, + padding=self.padding) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + x = self.down(x) + + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + use_wscale, + wscale_gain, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + activation_type: Type of activation. + + Raises: + NotImplementedError: If the `activation_type` is not supported. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = wscale + else: + self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) + self.wscale = 1.0 + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def forward(self, x): + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + + x = F.linear(x, weight=weight, bias=self.bias) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/pggan_generator.py b/models/pggan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..771c5ee0e66304fcd21432a8d873e27b92ca1db5 --- /dev/null +++ b/models/pggan_generator.py @@ -0,0 +1,401 @@ +# python3.7 +"""Contains the implementation of generator described in PGGAN. + +Paper: https://arxiv.org/pdf/1710.10196.pdf + +Official TensorFlow implementation: +https://github.com/tkarras/progressive_growing_of_gans +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['PGGANGenerator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# pylint: disable=missing-function-docstring + +class PGGANGenerator(nn.Module): + """Defines the generator network in PGGAN. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the network: + + (1) resolution: The resolution of the output image. + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) z_dim: Dimension of the input latent space, Z. (default: 512) + (4) image_channels: Number of channels of the output image. (default: 3) + (5) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (6) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (7) fused_scale: Whether to fused `upsample` and `conv2d` together, + resulting in `conv2d_transpose`. (default: False) + (8) use_wscale: Whether to use weight scaling. (default: True) + (9) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (10) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (12) eps: A small value to avoid divide overflow. (default: 1e-8) + """ + + def __init__(self, + resolution, + init_res=4, + z_dim=512, + image_channels=3, + final_tanh=False, + label_dim=0, + fused_scale=False, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + fmaps_base=16 << 10, + fmaps_max=512, + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(self.init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(self.resolution)) + self.z_dim = z_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.label_dim = label_dim + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (self.z_dim,) + + # Number of convolutional layers. + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2 ** res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # First convolution layer for each resolution. + if res == self.init_res: + self.add_module( + f'layer{2 * block_idx}', + ConvLayer(in_channels=z_dim + label_dim, + out_channels=out_channels, + kernel_size=init_res, + padding=init_res - 1, + add_bias=True, + upsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu', + eps=eps)) + tf_layer_name = 'Dense' + else: + self.add_module( + f'layer{2 * block_idx}', + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + add_bias=True, + upsample=True, + fused_scale=fused_scale, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu', + eps=eps)) + tf_layer_name = 'Conv0_up' if fused_scale else 'Conv0' + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + + # Second convolution layer for each resolution. + self.add_module( + f'layer{2 * block_idx + 1}', + ConvLayer(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + add_bias=True, + upsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu', + eps=eps)) + tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + + # Output convolution layer for each resolution. + self.add_module( + f'output{block_idx}', + ConvLayer(in_channels=out_channels, + out_channels=image_channels, + kernel_size=1, + padding=0, + add_bias=True, + upsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=1.0, + activation_type='linear', + eps=eps)) + self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/weight') + self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, z, label=None, lod=None): + if z.ndim != 2 or z.shape[1] != self.z_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, latent_dim], where ' + f'`latent_dim` equals to {self.z_dim}!\n' + f'But `{z.shape}` is received!') + z = self.layer0.pixel_norm(z) + if self.label_dim: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with size {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + z = torch.cat((z, label), dim=1) + + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + x = z.view(z.shape[0], self.z_dim + self.label_dim, 1, 1) + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + current_lod = self.final_res_log2 - res_log2 + block_idx = res_log2 - self.init_res_log2 + if lod < current_lod + 1: + x = getattr(self, f'layer{2 * block_idx}')(x) + x = getattr(self, f'layer{2 * block_idx + 1}')(x) + if current_lod - 1 < lod <= current_lod: + image = getattr(self, f'output{block_idx}')(x) + elif current_lod < lod < current_lod + 1: + alpha = np.ceil(lod) - lod + temp = getattr(self, f'output{block_idx}')(x) + image = F.interpolate(image, scale_factor=2, mode='nearest') + image = temp * alpha + image * (1 - alpha) + elif lod >= current_lod + 1: + image = F.interpolate(image, scale_factor=2, mode='nearest') + if self.final_tanh: + image = torch.tanh(image) + + results = { + 'z': z, + 'label': label, + 'image': image, + } + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class UpsamplingLayer(nn.Module): + """Implements the upsampling layer. + + Basically, this layer can be used to upsample feature maps with nearest + neighbor interpolation. + """ + + def __init__(self, scale_factor): + super().__init__() + self.scale_factor = scale_factor + + def extra_repr(self): + return f'factor={self.scale_factor}' + + def forward(self, x): + if self.scale_factor <= 1: + return x + return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + Basically, this layer executes pixel-wise normalization, upsampling (if + needed), convolution, and activation in sequence. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding, + add_bias, + upsample, + fused_scale, + use_wscale, + wscale_gain, + activation_type, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + padding: Padding used in convolution. + add_bias: Whether to add bias onto the convolutional result. + upsample: Whether to upsample the input tensor before convolution. + fused_scale: Whether to fused `upsample` and `conv2d` together, + resulting in `conv2d_transpose`. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + activation_type: Type of activation. + eps: A small value to avoid divide overflow. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.add_bias = add_bias + self.upsample = upsample + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.activation_type = activation_type + self.eps = eps + + self.pixel_norm = PixelNormLayer(dim=1, eps=eps) + + if upsample and not fused_scale: + self.up = UpsamplingLayer(scale_factor=2) + else: + self.up = nn.Identity() + + if upsample and fused_scale: + self.use_conv2d_transpose = True + weight_shape = (in_channels, out_channels, kernel_size, kernel_size) + self.stride = 2 + self.padding = 1 + else: + self.use_conv2d_transpose = False + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + self.stride = 1 + + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = wscale + else: + self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) + self.wscale = 1.0 + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'padding={self.padding}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'upsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'act={self.activation_type}') + + def forward(self, x): + x = self.pixel_norm(x) + x = self.up(x) + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + if self.use_conv2d_transpose: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) + x = F.conv_transpose2d(x, + weight=weight, + bias=self.bias, + stride=self.stride, + padding=self.padding) + else: + x = F.conv2d(x, + weight=weight, + bias=self.bias, + stride=self.stride, + padding=self.padding) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/pigan_discriminator.py b/models/pigan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..31760d07330e2baa2e1f906616619d7331798dea --- /dev/null +++ b/models/pigan_discriminator.py @@ -0,0 +1,305 @@ +# python3.8 +"""Contains the implementation of discriminator described in StyleGAN. + +Paper: https://arxiv.org/pdf/1812.04948.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +import math + +__all__ = ['PiGANDiscriminator'] + + +class PiGANDiscriminator(nn.Module): + + def __init__(self, + resolution, + latent_dim=256, + label_dim=0, + embedding_dim=256, + normalize_embedding=True, + **kwargs): # from 4 * 2^0 to 4 * 2^7 4 -> 512 + super().__init__() + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.normalize_embedding = normalize_embedding + + self.register_buffer('lod', torch.zeros(())) + + self.use_embedding = label_dim > 0 and embedding_dim > 0 + if self.use_embedding > 0: + self.class_embedding = EqualLinear(label_dim, + embedding_dim, + bias=True, + bias_init=0, + lr_mul=1) + self.norm = PixelNormLayer(dim=1, eps=1e-8) + + self.layers = nn.ModuleList([ + ResidualCCBlock(32, 64), # 6 256x256 -> 128x128 + ResidualCCBlock(64, 128), # 5 128x128 -> 64x64 + ResidualCCBlock(128, 256), # 4 64x64 -> 32x32 + ResidualCCBlock(256, 400), # 3 32x32 -> 16x16 + ResidualCCBlock(400, 400), # 2 16x16 -> 8x8 + ResidualCCBlock(400, 400), # 1 8x8 -> 4x4 + ResidualCCBlock(400, 400), # 7 4x4 -> 2x2 + ]) + + self.fromRGB = nn.ModuleList([ + AdapterBlock(32), + AdapterBlock(64), + AdapterBlock(128), + AdapterBlock(256), + AdapterBlock(400), + AdapterBlock(400), + AdapterBlock(400), + AdapterBlock(400), + ]) + self.score_conv = nn.Conv2d( + 400, embedding_dim if self.use_embedding else max(label_dim, 1), 2) + self.latent_conv = nn.Conv2d(400, latent_dim, 2) + self.camera_conv = nn.Conv2d(400, 2, 2) + + self.img_size_to_layer = { + 2: 7, + 4: 6, + 8: 5, + 16: 4, + 32: 3, + 64: 2, + 128: 1, + 256: 0 + } + self.register_buffer('lod', torch.zeros(())) + + def forward(self, + input, + label=None, + options=None, + alpha=None, + enable_amp=False, + **kwargs): + + if self.label_dim > 0: + if label is None: + raise ValueError( + f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (input.shape[0], + self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'images ({input.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + if self.use_embedding: + embed = self.class_embedding(label) + if self.normalize_embedding: + embed = self.norm(embed) + + start = self.img_size_to_layer[input.shape[-1]] + + with autocast(enabled=enable_amp): + x = self.fromRGB[start](input) + + if kwargs.get('instance_noise', 0) > 0: + x = x + torch.randn_like(x) * kwargs['instance_noise'] + + for i, layer in enumerate(self.layers[start:]): + if i == 1 and alpha < 1: + down_image = F.interpolate(input, + scale_factor=0.5, + mode='nearest') + x = alpha * x + (1 - alpha) * self.fromRGB[start + + 1](down_image) + + x = layer(x) + + # x = self.final_layer(x).reshape(x.shape[0], -1) + score = self.score_conv(x).reshape(x.shape[0], -1) + if self.use_embedding: + score = (score * embed).sum(dim=1, keepdim=True) + score = score / np.sqrt(self.embedding_dim) + elif self.label_dim > 0: + score = (score * label).sum(dim=1, keepdim=True) + + latent = self.latent_conv(x).reshape(x.shape[0], -1) + position = self.camera_conv(x).reshape(x.shape[0], -1) + + results = { + 'score': score, + 'latent': latent, + 'camera': position, + } + return results + + +class ResidualCCBlock(nn.Module): + + def __init__(self, inplanes, planes, kernel_size=3): + super().__init__() + p = kernel_size // 2 + self.network = nn.Sequential( + CoordConv(inplanes, planes, kernel_size=kernel_size, padding=p), + nn.LeakyReLU(0.2, inplace=True), + CoordConv(planes, + planes, + kernel_size=kernel_size, + stride=2, + padding=p), nn.LeakyReLU(0.2, inplace=True)) + self.proj = nn.Conv2d(inplanes, planes, 1, stride=2) + + def init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight, + a=0.2, + mode='fan_in', + nonlinearity='leaky_relu') + + def forward(self, input): + y = self.network(input) + + identity = self.proj(input) + + y = (y + identity) / math.sqrt(2) + return y + + +class AdapterBlock(nn.Module): + + def __init__(self, output_channels): + super().__init__() + self.model = nn.Sequential(nn.Conv2d(3, output_channels, 1, padding=0), + nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, input): + return self.model(input) + + +class AddCoords(nn.Module): + """ + Source: + https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py + """ + + def __init__(self, with_r=False): + super().__init__() + self.with_r = with_r + + def forward(self, input_tensor): + """ + Args: + input_tensor: shape(batch, channel, x_dim, y_dim) + """ + batch_size, _, x_dim, y_dim = input_tensor.size() + + xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) + yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) + + xx_channel = xx_channel.float() / (x_dim - 1) + yy_channel = yy_channel.float() / (y_dim - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) + yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) + + ret = torch.cat([ + input_tensor, + xx_channel.type_as(input_tensor), + yy_channel.type_as(input_tensor) + ], + dim=1) + + if self.with_r: + rr = torch.sqrt( + torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2)) + ret = torch.cat([ret, rr], dim=1) + + return ret + + +class CoordConv(nn.Module): + """ + Source: + https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py + """ + + def __init__(self, in_channels, out_channels, with_r=False, **kwargs): + super().__init__() + self.addcoords = AddCoords(with_r=with_r) + in_size = in_channels + 2 + if with_r: + in_size += 1 + self.conv = nn.Conv2d(in_size, out_channels, **kwargs) + + def forward(self, x): + ret = self.addcoords(x) + ret = self.conv(ret) + return ret + + +class EqualLinear(nn.Module): + + def __init__( + self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, + self.weight * self.scale, + bias=self.bias * self.lr_mul) + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + + self.eps).rsqrt() + return x * scale \ No newline at end of file diff --git a/models/pigan_generator.py b/models/pigan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a107cb6bea3561a2f79d7d0204d3bd673f264db9 --- /dev/null +++ b/models/pigan_generator.py @@ -0,0 +1,514 @@ +# python3.7 +"""Contains the implementation of generator described in PiGAN.""" + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +from .utils.ops import all_gather +from .rendering.renderer import Renderer +from .rendering.feature_extractor import FeatureExtractor + +__all__ = ['PiGANGenerator'] + + +class PiGANGenerator(nn.Module): + """Defines the generator network in PiGAN.""" + def __init__(self, + # Settings for mapping network. + z_dim=256, + w_dim=256, + repeat_w=False, + normalize_z=False, + mapping_layers=3, + mapping_hidden_dim=256, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + normalize_embedding=True, + normalize_embedding_latent=False, + label_concat=True, + # Settings for synthesis network. + resolution=-1, + synthesis_input_dim=3, + synthesis_output_dim=256, + synthesis_layers=8, + grid_scale=0.24, + eps=1e-8, + # Settings for rendering module. + rendering_kwargs={}): + """Initializes with basic settings.""" + super().__init__() + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + + self.latent_dim = (z_dim,) + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + + self.resolution = resolution + self.num_layers = synthesis_layers + self.eps = eps + + if self.repeat_w: + self.mapping_space_dim = self.w_dim + else: + self.mapping_space_dim = self.w_dim * (self.num_layers + 1) + + # Mapping Network to tranform latent codes from Z-Space into W-Space. + self.mapping = MappingNetwork( + input_dim=z_dim, + output_dim=w_dim, + num_outputs=synthesis_layers + 1, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_hidden_dim, + label_dim=label_dim, + embedding_dim=embedding_dim, + normalize_embedding=normalize_embedding, + normalize_embedding_latent=normalize_embedding_latent, + eps=eps, + label_concat=label_concat, + lr=None) + + # Set up the overall renderer. + self.renderer = Renderer() + + # Set up the reference representation generator. + self.ref_representation_generator = None + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode='none') + + # Set up the post module in the feature extractor. + self.post_module = MLPNetwork(w_dim=w_dim, + in_channels=synthesis_input_dim, + num_layers=synthesis_layers, + out_channels=synthesis_output_dim, + grid_scale=grid_scale) + + # Set up the fully-connected layer head. + self.fc_head = FCHead(w_dim=w_dim, + channels=synthesis_output_dim, + mlp_length=self.post_module.mlp_length) + + # Set up the post neural renderer. + self.post_neural_renderer = None + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + # Set up some rendering related arguments. + self.rendering_kwargs = rendering_kwargs + + # Initialize weights. + self.init_weights() + + def init_weights(self): + self.mapping.init_weights() + self.post_module.init_weights() + self.fc_head.init_weights() + + def forward(self, + z, + label=None, + lod=None, + w_moving_decay=None, + sync_w_avg=False, + style_mixing_prob=None, + noise_std=None, + trunc_psi=None, + trunc_layers=None, + enable_amp=False): + if noise_std is not None: + self.rendering_kwargs.update(noise_std=noise_std) + + lod = self.post_module.lod.cpu().tolist() if lod is None else lod + + mapping_results = self.mapping(z, label) + w = mapping_results['w'] + wp = mapping_results.pop('wp') + + if self.training and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + # Truncation. + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + with autocast(enabled=enable_amp): + rendering_result = self.renderer( + wp=wp, + feature_extractor=self.feature_extractor, + rendering_options=self.rendering_kwargs, + position_encoder=None, + ref_representation=None, + post_module=self.post_module, + post_module_kwargs=dict(lod=lod), + fc_head=self.fc_head) + + image = rendering_result['composite_rgb'].reshape( + z.shape[0], self.resolution, self.resolution, + -1).permute(0, 3, 1, 2) + + camera = torch.cat([ + rendering_result['camera_polar'], + rendering_result['camera_azimuthal'] + ], -1) + + return { + **mapping_results, + 'image': image, + 'camera': camera, + 'latent': z + } + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping module. + + Basically, this module executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + label_dim, + embedding_dim, + normalize_embedding, + normalize_embedding_latent, + eps, + label_concat, + lr=None): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + # self.out_channels = out_channels + # TODO + # self.lr_mul = lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + self.eps = eps + self.label_concat = label_concat + + self.norm = PixelNormLayer(dim=1, eps=eps) + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + + if self.label_dim > 0: + if self.label_concat: + input_dim = input_dim + embedding_dim + self.embedding = EqualLinear(label_dim, + embedding_dim, + bias=True, + bias_init=0, + lr_mul=1) + else: + self.embedding = EqualLinear(label_dim, + output_dim, + bias=True, + bias_init=0, + lr_mul=1) + + network = [] + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + network.append(nn.Linear(in_channels, out_channels)) + network.append(nn.LeakyReLU(0.2, inplace=True)) + self.network = nn.Sequential(*network) + + def init_weights(self): + for module in self.network.modules(): + if isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight, + a=0.2, + mode='fan_in', + nonlinearity='leaky_relu') + + def forward(self, z, label=None): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + if self.normalize_input: + z = self.norm(z) + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + + embedding = self.embedding(label) + if self.normalize_embedding and self.label_concat: + embedding = self.norm(embedding) + if self.label_concat: + w = torch.cat((z, embedding), dim=1) + else: + w = z + else: + w = z + + if (self.label_dim > 0 and self.normalize_embedding_latent + and self.label_concat): + w = self.norm(w) + + for layer in self.network: + w = layer(w) + + if self.label_dim > 0 and (not self.label_concat): + w = w * embedding + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class MLPNetwork(nn.Module): + """Defines MLP Network in Pi-GAN.""" + def __init__(self, + w_dim, + in_channels, + num_layers, + out_channels, + grid_scale=0.24): + super().__init__() + + self.in_channels = in_channels + self.w_dim = w_dim + self.out_channels = out_channels + + self.register_buffer('lod', torch.zeros(())) + + self.grid_warper = UniformBoxWarp(grid_scale) + + network = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + out_channels = out_channels + film = FiLMLayer(in_channels, out_channels, w_dim) + network.append(film) + self.mlp_network = nn.Sequential(*network) + + self.mlp_length = len(self.mlp_network) + + def init_weights(self): + for module in self.modules(): + if isinstance(module, FiLMLayer): + module.init_weights() + + self.mlp_network[0].init_weights(first=True) + + def forward(self, pts, wp, lod=None): + num_dims = pts.ndim + assert num_dims in [3, 4, 5] + if num_dims == 5: + N, H, W, K, C = pts.shape + pts = pts.reshape(N, H * W * K, C) + elif num_dims == 4: + N, R, K, C = pts.shape + pts = pts.reshape(N, R * K, C) + + x = self.grid_warper(pts) + + for idx, layer in enumerate(self.mlp_network): + x = layer(x, wp[:, idx]) + + return x + + +class FCHead(nn.Module): + """Defines fully-connected layer head in Pi-GAN to decode `feature` into + `sigma` and `rgb`.""" + + def __init__(self, w_dim, channels, mlp_length): + super().__init__() + + self.w_dim = w_dim + self.channels = channels + self.mlp_length = mlp_length + + self.sigma_head = nn.Linear(channels, 1) + self.rgb_film = FiLMLayer(channels + 3, channels, w_dim) + self.rgb_head = nn.Linear(channels, 3) + + def init_weights(self,): + self.sigma_head.apply(freq_init(25)) + self.rgb_head.apply(freq_init(25)) + + self.rgb_film.init_weights() + + def forward(self, point_features, wp, dirs): + sigma = self.sigma_head(point_features) + + dirs = torch.cat([point_features, dirs], dim=-1) + rgb = self.rgb_film(dirs, wp[:, self.mlp_length]) + rgb = self.rgb_head(rgb).sigmoid() + + results = {'sigma': sigma, 'rgb': rgb} + + return results + + +class FiLMLayer(nn.Module): + def __init__(self, input_dim, output_dim, w_dim, **kwargs): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.w_dim = w_dim + + self.layer = nn.Linear(input_dim, output_dim) + self.style = nn.Linear(w_dim, output_dim*2) + + def init_weights(self, first=False): + # initial with 25 frequency + if not first: + self.layer.apply(freq_init(25)) + else: + self.layer.apply(first_film_init) + # kaiming initial && scale 1/4 + nn.init.kaiming_normal_(self.style.weight, + a=0.2, + mode='fan_in', + nonlinearity='leaky_relu') + with torch.no_grad(): self.style.weight *= 0.25 + + def extra_repr(self): + return (f'in_ch={self.input_dim}, ' + f'out_ch={self.output_dim}, ' + f'w_ch={self.w_dim}') + + def forward(self, x, wp): + x = self.layer(x) + style = self.style(wp) + style_split = style.unsqueeze(1).chunk(2, dim=2) + freq = style_split[0] + # Scale for sin activation + freq = freq*15 + 30 + phase_shift = style_split[1] + return torch.sin(freq * x + phase_shift) + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + +class UniformBoxWarp(nn.Module): + def __init__(self, sidelength): + super().__init__() + self.scale_factor = 2 / sidelength + + def forward(self, coordinates): + return coordinates * self.scale_factor + +def first_film_init(m): + with torch.no_grad(): + if isinstance(m, nn.Linear): + num_input = m.weight.size(-1) + m.weight.uniform_(-1/num_input, 1/num_input) + +def freq_init(freq): + def init(m): + with torch.no_grad(): + if isinstance(m, nn.Linear): + num_input = m.weight.size(-1) + m.weight.uniform_(-np.sqrt(6/num_input)/freq, + np.sqrt(6/num_input)/freq) + return init + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) \ No newline at end of file diff --git a/models/rendering/__init__.py b/models/rendering/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b549cde2c4bd771bf954b17e57e775ae22b40c37 --- /dev/null +++ b/models/rendering/__init__.py @@ -0,0 +1,10 @@ +# pyhton3.8 +"""Collects all functions for rendering.""" +from .renderer import Renderer +from .feature_extractor import FeatureExtractor +from .utils import interpolate_feature +from .point_sampler import PointSampler + +__all__ = [ + 'Renderer', 'FeatureExtractor', 'interpolate_feature', 'PointSampler' +] \ No newline at end of file diff --git a/models/rendering/__pycache__/__init__.cpython-37.pyc b/models/rendering/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..698a67a87bbd35bd274f41d08c77e5aa2e2239e5 Binary files /dev/null and b/models/rendering/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/__init__.cpython-39.pyc b/models/rendering/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..447308ab876457b9aedeb96333dabcae54f191da Binary files /dev/null and b/models/rendering/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/rendering/__pycache__/feature_extractor.cpython-37.pyc b/models/rendering/__pycache__/feature_extractor.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b346475e3eb13af27018001454f03e2624307e Binary files /dev/null and b/models/rendering/__pycache__/feature_extractor.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/feature_extractor.cpython-39.pyc b/models/rendering/__pycache__/feature_extractor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47d9f53e7abb513792648f1c0d3fd4bb94f5e872 Binary files /dev/null and b/models/rendering/__pycache__/feature_extractor.cpython-39.pyc differ diff --git a/models/rendering/__pycache__/integrator.cpython-37.pyc b/models/rendering/__pycache__/integrator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a3bb503097a1836c66c32d8bddf0718087b650 Binary files /dev/null and b/models/rendering/__pycache__/integrator.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/integrator.cpython-39.pyc b/models/rendering/__pycache__/integrator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807abd5e0696041a0cf5f92fb7f7378cf4601141 Binary files /dev/null and b/models/rendering/__pycache__/integrator.cpython-39.pyc differ diff --git a/models/rendering/__pycache__/point_sampler.cpython-37.pyc b/models/rendering/__pycache__/point_sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342b37338597fe46cf3e48a346da0c14b370e96b Binary files /dev/null and b/models/rendering/__pycache__/point_sampler.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/point_sampler.cpython-39.pyc b/models/rendering/__pycache__/point_sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2174edccb6e2a4fa72e1a35bdf60e30b9131fec Binary files /dev/null and b/models/rendering/__pycache__/point_sampler.cpython-39.pyc differ diff --git a/models/rendering/__pycache__/renderer.cpython-37.pyc b/models/rendering/__pycache__/renderer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce6552b9cb3ff8baeafb0dc187e11a07c98032c4 Binary files /dev/null and b/models/rendering/__pycache__/renderer.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/renderer.cpython-39.pyc b/models/rendering/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a9eeb5157b7f91ffc0d2db4248669c2a030a88f Binary files /dev/null and b/models/rendering/__pycache__/renderer.cpython-39.pyc differ diff --git a/models/rendering/__pycache__/triplane_sampler.cpython-37.pyc b/models/rendering/__pycache__/triplane_sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e39e27a0a191027c90a68c37596824414a5fa5b Binary files /dev/null and b/models/rendering/__pycache__/triplane_sampler.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/triplane_sampler.cpython-39.pyc b/models/rendering/__pycache__/triplane_sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..258ae46d42de82c73a90bc450eee314e9976c63b Binary files /dev/null and b/models/rendering/__pycache__/triplane_sampler.cpython-39.pyc differ diff --git a/models/rendering/__pycache__/utils.cpython-37.pyc b/models/rendering/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7916d6de851dd7bf8ece9bcd1d1cef158caf0092 Binary files /dev/null and b/models/rendering/__pycache__/utils.cpython-37.pyc differ diff --git a/models/rendering/__pycache__/utils.cpython-39.pyc b/models/rendering/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb1063232c85f1b4ea563832e58c0debf82a587b Binary files /dev/null and b/models/rendering/__pycache__/utils.cpython-39.pyc differ diff --git a/models/rendering/feature_extractor.py b/models/rendering/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..41645fe014710ac520d68ea3de266352dc234f2a --- /dev/null +++ b/models/rendering/feature_extractor.py @@ -0,0 +1,174 @@ +# python3.8 +"""Defines feature extractor in 3D generation pipeline.""" + +import torch +from .triplane_sampler import TriplaneSampler +from .utils import interpolate_feature +from einops import rearrange +import math + +__all__ = ['FeatureExtractor'] + + +_REF_MODE = ['none', 'tri_plane', 'feature_volume', 'bev_plane_clevr_256', 'bev_plane_clevr_512', 'bev_plane_carla'] + + +class FeatureExtractor(torch.nn.Module): + """Defines the feature extractor in 3D Generation Pipeline. + + Basically, the feature extractor takes in the latent code and sampled points + in addition to the reference representation as input, and outputs the + feature representation which contains information of each point's color and + density. + + """ + + def __init__(self, ref_mode='none', xyz_pe=False, reverse_xy=True): + super().__init__() + self.ref_mode = ref_mode + self.xyz_pe = xyz_pe + self.reverse_xy = reverse_xy + assert ref_mode in _REF_MODE + if ref_mode == 'tri_plane': + self.plane_axes = TriplaneSampler.generate_planes() + + def forward(self, + wp, + points, + rendering_options, + position_encoder=None, + ref_representation=None, + post_module=None, + post_module_kwargs={}, + ray_dirs=None, + cam_matrix=None,): + assert points.ndim in [3, 4] + if points.ndim == 3: + points = points.unsqueeze(2) # shape: [N, R, C] -> [N, R, 1, C] + N, R, K, _ = points.shape[:4] + # (Optional) Positional encoding. + if position_encoder is not None: + points_encoding = position_encoder(points) # shape: [N, R, K, C]. + points_encoding = rearrange(points_encoding, + 'N R K C -> N C (R K) 1').contiguous() + + # Reshape `points` with shape [N, R*K, 3]. + points = points.reshape(points.shape[0], -1, points.shape[-1]) + + # Get pre-point-features by sampling from + # the reference representation (if exists). + pre_point_features = points + if ref_representation is not None: + assert self.ref_mode is not None + if self.ref_mode == 'tri_plane': + pre_point_features = TriplaneSampler.sample_from_planes( + self.plane_axes.to(points.device), + ref_representation, + points, + padding_mode='zeros', + box_warp=rendering_options.get('box_warp', 1.0)) + # shape: [N, 3, num_points, C], where num_points = H*W*K. + elif self.ref_mode == 'feature_volume': + bounds = rendering_options.get( + 'bounds', + [[-0.1886, -0.1671, -0.1956], [0.1887, 0.1692, 0.1872]]) + bounds = torch.Tensor(bounds).to(points.device) + pre_point_features = interpolate_feature( + points, ref_representation, bounds) # shape: [N, C, R*K]. + pre_point_features = pre_point_features.unsqueeze(-1) + # shape: [N, C, R*K, 1]. + post_module_kwargs.update(points_encoding=points_encoding) + elif 'bev_plane_clevr' in self.ref_mode: + h = w = int(self.ref_mode[-3:]) + # first, transform points from world coordinates to bev coordinates + # cam_matrix: N, 4, 4 + # points: N, 3, R*K + + points_reshape = points # N, R*K, 3 + # points_homo = torch.cat([points_reshape, torch.ones([*points_reshape.shape[:2], 1]).to(points_reshape.device)], -1) + # points_cam = torch.einsum('nxy,nby->nbx', cam_matrix, points_homo) # N, R*K, 4 + + if self.reverse_xy: + y = (0.5 * w - 128 + 256 - (points_reshape[..., 0] /9 + .5) * 256 ) / w * 2 - 1 + x = (0.5 * h - 128 + (points_reshape[..., 1] /9 + .5) * 256 ) / h * 2 - 1 + else: + x = (0.5 * w - 128 + 256 - (points_reshape[..., 0] /9 + .5) * 256 ) / w * 2 - 1 + y = (0.5 * h - 128 + (points_reshape[..., 1] /9 + .5) * 256 ) / h * 2 - 1 + z = points_reshape[..., -1] / 9 + points_bev = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], -1) + + # second, sample feature from BEV map + # ref_representation: N, C, A, A + # points_bev: N, R*K, 3 + xy = points_bev[..., :2] # N, R*K, 2 + xy = xy.unsqueeze(2) # N, R*K, 1, 2 + feat_xy = torch.nn.functional.grid_sample(ref_representation, xy, mode='bilinear', + padding_mode='zeros', align_corners=False) # N, C, R*K, 1 + feat_xy = feat_xy.squeeze(3) # N, C,R*K + x = points_bev[..., 0] # N, R*K + y = points_bev[..., 1] # N, R*K + z = points_bev[..., -1] # N, R*K + + # third, do positional encoding on z + d_model = 32 + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *-(math.log(10000.0) / d_model))).to(z.device) + + pe_x = torch.zeros([x.shape[0], x.shape[1], d_model]).to(x.device) + pe_x[..., 0::2] = torch.sin(x.unsqueeze(-1).float() * div_term) + pe_x[..., 1::2] = torch.cos(x.unsqueeze(-1).float() * div_term) + pe_y = torch.zeros([y.shape[0], y.shape[1], d_model]).to(y.device) + pe_y[..., 0::2] = torch.sin(y.unsqueeze(-1).float() * div_term) + pe_y[..., 1::2] = torch.cos(y.unsqueeze(-1).float() * div_term) + pe_z = torch.zeros([z.shape[0], z.shape[1], d_model]).to(z.device) + pe_z[..., 0::2] = torch.sin(z.unsqueeze(-1).float() * div_term) + pe_z[..., 1::2] = torch.cos(z.unsqueeze(-1).float() * div_term) + if self.xyz_pe: + feat_xyz = torch.cat([feat_xy, pe_x.permute(0, 2, 1), pe_y.permute(0,2,1),pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K + else: + feat_xyz = torch.cat([feat_xy ,pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K + pre_point_features = feat_xyz.permute(0, 2, 1) # N, RK, C+d_model + pre_point_features = pre_point_features.view(N, R, K, -1) + elif self.ref_mode == 'bev_plane_carla': + x = (217.5 - 8 * points[..., 0]) / 128 - 1 + y = (128.0 + 8 * points[..., 1]) / 128 - 1 + z = points[..., 2] + points_bev = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], -1) + + xy = points_bev[..., :2] + xy = xy.unsqueeze(2) + feat_xy = torch.nn.functional.grid_sample(ref_representation, xy, mode='bilinear',padding_mode='zeros', align_corners=False) + feat_xy = feat_xy.squeeze(3) + z = points_bev[..., -1] + d_model = 32 + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *-(math.log(10000.0) / d_model))).to(z.device) + pe_x = torch.zeros([x.shape[0], x.shape[1], d_model]).to(x.device) + pe_x[..., 0::2] = torch.sin(x.unsqueeze(-1).float() * div_term) + pe_x[..., 1::2] = torch.cos(x.unsqueeze(-1).float() * div_term) + pe_y = torch.zeros([y.shape[0], y.shape[1], d_model]).to(y.device) + pe_y[..., 0::2] = torch.sin(y.unsqueeze(-1).float() * div_term) + pe_y[..., 1::2] = torch.cos(y.unsqueeze(-1).float() * div_term) + pe_z = torch.zeros([z.shape[0], z.shape[1], d_model]).to(z.device) + pe_z[..., 0::2] = torch.sin(z.unsqueeze(-1).float() * div_term) + pe_z[..., 1::2] = torch.cos(z.unsqueeze(-1).float() * div_term) + if self.xyz_pe: + feat_xyz = torch.cat([feat_xy, pe_x.permute(0, 2, 1), pe_y.permute(0,2,1),pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K + else: + feat_xyz = torch.cat([feat_xy ,pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K + pre_point_features = feat_xyz.permute(0, 2, 1) # N, RK, C+d_model + pre_point_features = pre_point_features.view(N, R, K, -1) + else: + raise NotImplementedError + + # Get post-point-features by feeding pre-point-features into the + # post-module (if exists). + if post_module is not None: + post_point_features = post_module(pre_point_features, wp, + **post_module_kwargs) + else: + post_point_features = pre_point_features + + if post_point_features.ndim == 2: + post_point_features = rearrange('(N R K) C -> N R K C', + N=N, R=R, K=K).contiguous() + + return post_point_features diff --git a/models/rendering/integrator.py b/models/rendering/integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b57b5c272ed08a71ce01813b50b1d5c5b505d3 --- /dev/null +++ b/models/rendering/integrator.py @@ -0,0 +1,116 @@ +# python 3.7 +"""Contains the function to march rays (integration).""" + +import torch +import torch.nn.functional as F + +__all__ = ['Integrator'] + + +class Integrator(torch.nn.Module): + """Defines the class to help march rays, i.e. do integral along each ray. + + The ray marcher takes the raw output of the implicit representation + (including colors(i.e. rgbs) and densities(i.e. sigmas)) and uses the + volume rendering equation to produce composited colors and depths. + """ + + def __init__(self): + super().__init__() + + def integration(self, rgbs, sigmas, depths, rendering_options): + """Integrate the values along the ray. + + `N` denotes batch size. + `R` denotes the number of rays, equals `H * W`. + `K` denotes the number of points on each ray. + + Args: + rgbs (torch.tensor): colors' value of each point in the fields, with + shape [N, R, K, 3]. + sigmas (torch.tensor): densities' value of each point in the fields, + with shape [N, R, K, 1]. + depths (torch.tensor): depths' value of each point in the fields, + with shape [N, R, K, 1]. + rendering_options (dict): Additional keyword arguments of rendering + option. + + Returns: + A dictionary, containing + - `composite_rgb`: camera radius w.r.t. the world coordinate + system, with shape [N, R, 3]. + - `composite_depth`: camera polar w.r.t. the world coordinate + system, with shape [N, R, 1]. + - `weights`: importance weights of each point in the field, + with shape [N, R, K, 1]. + """ + num_dims = rgbs.ndim + assert num_dims == 4 + assert sigmas.ndim == num_dims and depths.ndim == num_dims + + N, R, K = rgbs.shape[:3] + + # Get deltas for rendering. + deltas = depths[:, :, 1:] - depths[:, :, :-1] + if rendering_options.get('use_max_depth', False): + max_depth = rendering_options.get('max_depth', None) + if max_depth is not None: + delta_inf = max_depth - deltas[:, :, -1:] + else: + delta_inf = 1e10 * torch.ones_like(deltas[:, :, :1]) + deltas = torch.cat([deltas, delta_inf], -2) + if rendering_options.get('no_dist', False): + deltas[:] = 1 + + use_mid_point = rendering_options.get('use_mid_point', True) + if use_mid_point: + rgbs = (rgbs[:, :, :-1] + rgbs[:, :, 1:]) / 2 + sigmas = (sigmas[:, :, :-1] + sigmas[:, :, 1:]) / 2 + depths = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + clamp_mode = rendering_options.get('clamp_mode', 'mipnerf') + if clamp_mode == 'softplus': + sigmas = F.softplus(sigmas) + elif clamp_mode == 'relu': + sigmas = F.relu(sigmas) + elif clamp_mode == 'mipnerf': + sigmas = F.softplus(sigmas - 1) + else: + raise ValueError(f'Invalid clamping mode: `{clamp_mode}`!\n') + + alphas = 1 - torch.exp(- deltas * sigmas) + alphas_shifted = torch.cat( + [torch.ones_like(alphas[:, :, :1]), 1 - alphas + 1e-10], -2) + weights = alphas * torch.cumprod(alphas_shifted, -2)[:, :, :-1] + weights_sum = weights.sum(2) + if rendering_options.get('last_back', False): + weights[:, :, -1] = weights[:, :, -1] + (1 - weights_sum) + + composite_rgb = torch.sum(weights * rgbs, -2) + composite_depth = torch.sum(weights * depths, -2) + + if rendering_options.get('normalize_rgb', False): + composite_rgb = composite_rgb / weights_sum + if rendering_options.get('normalize_depth', True): + composite_depth = composite_depth / weights_sum + if rendering_options.get('clip_depth', True): + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + composite_depth = torch.clip(composite_depth, torch.min(depths), + torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weights_sum + + composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + results = { + 'composite_rgb': composite_rgb, + 'composite_depth': composite_depth, + 'weights': weights + } + + return results + + def forward(self, rgbs, sigmas, depths, rendering_options): + results = self.integration(rgbs, sigmas, depths, rendering_options) + return results diff --git a/models/rendering/point_sampler.py b/models/rendering/point_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..920e43f3043b6c8f037d87bab206995e1e811e58 --- /dev/null +++ b/models/rendering/point_sampler.py @@ -0,0 +1,1046 @@ +# python3.7 +"""Contains the functions to sample points in 3D space.""" + +import numpy as np + +import torch +import torch.nn.functional as F + +__all__ = [ + 'PointSampler' +] + +_POINT_SAMPLING_STRATEGIES = [ + 'uniform', 'normal', 'ray_dependent', 'point_dependent' +] + +_POINT_PERTURBING_STRATEGIES = [ + 'no', 'middle_uniform', 'uniform', 'self_uniform' +] + +_TENSOR_SAMPLING_STRATEGIES = [ + 'fix', 'uniform', 'normal', 'hybrid', 'truncated_normal' +] + + +class PointSampler(torch.nn.Module): + """Defines the class to help sample points. + + This class implements the `forward()` function for point sampling, which + includes the following steps: + + 1. Sample rays in the camera coordinate system. + 2. Sample points on each ray. + 3. Perturb points on each ray. + 4. Sample camera extrinsics. + 5. Transform points to the world coordinate system. + """ + + def __init__(self, + num_points=16, + fov=30, + image_boundary_value=1.0, + cam_look_at_dir=-1, + pixel_center=False, + y_descending=True, + # Point sampling (i.e., radial distance w.r.t. camera) related. + sampling_strategy='uniform', + focal=None, + dis_min=None, + dis_max=None, + dis_mean=None, + dis_stddev=None, + per_ray_ref=None, + per_point_ref=None, + perturbation_strategy='middle_uniform', + # Camera sampling related. + radius_strategy='fix', + radius_fix=None, + radius_min=None, + radius_max=None, + radius_mean=None, + radius_stddev=None, + polar_strategy='uniform', + polar_fix=None, + polar_min=None, + polar_max=None, + polar_mean=None, + polar_stddev=None, + azimuthal_strategy='uniform', + azimuthal_fix=None, + azimuthal_min=None, + azimuthal_max=None, + azimuthal_mean=None, + azimuthal_stddev=None, + use_spherical_uniform_position=False, + pitch_strategy='fix', + pitch_fix=0, + pitch_min=None, + pitch_max=None, + pitch_mean=None, + pitch_stddev=None, + yaw_strategy='fix', + yaw_fix=0, + yaw_min=None, + yaw_max=None, + yaw_mean=None, + yaw_stddev=None, + roll_strategy='fix', + roll_fix=0, + roll_min=None, + roll_max=None, + roll_mean=None, + roll_stddev=None): + """Initializes hyper-parameters for point sampling. + + Detailed description of each argument can be found in functions + `get_ray_per_pixel()`, `sample_points_per_ray()`, + `perturb_points_per_ray()`, `sample_camera_extrinsics()`. + """ + super().__init__() + self.num_points = num_points + self.fov = fov + self.image_boundary_value = image_boundary_value + self.cam_look_at_dir = cam_look_at_dir + self.pixel_center = pixel_center + self.y_descending = y_descending + + self.sampling_strategy = sampling_strategy + self.dis_min = dis_min + self.dis_max = dis_max + self.dis_mean = dis_mean + self.dis_stddev = dis_stddev + self.per_ray_ref = per_ray_ref + self.per_point_ref = per_point_ref + self.perturbation_strategy = perturbation_strategy + + self.radius_strategy = radius_strategy + self.radius_fix = radius_fix + self.radius_min = radius_min + self.radius_max = radius_max + self.radius_mean = radius_mean + self.radius_stddev = radius_stddev + self.polar_strategy = polar_strategy + self.polar_fix = polar_fix + self.polar_min = polar_min + self.polar_max = polar_max + self.polar_mean = polar_mean + self.polar_stddev = polar_stddev + self.azimuthal_strategy = azimuthal_strategy + self.azimuthal_fix = azimuthal_fix + self.azimuthal_min = azimuthal_min + self.azimuthal_max = azimuthal_max + self.azimuthal_mean = azimuthal_mean + self.azimuthal_stddev = azimuthal_stddev + self.use_spherical_uniform_position = use_spherical_uniform_position + self.pitch_strategy = pitch_strategy + self.pitch_fix = pitch_fix + self.pitch_min = pitch_min + self.pitch_max = pitch_max + self.pitch_mean = pitch_mean + self.pitch_stddev = pitch_stddev + self.yaw_strategy = yaw_strategy + self.yaw_fix = yaw_fix + self.yaw_min = yaw_min + self.yaw_max = yaw_max + self.yaw_mean = yaw_mean + self.yaw_stddev = yaw_stddev + self.roll_strategy = roll_strategy + self.roll_fix = roll_fix + self.roll_min = roll_min + self.roll_max = roll_max + self.roll_mean = roll_mean + self.roll_stddev = roll_stddev + self.focal = focal + + def forward(self, + batch_size, + image_size, + focal=None, + cam2world_matrix=None, + **kwargs): + """Samples points. + + `K` denotes the number of points on each ray. + + Args: + batch_size: Batch size of images. Denoted as `N`. + image_size: Size of the image. One element indicates square image, + while two elements stand for height and width respectively. + Denoted as `H` and `W`. + **kwargs: Additional keyword arguments to override the variables + initialized in `__init__()`. + + Returns: + A dictionary, containing + - `camera_radius`: camera radius w.r.t. the world coordinate + system, with shape [N]. + - `camera_polar`: camera polar w.r.t. the world coordinate + system, with shape [N]. + - `camera_azimuthal`: camera azimuthal w.r.t. the world + coordinate system, with shape [N]. + - `camera_pitch`: camera pitch w.r.t. the camera coordinate + system, with shape [N]. + - `camera_yaw`: camera yaw w.r.t. the camera coordinate system, + with shape [N]. + - `camera_roll`: camera roll w.r.t. the camera coordinate + system, with shape [N]. + - `camera_pos`: camera position, i.e., the (x, y, z) coordinate + in the world coordinate system, with shape [N, 3]. + - `cam2world_matrix`: transformation matrix to transform the + camera coordinate system to the world coordinate system, + with shape [N, 4, 4]. + - `rays_camera`: ray directions in the camera coordinate system, + with shape [N, H, W, 3]. + - `rays_world`: ray directions in the world coordinate system, + with shape [N, H, W, 3]. + - `radii_raw`: raw per-point radial distance w.r.t. the camera + position, with shape [N, H, W, K]. + - `radii`: per-point radial distance after perturbation w.r.t. + the camera position, with shape [N, H, W, K]. + - `points_camera`: per-point coordinate in the camera coordinate + system, with shape [N, H, W, K, 3]. + - `points_world`: per-point coordinate in the world coordinate + system, with shape [N, H, W, K, 3]. + """ + num_points = kwargs.get('num_points', self.num_points) + fov = kwargs.get('fov', self.fov) + image_boundary_value = kwargs.get( + 'image_boundary_value', self.image_boundary_value) + cam_look_at_dir = kwargs.get('cam_look_at_dir', self.cam_look_at_dir) + pixel_center = kwargs.get('pixel_center', self.pixel_center) + y_descending = kwargs.get('y_descending', self.y_descending) + sampling_strategy = kwargs.get( + 'sampling_strategy', self.sampling_strategy) + dis_min = kwargs.get('dis_min', self.dis_min) + dis_max = kwargs.get('dis_max', self.dis_max) + dis_mean = kwargs.get('dis_mean', self.dis_mean) + dis_stddev = kwargs.get('dis_stddev', self.dis_stddev) + per_ray_ref = kwargs.get('per_ray_ref', self.per_ray_ref) + per_point_ref = kwargs.get('per_point_ref', self.per_point_ref) + perturbation_strategy = kwargs.get( + 'perturbation_strategy', self.perturbation_strategy) + radius_strategy = kwargs.get('radius_strategy', self.radius_strategy) + radius_fix = kwargs.get('radius_fix', self.radius_fix) + radius_min = kwargs.get('radius_min', self.radius_min) + radius_max = kwargs.get('radius_max', self.radius_max) + radius_mean = kwargs.get('radius_mean', self.radius_mean) + radius_stddev = kwargs.get('radius_stddev', self.radius_stddev) + polar_strategy = kwargs.get('polar_strategy', self.polar_strategy) + polar_fix = kwargs.get('polar_fix', self.polar_fix) + polar_min = kwargs.get('polar_min', self.polar_min) + polar_max = kwargs.get('polar_max', self.polar_max) + polar_mean = kwargs.get('polar_mean', self.polar_mean) + polar_stddev = kwargs.get('polar_stddev', self.polar_stddev) + azimuthal_strategy = kwargs.get( + 'azimuthal_strategy', self.azimuthal_strategy) + azimuthal_fix = kwargs.get('azimuthal_fix', self.azimuthal_fix) + azimuthal_min = kwargs.get('azimuthal_min', self.azimuthal_min) + azimuthal_max = kwargs.get('azimuthal_max', self.azimuthal_max) + azimuthal_mean = kwargs.get('azimuthal_mean', self.azimuthal_mean) + azimuthal_stddev = kwargs.get('azimuthal_stddev', self.azimuthal_stddev) + use_spherical_uniform_position = kwargs.get( + 'use_spherical_uniform_position', + self.use_spherical_uniform_position) + pitch_strategy = kwargs.get('pitch_strategy', self.pitch_strategy) + pitch_fix = kwargs.get('pitch_fix', self.pitch_fix) + pitch_min = kwargs.get('pitch_min', self.pitch_min) + pitch_max = kwargs.get('pitch_max', self.pitch_max) + pitch_mean = kwargs.get('pitch_mean', self.pitch_mean) + pitch_stddev = kwargs.get('pitch_stddev', self.pitch_stddev) + yaw_strategy = kwargs.get('yaw_strategy', self.yaw_strategy) + yaw_fix = kwargs.get('yaw_fix', self.yaw_fix) + yaw_min = kwargs.get('yaw_min', self.yaw_min) + yaw_max = kwargs.get('yaw_max', self.yaw_max) + yaw_mean = kwargs.get('yaw_mean', self.yaw_mean) + yaw_stddev = kwargs.get('yaw_stddev', self.yaw_stddev) + roll_strategy = kwargs.get('roll_strategy', self.roll_strategy) + roll_fix = kwargs.get('roll_fix', self.roll_fix) + roll_min = kwargs.get('roll_min', self.roll_min) + roll_max = kwargs.get('roll_max', self.roll_max) + roll_mean = kwargs.get('roll_mean', self.roll_mean) + roll_stddev = kwargs.get('roll_stddev', self.roll_stddev) + + rays_camera = get_ray_per_pixel(batch_size=batch_size, + image_size=image_size, + fov=fov, + boundary=image_boundary_value, + focal=focal, + cam_look_at_dir=cam_look_at_dir, + pixel_center=pixel_center, + y_descending=y_descending) + + radii_raw = sample_points_per_ray(batch_size=batch_size, + image_size=image_size, + num_points=num_points, + strategy=sampling_strategy, + dis_min=dis_min, + dis_max=dis_max, + dis_mean=dis_mean, + dis_stddev=dis_stddev, + per_ray_ref=per_ray_ref, + per_point_ref=per_point_ref) + radii = perturb_points_per_ray(radii=radii_raw, + strategy=perturbation_strategy) + + camera_info = {} + if cam2world_matrix is not None: + camera_info.update(dict( + cam2world_matrix=cam2world_matrix, + radius=None, + polar=None, + azimuthal=None, + pitch=None, + yaw=None, + roll=None, + camera_pos=None, + )) + else: + camera_info = sample_camera_extrinsics( + batch_size=batch_size, + radius_strategy=radius_strategy, + radius_fix=radius_fix, + radius_min=radius_min, + radius_max=radius_max, + radius_mean=radius_mean, + radius_stddev=radius_stddev, + polar_strategy=polar_strategy, + polar_fix=polar_fix, + polar_min=polar_min, + polar_max=polar_max, + polar_mean=polar_mean, + polar_stddev=polar_stddev, + azimuthal_strategy=azimuthal_strategy, + azimuthal_fix=azimuthal_fix, + azimuthal_min=azimuthal_min, + azimuthal_max=azimuthal_max, + azimuthal_mean=azimuthal_mean, + azimuthal_stddev=azimuthal_stddev, + use_spherical_uniform_position=use_spherical_uniform_position, + pitch_strategy=pitch_strategy, + pitch_fix=pitch_fix, + pitch_min=pitch_min, + pitch_max=pitch_max, + pitch_mean=pitch_mean, + pitch_stddev=pitch_stddev, + yaw_strategy=yaw_strategy, + yaw_fix=yaw_fix, + yaw_min=yaw_min, + yaw_max=yaw_max, + yaw_mean=yaw_mean, + yaw_stddev=yaw_stddev, + roll_strategy=roll_strategy, + roll_fix=roll_fix, + roll_min=roll_min, + roll_max=roll_max, + roll_mean=roll_mean, + roll_stddev=roll_stddev) + + points = get_point_coord( + rays_camera=rays_camera, + radii=radii, + cam2world_matrix=camera_info['cam2world_matrix']) + + return { + 'camera_radius': camera_info['radius'], # [N] + 'camera_polar': camera_info['polar'], # [N] + 'camera_azimuthal': camera_info['azimuthal'], # [N] + 'camera_pitch': camera_info['pitch'], # [N] + 'camera_yaw': camera_info['yaw'], # [N] + 'camera_roll': camera_info['roll'], # [N] + 'camera_pos':camera_info['camera_pos'], # [N, 3] + 'cam2world_matrix': camera_info['cam2world_matrix'], # [N, 4, 4] + 'rays_camera': rays_camera, # [N, H, W, 3] + 'rays_world': points['rays_world'], # [N, H, W, 3] + 'ray_origins_world': points['ray_origins_world'], # [N, H, W, 3] + 'radii_raw': radii_raw, # [N, H, W, K] + 'radii': radii, # [N, H, W, K] + 'points_camera': points['points_camera'], # [N, H, W, K, 3] + 'points_world': points['points_world'] # [N, H, W, K, 3] + } + + +def get_ray_per_pixel(batch_size, + image_size, + fov, + boundary=1.0, + focal=None, + cam_look_at_dir=-1, + pixel_center=False, + y_descending=True): + """Gets ray direction for each image pixel under camera coordinate system. + + Each ray direction is represent by a vector, [x, y, z], under the following + coordinate system: + + - The origin is set at the camera position. + - The X axis is set as the horizontal direction of the image plane, with + larger value on the right. + - The Y axis is set as the vertical direction of the image plane, with + larger value on the top. + - The Z axis is set as the direction perpendicular to the image plane, + from the image center pointing to the camera. In other words, the z + coordinate of the image plane is negative. + - The above coordinate system is a right-hand one. + + Taking a 5x5 image (with boundary 1.0) as an instance, the per-pixel (x, y) + coordinates should look like: + + (-1.0, 1.0) (-0.5, 1.0) (0.0, 1.0) (0.5, 1.0) (1.0, 1.0) + (-1.0, 0.5) (-0.5, 0.5) (0.0, 0.5) (0.5, 0.5) (1.0, 0.5) + (-1.0, 0.0) (-0.5, 0.0) (0.0, 0.0) (0.5, 0.0) (1.0, 0.0) + (-1.0, -0.5) (-0.5, -0.5) (0.0, -0.5) (0.5, -0.5) (1.0, -0.5) + (-1.0, -1.0) (-0.5, -1.0) (0.0, -1.0) (0.5, -1.0) (1.0, -1.0) + + NOTE: + The X-axis focal and Y-axis focal are assumed to be the same according + to the pinhole camera model. + + Args: + batch_size: Batch size of images, each of which has the same ray + directions. Denoted as `N`. + image_size: Size of the image. One element indicates square image, while + two elements stand for height and width respectively. Denoted as `H` + and `W`. + fov: Field of view (along X axis) of the camera, in unit of degree. + boundary: The maximum value of the X coordinate. Defaults to `1.0`. + focal (optional): Focal Length of camera. If given, it will cover the + focal calculated by `fov`. Note that the focal is a normalized one + which is divided by size of the image. + cam_look_at_dir: Direction of camera looks at. Defaults to `-1`, which + means camera looks at `-z` direction. + pixel_center: Whether rays originate from the pixel center or not. For + example, assume a pixel is at (H, W). If `pixel_center` is set + `True`, then the ray originate from (H+0.5, W+0.5), otherwise it + originate from (H, W). + y_descending: Whether the Y axis is in descending order from top to + bottom. If set `True`, the coordinates are the same as the above + example. If set `False`, the coordinate system is consistent with + 2D image plane coordinate system, where Y axis is in ascending + order. Defaults to `True`. + Returns: + A tensor, with shape [N, H, W, 3], representing the per-pixel ray + direction. Each direction is normalized to a unit vector. + """ + # Check inputs. + assert isinstance(batch_size, int) and batch_size > 0 + N = batch_size + assert isinstance(image_size, (int, list, tuple)) + if isinstance(image_size, int): + H = image_size + W = image_size + else: + H, W = image_size + assert isinstance(H, int) and H > 0 + assert isinstance(W, int) and W > 0 + assert 0 < fov < 180 + assert boundary > 0 + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Get (x, y) grid by boundary. + max_x = boundary + max_y = boundary / W * H + if pixel_center: + y, x = torch.meshgrid( + torch.linspace(max_y - 0.5 / H, -max_y + 0.5 / H, H, + device=device), + torch.linspace(-max_x + 0.5 / W, max_x - 0.5 / W, W, + device=device)) + else: + y, x = torch.meshgrid(torch.linspace(max_y, -max_y, H, device=device), + torch.linspace(-max_x, max_x, W, device=device)) + # Get z coordinate of the image plane by focal (i.e., FOV). + if not y_descending: + y = -y + if focal is None: + focal = boundary / np.tan((2 * np.pi * fov / 360) / 2) + z = np.sign(cam_look_at_dir) * focal * torch.ones_like(x) # [H, W] + # Normalize directions to unit vectors. + rays = F.normalize(torch.stack([x, y, z], dim=-1), dim=-1) # [H, W, 3] + + return rays.unsqueeze(0).repeat(N, 1, 1, 1) # [N, H, W, 3] + + +def sample_points_per_ray(batch_size, + image_size, + num_points, + strategy='uniform', + dis_min=None, + dis_max=None, + dis_mean=None, + dis_stddev=None, + per_ray_ref=None, + per_point_ref=None): + """Samples per-point radial distance on each ray. + + This function is independent of ray directions, hence, each point is + represent by a number, indicating its radial distance to the origin (i.e., + the camera). + + The following sampling strategies are supported: + + - `uniform`: + For each ray, the points uniformly locate in range `[dis_min, dis_max]`. + + - `normal`: + For each ray, the points are sampled subject to + `Gaussian(dis_mean, dis_stddev^2)`. + + - `ray_dependent`: + Each ray follows a separate strategy, controlled by `per_ray_ref`. + + - `point_dependent`: + Each point follows a separate strategy, controlled by `per_point_ref`. + + Args: + batch_size: Batch size of images, for which points are sampled + independently. Denoted as `N`. + image_size: Size of the image. One element indicates square image, while + two elements stand for height and width respectively. Denoted as `H` + and `W`. + num_points: Number of points sampled on each ray. Denoted as `K`. + strategy: Strategy for point sampling. Defaults to `uniform`. + dis_min: Minimum radial distance (with camera as the origin) for each + point. Defaults to `None`. + dis_max: Maximum radial distance (with camera as the origin) for each + point. Defaults to `None`. + dis_mean: Mean radial distance (with camera as the origin) for each + point. Defaults to `None`. + dis_stddev: Standard deviation of the radial distance (with camera as + the origin) for each point. Defaults to `None`. + per_ray_ref: Reference for each ray, which will guide the sampling + process. Shape [N, H, W, c] is expected, where `c` is the dimension + of a single reference. Defaults to `None`. + per_point_ref: Reference for each point, which will guide the sampling + process. Shape [N, H, W, K, c] is expected, where `c` is the + dimension of a single reference. Defaults to `None`. + + Returns: + A tensor, with shape [N, H, W, K], representing the per-point radial + distance on each ray. All numbers should be positive, and the + distances on each ray should follow a non-descending order. + + Raises: + ValueError: If the sampling strategy is not supported. + NotImplementedError: If the sampling strategy is not implemented. + """ + # Check inputs. + assert isinstance(batch_size, int) and batch_size > 0 + N = batch_size + assert isinstance(image_size, (int, list, tuple)) + if isinstance(image_size, int): + H = image_size + W = image_size + else: + H, W = image_size + assert isinstance(H, int) and H > 0 + assert isinstance(W, int) and W > 0 + assert isinstance(num_points, int) and num_points > 0 + K = num_points + strategy = strategy.lower() + if strategy not in _POINT_SAMPLING_STRATEGIES: + raise ValueError(f'Invalid point sampling strategy: `{strategy}`!\n' + f'Strategies allowed: {_POINT_SAMPLING_STRATEGIES}.') + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Sample points according to strategy. + if strategy == 'uniform': + assert dis_max >= dis_min > 0 + radii = torch.linspace(dis_min, dis_max, K, device=device) # [K] + return radii.reshape(1, 1, 1, K).repeat(N, H, W, 1) # [N, H, W, K] + + if strategy == 'normal': + # TODO: Should we support the normal sampling strategy? + assert dis_mean > 0 and dis_stddev >= 0 + + if strategy == 'ray_dependent': + # TODO: Strategy dependent on depth? + assert per_ray_ref.ndim == 4 + assert per_ray_ref.shape[:3] == (N, H, W) + + if strategy == 'point_dependent': + # TODO: This is hierarchical sampling? + assert per_point_ref.ndim == 5 + assert per_point_ref.shape[:4] == (N, H, W, K) + + raise NotImplementedError(f'Not implemented point sampling strategy: ' + f'`{strategy}`!') + + +def perturb_points_per_ray(radii, strategy='middle_uniform'): + # Stratified sampling approach described in original NeRF paper. + """Perturbs point radii within their local range on each ray. + + `N`, `H`, `W`, `K` denote batch size, image height, image width, number of + points per ray, respectively. + + The following perturbing strategies are supported: + + - `no`: + Disable point perturbation. + + - `middle_uniform`: + For each point, it is perturbed between two midpoints. One locates + within the point itself and its previous one on the same ray, while the + other locates within the point itself and its next one on the same ray. + + - `uniform`: + For each point, it is perturbed between itself and its next one. + For example, there are `n+1` points on the ray: [x_0, x_1, ..., x_n]. + Then the perturbed points are [x_0', x_1', ..., x_n'] with distribution + xi' ~ U(x_i, x_i+1), where x_n+1 = x_n + (x_n - x_n-1). + + - `self_uniform`: + For each point, it is perturbed around itself.For example, there are + `n+1` points on the ray: [x_0, x_1, ..., x_n]. Then the perturbed points + are [x_0', x_1', ..., x_n'] with distribution + xi' ~ U(x_i - 0.5, x_i+1 - 0.5). + + Args: + radii: A collection of point radii, with shape [N, H, W, K]. + strategy: Strategy to perturb each point. Defaults to `middle_uniform`. + + Returns: + A tensor, with shape [N, H, W, K], representing the per-point radial + distance on each ray. All numbers should be positive, and the + distances on each ray should follow a non-descending order. + + Raises: + ValueError: If the input point radii are with invalid shape, or the + perturbing strategy is not supported. + NotImplementedError: If the perturbing strategy is not implemented. + """ + # Check inputs. + if radii.ndim != 4: + raise ValueError(f'The input point radii should be with shape ' + f'[batch_size, height, width, num_points], ' + f'but `{radii.shape}` is received!') + strategy = strategy.lower() + if strategy not in _POINT_PERTURBING_STRATEGIES: + raise ValueError(f'Invalid point perturbing strategy: `{strategy}`!\n' + f'Strategies allowed: {_POINT_PERTURBING_STRATEGIES}.') + + if strategy == 'no': + return radii + + if strategy == 'middle_uniform': + # Get midpoints. + midpoint = (radii[..., 1:] + radii[..., :-1]) / 2 # [N, H, W, K-1] + # Get intervals. + left = torch.cat([radii[..., :1], midpoint], dim=-1) # [N, H, W, K] + right = torch.cat([midpoint, radii[..., -1:]], dim=-1) # [N, H, W, K] + # Uniformly sample within each interval. + t = torch.rand_like(radii) # [N, H, W, K] + return left + (right - left) * t # [N, H, W, K] + elif strategy == 'uniform': + delta = radii[..., 1:2] - radii[..., 0:1] # [N, H, W, 1] + t = torch.rand_like(radii) # [N, H, W, K] + return radii + t * delta # [N, H, W, K] + elif strategy == 'self_uniform': + delta = radii[..., 1:2] - radii[..., 0:1] # [N, H, W, 1] + t = torch.rand_like(radii) - 0.5 # [N, H, W, K] + return radii + t * delta # [N, H, W, K] + + raise NotImplementedError(f'Not implemented point perturbing strategy: ' + f'`{strategy}`!') + + +def sample_tensor(size, + strategy='uniform', + entry_fix=None, + entry_min=None, + entry_max=None, + entry_mean=None, + entry_stddev=None): + """Samples a tensor according to specified strategy. + + The following sampling strategies are supported: + + - `fix`: + Each entry is fixed as `entry_fix`. + + - `uniform`: + Each entry is uniformly sampled from range `[entry_min, entry_max]`. + + - `normal`: + Each entry is sampled subject to `Gaussian(entry_mean, entry_stddev^2)`. + + - `hybrid`: + Each entry is 50% sampled with `uniform` and 50% sampled with `normal`. + + - `truncated_normal`: + Each entry is sampled subject to a truncated normal distribution, with + `entry_min` and `entry_max` as the cut-off values. + + + Args: + size: Size of the sampled tensor. This field is expected to be an + integer, a list, or a tuple. + strategy: Strategy to sample points. Defaults to `uniform`. + entry_min: Minimum value of each entry. Defaults to `None`. + entry_max: Maximum value of each entry. Defaults to `None`. + entry_mean: Mean value of each entry. Defaults to `None`. + entry_stddev: Standard deviation of each entry. Defaults to `None`. + + Returns: + A tensor, with expected size. + + Raises: + ValueError: If the sampling strategy is not supported. + NotImplementedError: If the sampling strategy is not implemented. + """ + # Check inputs. + if isinstance(size, int): + size = (size,) + elif isinstance(size, list): + size = tuple(size) + assert isinstance(size, tuple) + strategy = strategy.lower() + if strategy not in _TENSOR_SAMPLING_STRATEGIES: + raise ValueError(f'Invalid tensor sampling strategy: `{strategy}`!\n' + f'Strategies allowed: {_TENSOR_SAMPLING_STRATEGIES}.') + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + if strategy == 'fix': + assert entry_fix is not None + return torch.ones(size, device=device) * entry_fix + + if strategy == 'uniform': + assert entry_max >= entry_min + t = torch.rand(size, device=device) + return entry_min + (entry_max - entry_min) * t + + if strategy == 'normal': + assert entry_mean is not None and entry_stddev >= 0 + return torch.randn(size, device=device) * entry_stddev + entry_mean + + if strategy == 'hybrid': + assert entry_max >= entry_min + assert entry_mean is not None and entry_stddev >= 0 + if np.random.random() < 0.5: + t = torch.rand(size, device=device) + return entry_min + (entry_max - entry_min) * t + return torch.randn(size, device=device) * entry_stddev + entry_mean + + if strategy == 'truncated_normal': + # TODO: Truncated normal distribution differs from cut-off. + assert entry_max >= entry_min + assert entry_mean is not None and entry_stddev >= 0 + tensor = torch.randn(size, device=device) * entry_stddev + entry_mean + tensor = torch.clamp(tensor, entry_min, entry_max) + return tensor + + raise NotImplementedError(f'Not implemented tensor sampling strategy: ' + f'`{strategy}`!') + + +def sample_camera_extrinsics(batch_size, + radius_strategy='fix', + radius_fix=None, + radius_min=None, + radius_max=None, + radius_mean=None, + radius_stddev=None, + polar_strategy='uniform', + polar_fix=None, + polar_min=None, + polar_max=None, + polar_mean=None, + polar_stddev=None, + azimuthal_strategy='uniform', + azimuthal_fix=None, + azimuthal_min=None, + azimuthal_max=None, + azimuthal_mean=None, + azimuthal_stddev=None, + use_spherical_uniform_position=False, + pitch_strategy='fix', + pitch_fix=0, + pitch_min=None, + pitch_max=None, + pitch_mean=None, + pitch_stddev=None, + yaw_strategy='fix', + yaw_fix=0, + yaw_min=None, + yaw_max=None, + yaw_mean=None, + yaw_stddev=None, + roll_strategy='fix', + roll_fix=0, + roll_min=None, + roll_max=None, + roll_mean=None, + roll_stddev=None): + """Samples camera extrinsics. + + This function supports sampling camera extrinsics from 6 dimensions (here, + all angles are in unit of radian): + + - Camera position: + - radius: Distance from the camera position to the origin of the world + coordinate system. + - polar: The polar angle with respect to the origin of the world + coordinate system. + - azimuthal: The azimuthal angle with respect to the origin of the world + coordinate system. + - Camera orientation: + - pitch: Pitch angle (X axis) regarding the camera coordinate system. + - yaw: Yaw angle (Y axis) regarding the camera coordinate system. + - roll: Roll angle (Z axis) regarding the camera coordinate system. + + and then convert the camera extrinsics to camera position and coordinate + transformation matrix. + + More details about sampling as well as arguments can be found in function + `sample_tensor()`. + + NOTE: + Without camera orientation (i.e., `pitch = 0, yaw = 0, roll = 0`), this + function assumes the camera pointing to the origin of the world + coordinate system. Furthermore, camera orientation controls the rotation + within the camera coordinate system, which is independent from the + transformation across coordinate systems. As a result, the camera does + not necessarily point to the origin of the world coordinate system + anymore. + + Args: + batch_size: Batch size of the sampled camera. Denoted as `N`. + use_spherical_uniform_position: Whether to sample the camera position + subject to a spherical uniform distribution. Defaults to False. + + Returns: + A dictionary, containing + - `camera_radius`: camera radius w.r.t. the world coordinate system, + with shape [N]. + - `camera_polar`: camera polar w.r.t. the world coordinate system, + with shape [N]. + - `camera_azimuthal`: camera azimuthal w.r.t. the world coordinate + system, with shape [N]. + - `camera_pitch`: camera pitch w.r.t. the camera coordinate system, + with shape [N]. + - `camera_yaw`: camera yaw w.r.t. the camera coordinate system, + with shape [N]. + - `camera_roll`: camera roll w.r.t. the camera coordinate system, + with shape [N]. + - `camera_pos`: camera position, i.e., the (x, y, z) coordinate + in the world coordinate system, with shape [N, 3]. + - `cam2world_matrix`: transformation matrix to transform the camera + coordinate system to the world coordinate system, with shape + [N, 4, 4]. + """ + # Sample camera position. + radius = sample_tensor(size=batch_size, + strategy=radius_strategy, + entry_fix=radius_fix, + entry_min=radius_min, + entry_max=radius_max, + entry_mean=radius_mean, + entry_stddev=radius_stddev) + if use_spherical_uniform_position: + # TODO: Check the local spherical uniform distribution? + polar = sample_tensor(size=batch_size, + strategy='uniform', + entry_fix=polar_fix, + entry_min=polar_min, + entry_max=polar_max, + entry_mean=polar_mean, + entry_stddev=polar_stddev) + azimuthal_cos_val = sample_tensor(size=batch_size, + strategy='uniform', + entry_min=azimuthal_min / np.pi, + entry_max=azimuthal_max / np.pi) + azimuthal = torch.arccos(1 - 2 * azimuthal_cos_val) + else: + polar = sample_tensor(size=batch_size, + strategy=polar_strategy, + entry_fix=polar_fix, + entry_min=polar_min, + entry_max=polar_max, + entry_mean=polar_mean, + entry_stddev=polar_stddev) + azimuthal = sample_tensor(size=batch_size, + strategy=azimuthal_strategy, + entry_fix=azimuthal_fix, + entry_min=azimuthal_min, + entry_max=azimuthal_max, + entry_mean=azimuthal_mean, + entry_stddev=azimuthal_stddev) + + # Sample camera orientation. + pitch = sample_tensor(size=batch_size, + strategy=pitch_strategy, + entry_fix=pitch_fix, + entry_min=pitch_min, + entry_max=pitch_max, + entry_mean=pitch_mean, + entry_stddev=pitch_stddev) + yaw = sample_tensor(size=batch_size, + strategy=yaw_strategy, + entry_fix=yaw_fix, + entry_min=yaw_min, + entry_max=yaw_max, + entry_mean=yaw_mean, + entry_stddev=yaw_stddev) + roll = sample_tensor(size=batch_size, + strategy=roll_strategy, + entry_fix=roll_fix, + entry_min=roll_min, + entry_max=roll_max, + entry_mean=roll_mean, + entry_stddev=roll_stddev) + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Get camera position. + N = batch_size + camera_pos = torch.zeros((N, 3), device=device) + camera_pos[:, 0] = radius * torch.sin(polar) * torch.cos(azimuthal) + camera_pos[:, 1] = radius * torch.cos(polar) + camera_pos[:, 2] = radius * torch.sin(polar) * torch.sin(azimuthal) + + # Get transformation matrix with the following steps. + # 1. Use pitch, yaw, and roll to get the rotation matrix within the camera + # coordinate system. + # 2. Get the forward axis, which points from the camper position to the + # origin of the world coordinate system. + # 3. Get a "pseudo" up axis, which is [0, 1, 0]. + # 4. Get the left axis by crossing the "pseudo" up axis with the forward + # axis. + # 5. Get the "actual" up axis by crossing the forward axis with the left + # axis. + # 6. Get the camera-to-world rotation matrix with the aforementioned + # forward axis, left axis, and "actual" up axis. + # 7. Get the camera-to-world transformation matrix. + pitch_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + pitch_matrix[:, 1, 1] = torch.cos(pitch) + pitch_matrix[:, 2, 2] = torch.cos(pitch) + pitch_matrix[:, 1, 2] = -torch.sin(pitch) + pitch_matrix[:, 2, 1] = torch.sin(pitch) # [N, 4, 4] + yaw_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + yaw_matrix[:, 0, 0] = torch.cos(yaw) + yaw_matrix[:, 2, 2] = torch.cos(yaw) + yaw_matrix[:, 2, 0] = -torch.sin(yaw) + yaw_matrix[:, 0, 2] = torch.sin(yaw) # [N, 4, 4] + roll_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + roll_matrix[:, 0, 0] = torch.cos(roll) + roll_matrix[:, 1, 1] = torch.cos(roll) + roll_matrix[:, 0, 1] = -torch.sin(roll) + roll_matrix[:, 1, 0] = torch.sin(roll) # [N, 4, 4] + + forward_axis = F.normalize(camera_pos * -1, dim=-1) # [N, 3] + pseudo_up_axis = torch.as_tensor([0.0, 1.0, 0.0], device=device) # [3] + pseudo_up_axis = pseudo_up_axis.reshape(1, 3).repeat(N, 1) # [N, 3] + left_axis = torch.cross(pseudo_up_axis, forward_axis, dim=-1) # [N, 3] + left_axis = F.normalize(left_axis, dim=-1) # [N, 3] + up_axis = torch.cross(forward_axis, left_axis, dim=-1) # [N, 3] + up_axis = F.normalize(up_axis, dim=-1) # [N, 3] + + rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(N, 1, 1) + rotation_matrix[:, :3, 0] = -left_axis + rotation_matrix[:, :3, 1] = up_axis + rotation_matrix[:, :3, 2] = -forward_axis # [N, 4, 4] + + translation_matrix = torch.eye(4, device=device) + translation_matrix = translation_matrix.unsqueeze(0).repeat(N, 1, 1) + translation_matrix[:, :3, 3] = camera_pos # [N, 4, 4] + + cam2world_matrix = (translation_matrix @ rotation_matrix @ + roll_matrix @ yaw_matrix @ pitch_matrix) # [N, 4, 4] + + return { + 'radius': radius, + 'polar': polar, + 'azimuthal': azimuthal, + 'pitch': pitch, + 'yaw': yaw, + 'roll': roll, + 'camera_pos':camera_pos, + 'cam2world_matrix': cam2world_matrix + } + + +def get_point_coord(rays_camera, radii, cam2world_matrix): + """Gets pre-point coordinate in the world coordinate system. + + `N`, `H`, `W`, `K` denote batch size, image height, image width, number of + points per ray, respectively. + + Args: + rays_camera: Per-pixel ray direction, with shape [N, H, W, 3], in the + camera coordinate system. + radii: Per-point radial distance on each ray, with shape [N, H, W, K]. + cam2world_matrix: Transformation matrix that transforms the camera + coordinate system to the world coordinate system, with shape + [N, 4, 4]. + + Returns: + A dictionary, containing + - `rays_world`: ray directions in the world coordinate system, + with shape [N, H, W, 3]. + - `points_camera`: per-point coordinate in the camera coordinate + system, with shape [N, H, W, K, 3]. + - `points_world`: per-point coordinate in the world coordinate + system, with shape [N, H, W, K, 3]. + + Raises: + ValueError: If any input has invalid shape. + """ + # Check inputs. + if rays_camera.ndim != 4 or rays_camera.shape[3] != 3: + raise ValueError(f'The input rays should be with shape ' + f'[batch_size, height, width, 3], ' + f'but `{rays_camera.shape}` is received!') + N, H, W, _ = rays_camera.shape + if radii.ndim != 4 or radii.shape[:3] != (N, H, W): + raise ValueError(f'The input radii should be with shape ' + f'[batch_size, height, width, num_points], where ' + f'batch_size, height, width align with those of rays, ' + f'but `{radii.shape}` is received!') + K = radii.shape[3] + if cam2world_matrix.shape != (N, 4, 4): + raise ValueError(f'The input cam2world_matrix should be with shape ' + f'[batch_size, 4, 4], where batch_size align with ' + f'that of rays and radii ' + f'but `{cam2world_matrix.shape}` is received!') + + # Get running device. + device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + # Transform rays. + rays_world = (cam2world_matrix[:, :3, :3] @ + rays_camera.reshape(N, -1, 3).permute(0, 2, 1)) + rays_world = rays_world.permute(0, 2, 1).reshape(N, H, W, 3) + + # Transform ray origins. + ray_origins_homo = torch.zeros((N, H * W, 4), device=device) + ray_origins_homo[..., 3] = 1 + ray_origins_world = torch.bmm(cam2world_matrix, + ray_origins_homo.permute(0, 2, 1)).permute( + 0, 2, 1)[..., :3] + ray_origins_world = ray_origins_world.reshape(N, H, W, 3) + + # Transform points. + points_camera = (rays_camera.unsqueeze(3) * + radii.unsqueeze(4)) # [N, H, W, K, 3] + points_camera_homo = torch.cat( + [points_camera, torch.ones((N, H, W, K, 1), device=device)], + dim=-1) # [N, H, W, K, 4] + points_world_homo = (cam2world_matrix @ + points_camera_homo.reshape(N, -1, 4).permute(0, 2, 1)) + points_world = points_world_homo.permute(0, 2, 1)[:, :, :3] + points_world = points_world.reshape(N, H, W, K, 3) + + return { + 'rays_world': rays_world, + 'ray_origins_world': ray_origins_world, + 'points_camera': points_camera, + 'points_world': points_world, + } diff --git a/models/rendering/renderer.py b/models/rendering/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..90737e009d243ed9bfd15221f0916e77f027f1fe --- /dev/null +++ b/models/rendering/renderer.py @@ -0,0 +1,354 @@ +# python3.8 +"""Contains image renderer class.""" + +import torch +import torch.nn as nn +from .point_sampler import PointSampler +from .integrator import Integrator + +__all__ = ['Renderer'] + + +class Renderer(nn.Module): + """Defines the class to render images. + + The renderer is a module that takes in latent codes and points, decides + where to sample along each ray, and computes pixel colors/features using the + volume rendering equation. + + Basically, the volume rendering pipiline consists of the following steps: + + 1. Sample points in 3D Space. + 2. (Optional) Get the reference representation by injecting latent codes + into the reference representation generator. Generally, the reference + representation can be a feature volume (VolumenGAN), a triplane (EG3D) or + others. + 3. Get the corresponding feature of each sampled point by the given feature + extractor. Typically, the overall formulation is: + feat = F(wp, points, options, ref_representation, post_module) + where + `feat`: The output points' features. + `F`: The feature extractor. + `wp`: The latent codes in W-sapce. + `points`: Sampled points. + `options`: Some options for rendering. + `ref_representation`: The reference representation obtained in step 2. + `post_module`: The post module, is usually a MLP. + 4. Get the sigma's and rgb's value (or feature) by feeding `feat` in + step 3 into one or two fully-connected layer head. + 5. Coarse pass to do the integration. + 6. Hierarchically sample points on top of step 5. + 6. Fine pass to do the integration. + + Note: In the following scripts, meanings of variables `N, H, W, R, K, C` are: + + - `N`: Batch size. + - `H`: Height of image. + - `W`: Width of image. + - `R`: Number of rays, usually equals `H * W`. + - `K`: Number of points on each ray. + - `C`: Number of channels w.r.t. features or images, e.t.c. + """ + + def __init__(self): + super().__init__() + self.point_sampler = PointSampler() + self.integrator = Integrator() + + def forward( + self, + wp, + feature_extractor, + rendering_options, + cam2world_matrix=None, + position_encoder=None, + ref_representation=None, + post_module=None, + post_module_kwargs={}, + fc_head=None, + fc_head_kwargs={}, + ): + #TODO: Organize `rendering_options` like the following format: + ''' + rendering_options = dict( + point_sampler_options=dict( + focal=None, + ... + ) + integrator_options=dict(...), + ...., + xxx=xxx, # some public parameters. + ... + ) + ''' + batch_size= wp.shape[0] + + # Sample points. + sampling_point_res = self.point_sampler( + batch_size=batch_size, + focal=rendering_options.get('focal', None), + image_boundary_value=rendering_options.get('image_boundary_value', + 0.5), + cam_look_at_dir=rendering_options.get('cam_look_at_dir', +1), + pixel_center=rendering_options.get('pixel_center', True), + y_descending=rendering_options.get('y_descending', False), + image_size=rendering_options.get('resolution', 64), + dis_min=rendering_options.get('ray_start', None), + dis_max=rendering_options.get('ray_end', None), + cam2world_matrix=cam2world_matrix, + num_points=rendering_options.get('depth_resolution', 48), + perturbation_strategy=rendering_options.get( + 'perturbation_strategy', 'uniform'), + radius_strategy=rendering_options.get('radius_strategy', None), + radius_fix=rendering_options.get('radius_fix', None), + polar_strategy=rendering_options.get('polar_strategy', None), + polar_fix=rendering_options.get('polar_fix', None), + polar_mean=rendering_options.get('polar_mean', None), + polar_stddev=rendering_options.get('polar_stddev', None), + azimuthal_strategy=rendering_options.get('azimuthal_strategy', + None), + azimuthal_fix=rendering_options.get('azimuthal_fix', None), + azimuthal_mean=rendering_options.get('azimuthal_mean', None), + azimuthal_stddev=rendering_options.get('azimuthal_stddev', None), + fov=rendering_options.get('fov', 30), + ) + points = sampling_point_res['points_world'] # [N, H, W, K, 3] + ray_dirs = sampling_point_res['rays_world'] # [N, H, W, 3] + ray_origins = sampling_point_res['ray_origins_world'] # [N, H, W, 3] + z_coarse = sampling_point_res['radii'] # [N, H, W, K] + + # NOTE: `pitch` is used to stand for `polar` in other code. + camera_polar = sampling_point_res['camera_polar'] # [N] + # NOTE: `yaw` is used to stand for `azimuthal` in other code. + camera_azimuthal = sampling_point_res['camera_azimuthal'] # [N] + if camera_polar is not None: + camera_polar = camera_polar.unsqueeze(-1) + if camera_azimuthal is not None: + camera_azimuthal = camera_azimuthal.unsqueeze(-1) + + # Reshape. + N, H, W, K, _ = points.shape + assert N == batch_size + R = H * W # number of rays + points = points.reshape(N, R, K, -1) + ray_dirs = ray_dirs.reshape(N, R, -1) + ray_origins = ray_origins.reshape(N, R, -1) + z_coarse = z_coarse.reshape(N, R, K, -1) + + out = self.get_sigma_rgb(wp, + points, + feature_extractor, + rendering_options=rendering_options, + position_encoder=position_encoder, + ref_representation=ref_representation, + post_module=post_module, + post_module_kwargs=post_module_kwargs, + fc_head=fc_head, + fc_head_kwargs=dict(**fc_head_kwargs, + wp=wp), + ray_dirs=ray_dirs, + cam_matrix=cam2world_matrix) + + sigmas_coarse = out['sigma'] # [N, H * W * K, 1] + rgbs_coarse = out['rgb'] # [N, H * W * K, C] + sigmas_coarse = sigmas_coarse.reshape(N, R, K, + sigmas_coarse.shape[-1]) + rgbs_coarse = rgbs_coarse.reshape(N, R, K, rgbs_coarse.shape[-1]) + + # Do the integration. + N_importance = rendering_options.get('depth_resolution_importance', 0) + if N_importance > 0: + # Do the integration in coarse pass. + rendering_result = self.integrator(rgbs_coarse, sigmas_coarse, + z_coarse, rendering_options) + weights = rendering_result['weights'] + + # Importrance sampling. + z_fine = self.sample_importance( + z_coarse, + weights, + N_importance, + smooth_weights=rendering_options.get('smooth_weights', True)) + points = ray_origins.unsqueeze(-2) + z_fine * ray_dirs.unsqueeze(-2) + + # Get sigma's and rgb's value (or feature). + out = self.get_sigma_rgb(wp, + points, + feature_extractor, + rendering_options=rendering_options, + position_encoder=position_encoder, + ref_representation=ref_representation, + post_module=post_module, + post_module_kwargs=post_module_kwargs, + fc_head=fc_head, + fc_head_kwargs=dict(**fc_head_kwargs, + wp=wp), + ray_dirs=ray_dirs, + cam_matrix=cam2world_matrix) + + sigmas_fine = out['sigma'] + rgbs_fine = out['rgb'] + sigmas_fine = sigmas_fine.reshape(N, R, N_importance, + sigmas_fine.shape[-1]) + rgbs_fine = rgbs_fine.reshape(N, R, N_importance, + rgbs_fine.shape[-1]) + + # Gather coarse and fine results. + all_zs, all_rgbs, all_sigmas = self.unify_samples( + z_coarse, rgbs_coarse, sigmas_coarse, + z_fine, rgbs_fine, sigmas_fine) + + # Do the integration in fine pass. + final_rendering_result = self.integrator( + all_rgbs, all_sigmas, all_zs, rendering_options) + + else: + final_rendering_result = self.integrator( + rgbs_coarse, sigmas_coarse, z_coarse, rendering_options) + + return { + **final_rendering_result, + **{ + 'camera_azimuthal': camera_azimuthal, + 'camera_polar': camera_polar + }, + **{ + 'points': points, + 'sigmas': sigmas_fine, + } + } + + def get_sigma_rgb(self, + wp, + points, + feature_extractor, + rendering_options, + position_encoder=None, + ref_representation=None, + post_module=None, + post_module_kwargs={}, + fc_head=None, + fc_head_kwargs={}, + ray_dirs=None, + cam_matrix=None): + # Get point feature in coarse pass. + point_features = feature_extractor(wp, points, rendering_options, + position_encoder, + ref_representation, post_module, + post_module_kwargs, ray_dirs, cam_matrix) + + # Get sigma's and rgb's value (or feature). + if ray_dirs.ndim != points.ndim: + ray_dirs = ray_dirs.unsqueeze(-2).expand_as(points) + ray_dirs = ray_dirs.reshape(ray_dirs.shape[0], -1, ray_dirs.shape[-1]) + # with shape [N, R * K, 3] + out = fc_head(point_features, dirs=ray_dirs, **fc_head_kwargs) + + if rendering_options.get('noise_std', 0) > 0: + out['sigma'] = out['sigma'] + torch.randn_like( + out['sigma']) * rendering_options['noise_std'] + + return out + + def unify_samples(self, depths1, rgbs1, sigmas1, depths2, rgbs2, sigmas2): + all_depths = torch.cat([depths1, depths2], dim=-2) + all_colors = torch.cat([rgbs1, rgbs2], dim=-2) + all_densities = torch.cat([sigmas1, sigmas2], dim=-2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather( + all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, + indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities + + def sample_importance(self, + z_vals, + weights, + N_importance, + smooth_weights=False): + """ Implements NeRF importance sampling. + + Returns: + importance_z_vals: Depths of importance sampled points along rays. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) + 1e-5 + + # smooth weights + if smooth_weights: + weights = torch.nn.functional.max_pool1d( + weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, + 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape( + batch_size, num_rays, + N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """Sample `N_importance` samples from `bins` with distribution defined + by `weights`. + + Args: + bins: (N_rays, N_samples_+1) where N_samples_ is the number of + coarse samples per ray - 2 + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + + Returns: + samples: the sampled samples + + Source: + https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py + + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps + # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, + keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), + # cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], + -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u) + below = torch.clamp_min(inds - 1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], + -1).view(N_rays, 2 * N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled) + cdf_g = cdf_g.view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, + inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[..., 1] - cdf_g[..., 0] + denom[denom < eps] = 1 # denom equals 0 means a bin has weight 0, + # in which case it will not be sampled + # anyway, therefore any value for it is fine + # (set to 1 here) + + samples = (bins_g[..., 0] + (u - cdf_g[..., 0]) / + denom * (bins_g[..., 1] - bins_g[..., 0])) + + return samples diff --git a/models/rendering/triplane_sampler.py b/models/rendering/triplane_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c8868936914ad726bc9cffbf36943a5b2b4ddac9 --- /dev/null +++ b/models/rendering/triplane_sampler.py @@ -0,0 +1,110 @@ +"""Contain the functions to sample point features from the triplane + representation.""" + +import torch + +__all__ = ['TriplaneSampler'] + + +class TriplaneSampler(torch.nn.Module): + """Defines the class to help sample point features from the triplane + representation. + + Basically, this class implements the following functions for sampling point + features (rgb && sigma) from the triplane representation: + + 1. `generate_planes()`. + 2. `project_onto_planes()`. + 3. `sample_from_planes()`. + 4. `sample_from_3dgrid()`. + """ + + def __init__(self): + super().__init__() + + @staticmethod + def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]], dtype=torch.float32) + + @staticmethod + def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Args: + planes: Plane axes of shape (n_planes, 3, 3) + coordinates: Coordinates of shape (N, M, 3) + + Returns: + projections: Projections of shape (N*n_planes, M, 2) + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, + -1).reshape( + N * n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand( + N, -1, -1, -1).reshape(N * n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + + @staticmethod + def sample_from_planes(plane_axes, + plane_features, + coordinates, + mode='bilinear', + padding_mode='zeros', + box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N * n_planes, C, H, W) + + coordinates = (2 / box_warp) * coordinates + + projected_coordinates = TriplaneSampler.project_onto_planes( + plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample( + plane_features, + projected_coordinates.float(), + mode=mode, + padding_mode=padding_mode, + align_corners=False).permute(0, 3, 2, + 1).reshape(N, n_planes, M, C) + return output_features + + @staticmethod + def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns: + Sampled features + with shape: (batch_size, num_points_per_batch, feature_channels). + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample( + grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', + padding_mode='zeros', + align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape( + N, H * W * D, C) + return sampled_features \ No newline at end of file diff --git a/models/rendering/utils.py b/models/rendering/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..21f51715f99e705b2647e813beec80ed8d0867ec --- /dev/null +++ b/models/rendering/utils.py @@ -0,0 +1,187 @@ +# python3.8 +"""Contains utility functions for rendering.""" +import torch + +def normalize_vecs(vectors): + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def truncated_normal(tensor, mean=0, std=1): + """ + Samples from truncated normal distribution. + """ + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor + +def get_grid_coords(points, bounds): + """ transform points from the world coordinate to the volume coordinate + pts: batch_size, num_point, 3 + bounds: 2, 3 + """ + # normalize the points + bounds = bounds[None] + min_xyz = bounds[:, :1] + points = points - min_xyz + # convert the voxel coordinate to [-1, 1] + size = bounds[:, 1] - bounds[:, 0] + points = (points / size[:, None]) * 2 - 1 + return points + +def grid_sample_3d(image, optical): + """grid sample images by the optical in 3D format + image: batch_size, channel, D, H, W + optical: batch_size, D, H, W, 3 + """ + N, C, ID, IH, IW = image.shape + _, D, H, W, _ = optical.shape + + ix = optical[..., 0] + iy = optical[..., 1] + iz = optical[..., 2] + + ix = ((ix + 1) / 2) * (IW - 1) + iy = ((iy + 1) / 2) * (IH - 1) + iz = ((iz + 1) / 2) * (ID - 1) + with torch.no_grad(): + ix_tnw = torch.floor(ix) + iy_tnw = torch.floor(iy) + iz_tnw = torch.floor(iz) + + ix_tne = ix_tnw + 1 + iy_tne = iy_tnw + iz_tne = iz_tnw + + ix_tsw = ix_tnw + iy_tsw = iy_tnw + 1 + iz_tsw = iz_tnw + + ix_tse = ix_tnw + 1 + iy_tse = iy_tnw + 1 + iz_tse = iz_tnw + + ix_bnw = ix_tnw + iy_bnw = iy_tnw + iz_bnw = iz_tnw + 1 + + ix_bne = ix_tnw + 1 + iy_bne = iy_tnw + iz_bne = iz_tnw + 1 + + ix_bsw = ix_tnw + iy_bsw = iy_tnw + 1 + iz_bsw = iz_tnw + 1 + + ix_bse = ix_tnw + 1 + iy_bse = iy_tnw + 1 + iz_bse = iz_tnw + 1 + + tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz) + tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz) + tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz) + tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz) + bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse) + bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw) + bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne) + bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw) + + with torch.no_grad(): + torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw) + torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw) + torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw) + + torch.clamp(ix_tne, 0, IW - 1, out=ix_tne) + torch.clamp(iy_tne, 0, IH - 1, out=iy_tne) + torch.clamp(iz_tne, 0, ID - 1, out=iz_tne) + + torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw) + torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw) + torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw) + + torch.clamp(ix_tse, 0, IW - 1, out=ix_tse) + torch.clamp(iy_tse, 0, IH - 1, out=iy_tse) + torch.clamp(iz_tse, 0, ID - 1, out=iz_tse) + + torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) + torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) + torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) + + torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) + torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) + torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) + + torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) + torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) + torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) + + torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) + torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) + torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) + + image = image.view(N, C, ID * IH * IW) + + tnw_val = torch.gather(image, 2, + (iz_tnw * IW * IH + iy_tnw * IW + + ix_tnw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + tne_val = torch.gather(image, 2, + (iz_tne * IW * IH + iy_tne * IW + + ix_tne).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + tsw_val = torch.gather(image, 2, + (iz_tsw * IW * IH + iy_tsw * IW + + ix_tsw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + tse_val = torch.gather(image, 2, + (iz_tse * IW * IH + iy_tse * IW + + ix_tse).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bnw_val = torch.gather(image, 2, + (iz_bnw * IW * IH + iy_bnw * IW + + ix_bnw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bne_val = torch.gather(image, 2, + (iz_bne * IW * IH + iy_bne * IW + + ix_bne).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bsw_val = torch.gather(image, 2, + (iz_bsw * IW * IH + iy_bsw * IW + + ix_bsw).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + bse_val = torch.gather(image, 2, + (iz_bse * IW * IH + iy_bse * IW + + ix_bse).long().view(N, 1, + D * H * W).repeat(1, C, 1)) + + out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) + + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) + + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) + + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) + + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W)) + + return out_val + +def interpolate_feature(points, volume, bounds): + """ + points: batch_size, num_point, 3 + volume: batch_size, num_channel, d, h, w + bounds: 2, 3 + """ + grid_coords = get_grid_coords(points, bounds) + grid_coords = grid_coords[:, None, None] + # point_features = F.grid_sample(volume, + # grid_coords, + # padding_mode='zeros', + # align_corners=True) + point_features = grid_sample_3d(volume, grid_coords) + point_features = point_features[:, :, 0, 0] + return point_features \ No newline at end of file diff --git a/models/sgbev3d_generator.py b/models/sgbev3d_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7c8429bf4cdcdf87b1d62bc530b70462759a3f --- /dev/null +++ b/models/sgbev3d_generator.py @@ -0,0 +1,332 @@ +# python3.8 +"""Contains the implementation of generator described in SGBEV3D.""" + +import torch +import torch.nn as nn +from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone +from models.utils.official_stylegan3_model_helper import Generator as StyleGAN3Backbone +from models.utils.unet import Generator as StyleGAN4Backbone +from models.utils.official_stylegan2_model_helper import FullyConnectedLayer +from models.utils.eg3d_superres import SuperresolutionHybrid2X +from models.utils.eg3d_superres import SuperresolutionHybrid4X +from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle +from models.utils.eg3d_superres import SuperresolutionHybrid8XDC +from models.rendering.renderer import Renderer +from models.rendering.feature_extractor import FeatureExtractor + +from models.utils.spade import SPADEGenerator + +class SGBEV3DGenerator(nn.Module): + + def __init__( + self, + z_dim, + c_dim, + w_dim, + semantic_nc, + ngf, + bev_grid_size, + aspect_ratio, + num_upsampling_layers, + not_use_vae, + norm_G, + interpolate_sr, + segmask=False, + dim_seq='16,8,4,2,1', + xyz_pe=False, + reverse_xy=True, + hidden_dim=64, + additional_layer_num=0, + block_num=5, + layer_num=2, + ff_input=False, + ref_mode='bev_plane_clevr_256', + sel_type=None, + backbone_ver=2, + img_resolution=256, + bev_resolution=256, + sr_num_fp16_res=0, # Number of fp16 layers of SR Network. + mapping_kwargs={}, + rendering_kwargs={}, # Arguments for rendering. + sr_kwargs={}, # Arguments for SuperResolution Network. + **synthesis_kwargs + ): + super().__init__() + + self.z_dim = z_dim + self.interpolate_sr = interpolate_sr + self.segmask = segmask + + # Set up the overall renderer. + self.renderer = Renderer() + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode=ref_mode, xyz_pe=xyz_pe, reverse_xy=reverse_xy) + + # Set up the reference representation generator. + self.backbone = globals()[f'StyleGAN{backbone_ver}Backbone'](z_dim, c_dim, w_dim, img_resolution=bev_resolution, img_channels=32, label_nc=semantic_nc, use_sel=True, sel_type=sel_type, mapping_kwargs=mapping_kwargs, ff_input=ff_input, block_num=block_num, layer_num=layer_num, **synthesis_kwargs) + + # Set up the post module in the feature extractor. + self.post_module = None + + # Set up the post neural renderer. + self.post_neural_renderer = None + sr_kwargs_total = dict( + channels=32, + img_resolution=img_resolution, + sr_num_fp16_res=sr_num_fp16_res, + sr_antialias=rendering_kwargs['sr_antialias'],) + sr_kwargs_total.update(**sr_kwargs) + if img_resolution == 128: + self.post_neural_renderer = SuperresolutionHybrid2X( + **sr_kwargs_total) + elif img_resolution == 256: + self.post_neural_renderer = SuperresolutionHybrid4X_conststyle( + **sr_kwargs_total) + elif img_resolution == 512: + self.post_neural_renderer = SuperresolutionHybrid8XDC( + **sr_kwargs_total) + else: + raise TypeError(f'Unsupported image resolution: {img_resolution}!') + + # Set up the fully-connected layer head. + self.fc_head = OSGDecoder( + 128 if xyz_pe else 64 , { + 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), + 'decoder_output_dim': 32 + }, + hidden_dim=hidden_dim, + additional_layer_num=additional_layer_num + ) + + # Set up some rendering related arguments. + self.neural_rendering_resolution = rendering_kwargs.get( + 'resolution', 64) + self.rendering_kwargs = rendering_kwargs + + def mapping(self, + z, + c, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False): + if self.rendering_kwargs['c_gen_conditioning_zero']: + c = torch.zeros_like(c) + return self.backbone.mapping(z, + c * + self.rendering_kwargs.get('c_scale', 0), + truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + + def synthesis(self, + wp, + c, + seg, + neural_rendering_resolution=None, + update_emas=False, + **synthesis_kwargs): + cam2world_matrix = c[:, :16].view(-1, 4, 4) + if self.rendering_kwargs.get('random_pose', False): + cam2world_matrix = None + + if neural_rendering_resolution is None: + neural_rendering_resolution = self.neural_rendering_resolution + else: + self.neural_rendering_resolution = neural_rendering_resolution + + xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs) + if self.segmask: + xy_planes = xy_planes * seg[:, 0, ...][:, None, ...] + + rendering_result = self.renderer( + wp=wp, + feature_extractor=self.feature_extractor, + rendering_options=self.rendering_kwargs, + cam2world_matrix=cam2world_matrix, + position_encoder=None, + ref_representation=xy_planes, + post_module=self.post_module, + fc_head=self.fc_head) + + feature_samples = rendering_result['composite_rgb'] + depth_samples = rendering_result['composite_depth'] + + # Reshape to keep consistent with 'raw' neural-rendered image. + N = wp.shape[0] + H = W = self.neural_rendering_resolution + feature_image = feature_samples.permute(0, 2, 1).reshape( + N, feature_samples.shape[-1], H, W).contiguous() + depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) + + # Run the post neural renderer to get final image. + # Here, the post neural renderer is a super-resolution network. + rgb_image = feature_image[:, :3] + if self.interpolate_sr: + sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False) + else: + sr_image = self.post_neural_renderer( + rgb_image, + feature_image, + # wp, # todo: study SR with wp + noise_mode=self.rendering_kwargs['superresolution_noise_mode'], + **{ + k: synthesis_kwargs[k] + for k in synthesis_kwargs.keys() if k != 'noise_mode' + }) + + return { + 'image': sr_image, + 'image_raw': rgb_image, + 'image_depth': depth_image, + 'plane': xy_planes, + 'points': rendering_result['points'], + 'sigmas': rendering_result['sigmas'] + } + + def sample(self, + coordinates, + directions, + z, + c, + seg, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False, + **synthesis_kwargs): + # Compute RGB features, density for arbitrary 3D coordinates. + # Mostly used for extracting shapes. + cam2world_matrix = c[:, :16].view(-1, 4, 4) + wp = self.mapping(z, c, truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs) + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=xy_planes, + post_module=self.post_module, + ray_dirs=directions, + cam_matrix=cam2world_matrix) + + return result + + def sample_mixed(self, + coordinates, + directions, + wp, c, seg, + truncation_psi=1, + truncation_cutoff=None, + update_emas=False, + **synthesis_kwargs): + # Same as function `self.sample()`, but expects latent vectors 'wp' + # instead of Gaussian noise 'z'. + cam2world_matrix = c[:, :16].view(-1, 4, 4) + xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs) + result = self.renderer.get_sigma_rgb( + wp=wp, + points=coordinates, + feature_extractor=self.feature_extractor, + fc_head=self.fc_head, + rendering_options=self.rendering_kwargs, + ref_representation=xy_planes, + post_module=self.post_module, + ray_dirs=directions, + cam_matrix=cam2world_matrix) + + return result + + def forward(self, + z, + c, + seg, + c_swapped=None, # `c_swapped` is swapped pose conditioning. + style_mixing_prob=0, + truncation_psi=1, + truncation_cutoff=None, + neural_rendering_resolution=None, + update_emas=False, + sample_mixed=False, + coordinates=None, + **synthesis_kwargs): + + # Render a batch of generated images. + c_wp = c.clone() + if c_swapped is not None: + c_wp = c_swapped.clone() + wp = self.mapping(z, c_wp, truncation_psi=truncation_psi, + truncation_cutoff=truncation_cutoff, + update_emas=update_emas) + + #TODO: implement style mixing + + if not sample_mixed: + gen_output = self.synthesis( + wp, + c, + seg, + update_emas=update_emas, + neural_rendering_resolution=neural_rendering_resolution, + **synthesis_kwargs) + + return { + 'wp': z, + 'gen_output': gen_output, + } + + else: + # Only for density regularization in training process. + assert coordinates is not None + sample_sigma = self.sample_mixed(coordinates, + torch.randn_like(coordinates), + wp, c, seg, + update_emas=False)['sigma'] + + return { + 'wp': z, + 'sample_sigma': sample_sigma + } + + +class OSGDecoder(nn.Module): + """Defines fully-connected layer head in EG3D.""" + def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0): + super().__init__() + self.hidden_dim = hidden_dim + + lst = [] + lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) + lst.append(nn.Softplus()) + for i in range(additional_layer_num): + lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) + lst.append(nn.Softplus()) + lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])) + self.net = nn.Sequential(*lst) + + # self.net = nn.Sequential( + # FullyConnectedLayer(n_features, + # self.hidden_dim, + # lr_multiplier=options['decoder_lr_mul']), + # nn.Softplus(), + # FullyConnectedLayer(self.hidden_dim, + # 1 + options['decoder_output_dim'], + # lr_multiplier=options['decoder_lr_mul'])) + + def forward(self, point_features, wp=None, dirs=None): + # Aggregate features + # point_features.shape: [N, R, K, C]. + # Average across 'X, Y, Z' planes. + + N, R, K, C = point_features.shape + x = point_features.reshape(-1, point_features.shape[-1]) + x = self.net(x) + x = x.view(N, -1, x.shape[-1]) + + # Uses sigmoid clamping from MipNeRF + rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} diff --git a/models/stylegan2_discriminator.py b/models/stylegan2_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..1802d44b68d0801290dcd691f09f605c3cfc9cbd --- /dev/null +++ b/models/stylegan2_discriminator.py @@ -0,0 +1,729 @@ +# python3.7 +"""Contains the implementation of discriminator described in StyleGAN2. + +Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip +connections, increases model size and disables progressive growth. This script +ONLY supports config F in the original paper. + +Paper: https://arxiv.org/pdf/1912.04958.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan2 +""" + +import numpy as np + +import torch +import torch.nn as nn + +from third_party.stylegan2_official_ops import bias_act +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import conv2d_gradfix + +__all__ = ['StyleGAN2Discriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Architectures allowed. +_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] + +# pylint: disable=missing-function-docstring + +class StyleGAN2Discriminator(nn.Module): + """Defines the discriminator network in StyleGAN2. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the backbone: + + (1) resolution: The resolution of the input image. (default: -1) + (2) init_res: Smallest resolution of the convolutional backbone. + (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) architecture: Type of architecture. Support `origin`, `skip`, and + `resnet`. (default: `resnet`) + (5) use_wscale: Whether to use weight scaling. (default: True) + (6) wscale_gain: The factor to control weight scaling. (default: 1.0) + (7) lr_mul: Learning rate multiplier for backbone. (default: 1.0) + (8) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 4) + (9) mbstd_channels: Number of new channels (appended to the original feature + map) after the minibatch standard deviation layer. (default: 1) + (10) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (12) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 3, 3, 1)) + (13) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Settings for conditional model: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_use_wscale: Whether to use weight scaling for embedding + learning. (default: True) + (5) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (6) normalize_embedding: Whether to normalize the embedding. (default: True) + (7) mapping_layers: Number of layers of the additional mapping network after + embedding. (default: 0) + (8) mapping_fmaps: Number of hidden channels of the additional mapping + network after embedding. (default: 512) + (9) mapping_use_wscale: Whether to use weight scaling for the additional + mapping network. (default: True) + (10) mapping_lr_mul: Learning rate multiplier for the additional mapping + network after embedding. (default: 0.1) + + Runtime settings: + + (1) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (2) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN2, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for backbone. + resolution=-1, + init_res=4, + image_channels=3, + architecture='resnet', + use_wscale=True, + wscale_gain=1.0, + lr_mul=1.0, + mbstd_groups=4, + mbstd_channels=1, + fmaps_base=32 << 10, + fmaps_max=512, + filter_kernel=(1, 3, 3, 1), + conv_clamp=None, + eps=1e-8, + # Settings for conditional model. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_use_wscale=True, + embedding_lr_mul=1.0, + normalize_embedding=True, + mapping_layers=0, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_lr_mul=0.1): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `architecture` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + architecture = architecture.lower() + if architecture not in _ARCHITECTURES_ALLOWED: + raise ValueError(f'Invalid architecture: `{architecture}`!\n' + f'Architectures allowed: ' + f'{_ARCHITECTURES_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.image_channels = image_channels + self.architecture = architecture + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.mbstd_groups = mbstd_groups + self.mbstd_channels = mbstd_channels + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_lr_mul = mapping_lr_mul + + self.pth_to_tf_var_mapping = {} + + # Embedding for conditional discrimination. + self.use_embedding = label_dim > 0 and embedding_dim > 0 + if self.use_embedding: + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + add_bias=embedding_bias, + init_bias=0.0, + use_wscale=embedding_use_wscale, + wscale_gain=wscale_gain, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' + if self.embedding_bias: + self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' + + if self.normalize_embedding: + self.norm = PixelNormLayer(dim=1, eps=eps) + + for i in range(mapping_layers): + in_channels = (embedding_dim if i == 0 else mapping_fmaps) + out_channels = (embedding_dim if i == (mapping_layers - 1) else + mapping_fmaps) + layer_name = f'mapping{i}' + self.add_module(layer_name, + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + init_bias=0.0, + use_wscale=mapping_use_wscale, + wscale_gain=wscale_gain, + lr_mul=mapping_lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'Mapping{i}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'Mapping{i}/bias') + + # Convolutional backbone. + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution (if needed). + if res_log2 == self.final_res_log2 or self.architecture == 'skip': + layer_name = f'input{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/FromRGB/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/FromRGB/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0/bias') + + # Second layer (kernel 3x3) with downsampling + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv1_down/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv1_down/bias') + + # Residual branch (kernel 1x1) with downsampling, without bias, + # with linear activation. + if self.architecture == 'resnet': + layer_name = f'residual{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + add_bias=False, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear', + conv_clamp=None)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Skip/weight') + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer( + groups=mbstd_groups, new_channels=mbstd_channels, eps=eps) + + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels + mbstd_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv/bias') + + # Second layer, as a fully-connected layer. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + DenseLayer(in_channels=in_channels * res * res, + out_channels=in_channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Dense0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Dense0/bias') + + # Final dense layer to output score. + self.output = DenseLayer(in_channels=in_channels, + out_channels=(embedding_dim + if self.use_embedding + else max(label_dim, 1)), + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight' + self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias' + + # Used for downsampling input image for `skip` architecture. + if self.architecture == 'skip': + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, label=None, fp16_res=None, impl='cuda'): + # Check shape. + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + batch_size = image.shape[0] + if label.ndim != 2 or label.shape != (batch_size, self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'images ({image.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + if self.use_embedding: + embed = self.embedding(label, impl=impl) + if self.normalize_embedding: + embed = self.norm(embed) + for i in range(self.mapping_layers): + embed = getattr(self, f'mapping{i}')(embed, impl=impl) + + # Cast to `torch.float16` if needed. + if fp16_res is not None and self.resolution >= fp16_res: + image = image.to(torch.float16) + + x = self.input0(image, impl=impl) + + for res_log2 in range(self.final_res_log2, self.init_res_log2, -1): + res = 2 ** res_log2 + # Cast to `torch.float16` if needed. + if fp16_res is not None and res >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + idx = self.final_res_log2 - res_log2 # Block index + + if self.architecture == 'skip' and idx > 0: + image = upfirdn2d.downsample2d(image, self.filter, impl=impl) + # Cast to `torch.float16` if needed. + if fp16_res is not None and res >= fp16_res: + image = image.to(torch.float16) + else: + image = image.to(torch.float32) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = x + y + + if self.architecture == 'resnet': + residual = getattr(self, f'residual{idx}')( + x, runtime_gain=np.sqrt(0.5), impl=impl) + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')( + x, runtime_gain=np.sqrt(0.5), impl=impl) + x = x + residual + else: + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) + + # Final output. + idx += 1 + if fp16_res is not None: # Always use FP32 for the last block. + x = x.to(torch.float32) + if self.architecture == 'skip': + image = upfirdn2d.downsample2d(image, self.filter, impl=impl) + if fp16_res is not None: # Always use FP32 for the last block. + image = image.to(torch.float32) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = x + y + x = self.mbstd(x) + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) + x = self.output(x, impl=impl) + + if self.use_embedding: + x = (x * embed).sum(dim=1, keepdim=True) + x = x / np.sqrt(self.embedding_dim) + elif self.label_dim > 0: + x = (x * label).sum(dim=1, keepdim=True) + + results = { + 'score': x, + 'label': label + } + if self.use_embedding: + results['embedding'] = embed + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, new_channels, eps): + super().__init__() + self.groups = groups + self.new_channels = new_channels + self.eps = eps + + def extra_repr(self): + return (f'groups={self.groups}, ' + f'new_channels={self.new_channels}, ' + f'epsilon={self.eps}') + + def forward(self, x): + if self.groups <= 1 or self.new_channels < 1: + return x + + dtype = x.dtype + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + nC = self.new_channels # Number of channel groups. + c = C // nC # Channels per channel group. + + y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW] + y = y - y.mean(dim=0) # [GnFcHW] + y = y.square().mean(dim=0) # [nFcHW] + y = (y + self.eps).sqrt() # [nFcHW] + y = y.mean(dim=(2, 3, 4)) # [nF] + y = y.reshape(-1, nC, 1, 1) # [nF11] + y = y.repeat(G, 1, H, W) # [NFHW] + x = torch.cat((x, y), dim=1) # [N(C+F)HW] + + assert x.dtype == dtype + return x + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If downsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` first. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type, + conv_clamp): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for downsampling. `1` means skip + downsampling. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + self.conv_clamp = conv_clamp + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw - scale_factor + 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh - scale_factor + 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'downsample={self.scale_factor}, ' + f'downsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward(self, x, runtime_gain=1.0, impl='cuda'): + dtype = x.dtype + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor == 1: # Native convolution without downsampling. + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=padding, impl=impl) + else: # Convolution with downsampling. + down = self.scale_factor + f = self.filter + padding = self.filter_padding + # When kernel size = 1, use filtering function for downsampling. + if self.kernel_size == 1: + x = upfirdn2d.upfirdn2d( + x, f, down=down, padding=padding, impl=impl) + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=0, impl=impl) + # When kernel size != 1, use stride convolution for downsampling. + else: + x = upfirdn2d.upfirdn2d( + x, f, down=1, padding=padding, impl=impl) + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=down, padding=0, impl=impl) + + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + init_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.init_bias = init_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + init_bias = np.float32(init_bias) / lr_mul + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan2_generator.py b/models/stylegan2_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..50472f326dfa4e198c40b6bcaf95d03750d0ff63 --- /dev/null +++ b/models/stylegan2_generator.py @@ -0,0 +1,1394 @@ +# python3.7 +"""Contains the implementation of generator described in StyleGAN2. + +Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style +demodulation, adds skip connections, increases model size, and disables +progressive growth. This script ONLY supports config F in the original paper. + +Paper: https://arxiv.org/pdf/1912.04958.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan2 +""" + +import numpy as np + +import torch +import torch.nn as nn + +from third_party.stylegan2_official_ops import fma +from third_party.stylegan2_official_ops import bias_act +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import conv2d_gradfix +from .utils.ops import all_gather + +__all__ = ['StyleGAN2Generator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Architectures allowed. +_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] + +# pylint: disable=missing-function-docstring + +class StyleGAN2Generator(nn.Module): + """Defines the generator network in StyleGAN2. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the mapping network: + + (1) z_dim: Dimension of the input latent space, Z. (default: 512) + (2) w_dim: Dimension of the output latent space, W. (default: 512) + (3) repeat_w: Repeat w-code for different layers. (default: True) + (4) normalize_z: Whether to normalize the z-code. (default: True) + (5) mapping_layers: Number of layers of the mapping network. (default: 8) + (6) mapping_fmaps: Number of hidden channels of the mapping network. + (default: 512) + (7) mapping_use_wscale: Whether to use weight scaling for the mapping + network. (default: True) + (8) mapping_wscale_gain: The factor to control weight scaling for the + mapping network (default: 1.0) + (9) mapping_lr_mul: Learning rate multiplier for the mapping network. + (default: 0.01) + + Settings for conditional generation: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_use_wscale: Whether to use weight scaling for embedding + learning. (default: True) + (5) embedding_wscale_gain: The factor to control weight scaling for + embedding. (default: 1.0) + (6) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (7) normalize_embedding: Whether to normalize the embedding. (default: True) + (8) normalize_embedding_latent: Whether to normalize the embedding together + with the latent. (default: False) + + Settings for the synthesis network: + + (1) resolution: The resolution of the output image. (default: -1) + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) image_channels: Number of channels of the output image. (default: 3) + (4) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (5) const_input: Whether to use a constant in the first convolutional layer. + (default: True) + (6) architecture: Type of architecture. Support `origin`, `skip`, and + `resnet`. (default: `skip`) + (7) demodulate: Whether to perform style demodulation. (default: True) + (8) use_wscale: Whether to use weight scaling. (default: True) + (9) wscale_gain: The factor to control weight scaling. (default: 1.0) + (10) lr_mul: Learning rate multiplier for the synthesis network. + (default: 1.0) + (11) noise_type: Type of noise added to the convolutional results at each + layer. (default: `spatial`) + (12) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (13) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (14) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 3, 3, 1)) + (15) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (16) eps: A small value to avoid divide overflow. (default: 1e-8) + + Runtime settings: + + (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for + training only. Set `None` to disable. (default: None) + (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set + as `True`, the stats will be more accurate, yet the speed maybe a little + bit slower. (default: False) + (3) style_mixing_prob: Probability to perform style mixing as a training + regularization. Set `None` to disable. (default: None) + (4) trunc_psi: Truncation psi, set `None` to disable. (default: None) + (5) trunc_layers: Number of layers to perform truncation. (default: None) + (6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`, + `const`. (default: `const`) + (7) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together. + (default: False) + (8) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (9) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN2, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=8, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_wscale_gain=1.0, + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_use_wscale=True, + embedding_wscale_gian=1.0, + embedding_lr_mul=1.0, + normalize_embedding=True, + normalize_embedding_latent=False, + # Settings for synthesis network. + resolution=-1, + init_res=4, + image_channels=3, + final_tanh=False, + const_input=True, + architecture='skip', + demodulate=True, + use_wscale=True, + wscale_gain=1.0, + lr_mul=1.0, + noise_type='spatial', + fmaps_base=32 << 10, + fmaps_max=512, + filter_kernel=(1, 3, 3, 1), + conv_clamp=None, + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `architecture` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + architecture = architecture.lower() + if architecture not in _ARCHITECTURES_ALLOWED: + raise ValueError(f'Invalid architecture: `{architecture}`!\n' + f'Architectures allowed: ' + f'{_ARCHITECTURES_ALLOWED}.') + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_wscale_gain = mapping_wscale_gain + self.mapping_lr_mul = mapping_lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_wscale_gain = embedding_wscale_gian + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + + self.resolution = resolution + self.init_res = init_res + self.image_channels = image_channels + self.final_tanh = final_tanh + self.const_input = const_input + self.architecture = architecture + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (z_dim,) + + # Number of synthesis (convolutional) layers. + self.num_layers = int(np.log2(resolution // init_res * 2)) * 2 + + self.mapping = MappingNetwork( + input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + use_wscale=mapping_use_wscale, + wscale_gain=mapping_wscale_gain, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + embedding_bias=embedding_bias, + embedding_use_wscale=embedding_use_wscale, + embedding_wscale_gian=embedding_wscale_gian, + embedding_lr_mul=embedding_lr_mul, + normalize_embedding=normalize_embedding, + normalize_embedding_latent=normalize_embedding_latent, + eps=eps) + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + self.synthesis = SynthesisNetwork(resolution=resolution, + init_res=init_res, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + const_input=const_input, + architecture=architecture, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + fmaps_base=fmaps_base, + filter_kernel=filter_kernel, + fmaps_max=fmaps_max, + conv_clamp=conv_clamp, + eps=eps) + + self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'} + for key, val in self.mapping.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'mapping.{key}'] = val + for key, val in self.synthesis.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + See `SynthesisNetwork` for more details. + """ + self.synthesis.set_space_of_latent(space_of_latent) + + def forward(self, + z, + label=None, + w_moving_decay=None, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + noise_mode='const', + fused_modulate=False, + fp16_res=None, + impl='cuda'): + """Connects mapping network and synthesis network. + + This forward function will also update the average `w_code`, perform + style mixing as a training regularizer, and do truncation trick, which + is specially designed for inference. + + Concretely, the truncation trick acts as follows: + + For layers in range [0, truncation_layers), the truncated w-code is + computed as + + w_new = w_avg + (w - w_avg) * trunc_psi + + To disable truncation, please set + + (1) trunc_psi = 1.0 (None) OR + (2) trunc_layers = 0 (None) + """ + + mapping_results = self.mapping(z, label, impl=impl) + + w = mapping_results['w'] + if self.training and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results.pop('wp') + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label, impl=impl)['wp'] + mixing_cutoff = np.random.randint(1, self.num_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + synthesis_results = self.synthesis(wp, + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl, + fp16_res=fp16_res) + + return {**mapping_results, **synthesis_results} + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping network. + + Basically, this network executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + use_wscale, + wscale_gain, + lr_mul, + label_dim, + embedding_dim, + embedding_bias, + embedding_use_wscale, + embedding_wscale_gian, + embedding_lr_mul, + normalize_embedding, + normalize_embedding_latent, + eps): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_wscale_gian = embedding_wscale_gian + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + self.eps = eps + + self.pth_to_tf_var_mapping = {} + + self.norm = PixelNormLayer(dim=1, eps=eps) + + if self.label_dim > 0: + input_dim = input_dim + embedding_dim + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + add_bias=embedding_bias, + init_bias=0.0, + use_wscale=embedding_use_wscale, + wscale_gain=embedding_wscale_gian, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' + if self.embedding_bias: + self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + self.add_module(f'dense{i}', + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight' + self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias' + + def forward(self, z, label=None, impl='cuda'): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + if self.normalize_input: + z = self.norm(z) + + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + embedding = self.embedding(label, impl=impl) + if self.normalize_embedding: + embedding = self.norm(embedding) + w = torch.cat((z, embedding), dim=1) + else: + w = z + + if self.label_dim > 0 and self.normalize_embedding_latent: + w = self.norm(w) + + for i in range(self.num_layers): + w = getattr(self, f'dense{i}')(w, impl=impl) + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class SynthesisNetwork(nn.Module): + """Implements the image synthesis network. + + Basically, this network executes several convolutional layers in sequence. + """ + + def __init__(self, + resolution, + init_res, + w_dim, + image_channels, + final_tanh, + const_input, + architecture, + demodulate, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + fmaps_base, + fmaps_max, + filter_kernel, + conv_clamp, + eps): + super().__init__() + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.const_input = const_input + self.architecture = architecture.lower() + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + self.pth_to_tf_var_mapping = {} + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2 ** res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # Early layer. + if res == init_res: + if self.const_input: + self.add_module('early_layer', + InputLayer(init_res=res, + channels=out_channels)) + self.pth_to_tf_var_mapping['early_layer.const'] = ( + f'{res}x{res}/Const/const') + else: + channels = out_channels * res * res + self.add_module('early_layer', + DenseLayer(in_channels=w_dim, + out_channels=channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping['early_layer.weight'] = ( + f'{res}x{res}/Dense/weight') + self.pth_to_tf_var_mapping['early_layer.bias'] = ( + f'{res}x{res}/Dense/bias') + else: + # Residual branch (kernel 1x1) with upsampling, without bias, + # with linear activation. + if self.architecture == 'resnet': + layer_name = f'residual{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + add_bias=False, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear', + conv_clamp=None)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Skip/weight') + + # First layer (kernel 3x3) with upsampling. + layer_name = f'layer{2 * block_idx - 1}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=2, + filter_kernel=filter_kernel, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + conv_clamp=conv_clamp, + eps=eps)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0_up/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0_up/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/Conv0_up/mod_weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/Conv0_up/mod_bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/Conv0_up/noise_strength') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx - 1}') + + # Second layer (kernel 3x3) without upsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + conv_clamp=conv_clamp, + eps=eps)) + tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/{tf_layer_name}/mod_weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/{tf_layer_name}/mod_bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/{tf_layer_name}/noise_strength') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx}') + + # Output convolution layer for each resolution (if needed). + if res_log2 == self.final_res_log2 or self.architecture == 'skip': + layer_name = f'output{block_idx}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=image_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type='none', + activation_type='linear', + conv_clamp=conv_clamp, + eps=eps)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/ToRGB/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/ToRGB/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/ToRGB/mod_weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/ToRGB/mod_bias') + + # Used for upsampling output images for each resolution block for sum. + if self.architecture == 'skip': + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + This function is particularly used for choosing how to inject the latent + code into the convolutional layers. The original generator will take a + W-Space code and apply it for style modulation after an affine + transformation. But, sometimes, it may need to directly feed an already + affine-transformed code into the convolutional layer, e.g., when + training an encoder for GAN inversion. We term the transformed space as + Style Space (or Y-Space). This function is designed to tell the + convolutional layers how to use the input code. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer): + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, + wp, + noise_mode='const', + fused_modulate=False, + fp16_res=None, + impl='cuda'): + results = {'wp': wp} + + if self.const_input: + x = self.early_layer(wp[:, 0]) + else: + x = self.early_layer(wp[:, 0], impl=impl) + + # Cast to `torch.float16` if needed. + if fp16_res is not None and self.init_res >= fp16_res: + x = x.to(torch.float16) + + if self.architecture == 'origin': + for layer_idx in range(self.num_layers - 1): + layer = getattr(self, f'layer{layer_idx}') + x, style = layer(x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx}'] = style + + # Cast to `torch.float16` if needed. + if layer_idx % 2 == 0 and layer_idx != self.num_layers - 2: + res = self.init_res * (2 ** (layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + output_layer = getattr(self, f'output{layer_idx // 2}') + image, style = output_layer(x, + wp[:, layer_idx + 1], + fused_modulate=fused_modulate, + impl=impl) + image = image.to(torch.float32) + results[f'output_style{layer_idx // 2}'] = style + + elif self.architecture == 'skip': + for layer_idx in range(self.num_layers - 1): + layer = getattr(self, f'layer{layer_idx}') + x, style = layer(x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx}'] = style + if layer_idx % 2 == 0: + output_layer = getattr(self, f'output{layer_idx // 2}') + y, style = output_layer(x, + wp[:, layer_idx + 1], + fused_modulate=fused_modulate, + impl=impl) + results[f'output_style{layer_idx // 2}'] = style + if layer_idx == 0: + image = y.to(torch.float32) + else: + image = y.to(torch.float32) + upfirdn2d.upsample2d( + image, self.filter, impl=impl) + + # Cast to `torch.float16` if needed. + if layer_idx != self.num_layers - 2: + res = self.init_res * (2 ** (layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + elif self.architecture == 'resnet': + x, style = self.layer0(x, + wp[:, 0], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results['style0'] = style + for layer_idx in range(1, self.num_layers - 1, 2): + # Cast to `torch.float16` if needed. + if layer_idx % 2 == 1: + res = self.init_res * (2 ** (layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + skip_layer = getattr(self, f'residual{layer_idx // 2 + 1}') + residual = skip_layer(x, runtime_gain=np.sqrt(0.5), impl=impl) + layer = getattr(self, f'layer{layer_idx}') + x, style = layer(x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx}'] = style + layer = getattr(self, f'layer{layer_idx + 1}') + x, style = layer(x, + wp[:, layer_idx + 1], + runtime_gain=np.sqrt(0.5), + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx + 1}'] = style + x = x + residual + output_layer = getattr(self, f'output{layer_idx // 2 + 1}') + image, style = output_layer(x, + wp[:, layer_idx + 2], + fused_modulate=fused_modulate, + impl=impl) + image = image.to(torch.float32) + results[f'output_style{layer_idx // 2}'] = style + + if self.final_tanh: + image = torch.tanh(image) + results['image'] = image + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class InputLayer(nn.Module): + """Implements the input layer to start convolution with. + + Basically, this block starts from a const input, which is with shape + `(channels, init_res, init_res)`. + """ + + def __init__(self, init_res, channels): + super().__init__() + self.const = nn.Parameter(torch.randn(1, channels, init_res, init_res)) + + def forward(self, w): + x = self.const.repeat(w.shape[0], 1, 1, 1) + return x + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If upsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` after convolution. This layer will only be + used for skip connection in `resnet` architecture. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type, + conv_clamp): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for upsampling. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + self.conv_clamp = conv_clamp + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw + scale_factor - 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh + scale_factor - 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'upsample={self.scale_factor}, ' + f'upsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward(self, x, runtime_gain=1.0, impl='cuda'): + dtype = x.dtype + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor == 1: # Native convolution without upsampling. + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=padding, impl=impl) + else: # Convolution with upsampling. + up = self.scale_factor + f = self.filter + # When kernel size = 1, use filtering function for upsampling. + if self.kernel_size == 1: + padding = self.filter_padding + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=0, impl=impl) + x = upfirdn2d.upfirdn2d( + x, f, up=up, padding=padding, gain=up ** 2, impl=impl) + # When kernel size != 1, use transpose convolution for upsampling. + else: + # Following codes are borrowed from + # https://github.com/NVlabs/stylegan2-ada-pytorch + px0, px1, py0, py1 = self.filter_padding + kh, kw = weight.shape[2:] + px0 = px0 - (kw - 1) + px1 = px1 - (kw - up) + py0 = py0 - (kh - 1) + py1 = py1 - (kh - up) + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + weight = weight.transpose(0, 1) + padding = (pyt, pxt) + x = conv2d_gradfix.conv_transpose2d( + x, weight.to(dtype), stride=up, padding=padding, impl=impl) + padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt) + x = upfirdn2d.upfirdn2d( + x, f, up=1, padding=padding, gain=up ** 2, impl=impl) + + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class ModulateConvLayer(nn.Module): + """Implements the convolutional layer with style modulation.""" + + def __init__(self, + in_channels, + out_channels, + resolution, + w_dim, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + demodulate, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + activation_type, + conv_clamp, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + resolution: Resolution of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for upsampling. + filter_kernel: Kernel used for filtering. + demodulate: Whether to perform style demodulation. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + noise_type: Type of noise added to the feature map after the + convolution (if needed). Support `none`, `spatial` and + `channel`. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + eps: A small value to avoid divide overflow. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.resolution = resolution + self.w_dim = w_dim + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.activation_type = activation_type + self.conv_clamp = conv_clamp + self.eps = eps + + self.space_of_latent = 'W' + + # Set up weight. + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + # Set up bias. + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + # Set up style. + self.style = DenseLayer(in_channels=w_dim, + out_channels=in_channels, + add_bias=True, + init_bias=1.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear') + + # Set up noise. + if self.noise_type != 'none': + self.noise_strength = nn.Parameter(torch.zeros(())) + if self.noise_type == 'spatial': + self.register_buffer( + 'noise', torch.randn(1, 1, resolution, resolution)) + elif self.noise_type == 'channel': + self.register_buffer( + 'noise', torch.randn(1, out_channels, 1, 1)) + else: + raise NotImplementedError(f'Not implemented noise type: ' + f'`{self.noise_type}`!') + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw + scale_factor - 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh + scale_factor - 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'upsample={self.scale_factor}, ' + f'upsample_filter={self.filter_kernel}, ' + f'demodulate={self.demodulate}, ' + f'noise_type={self.noise_type}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward_style(self, w, impl='cuda'): + """Gets style code from the given input. + + More specifically, if the input is from W-Space, it will be projected by + an affine transformation. If it is from the Style Space (Y-Space), no + operation is required. + + NOTE: For codes from Y-Space, we use slicing to make sure the dimension + is correct, in case that the code is padded before fed into this layer. + """ + space_of_latent = self.space_of_latent.upper() + if space_of_latent == 'W': + if w.ndim != 2 or w.shape[1] != self.w_dim: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, w_dim], where ' + f'`w_dim` equals to {self.w_dim}!\n' + f'But `{w.shape}` is received!') + style = self.style(w, impl=impl) + elif space_of_latent == 'Y': + if w.ndim != 2 or w.shape[1] < self.in_channels: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, y_dim], where ' + f'`y_dim` equals to {self.in_channels}!\n' + f'But `{w.shape}` is received!') + style = w[:, :self.in_channels] + else: + raise NotImplementedError(f'Not implemented `space_of_latent`: ' + f'`{space_of_latent}`!') + return style + + def forward(self, + x, + w, + runtime_gain=1.0, + noise_mode='const', + fused_modulate=False, + impl='cuda'): + dtype = x.dtype + N, C, H, W = x.shape + + fused_modulate = (fused_modulate and + not self.training and + (dtype == torch.float32 or N == 1)) + + weight = self.weight + out_ch, in_ch, kh, kw = weight.shape + assert in_ch == C + + # Affine on `w`. + style = self.forward_style(w, impl=impl) + if not self.demodulate: + _style = style * self.wscale # Equivalent to scaling weight. + else: + _style = style + + # Prepare noise. + noise = None + noise_mode = noise_mode.lower() + if self.noise_type != 'none' and noise_mode != 'none': + if noise_mode == 'random': + noise = torch.randn((N, *self.noise.shape[1:]), device=x.device) + elif noise_mode == 'const': + noise = self.noise + else: + raise ValueError(f'Unknown noise mode `{noise_mode}`!') + noise = (noise * self.noise_strength).to(dtype) + + # Pre-normalize inputs to avoid FP16 overflow. + if dtype == torch.float16 and self.demodulate: + weight_max = weight.norm(float('inf'), dim=(1, 2, 3), keepdim=True) + weight = weight * (self.wscale / weight_max) + style_max = _style.norm(float('inf'), dim=1, keepdim=True) + _style = _style / style_max + + if self.demodulate or fused_modulate: + _weight = weight.unsqueeze(0) + _weight = _weight * _style.reshape(N, 1, in_ch, 1, 1) + if self.demodulate: + decoef = (_weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt() + if self.demodulate and fused_modulate: + _weight = _weight * decoef.reshape(N, out_ch, 1, 1, 1) + + if not fused_modulate: + x = x * _style.to(dtype).reshape(N, in_ch, 1, 1) + w = weight.to(dtype) + groups = 1 + else: # Use group convolution to fuse style modulation and convolution. + x = x.reshape(1, N * in_ch, H, W) + w = _weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype) + groups = N + + if self.scale_factor == 1: # Native convolution without upsampling. + up = 1 + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, w, stride=1, padding=padding, groups=groups, impl=impl) + else: # Convolution with upsampling. + up = self.scale_factor + f = self.filter + # When kernel size = 1, use filtering function for upsampling. + if self.kernel_size == 1: + padding = self.filter_padding + x = conv2d_gradfix.conv2d( + x, w, stride=1, padding=0, groups=groups, impl=impl) + x = upfirdn2d.upfirdn2d( + x, f, up=up, padding=padding, gain=up ** 2, impl=impl) + # When kernel size != 1, use stride convolution for upsampling. + else: + # Following codes are borrowed from + # https://github.com/NVlabs/stylegan2-ada-pytorch + px0, px1, py0, py1 = self.filter_padding + px0 = px0 - (kw - 1) + px1 = px1 - (kw - up) + py0 = py0 - (kh - 1) + py1 = py1 - (kh - up) + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(N, out_ch, in_ch, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(N * in_ch, out_ch, kh, kw) + padding = (pyt, pxt) + x = conv2d_gradfix.conv_transpose2d( + x, w, stride=up, padding=padding, groups=groups, impl=impl) + padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt) + x = upfirdn2d.upfirdn2d( + x, f, up=1, padding=padding, gain=up ** 2, impl=impl) + + if not fused_modulate: + if self.demodulate: + decoef = decoef.to(dtype).reshape(N, out_ch, 1, 1) + if self.demodulate and noise is not None: + x = fma.fma(x, decoef, noise, impl=impl) + else: + if self.demodulate: + x = x * decoef + if noise is not None: + x = x + noise + else: + x = x.reshape(N, out_ch, H * up, W * up) + if noise is not None: + x = x + noise + + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.activation_type == 'linear': # Shortcut for output layer. + x = bias_act.bias_act( + x, bias, act='linear', clamp=self.conv_clamp, impl=impl) + else: + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + assert style.dtype == torch.float32 + return x, style + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + init_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.init_bias = init_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + init_bias = np.float32(init_bias) / lr_mul + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan3_generator.py b/models/stylegan3_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9e609ba4625e2a642d2a5b3c15c368fcfe6e7ff2 --- /dev/null +++ b/models/stylegan3_generator.py @@ -0,0 +1,1332 @@ +# python3.7 +"""Contains the implementation of generator described in StyleGAN3. + +Compared to that of StyleGAN2, the generator in StyleGAN3 controls the frequency +flow along with the convolutional layers growing. + +Paper: https://arxiv.org/pdf/2106.12423.pdf + +Official implementation: https://github.com/NVlabs/stylegan3 +""" + +import numpy as np +import scipy.signal + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from third_party.stylegan3_official_ops import bias_act +from third_party.stylegan3_official_ops import filtered_lrelu +from third_party.stylegan3_official_ops import conv2d_gradfix +from .utils.ops import all_gather + +__all__ = ['StyleGAN3Generator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# pylint: disable=missing-function-docstring + +class StyleGAN3Generator(nn.Module): + """Defines the generator network in StyleGAN3. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the mapping network: + + (1) z_dim: Dimension of the input latent space, Z. (default: 512) + (2) w_dim: Dimension of the output latent space, W. (default: 512) + (3) repeat_w: Repeat w-code for different layers. (default: True) + (4) normalize_z: Whether to normalize the z-code. (default: True) + (5) mapping_layers: Number of layers of the mapping network. (default: 2) + (6) mapping_fmaps: Number of hidden channels of the mapping network. + (default: 512) + (7) mapping_lr_mul: Learning rate multiplier for the mapping network. + (default: 0.01) + + Settings for conditional generation: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (5) normalize_embedding: Whether to normalize the embedding. (default: True) + (6) normalize_embedding_latent: Whether to normalize the embedding together + with the latent. (default: False) + + Settings for the synthesis network: + + (1) resolution: The resolution of the output image. (default: -1) + (2) image_channels: Number of channels of the output image. (default: 3) + (3) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (4) output_scale: Factor to scaling the output image. (default: 0.25) + (5) num_layers: Number of synthesis layers, excluding the first positional + encoding layer and the last ToRGB layer. (default: 14) + (6) num_critical: Number of synthesis layers with critical sampling. These + layers are always set as top (with highest resolution) ones. + (7) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (8) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (9) kernel_size: Size of convolutional kernels. (default: 1) + (10) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (11) first_cutoff: Cutoff frequency of the first layer. (default: 2) + (12) first_stopband: Stopband of the first layer. (default: 2 ** 2.1) + (13) last_stopband_rel: Stopband of the last layer, relative to the last + cutoff, which is `resolution / 2`. Concretely, `last_stopband` will be + equal to `resolution / 2 * last_stopband_rel`. (default: 2 ** 0.3) + (14) margin_size: Size of margin for each feature map. (default: 10) + (15) filter_size: Size of filter for upsampling and downsampling around the + activation. (default: 6) + (16) act_upsampling: Factor used to upsample the feature map before + activation for anti-aliasing. (default: 2) + (17) use_radial_filter: Whether to use radial filter for downsampling after + the activation. (default: False) + (18) eps: A small value to avoid divide overflow. (default: 1e-8) + + Runtime settings: + + (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for + training only. Set `None` to disable. (default: 0.998) + (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set + as `True`, the stats will be more accurate, yet the speed maybe a little + bit slower. (default: False) + (3) style_mixing_prob: Probability to perform style mixing as a training + regularization. Set `None` to disable. (default: None) + (4) trunc_psi: Truncation psi, set `None` to disable. (default: None) + (5) trunc_layers: Number of layers to perform truncation. (default: None) + (6) magnitude_moving_decay: Decay factor for updating `magnitude_ema` in + each `SynthesisLayer`, which is used for training only. Set `None` to + disable. (default: 0.999) + (7) update_ema: Whether to update `w_avg` in the `MappingNetwork` and + `magnitude_ema` in each `SynthesisLayer`. This field only takes effect + in `training` model. (default: False) + (8) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (9) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN3, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=2, + mapping_fmaps=512, + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_lr_mul=1.0, + normalize_embedding=True, + normalize_embedding_latent=False, + # Settings for synthesis network. + resolution=-1, + image_channels=3, + final_tanh=False, + output_scale=0.25, + num_layers=14, + num_critical=2, + fmaps_base=32 << 10, + fmaps_max=512, + kernel_size=1, + conv_clamp=256, + first_cutoff=2, + first_stopband=2 ** 2.1, + last_stopband_rel=2 ** 0.3, + margin_size=10, + filter_size=6, + act_upsampling=2, + use_radial_filter=False, + eps=1e-8): + """Initializes with basic settings.""" + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_lr_mul = mapping_lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + + self.resolution = resolution + self.image_channels = image_channels + self.final_tanh = final_tanh + self.output_scale = output_scale + self.num_layers = num_layers + 2 # Including InputLayer and ToRGBLayer. + self.num_critical = num_critical + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.kernel_size = kernel_size + self.conv_clamp = conv_clamp + self.first_cutoff = first_cutoff + self.first_stopband = first_stopband + self.last_stopband_rel = last_stopband_rel + self.margin_size = margin_size + self.filter_size = filter_size + self.act_upsampling = act_upsampling + self.use_radial_filter = use_radial_filter + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (z_dim,) + + self.mapping = MappingNetwork( + input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + embedding_bias=embedding_bias, + embedding_lr_mul=embedding_lr_mul, + normalize_embedding=normalize_embedding, + normalize_embedding_latent=normalize_embedding_latent, + eps=eps) + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + self.synthesis = SynthesisNetwork(resolution=resolution, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + output_scale=output_scale, + num_layers=num_layers, + num_critical=num_critical, + fmaps_base=fmaps_base, + fmaps_max=fmaps_max, + kernel_size=kernel_size, + conv_clamp=conv_clamp, + first_cutoff=first_cutoff, + first_stopband=first_stopband, + last_stopband_rel=last_stopband_rel, + margin_size=margin_size, + filter_size=filter_size, + act_upsampling=act_upsampling, + use_radial_filter=use_radial_filter, + eps=eps) + + self.var_mapping = {'w_avg': 'mapping.w_avg'} + for key, val in self.mapping.var_mapping.items(): + self.var_mapping[f'mapping.{key}'] = f'mapping.{val}' + for key, val in self.synthesis.var_mapping.items(): + self.var_mapping[f'synthesis.{key}'] = f'synthesis.{val}' + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + See `SynthesisNetwork` for more details. + """ + self.synthesis.set_space_of_latent(space_of_latent) + + def forward(self, + z, + label=None, + w_moving_decay=0.998, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + magnitude_moving_decay=0.999, + update_ema=False, + fp16_res=None, + impl='cuda'): + """Connects mapping network and synthesis network. + + This forward function will also update the average `w_code`, perform + style mixing as a training regularizer, and do truncation trick, which + is specially designed for inference. + + Concretely, the truncation trick acts as follows: + + For layers in range [0, truncation_layers), the truncated w-code is + computed as + + w_new = w_avg + (w - w_avg) * trunc_psi + + To disable truncation, please set + + (1) trunc_psi = 1.0 (None) OR + (2) trunc_layers = 0 (None) + """ + + mapping_results = self.mapping(z, label, impl=impl) + + w = mapping_results['w'] + if self.training and update_ema and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results.pop('wp') + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label, impl=impl)['wp'] + mixing_cutoff = np.random.randint(1, self.num_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + synthesis_results = self.synthesis( + wp, + magnitude_moving_decay=magnitude_moving_decay, + update_ema=update_ema, + fp16_res=fp16_res, + impl=impl) + + return {**mapping_results, **synthesis_results} + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping network. + + Basically, this network executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + lr_mul, + label_dim, + embedding_dim, + embedding_bias, + embedding_lr_mul, + normalize_embedding, + normalize_embedding_latent, + eps): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.lr_mul = lr_mul + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + self.eps = eps + + self.var_mapping = {} + + self.norm = PixelNormLayer(dim=1, eps=eps) + + if self.label_dim > 0: + input_dim = input_dim + embedding_dim + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + init_weight_std=1.0, + add_bias=embedding_bias, + init_bias=0.0, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.var_mapping['embedding.weight'] = 'embed.weight' + if self.embedding_bias: + self.var_mapping['embedding.bias'] = 'embed.bias' + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + self.add_module(f'dense{i}', + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + init_weight_std=1.0, + add_bias=True, + init_bias=0.0, + lr_mul=lr_mul, + activation_type='lrelu')) + self.var_mapping[f'dense{i}.weight'] = f'fc{i}.weight' + self.var_mapping[f'dense{i}.bias'] = f'fc{i}.bias' + + def forward(self, z, label=None, impl='cuda'): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + if self.normalize_input: + z = self.norm(z) + + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + embedding = self.embedding(label, impl=impl) + if self.normalize_embedding: + embedding = self.norm(embedding) + w = torch.cat((z, embedding), dim=1) + else: + w = z + + if self.label_dim > 0 and self.normalize_embedding_latent: + w = self.norm(w) + + for i in range(self.num_layers): + w = getattr(self, f'dense{i}')(w, impl=impl) + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class SynthesisNetwork(nn.Module): + """Implements the image synthesis network. + + Basically, this network executes several convolutional layers in sequence. + """ + + def __init__(self, + resolution, + w_dim, + image_channels, + final_tanh, + output_scale, + num_layers, + num_critical, + fmaps_base, + fmaps_max, + kernel_size, + conv_clamp, + first_cutoff, + first_stopband, + last_stopband_rel, + margin_size, + filter_size, + act_upsampling, + use_radial_filter, + eps): + super().__init__() + + self.resolution = resolution + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.output_scale = output_scale + self.num_layers = num_layers + self.num_critical = num_critical + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.kernel_size = kernel_size + self.conv_clamp = conv_clamp + self.first_cutoff = first_cutoff + self.first_stopband = first_stopband + self.last_stopband_rel = last_stopband_rel + self.margin_size = margin_size + self.filter_size = filter_size + self.act_upsampling = act_upsampling + self.use_radial_filter = use_radial_filter + self.eps = eps + + self.var_mapping = {} + + # Get layer settings. + last_cutoff = resolution / 2 + last_stopband = last_cutoff * last_stopband_rel + layer_indices = np.arange(num_layers + 1) + exponents = np.minimum(layer_indices / (num_layers - num_critical), 1) + cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents + stopbands = ( + first_stopband * (last_stopband / first_stopband) ** exponents) + sampling_rates = np.exp2(np.ceil(np.log2( + np.minimum(stopbands * 2, self.resolution)))) + sampling_rates = np.int64(sampling_rates) + half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs + sizes = sampling_rates + margin_size * 2 + sizes[-2:] = resolution + sizes = np.int64(sizes) + channels = np.rint(np.minimum((fmaps_base / 2) / cutoffs, fmaps_max)) + channels[-1] = image_channels + channels = np.int64(channels) + + self.cutoffs = cutoffs + self.stopbands = stopbands + self.sampling_rates = sampling_rates + self.half_widths = half_widths + self.sizes = sizes + self.channels = channels + + # Input layer, with positional encoding. + self.early_layer = InputLayer(w_dim=w_dim, + channels=channels[0], + size=sizes[0], + sampling_rate=sampling_rates[0], + cutoff=cutoffs[0]) + self.var_mapping['early_layer.weight'] = 'input.weight' + self.var_mapping['early_layer.affine.weight'] = 'input.affine.weight' + self.var_mapping['early_layer.affine.bias'] = 'input.affine.bias' + self.var_mapping['early_layer.transform'] = 'input.transform' + self.var_mapping['early_layer.frequency'] = 'input.freqs' + self.var_mapping['early_layer.phase'] = 'input.phases' + + # Convolutional layers. + for idx in range(num_layers + 1): + # Position related settings. + if idx < num_layers: + kernel_size = self.kernel_size + demodulate = True + act_upsampling = self.act_upsampling + else: # ToRGB layer. + kernel_size = 1 + demodulate = False + act_upsampling = 1 + if idx < num_layers - num_critical: # Non-critical sampling. + use_radial_filter = self.use_radial_filter + else: # Critical sampling. + use_radial_filter = False + + prev_idx = max(idx - 1, 0) + layer_name = f'layer{idx}' + official_layer_name = f'L{idx}_{sizes[idx]}_{channels[idx]}' + self.add_module( + layer_name, + SynthesisLayer(in_channels=channels[prev_idx], + out_channels=channels[idx], + w_dim=w_dim, + kernel_size=kernel_size, + demodulate=demodulate, + eps=eps, + conv_clamp=conv_clamp, + in_size=sizes[prev_idx], + out_size=sizes[idx], + in_sampling_rate=sampling_rates[prev_idx], + out_sampling_rate=sampling_rates[idx], + in_cutoff=cutoffs[prev_idx], + out_cutoff=cutoffs[idx], + in_half_width=half_widths[prev_idx], + out_half_width=half_widths[idx], + filter_size=filter_size, + use_radial_filter=use_radial_filter, + act_upsampling=act_upsampling)) + + self.var_mapping[f'{layer_name}.magnitude_ema'] = ( + f'{official_layer_name}.magnitude_ema') + self.var_mapping[f'{layer_name}.conv.weight'] = ( + f'{official_layer_name}.weight') + self.var_mapping[f'{layer_name}.conv.style.weight'] = ( + f'{official_layer_name}.affine.weight') + self.var_mapping[f'{layer_name}.conv.style.bias'] = ( + f'{official_layer_name}.affine.bias') + self.var_mapping[f'{layer_name}.filter.bias'] = ( + f'{official_layer_name}.bias') + if idx < num_layers: # ToRGB layer does not need filters. + self.var_mapping[f'{layer_name}.filter.up_filter'] = ( + f'{official_layer_name}.up_filter') + self.var_mapping[f'{layer_name}.filter.down_filter'] = ( + f'{official_layer_name}.down_filter') + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + This function is particularly used for choosing how to inject the latent + code into the convolutional layers. The original generator will take a + W-Space code and apply it for style modulation after an affine + transformation. But, sometimes, it may need to directly feed an already + affine-transformed code into the convolutional layer, e.g., when + training an encoder for GAN inversion. We term the transformed space as + Style Space (or Y-Space). This function is designed to tell the + convolutional layers how to use the input code. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer): + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, + wp, + magnitude_moving_decay=0.999, + update_ema=False, + fp16_res=None, + impl='cuda'): + results = {'wp': wp} + + x = self.early_layer(wp[:, 0]) + for idx, sampling_rate in enumerate(self.sampling_rates): + if fp16_res is not None and sampling_rate >= fp16_res: + x = x.to(torch.float16) + layer = getattr(self, f'layer{idx}') + x, style = layer(x, wp[:, idx + 1], + magnitude_moving_decay=magnitude_moving_decay, + update_ema=update_ema, + impl=impl) + results[f'style{idx}'] = style + + if self.output_scale != 1: + x = x * self.output_scale + x = x.to(torch.float32) + if self.final_tanh: + x = torch.tanh(x) + results['image'] = x + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class InputLayer(nn.Module): + """Implements the input layer with positional encoding. + + Basically, this block outputs a feature map with shape + `(channels, size, size)` based on the coordinate information. + `sampling_rate` and `cutoff` are used to control the coordinate range and + strength respectively. + + For a low-pass filter, `cutoff` is the same as the `bandwidth`. + The initial frequency of the starting feature map is controlled by the + positional encoding `sin(2 * pi * x)`, where + `x = trans(coord) * frequency + phase`. We would like to introduce rich + information (i.e. frequencies), but keep all frequencies lower than + stopband, which is `sampling_rate / 2`. + + Besides, this layer also supports learning a transformation from the latent + code w, and providing a customized transformation for inference. Please + use the buffer `transform`. + + NOTE: `size` is different from `sampling_rate`. `sampling_rate` is the + actual size of the current stage, which determines the maximum frequency + that the feature maps can hold. `size` is the actual height and width of the + current feature map, including the extended border. + """ + + def __init__(self, w_dim, channels, size, sampling_rate, cutoff): + super().__init__() + + self.w_dim = w_dim + self.channels = channels + self.size = size + self.sampling_rate = sampling_rate + self.cutoff = cutoff + + # Coordinate of the entire feature map, with resolution (size, size). + # The coordinate range for the central (sampling_rate, sampling_rate) + # region is set as (-0.0, 0.5), which extends to the remaining region. + theta = torch.eye(2, 3) + theta[0, 0] = 0.5 / sampling_rate * size + theta[1, 1] = 0.5 / sampling_rate * size + grid = F.affine_grid(theta=theta.unsqueeze(0), + size=(1, 1, size, size), + align_corners=False) + self.register_buffer('grid', grid) + + # Draw random frequency from a uniform 2D disc for each channel + # regarding X and Y dimension. And also draw a random phase for each + # channel. Accordingly, each channel has three pre-defined parameters, + # which are X-frequency, Y-frequency, and phase. + frequency = torch.randn(channels, 2) + radius = frequency.square().sum(dim=1, keepdim=True).sqrt() + frequency = frequency / (radius * radius.square().exp().pow(0.25)) + frequency = frequency * cutoff + self.register_buffer('frequency', frequency) + phase = torch.rand(channels) - 0.5 + self.register_buffer('phase', phase) + + # This layer is used to map the latent code w to transform factors, + # with order: cos(angle), sin(angle), transpose_x, transpose_y. + self.affine = DenseLayer(in_channels=w_dim, + out_channels=4, + init_weight_std=0.0, + add_bias=True, + init_bias=(1, 0, 0, 0), + lr_mul=1.0, + activation_type='linear') + + # It is possible to use this buffer to customize the transform of the + # output synthesis. + self.register_buffer('transform', torch.eye(3)) + + # Use 1x1 conv to convert positional encoding to features. + self.weight = nn.Parameter(torch.randn(channels, channels)) + self.weight_scale = 1 / np.sqrt(channels) + + def extra_repr(self): + return (f'channels={self.channels}, ' + f'size={self.size}, ' + f'sampling_rate={self.sampling_rate}, ' + f'cutoff={self.cutoff:.3f}, ') + + def forward(self, w): + batch = w.shape[0] + + # Get transformation matrix. + # Factor controlled by latent code. + transformation_factor = self.affine(w) + # Ensure the range of cosine and sine value (first two dimension). + _norm = transformation_factor[:, :2].norm(dim=1, keepdim=True) + transformation_factor = transformation_factor / _norm + # Rotation. + rotation = torch.eye(3, device=w.device).unsqueeze(0) + rotation = rotation.repeat((batch, 1, 1)) + rotation[:, 0, 0] = transformation_factor[:, 0] + rotation[:, 0, 1] = -transformation_factor[:, 1] + rotation[:, 1, 0] = transformation_factor[:, 1] + rotation[:, 1, 1] = transformation_factor[:, 0] + # Translation. + translation = torch.eye(3, device=w.device).unsqueeze(0) + translation = translation.repeat((batch, 1, 1)) + translation[:, 0, 2] = -transformation_factor[:, 2] + translation[:, 1, 2] = -transformation_factor[:, 3] + # Customized transformation. + transform = rotation @ translation @ self.transform.unsqueeze(0) + + # Transform frequency and shift, which is equivalent to transforming + # the coordinate. For example, given a coordinate, X, we would like to + # first transform it with the rotation matrix, R, and the translation + # matrix, T, as X' = RX + T. Then, we will apply frequency, f, and + # phase, p, with sin(2 * pi * (fX' + p)). Natively, we have + # fX' + p = f(RX + T) + p = (fR)X + (fT + p) + frequency = self.frequency.unsqueeze(0) @ transform[:, :2, :2] # [NC2] + phase = self.frequency.unsqueeze(0) @ transform[:, :2, 2:] # [NC] + phase = phase.squeeze(2) + self.phase.unsqueeze(0) # [NC] + + # Positional encoding. + x = self.grid # [NHW2] + x = x.unsqueeze(3) # [NHW12] + x = x @ frequency.transpose(1, 2).unsqueeze(1).unsqueeze(2) # [NHW1C] + x = x.squeeze(3) # [NHWC] + x = x + phase.unsqueeze(1).unsqueeze(2) # [NHWC] + x = torch.sin(2 * np.pi * x) # [NHWC] + + # Dampen out-of-band frequency that may be introduced by the customized + # transform `self.transform`. + frequency_norm = frequency.norm(dim=2) + stopband = self.sampling_rate / 2 + factor = (frequency_norm - self.cutoff) / (stopband - self.cutoff) + amplitude = (1 - factor).clamp(0, 1) # [NC] + x = x * amplitude.unsqueeze(1).unsqueeze(2) # [NHWC] + + # Project positional encoding to features. + weight = self.weight * self.weight_scale + x = x @ weight.t() + + return x.permute(0, 3, 1, 2).contiguous() + + +class SynthesisLayer(nn.Module): + """Implements the synthesis layer. + + Each synthesis layer (including ToRGB layer) consists of a + `ModulateConvLayer` and a `FilteringActLayer`. Besides, this layer will + trace the magnitude (norm) of the input feature map, and update the + statistic with `magnitude_moving_decay`. + """ + + def __init__(self, + # Settings for modulated convolution. + in_channels, + out_channels, + w_dim, + kernel_size, + demodulate, + eps, + conv_clamp, + # Settings for filtering activation. + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + filter_size, + use_radial_filter, + act_upsampling): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + demodulate: Whether to perform style demodulation. + eps: A small value to avoid divide overflow. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + in_size: Size of the input feature map, i.e., height and width. + out_size: Size of the output feature map, i.e., height and width. + in_sampling_rate: Sampling rate of the input feature map. Different + from `in_size` that includes extended border, this field + controls the actual maximum frequency that can be represented + by the feature map. + out_sampling_rate: Sampling rate of the output feature map. + in_cutoff: Cutoff frequency of the input feature map. + out_cutoff: Cutoff frequency of the output feature map. + in_half_width: Half-width of the transition band of the input + feature map. + out_half_width: Half-width of the transition band of the output + feature map. + filter_size: Size of the filter used in this layer. + use_radial_filter: Whether to use radial filter. + act_upsampling: Upsampling factor used before the activation. + `1` means do not wrap upsampling and downsampling around the + activation. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.kernel_size = kernel_size + self.demodulate = demodulate + self.eps = eps + self.conv_clamp = conv_clamp + + self.in_size = in_size + self.out_size = out_size + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.filter_size = filter_size + self.use_radial_filter = use_radial_filter + self.act_upsampling = act_upsampling + + self.conv = ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + w_dim=w_dim, + kernel_size=kernel_size, + demodulate=demodulate, + eps=eps) + self.register_buffer('magnitude_ema', torch.ones(())) + self.filter = FilteringActLayer(out_channels=out_channels, + in_size=in_size, + out_size=out_size, + in_sampling_rate=in_sampling_rate, + out_sampling_rate=out_sampling_rate, + in_cutoff=in_cutoff, + out_cutoff=out_cutoff, + in_half_width=in_half_width, + out_half_width=out_half_width, + filter_size=filter_size, + use_radial_filter=use_radial_filter, + conv_padding=self.conv.padding, + act_upsampling=act_upsampling) + + def extra_repr(self): + return f'conv_clamp={self.conv_clamp}' + + def forward(self, + x, + w, + magnitude_moving_decay=0.999, + update_ema=False, + impl='cuda'): + if self.training and update_ema and magnitude_moving_decay is not None: + magnitude = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_( + magnitude.lerp(self.magnitude_ema, magnitude_moving_decay)) + + input_gain = self.magnitude_ema.rsqrt() + x, style = self.conv(x, w, gain=input_gain, impl=impl) + if self.act_upsampling > 1: + x = self.filter(x, np.sqrt(2), 0.2, self.conv_clamp, impl=impl) + else: + x = self.filter(x, 1, 1, self.conv_clamp, impl=impl) + + return x, style + + +class ModulateConvLayer(nn.Module): + """Implements the convolutional layer with style modulation. + + Different from the one introduced in StyleGAN2, this layer has following + changes: + + (1) fusing `conv` and `style modulation` into one op by default + (2) NOT adding a noise onto the output feature map. + (3) NOT activating the feature map, which is moved to `FilteringActLayer`. + """ + + def __init__(self, + in_channels, + out_channels, + w_dim, + kernel_size, + demodulate, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + demodulate: Whether to perform style demodulation. + eps: A small value to avoid divide overflow. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.kernel_size = kernel_size + self.demodulate = demodulate + self.eps = eps + + self.space_of_latent = 'W' + + # Set up weight. + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = 1.0 / np.sqrt(kernel_size * kernel_size * in_channels) + self.padding = kernel_size - 1 + + # Set up style. + self.style = DenseLayer(in_channels=w_dim, + out_channels=in_channels, + init_weight_std=1.0, + add_bias=True, + init_bias=1.0, + lr_mul=1.0, + activation_type='linear') + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'demodulate={self.demodulate}') + + def forward_style(self, w, impl='cuda'): + """Gets style code from the given input. + + More specifically, if the input is from W-Space, it will be projected by + an affine transformation. If it is from the Style Space (Y-Space), no + operation is required. + + NOTE: For codes from Y-Space, we use slicing to make sure the dimension + is correct, in case that the code is padded before fed into this layer. + """ + space_of_latent = self.space_of_latent.upper() + if space_of_latent == 'W': + if w.ndim != 2 or w.shape[1] != self.w_dim: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, w_dim], where ' + f'`w_dim` equals to {self.w_dim}!\n' + f'But `{w.shape}` is received!') + style = self.style(w, impl=impl) + elif space_of_latent == 'Y': + if w.ndim != 2 or w.shape[1] < self.in_channels: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, y_dim], where ' + f'`y_dim` equals to {self.in_channels}!\n' + f'But `{w.shape}` is received!') + style = w[:, :self.in_channels] + else: + raise NotImplementedError(f'Not implemented `space_of_latent`: ' + f'`{space_of_latent}`!') + return style + + def forward(self, x, w, gain=None, impl='cuda'): + dtype = x.dtype + N, C, H, W = x.shape + + # Affine on `w`. + style = self.forward_style(w, impl=impl) + if not self.demodulate: + _style = style * self.wscale # Equivalent to scaling weight. + else: + _style = style + + weight = self.weight + out_ch, in_ch, kh, kw = weight.shape + assert in_ch == C + + # Pre-normalize inputs. + if self.demodulate: + weight = (weight * + weight.square().mean(dim=(1, 2, 3), keepdim=True).rsqrt()) + _style = _style * _style.square().mean().rsqrt() + + weight = weight.unsqueeze(0) + weight = weight * _style.reshape(N, 1, in_ch, 1, 1) # modulation + if self.demodulate: + decoef = (weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt() + weight = weight * decoef.reshape(N, out_ch, 1, 1, 1) # demodulation + + if gain is not None: + gain = gain.expand(N, in_ch) + weight = weight * gain.reshape(N, 1, in_ch, 1, 1) + + # Fuse `conv` and `style modulation` as one op, using group convolution. + x = x.reshape(1, N * in_ch, H, W) + w = weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype) + x = conv2d_gradfix.conv2d( + x, w, padding=self.padding, groups=N, impl=impl) + x = x.reshape(N, out_ch, x.shape[2], x.shape[3]) + + assert x.dtype == dtype + assert style.dtype == torch.float32 + return x, style + + +class FilteringActLayer(nn.Module): + """Implements the activation, wrapped with upsampling and downsampling. + + Basically, this layer executes the following operations in order: + + (1) Apply bias. + (2) Upsample the feature map to increase sampling rate. + (3) Apply non-linearity as activation. + (4) Downsample the feature map to target size. + + This layer is mostly borrowed from the official implementation: + + https://github.com/NVlabs/stylegan3/blob/main/training/networks_stylegan3.py + """ + + def __init__(self, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + filter_size, + use_radial_filter, + conv_padding, + act_upsampling): + """Initializes with layer settings. + + Args: + out_channels: Number of output channels, which is used for `bias`. + in_size: Size of the input feature map, i.e., height and width. + out_size: Size of the output feature map, i.e., height and width. + in_sampling_rate: Sampling rate of the input feature map. Different + from `in_size` that includes extended border, this field + controls the actual maximum frequency that can be represented + by the feature map. + out_sampling_rate: Sampling rate of the output feature map. + in_cutoff: Cutoff frequency of the input feature map. + out_cutoff: Cutoff frequency of the output feature map. + in_half_width: Half-width of the transition band of the input + feature map. + out_half_width: Half-width of the transition band of the output + feature map. + filter_size: Size of the filter used in this layer. + use_radial_filter: Whether to use radial filter. + conv_padding: The padding used in the previous convolutional layer. + act_upsampling: Upsampling factor used before the activation. + `1` means do not wrap upsampling and downsampling around the + activation. + """ + super().__init__() + + self.out_channels = out_channels + self.in_size = in_size + self.out_size = out_size + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.filter_size = filter_size + self.use_radial_filter = use_radial_filter + self.conv_padding = conv_padding + self.act_upsampling = act_upsampling + + # Define bias. + self.bias = nn.Parameter(torch.zeros(out_channels)) + + # This sampling rate describes the upsampled feature map before + # activation. + temp_sampling_rate = max(in_sampling_rate, out_sampling_rate) + temp_sampling_rate = temp_sampling_rate * act_upsampling + + # Design upsampling filter. + up_factor = int(np.rint(temp_sampling_rate / in_sampling_rate)) + assert in_sampling_rate * up_factor == temp_sampling_rate + if up_factor > 1: + self.up_factor = up_factor + self.up_taps = filter_size * up_factor + else: + self.up_factor = 1 + self.up_taps = 1 # No filtering. + self.register_buffer( + 'up_filter', + self.design_lowpass_filter(numtaps=self.up_taps, + cutoff=in_cutoff, + width=in_half_width * 2, + fs=temp_sampling_rate, + radial=False)) + + # Design downsampling filter. + down_factor = int(np.rint(temp_sampling_rate / out_sampling_rate)) + assert out_sampling_rate * down_factor == temp_sampling_rate + if down_factor > 1: + self.down_factor = down_factor + self.down_taps = filter_size * down_factor + else: + self.down_factor = 1 + self.down_taps = 1 # No filtering. + self.register_buffer( + 'down_filter', + self.design_lowpass_filter(numtaps=self.down_taps, + cutoff=out_cutoff, + width=out_half_width * 2, + fs=temp_sampling_rate, + radial=use_radial_filter)) + + # Compute padding. + # Desired output size before downsampling. + pad_total = (out_size - 1) * self.down_factor + 1 + # Input size after upsampling. + pad_total = pad_total - (in_size + conv_padding) * self.up_factor + # Size reduction caused by the filters. + pad_total = pad_total + self.up_taps + self.down_taps - 2 + # Shift sample locations according to the symmetric interpretation. + pad_lo = (pad_total + self.up_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = list(map(int, (pad_lo, pad_hi, pad_lo, pad_hi))) + + def extra_repr(self): + return (f'in_size={self.in_size}, ' + f'out_size={self.out_size}, ' + f'in_srate={self.in_sampling_rate}, ' + f'out_srate={self.out_sampling_rate}, ' + f'in_cutoff={self.in_cutoff:.3f}, ' + f'out_cutoff={self.out_cutoff:.3f}, ' + f'in_half_width={self.in_half_width:.3f}, ' + f'out_half_width={self.out_half_width:.3f}, ' + f'up_factor={self.up_factor}, ' + f'up_taps={self.up_taps}, ' + f'down_factor={self.down_factor}, ' + f'down_taps={self.down_taps}, ' + f'filter_size={self.filter_size}, ' + f'radial_filter={self.use_radial_filter}, ' + f'conv_padding={self.conv_padding}, ' + f'act_upsampling={self.act_upsampling}') + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + """Designs a low-pass filter. + + Args: + numtaps: Length of the filter (number of coefficients, i.e., the + filter order + 1). + cutoff: Cutoff frequency of the output filter. + width: Width of the transition region. + fs: Sampling frequency. + radial: Whether to use radially symmetric jinc-based filter. + (default: False) + """ + if numtaps == 1: + return None + + assert numtaps > 1 + + if not radial: # Separable Kaiser low-pass filter. + f = scipy.signal.firwin(numtaps=numtaps, + cutoff=cutoff, + width=width, + fs=fs) + else: # Radially symmetric jinc-based filter. + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = scipy.signal.kaiser_beta( + scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f = f * np.outer(w, w) + f = f / np.sum(f) + return torch.as_tensor(f, dtype=torch.float32) + + def forward(self, x, gain, slope, clamp, impl='cuda'): + dtype = x.dtype + + x = filtered_lrelu.filtered_lrelu(x=x, + fu=self.up_filter, + fd=self.down_filter, + b=self.bias.to(dtype), + up=self.up_factor, + down=self.down_factor, + padding=self.padding, + gain=gain, + slope=slope, + clamp=clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + init_weight_std, + add_bias, + init_bias, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + init_weight_std: The initial standard deviation of weight. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.init_weight_std = init_weight_std + self.add_bias = add_bias + self.init_bias = init_bias + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + self.weight = nn.Parameter( + torch.randn(*weight_shape) * init_weight_std / lr_mul) + self.wscale = lr_mul / np.sqrt(in_channels) + + if add_bias: + init_bias = np.float32(np.float32(init_bias) / lr_mul) + if isinstance(init_bias, np.float32): + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + else: + assert isinstance(init_bias, np.ndarray) + self.bias = nn.Parameter(torch.from_numpy(init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'init_weight_std={self.init_weight_std}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan_discriminator.py b/models/stylegan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..c76ff3430af76839a259035fe37783ac3aa807de --- /dev/null +++ b/models/stylegan_discriminator.py @@ -0,0 +1,624 @@ +# python3.7 +"""Contains the implementation of discriminator described in StyleGAN. + +Paper: https://arxiv.org/pdf/1812.04948.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +__all__ = ['StyleGANDiscriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Fused-scale options allowed. +_FUSED_SCALE_ALLOWED = [True, False, 'auto'] + +# pylint: disable=missing-function-docstring + +class StyleGANDiscriminator(nn.Module): + """Defines the discriminator network in StyleGAN. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the backbone: + + (1) resolution: The resolution of the input image. (default: -1) + (2) init_res: Smallest resolution of the convolutional backbone. + (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) fused_scale: The strategy of fusing `conv2d` and `downsample` as one + operator. `True` means blocks from all resolutions will fuse. `False` + means blocks from all resolutions will not fuse. `auto` means blocks + from resolutions higher than (or equal to) `fused_scale_res` will fuse. + (default: `auto`) + (5) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample` + as one operator. This field only takes effect if `fused_scale` is set + as `auto`. (default: 128) + (6) use_wscale: Whether to use weight scaling. (default: True) + (7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (8) lr_mul: Learning rate multiplier for backbone. (default: 1.0) + (9) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 4) + (10) mbstd_channels: Number of new channels (appended to the original + feature map) after the minibatch standard deviation layer. (default: 1) + (11) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (12) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (13) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 2, 1)) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Settings for conditional model: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + + Runtime settings: + + (1) enable_amp: Whether to enable automatic mixed precision training. + (default: False) + """ + + def __init__(self, + # Settings for backbone. + resolution=-1, + init_res=4, + image_channels=3, + fused_scale='auto', + fused_scale_res=128, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + lr_mul=1.0, + mbstd_groups=4, + mbstd_channels=1, + fmaps_base=16 << 10, + fmaps_max=512, + filter_kernel=(1, 2, 1), + eps=1e-8, + # Settings for conditional model. + label_dim=0): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `fused_scale` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + if fused_scale not in _FUSED_SCALE_ALLOWED: + raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' + f'Options allowed: {_FUSED_SCALE_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.image_channels = image_channels + self.fused_scale = fused_scale + self.fused_scale_res = fused_scale_res + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.mbstd_groups = mbstd_groups + self.mbstd_channels = mbstd_channels + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.eps = eps + self.label_dim = label_dim + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution. + self.add_module( + f'input{block_idx}', + ConvLayer(in_channels=image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( + f'FromRGB_lod{block_idx}/weight') + self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( + f'FromRGB_lod{block_idx}/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0/bias') + + # Second layer (kernel 3x3) with downsampling + layer_name = f'layer{2 * block_idx + 1}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + scale_factor=2, + fused_scale=(res >= fused_scale_res + if fused_scale == 'auto' + else fused_scale), + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv1_down/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv1_down/bias') + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, + new_channels=mbstd_channels, + eps=eps) + + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels + mbstd_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv/bias') + + # Second layer, as a fully-connected layer. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module( + f'layer{2 * block_idx + 1}', + DenseLayer(in_channels=in_channels * res * res, + out_channels=in_channels, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Dense0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Dense0/bias') + + # Final dense layer to output score. + self.output = DenseLayer(in_channels=in_channels, + out_channels=max(label_dim, 1), + add_bias=True, + use_wscale=use_wscale, + wscale_gain=1.0, + lr_mul=lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = ( + f'{res}x{res}/Dense1/weight') + self.pth_to_tf_var_mapping['output.bias'] = ( + f'{res}x{res}/Dense1/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, label=None, lod=None, enable_amp=False): + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + if self.label_dim: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + batch = image.shape[0] + if (label.ndim != 2 or label.shape != (batch, self.label_dim)): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to {batch}, and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + + with autocast(enabled=enable_amp): + for res_log2 in range( + self.final_res_log2, self.init_res_log2 - 1, -1): + block_idx = current_lod = self.final_res_log2 - res_log2 + if current_lod <= lod < current_lod + 1: + x = getattr(self, f'input{block_idx}')(image) + elif current_lod - 1 < lod < current_lod: + alpha = lod - np.floor(lod) + y = getattr(self, f'input{block_idx}')(image) + x = y * alpha + x * (1 - alpha) + if lod < current_lod + 1: + if res_log2 == self.init_res_log2: + x = self.mbstd(x) + x = getattr(self, f'layer{2 * block_idx}')(x) + x = getattr(self, f'layer{2 * block_idx + 1}')(x) + if lod > current_lod: + image = F.avg_pool2d( + image, kernel_size=2, stride=2, padding=0) + x = self.output(x) + + if self.label_dim: + x = (x * label).sum(dim=1, keepdim=True) + + results = { + 'score': x, + 'label': label + } + return results + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, new_channels, eps): + super().__init__() + self.groups = groups + self.new_channels = new_channels + self.eps = eps + + def extra_repr(self): + return (f'groups={self.groups}, ' + f'new_channels={self.new_channels}, ' + f'epsilon={self.eps}') + + def forward(self, x): + if self.groups <= 1 or self.new_channels < 1: + return x + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + nC = self.new_channels # Number of channel groups. + c = C // nC # Channels per channel group. + + y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW] + y = y - y.mean(dim=0) # [GnFcHW] + y = y.square().mean(dim=0) # [nFcHW] + y = (y + self.eps).sqrt() # [nFcHW] + y = y.mean(dim=(2, 3, 4)) # [nF] + y = y.reshape(-1, nC, 1, 1) # [nF11] + y = y.repeat(G, 1, H, W) # [NFHW] + x = torch.cat((x, y), dim=1) # [N(C+F)HW] + + return x + + +class Blur(torch.autograd.Function): + """Defines blur operation with customized gradient computation.""" + + @staticmethod + def forward(ctx, x, kernel): + assert kernel.shape[2] == 3 and kernel.shape[3] == 3 + ctx.save_for_backward(kernel) + y = F.conv2d(input=x, + weight=kernel, + bias=None, + stride=1, + padding=1, + groups=x.shape[1]) + return y + + @staticmethod + def backward(ctx, dy): + kernel, = ctx.saved_tensors + dx = BlurBackPropagation.apply(dy, kernel) + return dx, None, None + + +class BlurBackPropagation(torch.autograd.Function): + """Defines the back propagation of blur operation. + + NOTE: This is used to speed up the backward of gradient penalty. + """ + + @staticmethod + def forward(ctx, dy, kernel): + ctx.save_for_backward(kernel) + dx = F.conv2d(input=dy, + weight=kernel.flip((2, 3)), + bias=None, + stride=1, + padding=1, + groups=dy.shape[1]) + return dx + + @staticmethod + def backward(ctx, ddx): + kernel, = ctx.saved_tensors + ddy = F.conv2d(input=ddx, + weight=kernel, + bias=None, + stride=1, + padding=1, + groups=ddx.shape[1]) + return ddy, None, None + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If downsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` first. If `fused_scale` is set as `True`, + `conv2d` and `downsample` will be fused as one operator, using stride + convolution. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + fused_scale, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for downsampling. `1` means skip + downsampling. + fused_scale: Whether to fuse `conv2d` and `downsample` as one + operator, using stride convolution. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.fused_scale = fused_scale + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + if scale_factor > 1: + assert filter_kernel is not None + kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1) + kernel = kernel.T.dot(kernel) + kernel = kernel / np.sum(kernel) + kernel = kernel[np.newaxis, np.newaxis] + self.register_buffer('filter', torch.from_numpy(kernel)) + + if scale_factor > 1 and fused_scale: # use stride convolution. + self.stride = scale_factor + else: + self.stride = 1 + self.padding = kernel_size // 2 + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'downsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'downsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}') + + def forward(self, x): + if self.scale_factor > 1: + # Disable `autocast` for customized autograd function. + # Please check reference: + # https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions + with autocast(enabled=False): + f = self.filter.repeat(self.in_channels, 1, 1, 1) + x = Blur.apply(x.float(), f) # Always use FP32. + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor > 1 and self.fused_scale: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 + x = F.conv2d(x, + weight=weight, + bias=bias, + stride=self.stride, + padding=self.padding) + if self.scale_factor > 1 and not self.fused_scale: + down = self.scale_factor + x = F.avg_pool2d(x, kernel_size=down, stride=down, padding=0) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x): + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + + x = F.linear(x, weight=weight, bias=bias) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan_generator.py b/models/stylegan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0034b34a5b72bfe6b305a9f6ff8d772b391c4f5 --- /dev/null +++ b/models/stylegan_generator.py @@ -0,0 +1,999 @@ +# python3.7 +"""Contains the implementation of generator described in StyleGAN. + +Paper: https://arxiv.org/pdf/1812.04948.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +from .utils.ops import all_gather + +__all__ = ['StyleGANGenerator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Fused-scale options allowed. +_FUSED_SCALE_ALLOWED = [True, False, 'auto'] + +# pylint: disable=missing-function-docstring + +class StyleGANGenerator(nn.Module): + """Defines the generator network in StyleGAN. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the mapping network: + + (1) z_dim: Dimension of the input latent space, Z. (default: 512) + (2) w_dim: Dimension of the output latent space, W. (default: 512) + (3) repeat_w: Repeat w-code for different layers. (default: True) + (4) normalize_z: Whether to normalize the z-code. (default: True) + (5) mapping_layers: Number of layers of the mapping network. (default: 8) + (6) mapping_fmaps: Number of hidden channels of the mapping network. + (default: 512) + (7) mapping_use_wscale: Whether to use weight scaling for the mapping + network. (default: True) + (8) mapping_wscale_gain: The factor to control weight scaling for the + mapping network (default: sqrt(2.0)) + (9) mapping_lr_mul: Learning rate multiplier for the mapping network. + (default: 0.01) + + Settings for conditional generation: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + + Settings for the synthesis network: + + (1) resolution: The resolution of the output image. (default: -1) + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) image_channels: Number of channels of the output image. (default: 3) + (4) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (5) fused_scale: The strategy of fusing `upsample` and `conv2d` as one + operator. `True` means blocks from all resolutions will fuse. `False` + means blocks from all resolutions will not fuse. `auto` means blocks + from resolutions higher than (or equal to) `fused_scale_res` will fuse. + (default: `auto`) + (6) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample` + as one operator. This field only takes effect if `fused_scale` is set + as `auto`. (default: 128) + (7) use_wscale: Whether to use weight scaling. (default: True) + (8) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (9) lr_mul: Learning rate multiplier for the synthesis network. + (default: 1.0) + (10) noise_type: Type of noise added to the convolutional results at each + layer. (default: `spatial`) + (11) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (12) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (13) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 2, 1)) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Runtime settings: + + (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for + training only. Set `None` to disable. (default: None) + (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set + as `True`, the stats will be more accurate, yet the speed maybe a little + bit slower. (default: False) + (3) style_mixing_prob: Probability to perform style mixing as a training + regularization. Set `None` to disable. (default: None) + (4) trunc_psi: Truncation psi, set `None` to disable. (default: None) + (5) trunc_layers: Number of layers to perform truncation. (default: None) + (6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`, + `const`. (default: `const`) + (7) enable_amp: Whether to enable automatic mixed precision training. + (default: False) + """ + + def __init__(self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=8, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_wscale_gain=np.sqrt(2.0), + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + # Settings for synthesis network. + resolution=-1, + init_res=4, + image_channels=3, + final_tanh=False, + fused_scale='auto', + fused_scale_res=128, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + lr_mul=1.0, + noise_type='spatial', + fmaps_base=16 << 10, + fmaps_max=512, + filter_kernel=(1, 2, 1), + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `fused_scale` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + if fused_scale not in _FUSED_SCALE_ALLOWED: + raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' + f'Options allowed: {_FUSED_SCALE_ALLOWED}.') + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_wscale_gain = mapping_wscale_gain + self.mapping_lr_mul = mapping_lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + + self.resolution = resolution + self.init_res = init_res + self.image_channels = image_channels + self.final_tanh = final_tanh + self.fused_scale = fused_scale + self.fused_scale_res = fused_scale_res + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (z_dim,) + + # Number of synthesis (convolutional) layers. + self.num_layers = int(np.log2(resolution // init_res * 2)) * 2 + + self.mapping = MappingNetwork(input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + use_wscale=mapping_use_wscale, + wscale_gain=mapping_wscale_gain, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + eps=eps) + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + self.synthesis = SynthesisNetwork(resolution=resolution, + init_res=init_res, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + fused_scale=fused_scale, + fused_scale_res=fused_scale_res, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + fmaps_base=fmaps_base, + fmaps_max=fmaps_max, + filter_kernel=filter_kernel, + eps=eps) + + self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'} + for key, val in self.mapping.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'mapping.{key}'] = val + for key, val in self.synthesis.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + See `SynthesisNetwork` for more details. + """ + self.synthesis.set_space_of_latent(space_of_latent) + + def forward(self, + z, + label=None, + lod=None, + w_moving_decay=None, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + noise_mode='const', + enable_amp=False): + mapping_results = self.mapping(z, label) + + w = mapping_results['w'] + if self.training and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results.pop('wp') + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label)['wp'] + lod = self.synthesis.lod.item() if lod is None else lod + current_layers = self.num_layers - int(lod) * 2 + mixing_cutoff = np.random.randint(1, current_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + with autocast(enabled=enable_amp): + synthesis_results = self.synthesis(wp, + lod=lod, + noise_mode=noise_mode) + + return {**mapping_results, **synthesis_results} + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping module. + + Basically, this module executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + use_wscale, + wscale_gain, + lr_mul, + label_dim, + embedding_dim, + eps): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.eps = eps + + self.pth_to_tf_var_mapping = {} + + if normalize_input: + self.norm = PixelNormLayer(dim=1, eps=eps) + + if self.label_dim > 0: + input_dim = input_dim + embedding_dim + self.embedding = nn.Parameter( + torch.randn(label_dim, embedding_dim)) + self.pth_to_tf_var_mapping['embedding'] = 'LabelConcat/weight' + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + self.add_module(f'dense{i}', + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight' + self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias' + + def forward(self, z, label=None): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + embedding = torch.matmul(label, self.embedding) + z = torch.cat((z, embedding), dim=1) + + if self.normalize_input: + w = self.norm(z) + else: + w = z + + for i in range(self.num_layers): + w = getattr(self, f'dense{i}')(w) + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class SynthesisNetwork(nn.Module): + """Implements the image synthesis module. + + Basically, this module executes several convolutional layers in sequence. + """ + + def __init__(self, + resolution, + init_res, + w_dim, + image_channels, + final_tanh, + fused_scale, + fused_scale_res, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + fmaps_base, + fmaps_max, + filter_kernel, + eps): + super().__init__() + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.fused_scale = fused_scale + self.fused_scale_res = fused_scale_res + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.eps = eps + + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2 ** res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # First layer (kernel 3x3) with upsampling + layer_name = f'layer{2 * block_idx}' + if res == self.init_res: + self.add_module(layer_name, + ModulateConvLayer(in_channels=0, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=None, + add_bias=True, + scale_factor=None, + fused_scale=None, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + use_style=True, + eps=eps)) + tf_layer_name = 'Const' + self.pth_to_tf_var_mapping[f'{layer_name}.const'] = ( + f'{res}x{res}/{tf_layer_name}/const') + else: + self.add_module( + layer_name, + ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=2, + fused_scale=(res >= fused_scale_res + if fused_scale == 'auto' + else fused_scale), + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + use_style=True, + eps=eps)) + tf_layer_name = 'Conv0_up' + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/{tf_layer_name}/Noise/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx}') + + # Second layer (kernel 3x3) without upsampling. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + use_style=True, + eps=eps)) + tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/{tf_layer_name}/Noise/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx + 1}') + + # Output convolution layer for each resolution. + self.add_module(f'output{block_idx}', + ModulateConvLayer(in_channels=out_channels, + out_channels=image_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=1.0, + lr_mul=lr_mul, + noise_type='none', + activation_type='linear', + use_style=False, + eps=eps)) + self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/weight') + self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + This function is particularly used for choosing how to inject the latent + code into the convolutional layers. The original generator will take a + W-Space code and apply it for style modulation after an affine + transformation. But, sometimes, it may need to directly feed an already + affine-transformed code into the convolutional layer, e.g., when + training an encoder for GAN inversion. We term the transformed space as + Style Space (or Y-Space). This function is designed to tell the + convolutional layers how to use the input code. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer) and module.use_style: + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, wp, lod=None, noise_mode='const'): + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + results = {'wp': wp} + x = None + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + current_lod = self.final_res_log2 - res_log2 + block_idx = res_log2 - self.init_res_log2 + if lod < current_lod + 1: + layer = getattr(self, f'layer{2 * block_idx}') + x, style = layer(x, wp[:, 2 * block_idx], noise_mode) + results[f'style{2 * block_idx}'] = style + layer = getattr(self, f'layer{2 * block_idx + 1}') + x, style = layer(x, wp[:, 2 * block_idx + 1], noise_mode) + results[f'style{2 * block_idx + 1}'] = style + if current_lod - 1 < lod <= current_lod: + image = getattr(self, f'output{block_idx}')(x) + elif current_lod < lod < current_lod + 1: + alpha = np.ceil(lod) - lod + temp = getattr(self, f'output{block_idx}')(x) + image = F.interpolate(image, scale_factor=2, mode='nearest') + image = temp * alpha + image * (1 - alpha) + elif lod >= current_lod + 1: + image = F.interpolate(image, scale_factor=2, mode='nearest') + + if self.final_tanh: + image = torch.tanh(image) + results['image'] = image + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class Blur(torch.autograd.Function): + """Defines blur operation with customized gradient computation.""" + + @staticmethod + def forward(ctx, x, kernel): + assert kernel.shape[2] == 3 and kernel.shape[3] == 3 + ctx.save_for_backward(kernel) + y = F.conv2d(input=x, + weight=kernel, + bias=None, + stride=1, + padding=1, + groups=x.shape[1]) + return y + + @staticmethod + def backward(ctx, dy): + kernel, = ctx.saved_tensors + dx = F.conv2d(input=dy, + weight=kernel.flip((2, 3)), + bias=None, + stride=1, + padding=1, + groups=dy.shape[1]) + return dx, None, None + + +class ModulateConvLayer(nn.Module): + """Implements the convolutional layer with style modulation.""" + + def __init__(self, + in_channels, + out_channels, + resolution, + w_dim, + kernel_size, + add_bias, + scale_factor, + fused_scale, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + activation_type, + use_style, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + resolution: Resolution of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for upsampling. + fused_scale: Whether to fuse `upsample` and `conv2d` as one + operator, using transpose convolution. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + noise_type: Type of noise added to the feature map after the + convolution (if needed). Support `none`, `spatial` and + `channel`. + activation_type: Type of activation. + use_style: Whether to apply style modulation. + eps: A small value to avoid divide overflow. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.resolution = resolution + self.w_dim = w_dim + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.fused_scale = fused_scale + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.activation_type = activation_type + self.use_style = use_style + self.eps = eps + + # Set up noise. + if self.noise_type == 'none': + pass + elif self.noise_type == 'spatial': + self.register_buffer( + 'noise', torch.randn(1, 1, resolution, resolution)) + self.noise_strength = nn.Parameter( + torch.zeros(1, out_channels, 1, 1)) + elif self.noise_type == 'channel': + self.register_buffer( + 'noise', torch.randn(1, out_channels, 1, 1)) + self.noise_strength = nn.Parameter( + torch.zeros(1, 1, resolution, resolution)) + else: + raise NotImplementedError(f'Not implemented noise type: ' + f'`{noise_type}`!') + + # Set up bias. + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + # Set up activation. + assert activation_type in ['linear', 'relu', 'lrelu'] + + # Set up style. + if use_style: + self.space_of_latent = 'W' + self.style = DenseLayer(in_channels=w_dim, + out_channels=out_channels * 2, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=1.0, + lr_mul=1.0, + activation_type='linear') + + if in_channels == 0: # First layer. + self.const = nn.Parameter( + torch.ones(1, out_channels, resolution, resolution)) + return + + # Set up weight. + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + # Set up upsampling filter (if needed). + if scale_factor > 1: + assert filter_kernel is not None + kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1) + kernel = kernel.T.dot(kernel) + kernel = kernel / np.sum(kernel) + kernel = kernel[np.newaxis, np.newaxis] + self.register_buffer('filter', torch.from_numpy(kernel)) + + if scale_factor > 1 and fused_scale: # use transpose convolution. + self.stride = scale_factor + else: + self.stride = 1 + self.padding = kernel_size // 2 + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'upsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'upsample_filter={self.filter_kernel}, ' + f'noise_type={self.noise_type}, ' + f'act={self.activation_type}, ' + f'use_style={self.use_style}') + + def forward_style(self, w): + """Gets style code from the given input. + + More specifically, if the input is from W-Space, it will be projected by + an affine transformation. If it is from the Style Space (Y-Space), no + operation is required. + + NOTE: For codes from Y-Space, we use slicing to make sure the dimension + is correct, in case that the code is padded before fed into this layer. + """ + space_of_latent = self.space_of_latent.upper() + if space_of_latent == 'W': + if w.ndim != 2 or w.shape[1] != self.w_dim: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, w_dim], where ' + f'`w_dim` equals to {self.w_dim}!\n' + f'But `{w.shape}` is received!') + style = self.style(w) + elif space_of_latent == 'Y': + if w.ndim != 2 or w.shape[1] < self.out_channels * 2: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, y_dim], where ' + f'`y_dim` equals to {self.out_channels * 2}!\n' + f'But `{w.shape}` is received!') + style = w[:, :self.out_channels * 2] + else: + raise NotImplementedError(f'Not implemented `space_of_latent`: ' + f'`{space_of_latent}`!') + return style + + def forward(self, x, w=None, noise_mode='const'): + if self.in_channels == 0: + assert x is None + x = self.const.repeat(w.shape[0], 1, 1, 1) + else: + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + + if self.scale_factor > 1 and self.fused_scale: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) + x = F.conv_transpose2d(x, + weight=weight.transpose(0, 1), + bias=None, + stride=self.stride, + padding=self.padding) + else: + if self.scale_factor > 1: + up = self.scale_factor + x = F.interpolate(x, scale_factor=up, mode='nearest') + x = F.conv2d(x, + weight=weight, + bias=None, + stride=self.stride, + padding=self.padding) + + if self.scale_factor > 1: + # Disable `autocast` for customized autograd function. + # Please check reference: + # https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions + with autocast(enabled=False): + f = self.filter.repeat(self.out_channels, 1, 1, 1) + x = Blur.apply(x.float(), f) # Always use FP32. + + # Prepare noise. + noise_mode = noise_mode.lower() + if self.noise_type != 'none' and noise_mode != 'none': + if noise_mode == 'random': + noise = torch.randn( + (x.shape[0], *self.noise.shape[1:]), device=x.device) + elif noise_mode == 'const': + noise = self.noise + else: + raise ValueError(f'Unknown noise mode `{noise_mode}`!') + x = x + noise * self.noise_strength + + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + x = x + bias.reshape(1, self.out_channels, 1, 1) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + if not self.use_style: + return x + + # Instance normalization. + x = x - x.mean(dim=(2, 3), keepdim=True) + scale = (x.square().mean(dim=(2, 3), keepdim=True) + self.eps).rsqrt() + x = x * scale + # Style modulation. + style = self.forward_style(w) + style_split = style.unsqueeze(2).unsqueeze(3).chunk(2, dim=1) + x = x * (style_split[0] + 1) + style_split[1] + + return x, style + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x): + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + + x = F.linear(x, weight=weight, bias=bias) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylenerf_discriminator.py b/models/stylenerf_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..042031fb63fba1e484f7881417a1f8bd0998185b --- /dev/null +++ b/models/stylenerf_discriminator.py @@ -0,0 +1,256 @@ +# python3.8 +"""Contains implementation of Discriminator described in StyleNeRF.""" + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.utils.ops import upsample +from models.utils.ops import downsample +from models.utils.camera import camera_9d_to_16d + +from models.utils.official_stylegan2_model_helper import EqualConv2d +from models.utils.official_stylegan2_model_helper import MappingNetwork +from models.utils.official_stylegan2_model_helper import DiscriminatorBlock +from models.utils.official_stylegan2_model_helper import DiscriminatorEpilogue + + +class Discriminator(nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 1, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 0, # Use FP16 for the N highest resolutions. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + lowres_head = None, # add a low-resolution discriminator head + dual_discriminator = False, # add low-resolution (NeRF) image + dual_input_ratio = None, # optional another low-res image input, which will be interpolated to the main input + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + upsample_type = 'default', + + progressive = False, + resize_real_early = False, # Peform resizing before the training loop + enable_ema = False, # Additionally save an EMA checkpoint + + predict_camera = False, # Learn camera predictor as InfoGAN + predict_9d_camera = False, # Use 9D camera distribution + predict_3d_camera = False, # Use 3D camera (u, v, r), assuming camera is on the unit sphere + no_camera_condition = False, # Disable camera conditioning in the discriminator + saperate_camera = False, # by default, only works in the lowest resolution. + **unused + ): + super().__init__() + # setup parameters + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + self.architecture = architecture + self.lowres_head = lowres_head + self.dual_input_ratio = dual_input_ratio + self.dual_discriminator = dual_discriminator + self.upsample_type = upsample_type + self.progressive = progressive + self.resize_real_early = resize_real_early + self.enable_ema = enable_ema + self.predict_camera = predict_camera + self.predict_9d_camera = predict_9d_camera + self.predict_3d_camera = predict_3d_camera + self.no_camera_condition = no_camera_condition + self.separate_camera = saperate_camera + if self.progressive: + assert self.architecture == 'skip', "not supporting other types for now." + if self.dual_input_ratio is not None: # similar to EG3d, concat low/high-res images + self.img_channels = self.img_channels * 2 + if self.predict_camera: + assert not (self.predict_9d_camera and self.predict_3d_camera), "cannot achieve at the same time" + channel_base = int(channel_base * 32768) + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + # camera prediction module + self.c_dim = c_dim + if predict_camera: + if not self.no_camera_condition: + if self.predict_3d_camera: + self.c_dim = out_dim = 3 # (u, v) on the sphere + else: + self.c_dim = 16 # extrinsic 4x4 (for now) + if self.predict_9d_camera: + out_dim = 9 + else: + out_dim = 16 + self.projector = EqualConv2d(channels_dict[4], out_dim, 4, padding=0, bias=False) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if self.c_dim == 0: + cmap_dim = 0 + if self.c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=self.c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + + # main discriminator blocks + common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + + # dual discriminator or separate camera predictor + if self.separate_camera or self.dual_discriminator: + cur_layer_idx = 0 + for res in [r for r in self.block_resolutions if r <= self.lowres_head]: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=False, **block_kwargs, **common_kwargs) + setattr(self, f'c{res}', block) + cur_layer_idx += block.num_layers + + # final output module + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + self.register_buffer("alpha", torch.scalar_tensor(-1)) + + def set_alpha(self, alpha): + if alpha is not None: + self.alpha = self.alpha * 0 + alpha + + def set_resolution(self, res): + self.curr_status = res + + def get_estimated_camera(self, img, **block_kwargs): + if isinstance(img, dict): + img = img['img'] + img4cam = img.clone() + if self.progressive and (img.size(-1) != self.lowres_head): + img4cam = downsample(img, self.lowres_head) + + c, xc = None, None + for res in [r for r in self.block_resolutions if r <= self.lowres_head or (not self.progressive)]: + xc, img4cam = getattr(self, f'c{res}')(xc, img4cam, **block_kwargs) + + if self.separate_camera: + c = self.projector(xc)[:,:,0,0] + if self.predict_9d_camera: + c = camera_9d_to_16d(c) + return c, xc, img4cam + + def get_camera_loss(self, RT=None, UV=None, c=None): + if UV is not None: # UV has higher priority? + return F.mse_loss(UV, c) + # lu = torch.stack([(UV[:,0] - c[:, 0]) ** 2, (UV[:,0] - c[:, 0] + 1) ** 2, (UV[:,0] - c[:, 0] - 1) ** 2], 0).min(0).values + # return torch.mean(sum(lu + (UV[:,1] - c[:, 1]) ** 2 + (UV[:,2] - c[:, 2]) ** 2)) + elif RT is not None: + return F.smooth_l1_loss(RT.reshape(RT.size(0), -1), c) * 10 + return None + + def get_block_resolutions(self, input_img): + block_resolutions = self.block_resolutions + lowres_head = self.lowres_head + alpha = self.alpha + img_res = input_img.size(-1) + if self.progressive and (self.lowres_head is not None) and (self.alpha > -1): + if (self.alpha < 1) and (self.alpha > 0): + try: + n_levels, _, before_res, target_res = self.curr_status + alpha, index = math.modf(self.alpha * n_levels) + index = int(index) + except Exception as e: # TODO: this is a hack, better to save status as buffers. + before_res = target_res = img_res + if before_res == target_res: # no upsampling was used in generator, do not increase the discriminator + alpha = 0 + block_resolutions = [res for res in self.block_resolutions if res <= target_res] + lowres_head = before_res + elif self.alpha == 0: + block_resolutions = [res for res in self.block_resolutions if res <= lowres_head] + return block_resolutions, alpha, lowres_head + + def forward(self, inputs, c=None, aug_pipe=None, return_camera=False, **block_kwargs): + if not isinstance(inputs, dict): + inputs = {'img': inputs} + img = inputs['img'] + block_resolutions, alpha, lowres_head = self.get_block_resolutions(img) + if img.size(-1) > block_resolutions[0]: + img = downsample(img, block_resolutions[0]) + + # this is to handle real images to obtain nerf-size image. + if (self.dual_discriminator or (self.dual_input_ratio is not None)) and ('img_nerf' not in inputs): + inputs['img_nerf'] = img + if self.dual_discriminator and (inputs['img_nerf'].size(-1) > self.lowres_head): # using Conv to read image. + inputs['img_nerf'] = downsample(inputs['img_nerf'], self.lowres_head) + elif self.dual_input_ratio is not None: # similar to EG3d + if inputs['img_nerf'].size(-1) > (img.size(-1) // self.dual_input_ratio): + inputs['img_nerf'] = downsample(inputs['img_nerf'], img.size(-1) // self.dual_input_ratio) + img = torch.cat([img, upsample(inputs['img_nerf'], img.size(-1))], 1) + + camera_loss = None + RT = inputs['camera_matrices'][1].detach() if 'camera_matrices' in inputs else None + UV = inputs['camera_matrices'][2].detach() if 'camera_matrices' in inputs else None + + # perform separate camera predictor or dual discriminator + if self.dual_discriminator or self.separate_camera: + temp_img = img if not self.dual_discriminator else inputs['img_nerf'] + c_nerf, x_nerf, img_nerf = self.get_estimated_camera(temp_img, **block_kwargs) + if c.size(-1) == 0 and self.separate_camera: + c = c_nerf + if self.predict_3d_camera: + camera_loss = self.get_camera_loss(RT, UV, c) + + # if applied data augmentation for discriminator + if aug_pipe is not None: + assert self.separate_camera or (not self.predict_camera), "ada may break the camera predictor." + img = aug_pipe(img) + + # obtain the downsampled image for progressive growing + if self.progressive and (self.lowres_head is not None) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0): + img0 = downsample(img, img.size(-1) // 2) + + x = None if (not self.progressive) or (block_resolutions[0] == self.img_resolution) \ + else getattr(self, f'b{block_resolutions[0]}').fromrgb(img) + for res in block_resolutions: + block = getattr(self, f'b{res}') + if (lowres_head == res) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0): + if self.architecture == 'skip': + img = img * alpha + img0 * (1 - alpha) + if self.progressive: + x = x * alpha + block.fromrgb(img0) * (1 - alpha) + x, img = block(x, img, **block_kwargs) + + # predict camera based on discriminator features + if (c.size(-1) == 0) and self.predict_camera and (not self.separate_camera): + c = self.projector(x)[:,:,0,0] + if self.predict_9d_camera: + c = camera_9d_to_16d(c) + if self.predict_3d_camera: + camera_loss = self.get_camera_loss(RT, UV, c) + + # camera conditional discriminator + cmap = None + if self.c_dim > 0: + cc = c.clone().detach() + cmap = self.mapping(None, cc) + logits = self.b4(x, img, cmap) + if self.dual_discriminator: + logits = torch.cat([logits, self.b4(x_nerf, img_nerf, cmap)], 0) + + outputs = {'logits': logits} + if self.predict_camera and (camera_loss is not None): + outputs['camera_loss'] = camera_loss + if return_camera: + outputs['camera'] = c + return outputs diff --git a/models/stylenerf_generator.py b/models/stylenerf_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..34525d923d71dbf8accb3eff99b79fdb40c0ab07 --- /dev/null +++ b/models/stylenerf_generator.py @@ -0,0 +1,427 @@ +# python3.8 +"""Contains the implementation of generator described in StyleNeRF.""" + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from einops import rearrange + +from utils import eg3d_misc as misc +from models.utils.official_stylegan2_model_helper import modulated_conv2d +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import bias_act +from models.utils.official_stylegan2_model_helper import FullyConnectedLayer +from models.utils.official_stylegan2_model_helper import MappingNetwork +from models.utils.official_stylegan2_model_helper import SynthesisBlock +from models.utils.official_stylegan2_model_helper import ToRGBLayer +from models.utils.official_stylegan2_model_helper import Conv2dLayer +from models.rendering import Renderer +from models.rendering import FeatureExtractor +from models.volumegan_generator import PositionEncoder + + +class StyleNeRFGenerator(nn.Module): + """Defines the generator network in StyleNeRF.""" + + def __init__(self, ): + super().__init__() + + # Set up mapping network. + self.mapping = MappingNetwork() ### TODO: Accomplish filling kwargs. + + # Set up overall Renderer. + self.renderer = Renderer() + + # Set up the position encoder. + self.position_encoder = PositionEncoder() ### TODO: Accomplish filling kwargs. + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode='none') + + # Set up the post module in the feature extractor. + self.post_module = NeRFMLPNetwork() ### TODO: Accomplish filling kwargs. + + # Set up the fully-connected layer head. + self.fc_head = FCHead() ### TODO: Accomplish filling kwargs. + + # Set up the post neural renderer. + self.post_neural_renderer = PostNeuralRendererNetwork() ### TODO: Accomplish filling kwargs. + + def forward(self,): + pass + + +class NeRFMLPNetwork(nn.Module): + """Defines class of FOREGROUND/BACKGROUND NeRF MLP Network in StyleNeRF. + + Basically, this module consists of several `Style2Layer`s where convolutions + with 1x1 kernel are involved. Note that this module is not strictly + equivalent to MLP. Since 1x1 convolution is equal to fully-connected layer, + we name this module `NeRFMLPNetwork`. Besides, our `NeRFMLPNetwork` takes in + sampled points, view directions, latent codes as input, and outputs features + for the following computation of `sigma` and `rgb`. + """ + + def __init__( + self, + # dimensions + input_dim=60, + w_dim=512, # style latent + hidden_size=128, + n_blocks=8, + # architecture settings + activation='lrelu', + use_skip=False, + nerf_kwargs={} + ): + super().__init__() + self.input_dim = input_dim + self.hidden_size = hidden_size + self.w_dim = w_dim + self.activation = activation + self.n_blocks = n_blocks + self.use_skip = use_skip + + for key in nerf_kwargs: + setattr(self, key, nerf_kwargs[key]) + + self.fc_in = Style2Layer(self.input_dim, + self.hidden_size, + self.w_dim, + activation=self.activation) + self.num_wp = 1 + self.skip_layer = self.n_blocks // 2 - 1 if self.use_skip else None + if self.n_blocks > 1: + self.blocks = nn.ModuleList([ + Style2Layer(self.hidden_size if i != self.skip_layer else + self.hidden_size + self.input_dim, + self.hidden_size, + w_dim, + activation=self.activation, + magnitude_ema_beta=self.magnitude_ema_beta) + for i in range(self.n_blocks - 1) + ]) + self.num_wp += (self.n_blocks - 1) + + def forward(self, + pre_point_features, + points_encoding, + wp=None, + use_both=False): + input_p = points_encoding + if use_both: + input_p = torch.cat([pre_point_features, input_p], 1) + out = self.fc_in(points_encoding, wp[:, 0] if wp is not None else None) + if self.n_blocks > 1: + for idx, layer in enumerate(self.blocks): + wp_i = wp[:, idx + 1] if wp is not None else None + if (self.skip_layer is not None) and (idx == self.skip_layer): + out = torch.cat([out, input_p], 1) + out = layer(out, wp_i, up=1) + return out + + +class FCHead(nn.Module): + """Defines the fully connnected layer head in StyleNeRF. + + Basically, this module is composed of several `ToRGBLayer`s and + `Conv2dLayer`s where all convolutions are with kernel size 1x1, in order to + decode the common feature of each point to the sigma (feature) and + rgb (feature). Note that this module is not strictly equivalent to the fully + connnected layer. Since 1x1 convolution is equal to fully-connected layer, + we name this module `FCHead`. + """ + + def __init__(self, + in_dim=128, + w_dim=512, + w_idx=8, + sigma_out_dim=1, + rgb_out_dim=256, + img_channels=3, + predict_rgb=True): + super().__init__() + self.predict_rgb = predict_rgb + self.w_idx = w_idx + self.sigma_head = ToRGBLayer(in_dim, + sigma_out_dim, + w_dim, + kernel_size=1) + self.rgb_head = ToRGBLayer(in_dim, rgb_out_dim, w_dim, kernel_size=1) + # Predict RGB over features. + if self.predict_rgb: + self.to_rgb = Conv2dLayer(rgb_out_dim, + img_channels, + kernel_size=1, + activation='linear') + + def forward(self, + post_point_features, + wp=None, + dirs=None, + height=None, + width=None): + assert (height is not None) and (width is not None) + # TODO: Check shape. + post_point_features = rearrange(post_point_features, + 'N C R_K 1 -> N C R K', + R=height * width) + post_point_features = rearrange(post_point_features, + 'N C R K -> (N K) C H W', + H=height, + W=width) + + sigma = self.sigma_head(post_point_features, wp[:, self.w_idx]) + rgb_feat = self.rgb_head(post_point_features, wp[:, -1]) + rgb = self.to_rgb(post_point_features) + rgb_feat = torch.cat([rgb_feat, rgb], dim=1) + + results = {'sigma': sigma, 'rgb': rgb_feat} + + return results + + +class PostNeuralRendererNetwork(nn.Module): + """Implements the post neural renderer network in StyleNeRF to renderer + high-resolution images. + + Basically, this module comprises several `SynthesisBlock` with respect to + different resolutions, which is analogous to StyleGAN2 architecure, and it + is trained progressively during training. Besides, it is called `Upsampler` + in the official implemetation. + """ + + no_2d_renderer = False + block_reses = None + upsample_type = 'default' + img_channels = 3 + in_res = 32 + out_res = 512 + channel_base = 1 + channel_base_sz = None # usually 32768, which equals 2 ** 15. + channel_max = 512 + channel_dict = None + out_channel_dict = None + + def __init__(self, upsampler_kwargs, **other_kwargs): + super().__init__() + for key in other_kwargs: + if hasattr(self, key) and (key not in upsampler_kwargs): + setattr(upsampler_kwargs, key, other_kwargs[key]) + for key in upsampler_kwargs: + if hasattr(self, key): + setattr(self, key, upsampler_kwargs[key]) + + self.out_res_log2 = int(np.log2(self.out_res)) + + # Set up resolution of blocks. + if self.block_reses is None: + self.block_resolutions = [ + 2**i for i in range(2, self.out_res_log2 + 1) + ] + self.block_resolutions = [ + res for res in self.block_resolutions if res > self.in_res + ] + else: + self.block_resolutions = self.block_reses + + if self.no_2d_renderer: + self.block_resolutions = [] + + def build_network(self, w_dim, in_dim, **block_kwargs): + networks = [] + if len(self.block_resolutions) == 0: + return networks + + channel_base = int( + self.channel_base * 32768 + ) if self.channel_base_sz is None else self.channel_base_sz + + # Don't use fp16 for the first block. + fp16_resolution = self.block_resolutions[0] * 2 + + if self.channel_dict is None: + channel_dict = { + res: min(channel_base // res, self.channel_max) + for res in self.block_resolutions + } + else: + channel_dict = self.channel_dict + + if self.out_channel_dict is None: + img_channels = self.out_channel_dict + else: + img_channels = { + res: self.img_channels + for res in self.block_resolutions + } + + for idx, res in enumerate(self.block_resolutions): + res_before = self.block_resolutions[idx - 1] if idx > 0 else self.in_res + in_channels = channel_dict[res_before] if idx > 0 else in_dim + out_channels = channel_dict[res] + use_fp16 = (res > fp16_resolution) + is_last = (idx == (len(self.block_resolutions) - 1)) + block = SynthesisBlock(in_channels=in_channels, + out_channels=out_channels, + w_dim=w_dim, + resolution=res, + img_channels=img_channels[res], + is_last=is_last, + use_fp16=use_fp16, + **block_kwargs) # TODO: Check the kwargs of `SynthesisBlock`, and add `upsample_mode` in our `SynthesisBlock` + networks += [ + {'block': block, + 'num_wp': block.num_conv if not is_last else block.num_conv + block.num_torgb, + 'name': f'b{res}' if res_before != res else f'b{res}_l{idx}'} + ] + self.num_wp = sum(net['num_wp'] for net in networks) + + return networks + + def split_wp(self, wp, blocks): + block_wp = [] + w_idx = 0 + for idx, _ in enumerate(self.block_resolutions): + block = blocks[idx] + block_wp.append( + wp.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx = w_idx + block.num_conv + return block_wp + + def forward(self, blocks, block_wp, x, image, target_res): + images = [] + for idx, (res, + cur_wp) in enumerate(zip(self.block_resolutions, block_wp)): + if res > target_res: + break + + block = blocks[idx] + x, image = block(x, image, cur_wp) # TODO: Check whether use noise here. + + images.append(image) + + return images + + +class Style2Layer(nn.Module): + """Defines the class of simplified `SynthesisLayer` used in NeRF block with + the following modifications: + + - No noise injection; + - Kernel size set to be 1x1. + """ + + def __init__( + self, + in_channels, + out_channels, + w_dim, + activation='lrelu', + resample_filter=[1, 3, 3, 1], + magnitude_ema_beta=-1, # -1 means not using magnitude ema + **unused_kwargs): + + super().__init__() + self.activation = activation + self.conv_clamp = None + self.register_buffer('resample_filter', + upfirdn2d.setup_filter(resample_filter)) + self.padding = 0 + self.act_gain = bias_act.activation_funcs[activation].def_gain + self.w_dim = w_dim + self.in_features = in_channels + self.out_features = out_channels + memory_format = torch.contiguous_format + + if w_dim > 0: + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + self.weight = torch.nn.Parameter( + torch.randn([out_channels, in_channels, 1, + 1]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + else: + self.weight = torch.nn.Parameter( + torch.Tensor(out_channels, in_channels)) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + self.weight_gain = 1. + + # Initialization. + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( + self.weight) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + self.magnitude_ema_beta = magnitude_ema_beta + if magnitude_ema_beta > 0: + self.register_buffer('w_avg', torch.ones([])) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, style={}'.format( + self.in_features, self.out_features, self.w_dim) + + def forward(self, + x, + w=None, + fused_modconv=None, + gain=1, + up=1, + **unused_kwargs): + flip_weight = True + act = self.activation + + if (self.magnitude_ema_beta > 0): + if self.training: # updating EMA. + with torch.autograd.profiler.record_function( + 'update_magnitude_ema'): + magnitude_cur = x.detach().to( + torch.float32).square().mean() + self.w_avg.copy_( + magnitude_cur.lerp(self.w_avg, + self.magnitude_ema_beta)) + input_gain = self.w_avg.rsqrt() + x = x * input_gain + + if fused_modconv is None: + with misc.suppress_tracer_warnings(): + # this value will be treated as a constant + fused_modconv = not self.training + + if self.w_dim > 0: # modulated convolution + assert x.ndim == 4, "currently not support modulated MLP" + styles = self.affine(w) # Batch x style_dim + if x.size(0) > styles.size(0): + styles = repeat(styles, + 'b c -> (b s) c', + s=x.size(0) // styles.size(0)) + + x = modulated_conv2d(x=x, + weight=self.weight, + styles=styles, + noise=None, + up=up, + padding=self.padding, + resample_filter=self.resample_filter, + flip_weight=flip_weight, + fused_modconv=fused_modconv) + act_gain = self.act_gain * gain + act_clamp = (self.conv_clamp * + gain if self.conv_clamp is not None else None) + x = bias_act.bias_act(x, + self.bias.to(x.dtype), + act=act, + gain=act_gain, + clamp=act_clamp) + + else: + if x.ndim == 2: # MLP mode + x = F.relu(F.linear(x, self.weight, self.bias.to(x.dtype))) + else: + x = F.relu( + F.conv2d(x, self.weight[:, :, None, None], self.bias)) + return x \ No newline at end of file diff --git a/models/test.py b/models/test.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1e0239e223537d299a2c52c65928b6c59406da --- /dev/null +++ b/models/test.py @@ -0,0 +1,146 @@ +# python3.7 +"""Unit test for loading pre-trained models. + +Basically, this file tests whether the perceptual model (VGG16) and the +inception model (InceptionV3), which are commonly used for loss computation and +evaluation, have the expected behavior after loading pre-trained weights. In +particular, we compare with the models from repo + +https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +import torch + +from models import build_model +from utils.misc import download_url + +__all__ = ['test_model'] + +_BATCH_SIZE = 4 +# pylint: disable=line-too-long +_PERCEPTUAL_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' +_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' +# pylint: enable=line-too-long + + +def test_model(): + """Collects all model tests.""" + torch.backends.cudnn.enabled = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + print('========== Start Model Test ==========') + test_perceptual() + test_inception() + print('========== Finish Model Test ==========') + + +def test_perceptual(): + """Test the perceptual model.""" + print('===== Testing Perceptual Model =====') + + print('Build test model.') + model = build_model('PerceptualModel', + use_torchvision=False, + no_top=False, + enable_lpips=True) + + print('Build reference model.') + ref_model_path, _, = download_url(_PERCEPTUAL_URL) + with open(ref_model_path, 'rb') as f: + ref_model = torch.jit.load(f).eval().cuda() + + print('Test performance: ') + for size in [224, 128, 256, 512, 1024]: + raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size)) + raw_img_comp = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size)) + + # The test model requires input images to have range [-1, 1]. + img = raw_img.to(torch.float32).cuda() / 127.5 - 1 + img_comp = raw_img_comp.to(torch.float32).cuda() / 127.5 - 1 + feat = model(img, resize_input=True, return_tensor='feature') + pred = model(img, resize_input=True, return_tensor='prediction') + lpips = model(img, img_comp, resize_input=False, return_tensor='lpips') + assert feat.shape == (_BATCH_SIZE, 4096) + assert pred.shape == (_BATCH_SIZE, 1000) + assert lpips.shape == (_BATCH_SIZE,) + + # The reference model requires input images to have range [0, 255]. + img = raw_img.to(torch.float32).cuda() + img_comp = raw_img_comp.to(torch.float32).cuda() + ref_feat = ref_model(img, resize_images=True, return_features=True) + ref_pred = ref_model(img, resize_images=True, return_features=False) + temp = ref_model(torch.cat([img, img_comp], dim=0), + resize_images=False, return_lpips=True).chunk(2) + ref_lpips = (temp[0] - temp[1]).square().sum(dim=1, keepdim=False) + assert ref_feat.shape == (_BATCH_SIZE, 4096) + assert ref_pred.shape == (_BATCH_SIZE, 1000) + assert ref_lpips.shape == (_BATCH_SIZE,) + + print(f' Size {size}x{size}, feature (with resize):\n ' + f'mean: {(feat - ref_feat).abs().mean().item():.3e}, ' + f'max: {(feat - ref_feat).abs().max().item():.3e}, ' + f'ref_mean: {ref_feat.abs().mean().item():.3e}, ' + f'ref_max: {ref_feat.abs().max().item():.3e}.') + print(f' Size {size}x{size}, prediction (with resize):\n ' + f'mean: {(pred - ref_pred).abs().mean().item():.3e}, ' + f'max: {(pred - ref_pred).abs().max().item():.3e}, ' + f'ref_mean: {ref_pred.abs().mean().item():.3e}, ' + f'ref_max: {ref_pred.abs().max().item():.3e}.') + print(f' Size {size}x{size}, LPIPS (without resize):\n ' + f'mean: {(lpips - ref_lpips).abs().mean().item():.3e}, ' + f'max: {(lpips - ref_lpips).abs().max().item():.3e}, ' + f'ref_mean: {ref_lpips.abs().mean().item():.3e}, ' + f'ref_max: {ref_lpips.abs().max().item():.3e}.') + + +def test_inception(): + """Test the inception model.""" + print('===== Testing Inception Model =====') + + print('Build test model.') + model = build_model('InceptionModel', align_tf=True) + + print('Build reference model.') + ref_model_path, _, = download_url(_INCEPTION_URL) + with open(ref_model_path, 'rb') as f: + ref_model = torch.jit.load(f).eval().cuda() + + print('Test performance: ') + for size in [299, 128, 256, 512, 1024]: + raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size)) + + # The test model requires input images to have range [-1, 1]. + img = raw_img.to(torch.float32).cuda() / 127.5 - 1 + feat = model(img) + pred = model(img, output_predictions=True) + pred_nb = model(img, output_predictions=True, remove_logits_bias=True) + assert feat.shape == (_BATCH_SIZE, 2048) + assert pred.shape == (_BATCH_SIZE, 1008) + assert pred_nb.shape == (_BATCH_SIZE, 1008) + + # The reference model requires input images to have range [0, 255]. + img = raw_img.to(torch.float32).cuda() + ref_feat = ref_model(img, return_features=True) + ref_pred = ref_model(img) + ref_pred_nb = ref_model(img, no_output_bias=True) + assert ref_feat.shape == (_BATCH_SIZE, 2048) + assert ref_pred.shape == (_BATCH_SIZE, 1008) + assert ref_pred_nb.shape == (_BATCH_SIZE, 1008) + + print(f' Size {size}x{size}, feature:\n ' + f'mean: {(feat - ref_feat).abs().mean().item():.3e}, ' + f'max: {(feat - ref_feat).abs().max().item():.3e}, ' + f'ref_mean: {ref_feat.abs().mean().item():.3e}, ' + f'ref_max: {ref_feat.abs().max().item():.3e}.') + print(f' Size {size}x{size}, prediction:\n ' + f'mean: {(pred - ref_pred).abs().mean().item():.3e}, ' + f'max: {(pred - ref_pred).abs().max().item():.3e}, ' + f'ref_mean: {ref_pred.abs().mean().item():.3e}, ' + f'ref_max: {ref_pred.abs().max().item():.3e}.') + print(f' Size {size}x{size}, prediction (without bias):\n ' + f'mean: {(pred_nb - ref_pred_nb).abs().mean().item():.3e}, ' + f'max: {(pred_nb - ref_pred_nb).abs().max().item():.3e}, ' + f'ref_mean: {ref_pred_nb.abs().mean().item():.3e}, ' + f'ref_max: {ref_pred_nb.abs().max().item():.3e}.') diff --git a/models/utils/__init__.py b/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/utils/__pycache__/__init__.cpython-37.pyc b/models/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db32e564b9264df5d9aae4efe38f0f5cc5e0f14 Binary files /dev/null and b/models/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/utils/__pycache__/__init__.cpython-39.pyc b/models/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b2fc1478863d13dff4a18b96c83156bb039b24c Binary files /dev/null and b/models/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/utils/__pycache__/batchnorm.cpython-37.pyc b/models/utils/__pycache__/batchnorm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4270a3771af2fa2ecc6cec39f3b30b033aa0df52 Binary files /dev/null and b/models/utils/__pycache__/batchnorm.cpython-37.pyc differ diff --git a/models/utils/__pycache__/batchnorm.cpython-39.pyc b/models/utils/__pycache__/batchnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba3583316aa0c104696cb2c866fc805ca529a524 Binary files /dev/null and b/models/utils/__pycache__/batchnorm.cpython-39.pyc differ diff --git a/models/utils/__pycache__/blurpool.cpython-37.pyc b/models/utils/__pycache__/blurpool.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..034a779f5d47b4738708b240da38051f73c9053b Binary files /dev/null and b/models/utils/__pycache__/blurpool.cpython-37.pyc differ diff --git a/models/utils/__pycache__/blurpool.cpython-39.pyc b/models/utils/__pycache__/blurpool.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c96da958f01c8bb578e1d78924a52a34b3eb6f8 Binary files /dev/null and b/models/utils/__pycache__/blurpool.cpython-39.pyc differ diff --git a/models/utils/__pycache__/comm.cpython-37.pyc b/models/utils/__pycache__/comm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..272c23bc0c682e59c3c2efeec8e03fbe132e7c7b Binary files /dev/null and b/models/utils/__pycache__/comm.cpython-37.pyc differ diff --git a/models/utils/__pycache__/comm.cpython-39.pyc b/models/utils/__pycache__/comm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb9ce7dfbede257189a9e79a8651fac1ee81ace Binary files /dev/null and b/models/utils/__pycache__/comm.cpython-39.pyc differ diff --git a/models/utils/__pycache__/eg3d_superres.cpython-37.pyc b/models/utils/__pycache__/eg3d_superres.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc1f68071e12af6ff7a7be79f840d1d19a66acc8 Binary files /dev/null and b/models/utils/__pycache__/eg3d_superres.cpython-37.pyc differ diff --git a/models/utils/__pycache__/eg3d_superres.cpython-39.pyc b/models/utils/__pycache__/eg3d_superres.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6043677dd0a99bd4aff175a3bae8e119e82fff3 Binary files /dev/null and b/models/utils/__pycache__/eg3d_superres.cpython-39.pyc differ diff --git a/models/utils/__pycache__/official_stylegan2_model_helper.cpython-37.pyc b/models/utils/__pycache__/official_stylegan2_model_helper.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9145875b5e49d974d3af0a74465ea843c04a28f Binary files /dev/null and b/models/utils/__pycache__/official_stylegan2_model_helper.cpython-37.pyc differ diff --git a/models/utils/__pycache__/official_stylegan2_model_helper.cpython-39.pyc b/models/utils/__pycache__/official_stylegan2_model_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e54308956fc1e1330c4971430048440afb79872 Binary files /dev/null and b/models/utils/__pycache__/official_stylegan2_model_helper.cpython-39.pyc differ diff --git a/models/utils/__pycache__/official_stylegan3_model_helper.cpython-37.pyc b/models/utils/__pycache__/official_stylegan3_model_helper.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..794af7ee4e7c251ee4a1235d85a52272db0dec13 Binary files /dev/null and b/models/utils/__pycache__/official_stylegan3_model_helper.cpython-37.pyc differ diff --git a/models/utils/__pycache__/official_stylegan3_model_helper.cpython-39.pyc b/models/utils/__pycache__/official_stylegan3_model_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f63115a80a1bc82c612729ee2006972878a2acf Binary files /dev/null and b/models/utils/__pycache__/official_stylegan3_model_helper.cpython-39.pyc differ diff --git a/models/utils/__pycache__/ops.cpython-37.pyc b/models/utils/__pycache__/ops.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de88ccbb462c0efeb7b7d4b0b98218a9b8ed6ed7 Binary files /dev/null and b/models/utils/__pycache__/ops.cpython-37.pyc differ diff --git a/models/utils/__pycache__/ops.cpython-39.pyc b/models/utils/__pycache__/ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..024a9e6b5399a2bd28ecc1e6780a1b177f123821 Binary files /dev/null and b/models/utils/__pycache__/ops.cpython-39.pyc differ diff --git a/models/utils/__pycache__/replicate.cpython-37.pyc b/models/utils/__pycache__/replicate.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..273aaff3440ce51d3efd6a0a0eda303c75ff6a29 Binary files /dev/null and b/models/utils/__pycache__/replicate.cpython-37.pyc differ diff --git a/models/utils/__pycache__/replicate.cpython-39.pyc b/models/utils/__pycache__/replicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14953f76bb70520e7101e58572966a9e40fe995d Binary files /dev/null and b/models/utils/__pycache__/replicate.cpython-39.pyc differ diff --git a/models/utils/__pycache__/sg3.cpython-37.pyc b/models/utils/__pycache__/sg3.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d3c893da52ba258627c726c944543766eb2ba9a Binary files /dev/null and b/models/utils/__pycache__/sg3.cpython-37.pyc differ diff --git a/models/utils/__pycache__/sg3_nohm.cpython-37.pyc b/models/utils/__pycache__/sg3_nohm.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..182d0e20cba39e09b6d2c672354185b8301fae55 Binary files /dev/null and b/models/utils/__pycache__/sg3_nohm.cpython-37.pyc differ diff --git a/models/utils/__pycache__/sg3_sharebevencoder.cpython-37.pyc b/models/utils/__pycache__/sg3_sharebevencoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a12373a8b45d44efaf477fb5cdc65212508455f Binary files /dev/null and b/models/utils/__pycache__/sg3_sharebevencoder.cpython-37.pyc differ diff --git a/models/utils/__pycache__/spade.cpython-37.pyc b/models/utils/__pycache__/spade.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63aca154456a75433d01e77186403bf62f1d79ac Binary files /dev/null and b/models/utils/__pycache__/spade.cpython-37.pyc differ diff --git a/models/utils/__pycache__/spade.cpython-39.pyc b/models/utils/__pycache__/spade.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78fa02638c153afdb95706441dfb9c33fa5df122 Binary files /dev/null and b/models/utils/__pycache__/spade.cpython-39.pyc differ diff --git a/models/utils/__pycache__/unet.cpython-37.pyc b/models/utils/__pycache__/unet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1284cb40544408f40050a2d3d99518d280fbafc6 Binary files /dev/null and b/models/utils/__pycache__/unet.cpython-37.pyc differ diff --git a/models/utils/__pycache__/unet.cpython-39.pyc b/models/utils/__pycache__/unet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db5ca2ae27c7353853c46f0a3cfb0d032a217f4 Binary files /dev/null and b/models/utils/__pycache__/unet.cpython-39.pyc differ diff --git a/models/utils/batchnorm.py b/models/utils/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/models/utils/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/models/utils/blurpool.py b/models/utils/blurpool.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe39819a3d36a7416da3f7421bc4142e433f2e3 --- /dev/null +++ b/models/utils/blurpool.py @@ -0,0 +1,117 @@ +# Copyright (c) 2019, Adobe Inc. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike +# 4.0 International Public License. To view a copy of this license, visit +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. +# code borrowed from: https://github.com/adobe/antialiased-cnns/blob/master/antialiased_cnns/blurpool.py + +import torch +import torch.nn.parallel +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +class BlurPool(nn.Module): + def __init__(self, channels, pad_type='reflect', filt_size=4, stride=2, pad_off=0): + super(BlurPool, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] + self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride-1)/2.) + self.channels = channels + + if(self.filt_size==1): + a = np.array([1.,]) + elif(self.filt_size==2): + a = np.array([1., 1.]) + elif(self.filt_size==3): + a = np.array([1., 2., 1.]) + elif(self.filt_size==4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size==5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size==6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size==7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a[:,None]*a[None,:]) + filt = filt/torch.sum(filt) + self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size==1): + if(self.pad_off==0): + return inp[:,:,::self.stride,::self.stride] + else: + return self.pad(inp)[:,:,::self.stride,::self.stride] + else: + return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + +def get_pad_layer(pad_type): + if(pad_type in ['refl','reflect']): + PadLayer = nn.ReflectionPad2d + elif(pad_type in ['repl','replicate']): + PadLayer = nn.ReplicationPad2d + elif(pad_type=='zero'): + PadLayer = nn.ZeroPad2d + else: + print('Pad type [%s] not recognized'%pad_type) + return PadLayer + +class BlurPool1D(nn.Module): + def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0): + super(BlurPool1D, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.) + self.channels = channels + + # print('Filter size [%i]' % filt_size) + if(self.filt_size == 1): + a = np.array([1., ]) + elif(self.filt_size == 2): + a = np.array([1., 1.]) + elif(self.filt_size == 3): + a = np.array([1., 2., 1.]) + elif(self.filt_size == 4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size == 5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size == 6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size == 7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a) + filt = filt / torch.sum(filt) + self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) + + self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size == 1): + if(self.pad_off == 0): + return inp[:, :, ::self.stride] + else: + return self.pad(inp)[:, :, ::self.stride] + else: + return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + +def get_pad_layer_1d(pad_type): + if(pad_type in ['refl', 'reflect']): + PadLayer = nn.ReflectionPad1d + elif(pad_type in ['repl', 'replicate']): + PadLayer = nn.ReplicationPad1d + elif(pad_type == 'zero'): + PadLayer = nn.ZeroPad1d + else: + print('Pad type [%s] not recognized' % pad_type) + return PadLayer diff --git a/models/utils/camera.py b/models/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0397b2106b6dec6eb8c9498ed3ab4ff50bf957 --- /dev/null +++ b/models/utils/camera.py @@ -0,0 +1,39 @@ +# python3.8 +"""Contains some functions related to cameras, including rotation +transformation, projection, etc.""" + +import torch +import torch.nn.functional as F + +def camera_9d_to_16d(d9): + d6, translation = d9[..., :6], d9[..., 6:] + rotation = rotation_6d_to_matrix(d6) + RT = torch.eye(4).to(device=d9.device, dtype=d9.dtype).reshape( + 1, 4, 4).repeat(d6.size(0), 1, 1) + RT[:, :3, :3] = rotation + RT[:, :3, -1] = translation + return RT.reshape(-1, 16) + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035. + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) \ No newline at end of file diff --git a/models/utils/comm.py b/models/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..ac95c651fb2c5dc3ceef1a89bc2df1a416804d27 --- /dev/null +++ b/models/utils/comm.py @@ -0,0 +1,134 @@ +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/models/utils/eg3d_superres.py b/models/utils/eg3d_superres.py new file mode 100644 index 0000000000000000000000000000000000000000..aa549a7b43dfab3feefdc24bb53c0c23b63c0710 --- /dev/null +++ b/models/utils/eg3d_superres.py @@ -0,0 +1,288 @@ +# python 3.7 +"""Contains Super-Resolution Module described in EG3D.""" + +import numpy as np +import torch +from .official_stylegan2_model_helper import Conv2dLayer +from .official_stylegan2_model_helper import ToRGBLayer +from .official_stylegan2_model_helper import SynthesisLayer +from .official_stylegan2_model_helper import SynthesisBlock +from third_party.stylegan2_official_ops import upfirdn2d +from utils import eg3d_misc as misc + +#---------------------------------------------------------------------------- + +# for 512x512 generation +class SuperresolutionHybrid8X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 512 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 128 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlock(channels, 128, w_dim=512, resolution=256, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=512, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# for 256x256 generation +class SuperresolutionHybrid4X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 256 + use_fp16 = sr_num_fp16_res > 0 + self.sr_antialias = sr_antialias + self.input_resolution = 128 + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +class SuperresolutionHybrid4X_conststyle(SuperresolutionHybrid4X): + def forward(self, rgb, x, **block_kwargs): + ws = torch.ones([x.shape[0], 3, 512]).float().to(x.device) + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# for 128 x 128 generation +class SuperresolutionHybrid2X(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 128 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 64 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=64, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=128, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +# TODO: Delete (here for backwards compatibility with old 256x256 models) +class SuperresolutionHybridDeepfp32(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 256 + use_fp16 = sr_num_fp16_res > 0 + + self.input_resolution = 128 + self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] < self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + +class SynthesisBlockNoUp(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + # if img is not None: + # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + # img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + + +#---------------------------------------------------------------------------- + +# for 512x512 generation +class SuperresolutionHybrid8XDC(torch.nn.Module): + def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, + num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE + **block_kwargs): + super().__init__() + assert img_resolution == 512 + + use_fp16 = sr_num_fp16_res > 0 + self.input_resolution = 128 + self.sr_antialias = sr_antialias + self.block0 = SynthesisBlock(channels, 256, w_dim=512, resolution=256, + img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + self.block1 = SynthesisBlock(256, 128, w_dim=512, resolution=512, + img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs) + + def forward(self, rgb, x, ws, **block_kwargs): + ws = ws[:, -1:, :].repeat(1, 3, 1) + + if x.shape[-1] != self.input_resolution: + x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), + mode='bilinear', align_corners=False) + + x, rgb = self.block0(x, rgb, ws, **block_kwargs) + x, rgb = self.block1(x, rgb, ws, **block_kwargs) + return rgb + +#---------------------------------------------------------------------------- + diff --git a/models/utils/official_stylegan2_model_helper.py b/models/utils/official_stylegan2_model_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..84063050a5dcc407d36732b08227ddd8ebcd0ceb --- /dev/null +++ b/models/utils/official_stylegan2_model_helper.py @@ -0,0 +1,958 @@ +# python 3.7 +"""Contains some helper calsses and functions of EG3D model.""" + +import math +import random +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from utils import eg3d_misc as misc +from third_party.stylegan2_official_ops import conv2d_resample +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import bias_act +from third_party.stylegan2_official_ops import fma + +#---------------------------------------------------------------------------- + +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + +#---------------------------------------------------------------------------- + +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise = None, # Optional noise tensor to add to the output activations. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + padding = 0, # Padding with respect to the upsampled image. + resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate = True, # Apply weight demodulation? + flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + +#---------------------------------------------------------------------------- +class SEL(torch.nn.Module): + def __init__(self, norm_nc, label_nc, hidden_nc=128): + super().__init__() + self.norm = nn.InstanceNorm2d(norm_nc, affine=False) + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, hidden_nc, kernel_size=1, padding=0), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(hidden_nc, norm_nc, kernel_size=1, padding=0) + self.mlp_beta = nn.Conv2d(hidden_nc, norm_nc, kernel_size=1, padding=0) + + def forward(self, x, hm): + x_s = x + x = self.norm(x) + hm = F.interpolate(hm, size=x.size()[2:], mode='bilinear', align_corners=True) + actv = self.mlp_shared(hm) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = x * (1+gamma) + beta + + return out + 0.1 * x_s + + +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +#---------------------------------------------------------------------------- + +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output to +-X, None = disable clamping. + channels_last = False, # Expect the input to have memory_format=channels_last? + trainable = True, # Update the weights of this layer during training? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', + f'up={self.up}, down={self.down}']) + +#---------------------------------------------------------------------------- + +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers = 8, # Number of mapping layers. + embed_features = None, # Label embedding dimensionality, None = same as w_dim. + layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if update_emas and self.w_avg_beta is not None: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}' + +#---------------------------------------------------------------------------- + +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + label_nc = 0, + use_sel = False, + kernel_size = 3, # Convolution kernel size. + up = 1, # Integer upsampling factor. + use_noise = True, # Enable noise input? + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last = False, # Use channels_last format for the weights? + **unused_kwargs + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.use_sel = use_sel + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + if self.use_sel: + self.sel = SEL(norm_nc=in_channels, label_nc=label_nc) + + def forward(self, x, w, heatmap=None, noise_mode='random', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution]) + styles = self.affine(w) + + if self.use_sel: + x = self.sel(x, heatmap).to(x.dtype) + + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) + +#---------------------------------------------------------------------------- + +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, label_nc=0, use_sel=False, kernel_size=1, conv_clamp=None, channels_last=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + self.use_sel = use_sel + if self.use_sel: + self.sel=SEL(norm_nc=in_channels, label_nc=label_nc) + + def forward(self, x, w, heatmap=None, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + if self.use_sel: + x = self.sel(x, heatmap).to(x.dtype) + + if x.size(0) > styles.size(0): + assert (x.size(0) // styles.size(0) * styles.size(0) == x.size(0)) + styles = repeat(styles, 'b c -> (b s) c', s=x.size(0) // styles.size(0)) + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + def extra_repr(self): + return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' + +#---------------------------------------------------------------------------- +class SynthesisInput(torch.nn.Module): + def __init__(self, + w_dim, + channels, + offset_scale='0,0', + bound_len=0.5, + # sampling_rate, + # bandwidth + ): + super().__init__() + self.w_dim = w_dim + self.channels = channels + self.bound_len = bound_len + # self.sampling_rate = sampling_rate + # self.bandwidth = bandwidth + self.size = [4,4] + self.x_offset_scale, self.y_offset_scale = list(map(float, offset_scale.split(','))) + + freqs = torch.randn([self.channels, 2]) + radii = freqs.square().sum(dim=1, keepdim=True).sqrt() + freqs /= radii * radii.square().exp().pow(0.25) + # freqs *= bandwidth + phases = torch.rand([self.channels]) - 0.5 + + self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) + # self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) + self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. + self.register_buffer('freqs', freqs) + self.register_buffer('phases', phases) + + self.x_offset = None + self.y_offset = None + + def forward(self, w): + transforms = self.transform.unsqueeze(0) # [batch, row, col] + freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] + phases = self.phases.unsqueeze(0) # [batch, channel] + + # t = self.affine(w) # t = (r_c, r_s, t_x, t_y) + # t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) + # m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. + # m_r[:, 0, 0] = t[:, 0] # r'_c + # m_r[:, 0, 1] = -t[:, 1] # r'_s + # m_r[:, 1, 0] = t[:, 1] # r'_s + # m_r[:, 1, 1] = t[:, 0] # r'_c + # m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. + # m_t[:, 0, 2] = -t[:, 2] # t'_x + # m_t[:, 1, 2] = -t[:, 3] # t'_y + # transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. + + # Transform frequencies. + phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) + freqs = freqs @ transforms[:, :2, :2] + + # Dampen out-of-band frequencies that may occur due to the user-specified transform. + # amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) + + theta = torch.eye(2, 3, device=w.device) + theta[0, 0] = self.bound_len + theta[1, 1] = self.bound_len + # theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate + # theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate + grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) + + offset_len = 1 - 2*self.bound_len + dx = (random.random() * offset_len - offset_len/2) if self.x_offset == None else self.x_offset + dy = (random.random() * offset_len - offset_len/2) if self.y_offset == None else self.y_offset + dx *= self.x_offset_scale + dy *= self.y_offset_scale + grids[..., 0] += dx + grids[..., 1] += dy + + x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) + x = x + phases.unsqueeze(1).unsqueeze(2) + x = torch.sin(x * (np.pi * 2)) + # x = x * amplitudes.unsqueeze(1).unsqueeze(2) + + weight = self.weight / np.sqrt(self.channels) + x = x @ weight.t() + + x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] + return x + + +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + label_nc = 0, + use_sel = False, + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + ff_input = False, + offset_scale = '0,0', + bound_len=0.5, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + self.ff_input = ff_input + + if in_channels == 0: + if self.ff_input: + self.input = SynthesisInput(w_dim=w_dim, channels=out_channels, offset_scale=offset_scale, bound_len=bound_len) + else: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, use_sel=use_sel, label_nc=label_nc, + resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, use_sel=use_sel, label_nc=label_nc, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, label_nc=label_nc, use_sel=use_sel, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, heatmap=None, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + if self.ff_input: + x = self.input(ws) + else: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), heatmap, fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), heatmap, fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), heatmap, fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), heatmap, fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), heatmap, fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), heatmap, fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + label_nc = 0, + use_sel = False, + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + ff_input = False, + offset_scale = '0,0', + bound_len=0.5, + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.use_sel = use_sel + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, use_sel=self.use_sel, label_nc=label_nc, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, ff_input=ff_input, offset_scale=offset_scale, bound_len=bound_len, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, heatmap=None, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, heatmap, **block_kwargs) + return img + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + label_nc = None, + use_sel = False, + ff_input = False, + offset_scale = '0,0', + bound_len = 0.5, + + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, label_nc=label_nc, use_sel=use_sel, ff_input=ff_input, offset_scale=offset_scale, bound_len=bound_len, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, heatmap=None, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, spatial_map=heatmap, **synthesis_kwargs) + return img + +#---------------------------------------------------------------------------- + +class DiscriminatorBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + freeze_layers = 0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.num_layers = 0 + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = (layer_idx >= freeze_layers) + self.num_layers += 1 + yield trainable + trainable_iter = trainable_gen() + + if in_channels == 0 or architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) + + self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) + + self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) + + if architecture == 'resnet': + self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, force_fp32=False): + if (x if x is not None else img).device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + + # Input. + if x is not None: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0 or self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + + # Main layers. + if self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + x = self.conv1(x) + + assert x.dtype == dtype + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + def extra_repr(self): + return f'group_size={self.group_size}, num_channels={self.num_channels:d}' + +#---------------------------------------------------------------------------- + +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, x, img, cmap, force_fp32=False): + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + if self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + x = x + self.fromrgb(img) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + x = self.fc(x.flatten(1)) + x = self.out(x) + + # Conditioning. + if self.cmap_dim > 0: + misc.assert_shape(cmap, [None, self.cmap_dim]) + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- + +class EqualConv2d(torch.nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + new_scale = 1.0 + self.weight = torch.nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) * new_scale + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + self.stride = stride + self.padding = padding + if bias: + self.bias = torch.nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + +#---------------------------------------------------------------------------- diff --git a/models/utils/official_stylegan3_model_helper.py b/models/utils/official_stylegan3_model_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f50d320a7fec83bbdc40d6154c923fd90d49bd4b --- /dev/null +++ b/models/utils/official_stylegan3_model_helper.py @@ -0,0 +1,656 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generator architecture from the paper +"Alias-Free Generative Adversarial Networks".""" + +import numpy as np +import scipy.signal +import scipy.optimize +import torch +import torch.nn as nn +import torch.nn.functional as F +import random +from utils import eg3d_misc as misc +from third_party.stylegan3_official_ops import conv2d_gradfix +from third_party.stylegan3_official_ops import filtered_lrelu +from third_party.stylegan3_official_ops import bias_act +from third_party.stylegan3_official_ops import upfirdn2d + +#---------------------------------------------------------------------------- +class SEL(torch.nn.Module): + def __init__(self, norm_nc, label_nc, hidden_nc=128): + super().__init__() + self.norm = nn.InstanceNorm2d(norm_nc, affine=False) + self.mlp_shared = nn.Conv2d(label_nc, hidden_nc, kernel_size=1, padding=0) + self.actv = nn.ReLU() + self.mlp_gamma = nn.Conv2d(hidden_nc, norm_nc, kernel_size=1, padding=0) + self.mlp_beta = nn.Conv2d(hidden_nc, norm_nc, kernel_size=1, padding=0) + + def forward(self, x, hm): + x_s = x + x = self.norm(x) + hm = F.interpolate(hm, size=x.size()[2:], mode='bilinear', align_corners=True) + actv = self.actv(self.mlp_shared(hm)) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = x * (1+gamma) + beta + + return out + 0.1 * x_s + +class SEL_unet_pro(SEL): + def __init__(self, norm_nc, label_nc, hidden_nc=128, down_filter=None, slope=0.2, gain=np.sqrt(2), clamp=None): + super().__init__(norm_nc, label_nc, hidden_nc) + self.register_buffer('down_filter', down_filter) + self.slope = slope + self.gain = gain + self.clamp = clamp + + def forward(self, x, hm): + x_size = x.shape[-1] + hm_size = hm.shape[-1] + if x_size != hm_size: + hm = upfirdn2d.upfirdn2d(x=hm, f=self.down_filter, down=hm_size//x_size, flip_filter=False, padding=int(2.5 * hm_size//x_size)) + hm = self.mlp_shared(hm) + hm = bias_act.bias_act(x=hm, act='lrelu', alpha=self.slope, gain=self.gain, clamp=self.clamp) + gamma = self.mlp_gamma(hm) + beta = self.mlp_beta(hm) + + out = self.norm(x) * (1+gamma) + beta + return out + 0.1 * x + +class SEL_pro(SEL): + def __init__(self, norm_nc, label_nc, hidden_nc=128, down_filter=None, slope=0.2, gain=np.sqrt(2), clamp=None): + super().__init__(norm_nc, label_nc, hidden_nc) + self.zero_pad = nn.ZeroPad2d((576-256)//2) + self.register_buffer('down_filter', down_filter) + self.slope = slope + self.gain = gain + self.clamp = clamp + self.size_dict = {36: 36, 52: 72, 84: 144, 148: 288, 276: 576} + + def forward(self, x, hm): + x_size = x.shape[-1] + x_large_size = self.size_dict[x_size] + hm = self.zero_pad(hm) + hm_size = hm.shape[-1] + assert hm_size % x_large_size == 0, f'hm shape {hm.shape[-1]}, x shape {x.shape[-1]}, {x_large_size}' + if hm_size != x_large_size: + hm = upfirdn2d.upfirdn2d(x=hm, f=self.down_filter, down=hm_size//x_large_size, flip_filter=False, + padding=int(2.5 * hm_size//x_large_size)) + hm = self.mlp_shared(hm) + hm = bias_act.bias_act(x=hm, act='lrelu', alpha=self.slope, gain=self.gain, clamp=self.clamp) + pad_len = (x_large_size - x_size) // 2 + if pad_len > 0: + hm = hm[..., pad_len:-pad_len, pad_len:-pad_len] + gamma = self.mlp_gamma(hm) + beta = self.mlp_beta(hm) + + out = self.norm(x) * (1+gamma) + beta + return out + 0.1 * x + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor: [batch_size, in_channels, in_height, in_width] + w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] + s, # Style tensor: [batch_size, in_channels] + demodulate = True, # Apply weight demodulation? + padding = 0, # Padding: int or [padH, padW] + input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] +): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(x.shape[0]) + out_channels, in_channels, kh, kw = w.shape + misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(s, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs. + if demodulate: + w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() + s = s * s.square().mean().rsqrt() + + # Modulate weights. + w = w.unsqueeze(0) # [NOIkk] + w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Demodulate weights. + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Apply input scaling. + if input_gain is not None: + input_gain = input_gain.expand(batch_size, in_channels) # [NI] + w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Execute as one fused op using grouped convolution. + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) + x = x.reshape(batch_size, -1, *x.shape[2:]) + return x + +#---------------------------------------------------------------------------- + +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + bias = True, # Apply additive bias before the activation function? + lr_multiplier = 1, # Learning rate multiplier. + weight_init = 1, # Initial standard deviation of the weight tensor. + bias_init = 0, # Initial value of the additive bias. + low_rank = None, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.low_rank = low_rank + self.register_buffer('lr_multiplier', torch.tensor(lr_multiplier, dtype=torch.float)) + if self.low_rank is None: + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) + else: + self.weight_left = torch.nn.Parameter(torch.randn([out_features, self.low_rank]) * (weight_init / lr_multiplier)) + self.weight_right = torch.nn.Parameter(torch.randn([self.low_rank, in_features]) * (weight_init / lr_multiplier)) + bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) + self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None + self.weight_gain = 1 / np.sqrt(in_features) + self.bias_gain = 1 + + def forward(self, x): + if self.low_rank is None: + w = self.weight.to(x.dtype) * self.weight_gain * self.lr_multiplier + else: + w = torch.einsum('ab,bc->ac', self.weight_left, self.weight_right).to(x.dtype) * self.weight_gain * self.lr_multiplier + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain * self.lr_multiplier != 1: + b = b * self.bias_gain * self.lr_multiplier + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +#---------------------------------------------------------------------------- + +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality, 0 = no labels. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output. + num_layers = 2, # Number of mapping layers. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + # Construct layers. + self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None + features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers + for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): + layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + misc.assert_shape(z, [None, self.z_dim]) + if truncation_cutoff is None: + truncation_cutoff = self.num_ws + + # Embed, normalize, and concatenate inputs. + x = z.to(torch.float32) + x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = self.embed(c.to(torch.float32)) + y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() + x = torch.cat([x, y], dim=1) if x is not None else y + + # Execute layers. + for idx in range(self.num_layers): + x = getattr(self, f'fc{idx}')(x) + + # Update moving average of W. + if update_emas: + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast and apply truncation. + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + if truncation_psi != 1: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + +#---------------------------------------------------------------------------- + +class SynthesisInput(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + channels, # Number of output channels. + size, # Output spatial size: int or [width, height]. + sampling_rate, # Output sampling rate. + bandwidth, # Output bandwidth. + offset_scale='0,0', + bound_len=0.5, + wo_transform=False, + ): + super().__init__() + self.w_dim = w_dim + self.channels = channels + self.size = np.broadcast_to(np.asarray(size), [2]) + self.sampling_rate = sampling_rate + self.bandwidth = bandwidth + self.wo_transform = wo_transform + + self.x_offset_scale, self.y_offset_scale = list(map(float, offset_scale.split(','))) + self.bound_len = bound_len + + # Draw random frequencies from uniform 2D disc. + freqs = torch.randn([self.channels, 2]) + radii = freqs.square().sum(dim=1, keepdim=True).sqrt() + freqs /= radii * radii.square().exp().pow(0.25) + freqs *= bandwidth + phases = torch.rand([self.channels]) - 0.5 + + # Setup parameters and buffers. + self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) + if not wo_transform: + self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) + self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. + self.register_buffer('freqs', freqs) + self.register_buffer('phases', phases) + + self.x_offset = None + self.y_offset = None + + def forward(self, w): + # Introduce batch dimension. + transforms = self.transform.unsqueeze(0) # [batch, row, col] + freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] + phases = self.phases.unsqueeze(0) # [batch, channel] + + # Apply learned transformation. + if not self.wo_transform: + t = self.affine(w) # t = (r_c, r_s, t_x, t_y) + t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) + m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. + m_r[:, 0, 0] = t[:, 0] # r'_c + m_r[:, 0, 1] = -t[:, 1] # r'_s + m_r[:, 1, 0] = t[:, 1] # r'_s + m_r[:, 1, 1] = t[:, 0] # r'_c + m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. + m_t[:, 0, 2] = -t[:, 2] # t'_x + m_t[:, 1, 2] = -t[:, 3] # t'_y + transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. + + # Transform frequencies. + phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) + freqs = freqs @ transforms[:, :2, :2] + + # Dampen out-of-band frequencies that may occur due to the user-specified transform. + amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) + + # Construct sampling grid. + theta = torch.eye(2, 3, device=w.device) + theta[0, 0] = self.bound_len * self.size[0] / self.sampling_rate + theta[1, 1] = self.bound_len * self.size[1] / self.sampling_rate + grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) + offset_len = 1 - 2*self.bound_len + dx = (random.random() * offset_len - offset_len/2) if self.x_offset == None else self.x_offset + dy = (random.random() * offset_len - offset_len/2) if self.y_offset == None else self.y_offset + dx *= self.x_offset_scale + dy *= self.y_offset_scale + grids[..., 0] += dx + grids[..., 1] += dy + + # Compute Fourier features. + x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] + x = x + phases.unsqueeze(1).unsqueeze(2) + x = torch.sin(x * (np.pi * 2)) + x = x * amplitudes.unsqueeze(1).unsqueeze(2) + + # Apply trainable mapping. + weight = self.weight / np.sqrt(self.channels) + x = x @ weight.t() + + # Ensure correct shape. + x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] + if self.wo_transform: + x = x.repeat(w.shape[0], 1, 1, 1) + misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', + f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) + +#---------------------------------------------------------------------------- + +class SynthesisLayer(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + is_torgb, # Is this the final ToRGB layer? + is_critically_sampled, # Does this layer use critical sampling? + use_fp16, # Does this layer use FP16? + + # Input & output specifications. + in_channels, # Number of input channels. + out_channels, # Number of output channels. + in_size, # Input spatial size: int or [width, height]. + out_size, # Output spatial size: int or [width, height]. + in_sampling_rate, # Input sampling rate (s). + out_sampling_rate, # Output sampling rate (s). + in_cutoff, # Input cutoff frequency (f_c). + out_cutoff, # Output cutoff frequency (f_c). + in_half_width, # Input transition band half-width (f_h). + out_half_width, # Output Transition band half-width (f_h). + + # Hyperparameters. + conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. + filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. + lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. + use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. + conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. + magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. + label_nc = 0, + use_sel = False, + low_rank = None, + sel_type = 'normal', + **useless_stuff, # this part is dirty to be compatible with the same configuration system with stylegan2 + ): + super().__init__() + print('these configuration terms is not used:', useless_stuff) + self.w_dim = w_dim + self.is_torgb = is_torgb + self.is_critically_sampled = is_critically_sampled + self.use_fp16 = use_fp16 + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.conv_kernel = 1 if is_torgb else conv_kernel + self.conv_clamp = conv_clamp + self.magnitude_ema_beta = magnitude_ema_beta + self.use_sel = use_sel and not is_torgb + self.sel_type = sel_type + + size_dict = {36: 36, 52: 72, 84: 144, 148: 288, 276: 576} + self.down_radial = use_radial_filters and not self.is_critically_sampled + if self.use_sel: + if self.sel_type == 'normal': + self.sel = SEL(norm_nc=in_channels, label_nc=label_nc) + elif self.sel_type == 'pro': + sel_down_factor = 576 // size_dict[self.in_size[0]] + sel_pro_down_filter = self.design_lowpass_filter( + numtaps=filter_size * sel_down_factor, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial) + self.sel = SEL_pro(norm_nc=in_channels, label_nc=label_nc, hidden_nc=128, down_filter=sel_pro_down_filter, + slope=1 if is_torgb else 0.2, gain=1 if is_torgb else np.sqrt(2), clamp=256 if is_torgb else None) + + # Setup parameters and buffers. + self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1, low_rank=low_rank) + self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) + self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) + self.register_buffer('magnitude_ema', torch.ones([])) + + # Design upsampling filter. + self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate + self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 + self.register_buffer('up_filter', self.design_lowpass_filter( + numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) + + # Design downsampling filter. + self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate + self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 + self.register_buffer('down_filter', self.design_lowpass_filter( + numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) + + # Compute padding. + pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. + pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. + pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. + pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + def forward(self, x, w, heatmap=None, noise_mode='random', force_fp32=False, update_emas=False): + assert noise_mode in ['random', 'const', 'none'] # unused + misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) + misc.assert_shape(w, [x.shape[0], self.w_dim]) + + # Track input magnitude. + if update_emas: + with torch.autograd.profiler.record_function('update_magnitude_ema'): + magnitude_cur = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) + input_gain = self.magnitude_ema.rsqrt() + + if self.use_sel: + x = self.sel(x, heatmap).to(x.dtype) + + # Execute affine layer. + styles = self.affine(w) + if self.is_torgb: + weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) + styles = styles * weight_gain + + # Execute modulated conv2d. + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, + padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) + + # Execute bias, filtered leaky ReLU, and clamping. + gain = 1 if self.is_torgb else np.sqrt(2) + slope = 1 if self.is_torgb else 0.2 + if self.up_factor == 1 and self.down_factor == 1: + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp, act='lrelu') + else: + x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), + up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) + assert x.dtype == dtype + return x + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + assert numtaps >= 1 + + # Identity filter. + if numtaps == 1: + return None + + # Separable Kaiser low-pass filter. + if not radial: + f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return torch.as_tensor(f, dtype=torch.float32) + + # Radially symmetric jinc-based filter. + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return torch.as_tensor(f, dtype=torch.float32) + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', + f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', + f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', + f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', + f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', + f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) + +#---------------------------------------------------------------------------- + +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. + num_critical = 2, # Number of critically sampled layers at the end. + first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). + first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). + last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. + margin_size = 10, # Number of additional pixels outside the image. + output_scale = 0.25, # Scale factor for the output image. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + label_nc = 0, + use_sel = False, + sel_type = 'pro', + low_rank = None, + offset_scale = '0,0', + bound_len = 0.5, + wo_transform = False, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + super().__init__() + self.w_dim = w_dim + self.num_ws = num_layers + 2 + self.img_resolution = img_resolution + self.img_channels = img_channels + self.num_layers = num_layers + self.num_critical = num_critical + self.margin_size = margin_size + self.output_scale = output_scale + self.num_fp16_res = num_fp16_res + + self.use_sel = use_sel + self.sel_type = sel_type + + # Geometric progression of layer cutoffs and min. stopbands. + last_cutoff = self.img_resolution / 2 # f_{c,N} + last_stopband = last_cutoff * last_stopband_rel # f_{t,N} + exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) + cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] + stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] + + # Compute remaining layer parameters. + sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] + half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] + sizes = sampling_rates + self.margin_size * 2 + sizes[-2:] = self.img_resolution + channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) + channels[-1] = self.img_channels + + # Construct layers. + self.input = SynthesisInput( + w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), + sampling_rate=sampling_rates[0], bandwidth=cutoffs[0], offset_scale=offset_scale, bound_len=bound_len, wo_transform=wo_transform) + self.layer_names = [] + for idx in range(self.num_layers + 1): + prev = max(idx - 1, 0) + is_torgb = (idx == self.num_layers) + is_critically_sampled = (idx >= self.num_layers - self.num_critical) + use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) + layer = SynthesisLayer( + w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, + in_channels=int(channels[prev]), out_channels= int(channels[idx]), + in_size=int(sizes[prev]), out_size=int(sizes[idx]), + in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), + in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], + in_half_width=half_widths[prev], out_half_width=half_widths[idx], + use_sel=use_sel, sel_type=sel_type, label_nc=label_nc, low_rank=low_rank, + **layer_kwargs) + name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' + setattr(self, name, layer) + self.layer_names.append(name) + + def forward(self, ws, heatmap=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32).unbind(dim=1) + + # Execute layers. + x = self.input(ws[0]) + for name, w in zip(self.layer_names, ws[1:]): + x = getattr(self, name)(x, w, heatmap, **layer_kwargs) + if self.output_scale != 1: + x = x * self.output_scale + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) + x = x.to(torch.float32) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', + f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + label_nc = None, + use_sel = False, + sel_type = False, + low_rank = None, + offset_scale = '0,0', + bound_len = 0.5, + wo_transform = False, + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, label_nc=label_nc, use_sel=use_sel, sel_type=sel_type, low_rank=low_rank, offset_scale=offset_scale, bound_len=bound_len, wo_transform=wo_transform, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, heatmap=None, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, spatial_map=heatmap, **synthesis_kwargs) + return img + +#---------------------------------------------------------------------------- diff --git a/models/utils/ops.py b/models/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..65cbb8945a7d90b81c8d4c05d84404e23b8d3a9c --- /dev/null +++ b/models/utils/ops.py @@ -0,0 +1,56 @@ +# python3.8 +"""Contains some utility operators.""" + +import math +import torch +import torch.distributed as dist +import torch.nn.functional as F + +__all__ = [ + 'all_gather', + 'upsample', + 'downsample', +] + + +def all_gather(tensor): + """Gathers tensor from all devices and executes averaging.""" + if not dist.is_initialized(): + return tensor + + world_size = dist.get_world_size() + tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, async_op=False) + return torch.stack(tensor_list, dim=0).mean(dim=0) + + +def upsample(img_nerf, size, filter=None): + up = size // img_nerf.size(-1) + if up <= 1: + return img_nerf + + if filter is not None: + from third_party.stylegan2_official_ops import upfirdn2d + for _ in range(int(math.log2(up))): + img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2) + else: + img_nerf = F.interpolate(img_nerf, (size, size), + mode='bilinear', + align_corners=False) + return img_nerf + + +def downsample(img0, size, filter=None): + down = img0.size(-1) // size + if down <= 1: + return img0 + + if filter is not None: + from third_party.stylegan2_official_ops import upfirdn2d + for _ in range(int(math.log2(down))): + img0 = upfirdn2d.downsample2d(img0, filter, down=2) + else: + img0 = F.interpolate(img0, (size, size), + mode='bilinear', + align_corners=False) + return img0 \ No newline at end of file diff --git a/models/utils/replicate.py b/models/utils/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/models/utils/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/models/utils/spade.py b/models/utils/spade.py new file mode 100644 index 0000000000000000000000000000000000000000..38f9ed5ab35d6ec75ba2fd1e19acea16794c0c70 --- /dev/null +++ b/models/utils/spade.py @@ -0,0 +1,276 @@ +import re +import torch +import torch.nn as nn +from torch.nn import init +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as spectral_norm + +from models.utils.batchnorm import SynchronizedBatchNorm2d + +class SPADE(nn.Module): + def __init__(self, config_text, norm_nc, label_nc): + super().__init__() + + assert config_text.startswith('spade') + parsed = re.search('spade(\D+)(\d)x\d', config_text) + param_free_norm_type = str(parsed.group(1)) + ks = int(parsed.group(2)) + + if param_free_norm_type == 'instance': + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + elif param_free_norm_type == 'syncbatch': + self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) + elif param_free_norm_type == 'batch': + self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) + else: + raise ValueError('%s is not a recognized param-free norm type in SPADE' + % param_free_norm_type) + + # The dimension of the intermediate embedding space. Yes, hardcoded. + nhidden = 128 + + pw = ks // 2 + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.ReLU() + ) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, x, segmap): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + out = normalized * (1 + gamma) + beta + + return out + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, semantic_nc): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + + # define normalization layers + spade_config_str = norm_G.replace('spectral', '') + self.norm_0 = SPADE(spade_config_str, fin, semantic_nc) + self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc) + if self.learned_shortcut: + self.norm_s = SPADE(spade_config_str, fin, semantic_nc) + + # note the resnet block with SPADE also takes in |seg|, + # the semantic segmentation map as input + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + + dx = self.conv_0(self.actvn(self.norm_0(x, seg))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) + + out = x_s + dx + + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' + % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + +class SPADEGenerator(BaseNetwork): + def __init__(self, z_dim, semantic_nc, ngf, dim_seq, bev_grid_size, aspect_ratio, + num_upsampling_layers, not_use_vae, norm_G): + super().__init__() + nf = ngf + self.not_use_vae = not_use_vae + self.z_dim = z_dim + self.ngf = ngf + self.dim_seq = list(map(int, dim_seq.split(','))) + self.num_upsampling_layers = num_upsampling_layers + + self.sw, self.sh = self.compute_latent_vector_size(num_upsampling_layers, bev_grid_size, aspect_ratio) + + if not not_use_vae: + # In case of VAE, we will sample from random z vector + self.fc = nn.Linear(z_dim, self.dim_seq[0] * nf * self.sw * self.sh) + else: + # Otherwise, we make the network deterministic by starting with + # downsampled segmentation map instead of random z + self.fc = nn.Conv2d(semantic_nc, self.dim_seq[0] * nf, 3, padding=1) + + self.head_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc) + + self.G_middle_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc) + self.G_middle_1 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc) + + self.up_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[1] * nf, norm_G, semantic_nc) + self.up_1 = SPADEResnetBlock(self.dim_seq[1] * nf, self.dim_seq[2] * nf, norm_G, semantic_nc) + self.up_2 = SPADEResnetBlock(self.dim_seq[2] * nf, self.dim_seq[3] * nf, norm_G, semantic_nc) + self.up_3 = SPADEResnetBlock(self.dim_seq[3] * nf, self.dim_seq[4] * nf, norm_G, semantic_nc) + + final_nc = nf * self.dim_seq[4] + + if num_upsampling_layers == 'most': + self.up_4 = SPADEResnetBlock(self.dim_seq[4] * nf, nf // 2, norm_G, semantic_nc) + final_nc = nf // 2 + + self.conv_img = nn.Conv2d(final_nc, 32, 3, padding=1) + # self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) + + self.up = nn.Upsample(scale_factor=2) + + def compute_latent_vector_size(self, num_upsampling_layers, bev_grid_size, aspect_ratio): + if num_upsampling_layers == 'normal': + num_up_layers = 5 + elif num_upsampling_layers == 'more': + num_up_layers = 6 + elif num_upsampling_layers == 'most': + num_up_layers = 7 + else: + raise ValueError('num_upsampling_layers [%s] not recognized' % + num_upsampling_layers) + + sw = bev_grid_size // (2**num_up_layers) + sh = round(sw / aspect_ratio) + + return sw, sh + + def forward(self, input, z=None): + seg = input + + if not self.not_use_vae: + # we sample z from unit normal and reshape the tensor + if z is None: + z = torch.randn(input.size(0), self.z_dim, + dtype=torch.float32, device=input.get_device()) + x = self.fc(z) + x = x.view(-1, self.dim_seq[0] * self.ngf, self.sh, self.sw) + else: + # we downsample segmap and run convolution + x = F.interpolate(seg, size=(self.sh, self.sw)) + x = self.fc(x) + + x = self.head_0(x, seg) + + x = self.up(x) + x = self.G_middle_0(x, seg) + + if self.num_upsampling_layers == 'more' or \ + self.num_upsampling_layers == 'most': + x = self.up(x) + + x = self.G_middle_1(x, seg) + + x = self.up(x) + x = self.up_0(x, seg) + x = self.up(x) + x = self.up_1(x, seg) + x = self.up(x) + x = self.up_2(x, seg) + x = self.up(x) + x = self.up_3(x, seg) + + if self.num_upsampling_layers == 'most': + x = self.up(x) + x = self.up_4(x, seg) + + # TODO: Wtf is this leaky relu + x = self.conv_img(F.leaky_relu(x, 2e-1)) + # x = torch.tanh(x) + + return x + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--z_dim', type=int, default=10) + parser.add_argument('--semantic_nc', type=int, default=10) + parser.add_argument('--ngf', type=int, default=64) + parser.add_argument('--bev_grid_size', type=int, default=512) + parser.add_argument('--aspect_ratio', type=float, default=1.0) + parser.add_argument('--num_upsampling_layers', type=str, default='more') + parser.add_argument('--not_use_vae', action="store_true") + parser.add_argument('--norm_G', type=str, default='spectralspadesyncbatch3x3', help='instance normalization or batch normalization') + + args = parser.parse_args() + sg = SPADEGenerator(args).cuda() + seg = torch.zeros([2, 10, 5, 5]).cuda() + while 1: + import pdb;pdb.set_trace() + out = sg(seg) diff --git a/models/utils/unet.py b/models/utils/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..97048b77d6663f8909f0cc3a0a94211df04256e7 --- /dev/null +++ b/models/utils/unet.py @@ -0,0 +1,175 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy.signal +from .blurpool import BlurPool +from .official_stylegan3_model_helper import SEL, SEL_unet_pro, MappingNetwork, FullyConnectedLayer, modulated_conv2d, SynthesisInput +from third_party.stylegan3_official_ops import filtered_lrelu +from third_party.stylegan3_official_ops import upfirdn2d +from third_party.stylegan3_official_ops import bias_act + +class UNetBlock(nn.Module): + def __init__(self, w_dim, in_channel, latent_channel, out_channel, ks=3, layer_num=2): + super().__init__() + self.ks = ks + + self.layer_num = layer_num + self.weight1 = nn.Parameter(torch.randn([latent_channel, in_channel, ks, ks])) + self.weight2 = nn.Parameter(torch.randn([out_channel, latent_channel, ks, ks])) + self.bias1 = nn.Parameter(torch.zeros([latent_channel])) + self.bias2 = nn.Parameter(torch.zeros([out_channel])) + self.affine1 = FullyConnectedLayer(w_dim, in_channel, bias_init=1) + self.affine2 = FullyConnectedLayer(w_dim, latent_channel, bias_init=1) + + if self.layer_num == 3: + self.weight_mid = nn.Parameter(torch.randn([latent_channel, latent_channel, ks, ks])) + self.bias_mid = nn.Parameter(torch.zeros([latent_channel])) + self.affine_mid = FullyConnectedLayer(w_dim, latent_channel, bias_init=1) + + def forward(self, x, *w): + + s1 = self.affine1(w[0]) + if self.layer_num == 3: + s_mid = self.affine_mid(w[1]) + s2 = self.affine2(w[2]) + else: + s2 = self.affine2(w[1]) + + x = modulated_conv2d(x, w=self.weight1, s=s1, padding=self.ks//2) + x = bias_act.bias_act(x, self.bias1.to(x.dtype), act='lrelu') + if self.layer_num == 3: + x = modulated_conv2d(x, w=self.weight_mid, s=s_mid, padding=self.ks//2) + x = bias_act.bias_act(x, self.bias_mid.to(x.dtype), act='lrelu') + x = modulated_conv2d(x, w=self.weight2, s=s2, padding=self.ks//2) + x = bias_act.bias_act(x, self.bias2.to(x.dtype), act='lrelu') + + return x + + +class UNet(nn.Module): + def __init__(self, w_dim, in_dim=3, base_dim=64, ks=3, block_num=3, layer_num=2, filt_size=3, output_dim=3, label_nc=14, sel_type='normal', img_resolution=256, wo_transform = False,): + super().__init__() + + self.block_num = block_num + self.layer_num = layer_num + + self.sel_type = sel_type + if self.sel_type == 'normal': + self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + for i in range(block_num): + self.register_buffer(f'down_filter_{i}', self.design_lowpass_filter(numtaps=12, cutoff=2**((block_num-i+1)/2), width=None, fs=img_resolution//(2**i))) + self.register_buffer(f'sel_down_filter_{i}', self.design_lowpass_filter(numtaps=6*2**i, cutoff=2**((block_num-i+2)/2), width=None, fs=img_resolution//(2**(i-1)))) + self.input = SynthesisInput(w_dim=w_dim, channels=in_dim, size=img_resolution, sampling_rate=img_resolution, bound_len=0, bandwidth=4, wo_transform=wo_transform) # what is the bandwidth + + encoder_list, sel_enc_list, sel_dec_list, decoder_list, bp_list = [], [], [], [], [] + + for i in range(block_num): + if i == 0: + encoder_list.append(UNetBlock(w_dim, in_dim, base_dim, base_dim, layer_num=layer_num)) + else: + encoder_list.append(UNetBlock(w_dim, base_dim * 2 ** (i-1), base_dim * 2 ** i, base_dim * 2 ** i, layer_num=layer_num)) + + decoder_list.append(UNetBlock(w_dim, + base_dim * 2 ** (block_num-i), + base_dim * 2 ** (block_num-i-1), + base_dim * 2 ** (block_num-i-2) if i < block_num-1 else base_dim * 2 ** (block_num-i-1), + layer_num=layer_num + )) + + if self.sel_type == 'normal': + sel_enc_list.append(SEL(in_dim if i==0 else base_dim * 2 ** (i-1), label_nc)) + sel_dec_list.append(SEL(base_dim * 2 ** (block_num-i-1), label_nc)) + else: + sel_enc_list.append(SEL_unet_pro(in_dim if i==0 else base_dim * 2 ** (i-1), label_nc, down_filter=getattr(self, f'sel_down_filter_{i}'))) + sel_dec_list.append(SEL_unet_pro(base_dim * 2 ** (block_num-i-1), label_nc, down_filter=getattr(self, f'sel_down_filter_{block_num-i-1}'))) + + self.encoders = nn.ModuleList(encoder_list) + self.decoders = nn.ModuleList(decoder_list) + self.enc_sels = nn.ModuleList(sel_enc_list) + self.dec_sels = nn.ModuleList(sel_dec_list) + + self.torgb = UNetBlock(w_dim, base_dim, base_dim, output_dim) + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, fs, width=None): + if numtaps == 1: + return None + f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return torch.as_tensor(f, dtype=torch.float32) + + def forward(self, ws, heatmap, **kwargs): + ws = ws.unbind(1) + x = self.input(ws[0]) + ws = ws[1:] + + enc_x = [] + for i in range(self.block_num): + # modulate with SEL + x = self.enc_sels[i] (x, heatmap) + + if self.layer_num==2: + x = self.encoders[i] (x, ws[2*i], ws[2*i+1]) + else: + x = self.encoders[i] (x, ws[3*i], ws[3*i+1], ws[3*i+2]) + + enc_x.append(x) + if self.sel_type == 'normal': + x = self.pool(x) + else: + x = upfirdn2d.upfirdn2d(x=x, f=getattr(self, f'down_filter_{i}'), down=2, flip_filter=False, padding=5) + + ws = ws[self.layer_num*self.block_num: ] + for i in range(self.block_num): + x = F.interpolate(x, size=x.shape[-1] * 2, mode='bilinear', align_corners=False) + # modulate with SEL + x = self.dec_sels[i] (x, heatmap) + if self.layer_num==2: + x = self.decoders[i] (torch.cat([x, enc_x[-1-i]], 1), ws[2*i], ws[2*i+1]) + else: + x = self.decoders[i] (torch.cat([x, enc_x[-1-i]], 1), ws[3*i], ws[3*i+1], ws[3*i+2]) + + ws = ws[self.layer_num*self.block_num: ] + x = self.torgb(x, ws[0], ws[1]) + return x + +class Generator(nn.Module): + def __init__(self, z_dim, c_dim, w_dim, img_resolution=256, img_channels=3, + in_dim=3, base_dim=64, ks=3, block_num=3, layer_num=2, filt_size=3, output_dim=3, label_nc=14, sel_type='normal', wo_transform=False, **kwargs): + super().__init__() + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=2*layer_num*block_num+3) + self.synthesis = UNet(w_dim=w_dim, in_dim=in_dim, base_dim=64, ks=3, block_num=block_num, layer_num=layer_num, filt_size=3, output_dim=img_channels, label_nc=label_nc, sel_type=sel_type, img_resolution=img_resolution, wo_transform=wo_transform) + + def forward(self, z, c, heatmap, truncation_psi=1, truncation_cutoff=None, update_emas=False): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + + ret = self.synthesis(ws, heatmap=heatmap) + + return ret + +# class SELUNet(UNet): +# def forward(self, x, hm): + + +if __name__ == '__main__': + # g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4, img_resolution=256, img_channels=32, sel_type='abn') + # hm = torch.ones([10, 14, 256, 256]) + # z = torch.zeros([10, 64]) + # c = None + # opt = g(z, c, hm) + g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4,layer_num=3, img_resolution=512, img_channels=32, sel_type='abn') + hm = torch.ones([10, 14, 512, 512]) + z = torch.zeros([10, 64]) + c = None + opt = g(z, c, hm) + g = Generator(z_dim=64, c_dim=0, w_dim=512, block_num=4,layer_num=3, img_resolution=256, img_channels=32, sel_type='abn') + hm = torch.ones([10, 14,256,256]) + z = torch.zeros([10, 64]) + c = None + opt = g(z, c, hm) + + + + + + diff --git a/models/volumegan_discriminator.py b/models/volumegan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..a2245aea0b530043801cd6fa55be541513539551 --- /dev/null +++ b/models/volumegan_discriminator.py @@ -0,0 +1,728 @@ +# python3.7 +"""Contains the implementation of discriminator described in VolumeGAN. +Paper: https://arxiv.org/pdf/2112.10759.pdf +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from third_party.stylegan2_official_ops import bias_act +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import conv2d_gradfix + +__all__ = ['VolumeGANDiscriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Architectures allowed. +_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] + + +class VolumeGANDiscriminator(nn.Module): + """Defines the discriminator network in VolumeGAN. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the backbone: + + (1) resolution: The resolution of the input image. (default: -1) + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) architecture: Type of architecture. Support `origin`, `skip`, and + `resnet`. (default: `resnet`) + (5) use_wscale: Whether to use weight scaling. (default: True) + (6) wscale_gain: The factor to control weight scaling. (default: 1.0) + (7) lr_mul: Learning rate multiplier for backbone. (default: 1.0) + (8) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 4) + (9) mbstd_channels: Number of new channels (appended to the original feature + map) after the minibatch standard deviation layer. (default: 1) + (10) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (12) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 3, 3, 1)) + (13) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Settings for conditional model: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_use_wscale: Whether to use weight scaling for embedding + learning. (default: True) + (5) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (6) normalize_embedding: Whether to normalize the embedding. (default: True) + (7) mapping_layers: Number of layers of the additional mapping network after + embedding. (default: 0) + (8) mapping_fmaps: Number of hidden channels of the additional mapping + network after embedding. (default: 512) + (9) mapping_use_wscale: Whether to use weight scaling for the additional + mapping network. (default: True) + (10) mapping_lr_mul: Learning rate multiplier for the additional mapping + network after embedding. (default: 0.1) + + Runtime settings: + + (1) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (2) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN2, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for backbone. + resolution=-1, + init_res=4, + image_channels=3, + architecture='resnet', + use_wscale=True, + wscale_gain=1.0, + lr_mul=1.0, + mbstd_groups=4, + mbstd_channels=1, + fmaps_base=32 << 10, + fmaps_max=512, + filter_kernel=(1, 3, 3, 1), + conv_clamp=None, + eps=1e-8, + # Settings for conditional model. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_use_wscale=True, + embedding_lr_mul=1.0, + normalize_embedding=True, + mapping_layers=0, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_lr_mul=0.1): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `architecture` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + architecture = architecture.lower() + if architecture not in _ARCHITECTURES_ALLOWED: + raise ValueError(f'Invalid architecture: `{architecture}`!\n' + f'Architectures allowed: ' + f'{_ARCHITECTURES_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.image_channels = image_channels + self.architecture = architecture + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.mbstd_groups = mbstd_groups + self.mbstd_channels = mbstd_channels + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_lr_mul = mapping_lr_mul + + self.pth_to_tf_var_mapping = {} + self.register_buffer('lod', torch.zeros(())) + # Embedding for conditional discrimination. + self.use_embedding = label_dim > 0 and embedding_dim > 0 + if self.use_embedding: + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + add_bias=embedding_bias, + init_bias=0.0, + use_wscale=embedding_use_wscale, + wscale_gain=wscale_gain, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' + if self.embedding_bias: + self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' + + if self.normalize_embedding: + self.norm = PixelNormLayer(dim=1, eps=eps) + + for i in range(mapping_layers): + in_channels = (embedding_dim if i == 0 else mapping_fmaps) + out_channels = (embedding_dim if i == (mapping_layers - 1) else + mapping_fmaps) + layer_name = f'mapping{i}' + self.add_module(layer_name, + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + init_bias=0.0, + use_wscale=mapping_use_wscale, + wscale_gain=wscale_gain, + lr_mul=mapping_lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'Mapping{i}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'Mapping{i}/bias') + + # Convolutional backbone. + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution (if needed). + + layer_name = f'input{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/FromRGB/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/FromRGB/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0/bias') + + # Second layer (kernel 3x3) with downsampling + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv1_down/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv1_down/bias') + + # Residual branch (kernel 1x1) with downsampling, without bias, + # with linear activation. + if self.architecture == 'resnet': + layer_name = f'residual{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + add_bias=False, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear', + conv_clamp=None)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Skip/weight') + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer( + groups=mbstd_groups, new_channels=mbstd_channels, eps=eps) + + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels + mbstd_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv/bias') + + # Second layer, as a fully-connected layer. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + DenseLayer(in_channels=in_channels * res * res, + out_channels=in_channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Dense0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Dense0/bias') + + # Final dense layer to output score. + self.output = DenseLayer(in_channels=in_channels, + out_channels=(embedding_dim + if self.use_embedding + else max(label_dim, 1)), + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight' + self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias' + + # Used for downsampling input image for `skip` architecture. + if self.architecture == 'skip': + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + + def get_nf(self, res): + """Gets number of feature maps according to current resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, lod=None, label=None, fp16_res=None, impl='cuda'): + # Check shape. + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + batch_size = image.shape[0] + if label.ndim != 2 or label.shape != (batch_size, self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'images ({image.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + if self.use_embedding: + embed = self.embedding(label, impl=impl) + if self.normalize_embedding: + embed = self.norm(embed) + for i in range(self.mapping_layers): + embed = getattr(self, f'mapping{i}')(embed, impl=impl) + + # Cast to `torch.float16` if needed. + if fp16_res is not None and self.resolution >= fp16_res: + image = image.to(torch.float16) + + lod = self.lod.item() if lod is None else lod + x = self.input0(image, impl=impl) + + for res_log2 in range(self.final_res_log2, self.init_res_log2, -1): + res = 2 ** res_log2 + # Cast to `torch.float16` if needed. + if fp16_res is not None and res >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + idx = cur_lod = self.final_res_log2 - res_log2 # Block index + + if cur_lod <= lod < cur_lod + 1: + x = getattr(self, f'input{idx}')(image, impl=impl) + elif cur_lod - 1 < lod < cur_lod: + alpha = lod - np.floor(lod) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = y * alpha + x * (1 - alpha) + if lod < cur_lod + 1: + if self.architecture == 'skip' and idx > 0: + image = upfirdn2d.downsample2d(image, self.filter, impl=impl) + # Cast to `torch.float16` if needed. + if fp16_res is not None and res >= fp16_res: + image = image.to(torch.float16) + else: + image = image.to(torch.float32) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = x + y + if self.architecture == 'resnet': + residual = getattr(self, f'residual{idx}')( + x, runtime_gain=np.sqrt(0.5), impl=impl) + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')( + x, runtime_gain=np.sqrt(0.5), impl=impl) + x = x + residual + else: + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) + + if lod > cur_lod: + image = F.avg_pool2d( + image, kernel_size=2, stride=2, padding=0) + # Final output. + if fp16_res is not None: # Always use FP32 for the last block. + x = x.to(torch.float32) + if self.architecture == 'skip': + image = upfirdn2d.downsample2d(image, self.filter, impl=impl) + if fp16_res is not None: # Always use FP32 for the last block. + image = image.to(torch.float32) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = x + y + x = self.mbstd(x) + x = getattr(self, f'layer{2 * idx + 2}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 3}')(x, impl=impl) + x = self.output(x, impl=impl) + + if self.use_embedding: + x = (x * embed).sum(dim=1, keepdim=True) + x = x / np.sqrt(self.embedding_dim) + elif self.label_dim > 0: + x = (x * label).sum(dim=1, keepdim=True) + + results = { + 'score': x, + 'label': label + } + if self.use_embedding: + results['embedding'] = embed + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, new_channels, eps): + super().__init__() + self.groups = groups + self.new_channels = new_channels + self.eps = eps + + def extra_repr(self): + return (f'groups={self.groups}, ' + f'new_channels={self.new_channels}, ' + f'epsilon={self.eps}') + + def forward(self, x): + if self.groups <= 1 or self.new_channels < 1: + return x + + dtype = x.dtype + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + nC = self.new_channels # Number of channel groups. + c = C // nC # Channels per channel group. + + y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW] + y = y - y.mean(dim=0) # [GnFcHW] + y = y.square().mean(dim=0) # [nFcHW] + y = (y + self.eps).sqrt() # [nFcHW] + y = y.mean(dim=(2, 3, 4)) # [nF] + y = y.reshape(-1, nC, 1, 1) # [nF11] + y = y.repeat(G, 1, H, W) # [NFHW] + x = torch.cat((x, y), dim=1) # [N(C+F)HW] + + assert x.dtype == dtype + return x + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If downsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` first. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type, + conv_clamp): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for downsampling. `1` means skip + downsampling. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + self.conv_clamp = conv_clamp + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw - scale_factor + 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh - scale_factor + 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'downsample={self.scale_factor}, ' + f'downsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward(self, x, runtime_gain=1.0, impl='cuda'): + dtype = x.dtype + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor == 1: # Native convolution without downsampling. + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=padding, impl=impl) + else: # Convolution with downsampling. + down = self.scale_factor + f = self.filter + padding = self.filter_padding + # When kernel size = 1, use filtering function for downsampling. + if self.kernel_size == 1: + x = upfirdn2d.upfirdn2d( + x, f, down=down, padding=padding, impl=impl) + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=0, impl=impl) + # When kernel size != 1, use stride convolution for downsampling. + else: + x = upfirdn2d.upfirdn2d( + x, f, down=1, padding=padding, impl=impl) + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=down, padding=0, impl=impl) + + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + init_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.init_bias = init_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + init_bias = np.float32(init_bias) / lr_mul + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x diff --git a/models/volumegan_generator.py b/models/volumegan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..231b7ed1ea703b900939ecbc4fa0143f5b2a9607 --- /dev/null +++ b/models/volumegan_generator.py @@ -0,0 +1,828 @@ +# python3.8 +"""Contains the implementation of generator described in VolumeGAN. + +Paper: https://arxiv.org/pdf/2112.10759.pdf +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .stylegan2_generator import MappingNetwork +from .stylegan2_generator import ModulateConvLayer +from .stylegan2_generator import ConvLayer +from .stylegan2_generator import DenseLayer +from third_party.stylegan2_official_ops import upfirdn2d +from .rendering import Renderer +from .rendering import FeatureExtractor +from .utils.ops import all_gather + + +class VolumeGANGenerator(nn.Module): + """Defines the generator network in VoumeGAN.""" + + def __init__( + self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=8, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_wscale_gain=1.0, + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_use_wscale=True, + embedding_wscale_gian=1.0, + embedding_lr_mul=1.0, + normalize_embedding=True, + normalize_embedding_latent=False, + # Settings for post neural renderer network. + resolution=-1, + nerf_res=32, + image_channels=3, + final_tanh=False, + demodulate=True, + use_wscale=True, + wscale_gain=1.0, + lr_mul=1.0, + noise_type='spatial', + fmaps_base=32 << 10, + fmaps_max=512, + filter_kernel=(1, 3, 3, 1), + conv_clamp=None, + eps=1e-8, + rgb_init_res_out=True, + # Settings for feature volume. + fv_cfg=dict(feat_res=32, + init_res=4, + base_channels=256, + output_channels=32, + w_dim=512), + # Settings for position encoder. + embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10), + # Settings for MLP network. + fg_cfg=dict(num_layers=4, hidden_dim=256, activation_type='lrelu'), + bg_cfg=None, + out_dim=512, + # Settings for rendering. + rendering_kwargs={}): + + super().__init__() + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_wscale_gain = mapping_wscale_gain + self.mapping_lr_mul = mapping_lr_mul + + self.latent_dim = (z_dim,) + self.label_size = label_dim + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_wscale_gain = embedding_wscale_gian + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + + self.resolution = resolution + self.nerf_res = nerf_res + self.image_channels = image_channels + self.final_tanh = final_tanh + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + self.num_nerf_layers = fg_cfg['num_layers'] + self.num_cnn_layers = int(np.log2(resolution // nerf_res * 2)) * 2 + self.num_layers = self.num_nerf_layers + self.num_cnn_layers + + # Set up `w_avg` for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + # Set up the mapping network. + self.mapping = MappingNetwork( + input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + use_wscale=mapping_use_wscale, + wscale_gain=mapping_wscale_gain, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + embedding_bias=embedding_bias, + embedding_use_wscale=embedding_use_wscale, + embedding_wscale_gian=embedding_wscale_gian, + embedding_lr_mul=embedding_lr_mul, + normalize_embedding=normalize_embedding, + normalize_embedding_latent=normalize_embedding_latent, + eps=eps) + + # Set up the overall renderer. + self.renderer = Renderer() + + # Set up the reference representation generator. + self.ref_representation_generator = FeatureVolume(**fv_cfg) + + # Set up the position encoder. + self.position_encoder = PositionEncoder(**embed_cfg) + + # Set up the feature extractor. + self.feature_extractor = FeatureExtractor(ref_mode='feature_volume') + + # Set up the post module in the feature extractor. + self.post_module = NeRFMLPNetwork(input_dim=self.position_encoder.out_dim + + fv_cfg['output_channels'], + fg_cfg=fg_cfg, + bg_cfg=bg_cfg) + + # Set up the fully-connected layer head. + self.fc_head = FCHead(fg_cfg=fg_cfg, bg_cfg=bg_cfg, out_dim=out_dim) + + # Set up the post neural renderer. + self.post_neural_renderer = PostNeuralRendererNetwork( + resolution=resolution, + init_res=nerf_res, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + fmaps_base=fmaps_base, + filter_kernel=filter_kernel, + fmaps_max=fmaps_max, + conv_clamp=conv_clamp, + eps=eps, + rgb_init_res_out=rgb_init_res_out) + + # Set up some rendering related arguments. + self.rendering_kwargs = rendering_kwargs + + # Set up vars' mapping from current implementation to the official + # implementation. Note that this is only for debug. + self.cur_to_official_part_mapping = { + 'w_avg': 'w_avg', + 'mapping': 'mapping', + 'ref_representation_generator': 'nerfmlp.fv', + 'post_module.fg_mlp': 'nerfmlp.fg_mlps', + 'fc_head.fg_sigma_head': 'nerfmlp.fg_density', + 'fc_head.fg_rgb_head': 'nerfmlp.fg_color', + 'post_neural_renderer': 'synthesis' + } + + # Set debug mode only when debugging. + if self.rendering_kwargs.get('debug_mode', False): + self.set_weights_from_official( + rendering_kwargs.get('cur_state', None), + rendering_kwargs.get('official_state', None)) + + def get_cur_to_official_full_mapping(self, keys_cur): + cur_to_official_full_mapping = {} + for key, val in self.cur_to_official_part_mapping.items(): + for key_cur_full in keys_cur: + if key in key_cur_full: + sub_key = key_cur_full.replace(key, '') + cur_to_official_full_mapping[key + sub_key] = val + sub_key + return cur_to_official_full_mapping + + def set_weights_from_official(self, cur_state, official_state): + keys_cur = cur_state['models']['generator_smooth'].keys() + self.cur_to_official_full_mapping = ( + self.get_cur_to_official_full_mapping(keys_cur)) + for name, param in self.named_parameters(): + param.data = (official_state['models']['generator_smooth'][ + self.cur_to_official_full_mapping[name]]) + + def forward( + self, + z, + label=None, + lod=None, + w_moving_decay=None, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + noise_mode='const', + fused_modulate=False, + impl='cuda', + fp16_res=None, + ): + mapping_results = self.mapping(z, label, impl=impl) + w = mapping_results['w'] + lod = self.post_neural_renderer.lod.item() if lod is None else lod + + if self.training and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results['wp'] + + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label, impl=impl)['wp'] + current_layers = self.num_layers + if current_layers > self.num_nerf_layers: + mixing_cutoff = np.random.randint(self.num_nerf_layers, + current_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + nerf_w = wp[:,:self.num_nerf_layers] + cnn_w = wp[:,self.num_nerf_layers:] + + feature_volume = self.ref_representation_generator(nerf_w) + + rendering_results = self.renderer( + wp=nerf_w, + feature_extractor=self.feature_extractor, + rendering_options=self.rendering_kwargs, + position_encoder=self.position_encoder, + ref_representation=feature_volume, + post_module=self.post_module, + fc_head=self.fc_head) + + feature2d = rendering_results['composite_rgb'] + feature2d = feature2d.reshape(feature2d.shape[0], self.nerf_res, + self.nerf_res, -1).permute(0, 3, 1, 2) + + final_results = self.post_neural_renderer( + feature2d, + cnn_w, + lod=None, + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl, + fp16_res=fp16_res) + + return {**mapping_results, **final_results} + + +class PositionEncoder(nn.Module): + """Implements the class for positional encoding.""" + + def __init__(self, + input_dim, + max_freq_log2, + N_freqs, + log_sampling=True, + include_input=True, + periodic_fns=(torch.sin, torch.cos)): + """Initializes with basic settings. + + Args: + input_dim: Dimension of input to be embedded. + max_freq_log2: `log2` of max freq; min freq is 1 by default. + N_freqs: Number of frequency bands. + log_sampling: If True, frequency bands are linerly sampled in + log-space. + include_input: If True, raw input is included in the embedding. + Defaults to True. + periodic_fns: Periodic functions used to embed input. + Defaults to (torch.sin, torch.cos). + """ + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.out_dim = 0 + if self.include_input: + self.out_dim += self.input_dim + + self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2.**torch.linspace(0., max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2.**0., 2.**max_freq_log2, + N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input): + assert (input.shape[-1] == self.input_dim) + + out = [] + if self.include_input: + out.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + out = torch.cat(out, dim=-1) + + assert (out.shape[-1] == self.out_dim) + + return out + + +class FeatureVolume(nn.Module): + """Defines feature volume in VolumeGAN.""" + + def __init__(self, + feat_res=32, + init_res=4, + base_channels=256, + output_channels=32, + w_dim=512, + **kwargs): + super().__init__() + self.num_stages = int(np.log2(feat_res // init_res)) + 1 + + self.const = nn.Parameter( + torch.ones(1, base_channels, init_res, init_res, init_res)) + inplanes = base_channels + outplanes = base_channels + + self.stage_channels = [] + for i in range(self.num_stages): + conv = nn.Conv3d(inplanes, + outplanes, + kernel_size=(3, 3, 3), + padding=(1, 1, 1)) + self.stage_channels.append(outplanes) + self.add_module(f'layer{i}', conv) + instance_norm = InstanceNormLayer(num_features=outplanes, + affine=False) + + self.add_module(f'instance_norm{i}', instance_norm) + inplanes = outplanes + outplanes = max(outplanes // 2, output_channels) + if i == self.num_stages - 1: + outplanes = output_channels + + self.mapping_network = nn.Linear(w_dim, sum(self.stage_channels) * 2) + self.mapping_network.apply(kaiming_leaky_init) + with torch.no_grad(): + self.mapping_network.weight *= 0.25 + self.upsample = UpsamplingLayer() + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, w, **kwargs): + if w.ndim == 3: + _w = w[:, 0] + else: + _w = w + scale_shifts = self.mapping_network(_w) + scales = scale_shifts[..., :scale_shifts.shape[-1] // 2] + shifts = scale_shifts[..., scale_shifts.shape[-1] // 2:] + + x = self.const.repeat(w.shape[0], 1, 1, 1, 1) + for idx in range(self.num_stages): + if idx != 0: + x = self.upsample(x) + conv_layer = self.__getattr__(f'layer{idx}') + x = conv_layer(x) + instance_norm = self.__getattr__(f'instance_norm{idx}') + scale = scales[:, + sum(self.stage_channels[:idx] + ):sum(self.stage_channels[:idx + 1])] + shift = shifts[:, + sum(self.stage_channels[:idx] + ):sum(self.stage_channels[:idx + 1])] + scale = scale.view(scale.shape + (1, 1, 1)) + shift = shift.view(shift.shape + (1, 1, 1)) + x = instance_norm(x, weight=scale, bias=shift) + x = self.lrelu(x) + + return x + + +def kaiming_leaky_init(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + torch.nn.init.kaiming_normal_(m.weight, + a=0.2, + mode='fan_in', + nonlinearity='leaky_relu') + + +class InstanceNormLayer(nn.Module): + """Implements instance normalization layer.""" + + def __init__(self, num_features, epsilon=1e-8, affine=False): + super().__init__() + self.eps = epsilon + self.affine = affine + if self.affine: + self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) + self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) + self.weight.data.uniform_() + self.bias.data.zero_() + + def forward(self, x, weight=None, bias=None): + x = x - torch.mean(x, dim=[2, 3, 4], keepdim=True) + norm = torch.sqrt( + torch.mean(x**2, dim=[2, 3, 4], keepdim=True) + self.eps) + x = x / norm + isnot_input_none = weight is not None and bias is not None + assert (isnot_input_none and not self.affine) or (not isnot_input_none + and self.affine) + if self.affine: + x = x * self.weight + self.bias + else: + x = x * weight + bias + return x + + +class UpsamplingLayer(nn.Module): + + def __init__(self, scale_factor=2): + super().__init__() + self.scale_factor = scale_factor + + def forward(self, x): + if self.scale_factor <= 1: + return x + return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') + + +class NeRFMLPNetwork(nn.Module): + """Defines class of MLP Network described in VolumeGAN. + + Basically, this class takes in latent codes and point coodinates as input, + and outputs features of each point, which is followed by two fully-connected + layer heads. + """ + + def __init__(self, input_dim, fg_cfg, bg_cfg=None): + super().__init__() + self.fg_mlp = self.build_mlp(input_dim=input_dim, **fg_cfg) + + def build_mlp(self, input_dim, num_layers, hidden_dim, activation_type, + **kwargs): + """Implements function to build the `MLP`. + + Note that here the `MLP` network is consists of a series of + `ModulateConvLayer` with `kernel_size=1` to simulate fully-connected + layer. Typically, the input's shape of convolutional layers is + `[N, C, H, W]`. And the input's shape is `[N, C, R*K, 1]` here, which + aims to keep consistent with `MLP`. + """ + default_conv_cfg = dict(resolution=32, + w_dim=512, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=True, + use_wscale=True, + wscale_gain=1, + lr_mul=1, + noise_type='none', + conv_clamp=None, + eps=1e-8) + mlp_list = nn.ModuleList() + in_ch = input_dim + out_ch = hidden_dim + for _ in range(num_layers): + mlp = ModulateConvLayer(in_channels=in_ch, + out_channels=out_ch, + activation_type=activation_type, + **default_conv_cfg) + mlp_list.append(mlp) + in_ch = out_ch + out_ch = hidden_dim + + return mlp_list + + def forward(self, + pre_point_features, + wp, + points_encoding=None, + fused_modulate=False, + impl='cuda'): + N, C, R_K, _ = points_encoding.shape + x = torch.cat([pre_point_features, points_encoding], dim=1) + + for idx, mlp in enumerate(self.fg_mlp): + if wp.ndim == 3: + _w = wp[:, idx] + else: + _w = wp + x, _ = mlp(x, _w, fused_modulate=fused_modulate, impl=impl) + + return x # x's shape: [N, C, R*K, 1] + + +class FCHead(nn.Module): + """Defines fully-connected layer head in VolumeGAN to decode `feature` into + `sigma` and `rgb`.""" + + def __init__(self, fg_cfg, bg_cfg=None, out_dim=512): + super().__init__() + self.fg_sigma_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], + out_channels=1, + add_bias=True, + init_bias=0.0, + use_wscale=True, + wscale_gain=1, + lr_mul=1, + activation_type='linear') + self.fg_rgb_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], + out_channels=out_dim, + add_bias=True, + init_bias=0.0, + use_wscale=True, + wscale_gain=1, + lr_mul=1, + activation_type='linear') + + def forward(self, post_point_features, wp=None, dirs=None): + post_point_features = rearrange( + post_point_features, 'N C (R_K) 1 -> (N R_K) C').contiguous() + fg_sigma = self.fg_sigma_head(post_point_features) + fg_rgb = self.fg_rgb_head(post_point_features) + + results = {'sigma': fg_sigma, 'rgb': fg_rgb} + + return results + + +class PostNeuralRendererNetwork(nn.Module): + """Implements the neural renderer in VolumeGAN to render high-resolution + images. + + Basically, this network executes several convolutional layers in sequence. + """ + + def __init__( + self, + resolution, + init_res, + w_dim, + image_channels, + final_tanh, + demodulate, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + fmaps_base, + fmaps_max, + filter_kernel, + conv_clamp, + eps, + rgb_init_res_out=False, + ): + super().__init__() + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + self.rgb_init_res_out = rgb_init_res_out + + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + self.register_buffer('lod', torch.zeros(())) + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2**res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # Early layer. + if res > init_res: + layer_name = f'layer{2 * block_idx - 1}' + self.add_module( + layer_name, + ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=2, + filter_kernel=filter_kernel, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + conv_clamp=conv_clamp, + eps=eps)) + if block_idx == 0: + if self.rgb_init_res_out: + self.rgb_init_res = ConvLayer( + in_channels=out_channels, + out_channels=image_channels, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear', + conv_clamp=conv_clamp, + ) + continue + # Second layer (kernel 1x1) without upsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + conv_clamp=conv_clamp, + eps=eps)) + + # Output convolution layer for each resolution (if needed). + layer_name = f'output{block_idx}' + self.add_module( + layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=image_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type='none', + activation_type='linear', + conv_clamp=conv_clamp, + eps=eps)) + + # Used for upsampling output images for each resolution block for sum. + self.register_buffer('filter', upfirdn2d.setup_filter(filter_kernel)) + + def get_nf(self, res): + """Gets number of feature maps according to current resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer): + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, + x, + wp, + lod=None, + noise_mode='const', + fused_modulate=False, + impl='cuda', + fp16_res=None, + nerf_out=False): + lod = self.lod.item() if lod is None else lod + + results = {} + + # Cast to `torch.float16` if needed. + if fp16_res is not None and self.init_res >= fp16_res: + x = x.to(torch.float16) + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + cur_lod = self.final_res_log2 - res_log2 + block_idx = res_log2 - self.init_res_log2 + + layer_idxs = [2 * block_idx - 1, 2 * + block_idx] if block_idx > 0 else [ + 2 * block_idx, + ] + # determine forward until cur resolution + if lod < cur_lod + 1: + for layer_idx in layer_idxs: + if layer_idx == 0: + # image = x[:,:3] + if self.rgb_init_res_out: + cur_image = self.rgb_init_res(x, + runtime_gain=1, + impl=impl) + else: + cur_image = x[:, :3] + continue + layer = getattr(self, f'layer{layer_idx}') + x, style = layer( + x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl, + ) + results[f'style{layer_idx}'] = style + if layer_idx % 2 == 0: + output_layer = getattr(self, f'output{layer_idx // 2}') + y, style = output_layer( + x, + wp[:, layer_idx + 1], + fused_modulate=fused_modulate, + impl=impl, + ) + results[f'output_style{layer_idx // 2}'] = style + if layer_idx == 0: + cur_image = y.to(torch.float32) + else: + if not nerf_out: + cur_image = y.to( + torch.float32) + upfirdn2d.upsample2d( + cur_image, self.filter, impl=impl) + else: + cur_image = y.to(torch.float32) + cur_image + + # Cast to `torch.float16` if needed. + if layer_idx != self.num_layers - 2: + res = self.init_res * (2**(layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + # rgb interpolation + if cur_lod - 1 < lod <= cur_lod: + image = cur_image + elif cur_lod < lod < cur_lod + 1: + alpha = np.ceil(lod) - lod + image = F.interpolate(image, scale_factor=2, mode='nearest') + image = cur_image * alpha + image * (1 - alpha) + elif lod >= cur_lod + 1: + image = F.interpolate(image, scale_factor=2, mode='nearest') + + if self.final_tanh: + image = torch.tanh(image) + results['image'] = image + + return results diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..62d07f73748b422855f04a28d7f322d7faeee144 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +einops +imageio +gradio diff --git a/test_local.py b/test_local.py new file mode 100644 index 0000000000000000000000000000000000000000..303bccbe0a36470eb64a76e1d9e3390cf0592e5b --- /dev/null +++ b/test_local.py @@ -0,0 +1,239 @@ +import gradio as gr +from models import build_model +from PIL import Image +import numpy as np +import torchvision +import ninja +import torch +from tqdm import trange +import imageio + +checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth' +state = torch.load(checkpoint, map_location='cpu') +G = build_model(**state['model_kwargs_init']['generator_smooth']) +o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False) +G.eval().cuda() +G.backbone.synthesis.input.x_offset =0 +G.backbone.synthesis.input.y_offset =0 +G_kwargs= dict(noise_mode='const', + fused_modulate=False, + impl='cuda', + fp16_res=None) + +def trans(x, y, z, length): + w = h = length + x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256 + y = 0.5 * h - 128 + (y/9 + .5) * 256 + z = z / 9 * 256 + return x, y, z +def get_bev_from_objs(objs, length=256, scale = 6): + h, w = length, length *scale + nc = 14 + canvas = np.zeros([h, w, nc]) + xx = np.ones([h,w]).cumsum(0) + yy = np.ones([h,w]).cumsum(1) + + for x, y, z, shape, color, material, rot in objs: + y, x, z = trans(x, y, z, length) + + feat = [0] * nc + feat[0] = 1 + feat[COLOR_NAME_LIST.index(color) + 1] = 1 + feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1 + feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1 + feat = np.array(feat) + rot_sin = np.sin(rot / 180 * np.pi) + rot_cos = np.cos(rot / 180 * np.pi) + + if shape == 'cube': + mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \ + (np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z) + else: + mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z + canvas[mask] = feat + canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32) + rotate_angle = 0 + canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy() + return canvas + +# COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue'] +COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue'] +SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder'] +MATERIAL_NAME_LIST = ['rubber', 'metal'] + +xy_lib = dict() +xy_lib['B'] = [ + [-2, -1], + [-1, -1], + [-2, 0], + [-2, 1], + [-1, .5], + [0, 1], + [0, 0], + [0, -1], + [0, 2], + [-1, 2], + [-2, 2] +] +xy_lib['B'] = [ + [-2.5, 1.25], + [-2, 2], + [-2, 0.5], + [-2, -0.75], + [-1, -1], + [-1, 2], + [-1, 0], + [-1, 2], + [0, 1], + [0, 0], + [0, -1], + [0, 2], + # [-1, 2], + +] +xy_lib['B'] = [ + [-2.5, 1.25], + [-2, 2], + [-2, 0.5], + [-2, -1], + [-1, -1.25], + [-1, 2], + [-1, 0], + [-1, 2], + [0, 1], + [0, 0], + [0, -1.25], + [0, 2], + # [-1, 2], + +] +xy_lib['R'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + # [-1, 2], + [-2, -1], + [-2, 0], + [-2.25, 2], + [-1, 1] +] +xy_lib['C'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + [-1, 2], + [-2, -1], + # [-2, .5], + [-2, 2], + # [-1, .5] +] +xy_lib['s'] = [ + [0, -1], + [0, 0], + [0, 2], + [-1, -1], + [-1, 2], + [-2, -1], + [-2, 1], + [-2, 2], + [-1, .5] +] + +xy_lib['F'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + # [-1, 2], + [-2, -1], + [-2, .5], + # [-2, 2], + [-1, .5] +] + +xy_lib['c'] = [ + [0.8,1], + # [-0.8,1], + [0,0.1], + [0,1.9], +] + +xy_lib['e'] = [ + [0, -1], + [0, 0], + [0, 1], + [0, 2], + [-1, -1], + [-1, 2], + [-2, -1], + [-2, .5], + [-2, 2], + [-1, .5] +] +xy_lib['n'] = [ + [0,1], + [0,-1], + [0,0.1], + [0,1.9], + [-1,0], + [-2,1], + [-3,-1], + [-3,1], + [-3,0.1], + [-3,1.9], +] +offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8) +s = 'BeRFsCene' +objs = [] +offset = 2 +for idx, c in enumerate(s): + xy = xy_lib[c] + + + color = np.random.choice(COLOR_NAME_LIST) + for i in range(len(xy)): + # while 1: + # is_ok = 1 + # x, y = + + # for prev_x, prev_y in zip(xpool, ypool): + x, y = xy[i] + y *= 1.5 + y -= 0.5 + x -= offset + z = 0.35 + # if idx<4: + # color = np.random.choice(COLOR_NAME_LIST[:-1]) + # else: + # color = 'blue' + shape = 'cube' + material = 'rubber' + rot = 0 + objs.append([x, y, z, shape, color, material, rot]) + offset += offset_x[c] +Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8)) + +batch_size = 1 +code = torch.randn(1, G.z_dim).cuda() +to_pil = torchvision.transforms.ToPILImage() +large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None] +bevs = large_bevs[..., 0: 0+256] +RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660, + 10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000, + 0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000, + 32.0000, 0.0000, 0.0000, 1.0000]], device='cuda') + +print('prepare finish', flush=True) + +gen = G(code, RT, bevs) +rgb = gen['gen_output']['image'][0] * .5 + .5 + +to_pil(rgb).save('tmp.png') + # save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png' + # return [save_path] + diff --git a/third_party/__init__.py b/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/__pycache__/__init__.cpython-37.pyc b/third_party/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b83f3ced199c710eefd60ea8d4f269fb71487f6 Binary files /dev/null and b/third_party/__pycache__/__init__.cpython-37.pyc differ diff --git a/third_party/__pycache__/__init__.cpython-39.pyc b/third_party/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d51e095e6f8b00b01414edec63e01e96154464bf Binary files /dev/null and b/third_party/__pycache__/__init__.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/README.md b/third_party/stylegan2_official_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf93e3809c8a1e982b55cb1d3164783cc97de9b4 --- /dev/null +++ b/third_party/stylegan2_official_ops/README.md @@ -0,0 +1,30 @@ +# Operators for StyleGAN2 + +All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including + +- `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. +- `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. +- `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. +- `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. +- `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. +- `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. +- `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) + +We make following slight modifications beyond disabling some lint warnings: + +- Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). +- Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators. +- Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) +- Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). +- Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. +- Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. +- Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. +- Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. +- Line 66 of file `conv2d_gradfix.py`: Update PyTorch version check considering the sustained development of the community. +- Line 47 of file `grid_sample_gradfix.py`: Update PyTorch version check considering the sustained development of the community. +- Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. +- Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. + +Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. diff --git a/third_party/stylegan2_official_ops/__init__.py b/third_party/stylegan2_official_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/stylegan2_official_ops/__pycache__/__init__.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c82fbded47983df1b9ce3791e2d442340c2bf37b Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/__init__.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/__init__.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9061276d369cc3c060206222b37c42aaa78ff909 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/bias_act.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/bias_act.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef7e9c9446478f0866e0a461635b7363538a3478 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/bias_act.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/bias_act.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/bias_act.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7171689085cee910ec1a27959703196b2291f157 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/bias_act.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/conv2d_gradfix.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/conv2d_gradfix.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1fd7db2902a501759249be14d37c39556db91c9 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/conv2d_gradfix.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/conv2d_gradfix.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/conv2d_gradfix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5862c4f35d1981d58acadbe173f01094af96bf3 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/conv2d_resample.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/conv2d_resample.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25693fb54bb488f440ae06811eabb32631aff5cf Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/conv2d_resample.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/conv2d_resample.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/conv2d_resample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf5c14494bcc41513491ddca5f99020b6fe4fa59 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/conv2d_resample.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/custom_ops.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/custom_ops.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb7db1bbb28fcde9e0802bc5d77c1fb09989d94c Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/custom_ops.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/custom_ops.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/custom_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdd9a26bca3ed8f6f46e312cdfc39a6bb9dacb17 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/custom_ops.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/fma.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/fma.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f94246c5c912b4bb67cffd412b57ff667a035edc Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/fma.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/fma.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/fma.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d825f226b16325819d8cdae38741756bf5c08cc0 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/fma.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/grid_sample_gradfix.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/grid_sample_gradfix.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4047d0b6983ed001e2d6239674eccaf27ace2e Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/grid_sample_gradfix.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/misc.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/misc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70fd6bff5664946eb8fe328e383bdbe56317fb7d Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/misc.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/misc.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..234e0b3723f2b5f0e8e02d05b38a034726a61171 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/misc.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/upfirdn2d.cpython-37.pyc b/third_party/stylegan2_official_ops/__pycache__/upfirdn2d.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad70a046bb19321049c76ec6acbf09951f9c0421 Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/upfirdn2d.cpython-37.pyc differ diff --git a/third_party/stylegan2_official_ops/__pycache__/upfirdn2d.cpython-39.pyc b/third_party/stylegan2_official_ops/__pycache__/upfirdn2d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..959ded85a9e6f7ce27a6b81b492fa300b8651c4c Binary files /dev/null and b/third_party/stylegan2_official_ops/__pycache__/upfirdn2d.cpython-39.pyc differ diff --git a/third_party/stylegan2_official_ops/bias_act.cpp b/third_party/stylegan2_official_ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/bias_act.cu b/third_party/stylegan2_official_ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/bias_act.h b/third_party/stylegan2_official_ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/bias_act.py b/third_party/stylegan2_official_ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..b94dca1fb0a7f3bc13dce952d8e97a211ec94a88 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.py @@ -0,0 +1,227 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom ops to fuse bias and activation as one operator, which is efficient. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement +# pylint: disable=bare-except + +import os +import warnings +import traceback +from easydict import EasyDict +import numpy as np +import torch + +from . import custom_ops +from . import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement +# pylint: enable=bare-except diff --git a/third_party/stylegan2_official_ops/conv2d_gradfix.py b/third_party/stylegan2_official_ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..f2872868ff2cdfa72917ea08e43dfe8dbdc76b83 --- /dev/null +++ b/third_party/stylegan2_official_ops/conv2d_gradfix.py @@ -0,0 +1,191 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for convolution operators. + +Operators in this file support arbitrarily high order gradients with zero +performance penalty. Please set `impl` as `cuda` to use faster customized +operators, OR as `ref` to use native `torch.nn.functional.conv2d` and +`torch.nn.functional.conv_transpose2d`. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=global-statement +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +import warnings +import contextlib +import torch + +from distutils.version import LooseVersion + +enabled = True # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=global-statement +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan2_official_ops/conv2d_resample.py b/third_party/stylegan2_official_ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..fb76aa245dd4b2c99f79f24c30403c1a1958c90b --- /dev/null +++ b/third_party/stylegan2_official_ops/conv2d_resample.py @@ -0,0 +1,168 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long + +import torch + +from . import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups, impl=impl) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation mode of customized ops. 'ref' for native PyTorch + implementation, 'cuda' for `.cu` implementation + (default: 'cuda'). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long diff --git a/third_party/stylegan2_official_ops/custom_ops.py b/third_party/stylegan2_official_ops/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9a8ef3ec71d144eed7584378546d7ccc183748 --- /dev/null +++ b/third_party/stylegan2_official_ops/custom_ops.py @@ -0,0 +1,159 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Utility functions to setup customized operators. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring +# pylint: disable=useless-suppression +# pylint: disable=inconsistent-quotes + +import os +import glob +import importlib +import hashlib +import shutil +from pathlib import Path + +import torch +from torch.utils.file_baton import FileBaton +import torch.utils.cpp_extension + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +def _find_compiler_bindir_posix(): + patterns = [ + '/usr/local/cuda/bin' + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + elif os.name == 'posix': + compiler_bindir = _find_compiler_bindir_posix() + if compiler_bindir is None: + raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring +# pylint: enable=useless-suppression +# pylint: enable=inconsistent-quotes diff --git a/third_party/stylegan2_official_ops/fma.py b/third_party/stylegan2_official_ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..7304d85825d16612eec488242b220c2dbd83b6d7 --- /dev/null +++ b/third_party/stylegan2_official_ops/fma.py @@ -0,0 +1,73 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c, impl='cuda'): # => a * b + c + if impl == 'cuda': + return _FusedMultiplyAdd.apply(a, b, c) + return torch.addcmul(c, a, b) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan2_official_ops/grid_sample_gradfix.py b/third_party/stylegan2_official_ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..516e61eccdbfa087436853c85394488fd96dad23 --- /dev/null +++ b/third_party/stylegan2_official_ops/grid_sample_gradfix.py @@ -0,0 +1,99 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample`. + +This is useful for differentiable augmentation. This customized operator +supports arbitrarily high order gradients between the input and output. Only +works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and +`align_corners=False`. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import warnings +import torch +from distutils.version import LooseVersion + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + if not enabled: + return False + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): + return True + warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + return False + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan2_official_ops/misc.py b/third_party/stylegan2_official_ops/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7973619f8db5a41f18ef42c83c4c5e5e013e7ff7 --- /dev/null +++ b/third_party/stylegan2_official_ops/misc.py @@ -0,0 +1,281 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Misc functions for customized operations. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=use-maxsplit-arg +# pylint: disable=unnecessary-comprehension + +import re +import contextlib +import warnings +from easydict import EasyDict +import numpy as np +import torch + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring +# pylint: enable=use-maxsplit-arg +# pylint: enable=unnecessary-comprehension diff --git a/third_party/stylegan2_official_ops/upfirdn2d.cpp b/third_party/stylegan2_official_ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/upfirdn2d.cu b/third_party/stylegan2_official_ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916 --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/upfirdn2d.h b/third_party/stylegan2_official_ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/upfirdn2d.py b/third_party/stylegan2_official_ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8447210b41595023a580c50654e2e29557b0fe58 --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.py @@ -0,0 +1,401 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom operators for efficient resampling of 2D images. + +`upfirdn` means executing upsampling, FIR filtering, downsampling in sequence. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement +# pylint: disable=bare-except + +import os +import warnings +import traceback +import numpy as np +import torch + +from . import custom_ops +from . import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement +# pylint: enable=bare-except diff --git a/third_party/stylegan3_official_ops/README.md b/third_party/stylegan3_official_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..417b50ba3dde490f4a4c5c6dfc6afbc28ba640d0 --- /dev/null +++ b/third_party/stylegan3_official_ops/README.md @@ -0,0 +1,30 @@ +# Operators for StyleGAN2 + +All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including + +- `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. +- `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. +- `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. +- `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. +- `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. +- `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. +- `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing. +- `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) + +We make following slight modifications beyond disabling some lint warnings: + +- Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). +- Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators. +- Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) +- Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). +- Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`. +- Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. +- Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. +- Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. +- Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. +- Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. +- Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. + +Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. diff --git a/third_party/stylegan3_official_ops/__init__.py b/third_party/stylegan3_official_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/stylegan3_official_ops/__pycache__/__init__.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c900b47166397d09c1ef519f715ad8f5ab374b4e Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/__init__.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/__init__.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53854ad35c162700d29b3ed1cabc850106812679 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/bias_act.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/bias_act.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4c366aba73b50ebcca6815d3f6e8499e9a65687 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/bias_act.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/bias_act.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/bias_act.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cba8f682cef84e2df9e1236d4007da9d6b3c9653 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/bias_act.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/conv2d_gradfix.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/conv2d_gradfix.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9c1ede3f92139e2e27ae47518dad1465f704a7b Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/conv2d_gradfix.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/conv2d_gradfix.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/conv2d_gradfix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dab6c2bcac2c56c244fa6b9f7b614e66fc855738 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/custom_ops.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/custom_ops.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e419e91d37554e5154418977662d8714268e2795 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/custom_ops.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/custom_ops.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/custom_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b04d7b341c28d2053308b0e33e37054a0964a7f7 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/custom_ops.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/filtered_lrelu.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/filtered_lrelu.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc7f5f3475b55b264c93302c8621c3b4f4b8feb3 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/filtered_lrelu.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/filtered_lrelu.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/filtered_lrelu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd5a365f605efec94a5498f8c7e7d24f18176d1f Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/filtered_lrelu.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/misc.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/misc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e20f093d5535e9d7d18eda1d42952aebac490fff Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/misc.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/misc.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4500f778d8beed8785a8507705d4cf0d0c364a4 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/misc.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/upfirdn2d.cpython-37.pyc b/third_party/stylegan3_official_ops/__pycache__/upfirdn2d.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba9f83975be635dcdbf0cf8e4b9f8fa5dba2b130 Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/upfirdn2d.cpython-37.pyc differ diff --git a/third_party/stylegan3_official_ops/__pycache__/upfirdn2d.cpython-39.pyc b/third_party/stylegan3_official_ops/__pycache__/upfirdn2d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a56a84ba067c91aa873d76a8a5b7331e9cd1f4c Binary files /dev/null and b/third_party/stylegan3_official_ops/__pycache__/upfirdn2d.cpython-39.pyc differ diff --git a/third_party/stylegan3_official_ops/bias_act.cpp b/third_party/stylegan3_official_ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3adaeee2ae44e96655d354c2bdfb81de8ebfe6c6 --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/bias_act.cu b/third_party/stylegan3_official_ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..ed1d16f14eadd1344939e074ace1375cfd936cea --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/bias_act.h b/third_party/stylegan3_official_ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..60b81c6058d54638a6d74a13046fa388442d767d --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/bias_act.py b/third_party/stylegan3_official_ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..c90e4f0fcc22b2eeb0e5b6a10d1d3f700f808e00 --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.py @@ -0,0 +1,222 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom ops to fuse bias and activation as one operator, which is efficient. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement + +import os +from easydict import EasyDict +import numpy as np +import torch + +from . import custom_ops +from . import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement diff --git a/third_party/stylegan3_official_ops/conv2d_gradfix.py b/third_party/stylegan3_official_ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..19aba5ca78f1228e4b8e3aafccbbe072c747f007 --- /dev/null +++ b/third_party/stylegan3_official_ops/conv2d_gradfix.py @@ -0,0 +1,219 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for convolution operators. + +Operators in this file support arbitrarily high order gradients with zero +performance penalty. Please set `impl` as `cuda` to use faster customized +operators, OR as `ref` to use native `torch.nn.functional.conv2d` and +`torch.nn.functional.conv_transpose2d`. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=global-statement +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +import contextlib +import torch + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() +_null_tensor = torch.empty([0]) + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=global-statement +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan3_official_ops/conv2d_resample.py b/third_party/stylegan3_official_ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..dfde81ee19204a7993fd1c3cd21055a51418231b --- /dev/null +++ b/third_party/stylegan3_official_ops/conv2d_resample.py @@ -0,0 +1,154 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long + +import torch + +from . import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight and (kw > 1 or kh > 1): + w = w.flip([2, 3]) + + # Execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation mode, 'cuda' for CUDA implementation, and 'ref' for + native PyTorch implementation (default: 'cuda'). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long diff --git a/third_party/stylegan3_official_ops/custom_ops.py b/third_party/stylegan3_official_ops/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c5853ac187e6e3ae522b0ef1aabefc7b188f7083 --- /dev/null +++ b/third_party/stylegan3_official_ops/custom_ops.py @@ -0,0 +1,191 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Utility functions to setup customized operators. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=multiple-statements +# pylint: disable=missing-function-docstring +# pylint: disable=useless-suppression +# pylint: disable=inconsistent-quotes + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +def _find_compiler_bindir_posix(): + patterns = [ + '/usr/local/cuda/bin' + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + elif os.name == 'posix': + compiler_bindir = _find_compiler_bindir_posix() + if compiler_bindir is None: + raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIONS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=multiple-statements +# pylint: enable=missing-function-docstring +# pylint: enable=useless-suppression +# pylint: enable=inconsistent-quotes diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.cpp b/third_party/stylegan3_official_ops/filtered_lrelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ff4149b8b46b54d2f400ae10e44d19f20503ba1f --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.cpp @@ -0,0 +1,300 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "filtered_lrelu.h" + +//------------------------------------------------------------------------ + +static std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, + int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a CUDA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) + { + // No kernel found - return empty tensors and indicate missing kernel with return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) + { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); + } + + // Populate rest of CUDA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose CUDA kernel. + filtered_lrelu_kernel_spec spec = { 0 }; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] + { + if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. + { + // Choose kernel based on index type, datatype and sign read/write modes. + if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + } + }); + TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) + { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } + else + { + p.tilesXrep = 0; + p.tilesXdim = 0; + } + + // Launch filter setup kernel. + AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); + + // Copy kernels to constant memory. + if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + + // Set cache and shared memory configurations for main kernel. + AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? + AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); + AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // CUDA maximum for block z dimension. + for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); + AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); + } + + // Done. + return std::make_tuple(y, so, 0); +} + +//------------------------------------------------------------------------ + +static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) + { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); + } + + // Initialize CUDA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose CUDA kernel. + void* func = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] + { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - CUDA kernel not found"); + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. + AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); + return so; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. + m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.cu b/third_party/stylegan3_official_ops/filtered_lrelu.cu new file mode 100644 index 0000000000000000000000000000000000000000..8e6f47f873d42f7181a0faf64779377e70be3012 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.cu @@ -0,0 +1,1284 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "filtered_lrelu.h" +#include + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ + MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template struct InternalType; +template <> struct InternalType +{ + typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } + __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) (((B)==1) ? (A) : \ + ((B)==2) ? ((int)((A)+1) >> 1) : \ + ((B)==4) ? ((int)((A)+3) >> 2) : \ + (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers of two. +template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) +{ + if ((N & (N-1)) && N <= 256) + y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i/N; + + x = i - y*N; +} + +// Type cast stride before reading it. +template __device__ __forceinline__ T get_stride(const int64_t& x) +{ + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. +__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) +{ + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) + { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer for main kernel. +template static cudaError_t copy_filters(cudaStream_t stream) +{ + void* src = 0; + cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) +{ + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); + static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); + static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); + static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); + static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); + static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : + (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUFD) ? szIn : + -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : + (filterMode == MODE_FUSD) ? szUpXY : + (filterMode == MODE_SUFD) ? szUpX : + (filterMode == MODE_FUFD) ? szUpXY : + -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t* s_buf0; + scalar_t* s_buf1; + if (sharedKB <= 48) + { + // Allocate shared memory arrays here. + __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } + else + { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t*)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] + scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] + scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] + if (filterMode == MODE_SUSD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } + else if (filterMode == MODE_FUSD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } + else if (filterMode == MODE_SUFD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } + else if (filterMode == MODE_FUFD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + + // Inner tile loop. + #pragma unroll 1 + for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) + { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on first tile. + if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); + #pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) + { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; + + bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); + if (!skip) + s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) + { + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInX == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInX == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + s_tileUpX[dst+2] = v.z; + s_tileUpX[dst+3] = v.w; + } + } + else if (up == 2) + { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) + { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInY == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInY == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } + if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } + else if (up == 2) + { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) + { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) + s_tileUpXY[dst + tileUpW] = v.y; + } + else + { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) + { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } + else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) + { + // Full upsampling filter. + + if (up == 2) + { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + + #define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) \ + { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 0) } + if (tap0y == 0 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 1) } + if (tap0y == 1 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 0) } + if (tap0y == 1 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 1) } + + #undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) + { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } + else if (up == 1) + { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + } + else + { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + else + { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } + else if (signRead) + { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) + { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer + *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) + { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) + { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + s_tileDownX[idx+2] = v.z; + s_tileDownX[idx+3] = v.w; + } + } + else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) + { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + } + } + else + { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) + { + // Full downsampling filter. + if (down == 2) + { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int sy = 0; sy < fdSize; sy++) + #pragma unroll + for (int sx = 0; sx < fdSize; sx++) + { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) + { + index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; + } + } + } + else if (down == 1 && !downInline) + { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) + break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. +// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) + { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) + { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) + { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) + { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t*)p.s)[is >> 4] = s; + } + } + else if (signRead) + { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) + { + uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } + else + { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) + v *= p.slope; + if (fabsf(v) > p.clamp) + v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template void* choose_filtered_lrelu_act_kernel(void) +{ + return (void*)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) +{ + filtered_lrelu_kernel_spec s = { 0 }; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ + { \ + static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ + static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ + s.setup = (void*)setup_filters_kernel; \ + s.exec = (void*)filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. + // Kernels that use more shared memory must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 + + #undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.h b/third_party/stylegan3_official_ops/filtered_lrelu.h new file mode 100644 index 0000000000000000000000000000000000000000..2c403e3f275f472315662321cad54dd0dbc56d00 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.h @@ -0,0 +1,90 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct filtered_lrelu_kernel_params +{ + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void* x; // Input tensor. + void* y; // Output tensor. + const void* b; // Bias tensor. + unsigned char* s; // Sign tensor in/out. NULL if unused. + const float* fu; // Upsampling filter. + const float* fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params +{ + void* x; // Input/output, modified in-place. + unsigned char* s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct filtered_lrelu_kernel_spec +{ + void* setup; // Function for filter kernel setup. + void* exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template void* choose_filtered_lrelu_act_kernel(void); +template cudaError_t copy_filters(cudaStream_t stream); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.py b/third_party/stylegan3_official_ops/filtered_lrelu.py new file mode 100644 index 0000000000000000000000000000000000000000..ec924b630622f9e945baa2d3c674cf158b524005 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.py @@ -0,0 +1,297 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom operators for Leaky ReLU, wrapped with upsampling and downsampling. + +Leaky ReLU will introduce an extremely high frequency into the source feature +map. To solve this problem, an upsampling layer and a downsampling layer are +wrapped around the Leaky ReLU operator. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement +# pylint: disable=multiple-statements +# pylint: disable=inconsistent-quotes + +import os +import warnings +import numpy as np +import torch + +from . import custom_ops +from . import misc +from . import upfirdn2d +from . import bias_act + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='filtered_lrelu_plugin', + sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], + headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + +#---------------------------------------------------------------------------- + +def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): + r"""Filtered leaky ReLU for a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Add channel-specific bias if provided (`b`). + + 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 3. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 5. Multiply each value by the provided gain factor (`gain`). + + 6. Apply leaky ReLU activation function to each value. + + 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. + + 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking + it so that the footprint of all output pixels lies within the input image. + + 9. Downsample the image by keeping every Nth pixel (`down`). + + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + fu: Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + fd: Float32 downsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The length of vector must must match the channel dimension of `x`. + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor. (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + gain: Overall scaling factor for signal magnitude (default: sqrt(2)). + slope: Slope on the negative side of leaky ReLU (default: 0.2). + clamp: Maximum magnitude for leaky ReLU output (default: None). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) + return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using + existing `upfirdn2n()` and `bias_act()` ops. + """ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + fu_w, fu_h = _get_filter_size(fu) + fd_w, fd_h = _get_filter_size(fd) + if b is not None: + assert isinstance(b, torch.Tensor) and b.dtype == x.dtype + misc.assert_shape(b, [x.shape[1]]) + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + assert slope == float(slope) and slope >= 0 + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + + # Calculate output size. + batch_size, channels, in_h, in_w = x.shape + in_dtype = x.dtype + out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + # Compute using existing ops. + x = bias_act.bias_act(x=x, b=b, impl='ref') # Apply bias. + x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter, impl='ref') # Upsample. + x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp, impl='ref') # Bias, leaky ReLU, clamp. + x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter, impl='ref') # Downsample. + + # Check output shape & dtype. + misc.assert_shape(x, [batch_size, channels, out_h, out_w]) + assert x.dtype == in_dtype + return x + +#---------------------------------------------------------------------------- + +_filtered_lrelu_cuda_cache = dict() + +def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Fast CUDA implementation of `filtered_lrelu()` using custom ops. + """ + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + gain = float(gain) + assert slope == float(slope) and slope >= 0 + slope = float(slope) + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + clamp = float(clamp if clamp is not None else 'inf') + + # Lookup from cache. + key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) + if key in _filtered_lrelu_cuda_cache: + return _filtered_lrelu_cuda_cache[key] + + # Forward op. + class FilteredLReluCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + + # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). + if fu is None: + fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if fd is None: + fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert 1 <= fu.ndim <= 2 + assert 1 <= fd.ndim <= 2 + + # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. + if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: + fu = fu.square()[None] + if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: + fd = fd.square()[None] + + # Missing sign input tensor. + if si is None: + si = torch.empty([0]) + + # Missing bias tensor. + if b is None: + b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) + + # Construct internal sign tensor only if gradients are needed. + write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) + + # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. + strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] + if any(a < b for a, b in zip(strides[:-1], strides[1:])): + warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) + + # Call C++/Cuda plugin if datatype is supported. + if x.dtype in [torch.float16, torch.float32]: + if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): + warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) + y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) + else: + return_code = -1 + + # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because + # only the bit-packed sign tensor is retained for gradient computation. + if return_code < 0: + warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) + + y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. + y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. + y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Prepare for gradient computation. + ctx.save_for_backward(fu, fd, (si if si.numel() else so)) + ctx.x_shape = x.shape + ctx.y_shape = y.shape + ctx.s_ofs = sx, sy + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + fu, fd, si = ctx.saved_tensors + _, _, xh, xw = ctx.x_shape + _, _, yh, yw = ctx.y_shape + sx, sy = ctx.s_ofs + dx = None # 0 + dfu = None; assert not ctx.needs_input_grad[1] + dfd = None; assert not ctx.needs_input_grad[2] + db = None # 3 + dsi = None; assert not ctx.needs_input_grad[4] + dsx = None; assert not ctx.needs_input_grad[5] + dsy = None; assert not ctx.needs_input_grad[6] + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: + pp = [ + (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, + xw * up - yw * down + px0 - (up - 1), + (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, + xh * up - yh * down + py0 - (up - 1), + ] + gg = gain * (up ** 2) / (down ** 2) + ff = (not flip_filter) + sx = sx - (fu.shape[-1] - 1) + px0 + sy = sy - (fu.shape[0] - 1) + py0 + dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) + + if ctx.needs_input_grad[3]: + db = dx.sum([0, 2, 3]) + + return dx, dfu, dfd, db, dsi, dsx, dsy + + # Add to cache. + _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda + return FilteredLReluCuda + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement +# pylint: enable=multiple-statements +# pylint: enable=inconsistent-quotes diff --git a/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu b/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef5d948c4fdf9cb0fe8a42f6268c61aeef6b2000 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for no signs mode (no gradients required). + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu b/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu new file mode 100644 index 0000000000000000000000000000000000000000..968347882e9aebd36204f67e201cd16226dd9132 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign read mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu b/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu new file mode 100644 index 0000000000000000000000000000000000000000..a4c6a24aae908bc07248f7ff710cbd1a11a38bb1 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign write mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/third_party/stylegan3_official_ops/fma.py b/third_party/stylegan3_official_ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..26195fdb5d4e0329703b7d6e5578f4d17ec57cde --- /dev/null +++ b/third_party/stylegan3_official_ops/fma.py @@ -0,0 +1,73 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c, impl='cuda'): # => a * b + c + if impl == 'cuda': + return _FusedMultiplyAdd.apply(a, b, c) + return torch.addcmul(c, a, b) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan3_official_ops/grid_sample_gradfix.py b/third_party/stylegan3_official_ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d9cd591a13e146eeeedddbef28871d7c3a0742 --- /dev/null +++ b/third_party/stylegan3_official_ops/grid_sample_gradfix.py @@ -0,0 +1,92 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample`. + +This is useful for differentiable augmentation. This customized operator +supports arbitrarily high order gradients between the input and output. Only +works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and +`align_corners=False`. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import torch + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan3_official_ops/misc.py b/third_party/stylegan3_official_ops/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..1acfb7ea16904c07e362aeaae7337920d06fe5ca --- /dev/null +++ b/third_party/stylegan3_official_ops/misc.py @@ -0,0 +1,283 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Misc functions for customized operations. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=use-maxsplit-arg + +import re +import contextlib +import warnings +from easydict import EasyDict +import numpy as np +import torch + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring +# pylint: enable=use-maxsplit-arg diff --git a/third_party/stylegan3_official_ops/upfirdn2d.cpp b/third_party/stylegan3_official_ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44fa337d8d4c34dfa010a59cd27d86857db671aa --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.cpp @@ -0,0 +1,107 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/upfirdn2d.cu b/third_party/stylegan3_official_ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a33e31bbb1bbc1cd02ee7d2ede3943917f3906e --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.cu @@ -0,0 +1,384 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/upfirdn2d.h b/third_party/stylegan3_official_ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..2793daf874492af01e8634a7863c036e17b6731f --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/upfirdn2d.py b/third_party/stylegan3_official_ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..b4cf0bb8fc299e66997b28cd517b8252619d3f26 --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.py @@ -0,0 +1,404 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom operators for efficient resampling of 2D images. + +`upfirdn` means executing upsampling, FIR filtering, downsampling in sequence. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement + +import os +import numpy as np +import torch + +from . import custom_ops +from . import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='upfirdn2d_plugin', + sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], + headers=['upfirdn2d.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement diff --git a/tmp.png b/tmp.png new file mode 100644 index 0000000000000000000000000000000000000000..4ffbbcd7dd167486aeaedea2afea3b5817262dc9 Binary files /dev/null and b/tmp.png differ diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-37.pyc b/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f5cd8b0c399cd4119a141c1cf5315919cfd07d Binary files /dev/null and b/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..905cd40e9f350b5a7fc1a7085c92ef8587d29d4d Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/__pycache__/dist_utils.cpython-37.pyc b/utils/__pycache__/dist_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d00a2700da54dc1de4bbccdcbfe53a1e6b91e30 Binary files /dev/null and b/utils/__pycache__/dist_utils.cpython-37.pyc differ diff --git a/utils/__pycache__/eg3d_misc.cpython-37.pyc b/utils/__pycache__/eg3d_misc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..521c0177bfb3db8dfa9d4fcbe4ffbedc6f007f26 Binary files /dev/null and b/utils/__pycache__/eg3d_misc.cpython-37.pyc differ diff --git a/utils/__pycache__/eg3d_misc.cpython-39.pyc b/utils/__pycache__/eg3d_misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b6a9a774194de3d19db3f2f75df056defe5fc3d Binary files /dev/null and b/utils/__pycache__/eg3d_misc.cpython-39.pyc differ diff --git a/utils/__pycache__/formatting_utils.cpython-37.pyc b/utils/__pycache__/formatting_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6144a195cdd6aba84c6047a951124ddb95f4efe8 Binary files /dev/null and b/utils/__pycache__/formatting_utils.cpython-37.pyc differ diff --git a/utils/__pycache__/formatting_utils.cpython-39.pyc b/utils/__pycache__/formatting_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e36e7f865df78e3142e6aef0ae2093be3adeba0 Binary files /dev/null and b/utils/__pycache__/formatting_utils.cpython-39.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-37.pyc b/utils/__pycache__/image_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0d27ebd2a22a3ee8a50fc1cdc58f9dbab0b093 Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-37.pyc differ diff --git a/utils/__pycache__/misc.cpython-37.pyc b/utils/__pycache__/misc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eb183154e2cd3a8dd407341a287d80f31846a92 Binary files /dev/null and b/utils/__pycache__/misc.cpython-37.pyc differ diff --git a/utils/__pycache__/misc.cpython-39.pyc b/utils/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21fe3e99b59a7e2cbf99e8a4e88797842be11fa Binary files /dev/null and b/utils/__pycache__/misc.cpython-39.pyc differ diff --git a/utils/__pycache__/parsing_utils.cpython-37.pyc b/utils/__pycache__/parsing_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efc18b710f5fc8c4db656a4df300c6888257136 Binary files /dev/null and b/utils/__pycache__/parsing_utils.cpython-37.pyc differ diff --git a/utils/__pycache__/tf_utils.cpython-37.pyc b/utils/__pycache__/tf_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3032c7fd2feabd9309d6bd24cf97bd3811a20756 Binary files /dev/null and b/utils/__pycache__/tf_utils.cpython-37.pyc differ diff --git a/utils/dist_utils.py b/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c792ba8b34fed30942b79c8054141b98c6273d8f --- /dev/null +++ b/utils/dist_utils.py @@ -0,0 +1,67 @@ +# python3.7 +"""Contains utility functions used for distribution.""" + +import contextlib +import os +import subprocess + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +__all__ = ['init_dist', 'exit_dist', 'ddp_sync', 'get_ddp_module'] + + +def init_dist(launcher, backend='nccl', **kwargs): + """Initializes distributed environment.""" + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + elif launcher == 'slurm': + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + port = os.environ.get('PORT', 29500) + os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + else: + raise NotImplementedError(f'Not implemented launcher type: ' + f'`{launcher}`!') + + +def exit_dist(): + """Exits the distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + +@contextlib.contextmanager +def ddp_sync(model, sync): + """Controls whether the `DistributedDataParallel` model should be synced.""" + assert isinstance(model, torch.nn.Module) + is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) + if sync or not is_ddp: + yield + else: + with model.no_sync(): + yield + + +def get_ddp_module(model): + """Gets the module from `DistributedDataParallel`.""" + assert isinstance(model, torch.nn.Module) + is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) + if is_ddp: + return model.module + return model diff --git a/utils/eg3d_misc.py b/utils/eg3d_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b90b747d6e9dcbd176048db4b0f01fd04ce90f27 --- /dev/null +++ b/utils/eg3d_misc.py @@ -0,0 +1,261 @@ +# python 3.7 +"""Contains some """ + +import re +import contextlib +import numpy as np +import torch +import warnings +import easydict + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(easydict.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/utils/file_transmitters/__init__.py b/utils/file_transmitters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..027642bc8f3bcd6bc992daf4aa419fd4ab464372 --- /dev/null +++ b/utils/file_transmitters/__init__.py @@ -0,0 +1,30 @@ +# python3.7 +"""Collects all file transmitters.""" + +from .local_file_transmitter import LocalFileTransmitter +from .dummy_file_transmitter import DummyFileTransmitter + +__all__ = ['build_file_transmitter'] + +_TRANSMITTERS = { + 'local': LocalFileTransmitter, + 'dummy': DummyFileTransmitter, +} + + +def build_file_transmitter(transmitter_type='local', **kwargs): + """Builds a file transmitter. + + Args: + transmitter_type: Type of the file transmitter_type, which is case + insensitive. (default: `normal`) + **kwargs: Additional arguments to build the file transmitter. + + Raises: + ValueError: If the `transmitter_type` is not supported. + """ + transmitter_type = transmitter_type.lower() + if transmitter_type not in _TRANSMITTERS: + raise ValueError(f'Invalid transmitter type: `{transmitter_type}`!\n' + f'Types allowed: {list(_TRANSMITTERS)}.') + return _TRANSMITTERS[transmitter_type](**kwargs) diff --git a/utils/file_transmitters/__pycache__/__init__.cpython-37.pyc b/utils/file_transmitters/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8289f9e223128389dc2d7850e7665284096e166 Binary files /dev/null and b/utils/file_transmitters/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/file_transmitters/__pycache__/base_file_transmitter.cpython-37.pyc b/utils/file_transmitters/__pycache__/base_file_transmitter.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c2f3b8e411478de5296e5541eb6edb20ed8ccc8 Binary files /dev/null and b/utils/file_transmitters/__pycache__/base_file_transmitter.cpython-37.pyc differ diff --git a/utils/file_transmitters/__pycache__/dummy_file_transmitter.cpython-37.pyc b/utils/file_transmitters/__pycache__/dummy_file_transmitter.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f555d036d6122475331b7e835fb97b6fd4f690be Binary files /dev/null and b/utils/file_transmitters/__pycache__/dummy_file_transmitter.cpython-37.pyc differ diff --git a/utils/file_transmitters/__pycache__/local_file_transmitter.cpython-37.pyc b/utils/file_transmitters/__pycache__/local_file_transmitter.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c263bc0f0e0cb0183d30756e1ddd4654cc575082 Binary files /dev/null and b/utils/file_transmitters/__pycache__/local_file_transmitter.cpython-37.pyc differ diff --git a/utils/file_transmitters/base_file_transmitter.py b/utils/file_transmitters/base_file_transmitter.py new file mode 100644 index 0000000000000000000000000000000000000000..5e93d2e85581378d01a8227feca29219f5c51417 --- /dev/null +++ b/utils/file_transmitters/base_file_transmitter.py @@ -0,0 +1,92 @@ +# python3.7 +"""Contains the base class to transmit files across file systems. + +Basically, a file transmitter connects the local file system, on which the +programme runs, to a remote file system. This is particularly used for +(1) pulling files that are required by the programme from remote, and +(2) pushing results that are produced by the programme to remote. In this way, +the programme can focus on local file system only. + +NOTE: The remote file system can be the same as the local file system, since +users may want to transmit files across directories. +""" + +import warnings + +__all__ = ['BaseFileTransmitter'] + + +class BaseFileTransmitter(object): + """Defines the base file transmitter. + + A transmitter should have the following functions: + + (1) pull(): The function to pull a file/directory from remote to local. + (2) push(): The function to push a file/directory from local to remote. + (3) remove(): The function to remove a file/directory. + (4) make_remote_dir(): Make directory remotely. + + + To simplify, each derived class just need to implement the following helper + functions: + + (1) download_hard(): Hard download a file/directory from remote to local. + (2) download_soft(): Soft download a file/directory from remote to local. + This is especially used to save space (e.g., soft link). + (3) upload(): Upload a file/directory from local to remote. + (4) delete(): Delete a file/directory according to given path. + """ + + def __init__(self): + pass + + @property + def name(self): + """Returns the class name of the file transmitter.""" + return self.__class__.__name__ + + @staticmethod + def download_hard(src, dst): + """Downloads (in hard mode) a file/directory from remote to local.""" + raise NotImplementedError('Should be implemented in derived class!') + + @staticmethod + def download_soft(src, dst): + """Downloads (in soft mode) a file/directory from local to remote.""" + raise NotImplementedError('Should be implemented in derived class!') + + @staticmethod + def upload(src, dst): + """Uploads a file/directory from local to remote.""" + raise NotImplementedError('Should be implemented in derived class!') + + @staticmethod + def delete(path): + """Deletes the given path.""" + # TODO: should we secure the path to avoid mis-removing / attacks? + raise NotImplementedError('Should be implemented in derived class!') + + def pull(self, src, dst, hard=False): + """Pulls a file/directory from remote to local. + + The argument `hard` is to control the download mode (hard or soft). + For example, the hard mode may hardly copy the file while the soft mode + may softly link the file. + """ + if hard: + self.download_hard(src, dst) + else: + self.download_soft(src, dst) + + def push(self, src, dst): + """Pushes a file/directory from local to remote.""" + self.upload(src, dst) + + def remove(self, path): + """Removes the given path.""" + warnings.warn(f'`{path}` will be removed!') + self.delete(path) + + def make_remote_dir(self, directory): + """Makes a directory on the remote system.""" + raise NotImplementedError('Should be implemented in derived class!') diff --git a/utils/file_transmitters/dummy_file_transmitter.py b/utils/file_transmitters/dummy_file_transmitter.py new file mode 100644 index 0000000000000000000000000000000000000000..c553f4082061da9e6d8194dbbc2ce16f7a122554 --- /dev/null +++ b/utils/file_transmitters/dummy_file_transmitter.py @@ -0,0 +1,34 @@ +# python3.7 +"""Contains the class of dummy file transmitter. + +This file transmitter has all expected data transmission functions but behaves +silently, which is very useful in multi-processing mode. Only the chief process +can have the file transmitter with normal behavior. +""" + +from .base_file_transmitter import BaseFileTransmitter + +__all__ = ['DummyFileTransmitter'] + + +class DummyFileTransmitter(BaseFileTransmitter): + """Implements a dummy transmitter which transmits nothing.""" + + @staticmethod + def download_hard(src, dst): + return + + @staticmethod + def download_soft(src, dst): + return + + @staticmethod + def upload(src, dst): + return + + @staticmethod + def delete(path): + return + + def make_remote_dir(self, directory): + return diff --git a/utils/file_transmitters/local_file_transmitter.py b/utils/file_transmitters/local_file_transmitter.py new file mode 100644 index 0000000000000000000000000000000000000000..562becf65ce0052559109300557c8c8de2e142b6 --- /dev/null +++ b/utils/file_transmitters/local_file_transmitter.py @@ -0,0 +1,35 @@ +# python3.7 +"""Contains the class of local file transmitter. + +The transmitter builds the connection between the local file system and itself. +This can be used to transmit files from one directory to another. Consequently, +`remote` in this file also means `local`. +""" + +from utils.misc import print_and_execute +from .base_file_transmitter import BaseFileTransmitter + +__all__ = ['LocalFileTransmitter'] + + +class LocalFileTransmitter(BaseFileTransmitter): + """Implements the transmitter connecting local file system to itself.""" + + @staticmethod + def download_hard(src, dst): + print_and_execute(f'cp {src} {dst}') + + @staticmethod + def download_soft(src, dst): + print_and_execute(f'ln -s {src} {dst}') + + @staticmethod + def upload(src, dst): + print_and_execute(f'cp {src} {dst}') + + @staticmethod + def delete(path): + print_and_execute(f'rm -r {path}') + + def make_remote_dir(self, directory): + print_and_execute(f'mkdir -p {directory}') diff --git a/utils/formatting_utils.py b/utils/formatting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20f9f14050da889b7b9be0867e9373ff54ebe42d --- /dev/null +++ b/utils/formatting_utils.py @@ -0,0 +1,178 @@ +# python3.7 +"""Contains utility functions used for formatting.""" + +import cv2 +import numpy as np + +__all__ = [ + 'format_time', 'format_range', 'format_image_size', 'format_image', + 'raw_label_to_one_hot', 'one_hot_to_raw_label' +] + + +def format_time(seconds): + """Formats seconds to readable time string. + + Args: + seconds: Number of seconds to format. + + Returns: + The formatted time string. + + Raises: + ValueError: If the input `seconds` is less than 0. + """ + if seconds < 0: + raise ValueError(f'Input `seconds` should be greater than or equal to ' + f'0, but `{seconds}` is received!') + + # Returns seconds as float if less than 1 minute. + if seconds < 10: + return f'{seconds:7.3f} s' + if seconds < 60: + return f'{seconds:7.2f} s' + + seconds = int(seconds + 0.5) + days, seconds = divmod(seconds, 86400) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if days: + return f'{days:2d} d {hours:02d} h' + if hours: + return f'{hours:2d} h {minutes:02d} m' + return f'{minutes:2d} m {seconds:02d} s' + + +def format_range(obj, min_val=None, max_val=None): + """Formats the given object to a valid range. + + If `min_val` or `max_val` is provided, both the starting value and the end + value will be clamped to range `[min_val, max_val]`. + + NOTE: (a, b) is regarded as a valid range if and only if `a <= b`. + + Args: + obj: The input object to format. + min_val: The minimum value to cut off the input range. If not provided, + the default minimum value is negative infinity. (default: None) + max_val: The maximum value to cut off the input range. If not provided, + the default maximum value is infinity. (default: None) + + Returns: + A two-elements tuple, indicating the start and the end of the range. + + Raises: + ValueError: If the input object is an invalid range. + """ + if not isinstance(obj, (tuple, list)): + raise ValueError(f'Input object must be a tuple or a list, ' + f'but `{type(obj)}` received!') + if len(obj) != 2: + raise ValueError(f'Input object is expected to contain two elements, ' + f'but `{len(obj)}` received!') + if obj[0] > obj[1]: + raise ValueError(f'The second element is expected to be equal to or ' + f'greater than the first one, ' + f'but `({obj[0]}, {obj[1]})` received!') + + obj = list(obj) + if min_val is not None: + obj[0] = max(obj[0], min_val) + obj[1] = max(obj[1], min_val) + if max_val is not None: + obj[0] = min(obj[0], max_val) + obj[1] = min(obj[1], max_val) + return tuple(obj) + + +def format_image_size(size): + """Formats the given image size to a two-element tuple. + + A valid image size can be an integer, indicating both the height and the + width, OR can be a two-element list or tuple. Both height and width are + assumed to be positive integer. + + Args: + size: The input size to format. + + Returns: + A two-elements tuple, indicating the height and the width, respectively. + + Raises: + ValueError: If the input size is invalid. + """ + if not isinstance(size, (int, tuple, list)): + raise ValueError(f'Input size must be an integer, a tuple, or a list, ' + f'but `{type(size)}` received!') + if isinstance(size, int): + size = (size, size) + else: + if len(size) == 1: + size = (size[0], size[0]) + if not len(size) == 2: + raise ValueError(f'Input size is expected to have two numbers at ' + f'most, but `{len(size)}` numbers received!') + if not isinstance(size[0], int) or size[0] < 0: + raise ValueError(f'The height is expected to be a non-negative ' + f'integer, but `{size[0]}` received!') + if not isinstance(size[1], int) or size[1] < 0: + raise ValueError(f'The width is expected to be a non-negative ' + f'integer, but `{size[1]}` received!') + return tuple(size) + + +def format_image(image): + """Formats an image read from `cv2`. + + NOTE: This function will always return a 3-dimensional image (i.e., with + shape [H, W, C]) in pixel range [0, 255]. For color images, the channel + order of the input is expected to be with `BGR` or `BGRA`, which is the + raw image decoded by `cv2`; while the channel order of the output is set to + `RGB` or `RGBA` by default. + + Args: + image: `np.ndarray`, an image read by `cv2.imread()` or + `cv2.imdecode()`. + + Returns: + An image with shape [H, W, C] (where `C = 1` for grayscale image). + """ + if image.ndim == 2: # add additional axis if given a grayscale image + image = image[:, :, np.newaxis] + + assert isinstance(image, np.ndarray) + assert image.dtype == np.uint8 + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + + if image.shape[2] == 3: # BGR image + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: # BGRA image + return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image + + +def raw_label_to_one_hot(raw_label, num_classes): + """Converts a single label into one-hot vector. + + Args: + raw_label: The raw label. + num_classes: Total number of classes. + + Returns: + one-hot vector of the given raw label. + """ + one_hot = np.zeros(num_classes, dtype=np.float32) + one_hot[raw_label] = 1.0 + return one_hot + + +def one_hot_to_raw_label(one_hot): + """Converts a one-hot vector to a single value label. + + Args: + one_hot: `np.ndarray`, a one-hot encoded vector. + + Returns: + A single integer to represent the category. + """ + return np.argmax(one_hot) diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c640ac5ef977e3a7824dabbc43e5d56e733d0d76 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,332 @@ +# python3.7 +"""Contains utility functions for image processing. + +The module is primarily built on `cv2`. But, differently, we assume all colorful +images are with `RGB` channel order by default. Also, we assume all gray-scale +images to be with shape [height, width, 1]. +""" + +import os +import cv2 +import numpy as np + +from .misc import IMAGE_EXTENSIONS +from .misc import check_file_ext + +__all__ = [ + 'get_blank_image', 'load_image', 'save_image', 'resize_image', + 'add_text_to_image', 'preprocess_image', 'postprocess_image', + 'parse_image_size', 'get_grid_shape', 'list_images_from_dir' +] + + +def _check_2d_image(image): + """Checks whether a given image is valid. + + A valid image is expected to be with dtype `uint8`. Also, it should have + shape like: + + (1) (height, width, 1) # gray-scale image. + (2) (height, width, 3) # colorful image. + (3) (height, width, 4) # colorful image with transparency (RGBA) + """ + assert isinstance(image, np.ndarray) + assert image.dtype == np.uint8 + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + + +def get_blank_image(height, width, channels=3, use_black=True): + """Gets a blank image, either white of black. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + height: Height of the returned image. + width: Width of the returned image. + channels: Number of channels. (default: 3) + use_black: Whether to return a black image. (default: True) + """ + shape = (height, width, channels) + if use_black: + return np.zeros(shape, dtype=np.uint8) + return np.ones(shape, dtype=np.uint8) * 255 + + +def load_image(path): + """Loads an image from disk. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + path: Path to load the image from. + + Returns: + An image with dtype `np.ndarray`, or `None` if `path` does not exist. + """ + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if image is None: + return None + + if image.ndim == 2: + image = image[:, :, np.newaxis] + _check_2d_image(image) + if image.shape[2] == 3: + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: + return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image + + +def save_image(path, image): + """Saves an image to disk. + + NOTE: The input image (if colorful) is assumed to be with `RGB` channel + order and pixel range [0, 255]. + + Args: + path: Path to save the image to. + image: Image to save. + """ + if image is None: + return + + _check_2d_image(image) + if image.shape[2] == 1: + cv2.imwrite(path, image) + elif image.shape[2] == 3: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + elif image.shape[2] == 4: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) + + +def resize_image(image, *args, **kwargs): + """Resizes image. + + This is a wrap of `cv2.resize()`. + + NOTE: The channel order of the input image will not be changed. + + Args: + image: Image to resize. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + An image with dtype `np.ndarray`, or `None` if `image` is empty. + """ + if image is None: + return None + + _check_2d_image(image) + if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image. + return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] + return cv2.resize(image, *args, **kwargs) + + +def add_text_to_image(image, + text='', + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255)): + """Overlays text on given image. + + NOTE: The input image is assumed to be with `RGB` channel order. + + Args: + image: The image to overlay text on. + text: Text content to overlay on the image. (default: empty) + position: Target position (bottom-left corner) to add text. If not set, + center of the image will be used by default. (default: None) + font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) + font_size: Font size of the text added. (default: 1.0) + line_type: Line type used to depict the text. (default: cv2.LINE_8) + line_width: Line width used to depict the text. (default: 1) + color: Color of the text added in `RGB` channel order. (default: + (255, 255, 255)) + + Returns: + An image with target text overlaid on. + """ + if image is None or not text: + return image + + _check_2d_image(image) + cv2.putText(img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False) + return image + + +def preprocess_image(image, min_val=-1.0, max_val=1.0): + """Pre-processes image by adjusting the pixel range and to dtype `float32`. + + This function is particularly used to convert an image or a batch of images + to `NCHW` format, which matches the data type commonly used in deep models. + + NOTE: The input image is assumed to be with pixel range [0, 255] and with + format `HWC` or `NHWC`. The returned image will be always be with format + `NCHW`. + + Args: + image: The input image for pre-processing. + min_val: Minimum value of the output image. + max_val: Maximum value of the output image. + + Returns: + The pre-processed image. + """ + assert isinstance(image, np.ndarray) + + image = image.astype(np.float64) + image = image / 255.0 * (max_val - min_val) + min_val + + if image.ndim == 3: + image = image[np.newaxis] + assert image.ndim == 4 and image.shape[3] in [1, 3, 4] + return image.transpose(0, 3, 1, 2) + + +def postprocess_image(image, min_val=-1.0, max_val=1.0): + """Post-processes image to pixel range [0, 255] with dtype `uint8`. + + This function is particularly used to handle the results produced by deep + models. + + NOTE: The input image is assumed to be with format `NCHW`, and the returned + image will always be with format `NHWC`. + + Args: + image: The input image for post-processing. + min_val: Expected minimum value of the input image. + max_val: Expected maximum value of the input image. + + Returns: + The post-processed image. + """ + assert isinstance(image, np.ndarray) + + image = image.astype(np.float64) + image = (image - min_val) / (max_val - min_val) * 255 + image = np.clip(image + 0.5, 0, 255).astype(np.uint8) + + assert image.ndim == 4 and image.shape[1] in [1, 3, 4] + return image.transpose(0, 2, 3, 1) + + +def parse_image_size(obj): + """Parses an object to a pair of image size, i.e., (height, width). + + Args: + obj: The input object to parse image size from. + + Returns: + A two-element tuple, indicating image height and width respectively. + + Raises: + If the input is invalid, i.e., neither a list or tuple, nor a string. + """ + if obj is None or obj == '': + height = 0 + width = 0 + elif isinstance(obj, int): + height = obj + width = obj + elif isinstance(obj, (list, tuple, str, np.ndarray)): + if isinstance(obj, str): + splits = obj.replace(' ', '').split(',') + numbers = tuple(map(int, splits)) + else: + numbers = tuple(obj) + if len(numbers) == 0: + height = 0 + width = 0 + elif len(numbers) == 1: + height = int(numbers[0]) + width = int(numbers[0]) + elif len(numbers) == 2: + height = int(numbers[0]) + width = int(numbers[1]) + else: + raise ValueError('At most two elements for image size.') + else: + raise ValueError(f'Invalid type of input: `{type(obj)}`!') + + return (max(0, height), max(0, width)) + + +def get_grid_shape(size, height=0, width=0, is_portrait=False): + """Gets the shape of a grid based on the size. + + This function makes greatest effort on making the output grid square if + neither `height` nor `width` is set. If `is_portrait` is set as `False`, the + height will always be equal to or smaller than the width. For example, if + input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, + output shape will be (3, 5). Otherwise, the height will always be equal to + or larger than the width. + + Args: + size: Size (height * width) of the target grid. + height: Expected height. If `size % height != 0`, this field will be + ignored. (default: 0) + width: Expected width. If `size % width != 0`, this field will be + ignored. (default: 0) + is_portrait: Whether to return a portrait size of a landscape size. + (default: False) + + Returns: + A two-element tuple, representing height and width respectively. + """ + assert isinstance(size, int) + assert isinstance(height, int) + assert isinstance(width, int) + if size <= 0: + return (0, 0) + + if height > 0 and width > 0 and height * width != size: + height = 0 + width = 0 + + if height > 0 and width > 0 and height * width == size: + return (height, width) + if height > 0 and size % height == 0: + return (height, size // height) + if width > 0 and size % width == 0: + return (size // width, width) + + height = int(np.sqrt(size)) + while height > 0: + if size % height == 0: + width = size // height + break + height = height - 1 + + return (width, height) if is_portrait else (height, width) + + +def list_images_from_dir(directory): + """Lists all images from the given directory. + + NOTE: Do NOT support finding images recursively. + + Args: + directory: The directory to find images from. + + Returns: + A list of sorted filenames, with the directory as prefix. + """ + image_list = [] + for filename in os.listdir(directory): + if check_file_ext(filename, *IMAGE_EXTENSIONS): + image_list.append(os.path.join(directory, filename)) + return sorted(image_list) diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..665fd01dc34ae7a520dadfe4581c97e59dd6affe --- /dev/null +++ b/utils/loggers/__init__.py @@ -0,0 +1,32 @@ +# python3.7 +"""Collects all loggers.""" + +from .normal_logger import NormalLogger +from .rich_logger import RichLogger +from .dummy_logger import DummyLogger + +__all__ = ['build_logger'] + +_LOGGERS = { + 'normal': NormalLogger, + 'rich': RichLogger, + 'dummy': DummyLogger +} + + +def build_logger(logger_type='normal', **kwargs): + """Builds a logger. + + Args: + logger_type: Type of logger, which is case insensitive. + (default: `normal`) + **kwargs: Additional arguments to build the logger. + + Raises: + ValueError: If the `logger_type` is not supported. + """ + logger_type = logger_type.lower() + if logger_type not in _LOGGERS: + raise ValueError(f'Invalid logger type: `{logger_type}`!\n' + f'Types allowed: {list(_LOGGERS)}.') + return _LOGGERS[logger_type](**kwargs) diff --git a/utils/loggers/__pycache__/__init__.cpython-37.pyc b/utils/loggers/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de9b055ec6cbb7b1ce42998929c7536565cf154 Binary files /dev/null and b/utils/loggers/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/loggers/__pycache__/base_logger.cpython-37.pyc b/utils/loggers/__pycache__/base_logger.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf478f2f13ccd3e59cf85379cf7b39616c0ca9f4 Binary files /dev/null and b/utils/loggers/__pycache__/base_logger.cpython-37.pyc differ diff --git a/utils/loggers/__pycache__/dummy_logger.cpython-37.pyc b/utils/loggers/__pycache__/dummy_logger.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2c4d74bc5b08b84bc4bebcbdfcf3bd874ff6daf Binary files /dev/null and b/utils/loggers/__pycache__/dummy_logger.cpython-37.pyc differ diff --git a/utils/loggers/__pycache__/normal_logger.cpython-37.pyc b/utils/loggers/__pycache__/normal_logger.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d7c0908ce936e924aefbbbbf249f3bd3b38e2e8 Binary files /dev/null and b/utils/loggers/__pycache__/normal_logger.cpython-37.pyc differ diff --git a/utils/loggers/__pycache__/rich_logger.cpython-37.pyc b/utils/loggers/__pycache__/rich_logger.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7630ff830a16563912da66a7faeec2bc5d414b07 Binary files /dev/null and b/utils/loggers/__pycache__/rich_logger.cpython-37.pyc differ diff --git a/utils/loggers/base_logger.py b/utils/loggers/base_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..c08fa7fec115a54fd6fff578af5cf0229e5395e8 --- /dev/null +++ b/utils/loggers/base_logger.py @@ -0,0 +1,258 @@ +# python3.7 +"""Contains the base class for logging. + +Basically, this is an interface bridging the program and the local file system. +A logger is able to log wrapped message onto the screen and a log file. +""" + +import logging + +__all__ = ['BaseLogger'] + + +class BaseLogger(object): + """Defines the base logger. + + A logger should have the following members: + + (1) logger: The logger to record message. + (2) pbar: The progressive bar (shown on the screen only). + (3) pbar_kwargs: The arguments for the progressive bar. + (4) file_stream: The stream to log messages into if needed. + + A logger should have the following functions: + + (1) log(): The base function to log message. + (2) debug(): The function to log message with `DEBUG` level. + (3) info(): The function to log message with `INFO` level. + (4) warning(): The function to log message with `WARNING` level. + (5) warn(): Same as function `warning()`. + (6) error(): The function to log message with `ERROR` level. + (7) exception(): The function to log message with exception information. + (8) critical(): The function to log message with `CRITICAL` level. + (9) fatal(): Same as function `critical()`. + (10) print(): The function to print the message without any decoration. + (11) init_pbar(): The function to initialize the progressive bar. + (12) add_pbar_task(): The function to add a task to the progressive bar. + (13) update_pbar(): The function to update the progressive bar. + (14) close_pbar(): The function to close the progressive bar. + + The logger will record log message both on screen and to file. + + Args: + logger_name: Unique name for the logger. (default: `logger`) + logfile: Path to the log file. If set as `None`, the file stream + will be skipped. (default: `None`) + screen_level: Minimum level of message to log onto screen. + (default: `logging.INFO`) + file_level: Minimum level of message to log into file. + (default: `logging.DEBUG`) + indent_space: Number of spaces between two adjacent indent levels. + (default: 4) + verbose_log: Whether to log verbose message. (default: False) + """ + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=logging.INFO, + file_level=logging.DEBUG, + indent_space=4, + verbose_log=False): + self.logger_name = logger_name + self.logfile = logfile + self.screen_level = screen_level + self.file_level = file_level + self.indent_space = indent_space + self.verbose_log = verbose_log + + self.logger = None + self.pbar = None + self.pbar_kwargs = None + self.file_stream = None + + self.warn = self.warning + self.fatal = self.critical + + def __del__(self): + self.close() + + def close(self): + """Closes the logger.""" + if self.file_stream is not None: + self.file_stream.close() + + @property + def name(self): + """Returns the class name of the logger.""" + return self.__class__.__name__ + + # Log message. + def wrap_message(self, message, indent_level=0): + """Wraps the message with indent.""" + if message is None: + message = '' + assert isinstance(message, str) + assert isinstance(indent_level, int) and indent_level >= 0 + if message == '': + return '' + return ' ' * (indent_level * self.indent_space) + message + + def _log(self, message, **kwargs): + """Logs wrapped message.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _debug(self, message, **kwargs): + """Logs wrapped message with `DEBUG` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _info(self, message, **kwargs): + """Logs wrapped message with `INFO` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _warning(self, message, **kwargs): + """Logs wrapped message with `WARNING` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _error(self, message, **kwargs): + """Logs wrapped message with `ERROR` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _exception(self, message, **kwargs): + """Logs wrapped message with exception information.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _critical(self, message, **kwargs): + """Logs wrapped message with `CRITICAL` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _print(self, *messages, **kwargs): + """Prints wrapped message without any decoration.""" + raise NotImplementedError('Should be implemented in derived class!') + + def log(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._log(message, **kwargs) + + def debug(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `DEBUG` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._debug(message, **kwargs) + + def info(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `INFO` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._info(message, **kwargs) + + def warning(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `WARNING` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._warning(message, **kwargs) + + def error(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `ERROR` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._error(message, **kwargs) + + def exception(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with exception information. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._exception(message, **kwargs) + + def critical(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `CRITICAL` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._critical(message, **kwargs) + + def print(self, *messages, indent_level=0, is_verbose=False, **kwargs): + """Prints message without any decoration. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + new_messages = [] + for message in messages: + new_messages.append( + self.wrap_message(message, indent_level=indent_level)) + self._print(*new_messages, **kwargs) + + # Progressive bar. + def init_pbar(self, leave=False): + """Initializes the progressive bar. + + Args: + leave: Whether to leave the trace of the progressive bar. + (default: False) + """ + raise NotImplementedError('Should be implemented in derived class!') + + def add_pbar_task(self, name, total, **kwargs): + """Adds a task to the progressive bar. + + Args: + name: Name of the added task. + total: Total number of steps (samples) contained in the task. + **kwargs: Additional arguments. + + Returns: + Task ID. + """ + raise NotImplementedError('Should be implemented in derived class!') + + def update_pbar(self, task_id, advance=1): + """Updates the progressive bar. + + Args: + task_id: ID of the task to update. + advance: Number of steps advanced onto the target task. (default: 1) + """ + raise NotImplementedError('Should be implemented in derived class!') + + def close_pbar(self): + """Closes the progress bar.""" + raise NotImplementedError('Should be implemented in derived class!') diff --git a/utils/loggers/dummy_logger.py b/utils/loggers/dummy_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6220e6757c6ce4516834f5102cd0957f8669df --- /dev/null +++ b/utils/loggers/dummy_logger.py @@ -0,0 +1,65 @@ +# python3.7 +"""Contains the class of dummy logger. + +This logger has all expected logging functions but behaves silently, which is +very useful in multi-processing mode. Only the chief process can have the logger +with normal behavior. +""" + +from .base_logger import BaseLogger + +__all__ = ['DummyLogger'] + + +class DummyLogger(BaseLogger): + """Implements a dummy logger which logs nothing.""" + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=None, + file_level=None, + indent_space=4, + verbose_log=False): + super().__init__(logger_name=logger_name, + logfile=logfile, + screen_level=screen_level, + file_level=file_level, + indent_space=indent_space, + verbose_log=verbose_log) + + def _log(self, message, **kwargs): + return + + def _debug(self, message, **kwargs): + return + + def _info(self, message, **kwargs): + return + + def _warning(self, message, **kwargs): + return + + def _error(self, message, **kwargs): + return + + def _exception(self, message, **kwargs): + return + + def _critical(self, message, **kwargs): + return + + def _print(self, *messages, **kwargs): + return + + def init_pbar(self, leave=False): + return + + def add_pbar_task(self, name, total, **kwargs): + return -1 + + def update_pbar(self, task_id, advance=1): + return + + def close_pbar(self): + return diff --git a/utils/loggers/normal_logger.py b/utils/loggers/normal_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..57f66d21afce89d6f4dcb1bfd7fd7ddc3b441e31 --- /dev/null +++ b/utils/loggers/normal_logger.py @@ -0,0 +1,124 @@ +# python3.7 +"""Contains the class of normal logger. + +This class is built based on the built-in function `print()`, the module +`logging` and the module `tqdm` for progressive bar. +""" + +import sys +import logging +from copy import deepcopy +from tqdm import tqdm + +from .base_logger import BaseLogger + +__all__ = ['NormalLogger'] + + +class NormalLogger(BaseLogger): + """Implements the logger based on `logging` module.""" + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=logging.INFO, + file_level=logging.DEBUG, + indent_space=4, + verbose_log=False): + super().__init__(logger_name=logger_name, + logfile=logfile, + screen_level=screen_level, + file_level=file_level, + indent_space=indent_space, + verbose_log=verbose_log) + + # Get logger and check whether the logger has already been created. + self.logger = logging.getLogger(self.logger_name) + self.logger.propagate = False + if self.logger.hasHandlers(): # Already existed + raise SystemExit(f'Logger `{self.logger_name}` has already ' + f'existed!\n' + f'Please use another name, or otherwise the ' + f'messages may be mixed up.') + + # Set format. + self.logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '[%(asctime)s][%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + # Print log message onto the screen. + terminal_handler = logging.StreamHandler(stream=sys.stdout) + terminal_handler.setLevel(self.screen_level) + terminal_handler.setFormatter(formatter) + self.logger.addHandler(terminal_handler) + + # Save log message into log file if needed. + if self.logfile: + # File will be closed when the logger is closed in `self.close()`. + self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with + file_handler = logging.StreamHandler(stream=self.file_stream) + file_handler.setLevel(self.file_level) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + self.pbar = [] + self.pbar_kwargs = {} + + def _log(self, message, **kwargs): + self.logger.log(message, **kwargs) + + def _debug(self, message, **kwargs): + self.logger.debug(message, **kwargs) + + def _info(self, message, **kwargs): + self.logger.info(message, **kwargs) + + def _warning(self, message, **kwargs): + self.logger.warning(message, **kwargs) + + def _error(self, message, **kwargs): + self.logger.error(message, **kwargs) + + def _exception(self, message, **kwargs): + self.logger.exception(message, **kwargs) + + def _critical(self, message, **kwargs): + self.logger.critical(message, **kwargs) + + def _print(self, *messages, **kwargs): + for handler in self.logger.handlers: + print(*messages, file=handler.stream) + + def init_pbar(self, leave=False): + columns = [ + '{desc}', + '{bar}', + ' {percentage:5.1f}%', + '[{elapsed}<{remaining}, {rate_fmt}{postfix}]', + ] + self.pbar_kwargs = dict( + leave=leave, + bar_format=' '.join(columns), + unit='', + ) + + def add_pbar_task(self, name, total, **kwargs): + assert isinstance(self.pbar_kwargs, dict) + pbar_kwargs = deepcopy(self.pbar_kwargs) + pbar_kwargs.update(**kwargs) + self.pbar.append(tqdm(desc=name, total=total, **pbar_kwargs)) + return len(self.pbar) - 1 + + def update_pbar(self, task_id, advance=1): + assert len(self.pbar) > task_id and isinstance(self.pbar[task_id], tqdm) + if self.pbar[task_id].n < self.pbar[task_id].total: + self.pbar[task_id].update(advance) + if self.pbar[task_id].n >= self.pbar[task_id].total: + self.pbar[task_id].refresh() + + def close_pbar(self): + for pbar in self.pbar[::-1]: + pbar.close() + self.pbar = [] + self.pbar_kwargs = {} diff --git a/utils/loggers/rich_logger.py b/utils/loggers/rich_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..38c09d2b0c898aff394415a639742a623b07178a --- /dev/null +++ b/utils/loggers/rich_logger.py @@ -0,0 +1,177 @@ +# python3.7 +"""Contains the class of rich logger. + +This class is based on the module `rich`. Please refer to +https://github.com/Textualize/rich for more details. +""" + +import sys +import logging +from copy import deepcopy +from rich.console import Console +from rich.logging import RichHandler +from rich.progress import Progress +from rich.progress import ProgressColumn +from rich.progress import TextColumn +from rich.progress import BarColumn +from rich.text import Text + +from .base_logger import BaseLogger + +__all__ = ['RichLogger'] + + +def _format_time(seconds): + """Formats seconds to readable time string. + + This function is used to display time in progress bar. + """ + if not seconds: + return '--:--' + + seconds = int(seconds) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if hours: + return f'{hours}:{minutes:02d}:{seconds:02d}' + return f'{minutes:02d}:{seconds:02d}' + + +class TimeColumn(ProgressColumn): + """Renders total time, ETA, and speed in progress bar.""" + + max_refresh = 0.5 # Only refresh twice a second to prevent jitter + + def render(self, task): + elapsed_time = _format_time(task.elapsed) + eta = _format_time(task.time_remaining) + speed = f'{task.speed:.2f}/s' if task.speed else '?/s' + return Text(f'[{elapsed_time}<{eta}, {speed}]', + style='progress.remaining') + + +class RichLogger(BaseLogger): + """Implements the logger based on `rich` module.""" + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=logging.INFO, + file_level=logging.DEBUG, + indent_space=4, + verbose_log=False): + super().__init__(logger_name=logger_name, + logfile=logfile, + screen_level=screen_level, + file_level=file_level, + indent_space=indent_space, + verbose_log=verbose_log) + + # Get logger and check whether the logger has already been created. + self.logger = logging.getLogger(self.logger_name) + self.logger.propagate = False + if self.logger.hasHandlers(): # Already existed + raise SystemExit(f'Logger `{self.logger_name}` has already ' + f'existed!\n' + f'Please use another name, or otherwise the ' + f'messages may be mixed up.') + + # Set format. + self.logger.setLevel(logging.DEBUG) + + # Print log message onto the screen. + terminal_console = Console( + file=sys.stdout, log_time=False, log_path=False) + terminal_handler = RichHandler( + level=self.screen_level, + console=terminal_console, + show_time=True, + show_level=True, + show_path=False, + log_time_format='[%Y-%m-%d %H:%M:%S] ') + terminal_handler.setFormatter(logging.Formatter('%(message)s')) + self.logger.addHandler(terminal_handler) + + # Save log message into log file if needed. + if self.logfile: + # File will be closed when the logger is closed in `self.close()`. + self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with + file_console = Console( + file=self.file_stream, log_time=False, log_path=False) + file_handler = RichHandler( + level=self.file_level, + console=file_console, + show_time=True, + show_level=True, + show_path=False, + log_time_format='[%Y-%m-%d %H:%M:%S] ') + file_handler.setFormatter(logging.Formatter('%(message)s')) + self.logger.addHandler(file_handler) + + self.pbar = None + self.pbar_kwargs = {} + + def _log(self, message, **kwargs): + self.logger.log(message, **kwargs) + + def _debug(self, message, **kwargs): + self.logger.debug(message, **kwargs) + + def _info(self, message, **kwargs): + self.logger.info(message, **kwargs) + + def _warning(self, message, **kwargs): + self.logger.warning(message, **kwargs) + + def _error(self, message, **kwargs): + self.logger.error(message, **kwargs) + + def _exception(self, message, **kwargs): + self.logger.exception(message, **kwargs) + + def _critical(self, message, **kwargs): + self.logger.critical(message, **kwargs) + + def _print(self, *messages, **kwargs): + for handler in self.logger.handlers: + handler.console.print(*messages, **kwargs) + + def init_pbar(self, leave=False): + assert self.pbar is None + + # Columns shown in the progress bar. + columns = ( + TextColumn('[progress.description]{task.description}'), + BarColumn(bar_width=None), + TextColumn('[progress.percentage]{task.percentage:>5.1f}%'), + TimeColumn(), + ) + + self.pbar = Progress(*columns, + console=self.logger.handlers[0].console, + transient=not leave, + auto_refresh=True, + refresh_per_second=10) + self.pbar.start() + + def add_pbar_task(self, name, total, **kwargs): + assert isinstance(self.pbar, Progress) + assert isinstance(self.pbar_kwargs, dict) + pbar_kwargs = deepcopy(self.pbar_kwargs) + pbar_kwargs.update(**kwargs) + task_id = self.pbar.add_task(name, total=total, **pbar_kwargs) + return task_id + + def update_pbar(self, task_id, advance=1): + assert isinstance(self.pbar, Progress) + if self.pbar.tasks[task_id].finished: + if self.pbar.tasks[task_id].stop_time is None: + self.pbar.stop_task(task_id) + else: + self.pbar.update(task_id, advance=advance) + + def close_pbar(self): + assert isinstance(self.pbar, Progress) + self.pbar.stop() + self.pbar = None + self.pbar_kwargs = {} diff --git a/utils/loggers/test.py b/utils/loggers/test.py new file mode 100644 index 0000000000000000000000000000000000000000..096f7fd9b32458ac88b551f7acaf676cdc13be4f --- /dev/null +++ b/utils/loggers/test.py @@ -0,0 +1,63 @@ +# python3.7 +"""Unit test for logger.""" + +import os +import time + +from . import build_logger + +__all__ = ['test_logger'] + +_TEST_DIR = 'logger_test' + + +def test_logger(test_dir=_TEST_DIR): + """Tests loggers.""" + print('========== Start Logger Test ==========') + + os.makedirs(test_dir, exist_ok=True) + + for logger_type in ['normal', 'rich', 'dummy']: + for indent_space in [2, 4]: + for verbose_log in [False, True]: + if logger_type == 'normal': + class_name = 'Logger' + elif logger_type == 'rich': + class_name = 'RichLogger' + elif logger_type == 'dummy': + class_name = 'DummyLogger' + + print(f'===== ' + f'Testing `utils.logger.{class_name}` ' + f' (indent: {indent_space}, verbose: {verbose_log}) ' + f'=====') + logger_name = (f'{logger_type}_logger_' + f'indent_{indent_space}_' + f'verbose_{verbose_log}') + logger = build_logger( + logger_type, + logger_name=logger_name, + logfile=os.path.join(test_dir, f'test_{logger_name}.log'), + verbose_log=verbose_log, + indent_space=indent_space) + logger.print('print log') + logger.print('print log,', 'log 2') + logger.print('print log (indent level 0)', indent_level=0) + logger.print('print log (indent level 1)', indent_level=1) + logger.print('print log (indent level 2)', indent_level=2) + logger.print('print log (verbose `False`)', is_verbose=False) + logger.print('print log (verbose `True`)', is_verbose=True) + logger.debug('debug log') + logger.info('info log') + logger.warning('warning log') + logger.init_pbar() + task_1 = logger.add_pbar_task('Task 1', 500) + task_2 = logger.add_pbar_task('Task 2', 1000) + for _ in range(1000): + logger.update_pbar(task_1, 1) + logger.update_pbar(task_2, 1) + time.sleep(0.002) + logger.close_pbar() + print('Success!') + + print('========== Finish Logger Test ==========') diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..a34ba8344c44a7d86de3674387cd0ab167a515ba --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,234 @@ +# python3.7 +"""Misc utility functions.""" + +import os +import hashlib + +from torch.hub import download_url_to_file + +__all__ = [ + 'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext', + 'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS', + 'parse_file_format', 'set_cache_dir', 'get_cache_dir', 'download_url' +] + +REPO_NAME = 'Hammer' # Name of the repository (project). + + +class Infix(object): + """Helper class to create custom infix operators. + + When using it, make sure to put the operator between `<<` and `>>`. + `<< INFIX_OP_NAME >>` should be considered as a whole operator. + + Examples: + + # Use `Infix` to create infix operators directly. + add = Infix(lambda a, b: a + b) + 1 << add >> 2 # gives 3 + 1 << add >> 2 << add >> 3 # gives 6 + + # Use `Infix` as a decorator. + @Infix + def mul(a, b): + return a * b + 2 << mul >> 4 # gives 8 + 2 << mul >> 3 << mul >> 7 # gives 42 + """ + + def __init__(self, function): + self.function = function + self.left_value = None + + def __rlshift__(self, left_value): # override `<<` before `Infix` instance + assert self.left_value is None # make sure left is only called once + self.left_value = left_value + return self + + def __rshift__(self, right_value): # override `>>` after `Infix` instance + result = self.function(self.left_value, right_value) + self.left_value = None # reset to None + return result + + +def print_and_execute(cmd): + """Prints and executes a system command. + + Args: + cmd: Command to be executed. + """ + print(cmd) + os.system(cmd) + + +def check_file_ext(filename, *ext_list): + """Checks whether the given filename is with target extension(s). + + NOTE: If `ext_list` is empty, this function will always return `False`. + + Args: + filename: Filename to check. + *ext_list: A list of extensions. + + Returns: + `True` if the filename is with one of extensions in `ext_list`, + otherwise `False`. + """ + if len(ext_list) == 0: + return False + ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] + ext_list = [ext.lower() for ext in ext_list] + basename = os.path.basename(filename) + ext = os.path.splitext(basename)[1].lower() + return ext in ext_list + + +# File extensions regarding images (not including GIFs). +IMAGE_EXTENSIONS = ( + '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', + '.tiff', '.tif' +) +# File extensions regarding videos. +VIDEO_EXTENSIONS = ( + '.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm', + '.3gp' +) +# File extensions regarding media, i.e., images, videos, GIFs. +MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS) + + +def parse_file_format(path): + """Parses the file format of a given path. + + This function basically parses the file format according to its extension. + It will also return `dir` is the given path is a directory. + + Parable file formats: + + - zip: with `.zip` extension. + - tar: with `.tar` / `.tgz` / `.tar.gz` extension. + - lmdb: a folder ending with `lmdb`. + - txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE). + - json: with `.json` extension. + - jpg: with `.jpeg` / `jpg` / `jpe` extension. + - png: with `.png` extension. + + Args: + path: The path to the file to parse format from. + + Returns: + A lower-case string, indicating the file format, or `None` if the format + cannot be successfully parsed. + """ + # Handle directory. + if os.path.isdir(path) or path.endswith('/'): + if path.rstrip('/').lower().endswith('lmdb'): + return 'lmdb' + return 'dir' + # Handle file. + if os.path.isfile(path) and os.path.splitext(path)[1] == '': + return 'txt' + path = path.lower() + if path.endswith('.tar.gz'): # Cannot parse accurate extension. + return 'tar' + ext = os.path.splitext(path)[1] + if ext == '.zip': + return 'zip' + if ext in ['.tar', '.tgz']: + return 'tar' + if ext in ['.txt', '.text']: + return 'txt' + if ext == '.json': + return 'json' + if ext in ['.jpeg', '.jpg', '.jpe']: + return 'jpg' + if ext == '.png': + return 'png' + # Unparsable. + return None + + +_cache_dir = None + + +def set_cache_dir(directory=None): + """Sets the global cache directory. + + The cache directory can be used to save some files that will be shared + across jobs. The default cache directory is set as `~/.cache/`. This + function can be used to redirect the cache directory. Or, users can use + `None` to reset the cache directory back to default. + + Args: + directory: The target directory used to cache files. If set as `None`, + the cache directory will be reset back to default. (default: None) + """ + assert directory is None or isinstance(directory, str), 'Invalid directory!' + global _cache_dir # pylint: disable=global-statement + _cache_dir = directory + + +def get_cache_dir(use_repo_name=True): + """Gets the global cache directory. + + The global cache directory is primarily set as `~/.cache/` by default, and + can be redirected with `set_cache_dir()`. + + Args: + use_repo_name: Whether to create a folder, named `REPO_NAME`, under + `_cache_dir` as the actual cache directory. (default: True) + + Returns: + A string, representing the global cache directory. + """ + if _cache_dir is None: + cache_dir = os.path.join(os.path.expanduser('~'), '.cache') + else: + cache_dir = _cache_dir + if use_repo_name: + return os.path.join(cache_dir, REPO_NAME) + return cache_dir + + +def download_url(url, path=None, filename=None, sha256=None): + """Downloads file from URL. + + This function downloads a file from given URL, and executes Hash check if + needed. + + Args: + url: The URL to download file from. + path: Path (directory) to save the downloaded file. If set as `None`, + the cache directory will be used. Please see `get_cache_dir()` for + more details. (default: None) + filename: The name to save the file. If set as `None`, this name will be + automatically parsed from the given URL. (default: None) + sha256: The expected sha256 of the downloaded file. If set as `None`, + the hash check will be skipped. Otherwise, this function will check + whether the sha256 of the downloaded file matches this field. + + Returns: + A two-element tuple, where the first term is the full path of the + downloaded file, and the second term indicate the hash check result. + `True` means hash check passes, `False` means hash check fails, + while `None` means no hash check is executed. + """ + # Handle file path. + if path is None: + path = get_cache_dir() + if filename is None: + filename = os.path.basename(url) + save_path = os.path.join(path, filename) + # Download file if needed. + if not os.path.exists(save_path): + print(f'Downloading URL `{url}` to path `{save_path}` ...') + os.makedirs(path, exist_ok=True) + download_url_to_file(url, save_path, hash_prefix=None, progress=True) + # Check hash if needed. + check_result = None + if sha256 is not None: + with open(save_path, 'rb') as f: + file_hash = hashlib.sha256(f.read()) + check_result = (file_hash.hexdigest() == sha256) + + return save_path, check_result diff --git a/utils/parsing_utils.py b/utils/parsing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a41749861ad68f36cff7ba33053aa416d27bb07c --- /dev/null +++ b/utils/parsing_utils.py @@ -0,0 +1,214 @@ +# python3.7 +"""Contains the utility functions for parsing arguments.""" + +import json +import argparse +import click + +__all__ = [ + 'parse_int', 'parse_float', 'parse_bool', 'parse_index', 'parse_json', + 'IntegerParamType', 'FloatParamType', 'BooleanParamType', 'IndexParamType', + 'JsonParamType', 'DictAction' +] + + +def parse_int(arg): + """Parses an argument to integer. + + Support converting string `none` and `null` to `None`. + """ + if arg is None: + return None + if isinstance(arg, str) and arg.lower() in ['none', 'null']: + return None + return int(arg) + + +def parse_float(arg): + """Parses an argument to float number. + + Support converting string `none` and `null` to `None`. + """ + if arg is None: + return None + if isinstance(arg, str) and arg.lower() in ['none', 'null']: + return None + return float(arg) + + +def parse_bool(arg): + """Parses an argument to boolean. + + `None` will be converted to `False`. + """ + if isinstance(arg, bool): + return arg + if arg is None: + return False + if arg.lower() in ['1', 'true', 't', 'yes', 'y']: + return True + if arg.lower() in ['0', 'false', 'f', 'no', 'n', 'none', 'null']: + return False + raise ValueError(f'`{arg}` cannot be converted to boolean!') + + +def parse_index(arg, min_val=None, max_val=None): + """Parses indices. + + If the input is a list or tuple, this function has no effect. + + If the input is a string, it can be either a comma separated list of numbers + `1, 3, 5`, or a dash separated range `3 - 10`. Spaces in the string will be + ignored. + + Args: + arg: The input argument to parse indices from. + min_val: If not `None`, this function will check that all indices are + equal to or larger than this value. (default: None) + max_val: If not `None`, this function will check that all indices are + equal to or smaller than this field. (default: None) + + Returns: + A list of integers. + + Raises: + ValueError: If the input is invalid, i.e., neither a list or tuple, nor + a string. + """ + if arg is None or arg == '': + indices = [] + elif isinstance(arg, int): + indices = [arg] + elif isinstance(arg, (list, tuple)): + indices = list(arg) + elif isinstance(arg, str): + indices = [] + if arg.lower() not in ['none', 'null']: + splits = arg.replace(' ', '').split(',') + for split in splits: + numbers = list(map(int, split.split('-'))) + if len(numbers) == 1: + indices.append(numbers[0]) + elif len(numbers) == 2: + indices.extend(list(range(numbers[0], numbers[1] + 1))) + else: + raise ValueError(f'Invalid type of input: `{type(arg)}`!') + + assert isinstance(indices, list) + indices = sorted(list(set(indices))) + for idx in indices: + assert isinstance(idx, int) + if min_val is not None: + assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' + if max_val is not None: + assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' + + return indices + + +def parse_json(arg): + """Parses a string-like argument following JSON format. + + - `None` arguments will be kept. + - Non-string arguments will be kept. + """ + if not isinstance(arg, str): + return arg + try: + return json.loads(arg) + except json.decoder.JSONDecodeError: + return arg + + +class IntegerParamType(click.ParamType): + """Defines a `click.ParamType` to parse integer arguments.""" + + name = 'int' + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_int(value) + except ValueError: + self.fail(f'`{value}` cannot be parsed as an integer!', param, ctx) + + +class FloatParamType(click.ParamType): + """Defines a `click.ParamType` to parse float arguments.""" + + name = 'float' + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_float(value) + except ValueError: + self.fail(f'`{value}` cannot be parsed as a float!', param, ctx) + + +class BooleanParamType(click.ParamType): + """Defines a `click.ParamType` to parse boolean arguments.""" + + name = 'bool' + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_bool(value) + except ValueError: + self.fail(f'`{value}` cannot be parsed as a boolean!', param, ctx) + + +class IndexParamType(click.ParamType): + """Defines a `click.ParamType` to parse indices arguments.""" + + name = 'index' + + def __init__(self, min_val=None, max_val=None): + self.min_val = min_val + self.max_val = max_val + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_index(value, self.min_val, self.max_val) + except ValueError: + self.fail( + f'`{value}` cannot be parsed as a list of indices!', param, ctx) + + +class JsonParamType(click.ParamType): + """Defines a `click.ParamType` to parse arguments following JSON format.""" + + name = 'json' + + def convert(self, value, param, ctx): + return parse_json(value) + + +class DictAction(argparse.Action): + """Argparse action to split each argument into (key, value) pair. + + Each argument should be with `key=value` format, where `value` should be a + string with JSON format. + + For example, with an argparse: + + parser.add_argument('--options', nargs='+', action=DictAction) + + , you can use following arguments in the command line: + + --options \ + a=1 \ + b=1.5 + c=true \ + d=null \ + e=[1,2,3,4,5] \ + f='{"x":1,"y":2,"z":3}' \ + + NOTE: No space is allowed in each argument. Also, the dictionary-type + argument should be quoted with single quotation marks `'`. + """ + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for argument in values: + key, val = argument.split('=', maxsplit=1) + options[key] = parse_json(val) + setattr(namespace, self.dest, options) diff --git a/utils/tf_utils.py b/utils/tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..80e48ed06e614571d920125d4b64fbfefbf804c0 --- /dev/null +++ b/utils/tf_utils.py @@ -0,0 +1,47 @@ +# python3.7 +"""Contains the utility functions to handle import TensorFlow modules. + +Basically, TensorFlow may not be supported in the current environment, or may +cause some warnings. This file provides functions to help ease TensorFlow +related imports, such as TensorBoard. +""" + +import warnings + +__all__ = ['import_tf', 'import_tb_writer'] + + +def import_tf(): + """Imports TensorFlow module if possible. + + If `ImportError` is raised, `None` will be returned. Otherwise, the module + `tensorflow` will be returned. + """ + warnings.filterwarnings('ignore', category=FutureWarning) + try: + import tensorflow as tf # pylint: disable=import-outside-toplevel + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + module = tf + except ImportError: + module = None + warnings.filterwarnings('default', category=FutureWarning) + return module + + +def import_tb_writer(): + """Imports the SummaryWriter of TensorBoard. + + If `ImportError` is raised, `None` will be returned. Otherwise, the class + `SummaryWriter` will be returned. + + NOTE: This function attempts to import `SummaryWriter` from + `torch.utils.tensorboard`. But it does not necessarily mean the import + always succeeds because installing TensorBoard is not a duty of `PyTorch`. + """ + warnings.filterwarnings('ignore', category=FutureWarning) + try: + from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel + except ImportError: # In case TensorBoard is not supported. + SummaryWriter = None + warnings.filterwarnings('default', category=FutureWarning) + return SummaryWriter diff --git a/utils/visualizers.py b/utils/visualizers.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7cd5d925c96beb386d0652be6d4ecb6ace8a2f --- /dev/null +++ b/utils/visualizers.py @@ -0,0 +1,746 @@ +# python3.7 +"""Utility functions for visualizing results.""" + +import base64 +import os.path +import cv2 +import numpy as np +from bs4 import BeautifulSoup + +__all__ = [ + 'get_grid_shape', 'get_blank_image', 'load_image', 'save_image', + 'resize_image', 'postprocess_image', 'add_text_to_image', + 'parse_image_size', 'fuse_images', 'HtmlPageVisualizer', 'HtmlPageReader', + 'VideoReader', 'VideoWriter' +] + + +def get_grid_shape(size, row=0, col=0, is_portrait=False): + """Gets the shape of a grid based on the size. + + This function makes greatest effort on making the output grid square if + neither `row` nor `col` is set. If `is_portrait` is set as `False`, the + height will always be equal to or smaller than the width. For example, if + input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, + output shape will be (3, 5). Otherwise, the height will always be equal to + or larger than the width. + + Args: + size: Size (height * width) of the target grid. + is_portrait: Whether to return a portrait size of a landscape size. + (default: False) + + Returns: + A two-element tuple, representing height and width respectively. + """ + assert isinstance(size, int) + assert isinstance(row, int) + assert isinstance(col, int) + if size == 0: + return (0, 0) + + if row > 0 and col > 0 and row * col != size: + row = 0 + col = 0 + + if row > 0 and size % row == 0: + return (row, size // row) + if col > 0 and size % col == 0: + return (size // col, col) + + row = int(np.sqrt(size)) + while row > 0: + if size % row == 0: + col = size // row + break + row = row - 1 + + return (col, row) if is_portrait else (row, col) + + +def get_blank_image(height, width, channels=3, is_black=True): + """Gets a blank image, either white of black. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + height: Height of the returned image. + width: Width of the returned image. + channels: Number of channels. (default: 3) + is_black: Whether to return a black image. (default: True) + """ + shape = (height, width, channels) + if is_black: + return np.zeros(shape, dtype=np.uint8) + return np.ones(shape, dtype=np.uint8) * 255 + + +def load_image(path, image_channels=3): + """Loads an image from disk. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + path: Path to load the image from. + image_channels: Number of image channels of returned image. This field + is employed since `cv2.imread()` will always return a 3-channel + image, even for grayscale image. + + Returns: + An image with dtype `np.ndarray`, or `None` if `path` does not exist. + """ + if not os.path.isfile(path): + return None + + assert image_channels in [1, 3] + + image = cv2.imread(path) + assert image.ndim == 3 and image.shape[2] == 3 + if image_channels == 1: + return image[:, :, 0:1] + return image[:, :, ::-1] + + +def save_image(path, image): + """Saves an image to disk. + + NOTE: The input image (if colorful) is assumed to be with `RGB` channel + order and pixel range [0, 255]. + + Args: + path: Path to save the image to. + image: Image to save. + """ + if image is None: + return + + assert image.ndim == 3 and image.shape[2] in [1, 3] + cv2.imwrite(path, image[:, :, ::-1]) + + +def resize_image(image, *args, **kwargs): + """Resizes image. + + This is a wrap of `cv2.resize()`. + + NOTE: THe channel order of the input image will not be changed. + + Args: + image: Image to resize. + """ + if image is None: + return None + + assert image.ndim == 3 and image.shape[2] in [1, 3] + image = cv2.resize(image, *args, **kwargs) + if image.ndim == 2: + return image[:, :, np.newaxis] + return image + + +def postprocess_image(image, min_val=-1.0, max_val=1.0, data_format='NCHW'): + """Post-processes image to pixel range [0, 255] with dtype `uint8`. + + NOTE: The returned image will always be with `HWC` format. + + Args: + min_val: Minimum value of the input image. + max_val: Maximum value of the input image. + data_format: Data format of the input image. Supporting `NCHW`, `NHWC`, + `CHW`, `HWC`. + + Returns: + The post-processed image. + + Raises: + NotImplementedError: If the input `data_format` is not support. + """ + assert isinstance(image, np.ndarray) + image = image.astype(np.float64) + image = (image - min_val) * 255 / (max_val - min_val) + image = np.clip(image + 0.5, 0, 255).astype(np.uint8) + data_format = data_format.upper() + if data_format == 'NCHW': + assert image.ndim == 4 and image.shape[1] in [1, 3] + return image.transpose(0, 2, 3, 1) + if data_format == 'NHWC': + assert image.ndim == 4 and image.shape[3] in [1, 3] + return image + if data_format == 'CHW': + assert image.ndim == 3 and image.shape[0] in [1, 3] + return image.transpose(1, 2, 0) + if data_format == 'HWC': + assert image.ndim == 3 and image.shape[2] in [1, 3] + return image + raise NotImplementedError(f'Data format `{data_format}` is not supported!') + + +def add_text_to_image(image, + text='', + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255)): + """Overlays text on given image. + + NOTE: The input image is assumed to be with `RGB` channel order. + + Args: + image: The image to overlay text on. + text: Text content to overlay on the image. (default: '') + position: Target position (bottom-left corner) to add text. If not set, + center of the image will be used by default. (default: None) + font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) + font_size: Font size of the text added. (default: 1.0) + line_type: Line type used to depict the text. (default: cv2.LINE_8) + line_width: Line width used to depict the text. (default: 1) + color: Color of the text added in `RGB` channel order. (default: + (255, 255, 255)) + + Returns: + An image with target text overlayed on. + """ + if image is None or not text: + return image + + cv2.putText(img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False) + + return image + + +def parse_image_size(obj): + """Parses object to a pair of image size, i.e., (width, height). + + Args: + obj: The input object to parse image size from. + + Returns: + A two-element tuple, indicating image width and height respectively. + + Raises: + If the input is invalid, i.e., neither a list or tuple, nor a string. + """ + if obj is None or obj == '': + width = height = 0 + elif isinstance(obj, int): + width = height = obj + elif isinstance(obj, (list, tuple, np.ndarray)): + numbers = tuple(obj) + if len(numbers) == 0: + width = height = 0 + elif len(numbers) == 1: + width = height = numbers[0] + elif len(numbers) == 2: + width = numbers[0] + height = numbers[1] + else: + raise ValueError(f'At most two elements for image size.') + elif isinstance(obj, str): + splits = obj.replace(' ', '').split(',') + numbers = tuple(map(int, splits)) + if len(numbers) == 0: + width = height = 0 + elif len(numbers) == 1: + width = height = numbers[0] + elif len(numbers) == 2: + width = numbers[0] + height = numbers[1] + else: + raise ValueError(f'At most two elements for image size.') + else: + raise ValueError(f'Invalid type of input: {type(obj)}!') + + return (max(0, width), max(0, height)) + + +def fuse_images(images, + image_size=None, + row=0, + col=0, + is_row_major=True, + is_portrait=False, + row_spacing=0, + col_spacing=0, + border_left=0, + border_right=0, + border_top=0, + border_bottom=0, + black_background=True): + """Fuses a collection of images into an entire image. + + Args: + images: A collection of images to fuse. Should be with shape [num, + height, width, channels]. + image_size: This field is used to resize the image before fusion. `0` + disables resizing. (default: None) + row: Number of rows used for image fusion. If not set, this field will + be automatically assigned based on `col` and total number of images. + (default: None) + col: Number of columns used for image fusion. If not set, this field + will be automatically assigned based on `row` and total number of + images. (default: None) + is_row_major: Whether the input images should be arranged row-major or + column-major. (default: True) + is_portrait: Only active when both `row` and `col` should be assigned + automatically. (default: False) + row_spacing: Space between rows. (default: 0) + col_spacing: Space between columns. (default: 0) + border_left: Width of left border. (default: 0) + border_right: Width of right border. (default: 0) + border_top: Width of top border. (default: 0) + border_bottom: Width of bottom border. (default: 0) + + Returns: + The fused image. + + Raises: + ValueError: If the input `images` is not with shape [num, height, width, + width]. + """ + if images is None: + return images + + if images.ndim != 4: + raise ValueError(f'Input `images` should be with shape [num, height, ' + f'width, channels], but {images.shape} is received!') + + num, image_height, image_width, channels = images.shape + width, height = parse_image_size(image_size) + height = height or image_height + width = width or image_width + row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait) + fused_height = ( + height * row + row_spacing * (row - 1) + border_top + border_bottom) + fused_width = ( + width * col + col_spacing * (col - 1) + border_left + border_right) + fused_image = get_blank_image( + fused_height, fused_width, channels=channels, is_black=black_background) + images = images.reshape(row, col, image_height, image_width, channels) + if not is_row_major: + images = images.transpose(1, 0, 2, 3, 4) + + for i in range(row): + y = border_top + i * (height + row_spacing) + for j in range(col): + x = border_left + j * (width + col_spacing) + if height != image_height or width != image_width: + image = cv2.resize(images[i, j], (width, height)) + else: + image = images[i, j] + fused_image[y:y + height, x:x + width] = image + + return fused_image + + +def get_sortable_html_header(column_name_list, sort_by_ascending=False): + """Gets header for sortable html page. + + Basically, the html page contains a sortable table, where user can sort the + rows by a particular column by clicking the column head. + + Example: + + column_name_list = [name_1, name_2, name_3] + header = get_sortable_html_header(column_name_list) + footer = get_sortable_html_footer() + sortable_table = ... + html_page = header + sortable_table + footer + + Args: + column_name_list: List of column header names. + sort_by_ascending: Default sorting order. If set as `True`, the html + page will be sorted by ascending order when the header is clicked + for the first time. + + Returns: + A string, which represents for the header for a sortable html page. + """ + header = '\n'.join([ + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '']) + for idx, name in enumerate(column_name_list): + header += f' \n' + header += '\n' + header += '\n' + header += '\n' + + return header + + +def get_sortable_html_footer(): + """Gets footer for sortable html page. + + Check function `get_sortable_html_header()` for more details. + """ + return '\n
{name}
\n\n\n\n' + + +def encode_image_to_html_str(image, image_size=None): + """Encodes an image to html language. + + NOTE: Input image is always assumed to be with `RGB` channel order. + + Args: + image: The input image to encode. Should be with `RGB` channel order. + image_size: This field is used to resize the image before encoding. `0` + disables resizing. (default: None) + + Returns: + A string which represents the encoded image. + """ + if image is None: + return '' + + assert image.ndim == 3 and image.shape[2] in [1, 3] + + # Change channel order to `BGR`, which is opencv-friendly. + image = image[:, :, ::-1] + + # Resize the image if needed. + width, height = parse_image_size(image_size) + if height or width: + height = height or image.shape[0] + width = width or image.shape[1] + image = cv2.resize(image, (width, height)) + + # Encode the image to html-format string. + encoded_image = cv2.imencode('.jpg', image)[1].tostring() + encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8') + html_str = f'' + + return html_str + + +def decode_html_str_to_image(html_str, image_size=None): + """Decodes image from html. + + Args: + html_str: Image string parsed from html. + image_size: This field is used to resize the image after decoding. `0` + disables resizing. (default: None) + + Returns: + An image with `RGB` channel order. + """ + if not html_str: + return None + + assert isinstance(html_str, str) + image_str = html_str.split(',')[-1] + encoded_image = base64.b64decode(image_str) + encoded_image_numpy = np.frombuffer(encoded_image, dtype=np.uint8) + image = cv2.imdecode(encoded_image_numpy, flags=cv2.IMREAD_COLOR) + + # Resize the image if needed. + width, height = parse_image_size(image_size) + if height or width: + height = height or image.shape[0] + width = width or image.shape[1] + image = cv2.resize(image, (width, height)) + + return image[:, :, ::-1] + + +class HtmlPageVisualizer(object): + """Defines the html page visualizer. + + This class can be used to visualize image results as html page. Basically, + it is based on an html-format sorted table with helper functions + `get_sortable_html_header()`, `get_sortable_html_footer()`, and + `encode_image_to_html_str()`. To simplify the usage, specifying the + following fields are enough to create a visualization page: + + (1) num_rows: Number of rows of the table (header-row exclusive). + (2) num_cols: Number of columns of the table. + (3) header contents (optional): Title of each column. + + NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` + automatically. + + Example: + + html = HtmlPageVisualizer(num_rows, num_cols) + html.set_headers([...]) + for i in range(num_rows): + for j in range(num_cols): + html.set_cell(i, j, text=..., image=..., highlight=False) + html.save('visualize.html') + """ + + def __init__(self, + num_rows=0, + num_cols=0, + grid_size=0, + is_portrait=True, + viz_size=None): + if grid_size > 0: + num_rows, num_cols = get_grid_shape( + grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait) + assert num_rows > 0 and num_cols > 0 + + self.num_rows = num_rows + self.num_cols = num_cols + self.viz_size = parse_image_size(viz_size) + self.headers = ['' for _ in range(self.num_cols)] + self.cells = [[{ + 'text': '', + 'image': '', + 'highlight': False, + } for _ in range(self.num_cols)] for _ in range(self.num_rows)] + + def set_header(self, col_idx, content): + """Sets the content of a particular header by column index.""" + self.headers[col_idx] = content + + def set_headers(self, contents): + """Sets the contents of all headers.""" + if isinstance(contents, str): + contents = [contents] + assert isinstance(contents, (list, tuple)) + assert len(contents) == self.num_cols + for col_idx, content in enumerate(contents): + self.set_header(col_idx, content) + + def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False): + """Sets the content of a particular cell. + + Basically, a cell contains some text as well as an image. Both text and + image can be empty. + + Args: + row_idx: Row index of the cell to edit. + col_idx: Column index of the cell to edit. + text: Text to add into the target cell. (default: None) + image: Image to show in the target cell. Should be with `RGB` + channel order. (default: None) + highlight: Whether to highlight this cell. (default: False) + """ + self.cells[row_idx][col_idx]['text'] = text + self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str( + image, self.viz_size) + self.cells[row_idx][col_idx]['highlight'] = bool(highlight) + + def save(self, save_path): + """Saves the html page.""" + html = '' + for i in range(self.num_rows): + html += f'\n' + for j in range(self.num_cols): + text = self.cells[i][j]['text'] + image = self.cells[i][j]['image'] + if self.cells[i][j]['highlight']: + color = ' bgcolor="#FF8888"' + else: + color = '' + if text: + html += f' {text}

{image}\n' + else: + html += f' {image}\n' + html += f'\n' + + header = get_sortable_html_header(self.headers) + footer = get_sortable_html_footer() + + with open(save_path, 'w') as f: + f.write(header + html + footer) + + +class HtmlPageReader(object): + """Defines the html page reader. + + This class can be used to parse results from the visualization page + generated by `HtmlPageVisualizer`. + + Example: + + html = HtmlPageReader(html_path) + for j in range(html.num_cols): + header = html.get_header(j) + for i in range(html.num_rows): + for j in range(html.num_cols): + text = html.get_text(i, j) + image = html.get_image(i, j, image_size=None) + """ + def __init__(self, html_path): + """Initializes by loading the content from file.""" + self.html_path = html_path + if not os.path.isfile(html_path): + raise ValueError(f'File `{html_path}` does not exist!') + + # Load content. + with open(html_path, 'r') as f: + self.html = BeautifulSoup(f, 'html.parser') + + # Parse headers. + thead = self.html.find('thead') + headers = thead.findAll('th') + self.headers = [] + for header in headers: + self.headers.append(header.text) + self.num_cols = len(self.headers) + + # Parse cells. + tbody = self.html.find('tbody') + rows = tbody.findAll('tr') + self.cells = [] + for row in rows: + cells = row.findAll('td') + self.cells.append([]) + for cell in cells: + self.cells[-1].append({ + 'text': cell.text, + 'image': cell.find('img')['src'], + }) + assert len(self.cells[-1]) == self.num_cols + self.num_rows = len(self.cells) + + def get_header(self, j): + """Gets header for a particular column.""" + return self.headers[j] + + def get_text(self, i, j): + """Gets text from a particular cell.""" + return self.cells[i][j]['text'] + + def get_image(self, i, j, image_size=None): + """Gets image from a particular cell.""" + return decode_html_str_to_image(self.cells[i][j]['image'], image_size) + + +class VideoReader(object): + """Defines the video reader. + + This class can be used to read frames from a given video. + """ + + def __init__(self, path): + """Initializes the video reader by loading the video from disk.""" + if not os.path.isfile(path): + raise ValueError(f'Video `{path}` does not exist!') + + self.path = path + self.video = cv2.VideoCapture(path) + assert self.video.isOpened() + self.position = 0 + + self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) + self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.fps = self.video.get(cv2.CAP_PROP_FPS) + + def __del__(self): + """Releases the opened video.""" + self.video.release() + + def read(self, position=None): + """Reads a certain frame. + + NOTE: The returned frame is assumed to be with `RGB` channel order. + + Args: + position: Optional. If set, the reader will read frames from the + exact position. Otherwise, the reader will read next frames. + (default: None) + """ + if position is not None and position < self.length: + self.video.set(cv2.CAP_PROP_POS_FRAMES, position) + self.position = position + + success, frame = self.video.read() + self.position = self.position + 1 + + return frame[:, :, ::-1] if success else None + + +class VideoWriter(object): + """Defines the video writer. + + This class can be used to create a video. + + NOTE: `.avi` and `DIVX` is the most recommended codec format since it does + not rely on other dependencies. + """ + + def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'): + """Creates the video writer.""" + self.path = path + self.frame_height = frame_height + self.frame_width = frame_width + self.fps = fps + self.codec = codec + + self.video = cv2.VideoWriter(filename=path, + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=fps, + frameSize=(frame_width, frame_height)) + + def __del__(self): + """Releases the opened video.""" + self.video.release() + + def write(self, frame): + """Writes a target frame. + + NOTE: The input frame is assumed to be with `RGB` channel order. + """ + self.video.write(frame[:, :, ::-1]) diff --git a/utils/visualizers/__init__.py b/utils/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df9fbaca361c3802c0f8221c053b61bf66a2456f --- /dev/null +++ b/utils/visualizers/__init__.py @@ -0,0 +1,14 @@ +# python3.7 +"""Collects all visualizers.""" + +from .grid_visualizer import GridVisualizer +from .gif_visualizer import GifVisualizer +from .html_visualizer import HtmlVisualizer +from .html_visualizer import HtmlReader +from .video_visualizer import VideoVisualizer +from .video_visualizer import VideoReader + +__all__ = [ + 'GridVisualizer', 'GifVisualizer', 'HtmlVisualizer', 'HtmlReader', + 'VideoVisualizer', 'VideoReader' +] diff --git a/utils/visualizers/__pycache__/__init__.cpython-37.pyc b/utils/visualizers/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b3a9e1068ffd836d4f9c7676126efe261da30d Binary files /dev/null and b/utils/visualizers/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/visualizers/__pycache__/gif_visualizer.cpython-37.pyc b/utils/visualizers/__pycache__/gif_visualizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77e137d2907d64f09bb0ec83d9b2372ca9fb056f Binary files /dev/null and b/utils/visualizers/__pycache__/gif_visualizer.cpython-37.pyc differ diff --git a/utils/visualizers/__pycache__/grid_visualizer.cpython-37.pyc b/utils/visualizers/__pycache__/grid_visualizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4874e27c01483314d01bff014caf8e6047404465 Binary files /dev/null and b/utils/visualizers/__pycache__/grid_visualizer.cpython-37.pyc differ diff --git a/utils/visualizers/__pycache__/html_visualizer.cpython-37.pyc b/utils/visualizers/__pycache__/html_visualizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe45394783b9d606f6954fbc27e7f3ba43924f38 Binary files /dev/null and b/utils/visualizers/__pycache__/html_visualizer.cpython-37.pyc differ diff --git a/utils/visualizers/__pycache__/video_visualizer.cpython-37.pyc b/utils/visualizers/__pycache__/video_visualizer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7e24da7b9a32a1e174e0098d893b884b8d25524 Binary files /dev/null and b/utils/visualizers/__pycache__/video_visualizer.cpython-37.pyc differ diff --git a/utils/visualizers/gif_visualizer.py b/utils/visualizers/gif_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a5528e8af79fda2e3840c67cbf60ea87ff273f4c --- /dev/null +++ b/utils/visualizers/gif_visualizer.py @@ -0,0 +1,79 @@ +# python3.7 +"""Contains the visualizer to visualize images as a GIF.""" + +from PIL import Image + +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['GifVisualizer'] + + +class GifVisualizer(object): + """Defines the visualizer that visualizes an image collection as GIF.""" + + def __init__(self, image_size=None, duration=100, loop=0): + """Initializes the GIF visualizer. + + Args: + image_size: Size for image visualization. (default: None) + duration: Duration between two frames, in milliseconds. + (default: 100) + loop: How many times to loop the GIF. `0` means infinite. + (default: 0) + """ + self.set_image_size(image_size) + self.set_duration(duration) + self.set_loop(loop) + + def set_image_size(self, image_size=None): + """Sets the image size of the GIF.""" + height, width = parse_image_size(image_size) + self.image_height = height + self.image_width = width + + def set_duration(self, duration=100): + """Sets the GIF duration.""" + self.duration = duration + + def set_loop(self, loop=0): + """Sets how many times the GIF will be looped. `0` means infinite.""" + self.loop = loop + + def visualize_collection(self, images, save_path): + """Visualizes a collection of images one by one.""" + height, width = images[0].shape[0:2] + height = self.image_height or height + width = self.image_width or width + pil_images = [] + for image in images: + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + pil_images.append(Image.fromarray(image)) + pil_images[0].save(save_path, format='GIF', save_all=True, + append_images=pil_images[1:], + duration=self.duration, + loop=self.loop) + + def visualize_list(self, image_list, save_path): + """Visualizes a list of image files.""" + height, width = load_image(image_list[0]).shape[0:2] + height = self.image_height or height + width = self.image_width or width + pil_images = [] + for filename in image_list: + image = load_image(filename) + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + pil_images.append(Image.fromarray(image)) + pil_images[0].save(save_path, format='GIF', save_all=True, + append_images=pil_images[1:], + duration=self.duration, + loop=self.loop) + + def visualize_directory(self, directory, save_path): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list, save_path) diff --git a/utils/visualizers/grid_visualizer.py b/utils/visualizers/grid_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c88db155b5986b9e86ed38958077956beb60b6df --- /dev/null +++ b/utils/visualizers/grid_visualizer.py @@ -0,0 +1,234 @@ +# python3.7 +"""Contains the visualizer to visualize images by composing them as a gird.""" + +from ..image_utils import get_blank_image +from ..image_utils import get_grid_shape +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import save_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['GridVisualizer'] + + +class GridVisualizer(object): + """Defines the visualizer that visualizes images as a grid. + + Basically, given a collection of images, this visualizer stitches them one + by one. Notably, this class also supports adding spaces between images, + adding borders around images, and using white/black background. + + Example: + + grid = GridVisualizer(num_rows, num_cols) + for i in range(num_rows): + for j in range(num_cols): + grid.add(i, j, image) + grid.save('visualize.jpg') + """ + + def __init__(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=False, + image_size=None, + image_channels=0, + row_spacing=0, + col_spacing=0, + border_left=0, + border_right=0, + border_top=0, + border_bottom=0, + use_black_background=True): + """Initializes the grid visualizer. + + Args: + grid_size: Total number of cells, i.e., height * width. (default: 0) + num_rows: Number of rows. (default: 0) + num_cols: Number of columns. (default: 0) + is_portrait: Whether the grid should be portrait or landscape. + This is only used when it requires to compute `num_rows` and + `num_cols` automatically. See function `get_grid_shape()` in + file `./image_utils.py` for details. (default: False) + image_size: Size to visualize each image. (default: 0) + image_channels: Number of image channels. (default: 0) + row_spacing: Spacing between rows. (default: 0) + col_spacing: Spacing between columns. (default: 0) + border_left: Width of left border. (default: 0) + border_right: Width of right border. (default: 0) + border_top: Width of top border. (default: 0) + border_bottom: Width of bottom border. (default: 0) + use_black_background: Whether to use black background. + (default: True) + """ + self.reset(grid_size, num_rows, num_cols, is_portrait) + self.set_image_size(image_size) + self.set_image_channels(image_channels) + self.set_row_spacing(row_spacing) + self.set_col_spacing(col_spacing) + self.set_border_left(border_left) + self.set_border_right(border_right) + self.set_border_top(border_top) + self.set_border_bottom(border_bottom) + self.set_background(use_black_background) + self.grid = None + + def reset(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=False): + """Resets the grid shape, i.e., number of rows/columns.""" + if grid_size > 0: + num_rows, num_cols = get_grid_shape(grid_size, + height=num_rows, + width=num_cols, + is_portrait=is_portrait) + self.grid_size = num_rows * num_cols + self.num_rows = num_rows + self.num_cols = num_cols + self.grid = None + + def set_image_size(self, image_size=None): + """Sets the image size of each cell in the grid.""" + height, width = parse_image_size(image_size) + self.image_height = height + self.image_width = width + + def set_image_channels(self, image_channels=0): + """Sets the number of channels of the grid.""" + self.image_channels = image_channels + + def set_row_spacing(self, row_spacing=0): + """Sets the spacing between grid rows.""" + self.row_spacing = row_spacing + + def set_col_spacing(self, col_spacing=0): + """Sets the spacing between grid columns.""" + self.col_spacing = col_spacing + + def set_border_left(self, border_left=0): + """Sets the width of the left border of the grid.""" + self.border_left = border_left + + def set_border_right(self, border_right=0): + """Sets the width of the right border of the grid.""" + self.border_right = border_right + + def set_border_top(self, border_top=0): + """Sets the width of the top border of the grid.""" + self.border_top = border_top + + def set_border_bottom(self, border_bottom=0): + """Sets the width of the bottom border of the grid.""" + self.border_bottom = border_bottom + + def set_background(self, use_black=True): + """Sets the grid background.""" + self.use_black_background = use_black + + def init_grid(self): + """Initializes the grid with a blank image.""" + assert self.num_rows > 0 + assert self.num_cols > 0 + assert self.image_height > 0 + assert self.image_width > 0 + assert self.image_channels > 0 + grid_height = (self.image_height * self.num_rows + + self.row_spacing * (self.num_rows - 1) + + self.border_top + self.border_bottom) + grid_width = (self.image_width * self.num_cols + + self.col_spacing * (self.num_cols - 1) + + self.border_left + self.border_right) + self.grid = get_blank_image(grid_height, grid_width, + channels=self.image_channels, + use_black=self.use_black_background) + + def add(self, i, j, image): + """Adds an image into the grid. + + NOTE: The input image is assumed to be with `RGB` channel order. + """ + channels = 1 if image.ndim == 2 else image.shape[2] + if self.grid is None: + height, width = image.shape[0:2] + height = self.image_height or height + width = self.image_width or width + channels = self.image_channels or channels + self.set_image_size((height, width)) + self.set_image_channels(channels) + self.init_grid() + if image.shape[0:2] != (self.image_height, self.image_width): + image = resize_image(image, (self.image_width, self.image_height)) + y = self.border_top + i * (self.image_height + self.row_spacing) + x = self.border_left + j * (self.image_width + self.col_spacing) + self.grid[y:y + self.image_height, + x:x + self.image_width, + :channels] = image + + def visualize_collection(self, + images, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=False, + is_row_major=True): + """Visualizes a collection of images one by one.""" + self.grid = None + self.reset(grid_size=len(images), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, image in enumerate(images): + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.add(row_idx, col_idx, image) + if save_path: + self.save(save_path) + + def visualize_list(self, + image_list, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=False, + is_row_major=True): + """Visualizes a list of image files.""" + self.grid = None + self.reset(grid_size=len(image_list), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, filename in enumerate(image_list): + image = load_image(filename) + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.add(row_idx, col_idx, image) + if save_path: + self.save(save_path) + + def visualize_directory(self, + directory, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=False, + is_row_major=True): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list=image_list, + save_path=save_path, + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait, + is_row_major=is_row_major) + + def save(self, path): + """Saves the grid.""" + save_image(path, self.grid) diff --git a/utils/visualizers/html_visualizer.py b/utils/visualizers/html_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb63385af54021febb4791bdda6534324d2d901 --- /dev/null +++ b/utils/visualizers/html_visualizer.py @@ -0,0 +1,438 @@ +# python3.7 +"""Contains the visualizer to visualize images with HTML page.""" + +import os +import base64 +import cv2 +import numpy as np +from bs4 import BeautifulSoup + +from ..image_utils import get_grid_shape +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['HtmlVisualizer', 'HtmlReader'] + + +def get_sortable_html_header(column_name_list, sort_by_ascending=False): + """Gets header for sortable HTML page. + + Basically, the HTML page contains a sortable table, where user can sort the + rows by a particular column by clicking the column head. + + Example: + + column_name_list = [name_1, name_2, name_3] + header = get_sortable_html_header(column_name_list) + footer = get_sortable_html_footer() + sortable_table = ... + html_page = header + sortable_table + footer + + Args: + column_name_list: List of column header names. + sort_by_ascending: Default sorting order. If set as `True`, the HTML + page will be sorted by ascending order when the header is clicked + for the first time. + + Returns: + A string, which represents for the header for a sortable HTML page. + """ + header = '\n'.join([ + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '']) + for idx, name in enumerate(column_name_list): + header += f' \n' + header += '\n' + header += '\n' + header += '\n' + + return header + + +def get_sortable_html_footer(): + """Gets footer for sortable HTML page. + + Check function `get_sortable_html_header()` for more details. + """ + return '\n
{name}
\n\n\n\n' + + +def encode_image_to_html_str(image, image_size=None): + """Encodes an image to HTML language. + + NOTE: Input image is always assumed to be with `RGB` channel order. + + Args: + image: The input image to encode. Should be with `RGB` channel order. + image_size: This field is used to resize the image before encoding. + `None` disables resizing. (default: None) + + Returns: + A string that represents the encoded image. + """ + if image is None: + return '' + + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + if image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + elif image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA) + + # Resize the image if needed. + height, width = parse_image_size(image_size) + height = height or image.shape[0] + width = width or image.shape[1] + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + + # Encode the image to HTML-format string. + if image.shape[2] == 4: # Use `png` to encoder RGBA image. + encoded = cv2.imencode('.png', image)[1].tostring() + encoded_base64 = base64.b64encode(encoded).decode('utf-8') + html_str = f'' + else: + encoded = cv2.imencode('.jpg', image)[1].tostring() + encoded_base64 = base64.b64encode(encoded).decode('utf-8') + html_str = f'' + + return html_str + + +def decode_html_str_to_image(html_str, image_size=None): + """Decodes an image from HTML string. + + Args: + html_str: An HTML string that represents an image. + image_size: This field is used to resize the image after decoding. + `None` disables resizing. (default: None) + + Returns: + An image with `RGB` channel order. + """ + if not html_str: + return None + + assert isinstance(html_str, str) + image_str = html_str.split(',')[-1].strip() + encoded_image = base64.b64decode(image_str) + encoded_image_numpy = np.frombuffer(encoded_image, dtype=np.uint8) + image = cv2.imdecode(encoded_image_numpy, flags=cv2.IMREAD_UNCHANGED) + + if image.ndim == 2: + image = image[:, :, np.newaxis] + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + if image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + + # Resize the image if needed. + height, width = parse_image_size(image_size) + height = height or image.shape[0] + width = width or image.shape[1] + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + + return image + + +class HtmlVisualizer(object): + """Defines the HTML visualizer that visualizes images on an HTML page. + + This class can be used to visualize image results on an HTML page. + Basically, it is based on an HTML-format sorted table with helper functions + `get_sortable_html_header()`, `get_sortable_html_footer()`, and + `encode_image_to_html_str()`. To simplify the usage, specifying the + following fields are enough to create a visualization page: + + (1) num_rows: Number of rows of the table (header-row exclusive). + (2) num_cols: Number of columns of the table. + (3) header_contents (optional): Title of each column. + + NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` + automatically. + + Example: + + html = HtmlVisualizer(num_rows, num_cols) + html.set_headers([...]) + for i in range(num_rows): + for j in range(num_cols): + html.set_cell(i, j, text=..., image=..., highlight=False) + html.save('visualize.html') + """ + + def __init__(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=True, + image_size=None): + """Initializes the html visualizer. + + Args: + grid_size: Total number of cells, i.e., height * width. (default: 0) + num_rows: Number of rows. (default: 0) + num_cols: Number of columns. (default: 0) + is_portrait: Whether the HTML page should be portrait or landscape. + This is only used when it requires to compute `num_rows` and + `num_cols` automatically. See function `get_grid_shape()` in + file `./image_utils.py` for details. (default: True) + image_size: Size to visualize each image. (default: None) + """ + self.reset(grid_size, num_rows, num_cols, is_portrait) + self.set_image_size(image_size) + + def reset(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=True): + """Resets the HTML page with new number of rows and columns.""" + if grid_size > 0: + num_rows, num_cols = get_grid_shape(grid_size, + height=num_rows, + width=num_cols, + is_portrait=is_portrait) + self.grid_size = num_rows * num_cols + self.num_rows = num_rows + self.num_cols = num_cols + self.headers = ['' for _ in range(self.num_cols)] + self.cells = [[{ + 'text': '', + 'image': '', + 'highlight': False, + } for _ in range(self.num_cols)] for _ in range(self.num_rows)] + + def set_image_size(self, image_size=None): + """Sets the image size of each cell in the HTML page.""" + self.image_size = image_size + + def set_header(self, col_idx, content): + """Sets the content of a particular header by column index.""" + self.headers[col_idx] = content + + def set_headers(self, contents): + """Sets the contents of all headers.""" + assert isinstance(contents, (list, tuple)) + assert len(contents) == self.num_cols + for col_idx, content in enumerate(contents): + self.set_header(col_idx, content) + + def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False): + """Sets the content of a particular cell. + + Basically, a cell contains some text as well as an image. Both text and + image can be empty. + + NOTE: The image is assumed to be with `RGB` channel order. + + Args: + row_idx: Row index of the cell to edit. + col_idx: Column index of the cell to edit. + text: Text to add into the target cell. (default: None) + image: Image to show in the target cell. Should be with `RGB` + channel order. (default: None) + highlight: Whether to highlight this cell. (default: False) + """ + self.cells[row_idx][col_idx]['text'] = text + self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str( + image, self.image_size) + self.cells[row_idx][col_idx]['highlight'] = bool(highlight) + + def visualize_collection(self, + images, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=True, + is_row_major=True): + """Visualizes a collection of images one by one.""" + self.reset(grid_size=len(images), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, image in enumerate(images): + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.set_cell(row_idx, col_idx, text=f'Index {idx:03d}', + image=image) + if save_path: + self.save(save_path) + + def visualize_list(self, + image_list, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=True, + is_row_major=True): + """Visualizes a list of image files.""" + self.reset(grid_size=len(image_list), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, filename in enumerate(image_list): + basename = os.path.basename(filename) + image = load_image(filename) + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.set_cell(row_idx, col_idx, + text=f'{basename} (index {idx:03d})', image=image) + if save_path: + self.save(save_path) + + def visualize_directory(self, + directory, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=True, + is_row_major=True): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list=image_list, + save_path=save_path, + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait, + is_row_major=is_row_major) + + def save(self, path): + """Saves the HTML page.""" + html = '' + for i in range(self.num_rows): + html += '\n' + for j in range(self.num_cols): + text = self.cells[i][j]['text'] + image = self.cells[i][j]['image'] + if self.cells[i][j]['highlight']: + color = ' bgcolor="#FF8888"' + else: + color = '' + if text: + html += f' {text}

{image}\n' + else: + html += f' {image}\n' + html += '\n' + + header = get_sortable_html_header(self.headers) + footer = get_sortable_html_footer() + + with open(path, 'w') as f: + f.write(header + html + footer) + + +class HtmlReader(object): + """Defines the HTML page reader. + + This class can be used to parse results from the visualization page + generated by `HtmlVisualizer`. + + Example: + + html = HtmlReader(html_path) + for j in range(html.num_cols): + header = html.get_header(j) + for i in range(html.num_rows): + for j in range(html.num_cols): + text = html.get_text(i, j) + image = html.get_image(i, j, image_size=None) + """ + def __init__(self, path): + """Initializes by loading the content from file.""" + self.path = path + + # Load content. + with open(path, 'r') as f: + self.html = BeautifulSoup(f, 'html.parser') + + # Parse headers. + thead = self.html.find('thead') + headers = thead.findAll('th') + self.headers = [] + for header in headers: + self.headers.append(header.text) + self.num_cols = len(self.headers) + + # Parse cells. + tbody = self.html.find('tbody') + rows = tbody.findAll('tr') + self.cells = [] + for row in rows: + cells = row.findAll('td') + self.cells.append([]) + for cell in cells: + self.cells[-1].append({ + 'text': cell.text, + 'image': cell.find('img')['src'], + }) + assert len(self.cells[-1]) == self.num_cols + self.num_rows = len(self.cells) + + def get_header(self, j): + """Gets header for a particular column.""" + return self.headers[j] + + def get_text(self, i, j): + """Gets text from a particular cell.""" + return self.cells[i][j]['text'] + + def get_image(self, i, j, image_size=None): + """Gets image from a particular cell.""" + return decode_html_str_to_image(self.cells[i][j]['image'], image_size) diff --git a/utils/visualizers/test.py b/utils/visualizers/test.py new file mode 100644 index 0000000000000000000000000000000000000000..765ebf9c721b0792fb373ecb515ebf188f728df0 --- /dev/null +++ b/utils/visualizers/test.py @@ -0,0 +1,97 @@ +# python3.7 +"""Unit test for visualizer.""" + +import os +import skvideo.datasets + +from ..image_utils import save_image +from . import GridVisualizer +from . import HtmlVisualizer +from . import HtmlReader +from . import GifVisualizer +from . import VideoVisualizer +from . import VideoReader + +__all__ = ['test_visualizer'] + +_TEST_DIR = 'visualizer_test' + + +def test_visualizer(test_dir=_TEST_DIR): + """Tests visualizers.""" + print('========== Start Visualizer Test ==========') + + frame_dir = os.path.join(test_dir, 'test_frames') + os.makedirs(frame_dir, exist_ok=True) + + print('===== Testing `VideoReader` =====') + # Total 132 frames, with size (720, 1080). + video_reader = VideoReader(skvideo.datasets.bigbuckbunny()) + frame_height = video_reader.frame_height + frame_width = video_reader.frame_width + frame_size = (frame_height, frame_width) + half_size = (frame_height // 2, frame_width // 2) + # Save frames as the test set. + for idx in range(80): + frame = video_reader.read() + save_image(os.path.join(frame_dir, f'{idx:02d}.png'), frame) + + print('===== Testing `GirdVisualizer` =====') + grid_visualizer = GridVisualizer() + grid_visualizer.set_row_spacing(30) + grid_visualizer.set_col_spacing(30) + grid_visualizer.set_background(use_black=True) + path = os.path.join(test_dir, 'portrait_row_major_ori_space30_black.png') + grid_visualizer.visualize_directory(frame_dir, path, + is_portrait=True, is_row_major=True) + path = os.path.join( + test_dir, 'landscape_col_major_downsample_space15_white.png') + grid_visualizer.set_image_size(half_size) + grid_visualizer.set_row_spacing(15) + grid_visualizer.set_col_spacing(15) + grid_visualizer.set_background(use_black=False) + grid_visualizer.visualize_directory(frame_dir, path, + is_portrait=False, is_row_major=False) + + print('===== Testing `HtmlVisualizer` =====') + html_visualizer = HtmlVisualizer() + path = os.path.join(test_dir, 'portrait_col_major_ori.html') + html_visualizer.visualize_directory(frame_dir, path, + is_portrait=True, is_row_major=False) + path = os.path.join(test_dir, 'landscape_row_major_downsample.html') + html_visualizer.set_image_size(half_size) + html_visualizer.visualize_directory(frame_dir, path, + is_portrait=False, is_row_major=True) + + print('===== Testing `HtmlReader` =====') + path = os.path.join(test_dir, 'landscape_row_major_downsample.html') + html_reader = HtmlReader(path) + for j in range(html_reader.num_cols): + assert html_reader.get_header(j) == '' + parsed_dir = os.path.join(test_dir, 'parsed_frames') + os.makedirs(parsed_dir, exist_ok=True) + for i in range(html_reader.num_rows): + for j in range(html_reader.num_cols): + idx = i * html_reader.num_cols + j + assert html_reader.get_text(i, j).endswith(f'(index {idx:03d})') + image = html_reader.get_image(i, j, image_size=frame_size) + assert image.shape[0:2] == frame_size + save_image(os.path.join(parsed_dir, f'{idx:02d}.png'), image) + + print('===== Testing `GifVisualizer` =====') + gif_visualizer = GifVisualizer() + path = os.path.join(test_dir, 'gif_ori.gif') + gif_visualizer.visualize_directory(frame_dir, path) + gif_visualizer.set_image_size(half_size) + path = os.path.join(test_dir, 'gif_downsample.gif') + gif_visualizer.visualize_directory(frame_dir, path) + + print('===== Testing `VideoVisualizer` =====') + video_visualizer = VideoVisualizer() + path = os.path.join(test_dir, 'video_ori.mp4') + video_visualizer.visualize_directory(frame_dir, path) + path = os.path.join(test_dir, 'video_downsample.mp4') + video_visualizer.set_frame_size(half_size) + video_visualizer.visualize_directory(frame_dir, path) + + print('========== Finish Visualizer Test ==========') diff --git a/utils/visualizers/video_visualizer.py b/utils/visualizers/video_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c3a5934224edc5d557ad1b1458d4f776148a75 --- /dev/null +++ b/utils/visualizers/video_visualizer.py @@ -0,0 +1,173 @@ +# python3.7 +"""Contains the visualizer to visualize images as a video. + +This file relies on `FFmpeg`. Use `sudo apt-get install ffmpeg` and +`brew install ffmpeg` to install on Ubuntu and MacOS respectively. +""" + +import os.path +from skvideo.io import FFmpegWriter +from skvideo.io import FFmpegReader + +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['VideoVisualizer', 'VideoReader'] + + +class VideoVisualizer(object): + """Defines the video visualizer that presents images as a video.""" + + def __init__(self, + path=None, + frame_size=None, + fps=25.0, + codec='libx264', + pix_fmt='yuv420p', + crf=1): + """Initializes the video visualizer. + + Args: + path: Path to write the video. (default: None) + frame_size: Frame size, i.e., (height, width). (default: None) + fps: Frames per second. (default: 24) + codec: Codec. (default: `libx264`) + pix_fmt: Pixel format. (default: `yuv420p`) + crf: Constant rate factor, which controls the compression. The + larger this field is, the higher compression and lower quality. + `0` means no compression and consequently the highest quality. + To enable QuickTime playing (requires YUV to be 4:2:0, but + `crf = 0` results YUV to be 4:4:4), please set this field as + at least 1. (default: 1) + """ + self.set_path(path) + self.set_frame_size(frame_size) + self.set_fps(fps) + self.set_codec(codec) + self.set_pix_fmt(pix_fmt) + self.set_crf(crf) + self.video = None + + def set_path(self, path=None): + """Sets the path to save the video.""" + self.path = path + + def set_frame_size(self, frame_size=None): + """Sets the video frame size.""" + height, width = parse_image_size(frame_size) + self.frame_height = height + self.frame_width = width + + def set_fps(self, fps=25.0): + """Sets the FPS (frame per second) of the video.""" + self.fps = fps + + def set_codec(self, codec='libx264'): + """Sets the video codec.""" + self.codec = codec + + def set_pix_fmt(self, pix_fmt='yuv420p'): + """Sets the video pixel format.""" + self.pix_fmt = pix_fmt + + def set_crf(self, crf=1): + """Sets the CRF (constant rate factor) of the video.""" + self.crf = crf + + def init_video(self): + """Initializes an empty video with expected settings.""" + assert not os.path.exists(self.path), f'Video `{self.path}` existed!' + assert self.frame_height > 0 + assert self.frame_width > 0 + + video_setting = { + '-r': f'{self.fps:.2f}', + '-s': f'{self.frame_width}x{self.frame_height}', + '-vcodec': f'{self.codec}', + '-crf': f'{self.crf}', + '-pix_fmt': f'{self.pix_fmt}', + } + self.video = FFmpegWriter(self.path, outputdict=video_setting) + + def add(self, frame): + """Adds a frame into the video visualizer. + + NOTE: The input frame is assumed to be with `RGB` channel order. + """ + if self.video is None: + height, width = frame.shape[0:2] + height = self.frame_height or height + width = self.frame_width or width + self.set_frame_size((height, width)) + self.init_video() + if frame.shape[0:2] != (self.frame_height, self.frame_width): + frame = resize_image(frame, (self.frame_width, self.frame_height)) + self.video.writeFrame(frame) + + def visualize_collection(self, images, save_path=None): + """Visualizes a collection of images one by one.""" + if save_path is not None and save_path != self.path: + self.save() + self.set_path(save_path) + for image in images: + self.add(image) + self.save() + + def visualize_list(self, image_list, save_path=None): + """Visualizes a list of image files.""" + if save_path is not None and save_path != self.path: + self.save() + self.set_path(save_path) + for filename in image_list: + image = load_image(filename) + self.add(image) + self.save() + + def visualize_directory(self, directory, save_path=None): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list, save_path) + + def save(self): + """Saves the video by closing the file.""" + if self.video is not None: + self.video.close() + self.video = None + self.set_path(None) + + +class VideoReader(object): + """Defines the video reader. + + This class can be used to read frames from a given video. + + NOTE: Each frame can be read only once. + TODO: Fix this? + """ + + def __init__(self, path, inputdict=None): + """Initializes the video reader by loading the video from disk.""" + self.path = path + self.video = FFmpegReader(path, inputdict=inputdict) + + self.length = self.video.inputframenum + self.frame_height = self.video.inputheight + self.frame_width = self.video.inputwidth + self.fps = self.video.inputfps + self.pix_fmt = self.video.pix_fmt + + def __del__(self): + """Releases the opened video.""" + self.video.close() + + def read(self, image_size=None): + """Reads the next frame.""" + frame = next(self.video.nextFrame()) + height, width = parse_image_size(image_size) + height = height or frame.shape[0] + width = width or frame.shape[1] + if frame.shape[0:2] != (height, width): + frame = resize_image(frame, (width, height)) + return frame