zhigangjiang's picture
Update app.py
643aed9
'''
@author: Zhigang Jiang
@time: 2022/05/23
@description:
'''
import gradio as gr
import numpy as np
import os
import torch
os.system('pip install --upgrade --no-cache-dir gdown')
from PIL import Image
from utils.logger import get_logger
from config.defaults import get_config
from inference import preprocess, run_one_inference
from models.build import build_model
from argparse import Namespace
import gdown
def down_ckpt(model_cfg, ckpt_dir):
model_ids = [
['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'],
['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'],
['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'],
['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'],
['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
]
for model_id in model_ids:
if model_id[0] != model_cfg:
continue
path = os.path.join(ckpt_dir, 'best.pkl')
if not os.path.exists(path):
logger.info(f"Downloading {model_id}")
os.makedirs(ckpt_dir, exist_ok=True)
gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)
def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
args.pre_processing = pre_processing
args.post_processing = post_processing
if weight_name == 'mp3d':
model = mp3d_model
elif weight_name == 'zind':
model = zind_model
else:
logger.error("unknown pre-trained weight name")
raise NotImplementedError
img_name = os.path.basename(img_path).split('.')[0]
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
vp_cache_path = 'src/demo/default_vp.txt'
if args.pre_processing:
vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
logger.info("pre-processing ...")
img, vp = preprocess(img, vp_cache_path=vp_cache_path)
img = (img / 255.0).astype(np.float32)
run_one_inference(img, model, args, img_name,
logger=logger, show=False,
show_depth='depth-normal-gradient' in visualization,
show_floorplan='2d-floorplan' in visualization,
mesh_format=mesh_format, mesh_resolution=int(mesh_resolution))
return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
vp_cache_path,
os.path.join(args.output_dir, f"{img_name}_pred.json")]
def get_model(args):
config = get_config(args)
down_ckpt(args.cfg, config.CKPT.DIR)
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
logger.info(f'The {args.device} is not available, will use cpu...')
config.defrost()
args.device = "cpu"
config.TRAIN.DEVICE = "cpu"
config.freeze()
model, _, _, _ = build_model(config, logger)
return model
if __name__ == '__main__':
logger = get_logger()
args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True)
os.makedirs(args.output_dir, exist_ok=True)
args.cfg = 'src/config/mp3d.yaml'
mp3d_model = get_model(args)
args.cfg = 'src/config/zind.yaml'
zind_model = get_model(args)
description = "This demo of the project " \
"<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
"It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama."
demo = gr.Interface(fn=greet,
inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'),
gr.Checkbox(label='pre-processing', value=True),
gr.Radio(['mp3d', 'zind'],
label='pre-trained weight',
value='mp3d'),
gr.Radio(['manhattan', 'atalanta', 'original'],
label='post-processing method',
value='manhattan'),
gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'],
label='2d-visualization',
value=['depth-normal-gradient', '2d-floorplan']),
gr.Radio(['.gltf', '.obj', '.glb'],
label='output format of 3d mesh',
value='.gltf'),
gr.Radio(['128', '256', '512', '1024'],
label='output resolution of 3d mesh',
value='256'),
],
outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'),
gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]),
gr.File(label='3d mesh file'),
gr.File(label='vanishing point information'),
gr.File(label='layout json')],
examples=[
['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description)
demo.launch(debug=True, enable_queue=False)