import numpy as np import argparse import glob import os from functools import partial import vispy import scipy.misc as misc from tqdm import tqdm import yaml import time import sys from mesh import write_ply, read_ply, output_3d_photo from utils import get_MiDaS_samples, read_MiDaS_depth import torch import cv2 from skimage.transform import resize import imageio import copy from networks import Inpaint_Color_Net, Inpaint_Depth_Net, Inpaint_Edge_Net from MiDaS.run import run_depth from boostmonodepth_utils import run_boostmonodepth from MiDaS.monodepth_net import MonoDepthNet import MiDaS.MiDaS_utils as MiDaS_utils from bilateral_filtering import sparse_bilateral_filtering parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='argument.yml',help='Configure of post processing') args = parser.parse_args() config = yaml.load(open(args.config, 'r')) if config['offscreen_rendering'] is True: vispy.use(app='egl') os.makedirs(config['mesh_folder'], exist_ok=True) os.makedirs(config['video_folder'], exist_ok=True) os.makedirs(config['depth_folder'], exist_ok=True) sample_list = get_MiDaS_samples(config['src_folder'], config['depth_folder'], config, config['specific']) normal_canvas, all_canvas = None, None if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0): device = config["gpu_ids"] else: device = "cpu" print(f"running on device {device}") for idx in tqdm(range(len(sample_list))): depth = None sample = sample_list[idx] print("Current Source ==> ", sample['src_pair_name']) mesh_fi = os.path.join(config['mesh_folder'], sample['src_pair_name'] +'.ply') image = imageio.imread(sample['ref_img_fi']) print(f"Running depth extraction at {time.time()}") if config['use_boostmonodepth'] is True: run_boostmonodepth(sample['ref_img_fi'], config['src_folder'], config['depth_folder']) elif config['require_midas'] is True: run_depth([sample['ref_img_fi']], config['src_folder'], config['depth_folder'], config['MiDaS_model_ckpt'], MonoDepthNet, MiDaS_utils, target_w=640) if 'npy' in config['depth_format']: config['output_h'], config['output_w'] = np.load(sample['depth_fi']).shape[:2] else: config['output_h'], config['output_w'] = imageio.imread(sample['depth_fi']).shape[:2] frac = config['longer_side_len'] / max(config['output_h'], config['output_w']) config['output_h'], config['output_w'] = int(config['output_h'] * frac), int(config['output_w'] * frac) config['original_h'], config['original_w'] = config['output_h'], config['output_w'] if image.ndim == 2: image = image[..., None].repeat(3, -1) if np.sum(np.abs(image[..., 0] - image[..., 1])) == 0 and np.sum(np.abs(image[..., 1] - image[..., 2])) == 0: config['gray_image'] = True else: config['gray_image'] = False image = cv2.resize(image, (config['output_w'], config['output_h']), interpolation=cv2.INTER_AREA) depth = read_MiDaS_depth(sample['depth_fi'], 3.0, config['output_h'], config['output_w']) mean_loc_depth = depth[depth.shape[0]//2, depth.shape[1]//2] if not(config['load_ply'] is True and os.path.exists(mesh_fi)): vis_photos, vis_depths = sparse_bilateral_filtering(depth.copy(), image.copy(), config, num_iter=config['sparse_iter'], spdb=False) depth = vis_depths[-1] model = None torch.cuda.empty_cache() print("Start Running 3D_Photo ...") print(f"Loading edge model at {time.time()}") depth_edge_model = Inpaint_Edge_Net(init_weights=True) depth_edge_weight = torch.load(config['depth_edge_model_ckpt'], map_location=torch.device(device)) depth_edge_model.load_state_dict(depth_edge_weight) depth_edge_model = depth_edge_model.to(device) depth_edge_model.eval() print(f"Loading depth model at {time.time()}") depth_feat_model = Inpaint_Depth_Net() depth_feat_weight = torch.load(config['depth_feat_model_ckpt'], map_location=torch.device(device)) depth_feat_model.load_state_dict(depth_feat_weight, strict=True) depth_feat_model = depth_feat_model.to(device) depth_feat_model.eval() depth_feat_model = depth_feat_model.to(device) print(f"Loading rgb model at {time.time()}") rgb_model = Inpaint_Color_Net() rgb_feat_weight = torch.load(config['rgb_feat_model_ckpt'], map_location=torch.device(device)) rgb_model.load_state_dict(rgb_feat_weight) rgb_model.eval() rgb_model = rgb_model.to(device) graph = None print(f"Writing depth ply (and basically doing everything) at {time.time()}") rt_info = write_ply(image, depth, sample['int_mtx'], mesh_fi, config, rgb_model, depth_edge_model, depth_edge_model, depth_feat_model) if rt_info is False: continue rgb_model = None color_feat_model = None depth_edge_model = None depth_feat_model = None torch.cuda.empty_cache() if config['save_ply'] is True or config['load_ply'] is True: verts, colors, faces, Height, Width, hFov, vFov = read_ply(mesh_fi) else: verts, colors, faces, Height, Width, hFov, vFov = rt_info print(f"Making video at {time.time()}") videos_poses, video_basename = copy.deepcopy(sample['tgts_poses']), sample['tgt_name'] top = (config.get('original_h') // 2 - sample['int_mtx'][1, 2] * config['output_h']) left = (config.get('original_w') // 2 - sample['int_mtx'][0, 2] * config['output_w']) down, right = top + config['output_h'], left + config['output_w'] border = [int(xx) for xx in [top, down, left, right]] normal_canvas, all_canvas = output_3d_photo(verts.copy(), colors.copy(), faces.copy(), copy.deepcopy(Height), copy.deepcopy(Width), copy.deepcopy(hFov), copy.deepcopy(vFov), copy.deepcopy(sample['tgt_pose']), sample['video_postfix'], copy.deepcopy(sample['ref_pose']), copy.deepcopy(config['video_folder']), image.copy(), copy.deepcopy(sample['int_mtx']), config, image, videos_poses, video_basename, config.get('original_h'), config.get('original_w'), border=border, depth=depth, normal_canvas=normal_canvas, all_canvas=all_canvas, mean_loc_depth=mean_loc_depth)