|
import os |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import sys |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
|
|
os.system("git clone https://github.com/luost26/diffusion-point-cloud") |
|
sys.path.append("diffusion-point-cloud") |
|
|
|
|
|
from models.vae_gaussian import * |
|
from models.vae_flow import * |
|
|
|
airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main") |
|
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main") |
|
|
|
|
|
ckpt_airplane = torch.load(airplane) |
|
ckpt_chair = torch.load(chair) |
|
|
|
def normalize_point_clouds(pcs,mode): |
|
if mode is None: |
|
return pcs |
|
for i in range(pcs.size(0)): |
|
pc = pcs[i] |
|
if mode == 'shape_unit': |
|
shift = pc.mean(dim=0).reshape(1, 3) |
|
scale = pc.flatten().std().reshape(1, 1) |
|
elif mode == 'shape_bbox': |
|
pc_max, _ = pc.max(dim=0, keepdim=True) |
|
pc_min, _ = pc.min(dim=0, keepdim=True) |
|
shift = ((pc_min + pc_max) / 2).view(1, 3) |
|
scale = (pc_max - pc_min).max().reshape(1, 1) / 2 |
|
pc = (pc - shift) / scale |
|
pcs[i] = pc |
|
return pcs |
|
|
|
def predict(Seed,ckpt): |
|
if Seed==None: |
|
Seed=777 |
|
seed_all(Seed) |
|
|
|
if ckpt['args'].model == 'gaussian': |
|
model = GaussianVAE(ckpt['args']).to("cuda") |
|
elif ckpt['args'].model == 'flow': |
|
model = FlowVAE(ckpt['args']).to("cuda") |
|
|
|
model.load_state_dict(ckpt['state_dict']) |
|
|
|
gen_pcs = [] |
|
with torch.no_grad(): |
|
z = torch.randn([1, ckpt['args'].latent_dim]).to("cuda") |
|
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility) |
|
gen_pcs.append(x.detach().cpu()) |
|
gen_pcs = torch.cat(gen_pcs, dim=0)[:1] |
|
gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox") |
|
|
|
return gen_pcs[0] |
|
|
|
def generate(seed,value): |
|
if value=="Airplane": |
|
ckpt=ckpt_airplane |
|
elif value=="Chair": |
|
ckpt=ckpt_chair |
|
else : |
|
ckpt=ckpt_airplane |
|
|
|
print(value) |
|
colors=(238, 75, 43) |
|
points=predict(seed,ckpt) |
|
num_points=points.shape[0] |
|
|
|
|
|
fig = go.Figure( |
|
data=[ |
|
go.Scatter3d( |
|
x=points[:,0], y=points[:,1], z=points[:,2], |
|
mode='markers', |
|
marker=dict(size=1, color=colors) |
|
) |
|
], |
|
layout=dict( |
|
scene=dict( |
|
xaxis=dict(visible=False), |
|
yaxis=dict(visible=False), |
|
zaxis=dict(visible=False) |
|
) |
|
) |
|
) |
|
return fig |
|
markdown=f''' |
|
# Diffusion Probabilistic Models for 3D Point Cloud Generation |
|
|
|
[[The Paper](https://arxiv.org/abs/2103.01458)] [[Original Code](https://github.com/luost26/diffusion-point-cloud)] |
|
|
|
The space demo for our CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation". |
|
|
|
|
|
''' |
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown(markdown) |
|
with gr.Row(): |
|
seed = gr.Slider( minimum=0, maximum=2**16,label='Seed') |
|
value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type") |
|
|
|
btn = gr.Button(value="Generate") |
|
point_cloud = gr.Plot() |
|
demo.load(generate, [seed,value], point_cloud) |
|
btn.click(generate, [seed,value], point_cloud) |
|
|
|
demo.launch() |