SerdarHelli's picture
Update app.py
562e1b7
raw history blame
No virus
3.83 kB
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")
#Codes reference : https://github.com/luost26/diffusion-point-cloud
from models.vae_gaussian import *
from models.vae_flow import *
airplane=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
chair="./GEN_chair.pt"
device='cuda' if torch.cuda.is_available() else 'cpu'
ckpt_airplane = torch.load(airplane,map_location=torch.device(device))
ckpt_chair = torch.load(chair,map_location=torch.device(device))
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def normalize_point_clouds(pcs,mode):
if mode is None:
return pcs
for i in range(pcs.size(0)):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
pc = (pc - shift) / scale
pcs[i] = pc
return pcs
def predict(Seed,ckpt):
if Seed==None:
Seed=777
seed_all(Seed)
if ckpt['args'].model == 'gaussian':
model = GaussianVAE(ckpt['args']).to(device)
elif ckpt['args'].model == 'flow':
model = FlowVAE(ckpt['args']).to(device)
model.load_state_dict(ckpt['state_dict'])
# Generate Point Clouds
gen_pcs = []
with torch.no_grad():
z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")
return gen_pcs[0]
def generate(seed,value):
if value=="Airplane":
ckpt=ckpt_airplane
elif value=="Chair":
ckpt=ckpt_chair
else :
ckpt=ckpt_airplane
colors=(238, 75, 43)
points=predict(seed,ckpt)
num_points=points.shape[0]
fig = go.Figure(
data=[
go.Scatter3d(
x=points[:,0], y=points[:,1], z=points[:,2],
mode='markers',
marker=dict(size=1, color=colors)
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False)
)
)
)
return fig
markdown=f'''
# Diffusion Probabilistic Models for 3D Point Cloud Generation
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on {device}
'''
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")
#truncate_std = gr.Slider( minimum=1, maximum=2,label='Truncate Std')
btn = gr.Button(value="Generate")
point_cloud = gr.Plot()
demo.load(generate, [seed,value], point_cloud)
btn.click(generate, [seed,value], point_cloud)
demo.launch()