# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os, sys os.system('pip install -r requirements.txt') import gradio as gr import numpy as np import dnnlib import time import legacy import torch import glob import cv2 from torch_utils import misc from renderer import Renderer from training.networks import Generator from huggingface_hub import hf_hub_download device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111 model_lists = { 'ffhq-512x512-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'), 'ffhq-512x512-cc': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512_cc.pkl'), 'ffhq-256x256-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'), 'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'), } model_names = [name for name in model_lists] def set_random_seed(seed): torch.manual_seed(seed) np.random.seed(seed) def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None): gen = model.synthesis range_u, range_v = gen.C.range_u, gen.C.range_v if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option? yaw, pitch = 0.5 * yaw, 0.3 * pitch pitch = pitch + np.pi/2 u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) else: u = (yaw + 1) / 2 v = (pitch + 1) / 2 cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov) return cam def check_name(model_name): """Gets model by name.""" if model_name in model_lists: network_pkl = hf_hub_download(**model_lists[model_name]) else: if os.path.isdir(model_name): network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1] else: network_pkl = model_name return network_pkl def get_model(network_pkl, render_option=None): print('Loading networks from "%s"...' % network_pkl) with dnnlib.util.open_url(network_pkl) as f: network = legacy.load_network_pkl(f) G = network['G_ema'].to(device) # type: ignore with torch.no_grad(): G2 = Generator(*G.init_args, **G.init_kwargs).to(device) misc.copy_params_and_buffers(G, G2, require_all=False) print('compile and go through the initial image') G2 = G2.eval() init_z = torch.from_numpy(np.random.RandomState(0).rand(1, G2.z_dim)).to(device) init_cam = get_camera_traj(G2, 0, 0, model_name=network_pkl) dummy = G2(z=init_z, c=None, camera_matrices=init_cam, render_option=render_option, theta=0) res = dummy['img'].shape[-1] imgs = np.zeros((res, res//2, 3)) return G2, res, imgs global_states = list(get_model(check_name(model_names[0]))) wss = [None, None] def proc_seed(history, seed): if isinstance(seed, str): seed = 0 else: seed = int(seed) def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history): history = history or {} seeds = [] trunc = trunc / 100 if model_find != "": model_name = model_find model_name = check_name(model_name) if model_name != history.get("model_name", None): model, res, imgs = get_model(model_name, render_option) global_states[0] = model global_states[1] = res global_states[2] = imgs model, res, imgs = global_states for idx, seed in enumerate([seed1, seed2]): if isinstance(seed, str): seed = 0 else: seed = int(seed) if (seed != history.get(f'seed{idx}', -1)) or \ (model_name != history.get("model_name", None)) or \ (trunc != history.get("trunc", 0.7)) or \ (wss[idx] is None): print(f'use seed {seed}') set_random_seed(seed) z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device) ws = model.mapping(z=z, c=None, truncation_psi=trunc) img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option) ws = ws.detach().cpu().numpy() img = img[0].permute(1,2,0).detach().cpu().numpy() imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize( np.asarray(img).clip(-1, 1) * 0.5 + 0.5, (res//2, res//2), cv2.INTER_AREA) wss[idx] = ws else: seed = history[f'seed{idx}'] seeds += [seed] history[f'seed{idx}'] = seed history['trunc'] = trunc history['model_name'] = model_name set_random_seed(sum(seeds)) # style mixing (?) ws1, ws2 = [torch.from_numpy(ws).to(device) for ws in wss] ws = ws1.clone() ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1) ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2) # set visualization for other types of inputs. if early == 'Normal Map': render_option += ',normal,early' elif early == 'Gradient Map': render_option += ',gradient,early' start_t = time.time() with torch.no_grad(): cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name) image = model.get_final_output( styles=ws, camera_matrices=cam, theta=roll * np.pi, render_option=render_option) end_t = time.time() image = image[0].permute(1,2,0).detach().cpu().numpy().clip(-1, 1) * 0.5 + 0.5 if imgs.shape[0] == image.shape[0]: image = np.concatenate([imgs, image], 1) else: a = image.shape[0] b = int(imgs.shape[1] / imgs.shape[0] * a) print(f'resize {a} {b} {image.shape} {imgs.shape}') image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1) print(f'rendering time = {end_t-start_t:.4f}s') image = (image * 255).astype('uint8') return image, history model_name = gr.inputs.Dropdown(model_names) model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="") render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50') trunc = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)') seed1 = gr.inputs.Number(default=1, label="Random seed1") seed2 = gr.inputs.Number(default=9, label="Random seed2") mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)") mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)") early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output') yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw") pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch") roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)") fov = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov") css = ".output-image, .input-image, .image-preview {height: 600px !important} " gr.Interface(fn=f_synthesis, inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"], title="Interactive Web Demo for StyleNeRF (ICLR 2022)", description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.", outputs=["image", "state"], layout='unaligned', css=css, theme='dark-seafoam', live=True).launch(enable_queue=True)