zhigangjiang's picture
no message
88b0dcb
raw
history blame
No virus
3.34 kB
"""
@Date: 2021/11/06
@description:
"""
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from utils.conversion import depth2xyz
def convert_img(value, h, need_nor=True, cmap=None):
value = value.clone().detach().cpu().numpy()[None]
if need_nor:
value -= value.min()
value /= value.max() - value.min()
grad_img = value.repeat(int(h), axis=0)
if cmap is None:
grad_img = grad_img[..., np.newaxis].repeat(3, axis=-1)
elif cmap == cv2.COLORMAP_PLASMA:
grad_img = cv2.applyColorMap((grad_img * 255).astype(np.uint8), colormap=cmap)
grad_img = grad_img[..., ::-1]
grad_img = grad_img.astype(np.float) / 255.0
elif cmap == 'HSV':
grad_img = np.round(grad_img * 1000) / 1000.0
grad_img = grad_img[..., np.newaxis].repeat(3, axis=-1)
grad_img[..., 0] = grad_img[..., 0] * 180
grad_img[..., 1] = 255
grad_img[..., 2] = 255
grad_img = grad_img.astype(np.uint8)
grad_img = cv2.cvtColor(grad_img, cv2.COLOR_HSV2RGB)
grad_img = grad_img.astype(np.float) / 255.0
return grad_img
def show_grad(depth, grad_conv, h=5, show=False):
"""
:param h:
:param depth: [patch_num]
:param grad_conv:
:param show:
:return:
"""
direction, angle, grad = get_all(depth[None], grad_conv)
# depth_img = convert_img(depth, h)
# angle_img = convert_img(angle[0], h)
# grad_img = convert_img(grad[0], depth.shape[-1] // 4 - h * 2)
depth_img = convert_img(depth, h, cmap=cv2.COLORMAP_PLASMA)
angle_img = convert_img(angle[0], h, cmap='HSV')
# vis_grad = grad[0] / grad[0].max() / 2 + 0.5
grad_img = convert_img(grad[0], h)
img = np.concatenate([depth_img, angle_img, grad_img], axis=0)
if show:
plt.imshow(img)
plt.show()
return img
def get_grad(direction):
"""
:param direction: [b patch_num]
:return:[b patch_num]
"""
a = torch.roll(direction, -1, dims=1) # xz[i+1]
b = torch.roll(direction, 1, dims=1) # xz[i-1]
grad = torch.acos(torch.clip(a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1], -1+1e-6, 1-1e-6))
return grad
def get_grad2(angle, grad_conv):
"""
:param angle: [b patch_num]
:param grad_conv:
:return:[b patch_num]
"""
angle = torch.sin(angle)
angle = angle + 1
angle = torch.cat([angle[..., -1:], angle, angle[..., :1]], dim=-1)
grad = grad_conv(angle[:, None]) # [b, patch_num] -> [b, 1, patch_num]
# grad = torch.abs(grad)
return grad.reshape(angle.shape[0], -1)
def get_edge_angle(direction):
"""
:param direction: [b patch_num 2]
:return:
"""
angle = torch.atan2(direction[..., 1], direction[..., 0])
return angle
def get_edge_direction(depth):
xz = depth2xyz(depth)[..., ::2]
direction = torch.roll(xz, -1, dims=1) - xz # direct[i] = xz[i+1] - xz[i]
direction = direction / direction.norm(p=2, dim=-1)[..., None]
return direction
def get_all(depth, grad_conv):
"""
:param grad_conv:
:param depth: [b patch_num]
:return:
"""
direction = get_edge_direction(depth)
angle = get_edge_angle(direction)
# angle_grad = get_grad(direction)
angle_grad = get_grad2(angle, grad_conv) # signed gradient
return direction, angle, angle_grad