Metric3D / mono /utils /visualization.py
zach
initial commit based on github repo
3ef1661
raw
history blame contribute delete
No virus
4.83 kB
import matplotlib.pyplot as plt
import os, cv2
import numpy as np
from mono.utils.transform import gray_to_colormap
import shutil
import glob
from mono.utils.running import main_process
import torch
from html4vision import Col, imagetable
def save_raw_imgs(
pred: torch.tensor,
rgb: torch.tensor,
filename: str,
save_dir: str,
scale: float=200.0,
target: torch.tensor=None,
):
"""
Save raw GT, predictions, RGB in the same file.
"""
cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb)
cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16))
if target is not None:
cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16))
def save_val_imgs(
iter: int,
pred: torch.tensor,
target: torch.tensor,
rgb: torch.tensor,
filename: str,
save_dir: str,
tb_logger=None
):
"""
Save GT, predictions, RGB in the same file.
"""
rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb)
rgb = rgb.transpose((1, 2, 0))
cat_img = np.concatenate([rgb, pred_color, target_color], axis=0)
plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
# save to tensorboard
if tb_logger is not None:
tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
def save_normal_val_imgs(
iter: int,
pred: torch.tensor,
targ: torch.tensor,
rgb: torch.tensor,
filename: str,
save_dir: str,
tb_logger=None,
mask=None,
):
"""
Save GT, predictions, RGB in the same file.
"""
mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :]
std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :]
pred = pred.squeeze()
targ = targ.squeeze()
rgb = rgb.squeeze()
if pred.size(0) == 3:
pred = pred.permute(1,2,0)
if targ.size(0) == 3:
targ = targ.permute(1,2,0)
if rgb.size(0) == 3:
rgb = rgb.permute(1,2,0)
pred_color = vis_surface_normal(pred, mask)
targ_color = vis_surface_normal(targ, mask)
rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8)
try:
cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
except:
pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0]))
targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0]))
cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
# cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color)
# save to tensorboard
if tb_logger is not None:
tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor):
mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
pred = pred.squeeze().cpu().numpy()
target = target.squeeze().cpu().numpy()
rgb = rgb.squeeze().cpu().numpy()
pred[pred<0] = 0
target[target<0] = 0
max_scale = max(pred.max(), target.max())
pred_scale = (pred/max_scale * 10000).astype(np.uint16)
target_scale = (target/max_scale * 10000).astype(np.uint16)
pred_color = gray_to_colormap(pred)
target_color = gray_to_colormap(target)
pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1]))
target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1]))
rgb = ((rgb * std) + mean).astype(np.uint8)
return rgb, pred_scale, target_scale, pred_color, target_color
def create_html(name2path, save_path='index.html', size=(256, 384)):
# table description
cols = []
for k, v in name2path.items():
col_i = Col('img', k, v) # specify image content for column
cols.append(col_i)
# html table generation
imagetable(cols, out_file=save_path, imsize=size)
def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array:
"""
Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
Aargs:
normal (torch.tensor, [h, w, 3]): surface normal
mask (torch.tensor, [h, w]): valid masks
"""
normal = normal.cpu().numpy().squeeze()
n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
n_img_norm = normal / (n_img_L2 + 1e-8)
normal_vis = n_img_norm * 127
normal_vis += 128
normal_vis = normal_vis.astype(np.uint8)
if mask is not None:
mask = mask.cpu().numpy().squeeze()
normal_vis[~mask] = 0
return normal_vis