SDF-StyleGan-3D / app.py
SerdarHelli's picture
Update app.py
47dddc9
raw
history blame
No virus
3.79 kB
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
os.system("https://github.com/Zhengxinyang/SDF-StyleGAN.git")
sys.path.append("SDF-StyleGAN")
#Codes reference : https://github.com/Zhengxinyang/SDF-StyleGAN
from utils.utils import noise, evaluate_in_chunks, scale_to_unit_sphere, volume_noise, process_sdf, linear_slerp
from network.model import StyleGAN2_3D
cars=hf_hub_download("SerdarHelli/SDF-StyleGAN-3D", filename="cars.ckpt",revision="main")
#default model
device='cuda' if torch.cuda.is_available() else 'cpu'
if device=="cuda":
model = StyleGAN2_3D.load_from_checkpoint(cars).cuda(0)
else:
model = StyleGAN2_3D.load_from_checkpoint(cars)
model.eval()
models={"Car":cars,
"Airplane":"./planes.ckpt",
"Chair":"./chairs.ckpt",
"Rifle":"./rifles.ckpt",
"Table":"./tables.ckpt"
}
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def change_model(ckpt_path):
if device=="cuda":
model = StyleGAN2_3D.load_from_checkpoint(cars).cuda(0)
else:
model = StyleGAN2_3D.load_from_checkpoint(cars)
model.eval()
def predict(seed,trunc_psi):
if seed==None:
seed=777
seed_all(seed)
if trunc_psi==None:
trunc_psi=1
z = noise(100000, model.latent_dim, device=model.device)
samples = evaluate_in_chunks(1000, model.SE, z)
model.av = torch.mean(samples, dim=0, keepdim=True)
mesh = model.generate_mesh(
ema=True, mc_vol_size=64, level=-0.015, trunc_psi=trunc_psi)
mesh = scale_to_unit_sphere(mesh)
mesh.export("/content/asdads.obj")
x=np.asarray(mesh.vertices).T[0]
y=np.asarray(mesh.vertices).T[1]
z=np.asarray(mesh.vertices).T[2]
i=np.asarray(mesh.faces).T[0]
j=np.asarray(mesh.faces).T[1]
k=np.asarray(mesh.faces).T[2]
return x,y,z,i,j,k
def generate(seed,model_name,trunc_psi):
change_model(models[model_name])
x,y,z,i,j,k=predict(seed,trunc_psi)
fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
i=i, j=j, k=k,
colorscale="Viridis",
colorbar_len=0.75,
flatshading=True,
lighting=dict(ambient=0.5,
diffuse=1,
fresnel=4,
specular=0.5,
roughness=0.05,
facenormalsepsilon=0,
vertexnormalsepsilon=0),
lightposition=dict(x=100,
y=100,
z=1000)))
return fig
markdown=f'''
# SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation
[The space demo for the SGP 2022 paper "SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation".](https://arxiv.org/abs/2206.12055)
[For the official implementation.](https://github.com/Zhengxinyang/SDF-StyleGAN)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on {device}
'''
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
model_name=gr.Dropdown(choices=["Car","Airplane","Chair","Rifle","Table"],label="Choose Model Type")
trunc_psi = gr.Slider( minimum=0, maximum=2,label='Truncate PSI')
btn = gr.Button(value="Generate")
mesh = gr.Plot()
demo.load(generate, [seed,model_name,trunc_psi], mesh)
btn.click(generate, [seed,model_name,trunc_psi], mesh)
demo.launch(debug=True)