import streamlit as st |
import util |
import torch |
import render_util |
import math |
from pathlib import Path |
from models import PosADANet |
import json |
import plotly.graph_objects as go |
import gdown |
point_color = "rgb(30, 20, 160)" |
FILE_PC_KEY = 'File' |
@st.cache_resource |
def load_model(path: str, num_controls: int, url: str): |
""" |
Load model from memory, or download from drive |
:param path: path to save/load the pretrained model |
:param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness) |
:param url: google drive url to download the model if its not already downloaded |
:return: returns the pretrained model |
""" |
if not Path(path).exists(): |
with st.spinner('Downloading Model'): |
gdown.download(url, path, quiet=False) |
model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device) |
model.load_state_dict(torch.load(path, map_location=device)) |
model.eval() |
return model |
def load_dict_data(path: str): |
""" |
load a json file |
:param path: path to json file |
:return: dict with json data |
""" |
with open(path, 'r') as file: |
data = json.load(file) |
return data |
def to_rgb(hex_color: str): |
""" |
convert color in hex format to rgb format |
:param hex_color: color hex string |
:return: list of three numbers for RGB channels between 0-1 |
""" |
h = hex_color.lstrip('#') |
return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)] |
st.title('Z2P - Demo') |
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')) |
st.subheader('Settings') |
model_data = load_dict_data('models/default_settings.json') |
pc_data = load_dict_data('point_clouds/default_settings.json') |
col1_head, col2_head = st.columns(2) |
model_key = col2_head.radio( |
'Choose Model', |
model_data.keys()) |
pc_key = col2_head.radio( |
'Choose Point Cloud', |
pc_data.keys()) |
uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)') |
if pc_key == FILE_PC_KEY: |
if uploaded_file is not None: |
txt = uploaded_file.getvalue().decode("utf-8") |
pc = util.xyz2tensor(txt, append_normals=True) |
else: |
st.warning('Please upload a .xyz or .obj file') |
st.stop() |
else: |
pc = util.read_xyz_file(pc_data[pc_key]['path']) |
st.header('Input') |
col1, col2 = st.columns(2) |
col2.subheader("Point Cloud Transformations") |
scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale']) |
rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx']) |
ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry']) |
rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz']) |
dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy']) |
col1.subheader("Input Z-Buffer") |
pc = render_util.rotate_pc(pc, rx, ry, rz) |
trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers", |
marker=dict( |
symbol="circle", |
size=1, |
color=point_color))] |
fig = go.Figure(trace1, layout=go.Layout()) |
col1_head.plotly_chart(fig, use_container_width=True) |
zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale) |
col1.image(zbuffer / zbuffer.max(), use_column_width=True) |
zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device) |
st.header('Result') |
len_style = model_data[model_key]['len_style'] |
model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url']) |
col1, col2 = st.columns(2) |
col2.subheader('Visualization Controls') |
zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1) |
zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0) |
style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device) |
hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR) |
style[0], style[1], style[2] = to_rgb(hex_color) |
style[:3] = style[:3].clip(0.0, 0.9) |
style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0) |
style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) |
style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) |
if len_style == 8: |
style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5) |
style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5) |
style = style.unsqueeze(0) |
with torch.no_grad(): |
generated = model(zbuffer.float(), style) |
generated = util.embed_color(generated.detach(), style[:, :3], box_size=50) |
rendered = generated[0].permute(1, 2, 0).cpu().numpy() |
col1.image(rendered.clip(0, 1), use_column_width=True) |