Spaces:
Build error
Build error
''' | |
@author: Zhigang Jiang | |
@time: 2022/05/23 | |
@description: | |
''' | |
import gradio as gr | |
import numpy as np | |
import os | |
import torch | |
os.system('pip install --upgrade --no-cache-dir gdown') | |
from PIL import Image | |
from utils.logger import get_logger | |
from config.defaults import get_config | |
from inference import preprocess, run_one_inference | |
from models.build import build_model | |
from argparse import Namespace | |
import gdown | |
def down_ckpt(model_cfg, ckpt_dir): | |
model_ids = [ | |
['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'], | |
['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'], | |
['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'], | |
['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'], | |
['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha'] | |
] | |
for model_id in model_ids: | |
if model_id[0] != model_cfg: | |
continue | |
path = os.path.join(ckpt_dir, 'best.pkl') | |
if not os.path.exists(path): | |
logger.info(f"Downloading {model_id}") | |
os.makedirs(ckpt_dir, exist_ok=True) | |
gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False) | |
def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution): | |
args.pre_processing = pre_processing | |
args.post_processing = post_processing | |
if weight_name == 'mp3d': | |
model = mp3d_model | |
elif weight_name == 'zind': | |
model = zind_model | |
else: | |
logger.error("unknown pre-trained weight name") | |
raise NotImplementedError | |
img_name = os.path.basename(img_path).split('.')[0] | |
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] | |
vp_cache_path = 'src/demo/default_vp.txt' | |
if args.pre_processing: | |
vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt') | |
logger.info("pre-processing ...") | |
img, vp = preprocess(img, vp_cache_path=vp_cache_path) | |
img = (img / 255.0).astype(np.float32) | |
run_one_inference(img, model, args, img_name, | |
logger=logger, show=False, | |
show_depth='depth-normal-gradient' in visualization, | |
show_floorplan='2d-floorplan' in visualization, | |
mesh_format=mesh_format, mesh_resolution=int(mesh_resolution)) | |
return [os.path.join(args.output_dir, f"{img_name}_pred.png"), | |
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"), | |
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"), | |
vp_cache_path, | |
os.path.join(args.output_dir, f"{img_name}_pred.json")] | |
def get_model(args): | |
config = get_config(args) | |
down_ckpt(args.cfg, config.CKPT.DIR) | |
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): | |
logger.info(f'The {args.device} is not available, will use cpu ...') | |
config.defrost() | |
args.device = "cpu" | |
config.TRAIN.DEVICE = "cpu" | |
config.freeze() | |
model, _, _, _ = build_model(config, logger) | |
return model | |
if __name__ == '__main__': | |
logger = get_logger() | |
args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True) | |
os.makedirs(args.output_dir, exist_ok=True) | |
args.cfg = 'src/config/mp3d.yaml' | |
mp3d_model = get_model(args) | |
args.cfg = 'src/config/zind.yaml' | |
zind_model = get_model(args) | |
description = "This demo of the project " \ | |
"<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \ | |
"It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama." | |
demo = gr.Interface(fn=greet, | |
inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'), | |
gr.Checkbox(label='pre-processing', value=True), | |
gr.Radio(['mp3d', 'zind'], | |
label='pre-trained weight', | |
value='mp3d'), | |
gr.Radio(['manhattan', 'atalanta', 'original'], | |
label='post-processing method', | |
value='manhattan'), | |
gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'], | |
label='2d-visualization', | |
value=['depth-normal-gradient', '2d-floorplan']), | |
gr.Radio(['.gltf', '.obj', '.glb'], | |
label='output format of 3d mesh', | |
value='.gltf'), | |
gr.Radio(['128', '256', '512', '1024'], | |
label='output resolution of 3d mesh', | |
value='256'), | |
], | |
outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'), | |
gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]), | |
gr.File(label='3d mesh file'), | |
gr.File(label='vanishing point information'), | |
gr.File(label='layout json')], | |
examples=[ | |
['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'], | |
], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description) | |
demo.launch(debug=True, enable_queue=False) | |