Spaces:
Runtime error
Runtime error
""" | |
@Date: 2021/09/19 | |
@description: | |
""" | |
import json | |
import os | |
import argparse | |
import cv2 | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
import glob | |
from tqdm import tqdm | |
from PIL import Image | |
from config.defaults import merge_from_file, get_config | |
from dataset.mp3d_dataset import MP3DDataset | |
from dataset.zind_dataset import ZindDataset | |
from models.build import build_model | |
from loss import GradLoss | |
from postprocessing.post_process import post_process | |
from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama | |
from utils.boundary import corners2boundaries, layout2depth | |
from utils.conversion import depth2xyz | |
from utils.logger import get_logger | |
from utils.misc import tensor2np_d, tensor2np | |
from evaluation.accuracy import show_grad | |
from models.lgt_net import LGT_Net | |
from utils.writer import xyz2json | |
from visualization.boundary import draw_boundaries | |
from visualization.floorplan import draw_floorplan, draw_iou_floorplan | |
from visualization.obj3d import create_3d_obj | |
def parse_option(): | |
parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script') | |
parser.add_argument('--img_glob', | |
type=str, | |
required=True, | |
help='image glob path') | |
parser.add_argument('--cfg', | |
type=str, | |
required=True, | |
metavar='FILE', | |
help='path of config file') | |
parser.add_argument('--post_processing', | |
type=str, | |
default='manhattan', | |
choices=['manhattan', 'atalanta', 'original'], | |
help='post-processing type') | |
parser.add_argument('--output_dir', | |
type=str, | |
default='src/output', | |
help='path of output') | |
parser.add_argument('--visualize_3d', action='store_true', | |
help='visualize_3d') | |
parser.add_argument('--output_3d', action='store_true', | |
help='output_3d') | |
parser.add_argument('--device', | |
type=str, | |
default='cuda', | |
help='device') | |
args = parser.parse_args() | |
args.mode = 'test' | |
print("arguments:") | |
for arg in vars(args): | |
print(arg, ":", getattr(args, arg)) | |
print("-" * 50) | |
return args | |
def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None): | |
dt_np = tensor2np_d(dt) | |
dt_depth = dt_np['depth'][0] | |
dt_xyz = depth2xyz(np.abs(dt_depth)) | |
dt_ratio = dt_np['ratio'][0][0] | |
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1]) | |
vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0]) | |
if 'processed_xyz' in dt: | |
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False, | |
length=img.shape[1]) | |
vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0]) | |
if show_depth: | |
dt_grad_img = show_depth_normal_grad(dt) | |
grad_h = dt_grad_img.shape[0] | |
vis_merge = [ | |
vis_img[0:-grad_h, :, :], | |
dt_grad_img, | |
] | |
vis_img = np.concatenate(vis_merge, axis=0) | |
# vis_img = dt_grad_img.transpose(1, 2, 0)[100:] | |
if show_floorplan: | |
if 'processed_xyz' in dt: | |
floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2], | |
dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1]) | |
else: | |
floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1]) | |
vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1) | |
if show: | |
plt.imshow(vis_img) | |
plt.show() | |
if save_path: | |
result = Image.fromarray((vis_img * 255).astype(np.uint8)) | |
result.save(save_path) | |
return vis_img | |
def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None): | |
# Align images with VP | |
if os.path.exists(vp_cache_path): | |
with open(vp_cache_path) as f: | |
vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()] | |
vp = np.array(vp) | |
else: | |
# VP detection and line segment extraction | |
_, vp, _, _, _, _, _ = panoEdgeDetection(img_ori, | |
qError=q_error, | |
refineIter=refine_iter) | |
i_img = rotatePanorama(img_ori, vp[2::-1]) | |
if vp_cache_path is not None: | |
with open(vp_cache_path, 'w') as f: | |
for i in range(3): | |
f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2])) | |
return i_img, vp | |
def show_depth_normal_grad(dt): | |
grad_conv = GradLoss().to(dt['depth'].device).grad_conv | |
dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50) | |
dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST) | |
return dt_grad_img | |
def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None): | |
if border_color is None: | |
border_color = [1, 0, 0, 1] | |
fill_color = [0.2, 0.2, 0.2, 0.2] | |
dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color, | |
border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1]) | |
dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA') | |
back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float32) | |
back[..., :] = [0.8, 0.8, 0.8, 1] | |
back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA') | |
iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB") | |
dt_floorplan = np.array(iou_floorplan) / 255.0 | |
return dt_floorplan | |
def save_pred_json(xyz, ration, save_path): | |
# xyz[..., -1] = -xyz[..., -1] | |
json_data = xyz2json(xyz, ration) | |
with open(save_path, 'w') as f: | |
f.write(json.dumps(json_data, indent=4) + '\n') | |
return json_data | |
def inference(): | |
if len(img_paths) == 0: | |
logger.error('No images found') | |
return | |
bar = tqdm(img_paths, ncols=100) | |
for img_path in bar: | |
if not os.path.isfile(img_path): | |
logger.error(f'The {img_path} not is file') | |
continue | |
name = os.path.basename(img_path).split('.')[0] | |
bar.set_description(name) | |
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] | |
if args.post_processing is not None and 'manhattan' in args.post_processing: | |
bar.set_description("Preprocessing") | |
img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt")) | |
img = (img / 255.0).astype(np.float32) | |
run_one_inference(img, model, args, name) | |
def inference_dataset(dataset): | |
bar = tqdm(dataset, ncols=100) | |
for data in bar: | |
bar.set_description(data['id']) | |
run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger) | |
def run_one_inference(img, model, args, name, logger, show=True, show_depth=True, | |
show_floorplan=True, mesh_format='.gltf', mesh_resolution=512): | |
model.eval() | |
logger.info("model inference...") | |
dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device)) | |
if args.post_processing != 'original': | |
logger.info(f"post-processing, type:{args.post_processing}...") | |
dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing) | |
visualize_2d(img, dt, | |
show_depth=show_depth, | |
show_floorplan=show_floorplan, | |
show=show, | |
save_path=os.path.join(args.output_dir, f"{name}_pred.png")) | |
output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0])) | |
logger.info(f"saving predicted layout json...") | |
json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0], | |
save_path=os.path.join(args.output_dir, f"{name}_pred.json")) | |
# if args.visualize_3d: | |
# from visualization.visualizer.visualizer import visualize_3d | |
# visualize_3d(json_data, (img * 255).astype(np.uint8)) | |
if args.visualize_3d or args.output_3d: | |
dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None, | |
length=mesh_resolution if 'processed_xyz' in dt else None, | |
visible=True if 'processed_xyz' in dt else False) | |
dt_layout_depth = layout2depth(dt_boundaries, show=False) | |
logger.info(f"creating 3d mesh ...") | |
create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth, | |
save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None, | |
mesh=True, show=args.visualize_3d) | |
if __name__ == '__main__': | |
logger = get_logger() | |
args = parse_option() | |
config = get_config(args) | |
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): | |
logger.info(f'The {args.device} is not available, will use cpu ...') | |
config.defrost() | |
args.device = "cpu" | |
config.TRAIN.DEVICE = "cpu" | |
config.freeze() | |
model, _, _, _ = build_model(config, logger) | |
os.makedirs(args.output_dir, exist_ok=True) | |
img_paths = sorted(glob.glob(args.img_glob)) | |
inference() | |
# dataset = MP3DDataset(root_dir='./src/dataset/mp3d', mode='test', split_list=[ | |
# ['7y3sRwLe3Va', '155fac2d50764bf09feb6c8f33e8fb76'], | |
# ['e9zR4mvMWw7', 'c904c55a5d0e420bbd6e4e030b9fe5b4'], | |
# ]) | |
# dataset = ZindDataset(root_dir='./src/dataset/zind', mode='test', split_list=[ | |
# '1169_pano_21', | |
# '0583_pano_59', | |
# ], vp_align=True) | |
# inference_dataset(dataset) | |