File size: 3,825 Bytes
0d0e451
 
 
 
 
 
7872317
 
0d0e451
 
 
 
933cc55
0d0e451
 
 
 
2f68096
 
0d0e451
d5c0caa
0d0e451
7872317
 
 
 
 
 
 
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933cc55
 
 
562e1b7
0d0e451
 
 
 
 
d5c0caa
0d0e451
d5c0caa
0d0e451
 
 
 
 
d5c0caa
562e1b7
0d0e451
 
 
 
 
 
562e1b7
0d0e451
 
 
 
 
 
 
 
562e1b7
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7872317
0d0e451
 
933cc55
 
2f68096
933cc55
2f68096
0d0e451
933cc55
 
 
 
 
 
7872317
2f68096
61d2ddd
0d0e451
 
 
 
 
 
 
 
562e1b7
0d0e451
 
 
562e1b7
 
0d0e451
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random

os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")

#Codes reference : https://github.com/luost26/diffusion-point-cloud

from models.vae_gaussian import *
from models.vae_flow import *

airplane=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
chair="./GEN_chair.pt"

device='cuda' if torch.cuda.is_available() else 'cpu'

ckpt_airplane = torch.load(airplane,map_location=torch.device(device))
ckpt_chair = torch.load(chair,map_location=torch.device(device))

def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def normalize_point_clouds(pcs,mode):
    if mode is None:
        return pcs
    for i in range(pcs.size(0)):
        pc = pcs[i]
        if mode == 'shape_unit':
            shift = pc.mean(dim=0).reshape(1, 3)
            scale = pc.flatten().std().reshape(1, 1)
        elif mode == 'shape_bbox':
            pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
            pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
            shift = ((pc_min + pc_max) / 2).view(1, 3)
            scale = (pc_max - pc_min).max().reshape(1, 1) / 2
        pc = (pc - shift) / scale
        pcs[i] = pc
    return pcs

    


def predict(Seed,ckpt):
  if Seed==None:
    Seed=777 
  seed_all(Seed)

  if ckpt['args'].model == 'gaussian':
      model = GaussianVAE(ckpt['args']).to(device)
  elif ckpt['args'].model == 'flow':
      model = FlowVAE(ckpt['args']).to(device)

  model.load_state_dict(ckpt['state_dict'])
  # Generate Point Clouds
  gen_pcs = []
  with torch.no_grad():
      z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
      x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
      gen_pcs.append(x.detach().cpu())
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
  gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")

  return gen_pcs[0]

def generate(seed,value):
    if value=="Airplane":
      ckpt=ckpt_airplane
    elif value=="Chair":
      ckpt=ckpt_chair
    else :
       ckpt=ckpt_airplane

    colors=(238, 75, 43)
    points=predict(seed,ckpt)
    num_points=points.shape[0]


    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=points[:,0], y=points[:,1], z=points[:,2], 
                mode='markers',
                marker=dict(size=1, color=colors)
            )
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False)
            )
        )
    )
    return fig
    
markdown=f'''
  # Diffusion Probabilistic Models for 3D Point Cloud Generation

  
  [The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
  
  [For the official implementation.](https://github.com/luost26/diffusion-point-cloud)

  ### Future Work based on interest
  - Adding new models for new type objects
  - New Customization 
  
  

  It is running on {device}
  

'''
with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            gr.Markdown(markdown)
        with gr.Row():
            seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
            value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")
            #truncate_std = gr.Slider( minimum=1, maximum=2,label='Truncate Std')

        btn = gr.Button(value="Generate")
        point_cloud = gr.Plot()
    demo.load(generate, [seed,value], point_cloud)
    btn.click(generate, [seed,value], point_cloud)

demo.launch()