import os import numpy as np try: import cynetworkx as netx except ImportError: import networkx as netx import json import scipy.misc as misc #import OpenEXR import scipy.signal as signal import matplotlib.pyplot as plt import cv2 import scipy.misc as misc from skimage import io from functools import partial from vispy import scene, io from vispy.scene import visuals from functools import reduce # from moviepy.editor import ImageSequenceClip import scipy.misc as misc from vispy.visuals.filters import Alpha import cv2 from skimage.transform import resize import copy import torch import os from utils import refine_depth_around_edge, smooth_cntsyn_gap from utils import require_depth_edge, filter_irrelevant_edge_new, open_small_mask from skimage.feature import canny from scipy import ndimage import time import transforms3d def relabel_node(mesh, nodes, cur_node, new_node): if cur_node == new_node: return mesh mesh.add_node(new_node) for key, value in nodes[cur_node].items(): nodes[new_node][key] = value for ne in mesh.neighbors(cur_node): mesh.add_edge(new_node, ne) mesh.remove_node(cur_node) return mesh def filter_edge(mesh, edge_ccs, config, invalid=False): context_ccs = [set() for _ in edge_ccs] mesh_nodes = mesh.nodes for edge_id, edge_cc in enumerate(edge_ccs): if config['context_thickness'] == 0: continue edge_group = {} for edge_node in edge_cc: far_nodes = mesh_nodes[edge_node].get('far') if far_nodes is None: continue for far_node in far_nodes: context_ccs[edge_id].add(far_node) if mesh_nodes[far_node].get('edge_id') is not None: if edge_group.get(mesh_nodes[far_node]['edge_id']) is None: edge_group[mesh_nodes[far_node]['edge_id']] = set() edge_group[mesh_nodes[far_node]['edge_id']].add(far_node) if len(edge_cc) > 2: for edge_key in [*edge_group.keys()]: if len(edge_group[edge_key]) == 1: context_ccs[edge_id].remove([*edge_group[edge_key]][0]) valid_edge_ccs = [] for xidx, yy in enumerate(edge_ccs): if invalid is not True and len(context_ccs[xidx]) > 0: # if len(context_ccs[xidx]) > 0: valid_edge_ccs.append(yy) elif invalid is True and len(context_ccs[xidx]) == 0: valid_edge_ccs.append(yy) else: valid_edge_ccs.append(set()) # valid_edge_ccs = [yy for xidx, yy in enumerate(edge_ccs) if len(context_ccs[xidx]) > 0] return valid_edge_ccs def extrapolate(global_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, depth_edge_model, depth_feat_model, rgb_feat_model, config, direc='right-up'): h_off, w_off = global_mesh.graph['hoffset'], global_mesh.graph['woffset'] noext_H, noext_W = global_mesh.graph['noext_H'], global_mesh.graph['noext_W'] if "up" in direc.lower() and "-" not in direc.lower(): all_anchor = [0, h_off + config['context_thickness'], w_off, w_off + noext_W] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [0, h_off, w_off, w_off + noext_W] context_anchor = [h_off, h_off + config['context_thickness'], w_off, w_off + noext_W] valid_line_anchor = [h_off, h_off + 1, w_off, w_off + noext_W] valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]), min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])] elif "down" in direc.lower() and "-" not in direc.lower(): all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, w_off, w_off + noext_W] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, w_off, w_off + noext_W] context_anchor = [h_off + noext_H - config['context_thickness'], h_off + noext_H, w_off, w_off + noext_W] valid_line_anchor = [h_off + noext_H - 1, h_off + noext_H, w_off, w_off + noext_W] valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]), min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])] elif "left" in direc.lower() and "-" not in direc.lower(): all_anchor = [h_off, h_off + noext_H, 0, w_off + config['context_thickness']] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [h_off, h_off + noext_H, 0, w_off] context_anchor = [h_off, h_off + noext_H, w_off, w_off + config['context_thickness']] valid_line_anchor = [h_off, h_off + noext_H, w_off, w_off + 1] valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]), min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])] elif "right" in direc.lower() and "-" not in direc.lower(): all_anchor = [h_off, h_off + noext_H, w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [h_off, h_off + noext_H, w_off + noext_W, 2 * w_off + noext_W] context_anchor = [h_off, h_off + noext_H, w_off + noext_W - config['context_thickness'], w_off + noext_W] valid_line_anchor = [h_off, h_off + noext_H, w_off + noext_W - 1, w_off + noext_W] valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]), min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])] elif "left" in direc.lower() and "up" in direc.lower() and "-" in direc.lower(): all_anchor = [0, h_off + config['context_thickness'], 0, w_off + config['context_thickness']] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [0, h_off, 0, w_off] context_anchor = "inv-mask" valid_line_anchor = None valid_anchor = all_anchor elif "left" in direc.lower() and "down" in direc.lower() and "-" in direc.lower(): all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, 0, w_off + config['context_thickness']] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, 0, w_off] context_anchor = "inv-mask" valid_line_anchor = None valid_anchor = all_anchor elif "right" in direc.lower() and "up" in direc.lower() and "-" in direc.lower(): all_anchor = [0, h_off + config['context_thickness'], w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [0, h_off, w_off + noext_W, 2 * w_off + noext_W] context_anchor = "inv-mask" valid_line_anchor = None valid_anchor = all_anchor elif "right" in direc.lower() and "down" in direc.lower() and "-" in direc.lower(): all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W] global_shift = [all_anchor[0], all_anchor[2]] mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, w_off + noext_W, 2 * w_off + noext_W] context_anchor = "inv-mask" valid_line_anchor = None valid_anchor = all_anchor global_mask = np.zeros_like(depth) global_mask[mask_anchor[0]:mask_anchor[1],mask_anchor[2]:mask_anchor[3]] = 1 mask = global_mask[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * 1 context = 1 - mask global_context = np.zeros_like(depth) global_context[all_anchor[0]:all_anchor[1],all_anchor[2]:all_anchor[3]] = context # context = global_context[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * 1 valid_area = mask + context input_rgb = image[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] / 255. * context[..., None] input_depth = depth[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * context log_depth = np.log(input_depth + 1e-8) log_depth[mask > 0] = 0 input_mean_depth = np.mean(log_depth[context > 0]) input_zero_mean_depth = (log_depth - input_mean_depth) * context input_disp = 1./np.abs(input_depth) input_disp[mask > 0] = 0 input_disp = input_disp / input_disp.max() valid_line = np.zeros_like(depth) if valid_line_anchor is not None: valid_line[valid_line_anchor[0]:valid_line_anchor[1], valid_line_anchor[2]:valid_line_anchor[3]] = 1 valid_line = valid_line[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] # f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(global_context * 1 + global_mask * 2); ax2.imshow(image); plt.show() # f, ((ax1, ax2, ax3)) = plt.subplots(1, 3, sharex=True, sharey=True); ax1.imshow(context * 1 + mask * 2); ax2.imshow(input_rgb); ax3.imshow(valid_line); plt.show() # import pdb; pdb.set_trace() # return input_edge_map = edge_map[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] * context input_other_edge_with_id = other_edge_with_id[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] end_depth_maps = ((valid_line * input_edge_map) > 0) * input_depth if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0): device = config["gpu_ids"] else: device = "cpu" valid_edge_ids = sorted(list(input_other_edge_with_id[(valid_line * input_edge_map) > 0])) valid_edge_ids = valid_edge_ids[1:] if (len(valid_edge_ids) > 0 and valid_edge_ids[0] == -1) else valid_edge_ids edge = reduce(lambda x, y: (x + (input_other_edge_with_id == y).astype(np.uint8)).clip(0, 1), [np.zeros_like(mask)] + list(valid_edge_ids)) t_edge = torch.FloatTensor(edge).to(device)[None, None, ...] t_rgb = torch.FloatTensor(input_rgb).to(device).permute(2,0,1).unsqueeze(0) t_mask = torch.FloatTensor(mask).to(device)[None, None, ...] t_context = torch.FloatTensor(context).to(device)[None, None, ...] t_disp = torch.FloatTensor(input_disp).to(device)[None, None, ...] t_depth_zero_mean_depth = torch.FloatTensor(input_zero_mean_depth).to(device)[None, None, ...] depth_edge_output = depth_edge_model.forward_3P(t_mask, t_context, t_rgb, t_disp, t_edge, unit_length=128, cuda=device) t_output_edge = (depth_edge_output> config['ext_edge_threshold']).float() * t_mask + t_edge output_raw_edge = t_output_edge.data.cpu().numpy().squeeze() # import pdb; pdb.set_trace() mesh = netx.Graph() hxs, hys = np.where(output_raw_edge * mask > 0) valid_map = mask + context for hx, hy in zip(hxs, hys): node = (hx, hy) mesh.add_node((hx, hy)) eight_nes = [ne for ne in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), \ (hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)]\ if 0 <= ne[0] < output_raw_edge.shape[0] and 0 <= ne[1] < output_raw_edge.shape[1] and 0 < output_raw_edge[ne[0], ne[1]]] for ne in eight_nes: mesh.add_edge(node, ne, length=np.hypot(ne[0] - hx, ne[1] - hy)) if end_depth_maps[ne[0], ne[1]] != 0: mesh.nodes[ne[0], ne[1]]['cnt'] = True mesh.nodes[ne[0], ne[1]]['depth'] = end_depth_maps[ne[0], ne[1]] ccs = [*netx.connected_components(mesh)] end_pts = [] for cc in ccs: end_pts.append(set()) for node in cc: if mesh.nodes[node].get('cnt') is not None: end_pts[-1].add((node[0], node[1], mesh.nodes[node]['depth'])) fpath_map = np.zeros_like(output_raw_edge) - 1 npath_map = np.zeros_like(output_raw_edge) - 1 for end_pt, cc in zip(end_pts, ccs): sorted_end_pt = [] if len(end_pt) >= 2: continue if len(end_pt) == 0: continue if len(end_pt) == 1: sub_mesh = mesh.subgraph(list(cc)).copy() pnodes = netx.periphery(sub_mesh) ends = [*end_pt] edge_id = global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])]['edge_id'] pnodes = sorted(pnodes, key=lambda x: np.hypot((x[0] - ends[0][0]), (x[1] - ends[0][1])), reverse=True)[0] npath = [*netx.shortest_path(sub_mesh, (ends[0][0], ends[0][1]), pnodes, weight='length')] for np_node in npath: npath_map[np_node[0], np_node[1]] = edge_id fpath = [] if global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])].get('far') is None: print("None far") import pdb; pdb.set_trace() else: fnodes = global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])].get('far') fnodes = [(xx[0] - all_anchor[0], xx[1] - all_anchor[2], xx[2]) for xx in fnodes] dmask = mask + 0 did = 0 while True: did += 1 dmask = cv2.dilate(dmask, np.ones((3, 3)), iterations=1) if did > 3: break # ffnode = [fnode for fnode in fnodes if (dmask[fnode[0], fnode[1]] > 0)] ffnode = [fnode for fnode in fnodes if (dmask[fnode[0], fnode[1]] > 0 and mask[fnode[0], fnode[1]] == 0)] if len(ffnode) > 0: fnode = ffnode[0] break if len(ffnode) == 0: continue fpath.append((fnode[0], fnode[1])) for step in range(0, len(npath) - 1): parr = (npath[step + 1][0] - npath[step][0], npath[step + 1][1] - npath[step][1]) new_loc = (fpath[-1][0] + parr[0], fpath[-1][1] + parr[1]) new_loc_nes = [xx for xx in [(new_loc[0] + 1, new_loc[1]), (new_loc[0] - 1, new_loc[1]), (new_loc[0], new_loc[1] + 1), (new_loc[0], new_loc[1] - 1)]\ if xx[0] >= 0 and xx[0] < fpath_map.shape[0] and xx[1] >= 0 and xx[1] < fpath_map.shape[1]] if np.sum([fpath_map[nlne[0], nlne[1]] for nlne in new_loc_nes]) != -4: break if npath_map[new_loc[0], new_loc[1]] != -1: if npath_map[new_loc[0], new_loc[1]] != edge_id: break else: continue if valid_area[new_loc[0], new_loc[1]] == 0: break new_loc_nes_eight = [xx for xx in [(new_loc[0] + 1, new_loc[1]), (new_loc[0] - 1, new_loc[1]), (new_loc[0], new_loc[1] + 1), (new_loc[0], new_loc[1] - 1), (new_loc[0] + 1, new_loc[1] + 1), (new_loc[0] + 1, new_loc[1] - 1), (new_loc[0] - 1, new_loc[1] - 1), (new_loc[0] - 1, new_loc[1] + 1)]\ if xx[0] >= 0 and xx[0] < fpath_map.shape[0] and xx[1] >= 0 and xx[1] < fpath_map.shape[1]] if np.sum([int(npath_map[nlne[0], nlne[1]] == edge_id) for nlne in new_loc_nes_eight]) == 0: break fpath.append((fpath[-1][0] + parr[0], fpath[-1][1] + parr[1])) if step != len(npath) - 2: for xx in npath[step+1:]: if npath_map[xx[0], xx[1]] == edge_id: npath_map[xx[0], xx[1]] = -1 if len(fpath) > 0: for fp_node in fpath: fpath_map[fp_node[0], fp_node[1]] = edge_id # import pdb; pdb.set_trace() far_edge = (fpath_map > -1).astype(np.uint8) update_edge = (npath_map > -1) * mask + edge t_update_edge = torch.FloatTensor(update_edge).to(device)[None, None, ...] depth_output = depth_feat_model.forward_3P(t_mask, t_context, t_depth_zero_mean_depth, t_update_edge, unit_length=128, cuda=device) depth_output = depth_output.cpu().data.numpy().squeeze() depth_output = np.exp(depth_output + input_mean_depth) * mask # + input_depth * context # if "right" in direc.lower() and "-" not in direc.lower(): # plt.imshow(depth_output); plt.show() # import pdb; pdb.set_trace() # f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output); ax2.imshow(npath_map + fpath_map); plt.show() for near_id in np.unique(npath_map[npath_map > -1]): depth_output = refine_depth_around_edge(depth_output.copy(), (fpath_map == near_id).astype(np.uint8) * mask, # far_edge_map_in_mask, (fpath_map == near_id).astype(np.uint8), # far_edge_map, (npath_map == near_id).astype(np.uint8) * mask, mask.copy(), np.zeros_like(mask), config) # if "right" in direc.lower() and "-" not in direc.lower(): # plt.imshow(depth_output); plt.show() # import pdb; pdb.set_trace() # f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output); ax2.imshow(npath_map + fpath_map); plt.show() rgb_output = rgb_feat_model.forward_3P(t_mask, t_context, t_rgb, t_update_edge, unit_length=128, cuda=device) # rgb_output = rgb_feat_model.forward_3P(t_mask, t_context, t_rgb, t_update_edge, unit_length=128, cuda=config['gpu_ids']) if config.get('gray_image') is True: rgb_output = rgb_output.mean(1, keepdim=True).repeat((1,3,1,1)) rgb_output = ((rgb_output.squeeze().data.cpu().permute(1,2,0).numpy() * mask[..., None] + input_rgb) * 255).astype(np.uint8) image[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]][mask > 0] = rgb_output[mask > 0] # np.array([255,0,0]) # rgb_output[mask > 0] depth[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]][mask > 0] = depth_output[mask > 0] # nxs, nys = np.where(mask > -1) # for nx, ny in zip(nxs, nys): # info_on_pix[(nx, ny)][0]['color'] = rgb_output[] nxs, nys = np.where((npath_map > -1)) for nx, ny in zip(nxs, nys): n_id = npath_map[nx, ny] four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\ if 0 <= xx[0] < fpath_map.shape[0] and 0 <= xx[1] < fpath_map.shape[1]] for nex, ney in four_nes: if fpath_map[nex, ney] == n_id: na, nb = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']), \ (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth']) if global_mesh.has_edge(na, nb): global_mesh.remove_edge(na, nb) nxs, nys = np.where((fpath_map > -1)) for nx, ny in zip(nxs, nys): n_id = fpath_map[nx, ny] four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\ if 0 <= xx[0] < npath_map.shape[0] and 0 <= xx[1] < npath_map.shape[1]] for nex, ney in four_nes: if npath_map[nex, ney] == n_id: na, nb = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']), \ (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth']) if global_mesh.has_edge(na, nb): global_mesh.remove_edge(na, nb) nxs, nys = np.where(mask > 0) for x, y in zip(nxs, nys): x = x + all_anchor[0] y = y + all_anchor[2] cur_node = (x, y, 0) new_node = (x, y, -abs(depth[x, y])) disp = 1. / -abs(depth[x, y]) mapping_dict = {cur_node: new_node} info_on_pix, global_mesh = update_info(mapping_dict, info_on_pix, global_mesh) global_mesh.nodes[new_node]['color'] = image[x, y] global_mesh.nodes[new_node]['old_color'] = image[x, y] global_mesh.nodes[new_node]['disp'] = disp info_on_pix[(x, y)][0]['depth'] = -abs(depth[x, y]) info_on_pix[(x, y)][0]['disp'] = disp info_on_pix[(x, y)][0]['color'] = image[x, y] nxs, nys = np.where((npath_map > -1)) for nx, ny in zip(nxs, nys): self_node = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']) if global_mesh.has_node(self_node) is False: break n_id = int(round(npath_map[nx, ny])) four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\ if 0 <= xx[0] < fpath_map.shape[0] and 0 <= xx[1] < fpath_map.shape[1]] for nex, ney in four_nes: ne_node = (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth']) if global_mesh.has_node(ne_node) is False: continue if fpath_map[nex, ney] == n_id: if global_mesh.nodes[self_node].get('edge_id') is None: global_mesh.nodes[self_node]['edge_id'] = n_id edge_ccs[n_id].add(self_node) info_on_pix[(self_node[0], self_node[1])][0]['edge_id'] = n_id if global_mesh.has_edge(self_node, ne_node) is True: global_mesh.remove_edge(self_node, ne_node) if global_mesh.nodes[self_node].get('far') is None: global_mesh.nodes[self_node]['far'] = [] global_mesh.nodes[self_node]['far'].append(ne_node) global_fpath_map = np.zeros_like(other_edge_with_id) - 1 global_fpath_map[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] = fpath_map fpath_ids = np.unique(global_fpath_map) fpath_ids = fpath_ids[1:] if fpath_ids.shape[0] > 0 and fpath_ids[0] == -1 else [] fpath_real_id_map = np.zeros_like(global_fpath_map) - 1 for fpath_id in fpath_ids: fpath_real_id = np.unique(((global_fpath_map == fpath_id).astype(np.int) * (other_edge_with_id + 1)) - 1) fpath_real_id = fpath_real_id[1:] if fpath_real_id.shape[0] > 0 and fpath_real_id[0] == -1 else [] fpath_real_id = fpath_real_id.astype(np.int) fpath_real_id = np.bincount(fpath_real_id).argmax() fpath_real_id_map[global_fpath_map == fpath_id] = fpath_real_id nxs, nys = np.where((fpath_map > -1)) for nx, ny in zip(nxs, nys): self_node = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']) n_id = fpath_map[nx, ny] four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\ if 0 <= xx[0] < npath_map.shape[0] and 0 <= xx[1] < npath_map.shape[1]] for nex, ney in four_nes: ne_node = (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth']) if global_mesh.has_node(ne_node) is False: continue if npath_map[nex, ney] == n_id or global_mesh.nodes[ne_node].get('edge_id') == n_id: if global_mesh.has_edge(self_node, ne_node) is True: global_mesh.remove_edge(self_node, ne_node) if global_mesh.nodes[self_node].get('near') is None: global_mesh.nodes[self_node]['near'] = [] if global_mesh.nodes[self_node].get('edge_id') is None: f_id = int(round(fpath_real_id_map[self_node[0], self_node[1]])) global_mesh.nodes[self_node]['edge_id'] = f_id info_on_pix[(self_node[0], self_node[1])][0]['edge_id'] = f_id edge_ccs[f_id].add(self_node) global_mesh.nodes[self_node]['near'].append(ne_node) return info_on_pix, global_mesh, image, depth, edge_ccs # for edge_cc in edge_ccs: # for edge_node in edge_cc: # edge_ccs # context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, init_mask_connect, edge_maps, extend_context_ccs, extend_edge_ccs def get_valid_size(imap): x_max = np.where(imap.sum(1).squeeze() > 0)[0].max() + 1 x_min = np.where(imap.sum(1).squeeze() > 0)[0].min() y_max = np.where(imap.sum(0).squeeze() > 0)[0].max() + 1 y_min = np.where(imap.sum(0).squeeze() > 0)[0].min() size_dict = {'x_max':x_max, 'y_max':y_max, 'x_min':x_min, 'y_min':y_min} return size_dict def dilate_valid_size(isize_dict, imap, dilate=[0, 0]): osize_dict = copy.deepcopy(isize_dict) osize_dict['x_min'] = max(0, osize_dict['x_min'] - dilate[0]) osize_dict['x_max'] = min(imap.shape[0], osize_dict['x_max'] + dilate[0]) osize_dict['y_min'] = max(0, osize_dict['y_min'] - dilate[0]) osize_dict['y_max'] = min(imap.shape[1], osize_dict['y_max'] + dilate[1]) return osize_dict def size_operation(size_a, size_b, operation): assert operation == '+' or operation == '-', "Operation must be '+' (union) or '-' (exclude)" osize = {} if operation == '+': osize['x_min'] = min(size_a['x_min'], size_b['x_min']) osize['y_min'] = min(size_a['y_min'], size_b['y_min']) osize['x_max'] = max(size_a['x_max'], size_b['x_max']) osize['y_max'] = max(size_a['y_max'], size_b['y_max']) assert operation != '-', "Operation '-' is undefined !" return osize def fill_dummy_bord(mesh, info_on_pix, image, depth, config): context = np.zeros_like(depth).astype(np.uint8) context[mesh.graph['hoffset']:mesh.graph['hoffset'] + mesh.graph['noext_H'], mesh.graph['woffset']:mesh.graph['woffset'] + mesh.graph['noext_W']] = 1 mask = 1 - context xs, ys = np.where(mask > 0) depth = depth * context image = image * context[..., None] cur_depth = 0 cur_disp = 0 color = [0, 0, 0] for x, y in zip(xs, ys): cur_node = (x, y, cur_depth) mesh.add_node(cur_node, color=color, synthesis=False, disp=cur_disp, cc_id=set(), ext_pixel=True) info_on_pix[(x, y)] = [{'depth':cur_depth, 'color':mesh.nodes[(x, y, cur_depth)]['color'], 'synthesis':False, 'disp':mesh.nodes[cur_node]['disp'], 'ext_pixel':True}] # for x, y in zip(xs, ys): four_nes = [(xx, yy) for xx, yy in [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)] if\ 0 <= x < mesh.graph['H'] and 0 <= y < mesh.graph['W'] and info_on_pix.get((xx, yy)) is not None] for ne in four_nes: # if (ne[0] - x) + (ne[1] - y) == 1 and info_on_pix.get((ne[0], ne[1])) is not None: mesh.add_edge(cur_node, (ne[0], ne[1], info_on_pix[(ne[0], ne[1])][0]['depth'])) return mesh, info_on_pix def enlarge_border(mesh, info_on_pix, depth, image, config): mesh.graph['hoffset'], mesh.graph['woffset'] = config['extrapolation_thickness'], config['extrapolation_thickness'] mesh.graph['bord_up'], mesh.graph['bord_left'], mesh.graph['bord_down'], mesh.graph['bord_right'] = \ 0, 0, mesh.graph['H'], mesh.graph['W'] # new_image = np.pad(image, # pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']), # (config['extrapolation_thickness'], config['extrapolation_thickness']), (0, 0)), # mode='constant') # new_depth = np.pad(depth, # pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']), # (config['extrapolation_thickness'], config['extrapolation_thickness'])), # mode='constant') return mesh, info_on_pix, depth, image def fill_missing_node(mesh, info_on_pix, image, depth): for x in range(mesh.graph['bord_up'], mesh.graph['bord_down']): for y in range(mesh.graph['bord_left'], mesh.graph['bord_right']): if info_on_pix.get((x, y)) is None: print("fill missing node = ", x, y) import pdb; pdb.set_trace() re_depth, re_count = 0, 0 for ne in [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)]: if info_on_pix.get(ne) is not None: re_depth += info_on_pix[ne][0]['depth'] re_count += 1 if re_count == 0: re_depth = -abs(depth[x, y]) else: re_depth = re_depth / re_count depth[x, y] = abs(re_depth) info_on_pix[(x, y)] = [{'depth':re_depth, 'color':image[x, y], 'synthesis':False, 'disp':1./re_depth}] mesh.add_node((x, y, re_depth), color=image[x, y], synthesis=False, disp=1./re_depth, cc_id=set()) return mesh, info_on_pix, depth def refresh_bord_depth(mesh, info_on_pix, image, depth): H, W = mesh.graph['H'], mesh.graph['W'] corner_nodes = [(mesh.graph['bord_up'], mesh.graph['bord_left']), (mesh.graph['bord_up'], mesh.graph['bord_right'] - 1), (mesh.graph['bord_down'] - 1, mesh.graph['bord_left']), (mesh.graph['bord_down'] - 1, mesh.graph['bord_right'] - 1)] # (0, W - 1), (H - 1, 0), (H - 1, W - 1)] bord_nodes = [] bord_nodes += [(mesh.graph['bord_up'], xx) for xx in range(mesh.graph['bord_left'] + 1, mesh.graph['bord_right'] - 1)] bord_nodes += [(mesh.graph['bord_down'] - 1, xx) for xx in range(mesh.graph['bord_left'] + 1, mesh.graph['bord_right'] - 1)] bord_nodes += [(xx, mesh.graph['bord_left']) for xx in range(mesh.graph['bord_up'] + 1, mesh.graph['bord_down'] - 1)] bord_nodes += [(xx, mesh.graph['bord_right'] - 1) for xx in range(mesh.graph['bord_up'] + 1, mesh.graph['bord_down'] - 1)] for xy in bord_nodes: tgt_loc = None if xy[0] == mesh.graph['bord_up']: tgt_loc = (xy[0] + 1, xy[1])# (1, xy[1]) elif xy[0] == mesh.graph['bord_down'] - 1: tgt_loc = (xy[0] - 1, xy[1]) # (H - 2, xy[1]) elif xy[1] == mesh.graph['bord_left']: tgt_loc = (xy[0], xy[1] + 1) elif xy[1] == mesh.graph['bord_right'] - 1: tgt_loc = (xy[0], xy[1] - 1) if tgt_loc is not None: ne_infos = info_on_pix.get(tgt_loc) if ne_infos is None: import pdb; pdb.set_trace() # if ne_infos is not None and len(ne_infos) == 1: tgt_depth = ne_infos[0]['depth'] tgt_disp = ne_infos[0]['disp'] new_node = (xy[0], xy[1], tgt_depth) src_node = (tgt_loc[0], tgt_loc[1], tgt_depth) tgt_nes_loc = [(xx[0], xx[1]) \ for xx in mesh.neighbors(src_node)] tgt_nes_loc = [(xx[0] - tgt_loc[0] + xy[0], xx[1] - tgt_loc[1] + xy[1]) for xx in tgt_nes_loc \ if abs(xx[0] - xy[0]) == 1 and abs(xx[1] - xy[1]) == 1] tgt_nes_loc = [xx for xx in tgt_nes_loc if info_on_pix.get(xx) is not None] tgt_nes_loc.append(tgt_loc) # if (xy[0], xy[1]) == (559, 60): # import pdb; pdb.set_trace() if info_on_pix.get(xy) is not None and len(info_on_pix.get(xy)) > 0: old_depth = info_on_pix[xy][0].get('depth') old_node = (xy[0], xy[1], old_depth) mesh.remove_edges_from([(old_ne, old_node) for old_ne in mesh.neighbors(old_node)]) mesh.add_edges_from([((zz[0], zz[1], info_on_pix[zz][0]['depth']), old_node) for zz in tgt_nes_loc]) mapping_dict = {old_node: new_node} # if old_node[2] == new_node[2]: # print("mapping_dict = ", mapping_dict) info_on_pix, mesh = update_info(mapping_dict, info_on_pix, mesh) else: info_on_pix[xy] = [] info_on_pix[xy][0] = info_on_pix[tgt_loc][0] info_on_pix['color'] = image[xy[0], xy[1]] info_on_pix['old_color'] = image[xy[0], xy[1]] mesh.add_node(new_node) mesh.add_edges_from([((zz[0], zz[1], info_on_pix[zz][0]['depth']), new_node) for zz in tgt_nes_loc]) mesh.nodes[new_node]['far'] = None mesh.nodes[new_node]['near'] = None if mesh.nodes[src_node].get('far') is not None: redundant_nodes = [ne for ne in mesh.nodes[src_node]['far'] if (ne[0], ne[1]) == xy] [mesh.nodes[src_node]['far'].remove(aa) for aa in redundant_nodes] if mesh.nodes[src_node].get('near') is not None: redundant_nodes = [ne for ne in mesh.nodes[src_node]['near'] if (ne[0], ne[1]) == xy] [mesh.nodes[src_node]['near'].remove(aa) for aa in redundant_nodes] for xy in corner_nodes: hx, hy = xy four_nes = [xx for xx in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] if \ mesh.graph['bord_up'] <= xx[0] < mesh.graph['bord_down'] and \ mesh.graph['bord_left'] <= xx[1] < mesh.graph['bord_right']] ne_nodes = [] ne_depths = [] for ne_loc in four_nes: if info_on_pix.get(ne_loc) is not None: ne_depths.append(info_on_pix[ne_loc][0]['depth']) ne_nodes.append((ne_loc[0], ne_loc[1], info_on_pix[ne_loc][0]['depth'])) new_node = (xy[0], xy[1], float(np.mean(ne_depths))) if info_on_pix.get(xy) is not None and len(info_on_pix.get(xy)) > 0: old_depth = info_on_pix[xy][0].get('depth') old_node = (xy[0], xy[1], old_depth) mesh.remove_edges_from([(old_ne, old_node) for old_ne in mesh.neighbors(old_node)]) mesh.add_edges_from([(zz, old_node) for zz in ne_nodes]) mapping_dict = {old_node: new_node} info_on_pix, mesh = update_info(mapping_dict, info_on_pix, mesh) else: info_on_pix[xy] = [] info_on_pix[xy][0] = info_on_pix[ne_loc[-1]][0] info_on_pix['color'] = image[xy[0], xy[1]] info_on_pix['old_color'] = image[xy[0], xy[1]] mesh.add_node(new_node) mesh.add_edges_from([(zz, new_node) for zz in ne_nodes]) mesh.nodes[new_node]['far'] = None mesh.nodes[new_node]['near'] = None for xy in bord_nodes + corner_nodes: # if (xy[0], xy[1]) == (559, 60): # import pdb; pdb.set_trace() depth[xy[0], xy[1]] = abs(info_on_pix[xy][0]['depth']) for xy in bord_nodes: cur_node = (xy[0], xy[1], info_on_pix[xy][0]['depth']) nes = mesh.neighbors(cur_node) four_nes = set([(xy[0] + 1, xy[1]), (xy[0] - 1, xy[1]), (xy[0], xy[1] + 1), (xy[0], xy[1] - 1)]) - \ set([(ne[0], ne[1]) for ne in nes]) four_nes = [ne for ne in four_nes if mesh.graph['bord_up'] <= ne[0] < mesh.graph['bord_down'] and \ mesh.graph['bord_left'] <= ne[1] < mesh.graph['bord_right']] four_nes = [(ne[0], ne[1], info_on_pix[(ne[0], ne[1])][0]['depth']) for ne in four_nes] mesh.nodes[cur_node]['far'] = [] mesh.nodes[cur_node]['near'] = [] for ne in four_nes: if abs(ne[2]) >= abs(cur_node[2]): mesh.nodes[cur_node]['far'].append(ne) else: mesh.nodes[cur_node]['near'].append(ne) return mesh, info_on_pix, depth def get_union_size(mesh, dilate, *alls_cc): all_cc = reduce(lambda x, y: x | y, [set()] + [*alls_cc]) min_x, min_y, max_x, max_y = mesh.graph['H'], mesh.graph['W'], 0, 0 H, W = mesh.graph['H'], mesh.graph['W'] for node in all_cc: if node[0] < min_x: min_x = node[0] if node[0] > max_x: max_x = node[0] if node[1] < min_y: min_y = node[1] if node[1] > max_y: max_y = node[1] max_x = max_x + 1 max_y = max_y + 1 # mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20]) osize_dict = dict() osize_dict['x_min'] = max(0, min_x - dilate[0]) osize_dict['x_max'] = min(H, max_x + dilate[0]) osize_dict['y_min'] = max(0, min_y - dilate[1]) osize_dict['y_max'] = min(W, max_y + dilate[1]) return osize_dict def incomplete_node(mesh, edge_maps, info_on_pix): vis_map = np.zeros((mesh.graph['H'], mesh.graph['W'])) for node in mesh.nodes: if mesh.nodes[node].get('synthesis') is not True: connect_all_flag = False nes = [xx for xx in mesh.neighbors(node) if mesh.nodes[xx].get('synthesis') is not True] if len(nes) < 3 and 0 < node[0] < mesh.graph['H'] - 1 and 0 < node[1] < mesh.graph['W'] - 1: if len(nes) <= 1: connect_all_flag = True else: dan_ne_node_a = nes[0] dan_ne_node_b = nes[1] if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \ abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1: connect_all_flag = True if connect_all_flag == True: vis_map[node[0], node[1]] = len(nes) four_nes = [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)] for ne in four_nes: for info in info_on_pix[(ne[0], ne[1])]: ne_node = (ne[0], ne[1], info['depth']) if info.get('synthesis') is not True and mesh.has_node(ne_node): mesh.add_edge(node, ne_node) break return mesh def edge_inpainting(edge_id, context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc, mesh, edge_map, edge_maps_with_id, config, union_size, depth_edge_model, inpaint_iter): edge_dict = get_edge_from_nodes(context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc, mesh.graph['H'], mesh.graph['W'], mesh) edge_dict['edge'], end_depth_maps, _ = \ filter_irrelevant_edge_new(edge_dict['self_edge'] + edge_dict['comp_edge'], edge_map, edge_maps_with_id, edge_id, edge_dict['context'], edge_dict['depth'], mesh, context_cc | erode_context_cc, spdb=True) patch_edge_dict = dict() patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \ patch_edge_dict['disp'], patch_edge_dict['edge'] = \ crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'], edge_dict['rgb'], edge_dict['disp'], edge_dict['edge']) tensor_edge_dict = convert2tensor(patch_edge_dict) if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0: with torch.no_grad(): device = config["gpu_ids"] if isinstance(config["gpu_ids"], int) and config["gpu_ids"] >= 0 else "cpu" depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'], tensor_edge_dict['context'], tensor_edge_dict['rgb'], tensor_edge_dict['disp'], tensor_edge_dict['edge'], unit_length=128, cuda=device) depth_edge_output = depth_edge_output.cpu() tensor_edge_dict['output'] = (depth_edge_output > config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge'] else: tensor_edge_dict['output'] = tensor_edge_dict['edge'] depth_edge_output = tensor_edge_dict['edge'] + 0 patch_edge_dict['output'] = tensor_edge_dict['output'].squeeze().data.cpu().numpy() edge_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W'])) edge_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \ patch_edge_dict['output'] return edge_dict, end_depth_maps def depth_inpainting(context_cc, extend_context_cc, erode_context_cc, mask_cc, mesh, config, union_size, depth_feat_model, edge_output, given_depth_dict=False, spdb=False): if given_depth_dict is False: depth_dict = get_depth_from_nodes(context_cc | extend_context_cc, erode_context_cc, mask_cc, mesh.graph['H'], mesh.graph['W'], mesh, config['log_depth']) if edge_output is not None: depth_dict['edge'] = edge_output else: depth_dict = given_depth_dict patch_depth_dict = dict() patch_depth_dict['mask'], patch_depth_dict['context'], patch_depth_dict['depth'], \ patch_depth_dict['zero_mean_depth'], patch_depth_dict['edge'] = \ crop_maps_by_size(union_size, depth_dict['mask'], depth_dict['context'], depth_dict['real_depth'], depth_dict['zero_mean_depth'], depth_dict['edge']) tensor_depth_dict = convert2tensor(patch_depth_dict) resize_mask = open_small_mask(tensor_depth_dict['mask'], tensor_depth_dict['context'], 3, 41) with torch.no_grad(): device = config["gpu_ids"] if isinstance(config["gpu_ids"], int) and config["gpu_ids"] >= 0 else "cpu" depth_output = depth_feat_model.forward_3P(resize_mask, tensor_depth_dict['context'], tensor_depth_dict['zero_mean_depth'], tensor_depth_dict['edge'], unit_length=128, cuda=device) depth_output = depth_output.cpu() tensor_depth_dict['output'] = torch.exp(depth_output + depth_dict['mean_depth']) * \ tensor_depth_dict['mask'] + tensor_depth_dict['depth'] patch_depth_dict['output'] = tensor_depth_dict['output'].data.cpu().numpy().squeeze() depth_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W'])) depth_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \ patch_depth_dict['output'] depth_output = depth_dict['output'] * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context'] depth_output = smooth_cntsyn_gap(depth_dict['output'].copy() * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context'], depth_dict['mask'], depth_dict['context'], init_mask_region=depth_dict['mask']) if spdb is True: f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output * depth_dict['mask'] + depth_dict['depth']); ax2.imshow(depth_dict['output'] * depth_dict['mask'] + depth_dict['depth']); plt.show() import pdb; pdb.set_trace() depth_dict['output'] = depth_output * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context'] return depth_dict def update_info(mapping_dict, info_on_pix, *meshes): rt_meshes = [] for mesh in meshes: rt_meshes.append(relabel_node(mesh, mesh.nodes, [*mapping_dict.keys()][0], [*mapping_dict.values()][0])) x, y, _ = [*mapping_dict.keys()][0] info_on_pix[(x, y)][0]['depth'] = [*mapping_dict.values()][0][2] return [info_on_pix] + rt_meshes def build_connection(mesh, cur_node, dst_node): if (abs(cur_node[0] - dst_node[0]) + abs(cur_node[1] - dst_node[1])) < 2: mesh.add_edge(cur_node, dst_node) if abs(cur_node[0] - dst_node[0]) > 1 or abs(cur_node[1] - dst_node[1]) > 1: return mesh ne_nodes = [*mesh.neighbors(cur_node)].copy() for ne_node in ne_nodes: if mesh.has_edge(ne_node, dst_node) or ne_node == dst_node: continue else: mesh = build_connection(mesh, ne_node, dst_node) return mesh def recursive_add_edge(edge_mesh, mesh, info_on_pix, cur_node, mark): ne_nodes = [(x[0], x[1]) for x in edge_mesh.neighbors(cur_node)] for node_xy in ne_nodes: node = (node_xy[0], node_xy[1], info_on_pix[node_xy][0]['depth']) if mark[node[0], node[1]] != 3: continue else: mark[node[0], node[1]] = 0 mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)]) mesh = build_connection(mesh, cur_node, node) re_info = dict(depth=0, count=0) for re_ne in mesh.neighbors(node): re_info['depth'] += re_ne[2] re_info['count'] += 1. try: re_depth = re_info['depth'] / re_info['count'] except: re_depth = node[2] re_node = (node_xy[0], node_xy[1], re_depth) mapping_dict = {node: re_node} info_on_pix, edge_mesh, mesh = update_info(mapping_dict, info_on_pix, edge_mesh, mesh) edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, re_node, mark) return edge_mesh, mesh, mark, info_on_pix def resize_for_edge(tensor_dict, largest_size): resize_dict = {k: v.clone() for k, v in tensor_dict.items()} frac = largest_size / np.array([*resize_dict['edge'].shape[-2:]]).max() if frac < 1: resize_mark = torch.nn.functional.interpolate(torch.cat((resize_dict['mask'], resize_dict['context']), dim=1), scale_factor=frac, mode='bilinear') resize_dict['mask'] = (resize_mark[:, 0:1] > 0).float() resize_dict['context'] = (resize_mark[:, 1:2] == 1).float() resize_dict['context'][resize_dict['mask'] > 0] = 0 resize_dict['edge'] = torch.nn.functional.interpolate(resize_dict['edge'], scale_factor=frac, mode='bilinear') resize_dict['edge'] = (resize_dict['edge'] > 0).float() resize_dict['edge'] = resize_dict['edge'] * resize_dict['context'] resize_dict['disp'] = torch.nn.functional.interpolate(resize_dict['disp'], scale_factor=frac, mode='nearest') resize_dict['disp'] = resize_dict['disp'] * resize_dict['context'] resize_dict['rgb'] = torch.nn.functional.interpolate(resize_dict['rgb'], scale_factor=frac, mode='bilinear') resize_dict['rgb'] = resize_dict['rgb'] * resize_dict['context'] return resize_dict def get_map_from_nodes(nodes, height, width): omap = np.zeros((height, width)) for n in nodes: omap[n[0], n[1]] = 1 return omap def get_map_from_ccs(ccs, height, width, condition_input=None, condition=None, real_id=False, id_shift=0): if condition is None: condition = lambda x, condition_input: True if real_id is True: omap = np.zeros((height, width)) + (-1) + id_shift else: omap = np.zeros((height, width)) for cc_id, cc in enumerate(ccs): for n in cc: if condition(n, condition_input): if real_id is True: omap[n[0], n[1]] = cc_id + id_shift else: omap[n[0], n[1]] = 1 return omap def revise_map_by_nodes(nodes, imap, operation, limit_constr=None): assert operation == '+' or operation == '-', "Operation must be '+' (union) or '-' (exclude)" omap = copy.deepcopy(imap) revise_flag = True if operation == '+': for n in nodes: omap[n[0], n[1]] = 1 if limit_constr is not None and omap.sum() > limit_constr: omap = imap revise_flag = False elif operation == '-': for n in nodes: omap[n[0], n[1]] = 0 if limit_constr is not None and omap.sum() < limit_constr: omap = imap revise_flag = False return omap, revise_flag def repaint_info(mesh, cc, x_anchor, y_anchor, source_type): if source_type == 'rgb': feat = np.zeros((3, x_anchor[1] - x_anchor[0], y_anchor[1] - y_anchor[0])) else: feat = np.zeros((1, x_anchor[1] - x_anchor[0], y_anchor[1] - y_anchor[0])) for node in cc: if source_type == 'rgb': feat[:, node[0] - x_anchor[0], node[1] - y_anchor[0]] = np.array(mesh.nodes[node]['color']) / 255. elif source_type == 'd': feat[:, node[0] - x_anchor[0], node[1] - y_anchor[0]] = abs(node[2]) return feat def get_context_from_nodes(mesh, cc, H, W, source_type=''): if 'rgb' in source_type or 'color' in source_type: feat = np.zeros((H, W, 3)) else: feat = np.zeros((H, W)) context = np.zeros((H, W)) for node in cc: if 'rgb' in source_type or 'color' in source_type: feat[node[0], node[1]] = np.array(mesh.nodes[node]['color']) / 255. context[node[0], node[1]] = 1 else: feat[node[0], node[1]] = abs(node[2]) return feat, context def get_mask_from_nodes(mesh, cc, H, W): mask = np.zeros((H, W)) for node in cc: mask[node[0], node[1]] = abs(node[2]) return mask def get_edge_from_nodes(context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc, H, W, mesh): context = np.zeros((H, W)) mask = np.zeros((H, W)) rgb = np.zeros((H, W, 3)) disp = np.zeros((H, W)) depth = np.zeros((H, W)) real_depth = np.zeros((H, W)) edge = np.zeros((H, W)) comp_edge = np.zeros((H, W)) fpath_map = np.zeros((H, W)) - 1 npath_map = np.zeros((H, W)) - 1 near_depth = np.zeros((H, W)) for node in context_cc: rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color']) disp[node[0], node[1]] = mesh.nodes[node]['disp'] depth[node[0], node[1]] = node[2] context[node[0], node[1]] = 1 for node in erode_context_cc: rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color']) disp[node[0], node[1]] = mesh.nodes[node]['disp'] depth[node[0], node[1]] = node[2] context[node[0], node[1]] = 1 rgb = rgb / 255. disp = np.abs(disp) disp = disp / disp.max() real_depth = depth.copy() for node in context_cc: if mesh.nodes[node].get('real_depth') is not None: real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth'] for node in erode_context_cc: if mesh.nodes[node].get('real_depth') is not None: real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth'] for node in mask_cc: mask[node[0], node[1]] = 1 near_depth[node[0], node[1]] = node[2] for node in edge_cc: edge[node[0], node[1]] = 1 for node in extend_edge_cc: comp_edge[node[0], node[1]] = 1 rt_dict = {'rgb': rgb, 'disp': disp, 'depth': depth, 'real_depth': real_depth, 'self_edge': edge, 'context': context, 'mask': mask, 'fpath_map': fpath_map, 'npath_map': npath_map, 'comp_edge': comp_edge, 'valid_area': context + mask, 'near_depth': near_depth} return rt_dict def get_depth_from_maps(context_map, mask_map, depth_map, H, W, log_depth=False): context = context_map.astype(np.uint8) mask = mask_map.astype(np.uint8).copy() depth = np.abs(depth_map) real_depth = depth.copy() zero_mean_depth = np.zeros((H, W)) if log_depth is True: log_depth = np.log(real_depth + 1e-8) * context mean_depth = np.mean(log_depth[context > 0]) zero_mean_depth = (log_depth - mean_depth) * context else: zero_mean_depth = real_depth mean_depth = 0 edge = np.zeros_like(depth) rt_dict = {'depth': depth, 'real_depth': real_depth, 'context': context, 'mask': mask, 'mean_depth': mean_depth, 'zero_mean_depth': zero_mean_depth, 'edge': edge} return rt_dict def get_depth_from_nodes(context_cc, erode_context_cc, mask_cc, H, W, mesh, log_depth=False): context = np.zeros((H, W)) mask = np.zeros((H, W)) depth = np.zeros((H, W)) real_depth = np.zeros((H, W)) zero_mean_depth = np.zeros((H, W)) for node in context_cc: depth[node[0], node[1]] = node[2] context[node[0], node[1]] = 1 for node in erode_context_cc: depth[node[0], node[1]] = node[2] context[node[0], node[1]] = 1 depth = np.abs(depth) real_depth = depth.copy() for node in context_cc: if mesh.nodes[node].get('real_depth') is not None: real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth'] for node in erode_context_cc: if mesh.nodes[node].get('real_depth') is not None: real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth'] real_depth = np.abs(real_depth) for node in mask_cc: mask[node[0], node[1]] = 1 if log_depth is True: log_depth = np.log(real_depth + 1e-8) * context mean_depth = np.mean(log_depth[context > 0]) zero_mean_depth = (log_depth - mean_depth) * context else: zero_mean_depth = real_depth mean_depth = 0 rt_dict = {'depth': depth, 'real_depth': real_depth, 'context': context, 'mask': mask, 'mean_depth': mean_depth, 'zero_mean_depth': zero_mean_depth} return rt_dict def get_rgb_from_nodes(context_cc, erode_context_cc, mask_cc, H, W, mesh): context = np.zeros((H, W)) mask = np.zeros((H, W)) rgb = np.zeros((H, W, 3)) erode_context = np.zeros((H, W)) for node in context_cc: rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color']) context[node[0], node[1]] = 1 rgb = rgb / 255. for node in mask_cc: mask[node[0], node[1]] = 1 for node in erode_context_cc: erode_context[node[0], node[1]] = 1 mask[node[0], node[1]] = 1 rt_dict = {'rgb': rgb, 'context': context, 'mask': mask, 'erode': erode_context} return rt_dict def crop_maps_by_size(size, *imaps): omaps = [] for imap in imaps: omaps.append(imap[size['x_min']:size['x_max'], size['y_min']:size['y_max']].copy()) return omaps def convert2tensor(input_dict): rt_dict = {} for key, value in input_dict.items(): if 'rgb' in key or 'color' in key: rt_dict[key] = torch.FloatTensor(value).permute(2, 0, 1)[None, ...] else: rt_dict[key] = torch.FloatTensor(value)[None, None, ...] return rt_dict