File size: 5,573 Bytes
cd438c2 7b8151e cd438c2 fce9719 cd438c2 7b8151e cd438c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import streamlit as st
import util
import torch
import render_util
import math
from pathlib import Path
from models import PosADANet
import json
import plotly.graph_objects as go
import gdown
point_color = "rgb(30, 20, 160)"
FILE_PC_KEY = 'File'
def load_model(path: str, num_controls: int, url: str):
Load model from memory, or download from drive
:param path: path to save/load the pretrained model
:param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness)
:param url: google drive url to download the model if its not already downloaded
:return: returns the pretrained model
if not Path(path).exists():
with st.spinner('Downloading Model'):, path, quiet=False)
model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device)
model.load_state_dict(torch.load(path, map_location=device))
return model
def load_dict_data(path: str):
load a json file
:param path: path to json file
:return: dict with json data
with open(path, 'r') as file:
data = json.load(file)
return data
def to_rgb(hex_color: str):
convert color in hex format to rgb format
:param hex_color: color hex string
:return: list of three numbers for RGB channels between 0-1
h = hex_color.lstrip('#')
return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)]
st.title('Z2P - Demo')
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu'))
# Load model and pc data for info about predefined demo point clouds and pretrained models
model_data = load_dict_data('models/default_settings.json')
pc_data = load_dict_data('point_clouds/default_settings.json')
col1_head, col2_head = st.columns(2)
model_key =
'Choose Model',
pc_key =
'Choose Point Cloud',
uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)')
if pc_key == FILE_PC_KEY:
# Use point cloud uploaded by user
if uploaded_file is not None:
txt = uploaded_file.getvalue().decode("utf-8")
pc = util.xyz2tensor(txt, append_normals=True)
st.warning('Please upload a .xyz or .obj file')
# Load demo point cloud
pc = util.read_xyz_file(pc_data[pc_key]['path'])
col1, col2 = st.columns(2)
# parameters for point cloud spacial transformations
col2.subheader("Point Cloud Transformations")
scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale'])
rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx'])
ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry'])
rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz'])
dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy'])
col1.subheader("Input Z-Buffer")
# apply transformations
pc = render_util.rotate_pc(pc, rx, ry, rz)
trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers",
fig = go.Figure(trace1, layout=go.Layout())
col1_head.plotly_chart(fig, use_container_width=True)
# Project and render the point z-buffer
zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale)
# Show input z-buffer visualization in streamlit
col1.image(zbuffer / zbuffer.max(), use_column_width=True)
zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device)
len_style = model_data[model_key]['len_style']
# Load pretrained model
model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url'])
col1, col2 = st.columns(2)
col2.subheader('Visualization Controls')
zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1)
zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0)
style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device)
# Pick color and light direction visualization parameters
hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR)
style[0], style[1], style[2] = to_rgb(hex_color)
style[:3] = style[:3].clip(0.0, 0.9)
# Light direction
style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0) # delta_r
style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) # np.pi / 4 # delta_phi
style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) # delta_theta
# Extra Controls for Metallic and Roughness Model
if len_style == 8:
style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5)
style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5)
style = style.unsqueeze(0)
# generate image with pretrained model
with torch.no_grad():
generated = model(zbuffer.float(), style)
# embed a white background behind the object using the alpha map
# as well as the color used as input in the bottom right corner
generated = util.embed_color(generated.detach(), style[:, :3], box_size=50)
rendered = generated[0].permute(1, 2, 0).cpu().numpy()
# show the image in streamlit
col1.image(rendered.clip(0, 1), use_column_width=True)