File size: 3,644 Bytes
0d0e451
 
 
 
 
 
7872317
 
0d0e451
 
 
 
 
 
 
 
 
 
 
d5c0caa
0d0e451
7872317
 
 
 
 
 
 
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c0caa
0d0e451
d5c0caa
0d0e451
 
 
 
 
d5c0caa
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7872317
0d0e451
 
 
 
 
7872317
0d0e451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random

os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")


from models.vae_gaussian import *
from models.vae_flow import *

airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")

device='cuda' if torch.cuda.is_available() else 'cpu'

ckpt_airplane = torch.load(airplane,map_location=torch.device(device))
ckpt_chair = torch.load(chair,map_location=torch.device(device))

def seed_all(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def normalize_point_clouds(pcs,mode):
    if mode is None:
        return pcs
    for i in range(pcs.size(0)):
        pc = pcs[i]
        if mode == 'shape_unit':
            shift = pc.mean(dim=0).reshape(1, 3)
            scale = pc.flatten().std().reshape(1, 1)
        elif mode == 'shape_bbox':
            pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
            pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
            shift = ((pc_min + pc_max) / 2).view(1, 3)
            scale = (pc_max - pc_min).max().reshape(1, 1) / 2
        pc = (pc - shift) / scale
        pcs[i] = pc
    return pcs

def predict(Seed,ckpt):
  if Seed==None:
    Seed=777 
  seed_all(Seed)

  if ckpt['args'].model == 'gaussian':
      model = GaussianVAE(ckpt['args']).to(device)
  elif ckpt['args'].model == 'flow':
      model = FlowVAE(ckpt['args']).to(device)

  model.load_state_dict(ckpt['state_dict'])
  # Generate Point Clouds
  gen_pcs = []
  with torch.no_grad():
      z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
      x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
      gen_pcs.append(x.detach().cpu())
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
  gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")

  return gen_pcs[0]

def generate(seed,value):
    if value=="Airplane":
      ckpt=ckpt_airplane
    elif value=="Chair":
      ckpt=ckpt_chair
    else :
       ckpt=ckpt_airplane

    colors=(238, 75, 43)
    points=predict(seed,ckpt)
    num_points=points.shape[0]


    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=points[:,0], y=points[:,1], z=points[:,2], 
                mode='markers',
                marker=dict(size=1, color=colors)
            )
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False)
            )
        )
    )
    return fig
    
markdown=f'''
  # Diffusion Probabilistic Models for 3D Point Cloud Generation
  [[The Paper](https://arxiv.org/abs/2103.01458)] [[Original Code](https://github.com/luost26/diffusion-point-cloud)]
  The space demo for our CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".

  It is running on {device}
'''
with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            gr.Markdown(markdown)
        with gr.Row():
            seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
            value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")

        btn = gr.Button(value="Generate")
        point_cloud = gr.Plot()
    demo.load(generate, [seed,value], point_cloud)
    btn.click(generate, [seed,value], point_cloud)

demo.launch()