import streamlit as st import util import torch import render_util import math from pathlib import Path from models import PosADANet import json import plotly.graph_objects as go import gdown point_color = "rgb(30, 20, 160)" FILE_PC_KEY = 'File' DEFAULT_COLOR = '#E1E1E1' @st.cache_resource def load_model(path: str, num_controls: int, url: str): """ Load model from memory, or download from drive :param path: path to save/load the pretrained model :param num_controls: length of style/control vector the model requires (6 for regular, 8 for metallic roughness) :param url: google drive url to download the model if its not already downloaded :return: returns the pretrained model """ if not Path(path).exists(): with st.spinner('Downloading Model'): gdown.download(url, path, quiet=False) model = PosADANet(1, 4, num_controls, padding='zeros', bilinear=True).to(device) model.load_state_dict(torch.load(path, map_location=device)) model.eval() return model def load_dict_data(path: str): """ load a json file :param path: path to json file :return: dict with json data """ with open(path, 'r') as file: data = json.load(file) return data def to_rgb(hex_color: str): """ convert color in hex format to rgb format :param hex_color: color hex string :return: list of three numbers for RGB channels between 0-1 """ h = hex_color.lstrip('#') return [float(int(h[i:i + 2], 16)) / 255 for i in (0, 2, 4)] st.title('Z2P - Demo') device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')) st.subheader('Settings') # Load model and pc data for info about predefined demo point clouds and pretrained models model_data = load_dict_data('models/default_settings.json') pc_data = load_dict_data('point_clouds/default_settings.json') col1_head, col2_head = st.columns(2) model_key = col2_head.radio( 'Choose Model', model_data.keys()) pc_key = col2_head.radio( 'Choose Point Cloud', pc_data.keys()) uploaded_file = col2_head.file_uploader('Upload Your Own Point Cloud (.xyz, .obj)') if pc_key == FILE_PC_KEY: # Use point cloud uploaded by user if uploaded_file is not None: txt = uploaded_file.getvalue().decode("utf-8") pc = util.xyz2tensor(txt, append_normals=True) else: st.warning('Please upload a .xyz or .obj file') st.stop() else: # Load demo point cloud pc = util.read_xyz_file(pc_data[pc_key]['path']) st.header('Input') col1, col2 = st.columns(2) # parameters for point cloud spacial transformations col2.subheader("Point Cloud Transformations") scale = col2.slider('Scale', min_value=0.0, max_value=5.0, value=pc_data[pc_key]['scale']) rx = col2.slider('X-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rx']) ry = col2.slider('Y-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['ry']) rz = col2.slider('Z-Rotation', min_value=-math.pi, max_value=math.pi, value=pc_data[pc_key]['rz']) dy = col2.slider('Height', min_value=0, max_value=500, value=pc_data[pc_key]['dy']) col1.subheader("Input Z-Buffer") # apply transformations pc = render_util.rotate_pc(pc, rx, ry, rz) trace1 = [go.Scatter3d(x=pc[:, 0], y=pc[:, 1], z=-pc[:, 2], mode="markers", marker=dict( symbol="circle", size=1, color=point_color))] fig = go.Figure(trace1, layout=go.Layout()) col1_head.plotly_chart(fig, use_container_width=True) # Project and render the point z-buffer zbuffer = render_util.draw_pc(pc, radius=model_data[model_key]['point_radius'], dy=dy, scale=scale) # Show input z-buffer visualization in streamlit col1.image(zbuffer / zbuffer.max(), use_column_width=True) zbuffer: torch.Tensor = torch.from_numpy(zbuffer).float().to(device) st.header('Result') len_style = model_data[model_key]['len_style'] # Load pretrained model model = load_model(model_data[model_key]['path'], len_style, model_data[model_key]['url']) col1, col2 = st.columns(2) col2.subheader('Visualization Controls') zbuffer = zbuffer.unsqueeze(-1).permute(2, 0, 1) zbuffer: torch.Tensor = zbuffer.float().to(device).unsqueeze(0) style = torch.zeros(len_style, dtype=zbuffer.dtype, device=device) # Pick color and light direction visualization parameters hex_color = col2.color_picker('Pick A Color', DEFAULT_COLOR) style[0], style[1], style[2] = to_rgb(hex_color) style[:3] = style[:3].clip(0.0, 0.9) # Light direction style[3] = col2.slider('Light Radius', min_value=-1.0, max_value=1.0, value=0.0) # delta_r style[4] = col2.slider('Light Phi', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) # np.pi / 4 # delta_phi style[5] = col2.slider('Light Theta', min_value=-math.pi/4, max_value=math.pi/4, value=0.0) # delta_theta # Extra Controls for Metallic and Roughness Model if len_style == 8: style[6] = col2.slider('Mettalic', min_value=0.0, max_value=1.0, value=0.5) style[7] = col2.slider('Roughness', min_value=0.0, max_value=1.0, value=0.5) style = style.unsqueeze(0) # generate image with pretrained model with torch.no_grad(): generated = model(zbuffer.float(), style) # embed a white background behind the object using the alpha map # as well as the color used as input in the bottom right corner generated = util.embed_color(generated.detach(), style[:, :3], box_size=50) rendered = generated[0].permute(1, 2, 0).cpu().numpy() # show the image in streamlit col1.image(rendered.clip(0, 1), use_column_width=True)