3D_Photo_Inpainting / mesh_tools.py
Saini
init
0b9f920
import os
import numpy as np
try:
import cynetworkx as netx
except ImportError:
import networkx as netx
import json
import scipy.misc as misc
#import OpenEXR
import scipy.signal as signal
import matplotlib.pyplot as plt
import cv2
import scipy.misc as misc
from skimage import io
from functools import partial
from vispy import scene, io
from vispy.scene import visuals
from functools import reduce
# from moviepy.editor import ImageSequenceClip
import scipy.misc as misc
from vispy.visuals.filters import Alpha
import cv2
from skimage.transform import resize
import copy
import torch
import os
from utils import refine_depth_around_edge, smooth_cntsyn_gap
from utils import require_depth_edge, filter_irrelevant_edge_new, open_small_mask
from skimage.feature import canny
from scipy import ndimage
import time
import transforms3d
def relabel_node(mesh, nodes, cur_node, new_node):
if cur_node == new_node:
return mesh
mesh.add_node(new_node)
for key, value in nodes[cur_node].items():
nodes[new_node][key] = value
for ne in mesh.neighbors(cur_node):
mesh.add_edge(new_node, ne)
mesh.remove_node(cur_node)
return mesh
def filter_edge(mesh, edge_ccs, config, invalid=False):
context_ccs = [set() for _ in edge_ccs]
mesh_nodes = mesh.nodes
for edge_id, edge_cc in enumerate(edge_ccs):
if config['context_thickness'] == 0:
continue
edge_group = {}
for edge_node in edge_cc:
far_nodes = mesh_nodes[edge_node].get('far')
if far_nodes is None:
continue
for far_node in far_nodes:
context_ccs[edge_id].add(far_node)
if mesh_nodes[far_node].get('edge_id') is not None:
if edge_group.get(mesh_nodes[far_node]['edge_id']) is None:
edge_group[mesh_nodes[far_node]['edge_id']] = set()
edge_group[mesh_nodes[far_node]['edge_id']].add(far_node)
if len(edge_cc) > 2:
for edge_key in [*edge_group.keys()]:
if len(edge_group[edge_key]) == 1:
context_ccs[edge_id].remove([*edge_group[edge_key]][0])
valid_edge_ccs = []
for xidx, yy in enumerate(edge_ccs):
if invalid is not True and len(context_ccs[xidx]) > 0:
# if len(context_ccs[xidx]) > 0:
valid_edge_ccs.append(yy)
elif invalid is True and len(context_ccs[xidx]) == 0:
valid_edge_ccs.append(yy)
else:
valid_edge_ccs.append(set())
# valid_edge_ccs = [yy for xidx, yy in enumerate(edge_ccs) if len(context_ccs[xidx]) > 0]
return valid_edge_ccs
def extrapolate(global_mesh,
info_on_pix,
image,
depth,
other_edge_with_id,
edge_map,
edge_ccs,
depth_edge_model,
depth_feat_model,
rgb_feat_model,
config,
direc='right-up'):
h_off, w_off = global_mesh.graph['hoffset'], global_mesh.graph['woffset']
noext_H, noext_W = global_mesh.graph['noext_H'], global_mesh.graph['noext_W']
if "up" in direc.lower() and "-" not in direc.lower():
all_anchor = [0, h_off + config['context_thickness'], w_off, w_off + noext_W]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [0, h_off, w_off, w_off + noext_W]
context_anchor = [h_off, h_off + config['context_thickness'], w_off, w_off + noext_W]
valid_line_anchor = [h_off, h_off + 1, w_off, w_off + noext_W]
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
elif "down" in direc.lower() and "-" not in direc.lower():
all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, w_off, w_off + noext_W]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, w_off, w_off + noext_W]
context_anchor = [h_off + noext_H - config['context_thickness'], h_off + noext_H, w_off, w_off + noext_W]
valid_line_anchor = [h_off + noext_H - 1, h_off + noext_H, w_off, w_off + noext_W]
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
elif "left" in direc.lower() and "-" not in direc.lower():
all_anchor = [h_off, h_off + noext_H, 0, w_off + config['context_thickness']]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [h_off, h_off + noext_H, 0, w_off]
context_anchor = [h_off, h_off + noext_H, w_off, w_off + config['context_thickness']]
valid_line_anchor = [h_off, h_off + noext_H, w_off, w_off + 1]
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
elif "right" in direc.lower() and "-" not in direc.lower():
all_anchor = [h_off, h_off + noext_H, w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [h_off, h_off + noext_H, w_off + noext_W, 2 * w_off + noext_W]
context_anchor = [h_off, h_off + noext_H, w_off + noext_W - config['context_thickness'], w_off + noext_W]
valid_line_anchor = [h_off, h_off + noext_H, w_off + noext_W - 1, w_off + noext_W]
valid_anchor = [min(mask_anchor[0], context_anchor[0]), max(mask_anchor[1], context_anchor[1]),
min(mask_anchor[2], context_anchor[2]), max(mask_anchor[3], context_anchor[3])]
elif "left" in direc.lower() and "up" in direc.lower() and "-" in direc.lower():
all_anchor = [0, h_off + config['context_thickness'], 0, w_off + config['context_thickness']]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [0, h_off, 0, w_off]
context_anchor = "inv-mask"
valid_line_anchor = None
valid_anchor = all_anchor
elif "left" in direc.lower() and "down" in direc.lower() and "-" in direc.lower():
all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, 0, w_off + config['context_thickness']]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, 0, w_off]
context_anchor = "inv-mask"
valid_line_anchor = None
valid_anchor = all_anchor
elif "right" in direc.lower() and "up" in direc.lower() and "-" in direc.lower():
all_anchor = [0, h_off + config['context_thickness'], w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [0, h_off, w_off + noext_W, 2 * w_off + noext_W]
context_anchor = "inv-mask"
valid_line_anchor = None
valid_anchor = all_anchor
elif "right" in direc.lower() and "down" in direc.lower() and "-" in direc.lower():
all_anchor = [h_off + noext_H - config['context_thickness'], 2 * h_off + noext_H, w_off + noext_W - config['context_thickness'], 2 * w_off + noext_W]
global_shift = [all_anchor[0], all_anchor[2]]
mask_anchor = [h_off + noext_H, 2 * h_off + noext_H, w_off + noext_W, 2 * w_off + noext_W]
context_anchor = "inv-mask"
valid_line_anchor = None
valid_anchor = all_anchor
global_mask = np.zeros_like(depth)
global_mask[mask_anchor[0]:mask_anchor[1],mask_anchor[2]:mask_anchor[3]] = 1
mask = global_mask[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * 1
context = 1 - mask
global_context = np.zeros_like(depth)
global_context[all_anchor[0]:all_anchor[1],all_anchor[2]:all_anchor[3]] = context
# context = global_context[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * 1
valid_area = mask + context
input_rgb = image[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] / 255. * context[..., None]
input_depth = depth[valid_anchor[0]:valid_anchor[1], valid_anchor[2]:valid_anchor[3]] * context
log_depth = np.log(input_depth + 1e-8)
log_depth[mask > 0] = 0
input_mean_depth = np.mean(log_depth[context > 0])
input_zero_mean_depth = (log_depth - input_mean_depth) * context
input_disp = 1./np.abs(input_depth)
input_disp[mask > 0] = 0
input_disp = input_disp / input_disp.max()
valid_line = np.zeros_like(depth)
if valid_line_anchor is not None:
valid_line[valid_line_anchor[0]:valid_line_anchor[1], valid_line_anchor[2]:valid_line_anchor[3]] = 1
valid_line = valid_line[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]]
# f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(global_context * 1 + global_mask * 2); ax2.imshow(image); plt.show()
# f, ((ax1, ax2, ax3)) = plt.subplots(1, 3, sharex=True, sharey=True); ax1.imshow(context * 1 + mask * 2); ax2.imshow(input_rgb); ax3.imshow(valid_line); plt.show()
# import pdb; pdb.set_trace()
# return
input_edge_map = edge_map[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] * context
input_other_edge_with_id = other_edge_with_id[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]]
end_depth_maps = ((valid_line * input_edge_map) > 0) * input_depth
if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0):
device = config["gpu_ids"]
else:
device = "cpu"
valid_edge_ids = sorted(list(input_other_edge_with_id[(valid_line * input_edge_map) > 0]))
valid_edge_ids = valid_edge_ids[1:] if (len(valid_edge_ids) > 0 and valid_edge_ids[0] == -1) else valid_edge_ids
edge = reduce(lambda x, y: (x + (input_other_edge_with_id == y).astype(np.uint8)).clip(0, 1), [np.zeros_like(mask)] + list(valid_edge_ids))
t_edge = torch.FloatTensor(edge).to(device)[None, None, ...]
t_rgb = torch.FloatTensor(input_rgb).to(device).permute(2,0,1).unsqueeze(0)
t_mask = torch.FloatTensor(mask).to(device)[None, None, ...]
t_context = torch.FloatTensor(context).to(device)[None, None, ...]
t_disp = torch.FloatTensor(input_disp).to(device)[None, None, ...]
t_depth_zero_mean_depth = torch.FloatTensor(input_zero_mean_depth).to(device)[None, None, ...]
depth_edge_output = depth_edge_model.forward_3P(t_mask, t_context, t_rgb, t_disp, t_edge, unit_length=128,
cuda=device)
t_output_edge = (depth_edge_output> config['ext_edge_threshold']).float() * t_mask + t_edge
output_raw_edge = t_output_edge.data.cpu().numpy().squeeze()
# import pdb; pdb.set_trace()
mesh = netx.Graph()
hxs, hys = np.where(output_raw_edge * mask > 0)
valid_map = mask + context
for hx, hy in zip(hxs, hys):
node = (hx, hy)
mesh.add_node((hx, hy))
eight_nes = [ne for ne in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), \
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)]\
if 0 <= ne[0] < output_raw_edge.shape[0] and 0 <= ne[1] < output_raw_edge.shape[1] and 0 < output_raw_edge[ne[0], ne[1]]]
for ne in eight_nes:
mesh.add_edge(node, ne, length=np.hypot(ne[0] - hx, ne[1] - hy))
if end_depth_maps[ne[0], ne[1]] != 0:
mesh.nodes[ne[0], ne[1]]['cnt'] = True
mesh.nodes[ne[0], ne[1]]['depth'] = end_depth_maps[ne[0], ne[1]]
ccs = [*netx.connected_components(mesh)]
end_pts = []
for cc in ccs:
end_pts.append(set())
for node in cc:
if mesh.nodes[node].get('cnt') is not None:
end_pts[-1].add((node[0], node[1], mesh.nodes[node]['depth']))
fpath_map = np.zeros_like(output_raw_edge) - 1
npath_map = np.zeros_like(output_raw_edge) - 1
for end_pt, cc in zip(end_pts, ccs):
sorted_end_pt = []
if len(end_pt) >= 2:
continue
if len(end_pt) == 0:
continue
if len(end_pt) == 1:
sub_mesh = mesh.subgraph(list(cc)).copy()
pnodes = netx.periphery(sub_mesh)
ends = [*end_pt]
edge_id = global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])]['edge_id']
pnodes = sorted(pnodes,
key=lambda x: np.hypot((x[0] - ends[0][0]), (x[1] - ends[0][1])),
reverse=True)[0]
npath = [*netx.shortest_path(sub_mesh, (ends[0][0], ends[0][1]), pnodes, weight='length')]
for np_node in npath:
npath_map[np_node[0], np_node[1]] = edge_id
fpath = []
if global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])].get('far') is None:
print("None far")
import pdb; pdb.set_trace()
else:
fnodes = global_mesh.nodes[(ends[0][0] + all_anchor[0], ends[0][1] + all_anchor[2], -ends[0][2])].get('far')
fnodes = [(xx[0] - all_anchor[0], xx[1] - all_anchor[2], xx[2]) for xx in fnodes]
dmask = mask + 0
did = 0
while True:
did += 1
dmask = cv2.dilate(dmask, np.ones((3, 3)), iterations=1)
if did > 3:
break
# ffnode = [fnode for fnode in fnodes if (dmask[fnode[0], fnode[1]] > 0)]
ffnode = [fnode for fnode in fnodes if (dmask[fnode[0], fnode[1]] > 0 and mask[fnode[0], fnode[1]] == 0)]
if len(ffnode) > 0:
fnode = ffnode[0]
break
if len(ffnode) == 0:
continue
fpath.append((fnode[0], fnode[1]))
for step in range(0, len(npath) - 1):
parr = (npath[step + 1][0] - npath[step][0], npath[step + 1][1] - npath[step][1])
new_loc = (fpath[-1][0] + parr[0], fpath[-1][1] + parr[1])
new_loc_nes = [xx for xx in [(new_loc[0] + 1, new_loc[1]), (new_loc[0] - 1, new_loc[1]),
(new_loc[0], new_loc[1] + 1), (new_loc[0], new_loc[1] - 1)]\
if xx[0] >= 0 and xx[0] < fpath_map.shape[0] and xx[1] >= 0 and xx[1] < fpath_map.shape[1]]
if np.sum([fpath_map[nlne[0], nlne[1]] for nlne in new_loc_nes]) != -4:
break
if npath_map[new_loc[0], new_loc[1]] != -1:
if npath_map[new_loc[0], new_loc[1]] != edge_id:
break
else:
continue
if valid_area[new_loc[0], new_loc[1]] == 0:
break
new_loc_nes_eight = [xx for xx in [(new_loc[0] + 1, new_loc[1]), (new_loc[0] - 1, new_loc[1]),
(new_loc[0], new_loc[1] + 1), (new_loc[0], new_loc[1] - 1),
(new_loc[0] + 1, new_loc[1] + 1), (new_loc[0] + 1, new_loc[1] - 1),
(new_loc[0] - 1, new_loc[1] - 1), (new_loc[0] - 1, new_loc[1] + 1)]\
if xx[0] >= 0 and xx[0] < fpath_map.shape[0] and xx[1] >= 0 and xx[1] < fpath_map.shape[1]]
if np.sum([int(npath_map[nlne[0], nlne[1]] == edge_id) for nlne in new_loc_nes_eight]) == 0:
break
fpath.append((fpath[-1][0] + parr[0], fpath[-1][1] + parr[1]))
if step != len(npath) - 2:
for xx in npath[step+1:]:
if npath_map[xx[0], xx[1]] == edge_id:
npath_map[xx[0], xx[1]] = -1
if len(fpath) > 0:
for fp_node in fpath:
fpath_map[fp_node[0], fp_node[1]] = edge_id
# import pdb; pdb.set_trace()
far_edge = (fpath_map > -1).astype(np.uint8)
update_edge = (npath_map > -1) * mask + edge
t_update_edge = torch.FloatTensor(update_edge).to(device)[None, None, ...]
depth_output = depth_feat_model.forward_3P(t_mask, t_context, t_depth_zero_mean_depth, t_update_edge, unit_length=128,
cuda=device)
depth_output = depth_output.cpu().data.numpy().squeeze()
depth_output = np.exp(depth_output + input_mean_depth) * mask # + input_depth * context
# if "right" in direc.lower() and "-" not in direc.lower():
# plt.imshow(depth_output); plt.show()
# import pdb; pdb.set_trace()
# f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output); ax2.imshow(npath_map + fpath_map); plt.show()
for near_id in np.unique(npath_map[npath_map > -1]):
depth_output = refine_depth_around_edge(depth_output.copy(),
(fpath_map == near_id).astype(np.uint8) * mask, # far_edge_map_in_mask,
(fpath_map == near_id).astype(np.uint8), # far_edge_map,
(npath_map == near_id).astype(np.uint8) * mask,
mask.copy(),
np.zeros_like(mask),
config)
# if "right" in direc.lower() and "-" not in direc.lower():
# plt.imshow(depth_output); plt.show()
# import pdb; pdb.set_trace()
# f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True); ax1.imshow(depth_output); ax2.imshow(npath_map + fpath_map); plt.show()
rgb_output = rgb_feat_model.forward_3P(t_mask, t_context, t_rgb, t_update_edge, unit_length=128,
cuda=device)
# rgb_output = rgb_feat_model.forward_3P(t_mask, t_context, t_rgb, t_update_edge, unit_length=128, cuda=config['gpu_ids'])
if config.get('gray_image') is True:
rgb_output = rgb_output.mean(1, keepdim=True).repeat((1,3,1,1))
rgb_output = ((rgb_output.squeeze().data.cpu().permute(1,2,0).numpy() * mask[..., None] + input_rgb) * 255).astype(np.uint8)
image[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]][mask > 0] = rgb_output[mask > 0] # np.array([255,0,0]) # rgb_output[mask > 0]
depth[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]][mask > 0] = depth_output[mask > 0]
# nxs, nys = np.where(mask > -1)
# for nx, ny in zip(nxs, nys):
# info_on_pix[(nx, ny)][0]['color'] = rgb_output[]
nxs, nys = np.where((npath_map > -1))
for nx, ny in zip(nxs, nys):
n_id = npath_map[nx, ny]
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
if 0 <= xx[0] < fpath_map.shape[0] and 0 <= xx[1] < fpath_map.shape[1]]
for nex, ney in four_nes:
if fpath_map[nex, ney] == n_id:
na, nb = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']), \
(nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
if global_mesh.has_edge(na, nb):
global_mesh.remove_edge(na, nb)
nxs, nys = np.where((fpath_map > -1))
for nx, ny in zip(nxs, nys):
n_id = fpath_map[nx, ny]
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
if 0 <= xx[0] < npath_map.shape[0] and 0 <= xx[1] < npath_map.shape[1]]
for nex, ney in four_nes:
if npath_map[nex, ney] == n_id:
na, nb = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth']), \
(nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
if global_mesh.has_edge(na, nb):
global_mesh.remove_edge(na, nb)
nxs, nys = np.where(mask > 0)
for x, y in zip(nxs, nys):
x = x + all_anchor[0]
y = y + all_anchor[2]
cur_node = (x, y, 0)
new_node = (x, y, -abs(depth[x, y]))
disp = 1. / -abs(depth[x, y])
mapping_dict = {cur_node: new_node}
info_on_pix, global_mesh = update_info(mapping_dict, info_on_pix, global_mesh)
global_mesh.nodes[new_node]['color'] = image[x, y]
global_mesh.nodes[new_node]['old_color'] = image[x, y]
global_mesh.nodes[new_node]['disp'] = disp
info_on_pix[(x, y)][0]['depth'] = -abs(depth[x, y])
info_on_pix[(x, y)][0]['disp'] = disp
info_on_pix[(x, y)][0]['color'] = image[x, y]
nxs, nys = np.where((npath_map > -1))
for nx, ny in zip(nxs, nys):
self_node = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth'])
if global_mesh.has_node(self_node) is False:
break
n_id = int(round(npath_map[nx, ny]))
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
if 0 <= xx[0] < fpath_map.shape[0] and 0 <= xx[1] < fpath_map.shape[1]]
for nex, ney in four_nes:
ne_node = (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
if global_mesh.has_node(ne_node) is False:
continue
if fpath_map[nex, ney] == n_id:
if global_mesh.nodes[self_node].get('edge_id') is None:
global_mesh.nodes[self_node]['edge_id'] = n_id
edge_ccs[n_id].add(self_node)
info_on_pix[(self_node[0], self_node[1])][0]['edge_id'] = n_id
if global_mesh.has_edge(self_node, ne_node) is True:
global_mesh.remove_edge(self_node, ne_node)
if global_mesh.nodes[self_node].get('far') is None:
global_mesh.nodes[self_node]['far'] = []
global_mesh.nodes[self_node]['far'].append(ne_node)
global_fpath_map = np.zeros_like(other_edge_with_id) - 1
global_fpath_map[all_anchor[0]:all_anchor[1], all_anchor[2]:all_anchor[3]] = fpath_map
fpath_ids = np.unique(global_fpath_map)
fpath_ids = fpath_ids[1:] if fpath_ids.shape[0] > 0 and fpath_ids[0] == -1 else []
fpath_real_id_map = np.zeros_like(global_fpath_map) - 1
for fpath_id in fpath_ids:
fpath_real_id = np.unique(((global_fpath_map == fpath_id).astype(np.int) * (other_edge_with_id + 1)) - 1)
fpath_real_id = fpath_real_id[1:] if fpath_real_id.shape[0] > 0 and fpath_real_id[0] == -1 else []
fpath_real_id = fpath_real_id.astype(np.int)
fpath_real_id = np.bincount(fpath_real_id).argmax()
fpath_real_id_map[global_fpath_map == fpath_id] = fpath_real_id
nxs, nys = np.where((fpath_map > -1))
for nx, ny in zip(nxs, nys):
self_node = (nx + all_anchor[0], ny + all_anchor[2], info_on_pix[(nx + all_anchor[0], ny + all_anchor[2])][0]['depth'])
n_id = fpath_map[nx, ny]
four_nes = [xx for xx in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)]\
if 0 <= xx[0] < npath_map.shape[0] and 0 <= xx[1] < npath_map.shape[1]]
for nex, ney in four_nes:
ne_node = (nex + all_anchor[0], ney + all_anchor[2], info_on_pix[(nex + all_anchor[0], ney + all_anchor[2])][0]['depth'])
if global_mesh.has_node(ne_node) is False:
continue
if npath_map[nex, ney] == n_id or global_mesh.nodes[ne_node].get('edge_id') == n_id:
if global_mesh.has_edge(self_node, ne_node) is True:
global_mesh.remove_edge(self_node, ne_node)
if global_mesh.nodes[self_node].get('near') is None:
global_mesh.nodes[self_node]['near'] = []
if global_mesh.nodes[self_node].get('edge_id') is None:
f_id = int(round(fpath_real_id_map[self_node[0], self_node[1]]))
global_mesh.nodes[self_node]['edge_id'] = f_id
info_on_pix[(self_node[0], self_node[1])][0]['edge_id'] = f_id
edge_ccs[f_id].add(self_node)
global_mesh.nodes[self_node]['near'].append(ne_node)
return info_on_pix, global_mesh, image, depth, edge_ccs
# for edge_cc in edge_ccs:
# for edge_node in edge_cc:
# edge_ccs
# context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, init_mask_connect, edge_maps, extend_context_ccs, extend_edge_ccs
def get_valid_size(imap):
x_max = np.where(imap.sum(1).squeeze() > 0)[0].max() + 1
x_min = np.where(imap.sum(1).squeeze() > 0)[0].min()
y_max = np.where(imap.sum(0).squeeze() > 0)[0].max() + 1
y_min = np.where(imap.sum(0).squeeze() > 0)[0].min()
size_dict = {'x_max':x_max, 'y_max':y_max, 'x_min':x_min, 'y_min':y_min}
return size_dict
def dilate_valid_size(isize_dict, imap, dilate=[0, 0]):
osize_dict = copy.deepcopy(isize_dict)
osize_dict['x_min'] = max(0, osize_dict['x_min'] - dilate[0])
osize_dict['x_max'] = min(imap.shape[0], osize_dict['x_max'] + dilate[0])
osize_dict['y_min'] = max(0, osize_dict['y_min'] - dilate[0])
osize_dict['y_max'] = min(imap.shape[1], osize_dict['y_max'] + dilate[1])
return osize_dict
def size_operation(size_a, size_b, operation):
assert operation == '+' or operation == '-', "Operation must be '+' (union) or '-' (exclude)"
osize = {}
if operation == '+':
osize['x_min'] = min(size_a['x_min'], size_b['x_min'])
osize['y_min'] = min(size_a['y_min'], size_b['y_min'])
osize['x_max'] = max(size_a['x_max'], size_b['x_max'])
osize['y_max'] = max(size_a['y_max'], size_b['y_max'])
assert operation != '-', "Operation '-' is undefined !"
return osize
def fill_dummy_bord(mesh, info_on_pix, image, depth, config):
context = np.zeros_like(depth).astype(np.uint8)
context[mesh.graph['hoffset']:mesh.graph['hoffset'] + mesh.graph['noext_H'],
mesh.graph['woffset']:mesh.graph['woffset'] + mesh.graph['noext_W']] = 1
mask = 1 - context
xs, ys = np.where(mask > 0)
depth = depth * context
image = image * context[..., None]
cur_depth = 0
cur_disp = 0
color = [0, 0, 0]
for x, y in zip(xs, ys):
cur_node = (x, y, cur_depth)
mesh.add_node(cur_node, color=color,
synthesis=False,
disp=cur_disp,
cc_id=set(),
ext_pixel=True)
info_on_pix[(x, y)] = [{'depth':cur_depth,
'color':mesh.nodes[(x, y, cur_depth)]['color'],
'synthesis':False,
'disp':mesh.nodes[cur_node]['disp'],
'ext_pixel':True}]
# for x, y in zip(xs, ys):
four_nes = [(xx, yy) for xx, yy in [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)] if\
0 <= x < mesh.graph['H'] and 0 <= y < mesh.graph['W'] and info_on_pix.get((xx, yy)) is not None]
for ne in four_nes:
# if (ne[0] - x) + (ne[1] - y) == 1 and info_on_pix.get((ne[0], ne[1])) is not None:
mesh.add_edge(cur_node, (ne[0], ne[1], info_on_pix[(ne[0], ne[1])][0]['depth']))
return mesh, info_on_pix
def enlarge_border(mesh, info_on_pix, depth, image, config):
mesh.graph['hoffset'], mesh.graph['woffset'] = config['extrapolation_thickness'], config['extrapolation_thickness']
mesh.graph['bord_up'], mesh.graph['bord_left'], mesh.graph['bord_down'], mesh.graph['bord_right'] = \
0, 0, mesh.graph['H'], mesh.graph['W']
# new_image = np.pad(image,
# pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']),
# (config['extrapolation_thickness'], config['extrapolation_thickness']), (0, 0)),
# mode='constant')
# new_depth = np.pad(depth,
# pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']),
# (config['extrapolation_thickness'], config['extrapolation_thickness'])),
# mode='constant')
return mesh, info_on_pix, depth, image
def fill_missing_node(mesh, info_on_pix, image, depth):
for x in range(mesh.graph['bord_up'], mesh.graph['bord_down']):
for y in range(mesh.graph['bord_left'], mesh.graph['bord_right']):
if info_on_pix.get((x, y)) is None:
print("fill missing node = ", x, y)
import pdb; pdb.set_trace()
re_depth, re_count = 0, 0
for ne in [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)]:
if info_on_pix.get(ne) is not None:
re_depth += info_on_pix[ne][0]['depth']
re_count += 1
if re_count == 0:
re_depth = -abs(depth[x, y])
else:
re_depth = re_depth / re_count
depth[x, y] = abs(re_depth)
info_on_pix[(x, y)] = [{'depth':re_depth,
'color':image[x, y],
'synthesis':False,
'disp':1./re_depth}]
mesh.add_node((x, y, re_depth), color=image[x, y],
synthesis=False,
disp=1./re_depth,
cc_id=set())
return mesh, info_on_pix, depth
def refresh_bord_depth(mesh, info_on_pix, image, depth):
H, W = mesh.graph['H'], mesh.graph['W']
corner_nodes = [(mesh.graph['bord_up'], mesh.graph['bord_left']),
(mesh.graph['bord_up'], mesh.graph['bord_right'] - 1),
(mesh.graph['bord_down'] - 1, mesh.graph['bord_left']),
(mesh.graph['bord_down'] - 1, mesh.graph['bord_right'] - 1)]
# (0, W - 1), (H - 1, 0), (H - 1, W - 1)]
bord_nodes = []
bord_nodes += [(mesh.graph['bord_up'], xx) for xx in range(mesh.graph['bord_left'] + 1, mesh.graph['bord_right'] - 1)]
bord_nodes += [(mesh.graph['bord_down'] - 1, xx) for xx in range(mesh.graph['bord_left'] + 1, mesh.graph['bord_right'] - 1)]
bord_nodes += [(xx, mesh.graph['bord_left']) for xx in range(mesh.graph['bord_up'] + 1, mesh.graph['bord_down'] - 1)]
bord_nodes += [(xx, mesh.graph['bord_right'] - 1) for xx in range(mesh.graph['bord_up'] + 1, mesh.graph['bord_down'] - 1)]
for xy in bord_nodes:
tgt_loc = None
if xy[0] == mesh.graph['bord_up']:
tgt_loc = (xy[0] + 1, xy[1])# (1, xy[1])
elif xy[0] == mesh.graph['bord_down'] - 1:
tgt_loc = (xy[0] - 1, xy[1]) # (H - 2, xy[1])
elif xy[1] == mesh.graph['bord_left']:
tgt_loc = (xy[0], xy[1] + 1)
elif xy[1] == mesh.graph['bord_right'] - 1:
tgt_loc = (xy[0], xy[1] - 1)
if tgt_loc is not None:
ne_infos = info_on_pix.get(tgt_loc)
if ne_infos is None:
import pdb; pdb.set_trace()
# if ne_infos is not None and len(ne_infos) == 1:
tgt_depth = ne_infos[0]['depth']
tgt_disp = ne_infos[0]['disp']
new_node = (xy[0], xy[1], tgt_depth)
src_node = (tgt_loc[0], tgt_loc[1], tgt_depth)
tgt_nes_loc = [(xx[0], xx[1]) \
for xx in mesh.neighbors(src_node)]
tgt_nes_loc = [(xx[0] - tgt_loc[0] + xy[0], xx[1] - tgt_loc[1] + xy[1]) for xx in tgt_nes_loc \
if abs(xx[0] - xy[0]) == 1 and abs(xx[1] - xy[1]) == 1]
tgt_nes_loc = [xx for xx in tgt_nes_loc if info_on_pix.get(xx) is not None]
tgt_nes_loc.append(tgt_loc)
# if (xy[0], xy[1]) == (559, 60):
# import pdb; pdb.set_trace()
if info_on_pix.get(xy) is not None and len(info_on_pix.get(xy)) > 0:
old_depth = info_on_pix[xy][0].get('depth')
old_node = (xy[0], xy[1], old_depth)
mesh.remove_edges_from([(old_ne, old_node) for old_ne in mesh.neighbors(old_node)])
mesh.add_edges_from([((zz[0], zz[1], info_on_pix[zz][0]['depth']), old_node) for zz in tgt_nes_loc])
mapping_dict = {old_node: new_node}
# if old_node[2] == new_node[2]:
# print("mapping_dict = ", mapping_dict)
info_on_pix, mesh = update_info(mapping_dict, info_on_pix, mesh)
else:
info_on_pix[xy] = []
info_on_pix[xy][0] = info_on_pix[tgt_loc][0]
info_on_pix['color'] = image[xy[0], xy[1]]
info_on_pix['old_color'] = image[xy[0], xy[1]]
mesh.add_node(new_node)
mesh.add_edges_from([((zz[0], zz[1], info_on_pix[zz][0]['depth']), new_node) for zz in tgt_nes_loc])
mesh.nodes[new_node]['far'] = None
mesh.nodes[new_node]['near'] = None
if mesh.nodes[src_node].get('far') is not None:
redundant_nodes = [ne for ne in mesh.nodes[src_node]['far'] if (ne[0], ne[1]) == xy]
[mesh.nodes[src_node]['far'].remove(aa) for aa in redundant_nodes]
if mesh.nodes[src_node].get('near') is not None:
redundant_nodes = [ne for ne in mesh.nodes[src_node]['near'] if (ne[0], ne[1]) == xy]
[mesh.nodes[src_node]['near'].remove(aa) for aa in redundant_nodes]
for xy in corner_nodes:
hx, hy = xy
four_nes = [xx for xx in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] if \
mesh.graph['bord_up'] <= xx[0] < mesh.graph['bord_down'] and \
mesh.graph['bord_left'] <= xx[1] < mesh.graph['bord_right']]
ne_nodes = []
ne_depths = []
for ne_loc in four_nes:
if info_on_pix.get(ne_loc) is not None:
ne_depths.append(info_on_pix[ne_loc][0]['depth'])
ne_nodes.append((ne_loc[0], ne_loc[1], info_on_pix[ne_loc][0]['depth']))
new_node = (xy[0], xy[1], float(np.mean(ne_depths)))
if info_on_pix.get(xy) is not None and len(info_on_pix.get(xy)) > 0:
old_depth = info_on_pix[xy][0].get('depth')
old_node = (xy[0], xy[1], old_depth)
mesh.remove_edges_from([(old_ne, old_node) for old_ne in mesh.neighbors(old_node)])
mesh.add_edges_from([(zz, old_node) for zz in ne_nodes])
mapping_dict = {old_node: new_node}
info_on_pix, mesh = update_info(mapping_dict, info_on_pix, mesh)
else:
info_on_pix[xy] = []
info_on_pix[xy][0] = info_on_pix[ne_loc[-1]][0]
info_on_pix['color'] = image[xy[0], xy[1]]
info_on_pix['old_color'] = image[xy[0], xy[1]]
mesh.add_node(new_node)
mesh.add_edges_from([(zz, new_node) for zz in ne_nodes])
mesh.nodes[new_node]['far'] = None
mesh.nodes[new_node]['near'] = None
for xy in bord_nodes + corner_nodes:
# if (xy[0], xy[1]) == (559, 60):
# import pdb; pdb.set_trace()
depth[xy[0], xy[1]] = abs(info_on_pix[xy][0]['depth'])
for xy in bord_nodes:
cur_node = (xy[0], xy[1], info_on_pix[xy][0]['depth'])
nes = mesh.neighbors(cur_node)
four_nes = set([(xy[0] + 1, xy[1]), (xy[0] - 1, xy[1]), (xy[0], xy[1] + 1), (xy[0], xy[1] - 1)]) - \
set([(ne[0], ne[1]) for ne in nes])
four_nes = [ne for ne in four_nes if mesh.graph['bord_up'] <= ne[0] < mesh.graph['bord_down'] and \
mesh.graph['bord_left'] <= ne[1] < mesh.graph['bord_right']]
four_nes = [(ne[0], ne[1], info_on_pix[(ne[0], ne[1])][0]['depth']) for ne in four_nes]
mesh.nodes[cur_node]['far'] = []
mesh.nodes[cur_node]['near'] = []
for ne in four_nes:
if abs(ne[2]) >= abs(cur_node[2]):
mesh.nodes[cur_node]['far'].append(ne)
else:
mesh.nodes[cur_node]['near'].append(ne)
return mesh, info_on_pix, depth
def get_union_size(mesh, dilate, *alls_cc):
all_cc = reduce(lambda x, y: x | y, [set()] + [*alls_cc])
min_x, min_y, max_x, max_y = mesh.graph['H'], mesh.graph['W'], 0, 0
H, W = mesh.graph['H'], mesh.graph['W']
for node in all_cc:
if node[0] < min_x:
min_x = node[0]
if node[0] > max_x:
max_x = node[0]
if node[1] < min_y:
min_y = node[1]
if node[1] > max_y:
max_y = node[1]
max_x = max_x + 1
max_y = max_y + 1
# mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20])
osize_dict = dict()
osize_dict['x_min'] = max(0, min_x - dilate[0])
osize_dict['x_max'] = min(H, max_x + dilate[0])
osize_dict['y_min'] = max(0, min_y - dilate[1])
osize_dict['y_max'] = min(W, max_y + dilate[1])
return osize_dict
def incomplete_node(mesh, edge_maps, info_on_pix):
vis_map = np.zeros((mesh.graph['H'], mesh.graph['W']))
for node in mesh.nodes:
if mesh.nodes[node].get('synthesis') is not True:
connect_all_flag = False
nes = [xx for xx in mesh.neighbors(node) if mesh.nodes[xx].get('synthesis') is not True]
if len(nes) < 3 and 0 < node[0] < mesh.graph['H'] - 1 and 0 < node[1] < mesh.graph['W'] - 1:
if len(nes) <= 1:
connect_all_flag = True
else:
dan_ne_node_a = nes[0]
dan_ne_node_b = nes[1]
if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \
abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1:
connect_all_flag = True
if connect_all_flag == True:
vis_map[node[0], node[1]] = len(nes)
four_nes = [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)]
for ne in four_nes:
for info in info_on_pix[(ne[0], ne[1])]:
ne_node = (ne[0], ne[1], info['depth'])
if info.get('synthesis') is not True and mesh.has_node(ne_node):
mesh.add_edge(node, ne_node)
break
return mesh
def edge_inpainting(edge_id, context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc,
mesh, edge_map, edge_maps_with_id, config, union_size, depth_edge_model, inpaint_iter):
edge_dict = get_edge_from_nodes(context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc,
mesh.graph['H'], mesh.graph['W'], mesh)
edge_dict['edge'], end_depth_maps, _ = \
filter_irrelevant_edge_new(edge_dict['self_edge'] + edge_dict['comp_edge'],
edge_map,
edge_maps_with_id,
edge_id,
edge_dict['context'],
edge_dict['depth'], mesh, context_cc | erode_context_cc, spdb=True)
patch_edge_dict = dict()
patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \
patch_edge_dict['disp'], patch_edge_dict['edge'] = \
crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'],
edge_dict['rgb'], edge_dict['disp'], edge_dict['edge'])
tensor_edge_dict = convert2tensor(patch_edge_dict)
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0:
with torch.no_grad():
device = config["gpu_ids"] if isinstance(config["gpu_ids"], int) and config["gpu_ids"] >= 0 else "cpu"
depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'],
tensor_edge_dict['context'],
tensor_edge_dict['rgb'],
tensor_edge_dict['disp'],
tensor_edge_dict['edge'],
unit_length=128,
cuda=device)
depth_edge_output = depth_edge_output.cpu()
tensor_edge_dict['output'] = (depth_edge_output > config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge']
else:
tensor_edge_dict['output'] = tensor_edge_dict['edge']
depth_edge_output = tensor_edge_dict['edge'] + 0
patch_edge_dict['output'] = tensor_edge_dict['output'].squeeze().data.cpu().numpy()
edge_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W']))
edge_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
patch_edge_dict['output']
return edge_dict, end_depth_maps
def depth_inpainting(context_cc, extend_context_cc, erode_context_cc, mask_cc, mesh, config, union_size, depth_feat_model, edge_output, given_depth_dict=False, spdb=False):
if given_depth_dict is False:
depth_dict = get_depth_from_nodes(context_cc | extend_context_cc, erode_context_cc, mask_cc, mesh.graph['H'], mesh.graph['W'], mesh, config['log_depth'])
if edge_output is not None:
depth_dict['edge'] = edge_output
else:
depth_dict = given_depth_dict
patch_depth_dict = dict()
patch_depth_dict['mask'], patch_depth_dict['context'], patch_depth_dict['depth'], \
patch_depth_dict['zero_mean_depth'], patch_depth_dict['edge'] = \
crop_maps_by_size(union_size, depth_dict['mask'], depth_dict['context'],
depth_dict['real_depth'], depth_dict['zero_mean_depth'], depth_dict['edge'])
tensor_depth_dict = convert2tensor(patch_depth_dict)
resize_mask = open_small_mask(tensor_depth_dict['mask'], tensor_depth_dict['context'], 3, 41)
with torch.no_grad():
device = config["gpu_ids"] if isinstance(config["gpu_ids"], int) and config["gpu_ids"] >= 0 else "cpu"
depth_output = depth_feat_model.forward_3P(resize_mask,
tensor_depth_dict['context'],
tensor_depth_dict['zero_mean_depth'],
tensor_depth_dict['edge'],
unit_length=128,
cuda=device)
depth_output = depth_output.cpu()
tensor_depth_dict['output'] = torch.exp(depth_output + depth_dict['mean_depth']) * \
tensor_depth_dict['mask'] + tensor_depth_dict['depth']
patch_depth_dict['output'] = tensor_depth_dict['output'].data.cpu().numpy().squeeze()
depth_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W']))
depth_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
patch_depth_dict['output']
depth_output = depth_dict['output'] * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context']
depth_output = smooth_cntsyn_gap(depth_dict['output'].copy() * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context'],
depth_dict['mask'], depth_dict['context'],
init_mask_region=depth_dict['mask'])
if spdb is True:
f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True);
ax1.imshow(depth_output * depth_dict['mask'] + depth_dict['depth']); ax2.imshow(depth_dict['output'] * depth_dict['mask'] + depth_dict['depth']); plt.show()
import pdb; pdb.set_trace()
depth_dict['output'] = depth_output * depth_dict['mask'] + depth_dict['depth'] * depth_dict['context']
return depth_dict
def update_info(mapping_dict, info_on_pix, *meshes):
rt_meshes = []
for mesh in meshes:
rt_meshes.append(relabel_node(mesh, mesh.nodes, [*mapping_dict.keys()][0], [*mapping_dict.values()][0]))
x, y, _ = [*mapping_dict.keys()][0]
info_on_pix[(x, y)][0]['depth'] = [*mapping_dict.values()][0][2]
return [info_on_pix] + rt_meshes
def build_connection(mesh, cur_node, dst_node):
if (abs(cur_node[0] - dst_node[0]) + abs(cur_node[1] - dst_node[1])) < 2:
mesh.add_edge(cur_node, dst_node)
if abs(cur_node[0] - dst_node[0]) > 1 or abs(cur_node[1] - dst_node[1]) > 1:
return mesh
ne_nodes = [*mesh.neighbors(cur_node)].copy()
for ne_node in ne_nodes:
if mesh.has_edge(ne_node, dst_node) or ne_node == dst_node:
continue
else:
mesh = build_connection(mesh, ne_node, dst_node)
return mesh
def recursive_add_edge(edge_mesh, mesh, info_on_pix, cur_node, mark):
ne_nodes = [(x[0], x[1]) for x in edge_mesh.neighbors(cur_node)]
for node_xy in ne_nodes:
node = (node_xy[0], node_xy[1], info_on_pix[node_xy][0]['depth'])
if mark[node[0], node[1]] != 3:
continue
else:
mark[node[0], node[1]] = 0
mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)])
mesh = build_connection(mesh, cur_node, node)
re_info = dict(depth=0, count=0)
for re_ne in mesh.neighbors(node):
re_info['depth'] += re_ne[2]
re_info['count'] += 1.
try:
re_depth = re_info['depth'] / re_info['count']
except:
re_depth = node[2]
re_node = (node_xy[0], node_xy[1], re_depth)
mapping_dict = {node: re_node}
info_on_pix, edge_mesh, mesh = update_info(mapping_dict, info_on_pix, edge_mesh, mesh)
edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, re_node, mark)
return edge_mesh, mesh, mark, info_on_pix
def resize_for_edge(tensor_dict, largest_size):
resize_dict = {k: v.clone() for k, v in tensor_dict.items()}
frac = largest_size / np.array([*resize_dict['edge'].shape[-2:]]).max()
if frac < 1:
resize_mark = torch.nn.functional.interpolate(torch.cat((resize_dict['mask'],
resize_dict['context']),
dim=1),
scale_factor=frac,
mode='bilinear')
resize_dict['mask'] = (resize_mark[:, 0:1] > 0).float()
resize_dict['context'] = (resize_mark[:, 1:2] == 1).float()
resize_dict['context'][resize_dict['mask'] > 0] = 0
resize_dict['edge'] = torch.nn.functional.interpolate(resize_dict['edge'],
scale_factor=frac,
mode='bilinear')
resize_dict['edge'] = (resize_dict['edge'] > 0).float()
resize_dict['edge'] = resize_dict['edge'] * resize_dict['context']
resize_dict['disp'] = torch.nn.functional.interpolate(resize_dict['disp'],
scale_factor=frac,
mode='nearest')
resize_dict['disp'] = resize_dict['disp'] * resize_dict['context']
resize_dict['rgb'] = torch.nn.functional.interpolate(resize_dict['rgb'],
scale_factor=frac,
mode='bilinear')
resize_dict['rgb'] = resize_dict['rgb'] * resize_dict['context']
return resize_dict
def get_map_from_nodes(nodes, height, width):
omap = np.zeros((height, width))
for n in nodes:
omap[n[0], n[1]] = 1
return omap
def get_map_from_ccs(ccs, height, width, condition_input=None, condition=None, real_id=False, id_shift=0):
if condition is None:
condition = lambda x, condition_input: True
if real_id is True:
omap = np.zeros((height, width)) + (-1) + id_shift
else:
omap = np.zeros((height, width))
for cc_id, cc in enumerate(ccs):
for n in cc:
if condition(n, condition_input):
if real_id is True:
omap[n[0], n[1]] = cc_id + id_shift
else:
omap[n[0], n[1]] = 1
return omap
def revise_map_by_nodes(nodes, imap, operation, limit_constr=None):
assert operation == '+' or operation == '-', "Operation must be '+' (union) or '-' (exclude)"
omap = copy.deepcopy(imap)
revise_flag = True
if operation == '+':
for n in nodes:
omap[n[0], n[1]] = 1
if limit_constr is not None and omap.sum() > limit_constr:
omap = imap
revise_flag = False
elif operation == '-':
for n in nodes:
omap[n[0], n[1]] = 0
if limit_constr is not None and omap.sum() < limit_constr:
omap = imap
revise_flag = False
return omap, revise_flag
def repaint_info(mesh, cc, x_anchor, y_anchor, source_type):
if source_type == 'rgb':
feat = np.zeros((3, x_anchor[1] - x_anchor[0], y_anchor[1] - y_anchor[0]))
else:
feat = np.zeros((1, x_anchor[1] - x_anchor[0], y_anchor[1] - y_anchor[0]))
for node in cc:
if source_type == 'rgb':
feat[:, node[0] - x_anchor[0], node[1] - y_anchor[0]] = np.array(mesh.nodes[node]['color']) / 255.
elif source_type == 'd':
feat[:, node[0] - x_anchor[0], node[1] - y_anchor[0]] = abs(node[2])
return feat
def get_context_from_nodes(mesh, cc, H, W, source_type=''):
if 'rgb' in source_type or 'color' in source_type:
feat = np.zeros((H, W, 3))
else:
feat = np.zeros((H, W))
context = np.zeros((H, W))
for node in cc:
if 'rgb' in source_type or 'color' in source_type:
feat[node[0], node[1]] = np.array(mesh.nodes[node]['color']) / 255.
context[node[0], node[1]] = 1
else:
feat[node[0], node[1]] = abs(node[2])
return feat, context
def get_mask_from_nodes(mesh, cc, H, W):
mask = np.zeros((H, W))
for node in cc:
mask[node[0], node[1]] = abs(node[2])
return mask
def get_edge_from_nodes(context_cc, erode_context_cc, mask_cc, edge_cc, extend_edge_cc, H, W, mesh):
context = np.zeros((H, W))
mask = np.zeros((H, W))
rgb = np.zeros((H, W, 3))
disp = np.zeros((H, W))
depth = np.zeros((H, W))
real_depth = np.zeros((H, W))
edge = np.zeros((H, W))
comp_edge = np.zeros((H, W))
fpath_map = np.zeros((H, W)) - 1
npath_map = np.zeros((H, W)) - 1
near_depth = np.zeros((H, W))
for node in context_cc:
rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color'])
disp[node[0], node[1]] = mesh.nodes[node]['disp']
depth[node[0], node[1]] = node[2]
context[node[0], node[1]] = 1
for node in erode_context_cc:
rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color'])
disp[node[0], node[1]] = mesh.nodes[node]['disp']
depth[node[0], node[1]] = node[2]
context[node[0], node[1]] = 1
rgb = rgb / 255.
disp = np.abs(disp)
disp = disp / disp.max()
real_depth = depth.copy()
for node in context_cc:
if mesh.nodes[node].get('real_depth') is not None:
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
for node in erode_context_cc:
if mesh.nodes[node].get('real_depth') is not None:
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
for node in mask_cc:
mask[node[0], node[1]] = 1
near_depth[node[0], node[1]] = node[2]
for node in edge_cc:
edge[node[0], node[1]] = 1
for node in extend_edge_cc:
comp_edge[node[0], node[1]] = 1
rt_dict = {'rgb': rgb, 'disp': disp, 'depth': depth, 'real_depth': real_depth, 'self_edge': edge, 'context': context,
'mask': mask, 'fpath_map': fpath_map, 'npath_map': npath_map, 'comp_edge': comp_edge, 'valid_area': context + mask,
'near_depth': near_depth}
return rt_dict
def get_depth_from_maps(context_map, mask_map, depth_map, H, W, log_depth=False):
context = context_map.astype(np.uint8)
mask = mask_map.astype(np.uint8).copy()
depth = np.abs(depth_map)
real_depth = depth.copy()
zero_mean_depth = np.zeros((H, W))
if log_depth is True:
log_depth = np.log(real_depth + 1e-8) * context
mean_depth = np.mean(log_depth[context > 0])
zero_mean_depth = (log_depth - mean_depth) * context
else:
zero_mean_depth = real_depth
mean_depth = 0
edge = np.zeros_like(depth)
rt_dict = {'depth': depth, 'real_depth': real_depth, 'context': context, 'mask': mask,
'mean_depth': mean_depth, 'zero_mean_depth': zero_mean_depth, 'edge': edge}
return rt_dict
def get_depth_from_nodes(context_cc, erode_context_cc, mask_cc, H, W, mesh, log_depth=False):
context = np.zeros((H, W))
mask = np.zeros((H, W))
depth = np.zeros((H, W))
real_depth = np.zeros((H, W))
zero_mean_depth = np.zeros((H, W))
for node in context_cc:
depth[node[0], node[1]] = node[2]
context[node[0], node[1]] = 1
for node in erode_context_cc:
depth[node[0], node[1]] = node[2]
context[node[0], node[1]] = 1
depth = np.abs(depth)
real_depth = depth.copy()
for node in context_cc:
if mesh.nodes[node].get('real_depth') is not None:
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
for node in erode_context_cc:
if mesh.nodes[node].get('real_depth') is not None:
real_depth[node[0], node[1]] = mesh.nodes[node]['real_depth']
real_depth = np.abs(real_depth)
for node in mask_cc:
mask[node[0], node[1]] = 1
if log_depth is True:
log_depth = np.log(real_depth + 1e-8) * context
mean_depth = np.mean(log_depth[context > 0])
zero_mean_depth = (log_depth - mean_depth) * context
else:
zero_mean_depth = real_depth
mean_depth = 0
rt_dict = {'depth': depth, 'real_depth': real_depth, 'context': context, 'mask': mask,
'mean_depth': mean_depth, 'zero_mean_depth': zero_mean_depth}
return rt_dict
def get_rgb_from_nodes(context_cc, erode_context_cc, mask_cc, H, W, mesh):
context = np.zeros((H, W))
mask = np.zeros((H, W))
rgb = np.zeros((H, W, 3))
erode_context = np.zeros((H, W))
for node in context_cc:
rgb[node[0], node[1]] = np.array(mesh.nodes[node]['color'])
context[node[0], node[1]] = 1
rgb = rgb / 255.
for node in mask_cc:
mask[node[0], node[1]] = 1
for node in erode_context_cc:
erode_context[node[0], node[1]] = 1
mask[node[0], node[1]] = 1
rt_dict = {'rgb': rgb, 'context': context, 'mask': mask,
'erode': erode_context}
return rt_dict
def crop_maps_by_size(size, *imaps):
omaps = []
for imap in imaps:
omaps.append(imap[size['x_min']:size['x_max'], size['y_min']:size['y_max']].copy())
return omaps
def convert2tensor(input_dict):
rt_dict = {}
for key, value in input_dict.items():
if 'rgb' in key or 'color' in key:
rt_dict[key] = torch.FloatTensor(value).permute(2, 0, 1)[None, ...]
else:
rt_dict[key] = torch.FloatTensor(value)[None, None, ...]
return rt_dict