# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """Wrap the generator to render a sequence of images""" import torch import torch.nn.functional as F import numpy as np from torch import random import tqdm import copy import trimesh class Renderer(object): def __init__(self, generator, discriminator=None, program=None): self.generator = generator self.discriminator = discriminator self.sample_tmp = 0.65 self.program = program self.seed = 0 if (program is not None) and (len(program.split(':')) == 2): from training.dataset import ImageFolderDataset self.image_data = ImageFolderDataset(program.split(':')[1]) self.program = program.split(':')[0] else: self.image_data = None def set_random_seed(self, seed): self.seed = seed torch.manual_seed(seed) np.random.seed(seed) def __call__(self, *args, **kwargs): self.generator.eval() # eval mode... if self.program is None: if hasattr(self.generator, 'get_final_output'): return self.generator.get_final_output(*args, **kwargs) return self.generator(*args, **kwargs) if self.image_data is not None: batch_size = 1 indices = (np.random.rand(batch_size) * len(self.image_data)).tolist() rimages = np.stack([self.image_data._load_raw_image(int(i)) for i in indices], 0) rimages = torch.from_numpy(rimages).float().to(kwargs['z'].device) / 127.5 - 1 kwargs['img'] = rimages outputs = getattr(self, f"render_{self.program}")(*args, **kwargs) if self.image_data is not None: imgs = outputs if not isinstance(outputs, tuple) else outputs[0] size = imgs[0].size(-1) rimg = F.interpolate(rimages, (size, size), mode='bicubic', align_corners=False) imgs = [torch.cat([img, rimg], 0) for img in imgs] outputs = imgs if not isinstance(outputs, tuple) else (imgs, outputs[1]) return outputs def get_additional_params(self, ws, t=0): gen = self.generator.synthesis batch_size = ws.size(0) kwargs = {} if not hasattr(gen, 'get_latent_codes'): return kwargs s_val, t_val, r_val = [[0, 0, 0]], [[0.5, 0.5, 0.5]], [0.] # kwargs["transformations"] = gen.get_transformations(batch_size=batch_size, mode=[s_val, t_val, r_val], device=ws.device) # kwargs["bg_rotation"] = gen.get_bg_rotation(batch_size, device=ws.device) # kwargs["light_dir"] = gen.get_light_dir(batch_size, device=ws.device) kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) kwargs["camera_matrices"] = self.get_camera_traj(t, ws.size(0), device=ws.device) return kwargs def get_camera_traj(self, t, batch_size=1, traj_type='pigan', device='cpu'): gen = self.generator.synthesis if traj_type == 'pigan': range_u, range_v = gen.C.range_u, gen.C.range_v pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2 yaw = 0.4 * np.sin(t * 2 * np.pi) u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device) else: raise NotImplementedError return cam def render_rotation_camera(self, *args, **kwargs): batch_size, n_steps = 2, kwargs["n_steps"] gen = self.generator.synthesis if 'img' not in kwargs: ws = self.generator.mapping(*args, **kwargs) else: ws, _ = self.generator.encoder(kwargs['img']) # ws = ws.repeat(batch_size, 1, 1) # kwargs["not_render_background"] = True if hasattr(gen, 'get_latent_codes'): kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) kwargs.pop('img', None) out = [] cameras = [] relatve_range_u = kwargs['relative_range_u'] u_samples = np.linspace(relatve_range_u[0], relatve_range_u[1], n_steps) for step in tqdm.tqdm(range(n_steps)): # Set Camera u = u_samples[step] kwargs["camera_matrices"] = gen.get_camera(batch_size=batch_size, mode=[u, 0.5, 0.5], device=ws.device) cameras.append(gen.get_camera(batch_size=batch_size, mode=[u, 0.5, 0.5], device=ws.device)) with torch.no_grad(): out_i = gen(ws, **kwargs) if isinstance(out_i, dict): out_i = out_i['img'] out.append(out_i) if 'return_cameras' in kwargs and kwargs["return_cameras"]: return out, cameras else: return out def render_rotation_camera3(self, styles=None, *args, **kwargs): gen = self.generator.synthesis n_steps = 36 # 120 if styles is None: batch_size = 2 if 'img' not in kwargs: ws = self.generator.mapping(*args, **kwargs) else: ws = self.generator.encoder(kwargs['img'])['ws'] # ws = ws.repeat(batch_size, 1, 1) else: ws = styles batch_size = ws.size(0) # kwargs["not_render_background"] = True # Get Random codes and bg rotation self.sample_tmp = 0.72 if hasattr(gen, 'get_latent_codes'): kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) kwargs.pop('img', None) # if getattr(gen, "use_noise", False): # from dnnlib.geometry import extract_geometry # kwargs['meshes'] = {} # low_res, high_res = gen.resolution_vol, gen.img_resolution # res = low_res * 2 # while res <= high_res: # kwargs['meshes'][res] = [trimesh.Trimesh(*extract_geometry(gen, ws, resolution=res, threshold=30.))] # kwargs['meshes'][res] += [ # torch.randn(len(kwargs['meshes'][res][0].vertices), # 2, device=ws.device)[kwargs['meshes'][res][0].faces]] # res = res * 2 # if getattr(gen, "use_noise", False): # kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=2048, return_noise=True, sphere_noise=True) # if getattr(gen, "use_voxel_noise", False): # kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=128, return_noise=True) kwargs['noise_mode'] = 'const' out = [] tspace = np.linspace(0, 1, n_steps) range_u, range_v = gen.C.range_u, gen.C.range_v for step in tqdm.tqdm(range(n_steps)): t = tspace[step] pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2 yaw = 0.4 * np.sin(t * 2 * np.pi) u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) kwargs["camera_matrices"] = gen.get_camera( batch_size=batch_size, mode=[u, v, t], device=ws.device) with torch.no_grad(): out_i = gen(ws, **kwargs) if isinstance(out_i, dict): out_i = out_i['img'] out.append(out_i) return out def render_rotation_both(self, *args, **kwargs): gen = self.generator.synthesis batch_size, n_steps = 1, 36 if 'img' not in kwargs: ws = self.generator.mapping(*args, **kwargs) else: ws, _ = self.generator.encoder(kwargs['img']) ws = ws.repeat(batch_size, 1, 1) # kwargs["not_render_background"] = True # Get Random codes and bg rotation kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) kwargs.pop('img', None) out = [] tspace = np.linspace(0, 1, n_steps) range_u, range_v = gen.C.range_u, gen.C.range_v for step in tqdm.tqdm(range(n_steps)): t = tspace[step] pitch = 0.2 * np.cos(t * 2 * np.pi) + np.pi/2 yaw = 0.4 * np.sin(t * 2 * np.pi) u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) kwargs["camera_matrices"] = gen.get_camera( batch_size=batch_size, mode=[u, v, 0.5], device=ws.device) with torch.no_grad(): out_i = gen(ws, **kwargs) if isinstance(out_i, dict): out_i = out_i['img'] kwargs_n = copy.deepcopy(kwargs) kwargs_n.update({'render_option': 'early,no_background,up64,depth,normal'}) out_n = gen(ws, **kwargs_n) out_n = F.interpolate(out_n, size=(out_i.size(-1), out_i.size(-1)), mode='bicubic', align_corners=True) out_i = torch.cat([out_i, out_n], 0) out.append(out_i) return out def render_rotation_grid(self, styles=None, return_cameras=False, *args, **kwargs): gen = self.generator.synthesis if styles is None: batch_size = 1 ws = self.generator.mapping(*args, **kwargs) ws = ws.repeat(batch_size, 1, 1) else: ws = styles batch_size = ws.size(0) kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) kwargs.pop('img', None) if getattr(gen, "use_voxel_noise", False): kwargs['voxel_noise'] = gen.get_voxel_field(styles=ws, n_vols=128, return_noise=True) out = [] cameras = [] range_u, range_v = gen.C.range_u, gen.C.range_v a_steps, b_steps = 6, 3 aspace = np.linspace(-0.4, 0.4, a_steps) bspace = np.linspace(-0.2, 0.2, b_steps) * -1 for b in tqdm.tqdm(range(b_steps)): for a in range(a_steps): t_a = aspace[a] t_b = bspace[b] camera_mat = gen.camera_matrix.repeat(batch_size, 1, 1).to(ws.device) loc_x = np.cos(t_b) * np.cos(t_a) loc_y = np.cos(t_b) * np.sin(t_a) loc_z = np.sin(t_b) loc = torch.tensor([[loc_x, loc_y, loc_z]], dtype=torch.float32).to(ws.device) from dnnlib.camera import look_at R = look_at(loc) RT = torch.eye(4).reshape(1, 4, 4).repeat(batch_size, 1, 1) RT[:, :3, :3] = R RT[:, :3, -1] = loc world_mat = RT.to(ws.device) #kwargs["camera_matrices"] = gen.get_camera( # batch_size=batch_size, mode=[u, v, 0.5], device=ws.device) kwargs["camera_matrices"] = (camera_mat, world_mat, "random", None) with torch.no_grad(): out_i = gen(ws, **kwargs) if isinstance(out_i, dict): out_i = out_i['img'] # kwargs_n = copy.deepcopy(kwargs) # kwargs_n.update({'render_option': 'early,no_background,up64,depth,normal'}) # out_n = gen(ws, **kwargs_n) # out_n = F.interpolate(out_n, # size=(out_i.size(-1), out_i.size(-1)), # mode='bicubic', align_corners=True) # out_i = torch.cat([out_i, out_n], 0) out.append(out_i) if return_cameras: return out, cameras else: return out def render_rotation_camera_grid(self, *args, **kwargs): batch_size, n_steps = 1, 60 gen = self.generator.synthesis bbox_generator = self.generator.synthesis.boundingbox_generator ws = self.generator.mapping(*args, **kwargs) ws = ws.repeat(batch_size, 1, 1) # Get Random codes and bg rotation kwargs["latent_codes"] = gen.get_latent_codes(batch_size, tmp=self.sample_tmp, device=ws.device) del kwargs['render_option'] out = [] for v in [0.15, 0.5, 1.05]: for step in tqdm.tqdm(range(n_steps)): # Set Camera u = step * 1.0 / (n_steps - 1) - 1.0 kwargs["camera_matrices"] = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=ws.device) with torch.no_grad(): out_i = gen(ws, render_option=None, **kwargs) if isinstance(out_i, dict): out_i = out_i['img'] # option_n = 'early,no_background,up64,depth,direct_depth' # option_n = 'early,up128,no_background,depth,normal' # out_n = gen(ws, render_option=option_n, **kwargs) # out_n = F.interpolate(out_n, # size=(out_i.size(-1), out_i.size(-1)), # mode='bicubic', align_corners=True) # out_i = torch.cat([out_i, out_n], 0) out.append(out_i) # out += out[::-1] return out