StyleNeRF / app.py
Jiatao Gu
resolve conflict
77c753d
raw history blame
No virus
7.94 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os, sys
os.system('pip install -r requirements.txt')
import gradio as gr
import numpy as np
import dnnlib
import time
import legacy
import torch
import glob
import cv2
import signal
from torch_utils import misc
from renderer import Renderer
from training.networks import Generator
from huggingface_hub import hf_hub_download
device = torch.device('cuda')
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
def handler(signum, frame):
res = input("Ctrl-c was pressed. Do you really want to exit? y/n ")
if res == 'y':
gr.close_all()
exit(1)
signal.signal(signal.SIGINT, handler)
def set_random_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
def get_camera_traj(model, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512'):
gen = model.synthesis
range_u, range_v = gen.C.range_u, gen.C.range_v
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option?
yaw, pitch = 0.5 * yaw, 0.3 * pitch
pitch = pitch + np.pi/2
u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
else:
u = (yaw + 1) / 2
v = (pitch + 1) / 2
cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=device, fov=fov)
return cam
def check_name(model_name='FFHQ512'):
"""Gets model by name."""
if model_name == 'FFHQ512':
network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
# TODO: checkpoint to be updated!
# elif model_name == 'FFHQ512v2':
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
# elif model_name == 'AFHQ512':
# network_pkl = "./pretrained/afhq_512.pkl"
# elif model_name == 'MetFaces512':
# network_pkl = "./pretrained/metfaces_512.pkl"
# elif model_name == 'CompCars256':
# network_pkl = "./pretrained/cars_256.pkl"
# elif model_name == 'FFHQ1024':
# network_pkl = "./pretrained/ffhq_1024.pkl"
else:
if os.path.isdir(model_name):
network_pkl = sorted(glob.glob(model_name + '/*.pkl'))[-1]
else:
network_pkl = model_name
return network_pkl
def get_model(network_pkl, render_option=None):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
network = legacy.load_network_pkl(f)
G = network['G_ema'].to(device) # type: ignore
with torch.no_grad():
G2 = Generator(*G.init_args, **G.init_kwargs).to(device)
misc.copy_params_and_buffers(G, G2, require_all=False)
print('compile and go through the initial image')
G2 = G2.eval()
init_z = torch.from_numpy(np.random.RandomState(0).rand(1, G2.z_dim)).to(device)
init_cam = get_camera_traj(G2, 0, 0, model_name=network_pkl)
dummy = G2(z=init_z, c=None, camera_matrices=init_cam, render_option=render_option, theta=0)
res = dummy['img'].shape[-1]
imgs = np.zeros((res, res//2, 3))
return G2, res, imgs
global_states = list(get_model(check_name()))
wss = [None, None]
def proc_seed(history, seed):
if isinstance(seed, str):
seed = 0
else:
seed = int(seed)
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
history = history or {}
seeds = []
if model_find != "":
model_name = model_find
model_name = check_name(model_name)
if model_name != history.get("model_name", None):
model, res, imgs = get_model(model_name, render_option)
global_states[0] = model
global_states[1] = res
global_states[2] = imgs
model, res, imgs = global_states
for idx, seed in enumerate([seed1, seed2]):
if isinstance(seed, str):
seed = 0
else:
seed = int(seed)
if (seed != history.get(f'seed{idx}', -1)) or \
(model_name != history.get("model_name", None)) or \
(trunc != history.get("trunc", 0.7)) or \
(wss[idx] is None):
print(f'use seed {seed}')
set_random_seed(seed)
z = torch.from_numpy(np.random.RandomState(int(seed)).randn(1, model.z_dim).astype('float32')).to(device)
ws = model.mapping(z=z, c=None, truncation_psi=trunc)
img = model.get_final_output(styles=ws, camera_matrices=get_camera_traj(model, 0, 0), render_option=render_option)
ws = ws.detach().cpu().numpy()
img = img[0].permute(1,2,0).detach().cpu().numpy()
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
(res//2, res//2), cv2.INTER_AREA)
wss[idx] = ws
else:
seed = history[f'seed{idx}']
seeds += [seed]
history[f'seed{idx}'] = seed
history['trunc'] = trunc
history['model_name'] = model_name
set_random_seed(sum(seeds))
# style mixing (?)
ws1, ws2 = [torch.from_numpy(ws).to(device) for ws in wss]
ws = ws1.clone()
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
# set visualization for other types of inputs.
if early == 'Normal Map':
render_option += ',normal,early'
elif early == 'Gradient Map':
render_option += ',gradient,early'
start_t = time.time()
with torch.no_grad():
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
image = model.get_final_output(
styles=ws, camera_matrices=cam,
theta=roll * np.pi,
render_option=render_option)
end_t = time.time()
image = image[0].permute(1,2,0).detach().cpu().numpy().clip(-1, 1) * 0.5 + 0.5
if imgs.shape[0] == image.shape[0]:
image = np.concatenate([imgs, image], 1)
else:
a = image.shape[0]
b = int(imgs.shape[1] / imgs.shape[0] * a)
print(f'resize {a} {b} {image.shape} {imgs.shape}')
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
print(f'rendering time = {end_t-start_t:.4f}s')
image = (image * 255).astype('uint8')
return image, history
model_name = gr.inputs.Dropdown(['FFHQ512']) # 'FFHQ512v2', 'AFHQ512', 'MetFaces512', 'CompCars256', 'FFHQ1024'
model_find = gr.inputs.Textbox(label="checkpoint path", default="")
render_option = gr.inputs.Textbox(label="rendering options", default='steps:40')
trunc = gr.inputs.Slider(default=0.7, maximum=1.0, minimum=0.0, label='truncation trick')
seed1 = gr.inputs.Number(default=1, label="seed1")
seed2 = gr.inputs.Number(default=9, label="seed2")
mix1 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="linear mixing ratio (geometry)")
mix2 = gr.inputs.Slider(minimum=0, maximum=1, default=0, label="linear mixing ratio (apparence)")
early = gr.inputs.Radio(['None', 'Normal Map', 'Gradient Map'], default='None', label='intermedia output')
yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="yaw")
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
css = ".output_image {height: 40rem !important; width: 100% !important;}"
gr.Interface(fn=f_synthesis,
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
outputs=["image", "state"],
layout='unaligned',
css=css, theme='dark-huggingface',
live=True).launch(enable_queue=True)