File size: 3,870 Bytes
0d0e451
 
 
 
 
 
7872317
 
0d0e451
 
 
 
 
 
 
 
2f68096
 
0d0e451
d5c0caa
0d0e451
7872317
 
 
 
 
 
 
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c0caa
0d0e451
d5c0caa
0d0e451
 
 
 
 
d5c0caa
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7872317
0d0e451
 
2f68096
 
0d0e451
7872317
2f68096
 
 
 
 
 
 
 
 
 
 
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random

os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")


from models.vae_gaussian import *
from models.vae_flow import *

airplane=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
chair="./GEN_chair.pt"

device='cuda' if torch.cuda.is_available() else 'cpu'

ckpt_airplane = torch.load(airplane,map_location=torch.device(device))
ckpt_chair = torch.load(chair,map_location=torch.device(device))

def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def normalize_point_clouds(pcs,mode):
    if mode is None:
        return pcs
    for i in range(pcs.size(0)):
        pc = pcs[i]
        if mode == 'shape_unit':
            shift = pc.mean(dim=0).reshape(1, 3)
            scale = pc.flatten().std().reshape(1, 1)
        elif mode == 'shape_bbox':
            pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
            pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
            shift = ((pc_min + pc_max) / 2).view(1, 3)
            scale = (pc_max - pc_min).max().reshape(1, 1) / 2
        pc = (pc - shift) / scale
        pcs[i] = pc
    return pcs

def predict(Seed,ckpt):
  if Seed==None:
    Seed=777 
  seed_all(Seed)

  if ckpt['args'].model == 'gaussian':
      model = GaussianVAE(ckpt['args']).to(device)
  elif ckpt['args'].model == 'flow':
      model = FlowVAE(ckpt['args']).to(device)

  model.load_state_dict(ckpt['state_dict'])
  # Generate Point Clouds
  gen_pcs = []
  with torch.no_grad():
      z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
      x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
      gen_pcs.append(x.detach().cpu())
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
  gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")

  return gen_pcs[0]

def generate(seed,value):
    if value=="Airplane":
      ckpt=ckpt_airplane
    elif value=="Chair":
      ckpt=ckpt_chair
    else :
       ckpt=ckpt_airplane

    colors=(238, 75, 43)
    points=predict(seed,ckpt)
    num_points=points.shape[0]


    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=points[:,0], y=points[:,1], z=points[:,2], 
                mode='markers',
                marker=dict(size=1, color=colors)
            )
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False)
            )
        )
    )
    return fig
    
markdown=f'''
  # Diffusion Probabilistic Models for 3D Point Cloud Generation
  [The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
  [For the official implementation.](https://github.com/luost26/diffusion-point-cloud)

  It is running on {device}
  
  ### Citation By
  
  @inproceedings{luo2021diffusion,
  author = {Luo, Shitong and Hu, Wei},
  title = {Diffusion Probabilistic Models for 3D Point Cloud Generation},
  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month = {June},
  year = {2021}
    }
   
'''
with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            gr.Markdown(markdown)
        with gr.Row():
            seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
            value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")

        btn = gr.Button(value="Generate")
        point_cloud = gr.Plot()
    demo.load(generate, [seed,value], point_cloud)
    btn.click(generate, [seed,value], point_cloud)

demo.launch()