|
import streamlit as st |
|
import util |
|
import torch |
|
import render_util |
|
import math |
|
from pathlib import Path |
|
from models import PosADANet |
|
import json |
|
import plotly.graph_objects as go |
|
import gdown |
|
|
|
|
|
point_color = "rgb(30, 20, 160)" |
|
FILE_PC_KEY = 'File' |
|
DEFAULT_COLOR = '#E1E1E1' |
|
|
|
|
|
@st.cache_resource |
|
def load_model(path: str, num_controls: int, url: str): |
|
""" |
|
Load model from memory, or download from drive |
|
:param path: path to save/load the pretrained model |
|
:param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness) |
|
:param url: google drive url to download the model if its not already downloaded |
|
:return: returns the pretrained model |
|
""" |
|
if not Path(path).exists(): |
|
with st.spinner('Downloading Model'): |
|
gdown.download(url, path, quiet=False) |
|
|
|
model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device) |
|
model.load_state_dict(torch.load(path, map_location=device)) |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
def load_dict_data(path: str): |
|
""" |
|
load a json file |
|
:param path: path to json file |
|
:return: dict with json data |
|
""" |
|
with open(path, 'r') as file: |
|
data = json.load(file) |
|
|
|
return data |
|
|
|
|
|
def to_rgb(hex_color: str): |
|
""" |
|
convert color in hex format to rgb format |
|
:param hex_color: color hex string |
|
:return: list of three numbers for RGB channels between 0-1 |
|
""" |
|
h = hex_color.lstrip('#') |
|
return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)] |
|
|
|
|
|
st.title('Z2P - Demo') |
|
|
|
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')) |
|
|
|
st.subheader('Settings') |
|
|
|
|
|
model_data = load_dict_data('models/default_settings.json') |
|
pc_data = load_dict_data('point_clouds/default_settings.json') |
|
|
|
col1_head, col2_head = st.columns(2) |
|
model_key = col2_head.radio( |
|
'Choose Model', |
|
model_data.keys()) |
|
|
|
pc_key = col2_head.radio( |
|
'Choose Point Cloud', |
|
pc_data.keys()) |
|
|
|
uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)') |
|
|
|
if pc_key == FILE_PC_KEY: |
|
|
|
if uploaded_file is not None: |
|
txt = uploaded_file.getvalue().decode("utf-8") |
|
pc = util.xyz2tensor(txt, append_normals=True) |
|
else: |
|
st.warning('Please upload a .xyz or .obj file') |
|
st.stop() |
|
else: |
|
|
|
pc = util.read_xyz_file(pc_data[pc_key]['path']) |
|
|
|
st.header('Input') |
|
col1, col2 = st.columns(2) |
|
|
|
|
|
col2.subheader("Point Cloud Transformations") |
|
scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale']) |
|
rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx']) |
|
ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry']) |
|
rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz']) |
|
dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy']) |
|
|
|
col1.subheader("Input Z-Buffer") |
|
|
|
|
|
pc = render_util.rotate_pc(pc, rx, ry, rz) |
|
trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers", |
|
marker=dict( |
|
symbol="circle", |
|
size=1, |
|
color=point_color))] |
|
fig = go.Figure(trace1, layout=go.Layout()) |
|
col1_head.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale) |
|
|
|
|
|
col1.image(zbuffer / zbuffer.max(), use_column_width=True) |
|
|
|
zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device) |
|
|
|
st.header('Result') |
|
|
|
len_style = model_data[model_key]['len_style'] |
|
|
|
model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url']) |
|
col1, col2 = st.columns(2) |
|
col2.subheader('Visualization Controls') |
|
zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1) |
|
zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0) |
|
|
|
style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device) |
|
|
|
|
|
hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR) |
|
style[0], style[1], style[2] = to_rgb(hex_color) |
|
style[:3] = style[:3].clip(0.0, 0.9) |
|
|
|
|
|
style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0) |
|
style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) |
|
style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) |
|
|
|
|
|
if len_style == 8: |
|
style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5) |
|
style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5) |
|
|
|
style = style.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
generated = model(zbuffer.float(), style) |
|
|
|
|
|
|
|
generated = util.embed_color(generated.detach(), style[:, :3], box_size=50) |
|
rendered = generated[0].permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
col1.image(rendered.clip(0, 1), use_column_width=True) |
|
|