# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from socket import has_dualstack_ipv6 import sys import copy import traceback import math import numpy as np from PIL import Image, ImageDraw, ImageFont import torch import torch.fft import torch.nn as nn import torch.nn.functional as F import matplotlib.cm import dnnlib from torch_utils.ops import upfirdn2d import legacy # pylint: disable=import-error # ---------------------------------------------------------------------------- class CapturedException(Exception): def __init__(self, msg=None): if msg is None: _type, value, _traceback = sys.exc_info() assert value is not None if isinstance(value, CapturedException): msg = str(value) else: msg = traceback.format_exc() assert isinstance(msg, str) super().__init__(msg) # ---------------------------------------------------------------------------- class CaptureSuccess(Exception): def __init__(self, out): super().__init__() self.out = out # ---------------------------------------------------------------------------- def add_watermark_np(input_image_array, watermark_text="AI Generated"): image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA") # Initialize text image txt = Image.new('RGBA', image.size, (255, 255, 255, 0)) font = ImageFont.truetype('arial.ttf', round(25/512*image.size[0])) d = ImageDraw.Draw(txt) text_width, text_height = font.getsize(watermark_text) text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10) # white color with the alpha channel set to semi-transparent text_color = (255, 255, 255, 128) # Draw the text onto the text canvas d.text(text_position, watermark_text, font=font, fill=text_color) # Combine the image with the watermark watermarked = Image.alpha_composite(image, txt) watermarked_array = np.array(watermarked) return watermarked_array # ---------------------------------------------------------------------------- class Renderer: def __init__(self, disable_timing=False): self._device = torch.device('cuda' if torch.cuda.is_available( ) else 'mps' if torch.backends.mps.is_available() else 'cpu') self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64 self._pkl_data = dict() # {pkl: dict | CapturedException, ...} self._networks = dict() # {cache_key: torch.nn.Module, ...} self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} self._cmaps = dict() # {name: torch.Tensor, ...} self._is_timing = False if not disable_timing: self._start_event = torch.cuda.Event(enable_timing=True) self._end_event = torch.cuda.Event(enable_timing=True) self._disable_timing = disable_timing self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} def render(self, **args): if self._disable_timing: self._is_timing = False else: self._start_event.record(torch.cuda.current_stream(self._device)) self._is_timing = True res = dnnlib.EasyDict() try: init_net = False if not hasattr(self, 'G'): init_net = True if hasattr(self, 'pkl'): if self.pkl != args['pkl']: init_net = True if hasattr(self, 'w_load'): if self.w_load is not args['w_load']: init_net = True if hasattr(self, 'w0_seed'): if self.w0_seed != args['w0_seed']: init_net = True if hasattr(self, 'w_plus'): if self.w_plus != args['w_plus']: init_net = True if args['reset_w']: init_net = True res.init_net = init_net if init_net: self.init_network(res, **args) self._render_drag_impl(res, **args) except: res.error = CapturedException() if not self._disable_timing: self._end_event.record(torch.cuda.current_stream(self._device)) if 'image' in res: res.image = self.to_cpu(res.image).detach().numpy() res.image = add_watermark_np(res.image, 'AI Generated') if 'stats' in res: res.stats = self.to_cpu(res.stats).detach().numpy() if 'error' in res: res.error = str(res.error) # if 'stop' in res and res.stop: if self._is_timing and not self._disable_timing: self._end_event.synchronize() res.render_time = self._start_event.elapsed_time( self._end_event) * 1e-3 self._is_timing = False return res def get_network(self, pkl, key, **tweak_kwargs): data = self._pkl_data.get(pkl, None) if data is None: print(f'Loading "{pkl}"... ', end='', flush=True) try: with dnnlib.util.open_url(pkl, verbose=False) as f: data = legacy.load_network_pkl(f) print('Done.') except: data = CapturedException() print('Failed!') self._pkl_data[pkl] = data self._ignore_timing() if isinstance(data, CapturedException): raise data orig_net = data[key] cache_key = (orig_net, self._device, tuple( sorted(tweak_kwargs.items()))) net = self._networks.get(cache_key, None) if net is None: try: if 'stylegan2' in pkl: from training.networks_stylegan2 import Generator elif 'stylegan3' in pkl: from training.networks_stylegan3 import Generator elif 'stylegan_human' in pkl: from stylegan_human.training_scripts.sg2.training.networks import Generator else: raise NameError('Cannot infer model type from pkl name!') print(data[key].init_args) print(data[key].init_kwargs) if 'stylegan_human' in pkl: net = Generator( *data[key].init_args, **data[key].init_kwargs, square=False, padding=True) else: net = Generator(*data[key].init_args, **data[key].init_kwargs) net.load_state_dict(data[key].state_dict()) net.to(self._device) except: net = CapturedException() self._networks[cache_key] = net self._ignore_timing() if isinstance(net, CapturedException): raise net return net def _get_pinned_buf(self, ref): key = (tuple(ref.shape), ref.dtype) buf = self._pinned_bufs.get(key, None) if buf is None: buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() self._pinned_bufs[key] = buf return buf def to_device(self, buf): return self._get_pinned_buf(buf).copy_(buf).to(self._device) def to_cpu(self, buf): return self._get_pinned_buf(buf).copy_(buf).clone() def _ignore_timing(self): self._is_timing = False def _apply_cmap(self, x, name='viridis'): cmap = self._cmaps.get(name, None) if cmap is None: cmap = matplotlib.cm.get_cmap(name) cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] cmap = self.to_device(torch.from_numpy(cmap)) self._cmaps[name] = cmap hi = cmap.shape[0] - 1 x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) x = torch.nn.functional.embedding(x, cmap) return x def init_network(self, res, pkl=None, w0_seed=0, w_load=None, w_plus=True, noise_mode='const', trunc_psi=0.7, trunc_cutoff=None, input_transform=None, lr=0.001, **kwargs ): # Dig up network details. self.pkl = pkl G = self.get_network(pkl, 'G_ema') self.G = G res.img_resolution = G.img_resolution res.num_ws = G.num_ws res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) res.has_input_transform = ( hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) res.stop = False # Set input transform. if res.has_input_transform: m = np.eye(3) try: if input_transform is not None: m = np.linalg.inv(np.asarray(input_transform)) except np.linalg.LinAlgError: res.error = CapturedException() G.synthesis.input.transform.copy_(torch.from_numpy(m)) # Generate random latents. self.w0_seed = w0_seed self.w_load = w_load if self.w_load is None: # Generate random latents. z = torch.from_numpy(np.random.RandomState(w0_seed).randn( 1, 512)).to(self._device, dtype=self._dtype) # Run mapping network. label = torch.zeros([1, G.c_dim], device=self._device) w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) else: w = self.w_load.clone().to(self._device) self.w0 = w.detach().clone() self.w_plus = w_plus if w_plus: self.w = w.detach() else: self.w = w[:, 0, :].detach() self.w.requires_grad = True self.w_optim = torch.optim.Adam([self.w], lr=lr) self.feat_refs = None self.points0_pt = None def update_lr(self, lr): del self.w_optim self.w_optim = torch.optim.Adam([self.w], lr=lr) print(f'Rebuild optimizer with lr: {lr}') print(' Remain feat_refs and points0_pt') def _render_drag_impl(self, res, points=[], targets=[], mask=None, lambda_mask=10, reg=0, feature_idx=5, r1=3, r2=12, random_seed=0, noise_mode='const', trunc_psi=0.7, force_fp32=False, layer_name=None, sel_channels=3, base_channel=0, img_scale_db=0, img_normalize=False, untransform=False, is_drag=False, reset=False, to_pil=False, **kwargs ): G = self.G ws = self.w if ws.dim() == 2: ws = ws.unsqueeze(1).repeat(1, 6, 1) ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1) if hasattr(self, 'points'): if len(points) != len(self.points): reset = True if reset: self.feat_refs = None self.points0_pt = None self.points = points # Run synthesis network. label = torch.zeros([1, G.c_dim], device=self._device) img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True) h, w = G.img_resolution, G.img_resolution if is_drag: X = torch.linspace(0, h, h) Y = torch.linspace(0, w, w) xx, yy = torch.meshgrid(X, Y) feat_resize = F.interpolate( feat[feature_idx], [h, w], mode='bilinear') if self.feat_refs is None: self.feat0_resize = F.interpolate( feat[feature_idx].detach(), [h, w], mode='bilinear') self.feat_refs = [] for point in points: py, px = round(point[0]), round(point[1]) self.feat_refs.append(self.feat0_resize[:, :, py, px]) self.points0_pt = torch.Tensor(points).unsqueeze( 0).to(self._device) # 1, N, 2 # Point tracking with feature matching with torch.no_grad(): for j, point in enumerate(points): r = round(r2 / 512 * h) up = max(point[0] - r, 0) down = min(point[0] + r + 1, h) left = max(point[1] - r, 0) right = min(point[1] + r + 1, w) feat_patch = feat_resize[:, :, up:down, left:right] L2 = torch.linalg.norm( feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1) _, idx = torch.min(L2.view(1, -1), -1) width = right - left point = [idx.item() // width + up, idx.item() % width + left] points[j] = point res.points = [[point[0], point[1]] for point in points] # Motion supervision loss_motion = 0 res.stop = True for j, point in enumerate(points): direction = torch.Tensor( [targets[j][1] - point[1], targets[j][0] - point[0]]) if torch.linalg.norm(direction) > max(2 / 512 * h, 2): res.stop = False if torch.linalg.norm(direction) > 1: distance = ( (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5 relis, reljs = torch.where(distance < round(r1 / 512 * h)) direction = direction / \ (torch.linalg.norm(direction) + 1e-7) gridh = (relis-direction[1]) / (h-1) * 2 - 1 gridw = (reljs-direction[0]) / (w-1) * 2 - 1 grid = torch.stack( [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0) target = F.grid_sample( feat_resize.float(), grid, align_corners=True).squeeze(2) loss_motion += F.l1_loss( feat_resize[:, :, relis, reljs], target.detach()) loss = loss_motion if mask is not None: if mask.min() == 0 and mask.max() == 1: mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0) loss_fix = F.l1_loss( feat_resize * mask_usq, self.feat0_resize * mask_usq) loss += lambda_mask * loss_fix loss += reg * F.l1_loss(ws, self.w0) # latent code regularization if not res.stop: self.w_optim.zero_grad() loss.backward() self.w_optim.step() # Scale and convert to uint8. img = img[0] if img_normalize: img = img / img.norm(float('inf'), dim=[1, 2], keepdim=True).clip(1e-8, 1e8) img = img * (10 ** (img_scale_db / 20)) img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) if to_pil: from PIL import Image img = img.cpu().numpy() img = Image.fromarray(img) res.image = img res.w = ws.detach().cpu().numpy() # ----------------------------------------------------------------------------