# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import sys import copy import traceback import numpy as np import torch import torch.fft import torch.nn import matplotlib.cm import dnnlib from torch_utils.ops import upfirdn2d import legacy # pylint: disable=import-error #---------------------------------------------------------------------------- class CapturedException(Exception): def __init__(self, msg=None): if msg is None: _type, value, _traceback = sys.exc_info() assert value is not None if isinstance(value, CapturedException): msg = str(value) else: msg = traceback.format_exc() assert isinstance(msg, str) super().__init__(msg) #---------------------------------------------------------------------------- class CaptureSuccess(Exception): def __init__(self, out): super().__init__() self.out = out #---------------------------------------------------------------------------- def _sinc(x): y = (x * np.pi).abs() z = torch.sin(y) / y.clamp(1e-30, float('inf')) return torch.where(y < 1e-30, torch.ones_like(x), z) def _lanczos_window(x, a): x = x.abs() / a return torch.where(x < 1, _sinc(x), torch.zeros_like(x)) #---------------------------------------------------------------------------- def _construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): assert a <= amax < aflt mat = torch.as_tensor(mat).to(torch.float32) # Construct 2D filter taps in input & output coordinate spaces. taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) yi, xi = torch.meshgrid(taps, taps) xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) # Convolution of two oriented 2D sinc filters. fi = _sinc(xi * cutoff_in) * _sinc(yi * cutoff_in) fo = _sinc(xo * cutoff_out) * _sinc(yo * cutoff_out) f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real # Convolution of two oriented 2D Lanczos windows. wi = _lanczos_window(xi, a) * _lanczos_window(yi, a) wo = _lanczos_window(xo, a) * _lanczos_window(yo, a) w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real # Construct windowed FIR filter. f = f * w # Finalize. c = (aflt - amax) * up f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) f = f / f.sum([0,2], keepdim=True) / (up ** 2) f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] return f #---------------------------------------------------------------------------- def _apply_affine_transformation(x, mat, up=4, **filter_kwargs): _N, _C, H, W = x.shape mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) # Construct filter. f = _construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 p = f.shape[0] // 2 # Construct sampling grid. theta = mat.inverse() theta[:2, 2] *= 2 theta[0, 2] += 1 / up / W theta[1, 2] += 1 / up / H theta[0, :] *= W / (W + p / up * 2) theta[1, :] *= H / (H + p / up * 2) theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) # Resample image. y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) # Form mask. m = torch.zeros_like(y) c = p * 2 + 1 m[:, :, c:-c, c:-c] = 1 m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) return z, m #---------------------------------------------------------------------------- class Renderer: def __init__(self): self._device = torch.device('cuda') self._pkl_data = dict() # {pkl: dict | CapturedException, ...} self._networks = dict() # {cache_key: torch.nn.Module, ...} self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} self._cmaps = dict() # {name: torch.Tensor, ...} self._is_timing = False self._start_event = torch.cuda.Event(enable_timing=True) self._end_event = torch.cuda.Event(enable_timing=True) self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} def render(self, **args): self._is_timing = True self._start_event.record(torch.cuda.current_stream(self._device)) res = dnnlib.EasyDict() try: self._render_impl(res, **args) except: res.error = CapturedException() self._end_event.record(torch.cuda.current_stream(self._device)) if 'image' in res: res.image = self.to_cpu(res.image).numpy() if 'stats' in res: res.stats = self.to_cpu(res.stats).numpy() if 'error' in res: res.error = str(res.error) if self._is_timing: self._end_event.synchronize() res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 self._is_timing = False return res def get_network(self, pkl, key, **tweak_kwargs): data = self._pkl_data.get(pkl, None) if data is None: print(f'Loading "{pkl}"... ', end='', flush=True) try: with dnnlib.util.open_url(pkl, verbose=False) as f: data = legacy.load_network_pkl(f) print('Done.') except: data = CapturedException() print('Failed!') self._pkl_data[pkl] = data self._ignore_timing() if isinstance(data, CapturedException): raise data orig_net = data[key] cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items()))) net = self._networks.get(cache_key, None) if net is None: try: net = copy.deepcopy(orig_net) net = self._tweak_network(net, **tweak_kwargs) net.to(self._device) except: net = CapturedException() self._networks[cache_key] = net self._ignore_timing() if isinstance(net, CapturedException): raise net return net def _tweak_network(self, net): # Print diagnostics. #for name, value in misc.named_params_and_buffers(net): # if name.endswith('.magnitude_ema'): # value = value.rsqrt().numpy() # print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') # if name.endswith('.weight') and value.ndim == 4: # value = value.square().mean([1,2,3]).sqrt().numpy() # print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') return net def _get_pinned_buf(self, ref): key = (tuple(ref.shape), ref.dtype) buf = self._pinned_bufs.get(key, None) if buf is None: buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() self._pinned_bufs[key] = buf return buf def to_device(self, buf): return self._get_pinned_buf(buf).copy_(buf).to(self._device) def to_cpu(self, buf): return self._get_pinned_buf(buf).copy_(buf).clone() def _ignore_timing(self): self._is_timing = False def _apply_cmap(self, x, name='viridis'): cmap = self._cmaps.get(name, None) if cmap is None: cmap = matplotlib.cm.get_cmap(name) cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] cmap = self.to_device(torch.from_numpy(cmap)) self._cmaps[name] = cmap hi = cmap.shape[0] - 1 x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) x = torch.nn.functional.embedding(x, cmap) return x def _render_impl(self, res, pkl = None, w0_seeds = [[0, 1]], stylemix_idx = [], stylemix_seed = 0, trunc_psi = 1, trunc_cutoff = 0, random_seed = 0, noise_mode = 'const', force_fp32 = False, layer_name = None, sel_channels = 3, base_channel = 0, img_scale_db = 0, img_normalize = False, fft_show = False, fft_all = True, fft_range_db = 50, fft_beta = 8, input_transform = None, untransform = False, ): # Dig up network details. G = self.get_network(pkl, 'G_ema') res.img_resolution = G.img_resolution res.num_ws = G.num_ws res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) # Set input transform. if res.has_input_transform: m = np.eye(3) try: if input_transform is not None: m = np.linalg.inv(np.asarray(input_transform)) except np.linalg.LinAlgError: res.error = CapturedException() G.synthesis.input.transform.copy_(torch.from_numpy(m)) # Generate random latents. all_seeds = [seed for seed, _weight in w0_seeds] + [stylemix_seed] all_seeds = list(set(all_seeds)) all_zs = np.zeros([len(all_seeds), G.z_dim], dtype=np.float32) all_cs = np.zeros([len(all_seeds), G.c_dim], dtype=np.float32) for idx, seed in enumerate(all_seeds): rnd = np.random.RandomState(seed) all_zs[idx] = rnd.randn(G.z_dim) if G.c_dim > 0: all_cs[idx, rnd.randint(G.c_dim)] = 1 # Run mapping network. w_avg = G.mapping.w_avg all_zs = self.to_device(torch.from_numpy(all_zs)) all_cs = self.to_device(torch.from_numpy(all_cs)) all_ws = G.mapping(z=all_zs, c=all_cs, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) - w_avg all_ws = dict(zip(all_seeds, all_ws)) # Calculate final W. w = torch.stack([all_ws[seed] * weight for seed, weight in w0_seeds]).sum(dim=0, keepdim=True) stylemix_idx = [idx for idx in stylemix_idx if 0 <= idx < G.num_ws] if len(stylemix_idx) > 0: w[:, stylemix_idx] = all_ws[stylemix_seed][np.newaxis, stylemix_idx] w += w_avg # Run synthesis network. synthesis_kwargs = dnnlib.EasyDict(noise_mode=noise_mode, force_fp32=force_fp32) torch.manual_seed(random_seed) out, layers = self.run_synthesis_net(G.synthesis, w, capture_layer=layer_name, **synthesis_kwargs) # Update layer list. cache_key = (G.synthesis, tuple(sorted(synthesis_kwargs.items()))) if cache_key not in self._net_layers: if layer_name is not None: torch.manual_seed(random_seed) _out, layers = self.run_synthesis_net(G.synthesis, w, **synthesis_kwargs) self._net_layers[cache_key] = layers res.layers = self._net_layers[cache_key] # Untransform. if untransform and res.has_input_transform: out, _mask = _apply_affine_transformation(out.to(torch.float32), G.synthesis.input.transform, amax=6) # Override amax to hit the fast path in upfirdn2d. # Select channels and compute statistics. out = out[0].to(torch.float32) if sel_channels > out.shape[0]: sel_channels = 1 base_channel = max(min(base_channel, out.shape[0] - sel_channels), 0) sel = out[base_channel : base_channel + sel_channels] res.stats = torch.stack([ out.mean(), sel.mean(), out.std(), sel.std(), out.norm(float('inf')), sel.norm(float('inf')), ]) # Scale and convert to uint8. img = sel if img_normalize: img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) img = img * (10 ** (img_scale_db / 20)) img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) res.image = img # FFT. if fft_show: sig = out if fft_all else sel sig = sig.to(torch.float32) sig = sig - sig.mean(dim=[1,2], keepdim=True) sig = sig * torch.kaiser_window(sig.shape[1], periodic=False, beta=fft_beta, device=self._device)[None, :, None] sig = sig * torch.kaiser_window(sig.shape[2], periodic=False, beta=fft_beta, device=self._device)[None, None, :] fft = torch.fft.fftn(sig, dim=[1,2]).abs().square().sum(dim=0) fft = fft.roll(shifts=[fft.shape[0] // 2, fft.shape[1] // 2], dims=[0,1]) fft = (fft / fft.mean()).log10() * 10 # dB fft = self._apply_cmap((fft / fft_range_db + 1) / 2) res.image = torch.cat([img.expand_as(fft), fft], dim=1) @staticmethod def run_synthesis_net(net, *args, capture_layer=None, **kwargs): # => out, layers submodule_names = {mod: name for name, mod in net.named_modules()} unique_names = set() layers = [] def module_hook(module, _inputs, outputs): outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] outputs = [out for out in outputs if isinstance(out, torch.Tensor) and out.ndim in [4, 5]] for idx, out in enumerate(outputs): if out.ndim == 5: # G-CNN => remove group dimension. out = out.mean(2) name = submodule_names[module] if name == '': name = 'output' if len(outputs) > 1: name += f':{idx}' if name in unique_names: suffix = 2 while f'{name}_{suffix}' in unique_names: suffix += 1 name += f'_{suffix}' unique_names.add(name) shape = [int(x) for x in out.shape] dtype = str(out.dtype).split('.')[-1] layers.append(dnnlib.EasyDict(name=name, shape=shape, dtype=dtype)) if name == capture_layer: raise CaptureSuccess(out) hooks = [module.register_forward_hook(module_hook) for module in net.modules()] try: out = net(*args, **kwargs) except CaptureSuccess as e: out = e.out for hook in hooks: hook.remove() return out, layers #----------------------------------------------------------------------------