z2p / app.py
galmetzer's picture
Update app.py
7b8151e
raw
history blame
5.56 kB
import streamlit as st
import util
import torch
import render_util
import math
from pathlib import Path
from models import PosADANet
import json
import plotly.graph_objects as go
import gdown
point_color = "rgb(30, 20, 160)"
FILE_PC_KEY = 'File'
DEFAULT_COLOR = '#E1E1E1'
@st.cache
def load_model(path: str, num_controls: int, url: str):
"""
Load model from memory, or download from drive
:param path: path to save/load the pretrained model
:param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness)
:param url: google drive url to download the model if its not already downloaded
:return: returns the pretrained model
"""
if not Path(path).exists():
with st.spinner('Downloading Model'):
gdown.download(url, path, quiet=False)
model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model
def load_dict_data(path: str):
"""
load a json file
:param path: path to json file
:return: dict with json data
"""
with open(path, 'r') as file:
data = json.load(file)
return data
def to_rgb(hex_color: str):
"""
convert color in hex format to rgb format
:param hex_color: color hex string
:return: list of three numbers for RGB channels between 0-1
"""
h = hex_color.lstrip('#')
return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)]
st.title('Z2P - Demo')
device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu'))
st.subheader('Settings')
# Load model and pc data for info about predefined demo point clouds and pretrained models
model_data = load_dict_data('models/default_settings.json')
pc_data = load_dict_data('point_clouds/default_settings.json')
col1_head, col2_head = st.columns(2)
model_key = col2_head.radio(
'Choose Model',
model_data.keys())
pc_key = col2_head.radio(
'Choose Point Cloud',
pc_data.keys())
uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)')
if pc_key == FILE_PC_KEY:
# Use point cloud uploaded by user
if uploaded_file is not None:
txt = uploaded_file.getvalue().decode("utf-8")
pc = util.xyz2tensor(txt, append_normals=True)
else:
st.warning('Please upload a .xyz or .obj file')
st.stop()
else:
# Load demo point cloud
pc = util.read_xyz_file(pc_data[pc_key]['path'])
st.header('Input')
col1, col2 = st.columns(2)
# parameters for point cloud spacial transformations
col2.subheader("Point Cloud Transformations")
scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale'])
rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx'])
ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry'])
rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz'])
dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy'])
col1.subheader("Input Z-Buffer")
# apply transformations
pc = render_util.rotate_pc(pc, rx, ry, rz)
trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers",
marker=dict(
symbol="circle",
size=1,
color=point_color))]
fig = go.Figure(trace1, layout=go.Layout())
col1_head.plotly_chart(fig, use_container_width=True)
# Project and render the point z-buffer
zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale)
# Show input z-buffer visualization in streamlit
col1.image(zbuffer / zbuffer.max(), use_column_width=True)
zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device)
st.header('Result')
len_style = model_data[model_key]['len_style']
# Load pretrained model
model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url'])
col1, col2 = st.columns(2)
col2.subheader('Visualization Controls')
zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1)
zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0)
style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device)
# Pick color and light direction visualization parameters
hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR)
style[0], style[1], style[2] = to_rgb(hex_color)
style[:3] = style[:3].clip(0.0, 0.9)
# Light direction
style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0) # delta_r
style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) # np.pi / 4 # delta_phi
style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) # delta_theta
# Extra Controls for Metallic and Roughness Model
if len_style == 8:
style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5)
style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5)
style = style.unsqueeze(0)
# generate image with pretrained model
with torch.no_grad():
generated = model(zbuffer.float(), style)
# embed a white background behind the object using the alpha map
# as well as the color used as input in the bottom right corner
generated = util.embed_color(generated.detach(), style[:, :3], box_size=50)
rendered = generated[0].permute(1, 2, 0).cpu().numpy()
# show the image in streamlit
col1.image(rendered.clip(0, 1), use_column_width=True)