import cv2 import numpy as np import matplotlib.pyplot as plt import torch def visualize_geo_prior(img, geo_prior, save_path, batch_idx=0, point_coords=None, normalize=True, alpha=0.6): """ Visualize geometric prior matrix and overlay the result on the original image Args: img: Original image tensor [B,C,H,W] geo_prior: Geometric prior tensor with shape [B,HW,HW] save_path: Save path batch_idx: Batch index to visualize point_coords: Reference point coordinates in format (h, w). If None, center point will be used normalize: Whether to normalize the display result alpha: Heatmap transparency, 0.0 means completely transparent, 1.0 means completely opaque """ B, HW, _ = geo_prior.shape H = int(np.sqrt(HW)) W = H geo_prior_single = geo_prior[batch_idx] # [HW,HW] if point_coords is None: center_h, center_w = H // 2, W // 2 point_idx = center_h * W + center_w else: h, w = point_coords point_idx = h * W + w relation = geo_prior_single[point_idx] # [HW] relation_map = relation.reshape(H, W) relation_np = relation_map.detach().cpu().numpy() if normalize: relation_np = (relation_np - relation_np.min()) / (relation_np.max() - relation_np.min() + 1e-6) orig_img = img[batch_idx].detach().cpu().numpy() orig_img = np.transpose(orig_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) orig_img = std * orig_img + mean orig_img = np.clip(orig_img * 255, 0, 255).astype(np.uint8) orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR) orig_h, orig_w = orig_img.shape[:2] colored_map = cv2.applyColorMap((relation_np * 255).astype(np.uint8), cv2.COLORMAP_RAINBOW) colored_map = cv2.resize(colored_map, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) overlay = cv2.addWeighted(orig_img, 1-alpha, colored_map, alpha, 0) if point_coords is None: center_w_orig = int(center_w * orig_w / W) center_h_orig = int(center_h * orig_h / H) cv2.drawMarker(overlay, (center_w_orig, center_h_orig), (255, 255, 255), cv2.MARKER_CROSS, 20, 2) else: w_orig = int(w * orig_w / W) h_orig = int(h * orig_h / H) cv2.drawMarker(overlay, (w_orig, h_orig), (255, 255, 255), cv2.MARKER_CROSS, 20, 2) cv2.imwrite(save_path.replace('.png', '_overlay.png'), overlay) colored_map = cv2.applyColorMap((relation_np * 255).astype(np.uint8), cv2.COLORMAP_RAINBOW) cv2.imwrite(save_path.replace('.png', '_heatmap.png'), colored_map) cv2.imwrite(save_path.replace('.png', '_original.png'), orig_img) plt.figure(figsize=(10, 8)) plt.imshow(relation_np, cmap='rainbow') plt.colorbar(label='Geometric Prior Strength') if point_coords is None: plt.plot(center_w, center_h, 'w*', markersize=10) else: plt.plot(w, h, 'w*', markersize=10) plt.title(f'Geometric Prior Visualization (Ref Point: {"center" if point_coords is None else f"({point_coords[0]}, {point_coords[1]})"})') plt.savefig(save_path) plt.close() return relation_map def save_feature_visualization(feature_map, save_path): """ Visualize feature map by averaging all feature maps into one image and resize to 518*518 Args: feature_map: feature map tensor with shape [C,H,W] save_path: save path """ if len(feature_map.shape) == 4: feature_map = feature_map.squeeze(0) mean_feature = torch.mean(feature_map, dim=0).detach().cpu().numpy() mean_feature = (mean_feature - mean_feature.min()) / (mean_feature.max() - mean_feature.min() + 1e-6) mean_feature = (mean_feature * 255).astype(np.uint8) mean_feature = cv2.resize(mean_feature, (518, 518), interpolation=cv2.INTER_LINEAR) colored_feature = cv2.applyColorMap(mean_feature, cv2.COLORMAP_VIRIDIS) cv2.imwrite(save_path, colored_feature) def save_depth_visualization(depth_map, filename): """ Save depth map visualization as a colored image. Args: depth_map (torch.Tensor): Depth map tensor with shape [H, W] or [B, H, W] filename (str): Output file path for the visualization """ depth_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0 depth_norm = depth_norm.detach().cpu().numpy().astype(np.uint8) colored_depth = cv2.applyColorMap(depth_norm, cv2.COLORMAP_INFERNO) cv2.imwrite(filename, colored_depth) def save_image(img_tensor, filename): """ Save image tensor as a BGR image file. Args: img_tensor (torch.Tensor): Image tensor with shape [C, H, W] or [B, C, H, W] filename (str): Output file path for the image """ img = img_tensor.detach().cpu().numpy() if img.shape[0] == 3: img = np.transpose(img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) img = std * img + mean img = np.clip(img * 255, 0, 255).astype(np.uint8) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) cv2.imwrite(filename, img)