diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b5ee9bf994cc9441cb659c3527160b4ee5bcb33 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,97 @@ +Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. + + +NVIDIA Source Code License for StyleGAN3 + + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. Notwithstanding + the foregoing, NVIDIA and its affiliates may use the Work and any + derivative works commercially. As used herein, "non-commercially" + means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grant in Section 2.1) will terminate + immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor’s or its affiliates’ names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grant in Section 2.1) will + terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5c169484e7883263acb0404500e611bad4971279 --- /dev/null +++ b/app.py @@ -0,0 +1,169 @@ +import os +os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html") +os.system("git clone https://github.com/NVlabs/stylegan3") +os.system("git clone https://github.com/openai/CLIP") +os.system("pip install -e ./CLIP") +os.system("pip install einops ninja scipy numpy Pillow tqdm") +import sys +sys.path.append('./CLIP') +sys.path.append('./stylegan3') +import io +import os, time +import pickle +import shutil +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import requests +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +import clip +from tqdm.notebook import tqdm +from torchvision.transforms import Compose, Resize, ToTensor, Normalize +from einops import rearrange +device = torch.device('cuda:0') +def fetch(url_or_path): + if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): + r = requests.get(url_or_path) + r.raise_for_status() + fd = io.BytesIO() + fd.write(r.content) + fd.seek(0) + return fd + return open(url_or_path, 'rb') +def fetch_model(url_or_path): + basename = os.path.basename(url_or_path) + if os.path.exists(basename): + return basename + else: + os.system("wget -c '{url_or_path}'") + return basename +def norm1(prompt): + "Normalize to the unit sphere." + return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt() +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) +class MakeCutouts(torch.nn.Module): + def __init__(self, cut_size, cutn, cut_pow=1.): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.cut_pow = cut_pow + def forward(self, input): + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(self.cutn): + size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) +make_cutouts = MakeCutouts(224, 32, 0.5) +def embed_image(image): + n = image.shape[0] + cutouts = make_cutouts(image) + embeds = clip_model.embed_cutout(cutouts) + embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n) + return embeds +def embed_url(url): + image = Image.open(fetch(url)).convert('RGB') + return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0) +class CLIP(object): + def __init__(self): + clip_model = "ViT-B/32" + self.model, _ = clip.load(clip_model) + self.model = self.model.requires_grad_(False) + self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + @torch.no_grad() + def embed_text(self, prompt): + "Normalized clip text embedding." + return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float()) + def embed_cutout(self, image): + "Normalized clip image embedding." + return norm1(self.model.encode_image(self.normalize(image))) + +clip_model = CLIP() +# Load stylegan model +base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/" +model_name = "stylegan3-t-ffhqu-1024x1024.pkl" +#model_name = "stylegan3-r-metfacesu-1024x1024.pkl" +#model_name = "stylegan3-t-afhqv2-512x512.pkl" +network_url = base_url + model_name +os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl") +with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp: + G = pickle.load(fp)['G_ema'].to(device) +zs = torch.randn([10000, G.mapping.z_dim], device=device) +w_stds = G.mapping(zs, None).std(0) + + +def inference(text): + target = clip_model.embed_text(text) + steps = 600 + seed = 2 + tf = Compose([ + Resize(224), + lambda x: torch.clamp((x+1)/2,min=0,max=1), + ]) + torch.manual_seed(seed) + timestring = time.strftime('%Y%m%d%H%M%S') + with torch.no_grad(): + qs = [] + losses = [] + for _ in range(8): + q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds + images = G.synthesis(q * w_stds + G.mapping.w_avg) + embeds = embed_image(images.add(1).div(2)) + loss = spherical_dist_loss(embeds, target).mean(0) + i = torch.argmin(loss) + qs.append(q[i]) + losses.append(loss[i]) + qs = torch.stack(qs) + losses = torch.stack(losses) + print(losses) + print(losses.shape, qs.shape) + i = torch.argmin(losses) + q = qs[i].unsqueeze(0) + q.requires_grad_() + q_ema = q + opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999)) + loop = tqdm(range(steps)) + for i in loop: + opt.zero_grad() + w = q * w_stds + image = G.synthesis(w + G.mapping.w_avg, noise_mode='const') + embed = embed_image(image.add(1).div(2)) + loss = spherical_dist_loss(embed, target).mean() + loss.backward() + opt.step() + loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item()) + q_ema = q_ema * 0.9 + q * 0.1 + image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const') + if i % 10 == 0: + display(TF.to_pil_image(tf(image)[0])) + pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1)) + #os.makedirs(f'samples/{timestring}', exist_ok=True) + #pil_image.save(f'samples/{timestring}/{i:04}.jpg') + return pil_image + + +title = "StyleGAN+CLIP_with_Latent_Bootstraping" +description = "Gradio demo for StyleGAN+CLIP_with_Latent_Bootstraping. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." +article = "

colab by https://twitter.com/EricHallahan Colab

" +examples = [['elon musk']] +gr.Interface( + inference, + "text", + gr.outputs.Image(type="pil", label="Output"), + title=title, + description=description, + article=article, + enable_queue=True, + examples=examples + ).launch(debug=True) diff --git a/avg_spectra.py b/avg_spectra.py new file mode 100644 index 0000000000000000000000000000000000000000..afaef87de54e49df230b432b52fda92667d17667 --- /dev/null +++ b/avg_spectra.py @@ -0,0 +1,276 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Compare average power spectra between real and generated images, +or between multiple generators.""" + +import os +import numpy as np +import torch +import torch.fft +import scipy.ndimage +import matplotlib.pyplot as plt +import click +import tqdm +import dnnlib + +import legacy +from training import dataset + +#---------------------------------------------------------------------------- +# Setup an iterator for streaming images, in uint8 NCHW format, based on the +# respective command line options. + +def stream_source_images(source, num, seed, device, data_loader_kwargs=None): # => num_images, image_size, image_iter + ext = source.split('.')[-1].lower() + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + if ext == 'pkl': + if num is None: + raise click.ClickException('--num is required when --source points to network pickle') + with dnnlib.util.open_url(source) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) + def generate_image(seed): + rnd = np.random.RandomState(seed) + z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(device) + c = torch.zeros([1, G.c_dim], device=device) + if G.c_dim > 0: + c[:, rnd.randint(G.c_dim)] = 1 + return (G(z=z, c=c) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + _ = generate_image(seed) # warm up + image_iter = (generate_image(seed + idx) for idx in range(num)) + return num, G.img_resolution, image_iter + + elif ext == 'zip' or os.path.isdir(source): + dataset_obj = dataset.ImageFolderDataset(path=source, max_size=num, random_seed=seed) + if num is not None and num != len(dataset_obj): + raise click.ClickException(f'--source contains fewer than {num} images') + data_loader = torch.utils.data.DataLoader(dataset_obj, batch_size=1, **data_loader_kwargs) + image_iter = (image.to(device) for image, _label in data_loader) + return len(dataset_obj), dataset_obj.resolution, image_iter + + else: + raise click.ClickException('--source must point to network pickle, dataset zip, or directory') + +#---------------------------------------------------------------------------- +# Load average power spectrum from the specified .npz file and construct +# the corresponding heatmap for visualization. + +def construct_heatmap(npz_file, smooth): + npz_data = np.load(npz_file) + spectrum = npz_data['spectrum'] + image_size = npz_data['image_size'] + hmap = np.log10(spectrum) * 10 # dB + hmap = np.fft.fftshift(hmap) + hmap = np.concatenate([hmap, hmap[:1, :]], axis=0) + hmap = np.concatenate([hmap, hmap[:, :1]], axis=1) + if smooth > 0: + sigma = spectrum.shape[0] / image_size * smooth + hmap = scipy.ndimage.gaussian_filter(hmap, sigma=sigma, mode='nearest') + return hmap, image_size + +#---------------------------------------------------------------------------- + +@click.group() +def main(): + """Compare average power spectra between real and generated images, + or between multiple generators. + + Example: + + \b + # Calculate dataset mean and std, needed in subsequent steps. + python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip + + \b + # Calculate average spectrum for the training data. + python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \\ + --dest=tmp/training-data.npz --mean=112.684 --std=69.509 + + \b + # Calculate average spectrum for a pre-trained generator. + python avg_spectra.py calc \\ + --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \\ + --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000 + + \b + # Display results. + python avg_spectra.py heatmap tmp/training-data.npz + python avg_spectra.py heatmap tmp/stylegan3-r.npz + python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz + + \b + # Save as PNG. + python avg_spectra.py heatmap tmp/training-data.npz --save=tmp/training-data.png --dpi=300 + python avg_spectra.py heatmap tmp/stylegan3-r.npz --save=tmp/stylegan3-r.png --dpi=300 + python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz --save=tmp/slices.png --dpi=300 + """ + +#---------------------------------------------------------------------------- + +@main.command() +@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True) +@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1)) +@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +def stats(source, num, seed, device=torch.device('cuda')): + """Calculate dataset mean and standard deviation needed by 'calc'.""" + torch.multiprocessing.set_start_method('spawn') + num_images, _image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device) + + # Accumulate moments. + moments = torch.zeros([3], dtype=torch.float64, device=device) + for image in tqdm.tqdm(image_iter, total=num_images): + image = image.to(torch.float64) + moments += torch.stack([torch.ones_like(image).sum(), image.sum(), image.square().sum()]) + moments = moments / moments[0] + + # Compute mean and standard deviation. + mean = moments[1] + std = (moments[2] - moments[1].square()).sqrt() + print(f'--mean={mean:g} --std={std:g}') + +#---------------------------------------------------------------------------- + +@main.command() +@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True) +@click.option('--dest', help='Where to store the result', metavar='NPZ', required=True) +@click.option('--mean', help='Dataset mean for whitening', metavar='FLOAT', type=float, required=True) +@click.option('--std', help='Dataset standard deviation for whitening', metavar='FLOAT', type=click.FloatRange(min=0), required=True) +@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1)) +@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +@click.option('--beta', help='Shape parameter for the Kaiser window', metavar='FLOAT', type=click.FloatRange(min=0), default=8, show_default=True) +@click.option('--interp', help='Frequency-domain interpolation factor', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True) +def calc(source, dest, mean, std, num, seed, beta, interp, device=torch.device('cuda')): + """Calculate average power spectrum and store it in .npz file.""" + torch.multiprocessing.set_start_method('spawn') + num_images, image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device) + spectrum_size = image_size * interp + padding = spectrum_size - image_size + + # Setup window function. + window = torch.kaiser_window(image_size, periodic=False, beta=beta, device=device) + window *= window.square().sum().rsqrt() + window = window.ger(window).unsqueeze(0).unsqueeze(1) + + # Accumulate power spectrum. + spectrum = torch.zeros([spectrum_size, spectrum_size], dtype=torch.float64, device=device) + for image in tqdm.tqdm(image_iter, total=num_images): + image = (image.to(torch.float64) - mean) / std + image = torch.nn.functional.pad(image * window, [0, padding, 0, padding]) + spectrum += torch.fft.fftn(image, dim=[2,3]).abs().square().mean(dim=[0,1]) + spectrum /= num_images + + # Save result. + if os.path.dirname(dest): + os.makedirs(os.path.dirname(dest), exist_ok=True) + np.savez(dest, spectrum=spectrum.cpu().numpy(), image_size=image_size) + +#---------------------------------------------------------------------------- + +@main.command() +@click.argument('npz-file', nargs=1) +@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]') +@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True) +@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=1.25, show_default=True) +def heatmap(npz_file, save, smooth, dpi): + """Visualize 2D heatmap based on the given .npz file.""" + hmap, image_size = construct_heatmap(npz_file=npz_file, smooth=smooth) + + # Setup plot. + plt.figure(figsize=[6, 4.8], dpi=dpi, tight_layout=True) + freqs = np.linspace(-0.5, 0.5, num=hmap.shape[0], endpoint=True) * image_size + ticks = np.linspace(freqs[0], freqs[-1], num=5, endpoint=True) + levels = np.linspace(-40, 20, num=13, endpoint=True) + + # Draw heatmap. + plt.xlim(ticks[0], ticks[-1]) + plt.ylim(ticks[0], ticks[-1]) + plt.xticks(ticks) + plt.yticks(ticks) + plt.contourf(freqs, freqs, hmap, levels=levels, extend='both', cmap='Blues') + plt.gca().set_aspect('equal') + plt.colorbar(ticks=levels) + plt.contour(freqs, freqs, hmap, levels=levels, extend='both', linestyles='solid', linewidths=1, colors='midnightblue', alpha=0.2) + + # Display or save. + if save is None: + plt.show() + else: + if os.path.dirname(save): + os.makedirs(os.path.dirname(save), exist_ok=True) + plt.savefig(save) + +#---------------------------------------------------------------------------- + +@main.command() +@click.argument('npz-files', nargs=-1, required=True) +@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]') +@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True) +@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) +def slices(npz_files, save, dpi, smooth): + """Visualize 1D slices based on the given .npz files.""" + cases = [dnnlib.EasyDict(npz_file=npz_file) for npz_file in npz_files] + for c in cases: + c.hmap, c.image_size = construct_heatmap(npz_file=c.npz_file, smooth=smooth) + c.label = os.path.splitext(os.path.basename(c.npz_file))[0] + + # Check consistency. + image_size = cases[0].image_size + hmap_size = cases[0].hmap.shape[0] + if any(c.image_size != image_size or c.hmap.shape[0] != hmap_size for c in cases): + raise click.ClickException('All .npz must have the same resolution') + + # Setup plot. + plt.figure(figsize=[12, 4.6], dpi=dpi, tight_layout=True) + hmap_center = hmap_size // 2 + hmap_range = np.arange(hmap_center, hmap_size) + freqs0 = np.linspace(0, image_size / 2, num=(hmap_size // 2 + 1), endpoint=True) + freqs45 = np.linspace(0, image_size / np.sqrt(2), num=(hmap_size // 2 + 1), endpoint=True) + xticks0 = np.linspace(freqs0[0], freqs0[-1], num=9, endpoint=True) + xticks45 = np.round(np.linspace(freqs45[0], freqs45[-1], num=9, endpoint=True)) + yticks = np.linspace(-50, 30, num=9, endpoint=True) + + # Draw 0 degree slice. + plt.subplot(1, 2, 1) + plt.title('0\u00b0 slice') + plt.xlim(xticks0[0], xticks0[-1]) + plt.ylim(yticks[0], yticks[-1]) + plt.xticks(xticks0) + plt.yticks(yticks) + for c in cases: + plt.plot(freqs0, c.hmap[hmap_center, hmap_range], label=c.label) + plt.grid() + plt.legend(loc='upper right') + + # Draw 45 degree slice. + plt.subplot(1, 2, 2) + plt.title('45\u00b0 slice') + plt.xlim(xticks45[0], xticks45[-1]) + plt.ylim(yticks[0], yticks[-1]) + plt.xticks(xticks45) + plt.yticks(yticks) + for c in cases: + plt.plot(freqs45, c.hmap[hmap_range, hmap_range], label=c.label) + plt.grid() + plt.legend(loc='upper right') + + # Display or save. + if save is None: + plt.show() + else: + if os.path.dirname(save): + os.makedirs(os.path.dirname(save), exist_ok=True) + plt.savefig(save) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/calc_metrics.py b/calc_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..74a398a407f56e749e3a88eb9d8ff976191758f4 --- /dev/null +++ b/calc_metrics.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Calculate quality metrics for previous training run or pretrained network pickle.""" + +import os +import click +import json +import tempfile +import copy +import torch + +import dnnlib +import legacy +from metrics import metric_main +from metrics import metric_utils +from torch_utils import training_stats +from torch_utils import custom_ops +from torch_utils import misc +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- + +def subprocess_fn(rank, args, temp_dir): + dnnlib.util.Logger(should_flush=True) + + # Init torch.distributed. + if args.num_gpus > 1: + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + if rank != 0 or not args.verbose: + custom_ops.verbosity = 'none' + + # Configure torch. + device = torch.device('cuda', rank) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + conv2d_gradfix.enabled = True + + # Print network summary. + G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) + if rank == 0 and args.verbose: + z = torch.empty([1, G.z_dim], device=device) + c = torch.empty([1, G.c_dim], device=device) + misc.print_module_summary(G, [z, c]) + + # Calculate each metric. + for metric in args.metrics: + if rank == 0 and args.verbose: + print(f'Calculating {metric}...') + progress = metric_utils.ProgressMonitor(verbose=args.verbose) + result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, + num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) + if rank == 0: + metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) + if rank == 0 and args.verbose: + print() + + # Done. + if rank == 0 and args.verbose: + print('Exiting...') + +#---------------------------------------------------------------------------- + +def parse_comma_separated_list(s): + if isinstance(s, list): + return s + if s is None or s.lower() == 'none' or s == '': + return [] + return s.split(',') + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) +@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) +@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') +@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') +@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) +@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) + +def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): + """Calculate quality metrics for previous training run or pretrained network pickle. + + Examples: + + \b + # Previous training run: look up options automatically, save result to JSONL file. + python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ + --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl + + \b + # Pre-trained network pickle: specify dataset explicitly, print result to stdout. + python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl + + \b + Recommended metrics: + fid50k_full Frechet inception distance against the full dataset. + kid50k_full Kernel inception distance against the full dataset. + pr50k3_full Precision and recall againt the full dataset. + ppl2_wend Perceptual path length in W, endpoints, full image. + eqt50k_int Equivariance w.r.t. integer translation (EQ-T). + eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). + eqr50k Equivariance w.r.t. rotation (EQ-R). + + \b + Legacy metrics: + fid50k Frechet inception distance against 50k real images. + kid50k Kernel inception distance against 50k real images. + pr50k3 Precision and recall against 50k real images. + is50k Inception score for CIFAR-10. + """ + dnnlib.util.Logger(should_flush=True) + + # Validate arguments. + args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) + if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): + ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) + if not args.num_gpus >= 1: + ctx.fail('--gpus must be at least 1') + + # Load network. + if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): + ctx.fail('--network must point to a file or URL') + if args.verbose: + print(f'Loading network from "{network_pkl}"...') + with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: + network_dict = legacy.load_network_pkl(f) + args.G = network_dict['G_ema'] # subclass of torch.nn.Module + + # Initialize dataset options. + if data is not None: + args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) + elif network_dict['training_set_kwargs'] is not None: + args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) + else: + ctx.fail('Could not look up dataset options; please specify --data') + + # Finalize dataset options. + args.dataset_kwargs.resolution = args.G.img_resolution + args.dataset_kwargs.use_labels = (args.G.c_dim != 0) + if mirror is not None: + args.dataset_kwargs.xflip = mirror + + # Print dataset options. + if args.verbose: + print('Dataset options:') + print(json.dumps(args.dataset_kwargs, indent=2)) + + # Locate run dir. + args.run_dir = None + if os.path.isfile(network_pkl): + pkl_dir = os.path.dirname(network_pkl) + if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): + args.run_dir = pkl_dir + + # Launch processes. + if args.verbose: + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + if args.num_gpus == 1: + subprocess_fn(rank=0, args=args, temp_dir=temp_dir) + else: + torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + calc_metrics() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/dataset_tool.py b/dataset_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..e9382fb1265489053eaed0166385a10ef67965c2 --- /dev/null +++ b/dataset_tool.py @@ -0,0 +1,455 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Tool for creating ZIP/PNG based datasets.""" + +import functools +import gzip +import io +import json +import os +import pickle +import re +import sys +import tarfile +import zipfile +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +import click +import numpy as np +import PIL.Image +from tqdm import tqdm + +#---------------------------------------------------------------------------- + +def error(msg): + print('Error: ' + msg) + sys.exit(1) + +#---------------------------------------------------------------------------- + +def parse_tuple(s: str) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +def maybe_min(a: int, b: Optional[int]) -> int: + if b is not None: + return min(a, b) + return a + +#---------------------------------------------------------------------------- + +def file_ext(name: Union[str, Path]) -> str: + return str(name).split('.')[-1] + +#---------------------------------------------------------------------------- + +def is_image_ext(fname: Union[str, Path]) -> bool: + ext = file_ext(fname).lower() + return f'.{ext}' in PIL.Image.EXTENSION # type: ignore + +#---------------------------------------------------------------------------- + +def open_image_folder(source_dir, *, max_images: Optional[int]): + input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] + + # Load labels. + labels = {} + meta_fname = os.path.join(source_dir, 'dataset.json') + if os.path.isfile(meta_fname): + with open(meta_fname, 'r') as file: + labels = json.load(file)['labels'] + if labels is not None: + labels = { x[0]: x[1] for x in labels } + else: + labels = {} + + max_idx = maybe_min(len(input_images), max_images) + + def iterate_images(): + for idx, fname in enumerate(input_images): + arch_fname = os.path.relpath(fname, source_dir) + arch_fname = arch_fname.replace('\\', '/') + img = np.array(PIL.Image.open(fname)) + yield dict(img=img, label=labels.get(arch_fname)) + if idx >= max_idx-1: + break + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_image_zip(source, *, max_images: Optional[int]): + with zipfile.ZipFile(source, mode='r') as z: + input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] + + # Load labels. + labels = {} + if 'dataset.json' in z.namelist(): + with z.open('dataset.json', 'r') as file: + labels = json.load(file)['labels'] + if labels is not None: + labels = { x[0]: x[1] for x in labels } + else: + labels = {} + + max_idx = maybe_min(len(input_images), max_images) + + def iterate_images(): + with zipfile.ZipFile(source, mode='r') as z: + for idx, fname in enumerate(input_images): + with z.open(fname, 'r') as file: + img = PIL.Image.open(file) # type: ignore + img = np.array(img) + yield dict(img=img, label=labels.get(fname)) + if idx >= max_idx-1: + break + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): + import cv2 # pip install opencv-python # pylint: disable=import-error + import lmdb # pip install lmdb # pylint: disable=import-error + + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + max_idx = maybe_min(txn.stat()['entries'], max_images) + + def iterate_images(): + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + for idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.array(PIL.Image.open(io.BytesIO(value))) + yield dict(img=img, label=None) + if idx >= max_idx-1: + break + except: + print(sys.exc_info()[1]) + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_cifar10(tarball: str, *, max_images: Optional[int]): + images = [] + labels = [] + + with tarfile.open(tarball, 'r:gz') as tar: + for batch in range(1, 6): + member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') + with tar.extractfile(member) as file: + data = pickle.load(file, encoding='latin1') + images.append(data['data'].reshape(-1, 3, 32, 32)) + labels.append(data['labels']) + + images = np.concatenate(images) + labels = np.concatenate(labels) + images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC + assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_mnist(images_gz: str, *, max_images: Optional[int]): + labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') + assert labels_gz != images_gz + images = [] + labels = [] + + with gzip.open(images_gz, 'rb') as f: + images = np.frombuffer(f.read(), np.uint8, offset=16) + with gzip.open(labels_gz, 'rb') as f: + labels = np.frombuffer(f.read(), np.uint8, offset=8) + + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def make_transform( + transform: Optional[str], + output_width: Optional[int], + output_height: Optional[int] +) -> Callable[[np.ndarray], Optional[np.ndarray]]: + def scale(width, height, img): + w = img.shape[1] + h = img.shape[0] + if width == w and height == h: + return img + img = PIL.Image.fromarray(img) + ww = width if width is not None else w + hh = height if height is not None else h + img = img.resize((ww, hh), PIL.Image.LANCZOS) + return np.array(img) + + def center_crop(width, height, img): + crop = np.min(img.shape[:2]) + img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.LANCZOS) + return np.array(img) + + def center_crop_wide(width, height, img): + ch = int(np.round(width * img.shape[0] / img.shape[1])) + if img.shape[1] < width or ch < height: + return None + + img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.LANCZOS) + img = np.array(img) + + canvas = np.zeros([width, width, 3], dtype=np.uint8) + canvas[(width - height) // 2 : (width + height) // 2, :] = img + return canvas + + if transform is None: + return functools.partial(scale, output_width, output_height) + if transform == 'center-crop': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + 'transform') + return functools.partial(center_crop, output_width, output_height) + if transform == 'center-crop-wide': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + ' transform') + return functools.partial(center_crop_wide, output_width, output_height) + assert False, 'unknown transform' + +#---------------------------------------------------------------------------- + +def open_dataset(source, *, max_images: Optional[int]): + if os.path.isdir(source): + if source.rstrip('/').endswith('_lmdb'): + return open_lmdb(source, max_images=max_images) + else: + return open_image_folder(source, max_images=max_images) + elif os.path.isfile(source): + if os.path.basename(source) == 'cifar-10-python.tar.gz': + return open_cifar10(source, max_images=max_images) + elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': + return open_mnist(source, max_images=max_images) + elif file_ext(source) == 'zip': + return open_image_zip(source, max_images=max_images) + else: + assert False, 'unknown archive type' + else: + error(f'Missing input file or directory: {source}') + +#---------------------------------------------------------------------------- + +def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: + dest_ext = file_ext(dest) + + if dest_ext == 'zip': + if os.path.dirname(dest) != '': + os.makedirs(os.path.dirname(dest), exist_ok=True) + zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) + def zip_write_bytes(fname: str, data: Union[bytes, str]): + zf.writestr(fname, data) + return '', zip_write_bytes, zf.close + else: + # If the output folder already exists, check that is is + # empty. + # + # Note: creating the output directory is not strictly + # necessary as folder_write_bytes() also mkdirs, but it's better + # to give an error message earlier in case the dest folder + # somehow cannot be created. + if os.path.isdir(dest) and len(os.listdir(dest)) != 0: + error('--dest folder must be empty') + os.makedirs(dest, exist_ok=True) + + def folder_write_bytes(fname: str, data: Union[bytes, str]): + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, 'wb') as fout: + if isinstance(data, str): + data = data.encode('utf8') + fout.write(data) + return dest, folder_write_bytes, lambda: None + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH') +@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH') +@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None) +@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) +@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple) +def convert_dataset( + ctx: click.Context, + source: str, + dest: str, + max_images: Optional[int], + transform: Optional[str], + resolution: Optional[Tuple[int, int]] +): + """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. + + The input dataset format is guessed from the --source argument: + + \b + --source *_lmdb/ Load LSUN dataset + --source cifar-10-python.tar.gz Load CIFAR-10 dataset + --source train-images-idx3-ubyte.gz Load MNIST dataset + --source path/ Recursively load all images from path/ + --source dataset.zip Recursively load all images from dataset.zip + + Specifying the output format and path: + + \b + --dest /path/to/dir Save output files under /path/to/dir + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip + + The output dataset format can be either an image folder or an uncompressed zip archive. + Zip archives makes it easier to move datasets around file servers and clusters, and may + offer better training performance on network file systems. + + Images within the dataset archive will be stored as uncompressed PNG. + Uncompresed PNGs can be efficiently decoded in the training loop. + + Class labels are stored in a file called 'dataset.json' that is stored at the + dataset root folder. This file has the following structure: + + \b + { + "labels": [ + ["00000/img00000000.png",6], + ["00000/img00000001.png",9], + ... repeated for every image in the datase + ["00049/img00049999.png",1] + ] + } + + If the 'dataset.json' file cannot be found, the dataset is interpreted as + not containing class labels. + + Image scale/crop and resolution requirements: + + Output images must be square-shaped and they must all have the same power-of-two + dimensions. + + To scale arbitrary input image size to a specific width and height, use the + --resolution option. Output resolution will be either the original + input resolution (if resolution was not specified) or the one specified with + --resolution option. + + Use the --transform=center-crop or --transform=center-crop-wide options to apply a + center crop transform on the input image. These options should be used with the + --resolution option. For example: + + \b + python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ + --transform=center-crop-wide --resolution=512x384 + """ + + PIL.Image.init() # type: ignore + + if dest == '': + ctx.fail('--dest output filename or directory must not be an empty string') + + num_files, input_iter = open_dataset(source, max_images=max_images) + archive_root_dir, save_bytes, close_dest = open_dest(dest) + + if resolution is None: resolution = (None, None) + transform_image = make_transform(transform, *resolution) + + dataset_attrs = None + + labels = [] + for idx, image in tqdm(enumerate(input_iter), total=num_files): + idx_str = f'{idx:08d}' + archive_fname = f'{idx_str[:5]}/img{idx_str}.png' + + # Apply crop and resize. + img = transform_image(image['img']) + + # Transform may drop images. + if img is None: + continue + + # Error check to require uniform image attributes across + # the whole dataset. + channels = img.shape[2] if img.ndim == 3 else 1 + cur_image_attrs = { + 'width': img.shape[1], + 'height': img.shape[0], + 'channels': channels + } + if dataset_attrs is None: + dataset_attrs = cur_image_attrs + width = dataset_attrs['width'] + height = dataset_attrs['height'] + if width != height: + error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') + if dataset_attrs['channels'] not in [1, 3]: + error('Input images must be stored as RGB or grayscale') + if width != 2 ** int(np.floor(np.log2(width))): + error('Image width/height after scale and crop are required to be power-of-two') + elif dataset_attrs != cur_image_attrs: + err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object + error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) + + # Save the image as an uncompressed PNG. + img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels]) + image_bits = io.BytesIO() + img.save(image_bits, format='png', compress_level=0, optimize=False) + save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) + labels.append([archive_fname, image['label']] if image['label'] is not None else None) + + metadata = { + 'labels': labels if all(x is not None for x in labels) else None + } + save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) + close_dest() + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_dataset() # pylint: disable=no-value-for-parameter diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a006715176e91a5ed94a5d2362a87d53b4d889 --- /dev/null +++ b/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/dnnlib/util.py b/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..191b52f6ac7ad75344fb3921f03c37987047287c --- /dev/null +++ b/dnnlib/util.py @@ -0,0 +1,491 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..578a58a00e6588a0205ccdd55801824ba0ec1922 --- /dev/null +++ b/environment.yml @@ -0,0 +1,24 @@ +name: stylegan3 +channels: + - pytorch + - nvidia +dependencies: + - python >= 3.8 + - pip + - numpy>=1.20 + - click>=8.0 + - pillow=8.3.1 + - scipy=1.7.1 + - pytorch=1.9.1 + - cudatoolkit=11.1 + - requests=2.26.0 + - tqdm=4.62.2 + - ninja=1.10.2 + - matplotlib=3.4.2 + - imageio=2.9.0 + - pip: + - imgui==1.3.0 + - glfw==2.2.0 + - pyopengl==3.1.5 + - imageio-ffmpeg==0.4.3 + - pyspng diff --git a/gen_images.py b/gen_images.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a4b11b2fcdbad986a21753a29d7fee2fc26dbd --- /dev/null +++ b/gen_images.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import numpy as np +import PIL.Image +import torch + +import legacy + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: + '''Parse a floating point 2-vector of syntax 'a,b'. + + Example: + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + parts = s.split(',') + if len(parts) == 2: + return (float(parts[0]), float(parts[1])) + raise ValueError(f'cannot parse 2-vector {s}') + +#---------------------------------------------------------------------------- + +def make_transform(translate: Tuple[float,float], angle: float): + m = np.eye(3) + s = np.sin(angle/360.0*np.pi*2) + c = np.cos(angle/360.0*np.pi*2) + m[0][0] = c + m[0][1] = s + m[0][2] = translate[0] + m[1][0] = -s + m[1][1] = c + m[1][2] = translate[1] + return m + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') +@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +def generate_images( + network_pkl: str, + seeds: List[int], + truncation_psi: float, + noise_mode: str, + outdir: str, + translate: Tuple[float,float], + rotate: float, + class_idx: Optional[int] +): + """Generate images using pretrained network pickle. + + Examples: + + \b + # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). + python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + \b + # Generate uncurated images with truncation using the MetFaces-U dataset + python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl + """ + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + os.makedirs(outdir, exist_ok=True) + + # Labels. + label = torch.zeros([1, G.c_dim], device=device) + if G.c_dim != 0: + if class_idx is None: + raise click.ClickException('Must specify class label with --class when using a conditional network') + label[:, class_idx] = 1 + else: + if class_idx is not None: + print ('warn: --class=lbl ignored when running on an unconditional network') + + # Generate images. + for seed_idx, seed in enumerate(seeds): + print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + + # Construct an inverse rotation/translation matrix and pass to the generator. The + # generator expects this matrix as an inverse to avoid potentially failing numerical + # operations in the network. + if hasattr(G.synthesis, 'input'): + m = make_transform(translate, rotate) + m = np.linalg.inv(m) + G.synthesis.input.transform.copy_(torch.from_numpy(m)) + + img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/gen_video.py b/gen_video.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4bcc0ea7669530fa2727392e4c09500d8eed5e --- /dev/null +++ b/gen_video.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import copy +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm + +import legacy + +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + if num_keyframes is None: + if len(seeds) % (grid_w*grid_h) != 0: + raise ValueError('Number of input seeds must be divisible by grid W*H') + num_keyframes = len(seeds) // (grid_w*grid_h) + + all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) + for idx in range(num_keyframes*grid_h*grid_w): + all_seeds[idx] = seeds[idx % len(seeds)] + + if shuffle_seed is not None: + rng = np.random.RandomState(seed=shuffle_seed) + rng.shuffle(all_seeds) + + zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) + ws = G.mapping(z=zs, c=None, truncation_psi=psi) + _ = G.synthesis(ws[:1]) # warm up + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] + imgs.append(img) + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=parse_range, help='List of random seeds', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') +def generate_images( + network_pkl: str, + seeds: List[int], + shuffle_seed: Optional[int], + truncation_psi: float, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + output: str +): + """Render a latent vector interpolation video. + + Examples: + + \b + # Render a 4x2 grid of interpolations for seeds 0 through 31. + python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + Animation length and seed keyframes: + + The animation length is either determined based on the --seeds value or explicitly + specified using the --num-keyframes option. + + When num keyframes is specified with --num-keyframes, the output video length + will be 'num_keyframes*w_frames' frames. + + If --num-keyframes is not specified, the number of seeds given with + --seeds must be divisible by grid size W*H (--grid). In this case the + output video length will be '# seeds/(w*h)*w_frames' frames. + """ + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/gui_utils/__init__.py b/gui_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/gui_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/gui_utils/gl_utils.py b/gui_utils/gl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd8bd96f4b74bdfb274f25d810ddfd43dc44068 --- /dev/null +++ b/gui_utils/gl_utils.py @@ -0,0 +1,374 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import functools +import contextlib +import numpy as np +import OpenGL.GL as gl +import OpenGL.GL.ARB.texture_float +import dnnlib + +#---------------------------------------------------------------------------- + +def init_egl(): + assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL. + import OpenGL.EGL as egl + import ctypes + + # Initialize EGL. + display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) + assert display != egl.EGL_NO_DISPLAY + major = ctypes.c_int32() + minor = ctypes.c_int32() + ok = egl.eglInitialize(display, major, minor) + assert ok + assert major.value * 10 + minor.value >= 14 + + # Choose config. + config_attribs = [ + egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, + egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, + egl.EGL_NONE + ] + configs = (ctypes.c_int32 * 1)() + num_configs = ctypes.c_int32() + ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs) + assert ok + assert num_configs.value == 1 + config = configs[0] + + # Create dummy pbuffer surface. + surface_attribs = [ + egl.EGL_WIDTH, 1, + egl.EGL_HEIGHT, 1, + egl.EGL_NONE + ] + surface = egl.eglCreatePbufferSurface(display, config, surface_attribs) + assert surface != egl.EGL_NO_SURFACE + + # Setup GL context. + ok = egl.eglBindAPI(egl.EGL_OPENGL_API) + assert ok + context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None) + assert context != egl.EGL_NO_CONTEXT + ok = egl.eglMakeCurrent(display, surface, surface, context) + assert ok + +#---------------------------------------------------------------------------- + +_texture_formats = { + ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8), + ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8), + ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8), + ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8), + ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB), + ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB), + ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F), + ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F), +} + +def get_texture_format(dtype, channels): + return _texture_formats[(np.dtype(dtype).name, int(channels))] + +#---------------------------------------------------------------------------- + +def prepare_texture_data(image): + image = np.asarray(image) + if image.ndim == 2: + image = image[:, :, np.newaxis] + if image.dtype.name == 'float64': + image = image.astype('float32') + return image + +#---------------------------------------------------------------------------- + +def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True): + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) + align = np.broadcast_to(np.asarray(align, dtype='float32'), [2]) + image = prepare_texture_data(image) + height, width, channels = image.shape + size = zoom * [width, height] + pos = pos - size * align + if rint: + pos = np.rint(pos) + fmt = get_texture_format(image.dtype, channels) + + gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT) + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glRasterPos2f(pos[0], pos[1]) + gl.glPixelZoom(zoom[0], -zoom[1]) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl.glDrawPixels(width, height, fmt.format, fmt.type, image) + gl.glPopClientAttrib() + gl.glPopAttrib() + +#---------------------------------------------------------------------------- + +def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3): + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + dtype = np.dtype(dtype) + fmt = get_texture_format(dtype, channels) + image = np.empty([height, width, channels], dtype=dtype) + + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image) + gl.glPopClientAttrib() + return np.flipud(image) + +#---------------------------------------------------------------------------- + +class Texture: + def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True): + self.gl_id = None + self.bilinear = bilinear + self.mipmap = mipmap + + # Determine size and dtype. + if image is not None: + image = prepare_texture_data(image) + self.height, self.width, self.channels = image.shape + self.dtype = image.dtype + else: + assert width is not None and height is not None + self.width = width + self.height = height + self.channels = channels if channels is not None else 3 + self.dtype = np.dtype(dtype) if dtype is not None else np.uint8 + + # Validate size and dtype. + assert isinstance(self.width, int) and self.width >= 0 + assert isinstance(self.height, int) and self.height >= 0 + assert isinstance(self.channels, int) and self.channels >= 1 + assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype) + + # Create texture object. + self.gl_id = gl.glGenTextures(1) + with self.bind(): + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST) + self.update(image) + + def delete(self): + if self.gl_id is not None: + gl.glDeleteTextures([self.gl_id]) + self.gl_id = None + + def __del__(self): + try: + self.delete() + except: + pass + + @contextlib.contextmanager + def bind(self): + prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D) + gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id) + yield + gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id) + + def update(self, image): + if image is not None: + image = prepare_texture_data(image) + assert self.is_compatible(image=image) + with self.bind(): + fmt = get_texture_format(self.dtype, self.channels) + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image) + if self.mipmap: + gl.glGenerateMipmap(gl.GL_TEXTURE_2D) + gl.glPopClientAttrib() + + def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0): + zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) + size = zoom * [self.width, self.height] + with self.bind(): + gl.glPushAttrib(gl.GL_ENABLE_BIT) + gl.glEnable(gl.GL_TEXTURE_2D) + draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding) + gl.glPopAttrib() + + def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements + if image is not None: + if image.ndim != 3: + return False + ih, iw, ic = image.shape + if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype): + return False + if width is not None and self.width != width: + return False + if height is not None and self.height != height: + return False + if channels is not None and self.channels != channels: + return False + if dtype is not None and self.dtype != dtype: + return False + return True + +#---------------------------------------------------------------------------- + +class Framebuffer: + def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0): + self.texture = texture + self.gl_id = None + self.gl_color = None + self.gl_depth_stencil = None + self.msaa = msaa + + # Determine size and dtype. + if texture is not None: + assert isinstance(self.texture, Texture) + self.width = texture.width + self.height = texture.height + self.channels = texture.channels + self.dtype = texture.dtype + else: + assert width is not None and height is not None + self.width = width + self.height = height + self.channels = channels if channels is not None else 4 + self.dtype = np.dtype(dtype) if dtype is not None else np.float32 + + # Validate size and dtype. + assert isinstance(self.width, int) and self.width >= 0 + assert isinstance(self.height, int) and self.height >= 0 + assert isinstance(self.channels, int) and self.channels >= 1 + assert width is None or width == self.width + assert height is None or height == self.height + assert channels is None or channels == self.channels + assert dtype is None or dtype == self.dtype + + # Create framebuffer object. + self.gl_id = gl.glGenFramebuffers(1) + with self.bind(): + + # Setup color buffer. + if self.texture is not None: + assert self.msaa == 0 + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0) + else: + fmt = get_texture_format(self.dtype, self.channels) + self.gl_color = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color) + gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color) + + # Setup depth/stencil buffer. + self.gl_depth_stencil = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil) + gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil) + + def delete(self): + if self.gl_id is not None: + gl.glDeleteFramebuffers([self.gl_id]) + self.gl_id = None + if self.gl_color is not None: + gl.glDeleteRenderbuffers(1, [self.gl_color]) + self.gl_color = None + if self.gl_depth_stencil is not None: + gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil]) + self.gl_depth_stencil = None + + def __del__(self): + try: + self.delete() + except: + pass + + @contextlib.contextmanager + def bind(self): + prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING) + prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id) + if self.width is not None and self.height is not None: + gl.glViewport(0, 0, self.width, self.height) + yield + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo) + + def blit(self, dst=None): + assert dst is None or isinstance(dst, Framebuffer) + with self.bind(): + gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo) + gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST) + +#---------------------------------------------------------------------------- + +def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1): + assert vertices.ndim == 2 and vertices.shape[1] == 2 + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) + color = np.broadcast_to(np.asarray(color, dtype='float32'), [3]) + alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1) + + gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT) + gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT) + gl.glMatrixMode(gl.GL_MODELVIEW) + gl.glPushMatrix() + + gl.glEnableClientState(gl.GL_VERTEX_ARRAY) + gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY) + gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices) + gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices) + gl.glTranslate(pos[0], pos[1], 0) + gl.glScale(size[0], size[1], 1) + gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha) + gl.glDrawArrays(mode, 0, vertices.shape[0]) + + gl.glPopMatrix() + gl.glPopAttrib() + gl.glPopClientAttrib() + +#---------------------------------------------------------------------------- + +def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0): + assert pos2 is None or size is None + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None + size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None + size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32') + pos = pos - size * align + if rint: + pos = np.rint(pos) + rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2]) + rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5) + if np.min(rounding) == 0: + rounding *= 0 + vertices = _setup_rect(float(rounding[0]), float(rounding[1])) + draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha) + +@functools.lru_cache(maxsize=10000) +def _setup_rect(rx, ry): + t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64) + s = 1 - np.sin(t); c = 1 - np.cos(t) + x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx] + y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry] + v = np.stack([x, y], axis=-1).reshape(-1, 2) + return v.astype('float32') + +#---------------------------------------------------------------------------- + +def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1): + hole = np.broadcast_to(np.asarray(hole, dtype='float32'), []) + vertices = _setup_circle(float(hole)) + draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha) + +@functools.lru_cache(maxsize=10000) +def _setup_circle(hole): + t = np.linspace(0, np.pi * 2, 128) + s = np.sin(t); c = np.cos(t) + v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2) + return v.astype('float32') + +#---------------------------------------------------------------------------- diff --git a/gui_utils/glfw_window.py b/gui_utils/glfw_window.py new file mode 100644 index 0000000000000000000000000000000000000000..94c4a8dd2534b6def2607d6ff2bf29fe472dc53f --- /dev/null +++ b/gui_utils/glfw_window.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import time +import glfw +import OpenGL.GL as gl +from . import gl_utils + +#---------------------------------------------------------------------------- + +class GlfwWindow: # pylint: disable=too-many-public-methods + def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): + self._glfw_window = None + self._drawing_frame = False + self._frame_start_time = None + self._frame_delta = 0 + self._fps_limit = None + self._vsync = None + self._skip_frames = 0 + self._deferred_show = deferred_show + self._close_on_esc = close_on_esc + self._esc_pressed = False + self._drag_and_drop_paths = None + self._capture_next_frame = False + self._captured_frame = None + + # Create window. + glfw.init() + glfw.window_hint(glfw.VISIBLE, False) + self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) + self._attach_glfw_callbacks() + self.make_context_current() + + # Adjust window. + self.set_vsync(False) + self.set_window_size(window_width, window_height) + if not self._deferred_show: + glfw.show_window(self._glfw_window) + + def close(self): + if self._drawing_frame: + self.end_frame() + if self._glfw_window is not None: + glfw.destroy_window(self._glfw_window) + self._glfw_window = None + #glfw.terminate() # Commented out to play it nice with other glfw clients. + + def __del__(self): + try: + self.close() + except: + pass + + @property + def window_width(self): + return self.content_width + + @property + def window_height(self): + return self.content_height + self.title_bar_height + + @property + def content_width(self): + width, _height = glfw.get_window_size(self._glfw_window) + return width + + @property + def content_height(self): + _width, height = glfw.get_window_size(self._glfw_window) + return height + + @property + def title_bar_height(self): + _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) + return top + + @property + def monitor_width(self): + _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) + return width + + @property + def monitor_height(self): + _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) + return height + + @property + def frame_delta(self): + return self._frame_delta + + def set_title(self, title): + glfw.set_window_title(self._glfw_window, title) + + def set_window_size(self, width, height): + width = min(width, self.monitor_width) + height = min(height, self.monitor_height) + glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) + if width == self.monitor_width and height == self.monitor_height: + self.maximize() + + def set_content_size(self, width, height): + self.set_window_size(width, height + self.title_bar_height) + + def maximize(self): + glfw.maximize_window(self._glfw_window) + + def set_position(self, x, y): + glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) + + def center(self): + self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) + + def set_vsync(self, vsync): + vsync = bool(vsync) + if vsync != self._vsync: + glfw.swap_interval(1 if vsync else 0) + self._vsync = vsync + + def set_fps_limit(self, fps_limit): + self._fps_limit = int(fps_limit) + + def should_close(self): + return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) + + def skip_frame(self): + self.skip_frames(1) + + def skip_frames(self, num): # Do not update window for the next N frames. + self._skip_frames = max(self._skip_frames, int(num)) + + def is_skipping_frames(self): + return self._skip_frames > 0 + + def capture_next_frame(self): + self._capture_next_frame = True + + def pop_captured_frame(self): + frame = self._captured_frame + self._captured_frame = None + return frame + + def pop_drag_and_drop_paths(self): + paths = self._drag_and_drop_paths + self._drag_and_drop_paths = None + return paths + + def draw_frame(self): # To be overridden by subclass. + self.begin_frame() + # Rendering code goes here. + self.end_frame() + + def make_context_current(self): + if self._glfw_window is not None: + glfw.make_context_current(self._glfw_window) + + def begin_frame(self): + # End previous frame. + if self._drawing_frame: + self.end_frame() + + # Apply FPS limit. + if self._frame_start_time is not None and self._fps_limit is not None: + delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit + if delay > 0: + time.sleep(delay) + cur_time = time.perf_counter() + if self._frame_start_time is not None: + self._frame_delta = cur_time - self._frame_start_time + self._frame_start_time = cur_time + + # Process events. + glfw.poll_events() + + # Begin frame. + self._drawing_frame = True + self.make_context_current() + + # Initialize GL state. + gl.glViewport(0, 0, self.content_width, self.content_height) + gl.glMatrixMode(gl.GL_PROJECTION) + gl.glLoadIdentity() + gl.glTranslate(-1, 1, 0) + gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) + gl.glMatrixMode(gl.GL_MODELVIEW) + gl.glLoadIdentity() + gl.glEnable(gl.GL_BLEND) + gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. + + # Clear. + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + + def end_frame(self): + assert self._drawing_frame + self._drawing_frame = False + + # Skip frames if requested. + if self._skip_frames > 0: + self._skip_frames -= 1 + return + + # Capture frame if requested. + if self._capture_next_frame: + self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) + self._capture_next_frame = False + + # Update window. + if self._deferred_show: + glfw.show_window(self._glfw_window) + self._deferred_show = False + glfw.swap_buffers(self._glfw_window) + + def _attach_glfw_callbacks(self): + glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) + glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) + + def _glfw_key_callback(self, _window, key, _scancode, action, _mods): + if action == glfw.PRESS and key == glfw.KEY_ESCAPE: + self._esc_pressed = True + + def _glfw_drop_callback(self, _window, paths): + self._drag_and_drop_paths = paths + +#---------------------------------------------------------------------------- diff --git a/gui_utils/imgui_utils.py b/gui_utils/imgui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cb118d996a380d4b263762051af248ee0383c7 --- /dev/null +++ b/gui_utils/imgui_utils.py @@ -0,0 +1,169 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import contextlib +import imgui + +#---------------------------------------------------------------------------- + +def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): + s = imgui.get_style() + s.window_padding = [spacing, spacing] + s.item_spacing = [spacing, spacing] + s.item_inner_spacing = [spacing, spacing] + s.columns_min_spacing = spacing + s.indent_spacing = indent + s.scrollbar_size = scrollbar + s.frame_padding = [4, 3] + s.window_border_size = 1 + s.child_border_size = 1 + s.popup_border_size = 1 + s.frame_border_size = 1 + s.window_rounding = 0 + s.child_rounding = 0 + s.popup_rounding = 3 + s.frame_rounding = 3 + s.scrollbar_rounding = 3 + s.grab_rounding = 3 + + getattr(imgui, f'style_colors_{color_scheme}')(s) + c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] + c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] + s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] + +#---------------------------------------------------------------------------- + +@contextlib.contextmanager +def grayed_out(cond=True): + if cond: + s = imgui.get_style() + text = s.colors[imgui.COLOR_TEXT_DISABLED] + grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] + back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] + imgui.push_style_color(imgui.COLOR_TEXT, *text) + imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) + imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) + imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_BUTTON, *back) + imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_HEADER, *back) + imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) + yield + imgui.pop_style_color(14) + else: + yield + +#---------------------------------------------------------------------------- + +@contextlib.contextmanager +def item_width(width=None): + if width is not None: + imgui.push_item_width(width) + yield + imgui.pop_item_width() + else: + yield + +#---------------------------------------------------------------------------- + +def scoped_by_object_id(method): + def decorator(self, *args, **kwargs): + imgui.push_id(str(id(self))) + res = method(self, *args, **kwargs) + imgui.pop_id() + return res + return decorator + +#---------------------------------------------------------------------------- + +def button(label, width=0, enabled=True): + with grayed_out(not enabled): + clicked = imgui.button(label, width=width) + clicked = clicked and enabled + return clicked + +#---------------------------------------------------------------------------- + +def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): + expanded = False + if show: + if default: + flags |= imgui.TREE_NODE_DEFAULT_OPEN + if not enabled: + flags |= imgui.TREE_NODE_LEAF + with grayed_out(not enabled): + expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) + expanded = expanded and enabled + return expanded, visible + +#---------------------------------------------------------------------------- + +def popup_button(label, width=0, enabled=True): + if button(label, width, enabled): + imgui.open_popup(label) + opened = imgui.begin_popup(label) + return opened + +#---------------------------------------------------------------------------- + +def input_text(label, value, buffer_length, flags, width=None, help_text=''): + old_value = value + color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) + if value == '': + color[-1] *= 0.5 + with item_width(width): + imgui.push_style_color(imgui.COLOR_TEXT, *color) + value = value if value != '' else help_text + changed, value = imgui.input_text(label, value, buffer_length, flags) + value = value if value != help_text else '' + imgui.pop_style_color(1) + if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: + changed = (value != old_value) + return changed, value + +#---------------------------------------------------------------------------- + +def drag_previous_control(enabled=True): + dragging = False + dx = 0 + dy = 0 + if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): + if enabled: + dragging = True + dx, dy = imgui.get_mouse_drag_delta() + imgui.reset_mouse_drag_delta() + imgui.end_drag_drop_source() + return dragging, dx, dy + +#---------------------------------------------------------------------------- + +def drag_button(label, width=0, enabled=True): + clicked = button(label, width=width, enabled=enabled) + dragging, dx, dy = drag_previous_control(enabled=enabled) + return clicked, dragging, dx, dy + +#---------------------------------------------------------------------------- + +def drag_hidden_window(label, x, y, width, height, enabled=True): + imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) + imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) + imgui.set_next_window_position(x, y) + imgui.set_next_window_size(width, height) + imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) + dragging, dx, dy = drag_previous_control(enabled=enabled) + imgui.end() + imgui.pop_style_color(2) + return dragging, dx, dy + +#---------------------------------------------------------------------------- diff --git a/gui_utils/imgui_window.py b/gui_utils/imgui_window.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf7caa76f0e1261e490bce8ef1c8267a1e5c31d --- /dev/null +++ b/gui_utils/imgui_window.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import imgui +import imgui.integrations.glfw + +from . import glfw_window +from . import imgui_utils +from . import text_utils + +#---------------------------------------------------------------------------- + +class ImguiWindow(glfw_window.GlfwWindow): + def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): + if font is None: + font = text_utils.get_default_font() + font_sizes = {int(size) for size in font_sizes} + super().__init__(title=title, **glfw_kwargs) + + # Init fields. + self._imgui_context = None + self._imgui_renderer = None + self._imgui_fonts = None + self._cur_font_size = max(font_sizes) + + # Delete leftover imgui.ini to avoid unexpected behavior. + if os.path.isfile('imgui.ini'): + os.remove('imgui.ini') + + # Init ImGui. + self._imgui_context = imgui.create_context() + self._imgui_renderer = _GlfwRenderer(self._glfw_window) + self._attach_glfw_callbacks() + imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. + imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). + self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} + self._imgui_renderer.refresh_font_texture() + + def close(self): + self.make_context_current() + self._imgui_fonts = None + if self._imgui_renderer is not None: + self._imgui_renderer.shutdown() + self._imgui_renderer = None + if self._imgui_context is not None: + #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. + self._imgui_context = None + super().close() + + def _glfw_key_callback(self, *args): + super()._glfw_key_callback(*args) + self._imgui_renderer.keyboard_callback(*args) + + @property + def font_size(self): + return self._cur_font_size + + @property + def spacing(self): + return round(self._cur_font_size * 0.4) + + def set_font_size(self, target): # Applied on next frame. + self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] + + def begin_frame(self): + # Begin glfw frame. + super().begin_frame() + + # Process imgui events. + self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 + if self.content_width > 0 and self.content_height > 0: + self._imgui_renderer.process_inputs() + + # Begin imgui frame. + imgui.new_frame() + imgui.push_font(self._imgui_fonts[self._cur_font_size]) + imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) + + def end_frame(self): + imgui.pop_font() + imgui.render() + imgui.end_frame() + self._imgui_renderer.render(imgui.get_draw_data()) + super().end_frame() + +#---------------------------------------------------------------------------- +# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. + +class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mouse_wheel_multiplier = 1 + + def scroll_callback(self, window, x_offset, y_offset): + self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier + +#---------------------------------------------------------------------------- diff --git a/gui_utils/text_utils.py b/gui_utils/text_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0c7a9b13fe63df5deb20664e22584a8240fc59 --- /dev/null +++ b/gui_utils/text_utils.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import functools +from typing import Optional + +import dnnlib +import numpy as np +import PIL.Image +import PIL.ImageFont +import scipy.ndimage + +from . import gl_utils + +#---------------------------------------------------------------------------- + +def get_default_font(): + url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular + return dnnlib.util.open_url(url, return_filename=True) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=None) +def get_pil_font(font=None, size=32): + if font is None: + font = get_default_font() + return PIL.ImageFont.truetype(font=font, size=size) + +#---------------------------------------------------------------------------- + +def get_array(string, *, dropshadow_radius: int=None, **kwargs): + if dropshadow_radius is not None: + offset_x = int(np.ceil(dropshadow_radius*2/3)) + offset_y = int(np.ceil(dropshadow_radius*2/3)) + return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) + else: + return _get_array_priv(string, **kwargs) + +@functools.lru_cache(maxsize=10000) +def _get_array_priv( + string: str, *, + size: int = 32, + max_width: Optional[int]=None, + max_height: Optional[int]=None, + min_size=10, + shrink_coef=0.8, + dropshadow_radius: int=None, + offset_x: int=None, + offset_y: int=None, + **kwargs +): + cur_size = size + array = None + while True: + if dropshadow_radius is not None: + # separate implementation for dropshadow text rendering + array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) + else: + array = _get_array_impl(string, size=cur_size, **kwargs) + height, width, _ = array.shape + if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): + break + cur_size = max(int(cur_size * shrink_coef), min_size) + return array + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): + pil_font = get_pil_font(font=font, size=size) + lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] + lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] + width = max(line.shape[1] for line in lines) + lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] + line_spacing = line_pad if line_pad is not None else size // 2 + lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] + mask = np.concatenate(lines, axis=0) + alpha = mask + if outline > 0: + mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) + alpha = mask.astype(np.float32) / 255 + alpha = scipy.ndimage.gaussian_filter(alpha, outline) + alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp + alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) + alpha = np.maximum(alpha, mask) + return np.stack([mask, alpha], axis=-1) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): + assert (offset_x > 0) and (offset_y > 0) + pil_font = get_pil_font(font=font, size=size) + lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] + lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] + width = max(line.shape[1] for line in lines) + lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] + line_spacing = line_pad if line_pad is not None else size // 2 + lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] + mask = np.concatenate(lines, axis=0) + alpha = mask + + mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) + alpha = mask.astype(np.float32) / 255 + alpha = scipy.ndimage.gaussian_filter(alpha, radius) + alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 + alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) + alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] + alpha = np.maximum(alpha, mask) + return np.stack([mask, alpha], axis=-1) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def get_texture(string, bilinear=True, mipmap=True, **kwargs): + return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) + +#---------------------------------------------------------------------------- diff --git a/legacy.py b/legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..e361fc2351c383c23aa6786f0bbc9daae047afca --- /dev/null +++ b/legacy.py @@ -0,0 +1,323 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Converting legacy network pickle into the new format.""" + +import click +import pickle +import re +import copy +import numpy as np +import torch +import dnnlib +from torch_utils import misc + +#---------------------------------------------------------------------------- + +def load_network_pkl(f, force_fp16=False): + data = _LegacyUnpickler(f).load() + + # Legacy TensorFlow pickle => convert. + if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): + tf_G, tf_D, tf_Gs = data + G = convert_tf_generator(tf_G) + D = convert_tf_discriminator(tf_D) + G_ema = convert_tf_generator(tf_Gs) + data = dict(G=G, D=D, G_ema=G_ema) + + # Add missing fields. + if 'training_set_kwargs' not in data: + data['training_set_kwargs'] = None + if 'augment_pipe' not in data: + data['augment_pipe'] = None + + # Validate contents. + assert isinstance(data['G'], torch.nn.Module) + assert isinstance(data['D'], torch.nn.Module) + assert isinstance(data['G_ema'], torch.nn.Module) + assert isinstance(data['training_set_kwargs'], (dict, type(None))) + assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) + + # Force FP16. + if force_fp16: + for key in ['G', 'D', 'G_ema']: + old = data[key] + kwargs = copy.deepcopy(old.init_kwargs) + fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) + fp16_kwargs.num_fp16_res = 4 + fp16_kwargs.conv_clamp = 256 + if kwargs != old.init_kwargs: + new = type(old)(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(old, new, require_all=True) + data[key] = new + return data + +#---------------------------------------------------------------------------- + +class _TFNetworkStub(dnnlib.EasyDict): + pass + +class _LegacyUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'dnnlib.tflib.network' and name == 'Network': + return _TFNetworkStub + return super().find_class(module, name) + +#---------------------------------------------------------------------------- + +def _collect_tf_params(tf_net): + # pylint: disable=protected-access + tf_params = dict() + def recurse(prefix, tf_net): + for name, value in tf_net.variables: + tf_params[prefix + name] = value + for name, comp in tf_net.components.items(): + recurse(prefix + name + '/', comp) + recurse('', tf_net) + return tf_params + +#---------------------------------------------------------------------------- + +def _populate_module_params(module, *patterns): + for name, tensor in misc.named_params_and_buffers(module): + found = False + value = None + for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): + match = re.fullmatch(pattern, name) + if match: + found = True + if value_fn is not None: + value = value_fn(*match.groups()) + break + try: + assert found + if value is not None: + tensor.copy_(torch.from_numpy(np.array(value))) + except: + print(name, list(tensor.shape)) + raise + +#---------------------------------------------------------------------------- + +def convert_tf_generator(tf_G): + if tf_G.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_G.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None, none=None): + known_kwargs.add(tf_name) + val = tf_kwargs.get(tf_name, default) + return val if val is not None else none + + # Convert kwargs. + from training import networks_stylegan2 + network_class = networks_stylegan2.Generator + kwargs = dnnlib.EasyDict( + z_dim = kwarg('latent_size', 512), + c_dim = kwarg('label_size', 0), + w_dim = kwarg('dlatent_size', 512), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + architecture = kwarg('architecture', 'skip'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + use_noise = kwarg('use_noise', True), + activation = kwarg('nonlinearity', 'lrelu'), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 8), + embed_features = kwarg('label_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('mapping_nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.01), + w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), + ), + ) + + # Check for unknown kwargs. + kwarg('truncation_psi') + kwarg('truncation_cutoff') + kwarg('style_mixing_prob') + kwarg('structure') + kwarg('conditioning') + kwarg('fused_modconv') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_G) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value + kwargs.synthesis.kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + G = network_class(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + # pylint: disable=f-string-without-interpolation + _populate_module_params(G, + r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], + r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], + r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], + r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], + r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], + r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], + r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), + r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], + r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], + r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], + r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], + r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], + r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], + r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, + r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], + r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, + r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'.*\.resample_filter', None, + r'.*\.act_filter', None, + ) + return G + +#---------------------------------------------------------------------------- + +def convert_tf_discriminator(tf_D): + if tf_D.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_D.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None): + known_kwargs.add(tf_name) + return tf_kwargs.get(tf_name, default) + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + c_dim = kwarg('label_size', 0), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + architecture = kwarg('architecture', 'resnet'), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + cmap_dim = kwarg('mapping_fmaps', None), + block_kwargs = dnnlib.EasyDict( + activation = kwarg('nonlinearity', 'lrelu'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + freeze_layers = kwarg('freeze_layers', 0), + ), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 0), + embed_features = kwarg('mapping_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.1), + ), + epilogue_kwargs = dnnlib.EasyDict( + mbstd_group_size = kwarg('mbstd_group_size', None), + mbstd_num_channels = kwarg('mbstd_num_features', 1), + activation = kwarg('nonlinearity', 'lrelu'), + ), + ) + + # Check for unknown kwargs. + kwarg('structure') + kwarg('conditioning') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_D) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value + kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + from training import networks_stylegan2 + D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + # pylint: disable=f-string-without-interpolation + _populate_module_params(D, + r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], + r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], + r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), + r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], + r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], + r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), + r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], + r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), + r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], + r'.*\.resample_filter', None, + ) + return D + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--source', help='Input pickle', required=True, metavar='PATH') +@click.option('--dest', help='Output pickle', required=True, metavar='PATH') +@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) +def convert_network_pickle(source, dest, force_fp16): + """Convert legacy network pickle into the native PyTorch format. + + The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. + It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. + + Example: + + \b + python legacy.py \\ + --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ + --dest=stylegan2-cat-config-f.pkl + """ + print(f'Loading "{source}"...') + with dnnlib.util.open_url(source) as f: + data = load_network_pkl(f, force_fp16=force_fp16) + print(f'Saving "{dest}"...') + with open(dest, 'wb') as f: + pickle.dump(data, f) + print('Done.') + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_network_pickle() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/metrics/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/metrics/equivariance.py b/metrics/equivariance.py new file mode 100644 index 0000000000000000000000000000000000000000..c96ebed07fe478542ab56b56a6506e79f03d1388 --- /dev/null +++ b/metrics/equivariance.py @@ -0,0 +1,267 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper +"Alias-Free Generative Adversarial Networks".""" + +import copy +import numpy as np +import torch +import torch.fft +from torch_utils.ops import upfirdn2d +from . import metric_utils + +#---------------------------------------------------------------------------- +# Utilities. + +def sinc(x): + y = (x * np.pi).abs() + z = torch.sin(y) / y.clamp(1e-30, float('inf')) + return torch.where(y < 1e-30, torch.ones_like(x), z) + +def lanczos_window(x, a): + x = x.abs() / a + return torch.where(x < 1, sinc(x), torch.zeros_like(x)) + +def rotation_matrix(angle): + angle = torch.as_tensor(angle).to(torch.float32) + mat = torch.eye(3, device=angle.device) + mat[0, 0] = angle.cos() + mat[0, 1] = angle.sin() + mat[1, 0] = -angle.sin() + mat[1, 1] = angle.cos() + return mat + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.1. + +def apply_integer_translation(x, tx, ty): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.round().to(torch.int64) + iy = ty.round().to(torch.int64) + + z = torch.zeros_like(x) + m = torch.zeros_like(x) + if abs(ix) < W and abs(iy) < H: + y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)] + z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y + m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.2. + +def apply_fractional_translation(x, tx, ty, a=3): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.floor().to(torch.int64) + iy = ty.floor().to(torch.int64) + fx = tx - ix + fy = ty - iy + b = a - 1 + + z = torch.zeros_like(x) + zx0 = max(ix - b, 0) + zy0 = max(iy - b, 0) + zx1 = min(ix + a, 0) + W + zy1 = min(iy + a, 0) + H + if zx0 < zx1 and zy0 < zy1: + taps = torch.arange(a * 2, device=x.device) - b + filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0) + filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1) + y = x + y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0]) + y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a]) + y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)] + z[:, :, zy0:zy1, zx0:zx1] = y + + m = torch.zeros_like(x) + mx0 = max(ix + a, 0) + my0 = max(iy + a, 0) + mx1 = min(ix - b, 0) + W + my1 = min(iy - b, 0) + H + if mx0 < mx1 and my0 < my1: + m[:, :, my0:my1, mx0:mx1] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Construct an oriented low-pass filter that applies the appropriate +# bandlimit with respect to the input and output of the given affine 2D +# image transformation. + +def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): + assert a <= amax < aflt + mat = torch.as_tensor(mat).to(torch.float32) + + # Construct 2D filter taps in input & output coordinate spaces. + taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) + yi, xi = torch.meshgrid(taps, taps) + xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) + + # Convolution of two oriented 2D sinc filters. + fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in) + fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out) + f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real + + # Convolution of two oriented 2D Lanczos windows. + wi = lanczos_window(xi, a) * lanczos_window(yi, a) + wo = lanczos_window(xo, a) * lanczos_window(yo, a) + w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real + + # Construct windowed FIR filter. + f = f * w + + # Finalize. + c = (aflt - amax) * up + f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] + f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) + f = f / f.sum([0,2], keepdim=True) / (up ** 2) + f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] + return f + +#---------------------------------------------------------------------------- +# Apply the given affine transformation to a batch of 2D images. + +def apply_affine_transformation(x, mat, up=4, **filter_kwargs): + _N, _C, H, W = x.shape + mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) + + # Construct filter. + f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) + assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 + p = f.shape[0] // 2 + + # Construct sampling grid. + theta = mat.inverse() + theta[:2, 2] *= 2 + theta[0, 2] += 1 / up / W + theta[1, 2] += 1 / up / H + theta[0, :] *= W / (W + p / up * 2) + theta[1, :] *= H / (H + p / up * 2) + theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) + g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) + + # Resample image. + y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) + z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Form mask. + m = torch.zeros_like(y) + c = p * 2 + 1 + m[:, :, c:-c, c:-c] = 1 + m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) + return z, m + +#---------------------------------------------------------------------------- +# Apply fractional rotation to a batch of 2D images. Corresponds to the +# operator R_\alpha in Appendix E.3. + +def apply_fractional_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(angle) + return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs) + +#---------------------------------------------------------------------------- +# Modify the frequency content of a batch of 2D images as if they had undergo +# fractional rotation -- but without actually rotating them. Corresponds to +# the operator R^*_\alpha in Appendix E.3. + +def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(-angle) + f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs) + y = upfirdn2d.filter2d(x=x, f=f) + m = torch.zeros_like(y) + c = f.shape[0] // 2 + m[:, :, c:-c, c:-c] = 1 + return y, m + +#---------------------------------------------------------------------------- +# Compute the selected equivariance metrics for the given generator. + +def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False): + assert compute_eqt_int or compute_eqt_frac or compute_eqr + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + I = torch.eye(3, device=opts.device) + M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None) + if M is None: + raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations') + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + sums = None + progress = opts.progress.sub(tag='eq sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + s = [] + + # Randomize noise buffers, if any. + for name, buf in G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Run mapping network. + z = torch.randn([batch_size, G.z_dim], device=opts.device) + c = next(c_iter) + ws = G.mapping(z=z, c=c) + + # Generate reference image. + M[:] = I + orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + + # Integer translation (EQ-T). + if compute_eqt_int: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + t = (t * G.img_resolution).round() / G.img_resolution + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_integer_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Fractional translation (EQ-T_frac). + if compute_eqt_frac: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_fractional_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Rotation (EQ-R). + if compute_eqr: + angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi) + M[:] = rotation_matrix(-angle) + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, ref_mask = apply_fractional_rotation(orig, angle) + pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) + mask = ref_mask * pseudo_mask + s += [(ref - pseudo).square() * mask, mask] + + # Accumulate results. + s = torch.stack([x.to(torch.float64).sum() for x in s]) + sums = sums + s if sums is not None else s + progress.update(num_samples) + + # Compute PSNRs. + if opts.num_gpus > 1: + torch.distributed.all_reduce(sums) + sums = sums.cpu() + mses = sums[0::2] / sums[1::2] + psnrs = np.log10(2) * 20 - mses.log10() * 10 + psnrs = tuple(psnrs.numpy()) + return psnrs[0] if len(psnrs) == 1 else psnrs + +#---------------------------------------------------------------------------- diff --git a/metrics/frechet_inception_distance.py b/metrics/frechet_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdd6b6ce1f44a345c1451150634bb2fa0c7e9b3 --- /dev/null +++ b/metrics/frechet_inception_distance.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Frechet Inception Distance (FID) from the paper +"GANs trained by a two time-scale update rule converge to a local Nash +equilibrium". Matches the original implementation by Heusel et al. at +https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" + +import numpy as np +import scipy.linalg +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_fid(opts, max_real, num_gen): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() + + mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() + + if opts.rank != 0: + return float('nan') + + m = np.square(mu_gen - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + return float(fid) + +#---------------------------------------------------------------------------- diff --git a/metrics/inception_score.py b/metrics/inception_score.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a7f4acbcd3c33db6d6848f61d176e0ff97295e --- /dev/null +++ b/metrics/inception_score.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Inception Score (IS) from the paper "Improved techniques for training +GANs". Matches the original implementation by Salimans et al. at +https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_is(opts, num_gen, num_splits): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. + + gen_probs = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan'), float('nan') + + scores = [] + for i in range(num_splits): + part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] + kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) + kl = np.mean(np.sum(kl, axis=1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)), float(np.std(scores)) + +#---------------------------------------------------------------------------- diff --git a/metrics/kernel_inception_distance.py b/metrics/kernel_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e0bd12ef56e64d8e77091aaf465891f4984d9e --- /dev/null +++ b/metrics/kernel_inception_distance.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Kernel Inception Distance (KID) from the paper "Demystifying MMD +GANs". Matches the original implementation by Binkowski et al. at +https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan') + + n = real_features.shape[1] + m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) + t = 0 + for _subset_idx in range(num_subsets): + x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] + y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] + a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 + b = (x @ y.T / n + 1) ** 3 + t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m + kid = t / num_subsets / m + return float(kid) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_main.py b/metrics/metric_main.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f7389fbe322aff06a1860d581da56f9d9ad937 --- /dev/null +++ b/metrics/metric_main.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Main API for computing and reporting quality metrics.""" + +import os +import time +import json +import torch +import dnnlib + +from . import metric_utils +from . import frechet_inception_distance +from . import kernel_inception_distance +from . import precision_recall +from . import perceptual_path_length +from . import inception_score +from . import equivariance + +#---------------------------------------------------------------------------- + +_metric_dict = dict() # name => fn + +def register_metric(fn): + assert callable(fn) + _metric_dict[fn.__name__] = fn + return fn + +def is_valid_metric(metric): + return metric in _metric_dict + +def list_valid_metrics(): + return list(_metric_dict.keys()) + +#---------------------------------------------------------------------------- + +def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. + assert is_valid_metric(metric) + opts = metric_utils.MetricOptions(**kwargs) + + # Calculate. + start_time = time.time() + results = _metric_dict[metric](opts) + total_time = time.time() - start_time + + # Broadcast results. + for key, value in list(results.items()): + if opts.num_gpus > 1: + value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) + torch.distributed.broadcast(tensor=value, src=0) + value = float(value.cpu()) + results[key] = value + + # Decorate with metadata. + return dnnlib.EasyDict( + results = dnnlib.EasyDict(results), + metric = metric, + total_time = total_time, + total_time_str = dnnlib.util.format_time(total_time), + num_gpus = opts.num_gpus, + ) + +#---------------------------------------------------------------------------- + +def report_metric(result_dict, run_dir=None, snapshot_pkl=None): + metric = result_dict['metric'] + assert is_valid_metric(metric) + if run_dir is not None and snapshot_pkl is not None: + snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) + + jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) + print(jsonl_line) + if run_dir is not None and os.path.isdir(run_dir): + with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: + f.write(jsonl_line + '\n') + +#---------------------------------------------------------------------------- +# Recommended metrics. + +@register_metric +def fid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) + return dict(fid50k_full=fid) + +@register_metric +def kid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k_full=kid) + +@register_metric +def pr50k3_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) + +@register_metric +def ppl2_wend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) + return dict(ppl2_wend=ppl) + +@register_metric +def eqt50k_int(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) + return dict(eqt50k_int=psnr) + +@register_metric +def eqt50k_frac(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) + return dict(eqt50k_frac=psnr) + +@register_metric +def eqr50k(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) + return dict(eqr50k=psnr) + +#---------------------------------------------------------------------------- +# Legacy metrics. + +@register_metric +def fid50k(opts): + opts.dataset_kwargs.update(max_size=None) + fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) + return dict(fid50k=fid) + +@register_metric +def kid50k(opts): + opts.dataset_kwargs.update(max_size=None) + kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k=kid) + +@register_metric +def pr50k3(opts): + opts.dataset_kwargs.update(max_size=None) + precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_precision=precision, pr50k3_recall=recall) + +@register_metric +def is50k(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) + return dict(is50k_mean=mean, is50k_std=std) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_utils.py b/metrics/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44b67eed7b5bbf029481ecbd865457fa42f7cc89 --- /dev/null +++ b/metrics/metric_utils.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utilities used internally by the quality metrics.""" + +import os +import time +import hashlib +import pickle +import copy +import uuid +import numpy as np +import torch +import dnnlib + +#---------------------------------------------------------------------------- + +class MetricOptions: + def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): + assert 0 <= rank < num_gpus + self.G = G + self.G_kwargs = dnnlib.EasyDict(G_kwargs) + self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) + self.num_gpus = num_gpus + self.rank = rank + self.device = device if device is not None else torch.device('cuda', rank) + self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() + self.cache = cache + +#---------------------------------------------------------------------------- + +_feature_detector_cache = dict() + +def get_feature_detector_name(url): + return os.path.splitext(url.split('/')[-1])[0] + +def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): + assert 0 <= rank < num_gpus + key = (url, device) + if key not in _feature_detector_cache: + is_leader = (rank == 0) + if not is_leader and num_gpus > 1: + torch.distributed.barrier() # leader goes first + with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: + _feature_detector_cache[key] = pickle.load(f).to(device) + if is_leader and num_gpus > 1: + torch.distributed.barrier() # others follow + return _feature_detector_cache[key] + +#---------------------------------------------------------------------------- + +def iterate_random_labels(opts, batch_size): + if opts.G.c_dim == 0: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + while True: + yield c + else: + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + while True: + c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + yield c + +#---------------------------------------------------------------------------- + +class FeatureStats: + def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): + self.capture_all = capture_all + self.capture_mean_cov = capture_mean_cov + self.max_items = max_items + self.num_items = 0 + self.num_features = None + self.all_features = None + self.raw_mean = None + self.raw_cov = None + + def set_num_features(self, num_features): + if self.num_features is not None: + assert num_features == self.num_features + else: + self.num_features = num_features + self.all_features = [] + self.raw_mean = np.zeros([num_features], dtype=np.float64) + self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) + + def is_full(self): + return (self.max_items is not None) and (self.num_items >= self.max_items) + + def append(self, x): + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): + if self.num_items >= self.max_items: + return + x = x[:self.max_items - self.num_items] + + self.set_num_features(x.shape[1]) + self.num_items += x.shape[0] + if self.capture_all: + self.all_features.append(x) + if self.capture_mean_cov: + x64 = x.astype(np.float64) + self.raw_mean += x64.sum(axis=0) + self.raw_cov += x64.T @ x64 + + def append_torch(self, x, num_gpus=1, rank=0): + assert isinstance(x, torch.Tensor) and x.ndim == 2 + assert 0 <= rank < num_gpus + if num_gpus > 1: + ys = [] + for src in range(num_gpus): + y = x.clone() + torch.distributed.broadcast(y, src=src) + ys.append(y) + x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples + self.append(x.cpu().numpy()) + + def get_all(self): + assert self.capture_all + return np.concatenate(self.all_features, axis=0) + + def get_all_torch(self): + return torch.from_numpy(self.get_all()) + + def get_mean_cov(self): + assert self.capture_mean_cov + mean = self.raw_mean / self.num_items + cov = self.raw_cov / self.num_items + cov = cov - np.outer(mean, mean) + return mean, cov + + def save(self, pkl_file): + with open(pkl_file, 'wb') as f: + pickle.dump(self.__dict__, f) + + @staticmethod + def load(pkl_file): + with open(pkl_file, 'rb') as f: + s = dnnlib.EasyDict(pickle.load(f)) + obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) + obj.__dict__.update(s) + return obj + +#---------------------------------------------------------------------------- + +class ProgressMonitor: + def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): + self.tag = tag + self.num_items = num_items + self.verbose = verbose + self.flush_interval = flush_interval + self.progress_fn = progress_fn + self.pfn_lo = pfn_lo + self.pfn_hi = pfn_hi + self.pfn_total = pfn_total + self.start_time = time.time() + self.batch_time = self.start_time + self.batch_items = 0 + if self.progress_fn is not None: + self.progress_fn(self.pfn_lo, self.pfn_total) + + def update(self, cur_items): + assert (self.num_items is None) or (cur_items <= self.num_items) + if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): + return + cur_time = time.time() + total_time = cur_time - self.start_time + time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) + if (self.verbose) and (self.tag is not None): + print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') + self.batch_time = cur_time + self.batch_items = cur_items + + if (self.progress_fn is not None) and (self.num_items is not None): + self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) + + def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): + return ProgressMonitor( + tag = tag, + num_items = num_items, + flush_interval = flush_interval, + verbose = self.verbose, + progress_fn = self.progress_fn, + pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, + pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, + pfn_total = self.pfn_total, + ) + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + # Try to lookup from cache. + cache_file = None + if opts.cache: + # Choose cache file name. + args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) + md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) + cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' + cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') + + # Check if the file exists (all processes must agree). + flag = os.path.isfile(cache_file) if opts.rank == 0 else False + if opts.num_gpus > 1: + flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) + torch.distributed.broadcast(tensor=flag, src=0) + flag = (float(flag.cpu()) != 0) + + # Load. + if flag: + return FeatureStats.load(cache_file) + + # Initialize. + num_items = len(dataset) + if max_items is not None: + num_items = min(num_items, max_items) + stats = FeatureStats(max_items=num_items, **stats_kwargs) + progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] + for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images.to(opts.device), **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + + # Save to cache. + if cache_file is not None and opts.rank == 0: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + temp_file = cache_file + '.' + uuid.uuid4().hex + stats.save(temp_file) + os.replace(temp_file, cache_file) # atomic + return stats + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs): + if batch_gen is None: + batch_gen = min(batch_size, 4) + assert batch_size % batch_gen == 0 + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen) + + # Initialize. + stats = FeatureStats(**stats_kwargs) + assert stats.max_items is not None + progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + while not stats.is_full(): + images = [] + for _i in range(batch_size // batch_gen): + z = torch.randn([batch_gen, G.z_dim], device=opts.device) + img = G(z=z, c=next(c_iter), **opts.G_kwargs) + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + images.append(img) + images = torch.cat(images) + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images, **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + return stats + +#---------------------------------------------------------------------------- diff --git a/metrics/perceptual_path_length.py b/metrics/perceptual_path_length.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb74396475181c3a80feb6321d3b0f45eda7000 --- /dev/null +++ b/metrics/perceptual_path_length.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator +Architecture for Generative Adversarial Networks". Matches the original +implementation by Karras et al. at +https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" + +import copy +import numpy as np +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + +#---------------------------------------------------------------------------- + +class PPLSampler(torch.nn.Module): + def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__() + self.G = copy.deepcopy(G) + self.G_kwargs = G_kwargs + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.crop = crop + self.vgg16 = copy.deepcopy(vgg16) + + def forward(self, c): + # Generate random latents and interpolation t-values. + t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) + z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) + + # Interpolate in W or Z. + if self.space == 'w': + w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) + wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) + wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) + else: # space == 'z' + zt0 = slerp(z0, z1, t.unsqueeze(1)) + zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) + wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) + + # Randomize noise buffers. + for name, buf in self.G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Generate images. + img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) + + # Center crop. + if self.crop: + assert img.shape[2] == img.shape[3] + c = img.shape[2] // 8 + img = img[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample to 256x256. + factor = self.G.img_resolution // 256 + if factor > 1: + img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) + + # Scale dynamic range from [-1,1] to [0,255]. + img = (img + 1) * (255 / 2) + if self.G.img_channels == 1: + img = img.repeat([1, 3, 1, 1]) + + # Evaluate differential LPIPS. + lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) + dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 + return dist + +#---------------------------------------------------------------------------- + +def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): + vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) + + # Setup sampler and labels. + sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) + sampler.eval().requires_grad_(False).to(opts.device) + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + dist = [] + progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + x = sampler(next(c_iter)) + for src in range(opts.num_gpus): + y = x.clone() + if opts.num_gpus > 1: + torch.distributed.broadcast(y, src=src) + dist.append(y) + progress.update(num_samples) + + # Compute PPL. + if opts.rank != 0: + return float('nan') + dist = torch.cat(dist)[:num_samples].cpu().numpy() + lo = np.percentile(dist, 1, interpolation='lower') + hi = np.percentile(dist, 99, interpolation='higher') + ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() + return float(ppl) + +#---------------------------------------------------------------------------- diff --git a/metrics/precision_recall.py b/metrics/precision_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..17e5b4286b43e2d09aeba19d2521869a6cbe7ea1 --- /dev/null +++ b/metrics/precision_recall.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Precision/Recall (PR) from the paper "Improved Precision and Recall +Metric for Assessing Generative Models". Matches the original implementation +by Kynkaanniemi et al. at +https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" + +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): + assert 0 <= rank < num_gpus + num_cols = col_features.shape[0] + num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus + col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) + dist_batches = [] + for col_batch in col_batches[rank :: num_gpus]: + dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] + for src in range(num_gpus): + dist_broadcast = dist_batch.clone() + if num_gpus > 1: + torch.distributed.broadcast(dist_broadcast, src=src) + dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) + return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None + +#---------------------------------------------------------------------------- + +def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + detector_kwargs = dict(return_features=True) + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) + + results = dict() + for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: + kth = [] + for manifold_batch in manifold.split(row_batch_size): + dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) + kth = torch.cat(kth) if opts.rank == 0 else None + pred = [] + for probes_batch in probes.split(row_batch_size): + dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) + results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') + return results['precision'], results['recall'] + +#---------------------------------------------------------------------------- diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dffd4bd8a75b3495862954546ea04bb7ef39c998 --- /dev/null +++ b/torch_utils/custom_ops.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/torch_utils/misc.py b/torch_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..02f97e276e756bb1c4140ac16e7e4bcc63da628a --- /dev/null +++ b/torch_utils/misc.py @@ -0,0 +1,266 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/__init__.py b/torch_utils/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/torch_utils/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/ops/bias_act.cpp b/torch_utils/ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..218bc8fdf74f8e7ff74f6676f49b231c6a57bfb4 --- /dev/null +++ b/torch_utils/ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.cu b/torch_utils/ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..77979ab3c814dcc8e634468b0738e10b7df250fa --- /dev/null +++ b/torch_utils/ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.h b/torch_utils/ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..db231527a14bde51e26dbf41c6cd4f771a7ea50c --- /dev/null +++ b/torch_utils/ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.py b/torch_utils/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..af13db543c63934d902c08f1a957bee6012e1f27 --- /dev/null +++ b/torch_utils/ops/bias_act.py @@ -0,0 +1,209 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import numpy as np +import torch +import dnnlib + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..8056b5dbac9117d64aadb23bf7a3a36388a6bb99 --- /dev/null +++ b/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() +_null_tensor = torch.empty([0]) + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_resample.py b/torch_utils/ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d72402fa04b01c7983ed9b372ff5f3283717f3 --- /dev/null +++ b/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight and (kw > 1 or kh > 1): + w = w.flip([2, 3]) + + # Execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/filtered_lrelu.cpp b/torch_utils/ops/filtered_lrelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a99e11ded7eff66f356c0f15d8ed865988d4358 --- /dev/null +++ b/torch_utils/ops/filtered_lrelu.cpp @@ -0,0 +1,300 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "filtered_lrelu.h" + +//------------------------------------------------------------------------ + +static std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, + int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a CUDA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) + { + // No kernel found - return empty tensors and indicate missing kernel with return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) + { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); + } + + // Populate rest of CUDA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose CUDA kernel. + filtered_lrelu_kernel_spec spec = { 0 }; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] + { + if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. + { + // Choose kernel based on index type, datatype and sign read/write modes. + if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + } + }); + TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) + { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } + else + { + p.tilesXrep = 0; + p.tilesXdim = 0; + } + + // Launch filter setup kernel. + AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); + + // Copy kernels to constant memory. + if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + + // Set cache and shared memory configurations for main kernel. + AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? + AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); + AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // CUDA maximum for block z dimension. + for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); + AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); + } + + // Done. + return std::make_tuple(y, so, 0); +} + +//------------------------------------------------------------------------ + +static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) + { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); + } + + // Initialize CUDA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose CUDA kernel. + void* func = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] + { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - CUDA kernel not found"); + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. + AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); + return so; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. + m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/filtered_lrelu.cu b/torch_utils/ops/filtered_lrelu.cu new file mode 100644 index 0000000000000000000000000000000000000000..d76f1e515a6cdcdcae31424b56473fe9f5f955d8 --- /dev/null +++ b/torch_utils/ops/filtered_lrelu.cu @@ -0,0 +1,1284 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "filtered_lrelu.h" +#include + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ + MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template struct InternalType; +template <> struct InternalType +{ + typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } + __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) (((B)==1) ? (A) : \ + ((B)==2) ? ((int)((A)+1) >> 1) : \ + ((B)==4) ? ((int)((A)+3) >> 2) : \ + (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers of two. +template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) +{ + if ((N & (N-1)) && N <= 256) + y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i/N; + + x = i - y*N; +} + +// Type cast stride before reading it. +template __device__ __forceinline__ T get_stride(const int64_t& x) +{ + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. +__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) +{ + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) + { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer for main kernel. +template static cudaError_t copy_filters(cudaStream_t stream) +{ + void* src = 0; + cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) +{ + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); + static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); + static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); + static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); + static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); + static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : + (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUFD) ? szIn : + -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : + (filterMode == MODE_FUSD) ? szUpXY : + (filterMode == MODE_SUFD) ? szUpX : + (filterMode == MODE_FUFD) ? szUpXY : + -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t* s_buf0; + scalar_t* s_buf1; + if (sharedKB <= 48) + { + // Allocate shared memory arrays here. + __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } + else + { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t*)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] + scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] + scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] + if (filterMode == MODE_SUSD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } + else if (filterMode == MODE_FUSD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } + else if (filterMode == MODE_SUFD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } + else if (filterMode == MODE_FUFD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + + // Inner tile loop. + #pragma unroll 1 + for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) + { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on first tile. + if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); + #pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) + { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; + + bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); + if (!skip) + s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) + { + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInX == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInX == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + s_tileUpX[dst+2] = v.z; + s_tileUpX[dst+3] = v.w; + } + } + else if (up == 2) + { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) + { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInY == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInY == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } + if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } + else if (up == 2) + { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) + { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) + s_tileUpXY[dst + tileUpW] = v.y; + } + else + { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) + { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } + else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) + { + // Full upsampling filter. + + if (up == 2) + { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + + #define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) \ + { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 0) } + if (tap0y == 0 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 1) } + if (tap0y == 1 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 0) } + if (tap0y == 1 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 1) } + + #undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) + { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } + else if (up == 1) + { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + } + else + { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + else + { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } + else if (signRead) + { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) + { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer + *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) + { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) + { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + s_tileDownX[idx+2] = v.z; + s_tileDownX[idx+3] = v.w; + } + } + else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) + { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + } + } + else + { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) + { + // Full downsampling filter. + if (down == 2) + { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int sy = 0; sy < fdSize; sy++) + #pragma unroll + for (int sx = 0; sx < fdSize; sx++) + { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) + { + index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; + } + } + } + else if (down == 1 && !downInline) + { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) + break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. +// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) + { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) + { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) + { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) + { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t*)p.s)[is >> 4] = s; + } + } + else if (signRead) + { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) + { + uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } + else + { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) + v *= p.slope; + if (fabsf(v) > p.clamp) + v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template void* choose_filtered_lrelu_act_kernel(void) +{ + return (void*)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) +{ + filtered_lrelu_kernel_spec s = { 0 }; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ + { \ + static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ + static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ + s.setup = (void*)setup_filters_kernel; \ + s.exec = (void*)filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. + // Kernels that use more shared memory must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 + + #undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/filtered_lrelu.h b/torch_utils/ops/filtered_lrelu.h new file mode 100644 index 0000000000000000000000000000000000000000..2060f5dc0785df162ecebbdf1c034e697046ded1 --- /dev/null +++ b/torch_utils/ops/filtered_lrelu.h @@ -0,0 +1,90 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct filtered_lrelu_kernel_params +{ + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void* x; // Input tensor. + void* y; // Output tensor. + const void* b; // Bias tensor. + unsigned char* s; // Sign tensor in/out. NULL if unused. + const float* fu; // Upsampling filter. + const float* fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params +{ + void* x; // Input/output, modified in-place. + unsigned char* s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct filtered_lrelu_kernel_spec +{ + void* setup; // Function for filter kernel setup. + void* exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template void* choose_filtered_lrelu_act_kernel(void); +template cudaError_t copy_filters(cudaStream_t stream); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/filtered_lrelu.py b/torch_utils/ops/filtered_lrelu.py new file mode 100644 index 0000000000000000000000000000000000000000..48d619a85c348343edb2d2e0fd9d1cf6f99ce8b7 --- /dev/null +++ b/torch_utils/ops/filtered_lrelu.py @@ -0,0 +1,274 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import numpy as np +import torch +import warnings + +from .. import custom_ops +from .. import misc +from . import upfirdn2d +from . import bias_act + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='filtered_lrelu_plugin', + sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], + headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + +#---------------------------------------------------------------------------- + +def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): + r"""Filtered leaky ReLU for a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Add channel-specific bias if provided (`b`). + + 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 3. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 5. Multiply each value by the provided gain factor (`gain`). + + 6. Apply leaky ReLU activation function to each value. + + 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. + + 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking + it so that the footprint of all output pixels lies within the input image. + + 9. Downsample the image by keeping every Nth pixel (`down`). + + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + fu: Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + fd: Float32 downsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The length of vector must must match the channel dimension of `x`. + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor. (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + gain: Overall scaling factor for signal magnitude (default: sqrt(2)). + slope: Slope on the negative side of leaky ReLU (default: 0.2). + clamp: Maximum magnitude for leaky ReLU output (default: None). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) + return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using + existing `upfirdn2n()` and `bias_act()` ops. + """ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + fu_w, fu_h = _get_filter_size(fu) + fd_w, fd_h = _get_filter_size(fd) + if b is not None: + assert isinstance(b, torch.Tensor) and b.dtype == x.dtype + misc.assert_shape(b, [x.shape[1]]) + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + assert slope == float(slope) and slope >= 0 + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + + # Calculate output size. + batch_size, channels, in_h, in_w = x.shape + in_dtype = x.dtype + out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + # Compute using existing ops. + x = bias_act.bias_act(x=x, b=b) # Apply bias. + x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. + x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Check output shape & dtype. + misc.assert_shape(x, [batch_size, channels, out_h, out_w]) + assert x.dtype == in_dtype + return x + +#---------------------------------------------------------------------------- + +_filtered_lrelu_cuda_cache = dict() + +def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Fast CUDA implementation of `filtered_lrelu()` using custom ops. + """ + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + gain = float(gain) + assert slope == float(slope) and slope >= 0 + slope = float(slope) + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + clamp = float(clamp if clamp is not None else 'inf') + + # Lookup from cache. + key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) + if key in _filtered_lrelu_cuda_cache: + return _filtered_lrelu_cuda_cache[key] + + # Forward op. + class FilteredLReluCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + + # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). + if fu is None: + fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if fd is None: + fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert 1 <= fu.ndim <= 2 + assert 1 <= fd.ndim <= 2 + + # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. + if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: + fu = fu.square()[None] + if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: + fd = fd.square()[None] + + # Missing sign input tensor. + if si is None: + si = torch.empty([0]) + + # Missing bias tensor. + if b is None: + b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) + + # Construct internal sign tensor only if gradients are needed. + write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) + + # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. + strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] + if any(a < b for a, b in zip(strides[:-1], strides[1:])): + warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) + + # Call C++/Cuda plugin if datatype is supported. + if x.dtype in [torch.float16, torch.float32]: + if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): + warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) + y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) + else: + return_code = -1 + + # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because + # only the bit-packed sign tensor is retained for gradient computation. + if return_code < 0: + warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) + + y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. + y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. + y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Prepare for gradient computation. + ctx.save_for_backward(fu, fd, (si if si.numel() else so)) + ctx.x_shape = x.shape + ctx.y_shape = y.shape + ctx.s_ofs = sx, sy + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + fu, fd, si = ctx.saved_tensors + _, _, xh, xw = ctx.x_shape + _, _, yh, yw = ctx.y_shape + sx, sy = ctx.s_ofs + dx = None # 0 + dfu = None; assert not ctx.needs_input_grad[1] + dfd = None; assert not ctx.needs_input_grad[2] + db = None # 3 + dsi = None; assert not ctx.needs_input_grad[4] + dsx = None; assert not ctx.needs_input_grad[5] + dsy = None; assert not ctx.needs_input_grad[6] + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: + pp = [ + (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, + xw * up - yw * down + px0 - (up - 1), + (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, + xh * up - yh * down + py0 - (up - 1), + ] + gg = gain * (up ** 2) / (down ** 2) + ff = (not flip_filter) + sx = sx - (fu.shape[-1] - 1) + px0 + sy = sy - (fu.shape[0] - 1) + py0 + dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) + + if ctx.needs_input_grad[3]: + db = dx.sum([0, 2, 3]) + + return dx, dfu, dfd, db, dsi, dsx, dsy + + # Add to cache. + _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda + return FilteredLReluCuda + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/filtered_lrelu_ns.cu b/torch_utils/ops/filtered_lrelu_ns.cu new file mode 100644 index 0000000000000000000000000000000000000000..55ef84131882c26f3ee329063d010a1622607517 --- /dev/null +++ b/torch_utils/ops/filtered_lrelu_ns.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for no signs mode (no gradients required). + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/torch_utils/ops/filtered_lrelu_rd.cu b/torch_utils/ops/filtered_lrelu_rd.cu new file mode 100644 index 0000000000000000000000000000000000000000..f41832c71411711093820f3a0bbc3fc69f6cfe79 --- /dev/null +++ b/torch_utils/ops/filtered_lrelu_rd.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign read mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/torch_utils/ops/filtered_lrelu_wr.cu b/torch_utils/ops/filtered_lrelu_wr.cu new file mode 100644 index 0000000000000000000000000000000000000000..0265d183015feb13e6581fde9036eb6ef40cc89d --- /dev/null +++ b/torch_utils/ops/filtered_lrelu_wr.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign write mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/torch_utils/ops/fma.py b/torch_utils/ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9b3e9f6e3dd8afae6ccd7ea2b66b5e13ec0dcf --- /dev/null +++ b/torch_utils/ops/fma.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..269ffe81b04a8b11b4a8dea1913ae876b0ac4d30 --- /dev/null +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/upfirdn2d.cpp b/torch_utils/ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8262e0f063dcf919a686bdd0d69e001dbb9b4a8b --- /dev/null +++ b/torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,107 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.cu b/torch_utils/ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..84d45bacfb621591157ae573de4ba8479dc5ec76 --- /dev/null +++ b/torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,384 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.h b/torch_utils/ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..eb7db43943d43a803624d9379508dd03ae2b6cdb --- /dev/null +++ b/torch_utils/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.py b/torch_utils/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..504d3d5c31c3339d450bbeacf6df9db0a5f534ed --- /dev/null +++ b/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,389 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import numpy as np +import torch + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='upfirdn2d_plugin', + sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], + headers=['upfirdn2d.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/persistence.py b/torch_utils/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..a61fa032cf356df27df1acfcc62fc2201aea3974 --- /dev/null +++ b/torch_utils/persistence.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/training_stats.py b/torch_utils/training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..64e7835210d51d923e6a45240d27020a20e219de --- /dev/null +++ b/torch_utils/training_stats.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7f055b300f7b03747bb237b2458c067dd6512379 --- /dev/null +++ b/train.py @@ -0,0 +1,288 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Train a GAN using the techniques described in the paper +"Alias-Free Generative Adversarial Networks".""" + +import os +import click +import re +import json +import tempfile +import torch + +import dnnlib +from training import training_loop +from metrics import metric_main +from torch_utils import training_stats +from torch_utils import custom_ops + +#---------------------------------------------------------------------------- + +def subprocess_fn(rank, c, temp_dir): + dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) + + # Init torch.distributed. + if c.num_gpus > 1: + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=c.num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + if rank != 0: + custom_ops.verbosity = 'none' + + # Execute training loop. + training_loop.training_loop(rank=rank, **c) + +#---------------------------------------------------------------------------- + +def launch_training(c, desc, outdir, dry_run): + dnnlib.util.Logger(should_flush=True) + + # Pick output directory. + prev_run_dirs = [] + if os.path.isdir(outdir): + prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] + prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] + prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] + cur_run_id = max(prev_run_ids, default=-1) + 1 + c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}') + assert not os.path.exists(c.run_dir) + + # Print options. + print() + print('Training options:') + print(json.dumps(c, indent=2)) + print() + print(f'Output directory: {c.run_dir}') + print(f'Number of GPUs: {c.num_gpus}') + print(f'Batch size: {c.batch_size} images') + print(f'Training duration: {c.total_kimg} kimg') + print(f'Dataset path: {c.training_set_kwargs.path}') + print(f'Dataset size: {c.training_set_kwargs.max_size} images') + print(f'Dataset resolution: {c.training_set_kwargs.resolution}') + print(f'Dataset labels: {c.training_set_kwargs.use_labels}') + print(f'Dataset x-flips: {c.training_set_kwargs.xflip}') + print() + + # Dry run? + if dry_run: + print('Dry run; exiting.') + return + + # Create output directory. + print('Creating output directory...') + os.makedirs(c.run_dir) + with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: + json.dump(c, f, indent=2) + + # Launch processes. + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + if c.num_gpus == 1: + subprocess_fn(rank=0, c=c, temp_dir=temp_dir) + else: + torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus) + +#---------------------------------------------------------------------------- + +def init_dataset_kwargs(data): + try: + dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) + dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset. + dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution. + dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels. + dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size. + return dataset_kwargs, dataset_obj.name + except IOError as err: + raise click.ClickException(f'--data: {err}') + +#---------------------------------------------------------------------------- + +def parse_comma_separated_list(s): + if isinstance(s, list): + return s + if s is None or s.lower() == 'none' or s == '': + return [] + return s.split(',') + +#---------------------------------------------------------------------------- + +@click.command() + +# Required. +@click.option('--outdir', help='Where to save the results', metavar='DIR', required=True) +@click.option('--cfg', help='Base configuration', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), required=True) +@click.option('--data', help='Training data', metavar='[ZIP|DIR]', type=str, required=True) +@click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True) +@click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) +@click.option('--gamma', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), required=True) + +# Optional features. +@click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--mirror', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--aug', help='Augmentation mode', type=click.Choice(['noaug', 'ada', 'fixed']), default='ada', show_default=True) +@click.option('--resume', help='Resume from given network pickle', metavar='[PATH|URL]', type=str) +@click.option('--freezed', help='Freeze first layers of D', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) + +# Misc hyperparameters. +@click.option('--p', help='Probability for --aug=fixed', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True) +@click.option('--target', help='Target value for --aug=ada', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.6, show_default=True) +@click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) +@click.option('--cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=32768, show_default=True) +@click.option('--cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) +@click.option('--glr', help='G learning rate [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0)) +@click.option('--dlr', help='D learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.002, show_default=True) +@click.option('--map-depth', help='Mapping network depth [default: varies]', metavar='INT', type=click.IntRange(min=1)) +@click.option('--mbstd-group', help='Minibatch std group size', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True) + +# Misc settings. +@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) +@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) +@click.option('--kimg', help='Total training duration', metavar='KIMG', type=click.IntRange(min=1), default=25000, show_default=True) +@click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=4, show_default=True) +@click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) +@click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +@click.option('--fp32', help='Disable mixed-precision', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--nobench', help='Disable cuDNN benchmarking', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True) +@click.option('-n','--dry-run', help='Print training options and exit', is_flag=True) + +def main(**kwargs): + """Train a GAN using the techniques described in the paper + "Alias-Free Generative Adversarial Networks". + + Examples: + + \b + # Train StyleGAN3-T for AFHQv2 using 8 GPUs. + python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \\ + --gpus=8 --batch=32 --gamma=8.2 --mirror=1 + + \b + # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. + python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \\ + --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \\ + --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl + + \b + # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. + python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \\ + --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug + """ + + # Initialize config. + opts = dnnlib.EasyDict(kwargs) # Command line arguments. + c = dnnlib.EasyDict() # Main config dict. + c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict()) + c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) + c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) + c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) + c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss') + c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2) + + # Training set. + c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data) + if opts.cond and not c.training_set_kwargs.use_labels: + raise click.ClickException('--cond=True requires labels specified in dataset.json') + c.training_set_kwargs.use_labels = opts.cond + c.training_set_kwargs.xflip = opts.mirror + + # Hyperparameters & settings. + c.num_gpus = opts.gpus + c.batch_size = opts.batch + c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus + c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase + c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax + c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth + c.D_kwargs.block_kwargs.freeze_layers = opts.freezed + c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group + c.loss_kwargs.r1_gamma = opts.gamma + c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr + c.D_opt_kwargs.lr = opts.dlr + c.metrics = opts.metrics + c.total_kimg = opts.kimg + c.kimg_per_tick = opts.tick + c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap + c.random_seed = c.training_set_kwargs.random_seed = opts.seed + c.data_loader_kwargs.num_workers = opts.workers + + # Sanity checks. + if c.batch_size % c.num_gpus != 0: + raise click.ClickException('--batch must be a multiple of --gpus') + if c.batch_size % (c.num_gpus * c.batch_gpu) != 0: + raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu') + if c.batch_gpu < c.D_kwargs.epilogue_kwargs.mbstd_group_size: + raise click.ClickException('--batch-gpu cannot be smaller than --mbstd') + if any(not metric_main.is_valid_metric(metric) for metric in c.metrics): + raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) + + # Base configuration. + c.ema_kimg = c.batch_size * 10 / 32 + if opts.cfg == 'stylegan2': + c.G_kwargs.class_name = 'training.networks_stylegan2.Generator' + c.loss_kwargs.style_mixing_prob = 0.9 # Enable style mixing regularization. + c.loss_kwargs.pl_weight = 2 # Enable path length regularization. + c.G_reg_interval = 4 # Enable lazy regularization for G. + c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions. + c.loss_kwargs.pl_no_weight_grad = True # Speed up path length regularization by skipping gradient computation wrt. conv2d weights. + else: + c.G_kwargs.class_name = 'training.networks_stylegan3.Generator' + c.G_kwargs.magnitude_ema_beta = 0.5 ** (c.batch_size / (20 * 1e3)) + if opts.cfg == 'stylegan3-r': + c.G_kwargs.conv_kernel = 1 # Use 1x1 convolutions. + c.G_kwargs.channel_base *= 2 # Double the number of feature maps. + c.G_kwargs.channel_max *= 2 + c.G_kwargs.use_radial_filters = True # Use radially symmetric downsampling filters. + c.loss_kwargs.blur_init_sigma = 10 # Blur the images seen by the discriminator. + c.loss_kwargs.blur_fade_kimg = c.batch_size * 200 / 32 # Fade out the blur during the first N kimg. + + # Augmentation. + if opts.aug != 'noaug': + c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1) + if opts.aug == 'ada': + c.ada_target = opts.target + if opts.aug == 'fixed': + c.augment_p = opts.p + + # Resume. + if opts.resume is not None: + c.resume_pkl = opts.resume + c.ada_kimg = 100 # Make ADA react faster at the beginning. + c.ema_rampup = None # Disable EMA rampup. + c.loss_kwargs.blur_init_sigma = 0 # Disable blur rampup. + + # Performance-related toggles. + if opts.fp32: + c.G_kwargs.num_fp16_res = c.D_kwargs.num_fp16_res = 0 + c.G_kwargs.conv_clamp = c.D_kwargs.conv_clamp = None + if opts.nobench: + c.cudnn_benchmark = False + + # Description string. + desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}' + if opts.desc is not None: + desc += f'-{opts.desc}' + + # Launch. + launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/training/augment.py b/training/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a4b02ffd6079b089f8f0fe7f5d86703500215d --- /dev/null +++ b/training/augment.py @@ -0,0 +1,436 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Augmentation pipeline from the paper +"Training Generative Adversarial Networks with Limited Data". +Matches the original implementation by Karras et al. at +https://github.com/NVlabs/stylegan2-ada/blob/main/training/augment.py""" + +import numpy as np +import scipy.signal +import torch +from torch_utils import persistence +from torch_utils import misc +from torch_utils.ops import upfirdn2d +from torch_utils.ops import grid_sample_gradfix +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + +#---------------------------------------------------------------------------- +# Versatile image augmentation pipeline from the paper +# "Training Generative Adversarial Networks with Limited Data". +# +# All augmentations are disabled by default; individual augmentations can +# be enabled by setting their probability multipliers to 1. + +@persistence.persistent_class +class AugmentPipe(torch.nn.Module): + def __init__(self, + xflip=0, rotate90=0, xint=0, xint_max=0.125, + scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, + imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, + noise=0, cutout=0, noise_std=0.1, cutout_size=0.5, + ): + super().__init__() + self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.xflip = float(xflip) # Probability multiplier for x-flip. + self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float(xint) # Probability multiplier for integer translation. + self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float(xfrac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. + + # Image-space corruptions. + self.noise = float(noise) # Probability multiplier for additive RGB noise. + self.cutout = float(cutout) # Probability multiplier for cutout. + self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. + self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, images, debug_percentile=None): + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply x-flip with probability (xflip * strength). + if self.xflip > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 2) + i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) + + # Apply 90 degree rotations with probability (rotate90 * strength). + if self.rotate90 > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 4) + i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 4)) + G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) + s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) + s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) + + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) + c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. + if self.lumaflip > 0: + i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) + i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max + theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) + s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) + t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape([1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + + # Apply additive RGB noise with probability (noise * strength). + if self.noise > 0: + sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std + sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) + if debug_percentile is not None: + sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std) + images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma + + # Apply cutout with probability (cutout * strength). + if self.cutout > 0: + size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) + size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) + center = torch.rand([batch_size, 2, 1, 1, 1], device=device) + if debug_percentile is not None: + size = torch.full_like(size, self.cutout_size) + center = torch.full_like(center, debug_percentile) + coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) + coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) + mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) + mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) + mask = torch.logical_or(mask_x, mask_y).to(torch.float32) + images = images * mask + + return images + +#---------------------------------------------------------------------------- diff --git a/training/dataset.py b/training/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1d87ba7f20dfcf68a4a5c9f97545765c0b8537a7 --- /dev/null +++ b/training/dataset.py @@ -0,0 +1,238 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Streaming images and labels from datasets created with dataset_tool.py.""" + +import os +import numpy as np +import zipfile +import PIL.Image +import json +import torch +import dnnlib + +try: + import pyspng +except ImportError: + pyspng = None + +#---------------------------------------------------------------------------- + +class Dataset(torch.utils.data.Dataset): + def __init__(self, + name, # Name of the dataset. + raw_shape, # Shape of the raw image data (NCHW). + max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. + use_labels = False, # Enable conditioning labels? False = label dimension is zero. + xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. + random_seed = 0, # Random seed to use when applying max_size. + ): + self._name = name + self._raw_shape = list(raw_shape) + self._use_labels = use_labels + self._raw_labels = None + self._label_shape = None + + # Apply max_size. + self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) + if (max_size is not None) and (self._raw_idx.size > max_size): + np.random.RandomState(random_seed).shuffle(self._raw_idx) + self._raw_idx = np.sort(self._raw_idx[:max_size]) + + # Apply xflip. + self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if xflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) + + def _get_raw_labels(self): + if self._raw_labels is None: + self._raw_labels = self._load_raw_labels() if self._use_labels else None + if self._raw_labels is None: + self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) + assert isinstance(self._raw_labels, np.ndarray) + assert self._raw_labels.shape[0] == self._raw_shape[0] + assert self._raw_labels.dtype in [np.float32, np.int64] + if self._raw_labels.dtype == np.int64: + assert self._raw_labels.ndim == 1 + assert np.all(self._raw_labels >= 0) + return self._raw_labels + + def close(self): # to be overridden by subclass + pass + + def _load_raw_image(self, raw_idx): # to be overridden by subclass + raise NotImplementedError + + def _load_raw_labels(self): # to be overridden by subclass + raise NotImplementedError + + def __getstate__(self): + return dict(self.__dict__, _raw_labels=None) + + def __del__(self): + try: + self.close() + except: + pass + + def __len__(self): + return self._raw_idx.size + + def __getitem__(self, idx): + image = self._load_raw_image(self._raw_idx[idx]) + assert isinstance(image, np.ndarray) + assert list(image.shape) == self.image_shape + assert image.dtype == np.uint8 + if self._xflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, :, ::-1] + return image.copy(), self.get_label(idx) + + def get_label(self, idx): + label = self._get_raw_labels()[self._raw_idx[idx]] + if label.dtype == np.int64: + onehot = np.zeros(self.label_shape, dtype=np.float32) + onehot[label] = 1 + label = onehot + return label.copy() + + def get_details(self, idx): + d = dnnlib.EasyDict() + d.raw_idx = int(self._raw_idx[idx]) + d.xflip = (int(self._xflip[idx]) != 0) + d.raw_label = self._get_raw_labels()[d.raw_idx].copy() + return d + + @property + def name(self): + return self._name + + @property + def image_shape(self): + return list(self._raw_shape[1:]) + + @property + def num_channels(self): + assert len(self.image_shape) == 3 # CHW + return self.image_shape[0] + + @property + def resolution(self): + assert len(self.image_shape) == 3 # CHW + assert self.image_shape[1] == self.image_shape[2] + return self.image_shape[1] + + @property + def label_shape(self): + if self._label_shape is None: + raw_labels = self._get_raw_labels() + if raw_labels.dtype == np.int64: + self._label_shape = [int(np.max(raw_labels)) + 1] + else: + self._label_shape = raw_labels.shape[1:] + return list(self._label_shape) + + @property + def label_dim(self): + assert len(self.label_shape) == 1 + return self.label_shape[0] + + @property + def has_labels(self): + return any(x != 0 for x in self.label_shape) + + @property + def has_onehot_labels(self): + return self._get_raw_labels().dtype == np.int64 + +#---------------------------------------------------------------------------- + +class ImageFolderDataset(Dataset): + def __init__(self, + path, # Path to directory or zip. + resolution = None, # Ensure specific resolution, None = highest available. + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self._path = path + self._zipfile = None + + if os.path.isdir(self._path): + self._type = 'dir' + self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} + elif self._file_ext(self._path) == '.zip': + self._type = 'zip' + self._all_fnames = set(self._get_zipfile().namelist()) + else: + raise IOError('Path must point to a directory or zip') + + PIL.Image.init() + self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) + if len(self._image_fnames) == 0: + raise IOError('No image files found in the specified path') + + name = os.path.splitext(os.path.basename(self._path))[0] + raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) + if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): + raise IOError('Image files do not match the specified resolution') + super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) + + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def _get_zipfile(self): + assert self._type == 'zip' + if self._zipfile is None: + self._zipfile = zipfile.ZipFile(self._path) + return self._zipfile + + def _open_file(self, fname): + if self._type == 'dir': + return open(os.path.join(self._path, fname), 'rb') + if self._type == 'zip': + return self._get_zipfile().open(fname, 'r') + return None + + def close(self): + try: + if self._zipfile is not None: + self._zipfile.close() + finally: + self._zipfile = None + + def __getstate__(self): + return dict(super().__getstate__(), _zipfile=None) + + def _load_raw_image(self, raw_idx): + fname = self._image_fnames[raw_idx] + with self._open_file(fname) as f: + if pyspng is not None and self._file_ext(fname) == '.png': + image = pyspng.load(f.read()) + else: + image = np.array(PIL.Image.open(f)) + if image.ndim == 2: + image = image[:, :, np.newaxis] # HW => HWC + image = image.transpose(2, 0, 1) # HWC => CHW + return image + + def _load_raw_labels(self): + fname = 'dataset.json' + if fname not in self._all_fnames: + return None + with self._open_file(fname) as f: + labels = json.load(f)['labels'] + if labels is None: + return None + labels = dict(labels) + labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] + labels = np.array(labels) + labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) + return labels + +#---------------------------------------------------------------------------- diff --git a/training/loss.py b/training/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..920d05945b7156727044db3ed80e7aa94322af7b --- /dev/null +++ b/training/loss.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Loss functions.""" + +import numpy as np +import torch +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import upfirdn2d + +#---------------------------------------------------------------------------- + +class Loss: + def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass + raise NotImplementedError() + +#---------------------------------------------------------------------------- + +class StyleGAN2Loss(Loss): + def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): + super().__init__() + self.device = device + self.G = G + self.D = D + self.augment_pipe = augment_pipe + self.r1_gamma = r1_gamma + self.style_mixing_prob = style_mixing_prob + self.pl_weight = pl_weight + self.pl_batch_shrink = pl_batch_shrink + self.pl_decay = pl_decay + self.pl_no_weight_grad = pl_no_weight_grad + self.pl_mean = torch.zeros([], device=device) + self.blur_init_sigma = blur_init_sigma + self.blur_fade_kimg = blur_fade_kimg + + def run_G(self, z, c, update_emas=False): + ws = self.G.mapping(z, c, update_emas=update_emas) + if self.style_mixing_prob > 0: + with torch.autograd.profiler.record_function('style_mixing'): + cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) + cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) + ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] + img = self.G.synthesis(ws, update_emas=update_emas) + return img, ws + + def run_D(self, img, c, blur_sigma=0, update_emas=False): + blur_size = np.floor(blur_sigma * 3) + if blur_size > 0: + with torch.autograd.profiler.record_function('blur'): + f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() + img = upfirdn2d.filter2d(img, f / f.sum()) + if self.augment_pipe is not None: + img = self.augment_pipe(img) + logits = self.D(img, c, update_emas=update_emas) + return logits + + def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): + assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] + if self.pl_weight == 0: + phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) + if self.r1_gamma == 0: + phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) + blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 + + # Gmain: Maximize logits for generated images. + if phase in ['Gmain', 'Gboth']: + with torch.autograd.profiler.record_function('Gmain_forward'): + gen_img, _gen_ws = self.run_G(gen_z, gen_c) + gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) + training_stats.report('Loss/scores/fake', gen_logits) + training_stats.report('Loss/signs/fake', gen_logits.sign()) + loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) + training_stats.report('Loss/G/loss', loss_Gmain) + with torch.autograd.profiler.record_function('Gmain_backward'): + loss_Gmain.mean().mul(gain).backward() + + # Gpl: Apply path length regularization. + if phase in ['Greg', 'Gboth']: + with torch.autograd.profiler.record_function('Gpl_forward'): + batch_size = gen_z.shape[0] // self.pl_batch_shrink + gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size]) + pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) + with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): + pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] + pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() + pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) + self.pl_mean.copy_(pl_mean.detach()) + pl_penalty = (pl_lengths - pl_mean).square() + training_stats.report('Loss/pl_penalty', pl_penalty) + loss_Gpl = pl_penalty * self.pl_weight + training_stats.report('Loss/G/reg', loss_Gpl) + with torch.autograd.profiler.record_function('Gpl_backward'): + loss_Gpl.mean().mul(gain).backward() + + # Dmain: Minimize logits for generated images. + loss_Dgen = 0 + if phase in ['Dmain', 'Dboth']: + with torch.autograd.profiler.record_function('Dgen_forward'): + gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) + gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) + training_stats.report('Loss/scores/fake', gen_logits) + training_stats.report('Loss/signs/fake', gen_logits.sign()) + loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) + with torch.autograd.profiler.record_function('Dgen_backward'): + loss_Dgen.mean().mul(gain).backward() + + # Dmain: Maximize logits for real images. + # Dr1: Apply R1 regularization. + if phase in ['Dmain', 'Dreg', 'Dboth']: + name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' + with torch.autograd.profiler.record_function(name + '_forward'): + real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']) + real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) + training_stats.report('Loss/scores/real', real_logits) + training_stats.report('Loss/signs/real', real_logits.sign()) + + loss_Dreal = 0 + if phase in ['Dmain', 'Dboth']: + loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) + training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) + + loss_Dr1 = 0 + if phase in ['Dreg', 'Dboth']: + with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): + r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] + r1_penalty = r1_grads.square().sum([1,2,3]) + loss_Dr1 = r1_penalty * (self.r1_gamma / 2) + training_stats.report('Loss/r1_penalty', r1_penalty) + training_stats.report('Loss/D/reg', loss_Dr1) + + with torch.autograd.profiler.record_function(name + '_backward'): + (loss_Dreal + loss_Dr1).mean().mul(gain).backward() + +#---------------------------------------------------------------------------- diff --git a/training/networks_stylegan2.py b/training/networks_stylegan2.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab31062217fc7c8b8bc5ae8f45ddb23705fafe6 --- /dev/null +++ b/training/networks_stylegan2.py @@ -0,0 +1,794 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Network architectures from the paper +"Analyzing and Improving the Image Quality of StyleGAN". +Matches the original implementation of configs E-F by Karras et al. at +https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py""" + +import numpy as np +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise = None, # Optional noise tensor to add to the output activations. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + padding = 0, # Padding with respect to the upsampled image. + resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate = True, # Apply weight demodulation? + flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + # Execute as one fused op using grouped convolution. + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + misc.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Conv2dLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output to +-X, None = disable clamping. + channels_last = False, # Expect the input to have memory_format=channels_last? + trainable = True, # Update the weights of this layer during training? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + self.up = up + self.down = down + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.act_gain = bias_act.activation_funcs[activation].def_gain + + memory_format = torch.channels_last if channels_last else torch.contiguous_format + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + bias = torch.zeros([out_channels]) if bias else None + if trainable: + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) if bias is not None else None + else: + self.register_buffer('weight', weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x, gain=1): + w = self.weight * self.weight_gain + b = self.bias.to(x.dtype) if self.bias is not None else None + flip_weight = (self.up == 1) # slightly faster + x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},', + f'up={self.up}, down={self.down}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers = 8, # Number of mapping layers. + embed_features = None, # Label embedding dimensionality, None = same as w_dim. + layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if update_emas and self.w_avg_beta is not None: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this layer. + kernel_size = 3, # Convolution kernel size. + up = 1, # Integer upsampling factor. + use_noise = True, # Enable noise input? + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + channels_last = False, # Use channels_last format for the weights? + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.resolution = resolution + self.up = up + self.use_noise = use_noise + self.activation = activation + self.conv_clamp = conv_clamp + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.padding = kernel_size // 2 + self.act_gain = bias_act.activation_funcs[activation].def_gain + + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + if use_noise: + self.register_buffer('noise_const', torch.randn([resolution, resolution])) + self.noise_strength = torch.nn.Parameter(torch.zeros([])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + + def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1): + assert noise_mode in ['random', 'const', 'none'] + in_resolution = self.resolution // self.up + misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution]) + styles = self.affine(w) + + noise = None + if self.use_noise and noise_mode == 'random': + noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength + if self.use_noise and noise_mode == 'const': + noise = self.noise_const * self.noise_strength + + flip_weight = (self.up == 1) # slightly faster + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, + padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) + + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self): + return ' '.join([ + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},', + f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class ToRGBLayer(torch.nn.Module): + def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.conv_clamp = conv_clamp + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + memory_format = torch.channels_last if channels_last else torch.contiguous_format + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + + def forward(self, x, w, fused_modconv=True): + styles = self.affine(w) * self.weight_gain + x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) + x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) + return x + + def extra_repr(self): + return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + w_dim, # Intermediate latent (W) dimensionality. + resolution, # Resolution of this block. + img_channels, # Number of output color channels. + is_last, # Is this the last block? + architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.w_dim = w_dim + self.resolution = resolution + self.img_channels = img_channels + self.is_last = is_last + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.fused_modconv_default = fused_modconv_default + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + self.num_conv = 0 + self.num_torgb = 0 + + if in_channels == 0: + self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) + + if in_channels != 0: + self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, + resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, + conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) + self.num_conv += 1 + + if is_last or architecture == 'skip': + self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, + conv_clamp=conv_clamp, channels_last=self.channels_last) + self.num_torgb += 1 + + if in_channels != 0 and architecture == 'resnet': + self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, + resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs): + _ = update_emas # unused + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + if ws.device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + fused_modconv = self.fused_modconv_default + if fused_modconv == 'inference_only': + fused_modconv = (not self.training) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **block_kwargs, # Arguments for SynthesisBlock. + ): + assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.num_fp16_res = num_fp16_res + self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + self.num_ws = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res // 2] if res > 4 else 0 + out_channels = channels_dict[res] + use_fp16 = (res >= fp16_resolution) + is_last = (res == self.img_resolution) + block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, + img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) + self.num_ws += block.num_conv + if is_last: + self.num_ws += block.num_torgb + setattr(self, f'b{res}', block) + + def forward(self, ws, **block_kwargs): + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32) + w_idx = 0 + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + x = img = None + for res, cur_ws in zip(self.block_resolutions, block_ws): + block = getattr(self, f'b{res}') + x, img = block(x, img, cur_ws, **block_kwargs) + return img + + def extra_repr(self): + return ' '.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + tmp_channels, # Number of intermediate channels. + out_channels, # Number of output channels. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + first_layer_idx, # Index of the first layer. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + use_fp16 = False, # Use FP16 for this block? + fp16_channels_last = False, # Use channels-last memory format with FP16? + freeze_layers = 0, # Freeze-D: Number of layers to freeze. + ): + assert in_channels in [0, tmp_channels] + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.img_channels = img_channels + self.first_layer_idx = first_layer_idx + self.architecture = architecture + self.use_fp16 = use_fp16 + self.channels_last = (use_fp16 and fp16_channels_last) + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.num_layers = 0 + def trainable_gen(): + while True: + layer_idx = self.first_layer_idx + self.num_layers + trainable = (layer_idx >= freeze_layers) + self.num_layers += 1 + yield trainable + trainable_iter = trainable_gen() + + if in_channels == 0 or architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) + + self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, + trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) + + self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) + + if architecture == 'resnet': + self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, + trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) + + def forward(self, x, img, force_fp32=False): + if (x if x is not None else img).device.type != 'cuda': + force_fp32 = True + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + + # Input. + if x is not None: + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # FromRGB. + if self.in_channels == 0 or self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + y = self.fromrgb(img) + x = x + y if x is not None else y + img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None + + # Main layers. + if self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x) + x = self.conv1(x, gain=np.sqrt(0.5)) + x = y.add_(x) + else: + x = self.conv0(x) + x = self.conv1(x) + + assert x.dtype == dtype + return x, img + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + def extra_repr(self): + return f'group_size={self.group_size}, num_channels={self.num_channels:d}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. + resolution, # Resolution of this block. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + assert architecture in ['orig', 'skip', 'resnet'] + super().__init__() + self.in_channels = in_channels + self.cmap_dim = cmap_dim + self.resolution = resolution + self.img_channels = img_channels + self.architecture = architecture + + if architecture == 'skip': + self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) + + def forward(self, x, img, cmap, force_fp32=False): + misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] + _ = force_fp32 # unused + dtype = torch.float32 + memory_format = torch.contiguous_format + + # FromRGB. + x = x.to(dtype=dtype, memory_format=memory_format) + if self.architecture == 'skip': + misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) + img = img.to(dtype=dtype, memory_format=memory_format) + x = x + self.fromrgb(img) + + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + x = self.fc(x.flatten(1)) + x = self.out(x) + + # Conditioning. + if self.cmap_dim > 0: + misc.assert_shape(cmap, [None, self.cmap_dim]) + x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) + + assert x.dtype == dtype + return x + + def extra_repr(self): + return f'resolution={self.resolution:d}, architecture={self.architecture:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Discriminator(torch.nn.Module): + def __init__(self, + c_dim, # Conditioning label (C) dimensionality. + img_resolution, # Input resolution. + img_channels, # Number of input color channels. + architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. + cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. + block_kwargs = {}, # Arguments for DiscriminatorBlock. + mapping_kwargs = {}, # Arguments for MappingNetwork. + epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. + ): + super().__init__() + self.c_dim = c_dim + self.img_resolution = img_resolution + self.img_resolution_log2 = int(np.log2(img_resolution)) + self.img_channels = img_channels + self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] + channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} + fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) + + if cmap_dim is None: + cmap_dim = channels_dict[4] + if c_dim == 0: + cmap_dim = 0 + + common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) + cur_layer_idx = 0 + for res in self.block_resolutions: + in_channels = channels_dict[res] if res < img_resolution else 0 + tmp_channels = channels_dict[res] + out_channels = channels_dict[res // 2] + use_fp16 = (res >= fp16_resolution) + block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, + first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) + setattr(self, f'b{res}', block) + cur_layer_idx += block.num_layers + if c_dim > 0: + self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) + self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) + + def forward(self, img, c, update_emas=False, **block_kwargs): + _ = update_emas # unused + x = None + for res in self.block_resolutions: + block = getattr(self, f'b{res}') + x, img = block(x, img, **block_kwargs) + + cmap = None + if self.c_dim > 0: + cmap = self.mapping(None, c) + x = self.b4(x, img, cmap) + return x + + def extra_repr(self): + return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' + +#---------------------------------------------------------------------------- diff --git a/training/networks_stylegan3.py b/training/networks_stylegan3.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a2b4d7f5ebde6874768c0cac011cc7867d2a08 --- /dev/null +++ b/training/networks_stylegan3.py @@ -0,0 +1,515 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generator architecture from the paper +"Alias-Free Generative Adversarial Networks".""" + +import numpy as np +import scipy.signal +import scipy.optimize +import torch +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import filtered_lrelu +from torch_utils.ops import bias_act + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def modulated_conv2d( + x, # Input tensor: [batch_size, in_channels, in_height, in_width] + w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] + s, # Style tensor: [batch_size, in_channels] + demodulate = True, # Apply weight demodulation? + padding = 0, # Padding: int or [padH, padW] + input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] +): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(x.shape[0]) + out_channels, in_channels, kh, kw = w.shape + misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] + misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + misc.assert_shape(s, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs. + if demodulate: + w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() + s = s * s.square().mean().rsqrt() + + # Modulate weights. + w = w.unsqueeze(0) # [NOIkk] + w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Demodulate weights. + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Apply input scaling. + if input_gain is not None: + input_gain = input_gain.expand(batch_size, in_channels) # [NI] + w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] + + # Execute as one fused op using grouped convolution. + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) + x = x.reshape(batch_size, -1, *x.shape[2:]) + return x + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + bias = True, # Apply additive bias before the activation function? + lr_multiplier = 1, # Learning rate multiplier. + weight_init = 1, # Initial standard deviation of the weight tensor. + bias_init = 0, # Initial value of the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) + bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) + self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality, 0 = no labels. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output. + num_layers = 2, # Number of mapping layers. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + # Construct layers. + self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None + features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers + for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): + layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + misc.assert_shape(z, [None, self.z_dim]) + if truncation_cutoff is None: + truncation_cutoff = self.num_ws + + # Embed, normalize, and concatenate inputs. + x = z.to(torch.float32) + x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() + if self.c_dim > 0: + misc.assert_shape(c, [None, self.c_dim]) + y = self.embed(c.to(torch.float32)) + y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() + x = torch.cat([x, y], dim=1) if x is not None else y + + # Execute layers. + for idx in range(self.num_layers): + x = getattr(self, f'fc{idx}')(x) + + # Update moving average of W. + if update_emas: + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast and apply truncation. + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + if truncation_psi != 1: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisInput(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + channels, # Number of output channels. + size, # Output spatial size: int or [width, height]. + sampling_rate, # Output sampling rate. + bandwidth, # Output bandwidth. + ): + super().__init__() + self.w_dim = w_dim + self.channels = channels + self.size = np.broadcast_to(np.asarray(size), [2]) + self.sampling_rate = sampling_rate + self.bandwidth = bandwidth + + # Draw random frequencies from uniform 2D disc. + freqs = torch.randn([self.channels, 2]) + radii = freqs.square().sum(dim=1, keepdim=True).sqrt() + freqs /= radii * radii.square().exp().pow(0.25) + freqs *= bandwidth + phases = torch.rand([self.channels]) - 0.5 + + # Setup parameters and buffers. + self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) + self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) + self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. + self.register_buffer('freqs', freqs) + self.register_buffer('phases', phases) + + def forward(self, w): + # Introduce batch dimension. + transforms = self.transform.unsqueeze(0) # [batch, row, col] + freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] + phases = self.phases.unsqueeze(0) # [batch, channel] + + # Apply learned transformation. + t = self.affine(w) # t = (r_c, r_s, t_x, t_y) + t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) + m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. + m_r[:, 0, 0] = t[:, 0] # r'_c + m_r[:, 0, 1] = -t[:, 1] # r'_s + m_r[:, 1, 0] = t[:, 1] # r'_s + m_r[:, 1, 1] = t[:, 0] # r'_c + m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. + m_t[:, 0, 2] = -t[:, 2] # t'_x + m_t[:, 1, 2] = -t[:, 3] # t'_y + transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. + + # Transform frequencies. + phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) + freqs = freqs @ transforms[:, :2, :2] + + # Dampen out-of-band frequencies that may occur due to the user-specified transform. + amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) + + # Construct sampling grid. + theta = torch.eye(2, 3, device=w.device) + theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate + theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate + grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) + + # Compute Fourier features. + x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] + x = x + phases.unsqueeze(1).unsqueeze(2) + x = torch.sin(x * (np.pi * 2)) + x = x * amplitudes.unsqueeze(1).unsqueeze(2) + + # Apply trainable mapping. + weight = self.weight / np.sqrt(self.channels) + x = x @ weight.t() + + # Ensure correct shape. + x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] + misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', + f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisLayer(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + is_torgb, # Is this the final ToRGB layer? + is_critically_sampled, # Does this layer use critical sampling? + use_fp16, # Does this layer use FP16? + + # Input & output specifications. + in_channels, # Number of input channels. + out_channels, # Number of output channels. + in_size, # Input spatial size: int or [width, height]. + out_size, # Output spatial size: int or [width, height]. + in_sampling_rate, # Input sampling rate (s). + out_sampling_rate, # Output sampling rate (s). + in_cutoff, # Input cutoff frequency (f_c). + out_cutoff, # Output cutoff frequency (f_c). + in_half_width, # Input transition band half-width (f_h). + out_half_width, # Output Transition band half-width (f_h). + + # Hyperparameters. + conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer. + filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling. + lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. + use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers. + conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping. + magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes. + ): + super().__init__() + self.w_dim = w_dim + self.is_torgb = is_torgb + self.is_critically_sampled = is_critically_sampled + self.use_fp16 = use_fp16 + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.conv_kernel = 1 if is_torgb else conv_kernel + self.conv_clamp = conv_clamp + self.magnitude_ema_beta = magnitude_ema_beta + + # Setup parameters and buffers. + self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) + self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) + self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) + self.register_buffer('magnitude_ema', torch.ones([])) + + # Design upsampling filter. + self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate + self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 + self.register_buffer('up_filter', self.design_lowpass_filter( + numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) + + # Design downsampling filter. + self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate + self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 + self.down_radial = use_radial_filters and not self.is_critically_sampled + self.register_buffer('down_filter', self.design_lowpass_filter( + numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) + + # Compute padding. + pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. + pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. + pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. + pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): + assert noise_mode in ['random', 'const', 'none'] # unused + misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) + misc.assert_shape(w, [x.shape[0], self.w_dim]) + + # Track input magnitude. + if update_emas: + with torch.autograd.profiler.record_function('update_magnitude_ema'): + magnitude_cur = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) + input_gain = self.magnitude_ema.rsqrt() + + # Execute affine layer. + styles = self.affine(w) + if self.is_torgb: + weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) + styles = styles * weight_gain + + # Execute modulated conv2d. + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, + padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) + + # Execute bias, filtered leaky ReLU, and clamping. + gain = 1 if self.is_torgb else np.sqrt(2) + slope = 1 if self.is_torgb else 0.2 + x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), + up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) + assert x.dtype == dtype + return x + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + assert numtaps >= 1 + + # Identity filter. + if numtaps == 1: + return None + + # Separable Kaiser low-pass filter. + if not radial: + f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return torch.as_tensor(f, dtype=torch.float32) + + # Radially symmetric jinc-based filter. + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return torch.as_tensor(f, dtype=torch.float32) + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', + f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', + f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', + f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', + f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', + f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', + f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class SynthesisNetwork(torch.nn.Module): + def __init__(self, + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output image resolution. + img_channels, # Number of color channels. + channel_base = 32768, # Overall multiplier for the number of channels. + channel_max = 512, # Maximum number of channels in any layer. + num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB. + num_critical = 2, # Number of critically sampled layers at the end. + first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}). + first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}). + last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff. + margin_size = 10, # Number of additional pixels outside the image. + output_scale = 0.25, # Scale factor for the output image. + num_fp16_res = 4, # Use FP16 for the N highest resolutions. + **layer_kwargs, # Arguments for SynthesisLayer. + ): + super().__init__() + self.w_dim = w_dim + self.num_ws = num_layers + 2 + self.img_resolution = img_resolution + self.img_channels = img_channels + self.num_layers = num_layers + self.num_critical = num_critical + self.margin_size = margin_size + self.output_scale = output_scale + self.num_fp16_res = num_fp16_res + + # Geometric progression of layer cutoffs and min. stopbands. + last_cutoff = self.img_resolution / 2 # f_{c,N} + last_stopband = last_cutoff * last_stopband_rel # f_{t,N} + exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) + cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] + stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] + + # Compute remaining layer parameters. + sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] + half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] + sizes = sampling_rates + self.margin_size * 2 + sizes[-2:] = self.img_resolution + channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) + channels[-1] = self.img_channels + + # Construct layers. + self.input = SynthesisInput( + w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), + sampling_rate=sampling_rates[0], bandwidth=cutoffs[0]) + self.layer_names = [] + for idx in range(self.num_layers + 1): + prev = max(idx - 1, 0) + is_torgb = (idx == self.num_layers) + is_critically_sampled = (idx >= self.num_layers - self.num_critical) + use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) + layer = SynthesisLayer( + w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, + in_channels=int(channels[prev]), out_channels= int(channels[idx]), + in_size=int(sizes[prev]), out_size=int(sizes[idx]), + in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), + in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], + in_half_width=half_widths[prev], out_half_width=half_widths[idx], + **layer_kwargs) + name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' + setattr(self, name, layer) + self.layer_names.append(name) + + def forward(self, ws, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) + ws = ws.to(torch.float32).unbind(dim=1) + + # Execute layers. + x = self.input(ws[0]) + for name, w in zip(self.layer_names, ws[1:]): + x = getattr(self, name)(x, w, **layer_kwargs) + if self.output_scale != 1: + x = x * self.output_scale + + # Ensure correct shape and dtype. + misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) + x = x.to(torch.float32) + return x + + def extra_repr(self): + return '\n'.join([ + f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', + f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', + f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', + f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) + +#---------------------------------------------------------------------------- + +@persistence.persistent_class +class Generator(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + c_dim, # Conditioning label (C) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + img_resolution, # Output resolution. + img_channels, # Number of output color channels. + mapping_kwargs = {}, # Arguments for MappingNetwork. + **synthesis_kwargs, # Arguments for SynthesisNetwork. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.img_resolution = img_resolution + self.img_channels = img_channels + self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) + self.num_ws = self.synthesis.num_ws + self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): + ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) + img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) + return img + +#---------------------------------------------------------------------------- diff --git a/training/training_loop.py b/training/training_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..656839cc66370fde536c3aa54fadf623cda33ebc --- /dev/null +++ b/training/training_loop.py @@ -0,0 +1,426 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Main training loop.""" + +import os +import time +import copy +import json +import pickle +import psutil +import PIL.Image +import numpy as np +import torch +import dnnlib +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +import legacy +from metrics import metric_main + +#---------------------------------------------------------------------------- + +def setup_snapshot_image_grid(training_set, random_seed=0): + rnd = np.random.RandomState(random_seed) + gw = np.clip(7680 // training_set.image_shape[2], 7, 32) + gh = np.clip(4320 // training_set.image_shape[1], 4, 32) + + # No labels => show random subset of training samples. + if not training_set.has_labels: + all_indices = list(range(len(training_set))) + rnd.shuffle(all_indices) + grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] + + else: + # Group training samples by label. + label_groups = dict() # label => [idx, ...] + for idx in range(len(training_set)): + label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) + if label not in label_groups: + label_groups[label] = [] + label_groups[label].append(idx) + + # Reorder. + label_order = sorted(label_groups.keys()) + for label in label_order: + rnd.shuffle(label_groups[label]) + + # Organize into grid. + grid_indices = [] + for y in range(gh): + label = label_order[y % len(label_order)] + indices = label_groups[label] + grid_indices += [indices[x % len(indices)] for x in range(gw)] + label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] + + # Load data. + images, labels = zip(*[training_set[i] for i in grid_indices]) + return (gw, gh), np.stack(images), np.stack(labels) + +#---------------------------------------------------------------------------- + +def save_image_grid(img, fname, drange, grid_size): + lo, hi = drange + img = np.asarray(img, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = np.rint(img).clip(0, 255).astype(np.uint8) + + gw, gh = grid_size + _N, C, H, W = img.shape + img = img.reshape([gh, gw, C, H, W]) + img = img.transpose(0, 3, 1, 4, 2) + img = img.reshape([gh * H, gw * W, C]) + + assert C in [1, 3] + if C == 1: + PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) + if C == 3: + PIL.Image.fromarray(img, 'RGB').save(fname) + +#---------------------------------------------------------------------------- + +def training_loop( + run_dir = '.', # Output directory. + training_set_kwargs = {}, # Options for training set. + data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. + G_kwargs = {}, # Options for generator network. + D_kwargs = {}, # Options for discriminator network. + G_opt_kwargs = {}, # Options for generator optimizer. + D_opt_kwargs = {}, # Options for discriminator optimizer. + augment_kwargs = None, # Options for augmentation pipeline. None = disable. + loss_kwargs = {}, # Options for loss function. + metrics = [], # Metrics to evaluate during training. + random_seed = 0, # Global random seed. + num_gpus = 1, # Number of GPUs participating in the training. + rank = 0, # Rank of the current process in [0, num_gpus[. + batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. + batch_gpu = 4, # Number of samples processed at a time by one GPU. + ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. + ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup. + G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization. + D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. + augment_p = 0, # Initial value of augmentation probability. + ada_target = None, # ADA target value. None = fixed p. + ada_interval = 4, # How often to perform ADA adjustment? + ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. + total_kimg = 25000, # Total length of the training, measured in thousands of real images. + kimg_per_tick = 4, # Progress snapshot interval. + image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. + network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. + resume_pkl = None, # Network pickle to resume training from. + resume_kimg = 0, # First kimg to report when resuming training. + cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? + abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. + progress_fn = None, # Callback function for updating training progress. Called for all ranks. +): + # Initialize. + start_time = time.time() + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Load training set. + if rank == 0: + print('Loading training set...') + training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset + training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) + training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) + if rank == 0: + print() + print('Num images: ', len(training_set)) + print('Image shape:', training_set.image_shape) + print('Label shape:', training_set.label_shape) + print() + + # Construct networks. + if rank == 0: + print('Constructing networks...') + common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) + G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + G_ema = copy.deepcopy(G).eval() + + # Resume from existing pickle. + if (resume_pkl is not None) and (rank == 0): + print(f'Resuming from "{resume_pkl}"') + with dnnlib.util.open_url(resume_pkl) as f: + resume_data = legacy.load_network_pkl(f) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: + misc.copy_params_and_buffers(resume_data[name], module, require_all=False) + + # Print network summary tables. + if rank == 0: + z = torch.empty([batch_gpu, G.z_dim], device=device) + c = torch.empty([batch_gpu, G.c_dim], device=device) + img = misc.print_module_summary(G, [z, c]) + misc.print_module_summary(D, [img, c]) + + # Setup augmentation. + if rank == 0: + print('Setting up augmentation...') + augment_pipe = None + ada_stats = None + if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): + augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module + augment_pipe.p.copy_(torch.as_tensor(augment_p)) + if ada_target is not None: + ada_stats = training_stats.Collector(regex='Loss/signs/real') + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [G, D, G_ema, augment_pipe]: + if module is not None: + for param in misc.params_and_buffers(module): + if param.numel() > 0 and num_gpus > 1: + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss + phases = [] + for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: + if reg_interval is None: + opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] + else: # Lazy regularization. + mb_ratio = reg_interval / (reg_interval + 1) + opt_kwargs = dnnlib.EasyDict(opt_kwargs) + opt_kwargs.lr = opt_kwargs.lr * mb_ratio + opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] + opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer + phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] + phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] + for phase in phases: + phase.start_event = None + phase.end_event = None + if rank == 0: + phase.start_event = torch.cuda.Event(enable_timing=True) + phase.end_event = torch.cuda.Event(enable_timing=True) + + # Export sample images. + grid_size = None + grid_z = None + grid_c = None + if rank == 0: + print('Exporting sample images...') + grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) + save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size) + grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) + grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) + images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() + save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + stats_collector = training_stats.Collector(regex='.*') + stats_metrics = dict() + stats_jsonl = None + stats_tfevents = None + if rank == 0: + stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') + try: + import torch.utils.tensorboard as tensorboard + stats_tfevents = tensorboard.SummaryWriter(run_dir) + except ImportError as err: + print('Skipping tfevents export:', err) + + # Train. + if rank == 0: + print(f'Training for {total_kimg} kimg...') + print() + cur_nimg = resume_kimg * 1000 + cur_tick = 0 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - start_time + batch_idx = 0 + if progress_fn is not None: + progress_fn(0, total_kimg) + while True: + + # Fetch training data. + with torch.autograd.profiler.record_function('data_fetch'): + phase_real_img, phase_real_c = next(training_set_iterator) + phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) + phase_real_c = phase_real_c.to(device).split(batch_gpu) + all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) + all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] + all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] + all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) + all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] + + # Execute training phases. + for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): + if batch_idx % phase.interval != 0: + continue + if phase.start_event is not None: + phase.start_event.record(torch.cuda.current_stream(device)) + + # Accumulate gradients. + phase.opt.zero_grad(set_to_none=True) + phase.module.requires_grad_(True) + for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c): + loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg) + phase.module.requires_grad_(False) + + # Update weights. + with torch.autograd.profiler.record_function(phase.name + '_opt'): + params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None] + if len(params) > 0: + flat = torch.cat([param.grad.flatten() for param in params]) + if num_gpus > 1: + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + phase.opt.step() + + # Phase done. + if phase.end_event is not None: + phase.end_event.record(torch.cuda.current_stream(device)) + + # Update G_ema. + with torch.autograd.profiler.record_function('Gema'): + ema_nimg = ema_kimg * 1000 + if ema_rampup is not None: + ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) + ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) + for p_ema, p in zip(G_ema.parameters(), G.parameters()): + p_ema.copy_(p.lerp(p_ema, ema_beta)) + for b_ema, b in zip(G_ema.buffers(), G.buffers()): + b_ema.copy_(b) + + # Update state. + cur_nimg += batch_size + batch_idx += 1 + + # Execute ADA heuristic. + if (ada_stats is not None) and (batch_idx % ada_interval == 0): + ada_stats.update() + adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) + augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) + + # Perform maintenance tasks once per tick. + done = (cur_nimg >= total_kimg * 1000) + if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): + continue + + # Print status line, accumulating the same information in training_stats. + tick_end_time = time.time() + fields = [] + fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] + fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] + fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] + fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] + fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] + fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] + fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] + fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] + fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] + torch.cuda.reset_peak_memory_stats() + fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] + training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) + training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) + if rank == 0: + print(' '.join(fields)) + + # Check for abort. + if (not done) and (abort_fn is not None) and abort_fn(): + done = True + if rank == 0: + print() + print('Aborting...') + + # Save image snapshot. + if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): + images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() + save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size) + + # Save network snapshot. + snapshot_pkl = None + snapshot_data = None + if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): + snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) + for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: + if module is not None: + if num_gpus > 1: + misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)') + module = copy.deepcopy(module).eval().requires_grad_(False).cpu() + snapshot_data[name] = module + del module # conserve memory + snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') + if rank == 0: + with open(snapshot_pkl, 'wb') as f: + pickle.dump(snapshot_data, f) + + # Evaluate metrics. + if (snapshot_data is not None) and (len(metrics) > 0): + if rank == 0: + print('Evaluating metrics...') + for metric in metrics: + result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], + dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) + if rank == 0: + metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) + stats_metrics.update(result_dict.results) + del snapshot_data # conserve memory + + # Collect statistics. + for phase in phases: + value = [] + if (phase.start_event is not None) and (phase.end_event is not None): + phase.end_event.synchronize() + value = phase.start_event.elapsed_time(phase.end_event) + training_stats.report0('Timing/' + phase.name, value) + stats_collector.update() + stats_dict = stats_collector.as_dict() + + # Update logs. + timestamp = time.time() + if stats_jsonl is not None: + fields = dict(stats_dict, timestamp=timestamp) + stats_jsonl.write(json.dumps(fields) + '\n') + stats_jsonl.flush() + if stats_tfevents is not None: + global_step = int(cur_nimg / 1e3) + walltime = timestamp - start_time + for name, value in stats_dict.items(): + stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) + for name, value in stats_metrics.items(): + stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) + stats_tfevents.flush() + if progress_fn is not None: + progress_fn(cur_nimg // 1000, total_kimg) + + # Update state. + cur_tick += 1 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - tick_end_time + if done: + break + + # Done. + if rank == 0: + print() + print('Exiting...') + +#---------------------------------------------------------------------------- diff --git a/visualizer.py b/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4168447d7d6ec7481fc76b889d498ac009dc5549 --- /dev/null +++ b/visualizer.py @@ -0,0 +1,334 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import click +import os + +import multiprocessing +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_window +from gui_utils import imgui_utils +from gui_utils import gl_utils +from gui_utils import text_utils +from viz import renderer +from viz import pickle_widget +from viz import latent_widget +from viz import stylemix_widget +from viz import trunc_noise_widget +from viz import performance_widget +from viz import capture_widget +from viz import layer_widget +from viz import equivariance_widget + +#---------------------------------------------------------------------------- + +class Visualizer(imgui_window.ImguiWindow): + def __init__(self, capture_dir=None): + super().__init__(title='GAN Visualizer', window_width=3840, window_height=2160) + + # Internals. + self._last_error_print = None + self._async_renderer = AsyncRenderer() + self._defer_rendering = 0 + self._tex_img = None + self._tex_obj = None + + # Widget interface. + self.args = dnnlib.EasyDict() + self.result = dnnlib.EasyDict() + self.pane_w = 0 + self.label_w = 0 + self.button_w = 0 + + # Widgets. + self.pickle_widget = pickle_widget.PickleWidget(self) + self.latent_widget = latent_widget.LatentWidget(self) + self.stylemix_widget = stylemix_widget.StyleMixingWidget(self) + self.trunc_noise_widget = trunc_noise_widget.TruncationNoiseWidget(self) + self.perf_widget = performance_widget.PerformanceWidget(self) + self.capture_widget = capture_widget.CaptureWidget(self) + self.layer_widget = layer_widget.LayerWidget(self) + self.eq_widget = equivariance_widget.EquivarianceWidget(self) + + if capture_dir is not None: + self.capture_widget.path = capture_dir + + # Initialize window. + self.set_position(0, 0) + self._adjust_font_size() + self.skip_frame() # Layout may change after first frame. + + def close(self): + super().close() + if self._async_renderer is not None: + self._async_renderer.close() + self._async_renderer = None + + def add_recent_pickle(self, pkl, ignore_errors=False): + self.pickle_widget.add_recent(pkl, ignore_errors=ignore_errors) + + def load_pickle(self, pkl, ignore_errors=False): + self.pickle_widget.load(pkl, ignore_errors=ignore_errors) + + def print_error(self, error): + error = str(error) + if error != self._last_error_print: + print('\n' + error + '\n') + self._last_error_print = error + + def defer_rendering(self, num_frames=1): + self._defer_rendering = max(self._defer_rendering, num_frames) + + def clear_result(self): + self._async_renderer.clear_result() + + def set_async(self, is_async): + if is_async != self._async_renderer.is_async: + self._async_renderer.set_async(is_async) + self.clear_result() + if 'image' in self.result: + self.result.message = 'Switching rendering process...' + self.defer_rendering() + + def _adjust_font_size(self): + old = self.font_size + self.set_font_size(min(self.content_width / 120, self.content_height / 60)) + if self.font_size != old: + self.skip_frame() # Layout changed. + + def draw_frame(self): + self.begin_frame() + self.args = dnnlib.EasyDict() + self.pane_w = self.font_size * 45 + self.button_w = self.font_size * 5 + self.label_w = round(self.font_size * 4.5) + + # Detect mouse dragging in the result area. + dragging, dx, dy = imgui_utils.drag_hidden_window('##result_area', x=self.pane_w, y=0, width=self.content_width-self.pane_w, height=self.content_height) + if dragging: + self.latent_widget.drag(dx, dy) + + # Begin control pane. + imgui.set_next_window_position(0, 0) + imgui.set_next_window_size(self.pane_w, self.content_height) + imgui.begin('##control_pane', closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) + + # Widgets. + expanded, _visible = imgui_utils.collapsing_header('Network & latent', default=True) + self.pickle_widget(expanded) + self.latent_widget(expanded) + self.stylemix_widget(expanded) + self.trunc_noise_widget(expanded) + expanded, _visible = imgui_utils.collapsing_header('Performance & capture', default=True) + self.perf_widget(expanded) + self.capture_widget(expanded) + expanded, _visible = imgui_utils.collapsing_header('Layers & channels', default=True) + self.layer_widget(expanded) + with imgui_utils.grayed_out(not self.result.get('has_input_transform', False)): + expanded, _visible = imgui_utils.collapsing_header('Equivariance', default=True) + self.eq_widget(expanded) + + # Render. + if self.is_skipping_frames(): + pass + elif self._defer_rendering > 0: + self._defer_rendering -= 1 + elif self.args.pkl is not None: + self._async_renderer.set_args(**self.args) + result = self._async_renderer.get_result() + if result is not None: + self.result = result + + # Display. + max_w = self.content_width - self.pane_w + max_h = self.content_height + pos = np.array([self.pane_w + max_w / 2, max_h / 2]) + if 'image' in self.result: + if self._tex_img is not self.result.image: + self._tex_img = self.result.image + if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img): + self._tex_obj = gl_utils.Texture(image=self._tex_img, bilinear=False, mipmap=False) + else: + self._tex_obj.update(self._tex_img) + zoom = min(max_w / self._tex_obj.width, max_h / self._tex_obj.height) + zoom = np.floor(zoom) if zoom >= 1 else zoom + self._tex_obj.draw(pos=pos, zoom=zoom, align=0.5, rint=True) + if 'error' in self.result: + self.print_error(self.result.error) + if 'message' not in self.result: + self.result.message = str(self.result.error) + if 'message' in self.result: + tex = text_utils.get_texture(self.result.message, size=self.font_size, max_width=max_w, max_height=max_h, outline=2) + tex.draw(pos=pos, align=0.5, rint=True, color=1) + + # End frame. + self._adjust_font_size() + imgui.end() + self.end_frame() + +#---------------------------------------------------------------------------- + +class AsyncRenderer: + def __init__(self): + self._closed = False + self._is_async = False + self._cur_args = None + self._cur_result = None + self._cur_stamp = 0 + self._renderer_obj = None + self._args_queue = None + self._result_queue = None + self._process = None + + def close(self): + self._closed = True + self._renderer_obj = None + if self._process is not None: + self._process.terminate() + self._process = None + self._args_queue = None + self._result_queue = None + + @property + def is_async(self): + return self._is_async + + def set_async(self, is_async): + self._is_async = is_async + + def set_args(self, **args): + assert not self._closed + if args != self._cur_args: + if self._is_async: + self._set_args_async(**args) + else: + self._set_args_sync(**args) + self._cur_args = args + + def _set_args_async(self, **args): + if self._process is None: + self._args_queue = multiprocessing.Queue() + self._result_queue = multiprocessing.Queue() + try: + multiprocessing.set_start_method('spawn') + except RuntimeError: + pass + self._process = multiprocessing.Process(target=self._process_fn, args=(self._args_queue, self._result_queue), daemon=True) + self._process.start() + self._args_queue.put([args, self._cur_stamp]) + + def _set_args_sync(self, **args): + if self._renderer_obj is None: + self._renderer_obj = renderer.Renderer() + self._cur_result = self._renderer_obj.render(**args) + + def get_result(self): + assert not self._closed + if self._result_queue is not None: + while self._result_queue.qsize() > 0: + result, stamp = self._result_queue.get() + if stamp == self._cur_stamp: + self._cur_result = result + return self._cur_result + + def clear_result(self): + assert not self._closed + self._cur_args = None + self._cur_result = None + self._cur_stamp += 1 + + @staticmethod + def _process_fn(args_queue, result_queue): + renderer_obj = renderer.Renderer() + cur_args = None + cur_stamp = None + while True: + args, stamp = args_queue.get() + while args_queue.qsize() > 0: + args, stamp = args_queue.get() + if args != cur_args or stamp != cur_stamp: + result = renderer_obj.render(**args) + if 'error' in result: + result.error = renderer.CapturedException(result.error) + result_queue.put([result, stamp]) + cur_args = args + cur_stamp = stamp + +#---------------------------------------------------------------------------- + +@click.command() +@click.argument('pkls', metavar='PATH', nargs=-1) +@click.option('--capture-dir', help='Where to save screenshot captures', metavar='PATH', default=None) +@click.option('--browse-dir', help='Specify model path for the \'Browse...\' button', metavar='PATH') +def main( + pkls, + capture_dir, + browse_dir +): + """Interactive model visualizer. + + Optional PATH argument can be used specify which .pkl file to load. + """ + viz = Visualizer(capture_dir=capture_dir) + + if browse_dir is not None: + viz.pickle_widget.search_dirs = [browse_dir] + + # List pickles. + if len(pkls) > 0: + for pkl in pkls: + viz.add_recent_pickle(pkl) + viz.load_pickle(pkls[0]) + else: + pretrained = [ + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-256x256.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-metfaces-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-metfacesu-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-afhqv2-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfaces-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqdog-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqv2-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqwild-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-brecahad-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-celebahq-256x256.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-cifar10-32x32.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-256x256.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-lsundog-256x256.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfaces-1024x1024.pkl', + 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfacesu-1024x1024.pkl' + ] + + # Populate recent pickles list with pretrained model URLs. + for url in pretrained: + viz.add_recent_pickle(url) + + # Run. + while not viz.should_close(): + viz.draw_frame() + viz.close() + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/viz/__init__.py b/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/viz/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/viz/capture_widget.py b/viz/capture_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..311ae885cf2413f37f716d2c1b71cd8bec1ca889 --- /dev/null +++ b/viz/capture_widget.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import re +import numpy as np +import imgui +import PIL.Image +from gui_utils import imgui_utils +from . import renderer + +#---------------------------------------------------------------------------- + +class CaptureWidget: + def __init__(self, viz): + self.viz = viz + self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) + self.dump_image = False + self.dump_gui = False + self.defer_frames = 0 + self.disabled_time = 0 + + def dump_png(self, image): + viz = self.viz + try: + _height, _width, channels = image.shape + assert channels in [1, 3] + assert image.dtype == np.uint8 + os.makedirs(self.path, exist_ok=True) + file_id = 0 + for entry in os.scandir(self.path): + if entry.is_file(): + match = re.fullmatch(r'(\d+).*', entry.name) + if match: + file_id = max(file_id, int(match.group(1)) + 1) + if channels == 1: + pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') + else: + pil_image = PIL.Image.fromarray(image, 'RGB') + pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) + except: + viz.result.error = renderer.CapturedException() + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + with imgui_utils.grayed_out(self.disabled_time != 0): + imgui.text('Capture') + imgui.same_line(viz.label_w) + _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, + flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), + width=(-1 - viz.button_w * 2 - viz.spacing * 2), + help_text='PATH') + if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': + imgui.set_tooltip(self.path) + imgui.same_line() + if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): + self.dump_image = True + self.defer_frames = 2 + self.disabled_time = 0.5 + imgui.same_line() + if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): + self.dump_gui = True + self.defer_frames = 2 + self.disabled_time = 0.5 + + self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) + if self.defer_frames > 0: + self.defer_frames -= 1 + elif self.dump_image: + if 'image' in viz.result: + self.dump_png(viz.result.image) + self.dump_image = False + elif self.dump_gui: + viz.capture_next_frame() + self.dump_gui = False + captured_frame = viz.pop_captured_frame() + if captured_frame is not None: + self.dump_png(captured_frame) + +#---------------------------------------------------------------------------- diff --git a/viz/equivariance_widget.py b/viz/equivariance_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..49ef74fbfd96b92758df6128ffb92326ea87aac0 --- /dev/null +++ b/viz/equivariance_widget.py @@ -0,0 +1,115 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class EquivarianceWidget: + def __init__(self, viz): + self.viz = viz + self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2) + self.xlate_def = dnnlib.EasyDict(self.xlate) + self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3) + self.rotate_def = dnnlib.EasyDict(self.rotate) + self.opts = dnnlib.EasyDict(untransform=False) + self.opts_def = dnnlib.EasyDict(self.opts) + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + imgui.text('Translate') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f') + imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) + _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w) + if dragging: + self.xlate.x += dx / viz.font_size * 2e-2 + self.xlate.y += dy / viz.font_size * 2e-2 + imgui.same_line() + _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w) + if dragging: + self.xlate.x += dx / viz.font_size * 4e-4 + self.xlate.y += dy / viz.font_size * 4e-4 + imgui.same_line() + _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim) + imgui.same_line() + _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round) + imgui.same_line() + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim): + changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5) + if changed: + self.xlate.speed = speed + imgui.same_line() + if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)): + self.xlate = dnnlib.EasyDict(self.xlate_def) + + if show: + imgui.text('Rotate') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f') + imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) + _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w) + if dragging: + self.rotate.val += dx / viz.font_size * 2e-2 + imgui.same_line() + _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w) + if dragging: + self.rotate.val += dx / viz.font_size * 4e-4 + imgui.same_line() + _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim) + imgui.same_line() + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim): + changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3) + if changed: + self.rotate.speed = speed + imgui.same_line() + if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)): + self.rotate = dnnlib.EasyDict(self.rotate_def) + + if show: + imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16) + _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)): + self.opts = dnnlib.EasyDict(self.opts_def) + + if self.xlate.anim: + c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) + t = c.copy() + if np.max(np.abs(t)) < 1e-4: + t += 1 + t *= 0.1 / np.hypot(*t) + t += c[::-1] * [1, -1] + d = t - c + d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d) + self.xlate.x += d[0] + self.xlate.y += d[1] + + if self.rotate.anim: + self.rotate.val += viz.frame_delta * self.rotate.speed + + pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) + if self.xlate.round and 'img_resolution' in viz.result: + pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution + angle = self.rotate.val * np.pi * 2 + + viz.args.input_transform = [ + [np.cos(angle), np.sin(angle), pos[0]], + [-np.sin(angle), np.cos(angle), pos[1]], + [0, 0, 1]] + + viz.args.update(untransform=self.opts.untransform) + +#---------------------------------------------------------------------------- diff --git a/viz/latent_widget.py b/viz/latent_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6fa62d414183a14aae16cf7551fb1361157dc1 --- /dev/null +++ b/viz/latent_widget.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import imgui +import dnnlib +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class LatentWidget: + def __init__(self, viz): + self.viz = viz + self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25) + self.latent_def = dnnlib.EasyDict(self.latent) + self.step_y = 100 + + def drag(self, dx, dy): + viz = self.viz + self.latent.x += dx / viz.font_size * 4e-2 + self.latent.y += dy / viz.font_size * 4e-2 + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + if show: + imgui.text('Latent') + imgui.same_line(viz.label_w) + seed = round(self.latent.x) + round(self.latent.y) * self.step_y + with imgui_utils.item_width(viz.font_size * 8): + changed, seed = imgui.input_int('##seed', seed) + if changed: + self.latent.x = seed + self.latent.y = 0 + imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) + frac_x = self.latent.x - round(self.latent.x) + frac_y = self.latent.y - round(self.latent.y) + with imgui_utils.item_width(viz.font_size * 5): + changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + if changed: + self.latent.x += new_frac_x - frac_x + self.latent.y += new_frac_y - frac_y + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) + _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) + if dragging: + self.drag(dx, dy) + imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) + _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) + imgui.same_line(round(viz.font_size * 27.7)) + with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): + changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) + if changed: + self.latent.speed = speed + imgui.same_line() + snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) + if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): + self.latent = snapped + imgui.same_line() + if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): + self.latent = dnnlib.EasyDict(self.latent_def) + + if self.latent.anim: + self.latent.x += viz.frame_delta * self.latent.speed + viz.args.w0_seeds = [] # [[seed, weight], ...] + for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: + seed_x = np.floor(self.latent.x) + ofs_x + seed_y = np.floor(self.latent.y) + ofs_y + seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) + weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) + if weight > 0: + viz.args.w0_seeds.append([seed, weight]) + +#---------------------------------------------------------------------------- diff --git a/viz/layer_widget.py b/viz/layer_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..52758a9539bc0c0ee273e4dc3529a981fd0a6d13 --- /dev/null +++ b/viz/layer_widget.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class LayerWidget: + def __init__(self, viz): + self.viz = viz + self.prev_layers = None + self.cur_layer = None + self.sel_channels = 3 + self.base_channel = 0 + self.img_scale_db = 0 + self.img_normalize = False + self.fft_show = False + self.fft_all = True + self.fft_range_db = 50 + self.fft_beta = 8 + self.refocus = False + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + layers = viz.result.get('layers', []) + if self.prev_layers != layers: + self.prev_layers = layers + self.refocus = True + layer = ([layer for layer in layers if layer.name == self.cur_layer] + [None])[0] + if layer is None and len(layers) > 0: + layer = layers[-1] + self.cur_layer = layer.name + num_channels = layer.shape[1] if layer is not None else 0 + base_channel_max = max(num_channels - self.sel_channels, 0) + + if show: + bg_color = [0.16, 0.29, 0.48, 0.2] + dim_color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) + dim_color[-1] *= 0.5 + + # Begin list. + width = viz.font_size * 28 + height = imgui.get_text_line_height_with_spacing() * 12 + viz.spacing + imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) + imgui.push_style_color(imgui.COLOR_CHILD_BACKGROUND, *bg_color) + imgui.push_style_color(imgui.COLOR_HEADER, 0, 0, 0, 0) + imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, 0.16, 0.29, 0.48, 0.5) + imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, 0.16, 0.29, 0.48, 0.9) + imgui.begin_child('##list', width=width, height=height, border=True, flags=imgui.WINDOW_ALWAYS_VERTICAL_SCROLLBAR) + + # List items. + for layer in layers: + selected = (self.cur_layer == layer.name) + _opened, selected = imgui.selectable(f'##{layer.name}_selectable', selected) + imgui.same_line(viz.spacing) + _clicked, selected = imgui.checkbox(f'{layer.name}##radio', selected) + if selected: + self.cur_layer = layer.name + if self.refocus: + imgui.set_scroll_here() + viz.skip_frame() # Focus will change on next frame. + self.refocus = False + imgui.same_line(width - viz.font_size * 13) + imgui.text_colored('x'.join(str(x) for x in layer.shape[2:]), *dim_color) + imgui.same_line(width - viz.font_size * 8) + imgui.text_colored(str(layer.shape[1]), *dim_color) + imgui.same_line(width - viz.font_size * 5) + imgui.text_colored(layer.dtype, *dim_color) + + # End list. + if len(layers) == 0: + imgui.text_colored('No layers found', *dim_color) + imgui.end_child() + imgui.pop_style_color(4) + imgui.pop_style_var(1) + + # Begin options. + imgui.same_line() + imgui.begin_child('##options', width=-1, height=height, border=False) + + # RGB & normalize. + rgb = (self.sel_channels == 3) + _clicked, rgb = imgui.checkbox('RGB', rgb) + self.sel_channels = 3 if rgb else 1 + imgui.same_line(viz.font_size * 4) + _clicked, self.img_normalize = imgui.checkbox('Normalize', self.img_normalize) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + if imgui_utils.button('Reset##img_flags', width=-1, enabled=(self.sel_channels != 3 or self.img_normalize)): + self.sel_channels = 3 + self.img_normalize = False + + # Image scale. + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): + _changed, self.img_scale_db = imgui.slider_float('##scale', self.img_scale_db, min_value=-40, max_value=40, format='Scale %+.1f dB') + imgui.same_line() + if imgui_utils.button('Reset##scale', width=-1, enabled=(self.img_scale_db != 0)): + self.img_scale_db = 0 + + # Base channel. + self.base_channel = min(max(self.base_channel, 0), base_channel_max) + narrow_w = imgui.get_text_line_height_with_spacing() + with imgui_utils.grayed_out(base_channel_max == 0): + with imgui_utils.item_width(-1 - viz.button_w - narrow_w * 2 - viz.spacing * 3): + _changed, self.base_channel = imgui.drag_int('##channel', self.base_channel, change_speed=0.05, min_value=0, max_value=base_channel_max, format=f'Channel %d/{num_channels}') + imgui.same_line() + if imgui_utils.button('-##channel', width=narrow_w): + self.base_channel -= 1 + imgui.same_line() + if imgui_utils.button('+##channel', width=narrow_w): + self.base_channel += 1 + imgui.same_line() + self.base_channel = min(max(self.base_channel, 0), base_channel_max) + if imgui_utils.button('Reset##channel', width=-1, enabled=(self.base_channel != 0 and base_channel_max > 0)): + self.base_channel = 0 + + # Stats. + stats = viz.result.get('stats', None) + stats = [f'{stats[idx]:g}' if stats is not None else 'N/A' for idx in range(6)] + rows = [ + ['Statistic', 'All channels', 'Selected'], + ['Mean', stats[0], stats[1]], + ['Std', stats[2], stats[3]], + ['Max', stats[4], stats[5]], + ] + height = imgui.get_text_line_height_with_spacing() * len(rows) + viz.spacing + imgui.push_style_color(imgui.COLOR_CHILD_BACKGROUND, *bg_color) + imgui.begin_child('##stats', width=-1, height=height, border=True) + for y, cols in enumerate(rows): + for x, col in enumerate(cols): + if x != 0: + imgui.same_line(viz.font_size * (4 + (x - 1) * 6)) + if x == 0 or y == 0: + imgui.text_colored(col, *dim_color) + else: + imgui.text(col) + imgui.end_child() + imgui.pop_style_color(1) + + # FFT & all. + _clicked, self.fft_show = imgui.checkbox('FFT', self.fft_show) + imgui.same_line(viz.font_size * 4) + with imgui_utils.grayed_out(not self.fft_show or base_channel_max == 0): + _clicked, self.fft_all = imgui.checkbox('All channels', self.fft_all) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + with imgui_utils.grayed_out(not self.fft_show): + if imgui_utils.button('Reset##fft_flags', width=-1, enabled=(self.fft_show or not self.fft_all)): + self.fft_show = False + self.fft_all = True + + # FFT range. + with imgui_utils.grayed_out(not self.fft_show): + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): + _changed, self.fft_range_db = imgui.slider_float('##fft_range_db', self.fft_range_db, min_value=0.1, max_value=100, format='Range +-%.1f dB') + imgui.same_line() + if imgui_utils.button('Reset##fft_range_db', width=-1, enabled=(self.fft_range_db != 50)): + self.fft_range_db = 50 + + # FFT beta. + with imgui_utils.grayed_out(not self.fft_show): + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): + _changed, self.fft_beta = imgui.slider_float('##fft_beta', self.fft_beta, min_value=0, max_value=50, format='Kaiser beta %.2f', power=2.63) + imgui.same_line() + if imgui_utils.button('Reset##fft_beta', width=-1, enabled=(self.fft_beta != 8)): + self.fft_beta = 8 + + # End options. + imgui.end_child() + + self.base_channel = min(max(self.base_channel, 0), base_channel_max) + viz.args.layer_name = self.cur_layer if len(layers) > 0 and self.cur_layer != layers[-1].name else None + viz.args.update(sel_channels=self.sel_channels, base_channel=self.base_channel, img_scale_db=self.img_scale_db, img_normalize=self.img_normalize) + viz.args.fft_show = self.fft_show + if self.fft_show: + viz.args.update(fft_all=self.fft_all, fft_range_db=self.fft_range_db, fft_beta=self.fft_beta) + +#---------------------------------------------------------------------------- diff --git a/viz/performance_widget.py b/viz/performance_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..67e874a42d518658f67487ce101b216371464a58 --- /dev/null +++ b/viz/performance_widget.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import array +import numpy as np +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class PerformanceWidget: + def __init__(self, viz): + self.viz = viz + self.gui_times = [float('nan')] * 60 + self.render_times = [float('nan')] * 30 + self.fps_limit = 60 + self.use_vsync = False + self.is_async = False + self.force_fp32 = False + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + self.gui_times = self.gui_times[1:] + [viz.frame_delta] + if 'render_time' in viz.result: + self.render_times = self.render_times[1:] + [viz.result.render_time] + del viz.result.render_time + + if show: + imgui.text('GUI') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) + imgui.same_line(viz.label_w + viz.font_size * 9) + t = [x for x in self.gui_times if x > 0] + t = np.mean(t) if len(t) > 0 else 0 + imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 14) + imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) + with imgui_utils.item_width(viz.font_size * 6): + _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) + self.fps_limit = min(max(self.fps_limit, 5), 1000) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) + _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) + + if show: + imgui.text('Render') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8): + imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) + imgui.same_line(viz.label_w + viz.font_size * 9) + t = [x for x in self.render_times if x > 0] + t = np.mean(t) if len(t) > 0 else 0 + imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 14) + imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') + imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) + _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) + _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) + + viz.set_fps_limit(self.fps_limit) + viz.set_vsync(self.use_vsync) + viz.set_async(self.is_async) + viz.args.force_fp32 = self.force_fp32 + +#---------------------------------------------------------------------------- diff --git a/viz/pickle_widget.py b/viz/pickle_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5f9535741772698ef9b885be562cb014ced0df --- /dev/null +++ b/viz/pickle_widget.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import glob +import os +import re + +import dnnlib +import imgui +import numpy as np +from gui_utils import imgui_utils + +from . import renderer + +#---------------------------------------------------------------------------- + +def _locate_results(pattern): + return pattern + +#---------------------------------------------------------------------------- + +class PickleWidget: + def __init__(self, viz): + self.viz = viz + self.search_dirs = [] + self.cur_pkl = None + self.user_pkl = '' + self.recent_pkls = [] + self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} + self.browse_refocus = False + self.load('', ignore_errors=True) + + def add_recent(self, pkl, ignore_errors=False): + try: + resolved = self.resolve_pkl(pkl) + if resolved not in self.recent_pkls: + self.recent_pkls.append(resolved) + except: + if not ignore_errors: + raise + + def load(self, pkl, ignore_errors=False): + viz = self.viz + viz.clear_result() + viz.skip_frame() # The input field will change on next frame. + try: + resolved = self.resolve_pkl(pkl) + name = resolved.replace('\\', '/').split('/')[-1] + self.cur_pkl = resolved + self.user_pkl = resolved + viz.result.message = f'Loading {name}...' + viz.defer_rendering() + if resolved in self.recent_pkls: + self.recent_pkls.remove(resolved) + self.recent_pkls.insert(0, resolved) + except: + self.cur_pkl = None + self.user_pkl = pkl + if pkl == '': + viz.result = dnnlib.EasyDict(message='No network pickle loaded') + else: + viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) + if not ignore_errors: + raise + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] + if show: + imgui.text('Pickle') + imgui.same_line(viz.label_w) + changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, + flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), + width=(-1 - viz.button_w * 2 - viz.spacing * 2), + help_text=' | | | | /.pkl') + if changed: + self.load(self.user_pkl, ignore_errors=True) + if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': + imgui.set_tooltip(self.user_pkl) + imgui.same_line() + if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): + imgui.open_popup('recent_pkls_popup') + imgui.same_line() + if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): + imgui.open_popup('browse_pkls_popup') + self.browse_cache.clear() + self.browse_refocus = True + + if imgui.begin_popup('recent_pkls_popup'): + for pkl in recent_pkls: + clicked, _state = imgui.menu_item(pkl) + if clicked: + self.load(pkl, ignore_errors=True) + imgui.end_popup() + + if imgui.begin_popup('browse_pkls_popup'): + def recurse(parents): + key = tuple(parents) + items = self.browse_cache.get(key, None) + if items is None: + items = self.list_runs_and_pkls(parents) + self.browse_cache[key] = items + for item in items: + if item.type == 'run' and imgui.begin_menu(item.name): + recurse([item.path]) + imgui.end_menu() + if item.type == 'pkl': + clicked, _state = imgui.menu_item(item.name) + if clicked: + self.load(item.path, ignore_errors=True) + if len(items) == 0: + with imgui_utils.grayed_out(): + imgui.menu_item('No results found') + recurse(self.search_dirs) + if self.browse_refocus: + imgui.set_scroll_here() + viz.skip_frame() # Focus will change on next frame. + self.browse_refocus = False + imgui.end_popup() + + paths = viz.pop_drag_and_drop_paths() + if paths is not None and len(paths) >= 1: + self.load(paths[0], ignore_errors=True) + + viz.args.pkl = self.cur_pkl + + def list_runs_and_pkls(self, parents): + items = [] + run_regex = re.compile(r'\d+-.*') + pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') + for parent in set(parents): + if os.path.isdir(parent): + for entry in os.scandir(parent): + if entry.is_dir() and run_regex.fullmatch(entry.name): + items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) + if entry.is_file() and pkl_regex.fullmatch(entry.name): + items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) + + items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) + return items + + def resolve_pkl(self, pattern): + assert isinstance(pattern, str) + assert pattern != '' + + # URL => return as is. + if dnnlib.util.is_url(pattern): + return pattern + + # Short-hand pattern => locate. + path = _locate_results(pattern) + + # Run dir => pick the last saved snapshot. + if os.path.isdir(path): + pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) + if len(pkl_files) == 0: + raise IOError(f'No network pickle found in "{path}"') + path = pkl_files[-1] + + # Normalize. + path = os.path.abspath(path) + return path + +#---------------------------------------------------------------------------- diff --git a/viz/renderer.py b/viz/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8941229f811bfc1ce215c033e5cfc3300549e4 --- /dev/null +++ b/viz/renderer.py @@ -0,0 +1,377 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import sys +import copy +import traceback +import numpy as np +import torch +import torch.fft +import torch.nn +import matplotlib.cm +import dnnlib +from torch_utils.ops import upfirdn2d +import legacy # pylint: disable=import-error + +#---------------------------------------------------------------------------- + +class CapturedException(Exception): + def __init__(self, msg=None): + if msg is None: + _type, value, _traceback = sys.exc_info() + assert value is not None + if isinstance(value, CapturedException): + msg = str(value) + else: + msg = traceback.format_exc() + assert isinstance(msg, str) + super().__init__(msg) + +#---------------------------------------------------------------------------- + +class CaptureSuccess(Exception): + def __init__(self, out): + super().__init__() + self.out = out + +#---------------------------------------------------------------------------- + +def _sinc(x): + y = (x * np.pi).abs() + z = torch.sin(y) / y.clamp(1e-30, float('inf')) + return torch.where(y < 1e-30, torch.ones_like(x), z) + +def _lanczos_window(x, a): + x = x.abs() / a + return torch.where(x < 1, _sinc(x), torch.zeros_like(x)) + +#---------------------------------------------------------------------------- + +def _construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): + assert a <= amax < aflt + mat = torch.as_tensor(mat).to(torch.float32) + + # Construct 2D filter taps in input & output coordinate spaces. + taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) + yi, xi = torch.meshgrid(taps, taps) + xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) + + # Convolution of two oriented 2D sinc filters. + fi = _sinc(xi * cutoff_in) * _sinc(yi * cutoff_in) + fo = _sinc(xo * cutoff_out) * _sinc(yo * cutoff_out) + f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real + + # Convolution of two oriented 2D Lanczos windows. + wi = _lanczos_window(xi, a) * _lanczos_window(yi, a) + wo = _lanczos_window(xo, a) * _lanczos_window(yo, a) + w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real + + # Construct windowed FIR filter. + f = f * w + + # Finalize. + c = (aflt - amax) * up + f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] + f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) + f = f / f.sum([0,2], keepdim=True) / (up ** 2) + f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] + return f + +#---------------------------------------------------------------------------- + +def _apply_affine_transformation(x, mat, up=4, **filter_kwargs): + _N, _C, H, W = x.shape + mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) + + # Construct filter. + f = _construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) + assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 + p = f.shape[0] // 2 + + # Construct sampling grid. + theta = mat.inverse() + theta[:2, 2] *= 2 + theta[0, 2] += 1 / up / W + theta[1, 2] += 1 / up / H + theta[0, :] *= W / (W + p / up * 2) + theta[1, :] *= H / (H + p / up * 2) + theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) + g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) + + # Resample image. + y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) + z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Form mask. + m = torch.zeros_like(y) + c = p * 2 + 1 + m[:, :, c:-c, c:-c] = 1 + m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) + return z, m + +#---------------------------------------------------------------------------- + +class Renderer: + def __init__(self): + self._device = torch.device('cuda') + self._pkl_data = dict() # {pkl: dict | CapturedException, ...} + self._networks = dict() # {cache_key: torch.nn.Module, ...} + self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} + self._cmaps = dict() # {name: torch.Tensor, ...} + self._is_timing = False + self._start_event = torch.cuda.Event(enable_timing=True) + self._end_event = torch.cuda.Event(enable_timing=True) + self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} + + def render(self, **args): + self._is_timing = True + self._start_event.record(torch.cuda.current_stream(self._device)) + res = dnnlib.EasyDict() + try: + self._render_impl(res, **args) + except: + res.error = CapturedException() + self._end_event.record(torch.cuda.current_stream(self._device)) + if 'image' in res: + res.image = self.to_cpu(res.image).numpy() + if 'stats' in res: + res.stats = self.to_cpu(res.stats).numpy() + if 'error' in res: + res.error = str(res.error) + if self._is_timing: + self._end_event.synchronize() + res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 + self._is_timing = False + return res + + def get_network(self, pkl, key, **tweak_kwargs): + data = self._pkl_data.get(pkl, None) + if data is None: + print(f'Loading "{pkl}"... ', end='', flush=True) + try: + with dnnlib.util.open_url(pkl, verbose=False) as f: + data = legacy.load_network_pkl(f) + print('Done.') + except: + data = CapturedException() + print('Failed!') + self._pkl_data[pkl] = data + self._ignore_timing() + if isinstance(data, CapturedException): + raise data + + orig_net = data[key] + cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items()))) + net = self._networks.get(cache_key, None) + if net is None: + try: + net = copy.deepcopy(orig_net) + net = self._tweak_network(net, **tweak_kwargs) + net.to(self._device) + except: + net = CapturedException() + self._networks[cache_key] = net + self._ignore_timing() + if isinstance(net, CapturedException): + raise net + return net + + def _tweak_network(self, net): + # Print diagnostics. + #for name, value in misc.named_params_and_buffers(net): + # if name.endswith('.magnitude_ema'): + # value = value.rsqrt().numpy() + # print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') + # if name.endswith('.weight') and value.ndim == 4: + # value = value.square().mean([1,2,3]).sqrt().numpy() + # print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') + return net + + def _get_pinned_buf(self, ref): + key = (tuple(ref.shape), ref.dtype) + buf = self._pinned_bufs.get(key, None) + if buf is None: + buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() + self._pinned_bufs[key] = buf + return buf + + def to_device(self, buf): + return self._get_pinned_buf(buf).copy_(buf).to(self._device) + + def to_cpu(self, buf): + return self._get_pinned_buf(buf).copy_(buf).clone() + + def _ignore_timing(self): + self._is_timing = False + + def _apply_cmap(self, x, name='viridis'): + cmap = self._cmaps.get(name, None) + if cmap is None: + cmap = matplotlib.cm.get_cmap(name) + cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] + cmap = self.to_device(torch.from_numpy(cmap)) + self._cmaps[name] = cmap + hi = cmap.shape[0] - 1 + x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) + x = torch.nn.functional.embedding(x, cmap) + return x + + def _render_impl(self, res, + pkl = None, + w0_seeds = [[0, 1]], + stylemix_idx = [], + stylemix_seed = 0, + trunc_psi = 1, + trunc_cutoff = 0, + random_seed = 0, + noise_mode = 'const', + force_fp32 = False, + layer_name = None, + sel_channels = 3, + base_channel = 0, + img_scale_db = 0, + img_normalize = False, + fft_show = False, + fft_all = True, + fft_range_db = 50, + fft_beta = 8, + input_transform = None, + untransform = False, + ): + # Dig up network details. + G = self.get_network(pkl, 'G_ema') + res.img_resolution = G.img_resolution + res.num_ws = G.num_ws + res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) + res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) + + # Set input transform. + if res.has_input_transform: + m = np.eye(3) + try: + if input_transform is not None: + m = np.linalg.inv(np.asarray(input_transform)) + except np.linalg.LinAlgError: + res.error = CapturedException() + G.synthesis.input.transform.copy_(torch.from_numpy(m)) + + # Generate random latents. + all_seeds = [seed for seed, _weight in w0_seeds] + [stylemix_seed] + all_seeds = list(set(all_seeds)) + all_zs = np.zeros([len(all_seeds), G.z_dim], dtype=np.float32) + all_cs = np.zeros([len(all_seeds), G.c_dim], dtype=np.float32) + for idx, seed in enumerate(all_seeds): + rnd = np.random.RandomState(seed) + all_zs[idx] = rnd.randn(G.z_dim) + if G.c_dim > 0: + all_cs[idx, rnd.randint(G.c_dim)] = 1 + + # Run mapping network. + w_avg = G.mapping.w_avg + all_zs = self.to_device(torch.from_numpy(all_zs)) + all_cs = self.to_device(torch.from_numpy(all_cs)) + all_ws = G.mapping(z=all_zs, c=all_cs, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) - w_avg + all_ws = dict(zip(all_seeds, all_ws)) + + # Calculate final W. + w = torch.stack([all_ws[seed] * weight for seed, weight in w0_seeds]).sum(dim=0, keepdim=True) + stylemix_idx = [idx for idx in stylemix_idx if 0 <= idx < G.num_ws] + if len(stylemix_idx) > 0: + w[:, stylemix_idx] = all_ws[stylemix_seed][np.newaxis, stylemix_idx] + w += w_avg + + # Run synthesis network. + synthesis_kwargs = dnnlib.EasyDict(noise_mode=noise_mode, force_fp32=force_fp32) + torch.manual_seed(random_seed) + out, layers = self.run_synthesis_net(G.synthesis, w, capture_layer=layer_name, **synthesis_kwargs) + + # Update layer list. + cache_key = (G.synthesis, tuple(sorted(synthesis_kwargs.items()))) + if cache_key not in self._net_layers: + if layer_name is not None: + torch.manual_seed(random_seed) + _out, layers = self.run_synthesis_net(G.synthesis, w, **synthesis_kwargs) + self._net_layers[cache_key] = layers + res.layers = self._net_layers[cache_key] + + # Untransform. + if untransform and res.has_input_transform: + out, _mask = _apply_affine_transformation(out.to(torch.float32), G.synthesis.input.transform, amax=6) # Override amax to hit the fast path in upfirdn2d. + + # Select channels and compute statistics. + out = out[0].to(torch.float32) + if sel_channels > out.shape[0]: + sel_channels = 1 + base_channel = max(min(base_channel, out.shape[0] - sel_channels), 0) + sel = out[base_channel : base_channel + sel_channels] + res.stats = torch.stack([ + out.mean(), sel.mean(), + out.std(), sel.std(), + out.norm(float('inf')), sel.norm(float('inf')), + ]) + + # Scale and convert to uint8. + img = sel + if img_normalize: + img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) + img = img * (10 ** (img_scale_db / 20)) + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) + res.image = img + + # FFT. + if fft_show: + sig = out if fft_all else sel + sig = sig.to(torch.float32) + sig = sig - sig.mean(dim=[1,2], keepdim=True) + sig = sig * torch.kaiser_window(sig.shape[1], periodic=False, beta=fft_beta, device=self._device)[None, :, None] + sig = sig * torch.kaiser_window(sig.shape[2], periodic=False, beta=fft_beta, device=self._device)[None, None, :] + fft = torch.fft.fftn(sig, dim=[1,2]).abs().square().sum(dim=0) + fft = fft.roll(shifts=[fft.shape[0] // 2, fft.shape[1] // 2], dims=[0,1]) + fft = (fft / fft.mean()).log10() * 10 # dB + fft = self._apply_cmap((fft / fft_range_db + 1) / 2) + res.image = torch.cat([img.expand_as(fft), fft], dim=1) + + @staticmethod + def run_synthesis_net(net, *args, capture_layer=None, **kwargs): # => out, layers + submodule_names = {mod: name for name, mod in net.named_modules()} + unique_names = set() + layers = [] + + def module_hook(module, _inputs, outputs): + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [out for out in outputs if isinstance(out, torch.Tensor) and out.ndim in [4, 5]] + for idx, out in enumerate(outputs): + if out.ndim == 5: # G-CNN => remove group dimension. + out = out.mean(2) + name = submodule_names[module] + if name == '': + name = 'output' + if len(outputs) > 1: + name += f':{idx}' + if name in unique_names: + suffix = 2 + while f'{name}_{suffix}' in unique_names: + suffix += 1 + name += f'_{suffix}' + unique_names.add(name) + shape = [int(x) for x in out.shape] + dtype = str(out.dtype).split('.')[-1] + layers.append(dnnlib.EasyDict(name=name, shape=shape, dtype=dtype)) + if name == capture_layer: + raise CaptureSuccess(out) + + hooks = [module.register_forward_hook(module_hook) for module in net.modules()] + try: + out = net(*args, **kwargs) + except CaptureSuccess as e: + out = e.out + for hook in hooks: + hook.remove() + return out, layers + +#---------------------------------------------------------------------------- diff --git a/viz/stylemix_widget.py b/viz/stylemix_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..b997855470b77c3cb523679febeb004df42d6738 --- /dev/null +++ b/viz/stylemix_widget.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class StyleMixingWidget: + def __init__(self, viz): + self.viz = viz + self.seed_def = 1000 + self.seed = self.seed_def + self.animate = False + self.enables = [] + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + num_ws = viz.result.get('num_ws', 0) + num_enables = viz.result.get('num_ws', 18) + self.enables += [False] * max(num_enables - len(self.enables), 0) + + if show: + imgui.text('Stylemix') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): + _changed, self.seed = imgui.input_int('##seed', self.seed) + imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) + with imgui_utils.grayed_out(num_ws == 0): + _clicked, self.animate = imgui.checkbox('Anim', self.animate) + + pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w + pos1 = pos2 - imgui.get_text_line_height() - viz.spacing + pos0 = viz.label_w + viz.font_size * 12 + imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) + for idx in range(num_enables): + imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) + if idx == 0: + imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) + with imgui_utils.grayed_out(num_ws == 0): + _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) + if imgui.is_item_hovered(): + imgui.set_tooltip(f'{idx}') + imgui.pop_style_var(1) + + imgui.same_line(pos2) + imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) + with imgui_utils.grayed_out(num_ws == 0): + if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): + self.seed = self.seed_def + self.animate = False + self.enables = [False] * num_enables + + if any(self.enables[:num_ws]): + viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] + viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) + if self.animate: + self.seed += 1 + +#---------------------------------------------------------------------------- diff --git a/viz/trunc_noise_widget.py b/viz/trunc_noise_widget.py new file mode 100644 index 0000000000000000000000000000000000000000..4597b28e81e0941103b234364c60b1aeb5aa0361 --- /dev/null +++ b/viz/trunc_noise_widget.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import imgui +from gui_utils import imgui_utils + +#---------------------------------------------------------------------------- + +class TruncationNoiseWidget: + def __init__(self, viz): + self.viz = viz + self.prev_num_ws = 0 + self.trunc_psi = 1 + self.trunc_cutoff = 0 + self.noise_enable = True + self.noise_seed = 0 + self.noise_anim = False + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + num_ws = viz.result.get('num_ws', 0) + has_noise = viz.result.get('has_noise', False) + if num_ws > 0 and num_ws != self.prev_num_ws: + if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: + self.trunc_cutoff = num_ws + self.prev_num_ws = num_ws + + if show: + imgui.text('Truncate') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): + _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') + imgui.same_line() + if num_ws == 0: + imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) + else: + with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): + changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') + if changed: + self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) + + with imgui_utils.grayed_out(not has_noise): + imgui.same_line() + _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) + imgui.same_line(round(viz.font_size * 27.7)) + with imgui_utils.grayed_out(not self.noise_enable): + with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4): + _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) + imgui.same_line(spacing=0) + _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) + + is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) + is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) + with imgui_utils.grayed_out(is_def_trunc and not has_noise): + imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) + if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): + self.prev_num_ws = num_ws + self.trunc_psi = 1 + self.trunc_cutoff = num_ws + self.noise_enable = True + self.noise_seed = 0 + self.noise_anim = False + + if self.noise_anim: + self.noise_seed += 1 + viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) + viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') + +#----------------------------------------------------------------------------