ppsurf / app.py
perler's picture
copy later
ef0052a
raw
history blame
7.96 kB
#!/usr/bin/env python
from __future__ import annotations
import sys
import os
import datetime
import gradio as gr
import spaces
@spaces.GPU(duration=60 * 3)
def run_on_gpu(input_point_cloud: gr.utils.NamedString,
gen_resolution_global: int,
padding_factor: float,
gen_subsample_manifold_iter: int,
gen_refine_iter: int) -> str:
print('Started inference at {}'.format(datetime.datetime.now()))
print('Inputs:', input_point_cloud, gen_resolution_global, padding_factor,
gen_subsample_manifold_iter, gen_refine_iter)
print('Types:', type(input_point_cloud), type(gen_resolution_global), type(padding_factor),
type(gen_subsample_manifold_iter), type(gen_refine_iter))
sys.path.append(os.path.abspath('ppsurf'))
import subprocess
import uuid
in_file = '{}'.format(input_point_cloud.name)
rand_hash = uuid.uuid4().hex
out_dir = '/tmp/outputs/{}'.format(rand_hash)
out_file_basename = os.path.basename(in_file) + '.ply'
out_file = os.path.join(out_dir, os.path.basename(in_file), out_file_basename)
out_file_gradio = os.path.splitext(in_file)[0] + '_ppsurf.ply'
print('in_file:', in_file)
print('out_dir:', out_dir)
print('out_file:', out_file)
print('out_file_gradio:', out_file_gradio)
os.makedirs(out_dir, exist_ok=True)
model_path = 'models/ppsurf_50nn/version_0/checkpoints/last.ckpt'
args = [
'pps.py', 'predict',
'-c', 'ppsurf/configs/poco.yaml',
'-c', 'ppsurf/configs/ppsurf.yaml',
'-c', 'ppsurf/configs/ppsurf_50nn.yaml',
'--ckpt_path', model_path,
'--data.init_args.in_file', in_file,
'--model.init_args.results_dir', out_dir,
'--trainer.logger', 'False',
'--trainer.devices', '1',
'--model.init_args.gen_resolution_global', str(gen_resolution_global),
'--data.init_args.padding_factor', str(padding_factor),
'--model.init_args.gen_subsample_manifold_iter', str(gen_subsample_manifold_iter),
'--model.init_args.gen_refine_iter', str(gen_refine_iter),
]
sys.argv = args
try:
subprocess.run(['python', 'ppsurf/pps.py'] + args[1:]) # need subprocess to spawn workers
except Exception as e:
gr.Warning("Reconstruction failed:\n{}".format(e))
print('Finished inference at {}'.format(datetime.datetime.now()))
import shutil
shutil.copyfile(src=out_file, dst=out_file_gradio)
return out_file_gradio
def main():
description_header = '# PPSurf: Combining Patches and Point Convolutions for Detailed Surface Reconstruction'
description_col0 = '''## [Github](https://github.com/cg-tuwien/ppsurf)
Supported input file formats:
- PLY, STL, OBJ and other mesh files,
- XYZ as whitespace-separated text file,
- NPY and NPZ (key='arr_0'),
- LAS and LAZ (version 1.0-1.4), COPC and CRS.
Best results for 50k-250k points.
'''
description_col1 = '''## [Project Info](https://www.cg.tuwien.ac.at/research/publications/2024/erler_2024_ppsurf/)
This method is meant for scans of single and few objects.
Quality for scenes and landscapes will be lower.
Reconstructions with default settings will be done in about 30 seconds.
Inference will be terminated after 180 seconds.
'''
# can't render many input types directly in Gradio Model3D
# so we need to convert to supported format
# Gradio can't draw point clouds anyway (2024-03-04), so we skip this for now
# def convert_to_ply(input_point_cloud_upload: gr.utils.NamedString):
#
# # add absolute path to import dirs
# import sys
# import os
# sys.path.append(os.path.abspath('ppsurf'))
#
# # import os
# # os.chdir('ppsurf')
#
# print('Inputs:', input_point_cloud_upload, type(input_point_cloud_upload))
# input_shape: str = input_point_cloud_upload.name
# if not input_shape.endswith('.ply'):
# # load file
# from ppsurf.source.occupancy_data_module import OccupancyDataModule
# pts_np = OccupancyDataModule.load_pts(input_shape)
#
# # convert to ply
# import trimesh
# mesh = trimesh.Trimesh(vertices=pts_np[:, :3])
# input_shape = input_shape + '.ply'
# mesh.export(input_shape)
#
# print('ls:\n', subprocess.run(['ls', os.path.dirname(input_shape)]))
#
# # show in viewer
# print(type(input_tabs))
# # print(type(input_point_cloud_viewer))
# # input_tabs.selected = 'pc_viewer'
# # input_point_cloud_viewer.value = input_shape
with gr.Blocks(css='style.css') as demo:
# descriptions
gr.Markdown(description_header)
with gr.Row():
with gr.Column():
gr.Markdown(description_col0)
with gr.Column():
gr.Markdown(description_col1)
# inputs and outputs
with gr.Row():
with gr.Column():
input_point_cloud_upload = gr.File(show_label=False, file_count='single')
# with gr.Tabs() as input_tabs: # re-enable when Gradio supports point clouds
# with gr.TabItem(label='Input Point Cloud Upload', id='pc_upload'):
# input_point_cloud_upload.upload(
# fn=convert_to_ply,
# inputs=[
# input_point_cloud_upload,
# ],
# outputs=[
# # input_point_cloud_viewer, # not available here
# ])
# with gr.TabItem(label='Input Point Cloud Viewer', id='pc_viewer'):
# input_point_cloud_viewer = gr.Model3D(show_label=False)
gen_resolution_global = gr.Slider(
label='Grid Resolution (larger for more details)',
minimum=17, maximum=513, value=129, step=2)
padding_factor = gr.Slider(
label='Padding Factor (larger if object is cut off at boundaries)',
minimum=0, maximum=1.0, value=0.05, step=0.05)
gen_subsample_manifold_iter = gr.Slider(
label='Subsample Manifold Iterations (larger for larger point clouds)',
minimum=3, maximum=30, value=10, step=1)
gen_refine_iter = gr.Slider(
label='Edge Refinement Iterations (larger for more details)',
minimum=3, maximum=30, value=10, step=1)
with gr.Column():
result_3d_model = gr.Model3D(label='Reconstructed 3D model')
# progress_text = gr.Text(label='Progress')
# with gr.Tabs():
# with gr.TabItem(label='Reconstructed 3D model'):
# result_3d_model = gr.Model3D(show_label=False)
# with gr.TabItem(label='Output mesh file'):
# output_file = gr.File(show_label=False)
with gr.Row():
run_button = gr.Button('Reconstruct with PPSurf')
run_button.click(fn=run_on_gpu,
inputs=[
input_point_cloud_upload,
gen_resolution_global,
padding_factor,
gen_subsample_manifold_iter,
gen_refine_iter,
],
outputs=[
result_3d_model,
# output_file,
# progress_text,
])
demo.queue(max_size=5)
demo.launch(debug=True)
if __name__ == '__main__':
print(os.environ)
main()