CraigDroke commited on
Commit
32938bb
·
1 Parent(s): 44c1e94

Added all files

Browse files
Files changed (4) hide show
  1. interface.py +85 -0
  2. plaus_functs.py +784 -0
  3. plot_functs.py +154 -0
  4. toy_problem_pgt.py +247 -0
interface.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ from toy_problem_pgt import toy_problem
4
+
5
+ help_guide = """
6
+ ## Help Guide
7
+
8
+ This demo allows you to experiment with the toy problem from the Plausibility Guided Training (PGT) paper.
9
+
10
+ ### Input Parameters:
11
+
12
+ - **PGT Coefficient**: Choose a number between 0.1 and 10. This determines the emphasis given to the PGT loss function in the training process.
13
+ - **Focus Coefficient**: Choose a number between 0.01 and 1. This determines the concentration of pixels around the object that will be rewarded. Higher coefficient results in a more focused reward.
14
+ - **X Coord**: Choose a number between 0 and 1. This sets the X coordinate of the target.
15
+ - **Y Coord**: Choose a number between 0 and 1. This sets the Y coordinate of the target.
16
+
17
+ ### Outputs:
18
+
19
+ 1. **First 2 images**: Displays the distance regulaization map and the first atrribution map step.
20
+ 2. **Second 8 images**: Displays each other attribution map steps.
21
+ 3. **PGT Losses**: Visualizes the plausibility losses over each step.
22
+ 4. **PGT Scores**: Displays the plausibility scores over each step.
23
+
24
+ """
25
+
26
+
27
+ if __name__ == "__main__":
28
+
29
+ with gr.Blocks(title="toy problem demo", theme=gr.themes.Base()) as demo:
30
+ gr.Markdown(
31
+ """
32
+ # Toy Problem Demo
33
+ This is a demo of the toy problem implementation.
34
+ """
35
+ )
36
+
37
+ with gr.Accordion("Help", open=False):
38
+ gr.Markdown(help_guide)
39
+
40
+ with gr.Row() as file_settings:
41
+
42
+ pgt_coeff = gr.Number(label="PGT Coefficent",info="choose a number between 0.1 and 10",
43
+ minimum=0.1,maximum=10,value=1,interactive=True,step=1,show_label=True)
44
+ focus_coeff = gr.Number(label="Focus Coefficent",info="Choose a number between 0.1 and 1",
45
+ minimum=0.01,maximum=1,value=0.2,interactive=True,step=1,show_label=True)
46
+
47
+ #TODO - Target info (this is where we can adjust)
48
+ #We are just going to give the user access to the number of bounding boxes, and the xy coords
49
+ # num_bb = gr.Number(label="Number of Bounding Boxes",info="Choose a number",
50
+ # minimum=0,maximum=0,value=0,interactive=True,step=1,show_label=True)
51
+ x_coord = gr.Number(label="X Coord",info="Choose a number between 0 and 1",
52
+ minimum=0,maximum=1,value=0.8,interactive=True,step=1,show_label=True)
53
+ y_coord = gr.Number(label="Y Coord",info="Choose a number between 0 and 1",
54
+ minimum=0,maximum=1,value=0.76,interactive=True,step=1,show_label=True)
55
+
56
+ with gr.Row() as outputs:
57
+ output_img1 = gr.Image(type='filepath',label="First 2 images",
58
+ show_download_button=True,show_share_button=True,interactive=False,visible=True)
59
+ output_img2 = gr.Image(type='filepath',label="9 images",
60
+ show_download_button=True,show_share_button=True,interactive=False,visible=True, scale=4)
61
+
62
+ with gr.Row() as outputs_2:
63
+ output_img3 = gr.Image(type='filepath',label="PGT Losses",
64
+ show_download_button=True,show_share_button=True,interactive=False,visible=True)
65
+ output_img4 = gr.Image(type='filepath',label="PGT Scores",
66
+ show_download_button=True,show_share_button=True,interactive=False,visible=True)
67
+
68
+ # List of components for clearing
69
+ clear_comp_list = [output_img1, output_img2, output_img3, output_img4]
70
+
71
+ # Row for start, clear and demo buttons
72
+ with gr.Row() as buttons:
73
+ start = gr.Button(value="Start")
74
+ clear = gr.ClearButton(value='Clear All',components=clear_comp_list,
75
+ interactive=True,visible=True)
76
+
77
+ # List of gradio components that are input into the run_all method (when start button is clicked)
78
+ run_inputs = [pgt_coeff, focus_coeff, x_coord, y_coord]
79
+
80
+ # List of gradio components that are output from the run_all method (when start button is clicked)
81
+ run_outputs = [output_img1, output_img2, output_img3, output_img4]
82
+
83
+ start.click(toy_problem, inputs=run_inputs, outputs=run_outputs)
84
+
85
+ demo.queue().launch()
plaus_functs.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from plot_functs import *
4
+ from plot_functs import normalize_tensor, overlay_mask, imshow
5
+ import math
6
+ import time
7
+ import matplotlib.path as mplPath
8
+ from matplotlib.path import Path
9
+ from utils.general import non_max_suppression, xyxy2xywh, scale_coords
10
+
11
+ def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
12
+ """
13
+ Compute the gradient of an image with respect to a given tensor.
14
+
15
+ Args:
16
+ img (torch.Tensor): The input image tensor.
17
+ grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
18
+ norm (bool, optional): Whether to normalize the gradient. Defaults to True.
19
+ absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
20
+ grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
21
+ keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
22
+
23
+ Returns:
24
+ torch.Tensor: The computed attribution map.
25
+
26
+ """
27
+ if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
28
+ grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
29
+ else:
30
+ grad_wrt_outputs = None
31
+ attribution_map = torch.autograd.grad(grad_wrt, img,
32
+ grad_outputs=grad_wrt_outputs,
33
+ create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
34
+ )[0]
35
+ if absolute:
36
+ attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
37
+ if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
38
+ attribution_map = torch.sum(attribution_map, 1, keepdim=True)
39
+ if norm:
40
+ if keepmean:
41
+ attmean = torch.mean(attribution_map)
42
+ attmin = torch.min(attribution_map)
43
+ attmax = torch.max(attribution_map)
44
+ attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
45
+ if keepmean:
46
+ attribution_map -= attribution_map.mean()
47
+ attribution_map += (attmean / (attmax - attmin))
48
+
49
+ return attribution_map
50
+
51
+ def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
52
+ """
53
+ Generate Gaussian noise based on the input image.
54
+
55
+ Args:
56
+ img (torch.Tensor): Input image.
57
+ grad_wrt: Gradient with respect to the input image.
58
+ norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
59
+ absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
60
+ grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
61
+ keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
62
+
63
+ Returns:
64
+ torch.Tensor: Generated Gaussian noise.
65
+ """
66
+
67
+ gaussian_noise = torch.randn_like(img)
68
+
69
+ if absolute:
70
+ gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
71
+ if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
72
+ gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
73
+ if norm:
74
+ if keepmean:
75
+ attmean = torch.mean(gaussian_noise)
76
+ attmin = torch.min(gaussian_noise)
77
+ attmax = torch.max(gaussian_noise)
78
+ gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
79
+ if keepmean:
80
+ gaussian_noise -= gaussian_noise.mean()
81
+ gaussian_noise += (attmean / (attmax - attmin))
82
+
83
+ return gaussian_noise
84
+
85
+
86
+ def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
87
+ # TODO: Remove imgs from this function and only take it as input if debug is True
88
+ """
89
+ Calculates the plausibility score based on the given inputs.
90
+
91
+ Args:
92
+ imgs (torch.Tensor): The input images.
93
+ targets_out (torch.Tensor): The output targets.
94
+ attr (torch.Tensor): The attribute tensor.
95
+ debug (bool, optional): Whether to enable debug mode. Defaults to False.
96
+
97
+ Returns:
98
+ torch.Tensor: The plausibility score.
99
+ """
100
+ # # if imgs is None:
101
+ # # imgs = torch.zeros_like(attr)
102
+ # # with torch.no_grad():
103
+ # target_inds = targets_out[:, 0].int()
104
+ # xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
105
+ # num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
106
+ # # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
107
+ # xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
108
+ # co = xyxy_corners
109
+ # if corners:
110
+ # co = targets_out[:, 2:6].int()
111
+ # coords_map = torch.zeros_like(attr, dtype=torch.bool)
112
+ # # rows = np.arange(co.shape[0])
113
+ # x1, x2 = co[:,1], co[:,3]
114
+ # y1, y2 = co[:,0], co[:,2]
115
+
116
+ # for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
117
+ # coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
118
+
119
+ if torch.isnan(attr).any():
120
+ attr = torch.nan_to_num(attr, nan=0.0)
121
+
122
+ coords_map = get_bbox_map(targets_out, attr)
123
+ plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
124
+
125
+ if debug:
126
+ for i in range(len(coords_map)):
127
+ coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
128
+ test_bbox = torch.zeros_like(imgs[i])
129
+ test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
130
+ imshow(test_bbox, save_path='figs/test_bbox')
131
+ if imgs is None:
132
+ imgs = torch.zeros_like(attr)
133
+ imshow(imgs[i], save_path='figs/im0')
134
+ imshow(attr[i], save_path='figs/attr')
135
+
136
+ # with torch.no_grad():
137
+ # # att_select = attr[coords_map]
138
+ # att_select = attr * coords_map.to(torch.float32)
139
+ # att_total = attr
140
+
141
+ # IoU_num = torch.sum(att_select)
142
+ # IoU_denom = torch.sum(att_total)
143
+
144
+ # IoU_ = (IoU_num / IoU_denom)
145
+ # plaus_score = IoU_
146
+
147
+ # # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
148
+
149
+ return plaus_score
150
+
151
+ from utils.general import bbox_iou
152
+
153
+ def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
154
+ # TODO: Remove imgs from this function and only take it as input if debug is True
155
+ """
156
+ Calculates the plausibility score based on the given inputs.
157
+
158
+ Args:
159
+ imgs (torch.Tensor): The input images.
160
+ targets_out (torch.Tensor): The output targets.
161
+ attr (torch.Tensor): The attribute tensor.
162
+ debug (bool, optional): Whether to enable debug mode. Defaults to False.
163
+
164
+ Returns:
165
+ torch.Tensor: The plausibility score.
166
+ """
167
+ # if imgs is None:
168
+ # imgs = torch.zeros_like(attr)
169
+ # with torch.no_grad():
170
+ target_inds = targets_out[:, 0].int()
171
+ xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
172
+ num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
173
+ # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
174
+ xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
175
+ co = xyxy_corners
176
+ if corners:
177
+ co = targets_out[:, 2:6].int()
178
+ coords_map = torch.zeros_like(attr, dtype=torch.bool)
179
+ # rows = np.arange(co.shape[0])
180
+ x1, x2 = co[:,1], co[:,3]
181
+ y1, y2 = co[:,0], co[:,2]
182
+
183
+ for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
184
+ coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
185
+
186
+ if torch.isnan(attr).any():
187
+ attr = torch.nan_to_num(attr, nan=0.0)
188
+ if debug:
189
+ for i in range(len(coords_map)):
190
+ coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
191
+ test_bbox = torch.zeros_like(imgs[i])
192
+ test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
193
+ imshow(test_bbox, save_path='figs/test_bbox')
194
+ imshow(imgs[i], save_path='figs/im0')
195
+ imshow(attr[i], save_path='figs/attr')
196
+
197
+ # att_select = attr[coords_map]
198
+ # with torch.no_grad():
199
+ # IoU_num = (torch.sum(attr[coords_map]))
200
+ # IoU_denom = torch.sum(attr)
201
+ # IoU_ = (IoU_num / (IoU_denom))
202
+
203
+ # IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
204
+ co = (xyxy_batch * num_pixels).int()
205
+ x1 = co[:,1] + 1
206
+ y1 = co[:,0] + 1
207
+ # with torch.no_grad():
208
+ attr_ = torch.sum(attr, 1, keepdim=True)
209
+ corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
210
+ for ic in range(co.shape[0]):
211
+ attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
212
+ attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
213
+ attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
214
+ attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
215
+
216
+ x_0, y_0 = max_indices_2d(attr0[0])
217
+ x_1, y_1 = max_indices_2d(attr1[0])
218
+ x_2, y_2 = max_indices_2d(attr2[0])
219
+ x_3, y_3 = max_indices_2d(attr3[0])
220
+
221
+ y_1 += y1[ic]
222
+ x_2 += x1[ic]
223
+ x_3 += x1[ic]
224
+ y_3 += y1[ic]
225
+
226
+ max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
227
+ torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
228
+ torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
229
+ torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
230
+ if corners_attr is None:
231
+ corners_attr = max_corners
232
+ else:
233
+ corners_attr = torch.cat([corners_attr, max_corners], dim=0)
234
+ # corners_attr[ic] = max_corners
235
+ # corners_attr = attr[:,0,:4,0]
236
+ corners_attr = corners_attr.view(-1, 4)
237
+ # corners_attr = torch.stack(corners_attr, dim=0)
238
+ IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
239
+ plaus_score = IoU_.mean()
240
+
241
+ return plaus_score
242
+
243
+ def max_indices_2d(x_inp):
244
+ # values, indices = x.reshape(x.size(0), -1).max(dim=-1)
245
+ torch.max(x_inp,)
246
+ index = torch.argmax(x_inp)
247
+ x = index // x_inp.shape[1]
248
+ y = index % x_inp.shape[1]
249
+ # x, y = divmod(index.item(), x_inp.shape[1])
250
+
251
+ return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
252
+
253
+
254
+ def point_in_polygon(poly, grid):
255
+ # t0 = time.time()
256
+ num_points = poly.shape[0]
257
+ j = num_points - 1
258
+ oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
259
+ for i in range(num_points):
260
+ cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
261
+ cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
262
+ cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
263
+ oddNodes = oddNodes ^ (cond1 | cond2) & cond3
264
+ j = i
265
+ # t1 = time.time()
266
+ # print(f'point in polygon time: {t1-t0}')
267
+ return oddNodes
268
+
269
+ def point_in_polygon_gpu(poly, grid):
270
+ num_points = poly.shape[0]
271
+ i = torch.arange(num_points)
272
+ j = (i - 1) % num_points
273
+ # Expand dimensions
274
+ # t0 = time.time()
275
+ poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
276
+ # t1 = time.time()
277
+ cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
278
+ cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
279
+ cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
280
+ # t2 = time.time()
281
+ oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
282
+ cond = (cond1 | cond2) & cond3
283
+ # t3 = time.time()
284
+ # efficiently perform xor using gpu and avoiding cpu as much as possible
285
+ c = []
286
+ while len(cond) > 1:
287
+ if len(cond) % 2 == 1: # odd number of elements
288
+ c.append(cond[-1])
289
+ cond = cond[:-1]
290
+ cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
291
+ for c_ in c:
292
+ cond = torch.bitwise_xor(cond, c_)
293
+ oddNodes = cond
294
+ # t4 = time.time()
295
+ # for c in cond:
296
+ # oddNodes = oddNodes ^ c
297
+ # print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
298
+ # print(f'point in polygon time gpu: {t4-t0}')
299
+ # oddNodes = oddNodes ^ (cond1 | cond2) & cond3
300
+ return oddNodes
301
+
302
+
303
+ def bitmap_for_polygon(poly, h, w):
304
+ y = torch.arange(h).to(poly.device).float()
305
+ x = torch.arange(w).to(poly.device).float()
306
+ grid_y, grid_x = torch.meshgrid(y, x)
307
+ grid = torch.stack((grid_x, grid_y), dim=-1)
308
+ bitmap = point_in_polygon(poly, grid)
309
+ return bitmap.unsqueeze(0)
310
+
311
+
312
+ def corners_coords(center_xywh):
313
+ center_x, center_y, w, h = center_xywh
314
+ x = center_x - w/2
315
+ y = center_y - h/2
316
+ return torch.tensor([x, y, x+w, y+h])
317
+
318
+ def corners_coords_batch(center_xywh):
319
+ center_x, center_y = center_xywh[:,0], center_xywh[:,1]
320
+ w, h = center_xywh[:,2], center_xywh[:,3]
321
+ x = center_x - w/2
322
+ y = center_y - h/2
323
+ return torch.stack([x, y, x+w, y+h], dim=1)
324
+
325
+ def normalize_batch(x):
326
+ """
327
+ Normalize a batch of tensors along each channel.
328
+
329
+ Args:
330
+ x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
331
+
332
+ Returns:
333
+ torch.Tensor: Normalized tensor of the same shape as the input.
334
+ """
335
+ mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
336
+ maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
337
+ for i in range(x.shape[0]):
338
+ mins[i] = x[i].min()
339
+ maxs[i] = x[i].max()
340
+ x_ = (x - mins) / (maxs - mins)
341
+
342
+ return x_
343
+
344
+ def get_detections(model_clone, img):
345
+ """
346
+ Get detections from a model given an input image and targets.
347
+
348
+ Args:
349
+ model (nn.Module): The model to use for detection.
350
+ img (torch.Tensor): The input image tensor.
351
+
352
+ Returns:
353
+ torch.Tensor: The detected bounding boxes.
354
+ """
355
+ model_clone.eval() # Set model to evaluation mode
356
+ # Run inference
357
+ with torch.no_grad():
358
+ det_out, out = model_clone(img)
359
+
360
+ # model_.train()
361
+ del img
362
+
363
+ return det_out, out
364
+
365
+ def get_labels(det_out, imgs, targets, opt):
366
+ ###################### Get predicted labels ######################
367
+ nb, _, height, width = imgs.shape # batch size, channels, height, width
368
+ targets_ = targets.clone()
369
+ targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
370
+ lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
371
+ o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
372
+ pred_labels = []
373
+ for si, pred in enumerate(o):
374
+ labels = targets_[targets_[:, 0] == si, 1:]
375
+ nl = len(labels)
376
+ predn = pred.clone()
377
+ # Get the indices that sort the values in column 5 in ascending order
378
+ sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
379
+ # Apply the sorting indices to the tensor
380
+ sorted_pred = predn[sort_indices]
381
+ # Remove predictions with less than 0.1 confidence
382
+ n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
383
+ sorted_pred = sorted_pred[:n_conf]
384
+ new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
385
+ preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
386
+ preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
387
+ gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
388
+ preds[:, 2:] /= gn.to(imgs.device) # from pixels
389
+ pred_labels.append(preds)
390
+ pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
391
+
392
+ return pred_labels
393
+ ##################################################################
394
+
395
+ from torchvision.utils import make_grid
396
+
397
+ def get_center_coords(attr):
398
+ img_tensor = img_tensor / img_tensor.max()
399
+
400
+ # Define a brightness threshold
401
+ threshold = 0.95
402
+
403
+ # Create a binary mask of the bright pixels
404
+ mask = img_tensor > threshold
405
+
406
+ # Get the coordinates of the bright pixels
407
+ y_coords, x_coords = torch.where(mask)
408
+
409
+ # Calculate the centroid of the bright pixels
410
+ centroid_x = x_coords.float().mean().item()
411
+ centroid_y = y_coords.float().mean().item()
412
+
413
+ print(f'The central bright point is at ({centroid_x}, {centroid_y})')
414
+
415
+ return
416
+
417
+
418
+ def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
419
+ """
420
+ Compute the distance grids from each pixel to the target coordinates.
421
+
422
+ Args:
423
+ attr (torch.Tensor): Attribution maps.
424
+ targets (torch.Tensor): Target coordinates.
425
+ focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
426
+ debug (bool, optional): Whether to visualize debug information. Defaults to False.
427
+
428
+ Returns:
429
+ torch.Tensor: Distance grids.
430
+ """
431
+
432
+ # Assign the height and width of the input tensor to variables
433
+ height, width = attr.shape[-1], attr.shape[-2]
434
+
435
+ # attr = torch.abs(attr) # Take absolute values of gradients
436
+ # attr = normalize_batch(attr) # Normalize attribution maps per image in batch
437
+
438
+ # Create a grid of indices
439
+ xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
440
+ idx_grid = torch.stack((xx, yy), dim=-1).float()
441
+
442
+ # Expand the grid to match the batch size
443
+ idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
444
+
445
+ # Initialize a list to store the distance grids
446
+ dist_grids_ = [[]] * attr.shape[0]
447
+
448
+ # Loop over batches
449
+ for j in range(attr.shape[0]):
450
+ # Get the rows where the first column is the current unique value
451
+ rows = targets[targets[:, 0] == j]
452
+
453
+ if len(rows) != 0:
454
+ # Create a tensor for the target coordinates
455
+ xy = rows[:,2:4] # y, x
456
+ # Flip the x and y coordinates and scale them to the image size
457
+ xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
458
+ xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
459
+
460
+ # Compute the Euclidean distance from each pixel to the target coordinates
461
+ dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
462
+
463
+ # Pick the closest distance to any target for each pixel
464
+ dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
465
+ dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
466
+ else:
467
+ # Set grid to zero if no targets are present
468
+ dist_grid = torch.zeros_like(attr[j])
469
+
470
+ dist_grids_[j] = dist_grid
471
+ # Convert the list of distance grids to a tensor for faster computation
472
+ dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
473
+ if torch.isnan(dist_grids).any():
474
+ dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
475
+
476
+ if debug:
477
+ for i in range(len(dist_grids)):
478
+ if ((i % 8) == 0):
479
+ grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
480
+ imshow(grid_show, save_path='figs/dist_grids')
481
+ if imgs is None:
482
+ imgs = torch.zeros_like(attr)
483
+ imshow(imgs[i], save_path='figs/im0')
484
+ img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
485
+ imshow(img_overlay, save_path='figs/dist_grid_overlay')
486
+ weighted_attr = (dist_grids[i] * attr[i])
487
+ imshow(weighted_attr, save_path='figs/weighted_attr')
488
+ imshow(attr[i], save_path='figs/attr')
489
+
490
+ return dist_grids
491
+
492
+ def attr_reg(attribution_map, distance_map):
493
+
494
+ # dist_attr = distance_map * attribution_map
495
+ dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
496
+ # del distance_map, attribution_map
497
+ return dist_attr
498
+
499
+ def get_bbox_map(targets_out, attr, corners=False):
500
+ target_inds = targets_out[:, 0].int()
501
+ xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
502
+ num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
503
+ # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
504
+ xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
505
+ co = xyxy_corners
506
+ if corners:
507
+ co = targets_out[:, 2:6].int()
508
+ coords_map = torch.zeros_like(attr, dtype=torch.bool)
509
+ # rows = np.arange(co.shape[0])
510
+ x1, x2 = co[:,1], co[:,3]
511
+ y1, y2 = co[:,0], co[:,2]
512
+
513
+ for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
514
+ coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
515
+
516
+ bbox_map = coords_map.to(torch.float32)
517
+
518
+ return bbox_map
519
+ ######################################## BCE #######################################
520
+ def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
521
+ # if imgs is None:
522
+ # imgs = torch.zeros_like(attribution_map)
523
+ # Calculate Plausibility IoU with attribution maps
524
+ # attribution_map.retains_grad = True
525
+ if not only_loss:
526
+ plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
527
+ else:
528
+ plaus_score = torch.tensor(0.0)
529
+
530
+ # attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
531
+
532
+ # Calculate distance regularization
533
+ distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
534
+ # distance_map = torch.ones_like(attribution_map)
535
+
536
+ if opt.dist_x_bbox:
537
+ bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
538
+ distance_map[bbox_map] = 0.0
539
+ # distance_map = distance_map * (1 - bbox_map)
540
+
541
+ # Positive regularization term for incentivizing pixels near the target to have high attribution
542
+ dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
543
+ # Negative regularization term for incentivizing pixels far from the target to have low attribution
544
+ dist_attr_neg = attr_reg(attribution_map, distance_map)
545
+ # Calculate plausibility regularization term
546
+ # dist_reg = dist_attr_pos - dist_attr_neg
547
+ dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
548
+ # dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
549
+ # dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
550
+ # torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
551
+ # / 2.5
552
+
553
+ if opt.bbox_coeff != 0.0:
554
+ bbox_map = get_bbox_map(targets, attribution_map)
555
+ attr_bbox_pos = attr_reg(attribution_map, bbox_map)
556
+ attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
557
+ bbox_reg = attr_bbox_pos - attr_bbox_neg
558
+ # bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
559
+ else:
560
+ bbox_reg = 0.0
561
+
562
+ bbox_map = get_bbox_map(targets, attribution_map)
563
+ plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
564
+ # iou_loss = (1.0 - plaus_score)
565
+
566
+ if not opt.dist_reg_only:
567
+ dist_reg_loss = (((1.0 + dist_reg) / 2.0))
568
+ plaus_reg = (plaus_score * opt.iou_coeff) + \
569
+ (((dist_reg_loss * opt.dist_coeff) + \
570
+ (bbox_reg * opt.bbox_coeff))\
571
+ # ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
572
+ # / (plaus_score) \
573
+ )
574
+ else:
575
+ plaus_reg = (((1.0 + dist_reg) / 2.0))
576
+ # plaus_reg = dist_reg
577
+ # Calculate plausibility loss
578
+ plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
579
+ # plaus_loss = (plaus_reg) * opt.pgt_coeff
580
+ if only_loss:
581
+ return plaus_loss
582
+ if not debug:
583
+ return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
584
+ else:
585
+ return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
586
+
587
+ ####################################################################################
588
+ #### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
589
+ ####################################################################################
590
+
591
+ def generate_vanilla_grad(model, input_tensor, loss_func = None,
592
+ targets_list=None, targets=None, metric=None, out_num = 1,
593
+ n_max_labels=3, norm=True, abs=True, grayscale=True,
594
+ class_specific_attr = True, device='cpu'):
595
+ """
596
+ Generate vanilla gradients for the given model and input tensor.
597
+
598
+ Args:
599
+ model (nn.Module): The model to generate gradients for.
600
+ input_tensor (torch.Tensor): The input tensor for which gradients are computed.
601
+ loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
602
+ targets_list (list, optional): The list of target tensors. Defaults to None.
603
+ metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
604
+ out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
605
+ n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
606
+ norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
607
+ abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
608
+ grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
609
+ class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
610
+ device (str, optional): The device to use for computation. Defaults to 'cpu'.
611
+
612
+ Returns:
613
+ torch.Tensor: The generated vanilla gradients.
614
+ """
615
+ # Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
616
+ train_mode = model.training
617
+ if not train_mode:
618
+ model.train()
619
+
620
+ input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
621
+ model.zero_grad() # Zero gradients
622
+ inpt = input_tensor
623
+ # Forward pass
624
+ train_out = model(inpt) # training outputs (no inference outputs in train mode)
625
+
626
+ # train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
627
+ # train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
628
+ # train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
629
+
630
+ if class_specific_attr:
631
+ n_attr_list, index_classes = [], []
632
+ for i in range(len(input_tensor)):
633
+ if len(targets_list[i]) > n_max_labels:
634
+ targets_list[i] = targets_list[i][:n_max_labels]
635
+ if targets_list[i].numel() != 0:
636
+ # unique_classes = torch.unique(targets_list[i][:,1])
637
+ class_numbers = targets_list[i][:,1]
638
+ index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
639
+ num_attrs = len(targets_list[i])
640
+ # index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
641
+ # num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
642
+ n_attr_list.append(num_attrs)
643
+ else:
644
+ index_classes.append([0, 1, 2, 3, 4])
645
+ n_attr_list.append(0)
646
+
647
+ targets_list_filled = [targ.clone().detach() for targ in targets_list]
648
+ labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
649
+ max_labels = np.max(labels_len)
650
+ max_index = np.argmax(labels_len)
651
+ for i in range(len(targets_list)):
652
+ # targets_list_filled[i] = targets_list[i]
653
+ if len(targets_list_filled[i]) < max_labels:
654
+ tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
655
+ targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
656
+ else:
657
+ targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
658
+ for i in range(len(targets_list_filled)-1,-1,-1):
659
+ if targets_list_filled[i].numel() == 0:
660
+ targets_list_filled.pop(i)
661
+ targets_list_filled = torch.cat(targets_list_filled)
662
+
663
+ n_img_attrs = len(input_tensor) if class_specific_attr else 1
664
+ n_img_attrs = 1 if loss_func else n_img_attrs
665
+
666
+ attrs_batch = []
667
+ for i_batch in range(n_img_attrs):
668
+ if loss_func and class_specific_attr:
669
+ i_batch = max_index
670
+ # inpt = input_tensor[i_batch].unsqueeze(0)
671
+ # ##################################################################
672
+ # model.zero_grad() # Zero gradients
673
+ # train_out = model(inpt) # training outputs (no inference outputs in train mode)
674
+ # ##################################################################
675
+ n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
676
+ n_label_attrs = 1 if not class_specific_attr else n_label_attrs
677
+ attrs_img = []
678
+ for i_attr in range(n_label_attrs):
679
+ if loss_func is None:
680
+ grad_wrt = train_out[out_num]
681
+ if class_specific_attr:
682
+ grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
683
+ grad_wrt_outputs = torch.ones_like(grad_wrt)
684
+ else:
685
+ # if class_specific_attr:
686
+ # targets = targets_list[:][i_attr]
687
+ # n_targets = len(targets_list[i_batch])
688
+ if class_specific_attr:
689
+ target_indiv = targets_list_filled[:,i_attr] # batch image input
690
+ else:
691
+ target_indiv = targets
692
+ # target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
693
+ # target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
694
+
695
+ try:
696
+ loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
697
+ except:
698
+ target_indiv = target_indiv.to(device)
699
+ inpt = inpt.to(device)
700
+ for tro in train_out:
701
+ tro = tro.to(device)
702
+ print("Error in loss function, trying again with device specified")
703
+ loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
704
+ grad_wrt = loss
705
+ grad_wrt_outputs = None
706
+
707
+ model.zero_grad() # Zero gradients
708
+ gradients = torch.autograd.grad(grad_wrt, inpt,
709
+ grad_outputs=grad_wrt_outputs,
710
+ retain_graph=True,
711
+ # create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
712
+ )
713
+
714
+ # Convert gradients to numpy array and back to ensure full separation from graph
715
+ # attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
716
+ attribution_map = gradients[0]#.clone().detach() # without converting to numpy
717
+
718
+ if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
719
+ attribution_map = torch.sum(attribution_map, 1, keepdim=True)
720
+ if abs:
721
+ attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
722
+ if norm:
723
+ attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
724
+ attrs_img.append(attribution_map)
725
+ if len(attrs_img) == 0:
726
+ attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
727
+ else:
728
+ attrs_batch.append(torch.stack(attrs_img).to(device))
729
+
730
+ # out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
731
+ # out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
732
+ out_attr = attrs_batch
733
+ # Set model back to original mode
734
+ if not train_mode:
735
+ model.eval()
736
+
737
+ return out_attr
738
+
739
+ class RVNonLinearFunc(nn.Module):
740
+ """
741
+ Custom Bayesian ReLU activation function for random variables.
742
+
743
+ Attributes:
744
+ None
745
+ """
746
+ def __init__(self, func):
747
+ super(RVNonLinearFunc, self).__init__()
748
+ self.func = func
749
+
750
+ def forward(self, mu_in, Sigma_in):
751
+ """
752
+ Forward pass of the Bayesian ReLU activation function.
753
+
754
+ Args:
755
+ mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
756
+ representing the mean input to the ReLU activation function.
757
+ Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
758
+ representing the covariance input to the ReLU activation function.
759
+
760
+ Returns:
761
+ Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
762
+ including the mean of the output and the covariance of the output.
763
+ """
764
+ # Collect stats
765
+ batch_size = mu_in.size(0)
766
+
767
+ # Mean
768
+ mu_out = self.func(mu_in)
769
+
770
+ # Compute the derivative of the ReLU activation function with respect to the input mean
771
+ gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
772
+
773
+ # add an extra dimension to gradi at position 2 and 1
774
+ grad1 = gradi.unsqueeze(dim=2)
775
+ grad2 = gradi.unsqueeze(dim=1)
776
+
777
+ # compute the outer product of grad1 and grad2
778
+ outer_product = torch.bmm(grad1, grad2)
779
+
780
+ # element-wise multiply Sigma_in with the outer product
781
+ # and return the result
782
+ Sigma_out = torch.mul(Sigma_in, outer_product)
783
+
784
+ return mu_out, Sigma_out
plot_functs.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class Subplots:
7
+ def __init__(self, figsize = (40, 5)):
8
+ self.fig = plt.figure(figsize=figsize)
9
+
10
+ def plot_img_list(self, img_list, savedir='figs/test',
11
+ nrows = 1, rownum = 0,
12
+ hold = False, coltitles=[], rowtitle=''):
13
+
14
+ for i, img in enumerate(img_list):
15
+ try:
16
+ npimg = img.clone().detach().cpu().numpy()
17
+ except:
18
+ npimg = img
19
+ tpimg = np.transpose(npimg, (1, 2, 0))
20
+ lenrow = int((len(img_list)))
21
+ ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
22
+ if len(coltitles) > i:
23
+ ax.set_title(coltitles[i])
24
+ if i == 0:
25
+ ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
26
+ xycoords='axes fraction', textcoords='offset points',
27
+ size='large', ha='center', va='baseline')
28
+ # ax.set_ylabel(rowtitle, rotation=90)
29
+ ax.imshow(tpimg)
30
+ ax.axis('off')
31
+
32
+ if not hold:
33
+ self.fig.tight_layout()
34
+ plt.savefig(f'{savedir}.png')
35
+ plt.clf()
36
+ plt.close('all')
37
+
38
+
39
+ def VisualizeNumpyImageGrayscale(image_3d):
40
+ r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
41
+ """
42
+ vmin = np.min(image_3d)
43
+ image_2d = image_3d - vmin
44
+ vmax = np.max(image_2d)
45
+ return (image_2d / vmax)
46
+
47
+ def normalize_numpy(image_3d):
48
+ r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
49
+ """
50
+ vmin = np.min(image_3d)
51
+ image_2d = image_3d - vmin
52
+ vmax = np.max(image_2d)
53
+ return (image_2d / vmax)
54
+
55
+ # def normalize_tensor(image_3d):
56
+ # r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
57
+ # """
58
+ # vmin = torch.min(image_3d)
59
+ # image_2d = image_3d - vmin
60
+ # vmax = torch.max(image_2d)
61
+ # return (image_2d / vmax)
62
+
63
+ def normalize_tensor(image_3d):
64
+ r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
65
+ """
66
+ image_2d = (image_3d - torch.min(image_3d))
67
+ return (image_2d / torch.max(image_2d))
68
+
69
+ def format_img(img_):
70
+ np_img = img_.numpy()
71
+ tp_img = np.transpose(np_img, (1, 2, 0))
72
+ return tp_img
73
+
74
+ def imshow(img, save_path=None):
75
+ try:
76
+ npimg = img.clone().detach().cpu().numpy()
77
+ except:
78
+ npimg = img
79
+ tpimg = np.transpose(npimg, (1, 2, 0))
80
+ plt.imshow(tpimg)
81
+ # plt.axis('off')
82
+ plt.tight_layout()
83
+ if save_path != None:
84
+ plt.savefig(str(str(save_path) + ".png"))
85
+ #plt.show()a
86
+
87
+ def imshow_img(img, imsave_path):
88
+ # works for tensors and numpy arrays
89
+ try:
90
+ npimg = VisualizeNumpyImageGrayscale(img.numpy())
91
+ except:
92
+ npimg = VisualizeNumpyImageGrayscale(img)
93
+ npimg = np.transpose(npimg, (2, 0, 1))
94
+ imshow(npimg, save_path=imsave_path)
95
+ print("Saving image as ", imsave_path)
96
+
97
+ def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
98
+ model.train()
99
+ model.to(device)
100
+ img = img.to(device)
101
+ img.requires_grad_(True)
102
+ labels.to(device).requires_grad_(True)
103
+ model.requires_grad_(True)
104
+ cuda = device.type != 'cpu'
105
+ scaler = amp.GradScaler(enabled=cuda)
106
+ pred = model(img)
107
+ # out, train_out = model(img, augment=augment) # inference and training outputs
108
+ loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls
109
+ # loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
110
+ # loss = torch.sum(loss).requires_grad_(True)
111
+
112
+ with torch.autograd.set_detect_anomaly(True):
113
+ scaler.scale(loss).backward(inputs=img)
114
+ # loss.backward()
115
+
116
+ # S_c = torch.max(pred[0].data, 0)[0]
117
+ Sc_dx = img.grad
118
+ model.eval()
119
+ Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
120
+ return Sc_dx
121
+
122
+ def calculate_snr(img, attr, dB=True):
123
+ try:
124
+ img_np = img.detach().cpu().numpy()
125
+ attr_np = attr.detach().cpu().numpy()
126
+ except:
127
+ img_np = img
128
+ attr_np = attr
129
+
130
+ # Calculate the signal power
131
+ signal_power = np.mean(img_np**2)
132
+
133
+ # Calculate the noise power
134
+ noise_power = np.mean(attr_np**2)
135
+
136
+ if dB == True:
137
+ # Calculate SNR in dB
138
+ snr = 10 * np.log10(signal_power / noise_power)
139
+ else:
140
+ # Calculate SNR
141
+ snr = signal_power / noise_power
142
+
143
+ return snr
144
+
145
+ def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
146
+
147
+ cmap = plt.get_cmap(colormap)
148
+ npmask = np.array(mask.clone().detach().cpu().squeeze(0))
149
+ # cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
150
+ cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
151
+ overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
152
+ overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
153
+
154
+ return overlayed_tensor
toy_problem_pgt.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from plaus_functs import get_center_coords, get_distance_grids, get_plaus_loss, get_bbox_map, normalize_batch
3
+ from plot_functs import imshow
4
+ from torchvision.transforms.functional import gaussian_blur
5
+ import argparse
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import os
9
+ import cv2
10
+
11
+ def subfigimshow(img, ax):
12
+ print(f'img shape: {img.shape}')
13
+ try:
14
+ npimg = img.clone().detach().cpu().numpy()
15
+ except:
16
+ npimg = img
17
+ if len(npimg.shape) == 2:
18
+ # If it's a 2D array, it's likely a grayscale image
19
+ ax.imshow(npimg, cmap='gray')
20
+ elif len(npimg.shape) == 3:
21
+ if npimg.shape[0] == 3 or npimg.shape[0] == 1:
22
+ # If the first dimension is 3 or 1, it's likely in (C, H, W) format
23
+ tpimg = np.transpose(npimg, (1, 2, 0))
24
+ else:
25
+ # It's already in (H, W, C) format
26
+ tpimg = npimg
27
+
28
+ if tpimg.shape[2] == 1:
29
+ # If it's a 3D array with only one channel, squeeze it
30
+ ax.imshow(np.squeeze(tpimg), cmap='gray')
31
+ else:
32
+ ax.imshow(tpimg)
33
+ else:
34
+ raise ValueError(f"Unexpected image shape: {npimg.shape}")
35
+
36
+ def draw_bounding_boxes(image, boxes, color=(0, 255, 0), thickness=2):
37
+ # Ensure image is 3-channel RGB
38
+ if len(image.shape) == 2:
39
+ image = np.stack([image] * 3, axis=-1)
40
+ elif len(image.shape) == 3 and image.shape[2] == 1:
41
+ image = np.repeat(image, 3, axis=2)
42
+
43
+ # Ensure image is uint8 and in range [0, 255]
44
+ if image.dtype != np.uint8:
45
+ image = (image * 255).clip(0, 255).astype(np.uint8)
46
+
47
+ image_with_boxes = image.copy()
48
+ for box in boxes:
49
+ x_center, y_center, width, height = box
50
+ x_min = int((x_center - width / 2) * image_with_boxes.shape[1])
51
+ y_min = int((y_center - height / 2) * image_with_boxes.shape[0])
52
+ x_max = int((x_center + width / 2) * image_with_boxes.shape[1])
53
+ y_max = int((y_center + height / 2) * image_with_boxes.shape[0])
54
+ cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), color, thickness)
55
+
56
+ return image_with_boxes
57
+
58
+ def toy_problem(pgt_coeff, focus_coeff, x_coord, y_coord, num_bb=0, alpha=200.0, scheduler=2.0, device="0", dist_coeff=0.5, dist_reg_only=True, iou_coeff=0.5,
59
+ bbox_coeff=0.0, dist_x_bbox=False, iou_loss_only=False, show_dist_reg=True):
60
+
61
+ # Create a Namespace object to hold params
62
+ opt = argparse.Namespace()
63
+ # Save all parameters as attributes of the Namespace object
64
+ opt.pgt_coeff = pgt_coeff
65
+ opt.focus_coeff = focus_coeff
66
+ opt.x_coord = x_coord
67
+ opt.y_coord = y_coord
68
+ opt.num_bb = num_bb
69
+ opt.alpha = alpha
70
+ opt.scheduler = scheduler
71
+ opt.device = device
72
+ opt.dist_coeff = dist_coeff
73
+ opt.dist_reg_only = dist_reg_only
74
+ opt.iou_coeff = iou_coeff
75
+ opt.bbox_coeff = bbox_coeff
76
+ opt.dist_x_bbox = dist_x_bbox
77
+ opt.iou_loss_only = iou_loss_only
78
+ opt.show_dist_reg = show_dist_reg
79
+
80
+ # Create a list of save dirs for output
81
+ save_dirs = []
82
+
83
+ # Set CUDA device
84
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
85
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(int(opt.device))
86
+
87
+ #TODO - Adjust this for the number of bounding boxes
88
+ targets = torch.tensor([
89
+ [0, 0, opt.x_coord, opt.y_coord, 0.05, 0.05],
90
+ # [0, 1, 0.4, 0.6, 0.05, 0.07],
91
+ # [1, 0, 0.25, 0.2, 0.04, 0.05],
92
+ # [2, 0, 0.8, 0.76, 0.05, 0.05],
93
+ # [2, 0, 0.8, 0.2, 0.05, 0.05],
94
+ # [0, 0, 0.8, 0.76, 0.05, 0.05],
95
+ # [1, 0, 0.8, 0.2, 0.05, 0.05],
96
+ ])
97
+
98
+ unique_classes = torch.unique(targets[:,0])
99
+ # X = (gaussian_blur(torch.rand(len(unique_classes), 1, 50, 50)**2, 3)**4)
100
+ attr = (gaussian_blur(torch.rand(len(unique_classes), 1, 640, 640)**2, 13)**4).requires_grad_(True)
101
+ plaus_loss = get_plaus_loss(targets, attribution_map=attr,
102
+ opt=opt,
103
+ debug=True,
104
+ only_loss=True)
105
+ if opt.iou_loss_only:
106
+ bbox_map = get_bbox_map(targets, attr)
107
+ plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr)))
108
+ plaus_loss = (1.0 - plaus_score)
109
+
110
+ # Plot params (adjust as nessesary)
111
+ nsamples = 10
112
+ rows = len(attr) # Number of images
113
+ cols = nsamples + 2 # Define the number of columns for subplots
114
+ size = 3
115
+
116
+ # Create a new figure for each i
117
+ fig1 = plt.figure(figsize=(cols * size, rows * size))
118
+ plt.tight_layout()
119
+
120
+ # Create the second figure for the remaining 8 attr steps
121
+ fig2 = plt.figure(figsize=(cols * size, rows * size))
122
+ plt.tight_layout()
123
+
124
+ # Create a figure for plausibility losses
125
+ fig3, ax3 = plt.subplots(figsize=(10, 6))
126
+ plaus_losses = []
127
+
128
+ # Create a figure for plausibility scores
129
+ fig4, ax4 = plt.subplots(figsize=(10, 6))
130
+ plaus_scores = []
131
+
132
+ for i in range(10):
133
+ plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets.requires_grad_(True), attribution_map=attr, opt=opt, debug=True)
134
+
135
+ delta_attr = torch.autograd.grad(plaus_loss, attr, create_graph=True, retain_graph=True)[0]
136
+ attr = attr - (delta_attr * alpha)
137
+ alpha *= opt.scheduler
138
+
139
+ plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets, attribution_map=attr, opt=opt, debug=True)
140
+ if opt.iou_loss_only:
141
+ bbox_map = get_bbox_map(targets, attr)
142
+ plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr)))
143
+ plaus_loss = (1.0 - plaus_score)
144
+ distance_map = bbox_map
145
+
146
+ # attr = attr.clamp(0, 1)
147
+ attr = normalize_batch(attr)
148
+ plaus_losses.append(float(plaus_loss))
149
+ plaus_scores.append(float(plaus_score))
150
+ print(f'step: {i}, plaus_loss: {plaus_loss}, plaus_score: {plaus_score}, dist_reg: {dist_reg}, plaus_reg: {plaus_reg}')
151
+
152
+ for j in range(len(attr)):
153
+
154
+ # Add a subplot for each image
155
+ if i == 0 and opt.show_dist_reg:
156
+ ax = fig1.add_subplot(rows, cols, 1 + (j * cols))
157
+ ax.set_title(f'Distance Regularization Map {j}')
158
+ img_tensor = (1 - distance_map[j]).detach().cpu()
159
+ img_np = img_tensor.detach().cpu().numpy().squeeze()
160
+ img_colored = plt.cm.viridis(img_np)
161
+ bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now)
162
+ img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords)
163
+ subfigimshow(img_with_boxes, ax)
164
+ ax.axis('off')
165
+
166
+ else:
167
+ if i == 1:
168
+ # Add the first attr step to fig1
169
+ ax = fig1.add_subplot(rows, cols, 2 + (j * cols))
170
+ ax.set_title(f'Attr Step {i}' if j == 0 else '')
171
+ img_tensor = attr[j].detach().cpu()
172
+ img_np = img_tensor.detach().cpu().numpy().squeeze()
173
+ img_colored = plt.cm.viridis(img_np)
174
+ bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now)
175
+ img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords)
176
+ subfigimshow(img_with_boxes, ax)
177
+ ax.axis('off')
178
+ else:
179
+ # Subsequent steps go to fig2
180
+ ax = fig2.add_subplot(rows, cols, 1 + (i - 1) + (j * cols))
181
+ ax.set_title(f'Attr Step {i}' if j == 0 else '')
182
+ img_tensor = attr[j].detach().cpu()
183
+ img_np = img_tensor.detach().cpu().numpy().squeeze()
184
+ img_colored = plt.cm.viridis(img_np)
185
+ subfigimshow(img_colored, ax)
186
+ ax.axis('off')
187
+
188
+ # Plot plausibility losses
189
+ ax3.plot(range(nsamples), plaus_losses, marker='o', label='Plausibility Loss')
190
+ ax3.set_title('Plausibility Losses Across Steps')
191
+ ax3.set_xlabel('Step')
192
+ ax3.set_ylabel('Plausibility Loss')
193
+ ax3.grid(True)
194
+ ax3.legend()
195
+
196
+ # Plot plausibility scores
197
+ ax4.plot(range(nsamples), plaus_scores, marker='o', label='Plausibility Scores')
198
+ ax4.set_title('Plausibility Scores Across Steps')
199
+ ax4.set_xlabel('Step')
200
+ ax4.set_ylabel('Plausibility Score')
201
+ ax4.grid(True)
202
+ ax4.legend()
203
+
204
+ # Save the figures
205
+ fig1.savefig('figs/distance_and_first_step.png', bbox_inches='tight')
206
+ plt.close(fig1)
207
+
208
+ fig2.savefig('figs/remaining_attr_steps.png', bbox_inches='tight')
209
+ plt.close(fig2)
210
+
211
+ fig3.savefig('figs/plausibility_losses.png', bbox_inches='tight')
212
+ plt.close(fig3)
213
+
214
+ fig4.savefig('figs/plausibility_scores.png', bbox_inches='tight')
215
+ plt.close(fig3)
216
+
217
+
218
+ print('Figures saved: figs/distance_and_first_step.png, figs/remaining_attr_steps.png, and figs/plausibility_losses.png, figs/plausibility_scores.png')
219
+ return 'figs/distance_and_first_step.png', 'figs/remaining_attr_steps.png', 'figs/plausibility_losses.png', 'figs/plausibility_scores.png'
220
+
221
+ if __name__ == '__main__':
222
+
223
+ #TODO - this does not appear to be working correctly
224
+ parser = argparse.ArgumentParser()
225
+ # ##################### Standard Settings #####################
226
+ parser.add_argument('--pgt_coeff', type=float, default=1.0, help='pgt_coeff')
227
+ parser.add_argument('--focus_coeff', type=float, default=0.2, help='focus_coeff')
228
+ parser.add_argument('--alpha', type=float, default=400.0, help='alpha')
229
+ parser.add_argument('--num_bb', type=int, default=0, help='num_bb')
230
+ parser.add_argument('--x_coord', type=float, default=0.2, help='x_coord')
231
+ parser.add_argument('--y_coord', type=float, default=0.35, help='y_coord')
232
+ ########################## Advanced #########################
233
+ parser.add_argument('--scheduler', type=float, default=2.0, help='scheduler for alpha')
234
+ #############################################################
235
+ parser.add_argument('--device', type=str, default='0', help='device')
236
+ parser.add_argument('--dist_coeff', type=float, default=0.5, help='dist_coeff')
237
+ parser.add_argument('--dist_reg_only', type=bool, default=True, help='dist_reg_only')
238
+ parser.add_argument('--iou_coeff', type=float, default=0.5, help='iou_coeff')
239
+ parser.add_argument('--bbox_coeff', type=float, default=0.0, help='bbox_coeff')
240
+ parser.add_argument('--dist_x_bbox', type=bool, default=False, help='dist_x_bbox')
241
+ parser.add_argument('--iou_loss_only', type=bool, default=False, help='iou_loss_only')
242
+ parser.add_argument('--show_dist_reg', type=bool, default=True, help='show distance regularization map in figure')
243
+ opt = parser.parse_args()
244
+
245
+ toy_problem(opt.pgt_coeff, opt.focus_coeff, opt.x_coord, opt.y_coord, opt.alpha, opt.num_bb,
246
+ opt.scheduler, opt.device, opt.dist_coeff, opt.dist_reg_only, opt.iou_coeff,
247
+ opt.bbox_coeff, opt.dist_x_bbox, opt.iou_loss_only, opt.show_dist_reg)