StyleNeRF / app.py
thomagram's picture
Update app.py
f5f804f
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os, sys
os.system('pip install -r requirements.txt')
import gradio as gr
import numpy as np
import dnnlib
import time
import legacy
import torch
import glob
import cv2
from torch_utils import misc
from renderer import Renderer
from training.networks import Generator
from huggingface_hub import hf_hub_download
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
model_lists = {
'ffhq-512x512-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl'),
'ffhq-512x512-cc': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_512_cc.pkl'),
'ffhq-256x256-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_256.pkl'),
'ffhq-1024x1024-basic': dict(repo_id='facebook/stylenerf-ffhq-config-basic', filename='ffhq_1024.pkl'),
}
model_names = [name for name in model_lists]
def set_random_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name=None):
gen = model.synthesis
range_u, range_v = gen.C.range_u, gen.C.range_v
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
yaw, pitch = 0.5 * yaw, 0.3 * pitch
pitch = pitch + np.pi/2
u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
else:
u = (yaw + 1) / 2
v = (pitch + 1) / 2
cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
return cam
def check_name(model_name):
"""Gets model by name."""
if model_name in model_lists:
network_pkl = hf_hub_download(**model_lists[model_name])
else:
if os.path.isdir(model_name):
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
else:
network_pkl = model_name
return network_pkl
def get_model(network_pkl, render_option=None):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
network = legacy.load_network_pkl(f)
G = network['G_ema'].to(device) # type: ignore
with torch.no_grad():
G2 = Generator(*G.init_args, **G.init_kwargs).to(device)
misc.copy_params_and_buffers(G, G2, require_all=False)
print('compile and go through the initial image')
G2 = G2.eval()
init_z = torch.from_numpy(np.random.RandomState(0).rand(1, G2.z_dim)).to(device)
init_cam = get_camera_traj(G2, 0, 0, model_name=network_pkl)
dummy = G2(z=init_z, c=None, camera_matrices=init_cam, render_option=render_option, theta=0)
res = dummy['img'].shape[-1]
imgs = np.zeros((res, res//2, 3))
return G2, res, imgs
global_states = list(get_model(check_name(model_names[0])))
wss = [None, None]
def proc_seed(history, seed):
if isinstance(seed, str):
seed = 0
else:
seed = int(seed)
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
history = history or {}
seeds = []
trunc = trunc / 100
if model_find != "":
model_name = model_find
model_name = check_name(model_name)
if model_name != history.get("model_name", None):
model, res, imgs = get_model(model_name, render_option)
global_states[0] = model
global_states[1] = res
global_states[2] = imgs
model, res, imgs = global_states
for idx, seed in enumerate([seed1, seed2]):
if isinstance(seed, str):
seed = 0
else:
seed = int(seed)
if (seed != history.get(f'seed{idx}', -1)) or \
(model_name != history.get("model_name", None)) or \
(trunc != history.get("trunc", 0.7)) or \
(wss[idx] is None):
print(f'use seed {seed}')
set_random_seed(seed)
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0, model_name=model_name), render_option=render_option)
ws = ws.detach().cpu().numpy()
img = img[0].permute(1,2,0).detach().cpu().numpy()
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
(res//2, res//2), cv2.INTER_AREA)
wss[idx] = ws
else:
seed = history[f'seed{idx}']
seeds += [seed]
history[f'seed{idx}'] = seed
history['trunc'] = trunc
history['model_name'] = model_name
set_random_seed(sum(seeds))
# style mixing (?)
ws1, ws2 = [torch.from_numpy(ws).to(device) for ws in wss]
ws = ws1.clone()
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
# set visualization for other types of inputs.
if early == 'Normal Map':
render_option += ',normal,early'
elif early == 'Gradient Map':
render_option += ',gradient,early'
start_t = time.time()
with torch.no_grad():
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
image = model.get_final_output(
styles=ws, camera_matrices=cam,
theta=roll * np.pi,
render_option=render_option)
end_t = time.time()
image = image[0].permute(1,2,0).detach().cpu().numpy().clip(-1, 1) * 0.5 + 0.5
if imgs.shape[0] == image.shape[0]:
image = np.concatenate([imgs, image], 1)
else:
a = image.shape[0]
b = int(imgs.shape[1] / imgs.shape[0] * a)
print(f'resize {a} {b} {image.shape} {imgs.shape}')
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
print(f'rendering time = {end_t-start_t:.4f}s')
image = (image * 255).astype('uint8')
return image, history
model_name = gr.inputs.Dropdown(model_names)
model_find = gr.inputs.Textbox(label="Checkpoint path (folder or .pkl file)", default="")
render_option = gr.inputs.Textbox(label="Additional rendering options", default='freeze_bg,steps:50')
trunc = gr.inputs.Slider(default=70, maximum=100, minimum=0, label='Truncation trick (%)')
seed1 = gr.inputs.Number(default=1, label="Random seed1")
seed2 = gr.inputs.Number(default=9, label="Random seed2")
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (geometry)")
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="Linear mixing ratio (apparence)")
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='Intermedia output')
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Yaw")
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Pitch")
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="Roll (optional, not suggested for basic config)")
fov = gr.inputs.Slider(minimum=10, maximum=14, default=12, label="Fov")
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
gr.Interface(fn=f_synthesis,
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
title="Interactive Web Demo for StyleNeRF (ICLR 2022)",
description="StyleNeRF: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only.",
outputs=["image", "state"],
layout='unaligned',
css=css, theme='dark-seafoam',
live=True).launch(enable_queue=True)