File size: 5,189 Bytes
7f0f123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

def visualize_geo_prior(img, geo_prior, save_path, batch_idx=0, point_coords=None, normalize=True, alpha=0.6):
    """
    Visualize geometric prior matrix and overlay the result on the original image
    Args:
        img: Original image tensor [B,C,H,W]
        geo_prior: Geometric prior tensor with shape [B,HW,HW]
        save_path: Save path
        batch_idx: Batch index to visualize
        point_coords: Reference point coordinates in format (h, w). If None, center point will be used
        normalize: Whether to normalize the display result
        alpha: Heatmap transparency, 0.0 means completely transparent, 1.0 means completely opaque
    """
    B, HW, _ = geo_prior.shape
    H = int(np.sqrt(HW))
    W = H  
    geo_prior_single = geo_prior[batch_idx]  # [HW,HW]
    
    if point_coords is None:
        center_h, center_w = H // 2, W // 2
        point_idx = center_h * W + center_w
    else:
        h, w = point_coords
        point_idx = h * W + w
    relation = geo_prior_single[point_idx]  # [HW]
    relation_map = relation.reshape(H, W)
    relation_np = relation_map.detach().cpu().numpy()
    
    if normalize:
        relation_np = (relation_np - relation_np.min()) / (relation_np.max() - relation_np.min() + 1e-6)
    
    orig_img = img[batch_idx].detach().cpu().numpy()
    orig_img = np.transpose(orig_img, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    orig_img = std * orig_img + mean
    orig_img = np.clip(orig_img * 255, 0, 255).astype(np.uint8)
    orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR)
    
    orig_h, orig_w = orig_img.shape[:2]
    
    colored_map = cv2.applyColorMap((relation_np * 255).astype(np.uint8), cv2.COLORMAP_RAINBOW)
    colored_map = cv2.resize(colored_map, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
    
    overlay = cv2.addWeighted(orig_img, 1-alpha, colored_map, alpha, 0)
    
    if point_coords is None:
        center_w_orig = int(center_w * orig_w / W)
        center_h_orig = int(center_h * orig_h / H)
        cv2.drawMarker(overlay, (center_w_orig, center_h_orig), (255, 255, 255), cv2.MARKER_CROSS, 20, 2)
    else:
        w_orig = int(w * orig_w / W)
        h_orig = int(h * orig_h / H)
        cv2.drawMarker(overlay, (w_orig, h_orig), (255, 255, 255), cv2.MARKER_CROSS, 20, 2)
    
    cv2.imwrite(save_path.replace('.png', '_overlay.png'), overlay)
    
    colored_map = cv2.applyColorMap((relation_np * 255).astype(np.uint8), cv2.COLORMAP_RAINBOW)
    cv2.imwrite(save_path.replace('.png', '_heatmap.png'), colored_map)
    cv2.imwrite(save_path.replace('.png', '_original.png'), orig_img)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(relation_np, cmap='rainbow')
    plt.colorbar(label='Geometric Prior Strength')
    
    if point_coords is None:
        plt.plot(center_w, center_h, 'w*', markersize=10)
    else:
        plt.plot(w, h, 'w*', markersize=10)
    
    plt.title(f'Geometric Prior Visualization (Ref Point: {"center" if point_coords is None else f"({point_coords[0]}, {point_coords[1]})"})')
    plt.savefig(save_path)
    plt.close()
    
    return relation_map


def save_feature_visualization(feature_map, save_path):
    """
    Visualize feature map by averaging all feature maps into one image and resize to 518*518
    Args:
        feature_map: feature map tensor with shape [C,H,W]
        save_path: save path
    """
    
    if len(feature_map.shape) == 4:
        feature_map = feature_map.squeeze(0)
    mean_feature = torch.mean(feature_map, dim=0).detach().cpu().numpy()
    mean_feature = (mean_feature - mean_feature.min()) / (mean_feature.max() - mean_feature.min() + 1e-6)
    mean_feature = (mean_feature * 255).astype(np.uint8)
    mean_feature = cv2.resize(mean_feature, (518, 518), interpolation=cv2.INTER_LINEAR)
    
    colored_feature = cv2.applyColorMap(mean_feature, cv2.COLORMAP_VIRIDIS)
    cv2.imwrite(save_path, colored_feature)

def save_depth_visualization(depth_map, filename):
    """
    Save depth map visualization as a colored image.
    
    Args:
        depth_map (torch.Tensor): Depth map tensor with shape [H, W] or [B, H, W]
        filename (str): Output file path for the visualization
    """
    depth_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
    depth_norm = depth_norm.detach().cpu().numpy().astype(np.uint8)
    colored_depth = cv2.applyColorMap(depth_norm, cv2.COLORMAP_INFERNO)
    cv2.imwrite(filename, colored_depth)

def save_image(img_tensor, filename):
    """
    Save image tensor as a BGR image file.
    
    Args:
        img_tensor (torch.Tensor): Image tensor with shape [C, H, W] or [B, C, H, W]
        filename (str): Output file path for the image
    """
    img = img_tensor.detach().cpu().numpy()

    if img.shape[0] == 3:  
        img = np.transpose(img, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, img)