File size: 5,573 Bytes
cd438c2
 
 
 
 
 
 
 
7b8151e
cd438c2
 
 
 
 
 
 
 
fce9719
cd438c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b8151e
 
 
 
 
 
 
cd438c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import streamlit as st
import util
import torch
import render_util
import math
from pathlib import Path
from models import PosADANet
import json
import plotly.graph_objects as go
import gdown


point_color = "rgb(30, 20, 160)"
FILE_PC_KEY = 'File'
DEFAULT_COLOR = '#E1E1E1'


@st.cache_resource
def load_model(path: str, num_controls: int, url: str):
    """
    Load model from memory, or download from drive
    :param path: path to save/load the pretrained model
    :param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness)
    :param url: google drive url to download the model if its not already downloaded
    :return: returns the pretrained model
    """
    if not Path(path).exists():
        with st.spinner('Downloading Model'):
            gdown.download(url, path, quiet=False)

    model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()

    return model


def load_dict_data(path: str):
    """
    load a json file
    :param path: path to json file
    :return: dict with json data
    """
    with open(path, 'r') as file:
        data = json.load(file)

    return data


def to_rgb(hex_color: str):
    """
    convert color in hex format to rgb format
    :param hex_color: color hex string
    :return: list of three numbers for RGB channels between 0-1
    """
    h = hex_color.lstrip('#')
    return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)]


st.title('Z2P - Demo')

device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu'))

st.subheader('Settings')

# Load model and pc data for info about predefined demo point clouds and pretrained models
model_data = load_dict_data('models/default_settings.json')
pc_data = load_dict_data('point_clouds/default_settings.json')

col1_head, col2_head = st.columns(2)
model_key = col2_head.radio(
         'Choose Model',
         model_data.keys())

pc_key = col2_head.radio(
         'Choose Point Cloud',
         pc_data.keys())

uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)')

if pc_key == FILE_PC_KEY:
    # Use point cloud uploaded by user
    if uploaded_file is not None:
        txt = uploaded_file.getvalue().decode("utf-8")
        pc = util.xyz2tensor(txt, append_normals=True)
    else:
        st.warning('Please upload a .xyz or .obj file')
        st.stop()
else:
    # Load demo point cloud
    pc = util.read_xyz_file(pc_data[pc_key]['path'])

st.header('Input')
col1, col2 = st.columns(2)

# parameters for point cloud spacial transformations
col2.subheader("Point Cloud Transformations")
scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale'])
rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx'])
ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry'])
rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz'])
dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy'])

col1.subheader("Input Z-Buffer")

# apply transformations
pc = render_util.rotate_pc(pc, rx, ry, rz)
trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers",
    marker=dict(
        symbol="circle",
        size=1,
        color=point_color))]
fig = go.Figure(trace1, layout=go.Layout())
col1_head.plotly_chart(fig, use_container_width=True)

# Project and render the point z-buffer
zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale)

# Show input z-buffer visualization in streamlit
col1.image(zbuffer / zbuffer.max(), use_column_width=True)

zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device)

st.header('Result')

len_style = model_data[model_key]['len_style']
# Load pretrained model
model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url'])
col1, col2 = st.columns(2)
col2.subheader('Visualization Controls')
zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1)
zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0)

style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device)

# Pick color and light direction visualization parameters
hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR)
style[0], style[1], style[2] = to_rgb(hex_color)
style[:3] = style[:3].clip(0.0, 0.9)

# Light direction
style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0)  # delta_r
style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0)  # np.pi / 4  # delta_phi
style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0)  # delta_theta

# Extra Controls for Metallic and Roughness Model
if len_style == 8:
    style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5)
    style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5)

style = style.unsqueeze(0)

# generate image with pretrained model
with torch.no_grad():
    generated = model(zbuffer.float(), style)

# embed a white background behind the object using the alpha map
# as well as the color used as input in the bottom right corner
generated = util.embed_color(generated.detach(), style[:, :3], box_size=50)
rendered = generated[0].permute(1, 2, 0).cpu().numpy()

# show the image in streamlit
col1.image(rendered.clip(0, 1), use_column_width=True)