SDF-StyleGan-3D / app.py
SerdarHelli's picture
Update app.py
681d682
raw
history blame contribute delete
No virus
4.65 kB
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
os.system("git clone https://github.com/Zhengxinyang/SDF-StyleGAN.git")
sys.path.append("SDF-StyleGAN")
#Codes reference : https://github.com/Zhengxinyang/SDF-StyleGAN
from utils.utils import evaluate_in_chunks, scale_to_unit_sphere
from network.model import StyleGAN2_3D
def noise(batch_size, latent_dim, device):
return torch.randn(batch_size, latent_dim,device=device)
def noise_list(batch_size, layers, latent_dim, device):
return [(noise(batch_size, latent_dim, device), layers)]
def volume_noise(n, vol_size, device):
if device=="cuda":
return torch.FloatTensor(n, vol_size, vol_size, vol_size, 1).uniform_(0., 1.).cuda(device)
return torch.FloatTensor(n, vol_size, vol_size, vol_size, 1).uniform_(0., 1.)
class StyleGAN2_3D_not_cuda(StyleGAN2_3D):
@torch.no_grad()
def generate_feature_volume(self, ema=False, trunc_psi=0.75):
latents = noise_list(
1, self.num_layers, self.latent_dim, device=self.device)
n = volume_noise(1, self.G_vol_size, device=self.device)
if ema:
generate_voxels = self.generate_truncated(
self.SE, self.GE, latents, n, trunc_psi)
else:
generate_voxels = self.generate_truncated(
self.S, self.G, latents, n, trunc_psi)
return generate_voxels
cars=hf_hub_download("SerdarHelli/SDF-StyleGAN-3D", filename="cars.ckpt",revision="main")
#default model
device='cuda' if torch.cuda.is_available() else 'cpu'
models={"Car":cars,
"Airplane":"./planes.ckpt",
"Chair":"./chairs.ckpt",
"Rifle":"./rifles.ckpt",
"Table":"./tables.ckpt"
}
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def predict(seed,model,trunc_psi):
if seed==None:
seed=777
seed_all(seed)
if trunc_psi==None:
trunc_psi=1
z = noise(100000, model.latent_dim, device=model.device)
samples = evaluate_in_chunks(1000, model.SE, z)
model.av = torch.mean(samples, dim=0, keepdim=True)
mesh = model.generate_mesh(
ema=True, mc_vol_size=64, level=-0.015, trunc_psi=trunc_psi)
mesh = scale_to_unit_sphere(mesh)
x=np.asarray(mesh.vertices).T[0]
y=np.asarray(mesh.vertices).T[1]
z=np.asarray(mesh.vertices).T[2]
i=np.asarray(mesh.faces).T[0]
j=np.asarray(mesh.faces).T[1]
k=np.asarray(mesh.faces).T[2]
return x,y,z,i,j,k
def generate(seed,model_name,trunc_psi):
print(model_name)
try :
ckpt=models[model_name]
except KeyError:
ckpt=cars
if device=="cuda":
model = StyleGAN2_3D.load_from_checkpoint(ckpt).cuda(0)
else:
model = StyleGAN2_3D_not_cuda.load_from_checkpoint(ckpt)
model.eval()
x,y,z,i,j,k=predict(seed,model,trunc_psi)
fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
i=i, j=j, k=k,
colorscale="Viridis",
colorbar_len=0.75,
flatshading=True,
lighting=dict(ambient=0.5,
diffuse=1,
fresnel=4,
specular=0.5,
roughness=0.05,
facenormalsepsilon=0,
vertexnormalsepsilon=0),
lightposition=dict(x=100,
y=100,
z=1000)))
return fig
markdown=f'''
# SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation
[The space demo for the SGP 2022 paper "SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation".](https://arxiv.org/abs/2206.12055)
[For the official implementation.](https://github.com/Zhengxinyang/SDF-StyleGAN)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on {device}
'''
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
model_name=gr.Dropdown(choices=["Car","Airplane","Chair","Rifle","Table"],label="Choose Model Type")
trunc_psi = gr.Slider( minimum=0, maximum=2,label='Truncate PSI')
btn = gr.Button(value="Generate")
mesh = gr.Plot()
demo.load(generate, [seed,model_name,trunc_psi], mesh)
btn.click(generate, [seed,model_name,trunc_psi], mesh)
demo.launch(debug=True)