Spaces:
Sleeping
Sleeping
Commit ·
32938bb
1
Parent(s): 44c1e94
Added all files
Browse files- interface.py +85 -0
- plaus_functs.py +784 -0
- plot_functs.py +154 -0
- toy_problem_pgt.py +247 -0
interface.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import sys
|
| 3 |
+
from toy_problem_pgt import toy_problem
|
| 4 |
+
|
| 5 |
+
help_guide = """
|
| 6 |
+
## Help Guide
|
| 7 |
+
|
| 8 |
+
This demo allows you to experiment with the toy problem from the Plausibility Guided Training (PGT) paper.
|
| 9 |
+
|
| 10 |
+
### Input Parameters:
|
| 11 |
+
|
| 12 |
+
- **PGT Coefficient**: Choose a number between 0.1 and 10. This determines the emphasis given to the PGT loss function in the training process.
|
| 13 |
+
- **Focus Coefficient**: Choose a number between 0.01 and 1. This determines the concentration of pixels around the object that will be rewarded. Higher coefficient results in a more focused reward.
|
| 14 |
+
- **X Coord**: Choose a number between 0 and 1. This sets the X coordinate of the target.
|
| 15 |
+
- **Y Coord**: Choose a number between 0 and 1. This sets the Y coordinate of the target.
|
| 16 |
+
|
| 17 |
+
### Outputs:
|
| 18 |
+
|
| 19 |
+
1. **First 2 images**: Displays the distance regulaization map and the first atrribution map step.
|
| 20 |
+
2. **Second 8 images**: Displays each other attribution map steps.
|
| 21 |
+
3. **PGT Losses**: Visualizes the plausibility losses over each step.
|
| 22 |
+
4. **PGT Scores**: Displays the plausibility scores over each step.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
|
| 29 |
+
with gr.Blocks(title="toy problem demo", theme=gr.themes.Base()) as demo:
|
| 30 |
+
gr.Markdown(
|
| 31 |
+
"""
|
| 32 |
+
# Toy Problem Demo
|
| 33 |
+
This is a demo of the toy problem implementation.
|
| 34 |
+
"""
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
with gr.Accordion("Help", open=False):
|
| 38 |
+
gr.Markdown(help_guide)
|
| 39 |
+
|
| 40 |
+
with gr.Row() as file_settings:
|
| 41 |
+
|
| 42 |
+
pgt_coeff = gr.Number(label="PGT Coefficent",info="choose a number between 0.1 and 10",
|
| 43 |
+
minimum=0.1,maximum=10,value=1,interactive=True,step=1,show_label=True)
|
| 44 |
+
focus_coeff = gr.Number(label="Focus Coefficent",info="Choose a number between 0.1 and 1",
|
| 45 |
+
minimum=0.01,maximum=1,value=0.2,interactive=True,step=1,show_label=True)
|
| 46 |
+
|
| 47 |
+
#TODO - Target info (this is where we can adjust)
|
| 48 |
+
#We are just going to give the user access to the number of bounding boxes, and the xy coords
|
| 49 |
+
# num_bb = gr.Number(label="Number of Bounding Boxes",info="Choose a number",
|
| 50 |
+
# minimum=0,maximum=0,value=0,interactive=True,step=1,show_label=True)
|
| 51 |
+
x_coord = gr.Number(label="X Coord",info="Choose a number between 0 and 1",
|
| 52 |
+
minimum=0,maximum=1,value=0.8,interactive=True,step=1,show_label=True)
|
| 53 |
+
y_coord = gr.Number(label="Y Coord",info="Choose a number between 0 and 1",
|
| 54 |
+
minimum=0,maximum=1,value=0.76,interactive=True,step=1,show_label=True)
|
| 55 |
+
|
| 56 |
+
with gr.Row() as outputs:
|
| 57 |
+
output_img1 = gr.Image(type='filepath',label="First 2 images",
|
| 58 |
+
show_download_button=True,show_share_button=True,interactive=False,visible=True)
|
| 59 |
+
output_img2 = gr.Image(type='filepath',label="9 images",
|
| 60 |
+
show_download_button=True,show_share_button=True,interactive=False,visible=True, scale=4)
|
| 61 |
+
|
| 62 |
+
with gr.Row() as outputs_2:
|
| 63 |
+
output_img3 = gr.Image(type='filepath',label="PGT Losses",
|
| 64 |
+
show_download_button=True,show_share_button=True,interactive=False,visible=True)
|
| 65 |
+
output_img4 = gr.Image(type='filepath',label="PGT Scores",
|
| 66 |
+
show_download_button=True,show_share_button=True,interactive=False,visible=True)
|
| 67 |
+
|
| 68 |
+
# List of components for clearing
|
| 69 |
+
clear_comp_list = [output_img1, output_img2, output_img3, output_img4]
|
| 70 |
+
|
| 71 |
+
# Row for start, clear and demo buttons
|
| 72 |
+
with gr.Row() as buttons:
|
| 73 |
+
start = gr.Button(value="Start")
|
| 74 |
+
clear = gr.ClearButton(value='Clear All',components=clear_comp_list,
|
| 75 |
+
interactive=True,visible=True)
|
| 76 |
+
|
| 77 |
+
# List of gradio components that are input into the run_all method (when start button is clicked)
|
| 78 |
+
run_inputs = [pgt_coeff, focus_coeff, x_coord, y_coord]
|
| 79 |
+
|
| 80 |
+
# List of gradio components that are output from the run_all method (when start button is clicked)
|
| 81 |
+
run_outputs = [output_img1, output_img2, output_img3, output_img4]
|
| 82 |
+
|
| 83 |
+
start.click(toy_problem, inputs=run_inputs, outputs=run_outputs)
|
| 84 |
+
|
| 85 |
+
demo.queue().launch()
|
plaus_functs.py
ADDED
|
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from plot_functs import *
|
| 4 |
+
from plot_functs import normalize_tensor, overlay_mask, imshow
|
| 5 |
+
import math
|
| 6 |
+
import time
|
| 7 |
+
import matplotlib.path as mplPath
|
| 8 |
+
from matplotlib.path import Path
|
| 9 |
+
from utils.general import non_max_suppression, xyxy2xywh, scale_coords
|
| 10 |
+
|
| 11 |
+
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
|
| 12 |
+
"""
|
| 13 |
+
Compute the gradient of an image with respect to a given tensor.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
img (torch.Tensor): The input image tensor.
|
| 17 |
+
grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
|
| 18 |
+
norm (bool, optional): Whether to normalize the gradient. Defaults to True.
|
| 19 |
+
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
|
| 20 |
+
grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
|
| 21 |
+
keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
torch.Tensor: The computed attribution map.
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
|
| 28 |
+
grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
|
| 29 |
+
else:
|
| 30 |
+
grad_wrt_outputs = None
|
| 31 |
+
attribution_map = torch.autograd.grad(grad_wrt, img,
|
| 32 |
+
grad_outputs=grad_wrt_outputs,
|
| 33 |
+
create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
|
| 34 |
+
)[0]
|
| 35 |
+
if absolute:
|
| 36 |
+
attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
|
| 37 |
+
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
| 38 |
+
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
|
| 39 |
+
if norm:
|
| 40 |
+
if keepmean:
|
| 41 |
+
attmean = torch.mean(attribution_map)
|
| 42 |
+
attmin = torch.min(attribution_map)
|
| 43 |
+
attmax = torch.max(attribution_map)
|
| 44 |
+
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
| 45 |
+
if keepmean:
|
| 46 |
+
attribution_map -= attribution_map.mean()
|
| 47 |
+
attribution_map += (attmean / (attmax - attmin))
|
| 48 |
+
|
| 49 |
+
return attribution_map
|
| 50 |
+
|
| 51 |
+
def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
|
| 52 |
+
"""
|
| 53 |
+
Generate Gaussian noise based on the input image.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
img (torch.Tensor): Input image.
|
| 57 |
+
grad_wrt: Gradient with respect to the input image.
|
| 58 |
+
norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
|
| 59 |
+
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
|
| 60 |
+
grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
|
| 61 |
+
keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
torch.Tensor: Generated Gaussian noise.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
gaussian_noise = torch.randn_like(img)
|
| 68 |
+
|
| 69 |
+
if absolute:
|
| 70 |
+
gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
|
| 71 |
+
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
| 72 |
+
gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
|
| 73 |
+
if norm:
|
| 74 |
+
if keepmean:
|
| 75 |
+
attmean = torch.mean(gaussian_noise)
|
| 76 |
+
attmin = torch.min(gaussian_noise)
|
| 77 |
+
attmax = torch.max(gaussian_noise)
|
| 78 |
+
gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
|
| 79 |
+
if keepmean:
|
| 80 |
+
gaussian_noise -= gaussian_noise.mean()
|
| 81 |
+
gaussian_noise += (attmean / (attmax - attmin))
|
| 82 |
+
|
| 83 |
+
return gaussian_noise
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
|
| 87 |
+
# TODO: Remove imgs from this function and only take it as input if debug is True
|
| 88 |
+
"""
|
| 89 |
+
Calculates the plausibility score based on the given inputs.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
imgs (torch.Tensor): The input images.
|
| 93 |
+
targets_out (torch.Tensor): The output targets.
|
| 94 |
+
attr (torch.Tensor): The attribute tensor.
|
| 95 |
+
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
torch.Tensor: The plausibility score.
|
| 99 |
+
"""
|
| 100 |
+
# # if imgs is None:
|
| 101 |
+
# # imgs = torch.zeros_like(attr)
|
| 102 |
+
# # with torch.no_grad():
|
| 103 |
+
# target_inds = targets_out[:, 0].int()
|
| 104 |
+
# xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
| 105 |
+
# num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
| 106 |
+
# # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
| 107 |
+
# xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
| 108 |
+
# co = xyxy_corners
|
| 109 |
+
# if corners:
|
| 110 |
+
# co = targets_out[:, 2:6].int()
|
| 111 |
+
# coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
| 112 |
+
# # rows = np.arange(co.shape[0])
|
| 113 |
+
# x1, x2 = co[:,1], co[:,3]
|
| 114 |
+
# y1, y2 = co[:,0], co[:,2]
|
| 115 |
+
|
| 116 |
+
# for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
| 117 |
+
# coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
| 118 |
+
|
| 119 |
+
if torch.isnan(attr).any():
|
| 120 |
+
attr = torch.nan_to_num(attr, nan=0.0)
|
| 121 |
+
|
| 122 |
+
coords_map = get_bbox_map(targets_out, attr)
|
| 123 |
+
plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
|
| 124 |
+
|
| 125 |
+
if debug:
|
| 126 |
+
for i in range(len(coords_map)):
|
| 127 |
+
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
|
| 128 |
+
test_bbox = torch.zeros_like(imgs[i])
|
| 129 |
+
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
|
| 130 |
+
imshow(test_bbox, save_path='figs/test_bbox')
|
| 131 |
+
if imgs is None:
|
| 132 |
+
imgs = torch.zeros_like(attr)
|
| 133 |
+
imshow(imgs[i], save_path='figs/im0')
|
| 134 |
+
imshow(attr[i], save_path='figs/attr')
|
| 135 |
+
|
| 136 |
+
# with torch.no_grad():
|
| 137 |
+
# # att_select = attr[coords_map]
|
| 138 |
+
# att_select = attr * coords_map.to(torch.float32)
|
| 139 |
+
# att_total = attr
|
| 140 |
+
|
| 141 |
+
# IoU_num = torch.sum(att_select)
|
| 142 |
+
# IoU_denom = torch.sum(att_total)
|
| 143 |
+
|
| 144 |
+
# IoU_ = (IoU_num / IoU_denom)
|
| 145 |
+
# plaus_score = IoU_
|
| 146 |
+
|
| 147 |
+
# # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
|
| 148 |
+
|
| 149 |
+
return plaus_score
|
| 150 |
+
|
| 151 |
+
from utils.general import bbox_iou
|
| 152 |
+
|
| 153 |
+
def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
|
| 154 |
+
# TODO: Remove imgs from this function and only take it as input if debug is True
|
| 155 |
+
"""
|
| 156 |
+
Calculates the plausibility score based on the given inputs.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
imgs (torch.Tensor): The input images.
|
| 160 |
+
targets_out (torch.Tensor): The output targets.
|
| 161 |
+
attr (torch.Tensor): The attribute tensor.
|
| 162 |
+
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
torch.Tensor: The plausibility score.
|
| 166 |
+
"""
|
| 167 |
+
# if imgs is None:
|
| 168 |
+
# imgs = torch.zeros_like(attr)
|
| 169 |
+
# with torch.no_grad():
|
| 170 |
+
target_inds = targets_out[:, 0].int()
|
| 171 |
+
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
| 172 |
+
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
| 173 |
+
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
| 174 |
+
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
| 175 |
+
co = xyxy_corners
|
| 176 |
+
if corners:
|
| 177 |
+
co = targets_out[:, 2:6].int()
|
| 178 |
+
coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
| 179 |
+
# rows = np.arange(co.shape[0])
|
| 180 |
+
x1, x2 = co[:,1], co[:,3]
|
| 181 |
+
y1, y2 = co[:,0], co[:,2]
|
| 182 |
+
|
| 183 |
+
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
| 184 |
+
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
| 185 |
+
|
| 186 |
+
if torch.isnan(attr).any():
|
| 187 |
+
attr = torch.nan_to_num(attr, nan=0.0)
|
| 188 |
+
if debug:
|
| 189 |
+
for i in range(len(coords_map)):
|
| 190 |
+
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
|
| 191 |
+
test_bbox = torch.zeros_like(imgs[i])
|
| 192 |
+
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
|
| 193 |
+
imshow(test_bbox, save_path='figs/test_bbox')
|
| 194 |
+
imshow(imgs[i], save_path='figs/im0')
|
| 195 |
+
imshow(attr[i], save_path='figs/attr')
|
| 196 |
+
|
| 197 |
+
# att_select = attr[coords_map]
|
| 198 |
+
# with torch.no_grad():
|
| 199 |
+
# IoU_num = (torch.sum(attr[coords_map]))
|
| 200 |
+
# IoU_denom = torch.sum(attr)
|
| 201 |
+
# IoU_ = (IoU_num / (IoU_denom))
|
| 202 |
+
|
| 203 |
+
# IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
|
| 204 |
+
co = (xyxy_batch * num_pixels).int()
|
| 205 |
+
x1 = co[:,1] + 1
|
| 206 |
+
y1 = co[:,0] + 1
|
| 207 |
+
# with torch.no_grad():
|
| 208 |
+
attr_ = torch.sum(attr, 1, keepdim=True)
|
| 209 |
+
corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
|
| 210 |
+
for ic in range(co.shape[0]):
|
| 211 |
+
attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
|
| 212 |
+
attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
|
| 213 |
+
attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
|
| 214 |
+
attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
|
| 215 |
+
|
| 216 |
+
x_0, y_0 = max_indices_2d(attr0[0])
|
| 217 |
+
x_1, y_1 = max_indices_2d(attr1[0])
|
| 218 |
+
x_2, y_2 = max_indices_2d(attr2[0])
|
| 219 |
+
x_3, y_3 = max_indices_2d(attr3[0])
|
| 220 |
+
|
| 221 |
+
y_1 += y1[ic]
|
| 222 |
+
x_2 += x1[ic]
|
| 223 |
+
x_3 += x1[ic]
|
| 224 |
+
y_3 += y1[ic]
|
| 225 |
+
|
| 226 |
+
max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
|
| 227 |
+
torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
|
| 228 |
+
torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
|
| 229 |
+
torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
|
| 230 |
+
if corners_attr is None:
|
| 231 |
+
corners_attr = max_corners
|
| 232 |
+
else:
|
| 233 |
+
corners_attr = torch.cat([corners_attr, max_corners], dim=0)
|
| 234 |
+
# corners_attr[ic] = max_corners
|
| 235 |
+
# corners_attr = attr[:,0,:4,0]
|
| 236 |
+
corners_attr = corners_attr.view(-1, 4)
|
| 237 |
+
# corners_attr = torch.stack(corners_attr, dim=0)
|
| 238 |
+
IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
|
| 239 |
+
plaus_score = IoU_.mean()
|
| 240 |
+
|
| 241 |
+
return plaus_score
|
| 242 |
+
|
| 243 |
+
def max_indices_2d(x_inp):
|
| 244 |
+
# values, indices = x.reshape(x.size(0), -1).max(dim=-1)
|
| 245 |
+
torch.max(x_inp,)
|
| 246 |
+
index = torch.argmax(x_inp)
|
| 247 |
+
x = index // x_inp.shape[1]
|
| 248 |
+
y = index % x_inp.shape[1]
|
| 249 |
+
# x, y = divmod(index.item(), x_inp.shape[1])
|
| 250 |
+
|
| 251 |
+
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def point_in_polygon(poly, grid):
|
| 255 |
+
# t0 = time.time()
|
| 256 |
+
num_points = poly.shape[0]
|
| 257 |
+
j = num_points - 1
|
| 258 |
+
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
|
| 259 |
+
for i in range(num_points):
|
| 260 |
+
cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
|
| 261 |
+
cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
|
| 262 |
+
cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
|
| 263 |
+
oddNodes = oddNodes ^ (cond1 | cond2) & cond3
|
| 264 |
+
j = i
|
| 265 |
+
# t1 = time.time()
|
| 266 |
+
# print(f'point in polygon time: {t1-t0}')
|
| 267 |
+
return oddNodes
|
| 268 |
+
|
| 269 |
+
def point_in_polygon_gpu(poly, grid):
|
| 270 |
+
num_points = poly.shape[0]
|
| 271 |
+
i = torch.arange(num_points)
|
| 272 |
+
j = (i - 1) % num_points
|
| 273 |
+
# Expand dimensions
|
| 274 |
+
# t0 = time.time()
|
| 275 |
+
poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
|
| 276 |
+
# t1 = time.time()
|
| 277 |
+
cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
|
| 278 |
+
cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
|
| 279 |
+
cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
|
| 280 |
+
# t2 = time.time()
|
| 281 |
+
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
|
| 282 |
+
cond = (cond1 | cond2) & cond3
|
| 283 |
+
# t3 = time.time()
|
| 284 |
+
# efficiently perform xor using gpu and avoiding cpu as much as possible
|
| 285 |
+
c = []
|
| 286 |
+
while len(cond) > 1:
|
| 287 |
+
if len(cond) % 2 == 1: # odd number of elements
|
| 288 |
+
c.append(cond[-1])
|
| 289 |
+
cond = cond[:-1]
|
| 290 |
+
cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
|
| 291 |
+
for c_ in c:
|
| 292 |
+
cond = torch.bitwise_xor(cond, c_)
|
| 293 |
+
oddNodes = cond
|
| 294 |
+
# t4 = time.time()
|
| 295 |
+
# for c in cond:
|
| 296 |
+
# oddNodes = oddNodes ^ c
|
| 297 |
+
# print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
|
| 298 |
+
# print(f'point in polygon time gpu: {t4-t0}')
|
| 299 |
+
# oddNodes = oddNodes ^ (cond1 | cond2) & cond3
|
| 300 |
+
return oddNodes
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def bitmap_for_polygon(poly, h, w):
|
| 304 |
+
y = torch.arange(h).to(poly.device).float()
|
| 305 |
+
x = torch.arange(w).to(poly.device).float()
|
| 306 |
+
grid_y, grid_x = torch.meshgrid(y, x)
|
| 307 |
+
grid = torch.stack((grid_x, grid_y), dim=-1)
|
| 308 |
+
bitmap = point_in_polygon(poly, grid)
|
| 309 |
+
return bitmap.unsqueeze(0)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def corners_coords(center_xywh):
|
| 313 |
+
center_x, center_y, w, h = center_xywh
|
| 314 |
+
x = center_x - w/2
|
| 315 |
+
y = center_y - h/2
|
| 316 |
+
return torch.tensor([x, y, x+w, y+h])
|
| 317 |
+
|
| 318 |
+
def corners_coords_batch(center_xywh):
|
| 319 |
+
center_x, center_y = center_xywh[:,0], center_xywh[:,1]
|
| 320 |
+
w, h = center_xywh[:,2], center_xywh[:,3]
|
| 321 |
+
x = center_x - w/2
|
| 322 |
+
y = center_y - h/2
|
| 323 |
+
return torch.stack([x, y, x+w, y+h], dim=1)
|
| 324 |
+
|
| 325 |
+
def normalize_batch(x):
|
| 326 |
+
"""
|
| 327 |
+
Normalize a batch of tensors along each channel.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
torch.Tensor: Normalized tensor of the same shape as the input.
|
| 334 |
+
"""
|
| 335 |
+
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
| 336 |
+
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
| 337 |
+
for i in range(x.shape[0]):
|
| 338 |
+
mins[i] = x[i].min()
|
| 339 |
+
maxs[i] = x[i].max()
|
| 340 |
+
x_ = (x - mins) / (maxs - mins)
|
| 341 |
+
|
| 342 |
+
return x_
|
| 343 |
+
|
| 344 |
+
def get_detections(model_clone, img):
|
| 345 |
+
"""
|
| 346 |
+
Get detections from a model given an input image and targets.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
model (nn.Module): The model to use for detection.
|
| 350 |
+
img (torch.Tensor): The input image tensor.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
torch.Tensor: The detected bounding boxes.
|
| 354 |
+
"""
|
| 355 |
+
model_clone.eval() # Set model to evaluation mode
|
| 356 |
+
# Run inference
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
det_out, out = model_clone(img)
|
| 359 |
+
|
| 360 |
+
# model_.train()
|
| 361 |
+
del img
|
| 362 |
+
|
| 363 |
+
return det_out, out
|
| 364 |
+
|
| 365 |
+
def get_labels(det_out, imgs, targets, opt):
|
| 366 |
+
###################### Get predicted labels ######################
|
| 367 |
+
nb, _, height, width = imgs.shape # batch size, channels, height, width
|
| 368 |
+
targets_ = targets.clone()
|
| 369 |
+
targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
|
| 370 |
+
lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
|
| 371 |
+
o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
|
| 372 |
+
pred_labels = []
|
| 373 |
+
for si, pred in enumerate(o):
|
| 374 |
+
labels = targets_[targets_[:, 0] == si, 1:]
|
| 375 |
+
nl = len(labels)
|
| 376 |
+
predn = pred.clone()
|
| 377 |
+
# Get the indices that sort the values in column 5 in ascending order
|
| 378 |
+
sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
|
| 379 |
+
# Apply the sorting indices to the tensor
|
| 380 |
+
sorted_pred = predn[sort_indices]
|
| 381 |
+
# Remove predictions with less than 0.1 confidence
|
| 382 |
+
n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
|
| 383 |
+
sorted_pred = sorted_pred[:n_conf]
|
| 384 |
+
new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
|
| 385 |
+
preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
|
| 386 |
+
preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
|
| 387 |
+
gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
|
| 388 |
+
preds[:, 2:] /= gn.to(imgs.device) # from pixels
|
| 389 |
+
pred_labels.append(preds)
|
| 390 |
+
pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
|
| 391 |
+
|
| 392 |
+
return pred_labels
|
| 393 |
+
##################################################################
|
| 394 |
+
|
| 395 |
+
from torchvision.utils import make_grid
|
| 396 |
+
|
| 397 |
+
def get_center_coords(attr):
|
| 398 |
+
img_tensor = img_tensor / img_tensor.max()
|
| 399 |
+
|
| 400 |
+
# Define a brightness threshold
|
| 401 |
+
threshold = 0.95
|
| 402 |
+
|
| 403 |
+
# Create a binary mask of the bright pixels
|
| 404 |
+
mask = img_tensor > threshold
|
| 405 |
+
|
| 406 |
+
# Get the coordinates of the bright pixels
|
| 407 |
+
y_coords, x_coords = torch.where(mask)
|
| 408 |
+
|
| 409 |
+
# Calculate the centroid of the bright pixels
|
| 410 |
+
centroid_x = x_coords.float().mean().item()
|
| 411 |
+
centroid_y = y_coords.float().mean().item()
|
| 412 |
+
|
| 413 |
+
print(f'The central bright point is at ({centroid_x}, {centroid_y})')
|
| 414 |
+
|
| 415 |
+
return
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
|
| 419 |
+
"""
|
| 420 |
+
Compute the distance grids from each pixel to the target coordinates.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
attr (torch.Tensor): Attribution maps.
|
| 424 |
+
targets (torch.Tensor): Target coordinates.
|
| 425 |
+
focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
|
| 426 |
+
debug (bool, optional): Whether to visualize debug information. Defaults to False.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
torch.Tensor: Distance grids.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
# Assign the height and width of the input tensor to variables
|
| 433 |
+
height, width = attr.shape[-1], attr.shape[-2]
|
| 434 |
+
|
| 435 |
+
# attr = torch.abs(attr) # Take absolute values of gradients
|
| 436 |
+
# attr = normalize_batch(attr) # Normalize attribution maps per image in batch
|
| 437 |
+
|
| 438 |
+
# Create a grid of indices
|
| 439 |
+
xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
|
| 440 |
+
idx_grid = torch.stack((xx, yy), dim=-1).float()
|
| 441 |
+
|
| 442 |
+
# Expand the grid to match the batch size
|
| 443 |
+
idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
|
| 444 |
+
|
| 445 |
+
# Initialize a list to store the distance grids
|
| 446 |
+
dist_grids_ = [[]] * attr.shape[0]
|
| 447 |
+
|
| 448 |
+
# Loop over batches
|
| 449 |
+
for j in range(attr.shape[0]):
|
| 450 |
+
# Get the rows where the first column is the current unique value
|
| 451 |
+
rows = targets[targets[:, 0] == j]
|
| 452 |
+
|
| 453 |
+
if len(rows) != 0:
|
| 454 |
+
# Create a tensor for the target coordinates
|
| 455 |
+
xy = rows[:,2:4] # y, x
|
| 456 |
+
# Flip the x and y coordinates and scale them to the image size
|
| 457 |
+
xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
|
| 458 |
+
xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
|
| 459 |
+
|
| 460 |
+
# Compute the Euclidean distance from each pixel to the target coordinates
|
| 461 |
+
dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
|
| 462 |
+
|
| 463 |
+
# Pick the closest distance to any target for each pixel
|
| 464 |
+
dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
|
| 465 |
+
dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
|
| 466 |
+
else:
|
| 467 |
+
# Set grid to zero if no targets are present
|
| 468 |
+
dist_grid = torch.zeros_like(attr[j])
|
| 469 |
+
|
| 470 |
+
dist_grids_[j] = dist_grid
|
| 471 |
+
# Convert the list of distance grids to a tensor for faster computation
|
| 472 |
+
dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
|
| 473 |
+
if torch.isnan(dist_grids).any():
|
| 474 |
+
dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
|
| 475 |
+
|
| 476 |
+
if debug:
|
| 477 |
+
for i in range(len(dist_grids)):
|
| 478 |
+
if ((i % 8) == 0):
|
| 479 |
+
grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
|
| 480 |
+
imshow(grid_show, save_path='figs/dist_grids')
|
| 481 |
+
if imgs is None:
|
| 482 |
+
imgs = torch.zeros_like(attr)
|
| 483 |
+
imshow(imgs[i], save_path='figs/im0')
|
| 484 |
+
img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
|
| 485 |
+
imshow(img_overlay, save_path='figs/dist_grid_overlay')
|
| 486 |
+
weighted_attr = (dist_grids[i] * attr[i])
|
| 487 |
+
imshow(weighted_attr, save_path='figs/weighted_attr')
|
| 488 |
+
imshow(attr[i], save_path='figs/attr')
|
| 489 |
+
|
| 490 |
+
return dist_grids
|
| 491 |
+
|
| 492 |
+
def attr_reg(attribution_map, distance_map):
|
| 493 |
+
|
| 494 |
+
# dist_attr = distance_map * attribution_map
|
| 495 |
+
dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
|
| 496 |
+
# del distance_map, attribution_map
|
| 497 |
+
return dist_attr
|
| 498 |
+
|
| 499 |
+
def get_bbox_map(targets_out, attr, corners=False):
|
| 500 |
+
target_inds = targets_out[:, 0].int()
|
| 501 |
+
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
| 502 |
+
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
| 503 |
+
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
| 504 |
+
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
| 505 |
+
co = xyxy_corners
|
| 506 |
+
if corners:
|
| 507 |
+
co = targets_out[:, 2:6].int()
|
| 508 |
+
coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
| 509 |
+
# rows = np.arange(co.shape[0])
|
| 510 |
+
x1, x2 = co[:,1], co[:,3]
|
| 511 |
+
y1, y2 = co[:,0], co[:,2]
|
| 512 |
+
|
| 513 |
+
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
| 514 |
+
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
| 515 |
+
|
| 516 |
+
bbox_map = coords_map.to(torch.float32)
|
| 517 |
+
|
| 518 |
+
return bbox_map
|
| 519 |
+
######################################## BCE #######################################
|
| 520 |
+
def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
|
| 521 |
+
# if imgs is None:
|
| 522 |
+
# imgs = torch.zeros_like(attribution_map)
|
| 523 |
+
# Calculate Plausibility IoU with attribution maps
|
| 524 |
+
# attribution_map.retains_grad = True
|
| 525 |
+
if not only_loss:
|
| 526 |
+
plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
|
| 527 |
+
else:
|
| 528 |
+
plaus_score = torch.tensor(0.0)
|
| 529 |
+
|
| 530 |
+
# attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
| 531 |
+
|
| 532 |
+
# Calculate distance regularization
|
| 533 |
+
distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
|
| 534 |
+
# distance_map = torch.ones_like(attribution_map)
|
| 535 |
+
|
| 536 |
+
if opt.dist_x_bbox:
|
| 537 |
+
bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
|
| 538 |
+
distance_map[bbox_map] = 0.0
|
| 539 |
+
# distance_map = distance_map * (1 - bbox_map)
|
| 540 |
+
|
| 541 |
+
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
| 542 |
+
dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
|
| 543 |
+
# Negative regularization term for incentivizing pixels far from the target to have low attribution
|
| 544 |
+
dist_attr_neg = attr_reg(attribution_map, distance_map)
|
| 545 |
+
# Calculate plausibility regularization term
|
| 546 |
+
# dist_reg = dist_attr_pos - dist_attr_neg
|
| 547 |
+
dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
|
| 548 |
+
# dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
|
| 549 |
+
# dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
|
| 550 |
+
# torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
|
| 551 |
+
# / 2.5
|
| 552 |
+
|
| 553 |
+
if opt.bbox_coeff != 0.0:
|
| 554 |
+
bbox_map = get_bbox_map(targets, attribution_map)
|
| 555 |
+
attr_bbox_pos = attr_reg(attribution_map, bbox_map)
|
| 556 |
+
attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
|
| 557 |
+
bbox_reg = attr_bbox_pos - attr_bbox_neg
|
| 558 |
+
# bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
|
| 559 |
+
else:
|
| 560 |
+
bbox_reg = 0.0
|
| 561 |
+
|
| 562 |
+
bbox_map = get_bbox_map(targets, attribution_map)
|
| 563 |
+
plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
|
| 564 |
+
# iou_loss = (1.0 - plaus_score)
|
| 565 |
+
|
| 566 |
+
if not opt.dist_reg_only:
|
| 567 |
+
dist_reg_loss = (((1.0 + dist_reg) / 2.0))
|
| 568 |
+
plaus_reg = (plaus_score * opt.iou_coeff) + \
|
| 569 |
+
(((dist_reg_loss * opt.dist_coeff) + \
|
| 570 |
+
(bbox_reg * opt.bbox_coeff))\
|
| 571 |
+
# ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
|
| 572 |
+
# / (plaus_score) \
|
| 573 |
+
)
|
| 574 |
+
else:
|
| 575 |
+
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
| 576 |
+
# plaus_reg = dist_reg
|
| 577 |
+
# Calculate plausibility loss
|
| 578 |
+
plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
|
| 579 |
+
# plaus_loss = (plaus_reg) * opt.pgt_coeff
|
| 580 |
+
if only_loss:
|
| 581 |
+
return plaus_loss
|
| 582 |
+
if not debug:
|
| 583 |
+
return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
|
| 584 |
+
else:
|
| 585 |
+
return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
|
| 586 |
+
|
| 587 |
+
####################################################################################
|
| 588 |
+
#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
|
| 589 |
+
####################################################################################
|
| 590 |
+
|
| 591 |
+
def generate_vanilla_grad(model, input_tensor, loss_func = None,
|
| 592 |
+
targets_list=None, targets=None, metric=None, out_num = 1,
|
| 593 |
+
n_max_labels=3, norm=True, abs=True, grayscale=True,
|
| 594 |
+
class_specific_attr = True, device='cpu'):
|
| 595 |
+
"""
|
| 596 |
+
Generate vanilla gradients for the given model and input tensor.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
model (nn.Module): The model to generate gradients for.
|
| 600 |
+
input_tensor (torch.Tensor): The input tensor for which gradients are computed.
|
| 601 |
+
loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
|
| 602 |
+
targets_list (list, optional): The list of target tensors. Defaults to None.
|
| 603 |
+
metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
|
| 604 |
+
out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
|
| 605 |
+
n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
|
| 606 |
+
norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
|
| 607 |
+
abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
|
| 608 |
+
grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
|
| 609 |
+
class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
|
| 610 |
+
device (str, optional): The device to use for computation. Defaults to 'cpu'.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
torch.Tensor: The generated vanilla gradients.
|
| 614 |
+
"""
|
| 615 |
+
# Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
|
| 616 |
+
train_mode = model.training
|
| 617 |
+
if not train_mode:
|
| 618 |
+
model.train()
|
| 619 |
+
|
| 620 |
+
input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
|
| 621 |
+
model.zero_grad() # Zero gradients
|
| 622 |
+
inpt = input_tensor
|
| 623 |
+
# Forward pass
|
| 624 |
+
train_out = model(inpt) # training outputs (no inference outputs in train mode)
|
| 625 |
+
|
| 626 |
+
# train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
|
| 627 |
+
# train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
|
| 628 |
+
# train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
|
| 629 |
+
|
| 630 |
+
if class_specific_attr:
|
| 631 |
+
n_attr_list, index_classes = [], []
|
| 632 |
+
for i in range(len(input_tensor)):
|
| 633 |
+
if len(targets_list[i]) > n_max_labels:
|
| 634 |
+
targets_list[i] = targets_list[i][:n_max_labels]
|
| 635 |
+
if targets_list[i].numel() != 0:
|
| 636 |
+
# unique_classes = torch.unique(targets_list[i][:,1])
|
| 637 |
+
class_numbers = targets_list[i][:,1]
|
| 638 |
+
index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
|
| 639 |
+
num_attrs = len(targets_list[i])
|
| 640 |
+
# index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
|
| 641 |
+
# num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
|
| 642 |
+
n_attr_list.append(num_attrs)
|
| 643 |
+
else:
|
| 644 |
+
index_classes.append([0, 1, 2, 3, 4])
|
| 645 |
+
n_attr_list.append(0)
|
| 646 |
+
|
| 647 |
+
targets_list_filled = [targ.clone().detach() for targ in targets_list]
|
| 648 |
+
labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
|
| 649 |
+
max_labels = np.max(labels_len)
|
| 650 |
+
max_index = np.argmax(labels_len)
|
| 651 |
+
for i in range(len(targets_list)):
|
| 652 |
+
# targets_list_filled[i] = targets_list[i]
|
| 653 |
+
if len(targets_list_filled[i]) < max_labels:
|
| 654 |
+
tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
|
| 655 |
+
targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
|
| 656 |
+
else:
|
| 657 |
+
targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
|
| 658 |
+
for i in range(len(targets_list_filled)-1,-1,-1):
|
| 659 |
+
if targets_list_filled[i].numel() == 0:
|
| 660 |
+
targets_list_filled.pop(i)
|
| 661 |
+
targets_list_filled = torch.cat(targets_list_filled)
|
| 662 |
+
|
| 663 |
+
n_img_attrs = len(input_tensor) if class_specific_attr else 1
|
| 664 |
+
n_img_attrs = 1 if loss_func else n_img_attrs
|
| 665 |
+
|
| 666 |
+
attrs_batch = []
|
| 667 |
+
for i_batch in range(n_img_attrs):
|
| 668 |
+
if loss_func and class_specific_attr:
|
| 669 |
+
i_batch = max_index
|
| 670 |
+
# inpt = input_tensor[i_batch].unsqueeze(0)
|
| 671 |
+
# ##################################################################
|
| 672 |
+
# model.zero_grad() # Zero gradients
|
| 673 |
+
# train_out = model(inpt) # training outputs (no inference outputs in train mode)
|
| 674 |
+
# ##################################################################
|
| 675 |
+
n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
|
| 676 |
+
n_label_attrs = 1 if not class_specific_attr else n_label_attrs
|
| 677 |
+
attrs_img = []
|
| 678 |
+
for i_attr in range(n_label_attrs):
|
| 679 |
+
if loss_func is None:
|
| 680 |
+
grad_wrt = train_out[out_num]
|
| 681 |
+
if class_specific_attr:
|
| 682 |
+
grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
|
| 683 |
+
grad_wrt_outputs = torch.ones_like(grad_wrt)
|
| 684 |
+
else:
|
| 685 |
+
# if class_specific_attr:
|
| 686 |
+
# targets = targets_list[:][i_attr]
|
| 687 |
+
# n_targets = len(targets_list[i_batch])
|
| 688 |
+
if class_specific_attr:
|
| 689 |
+
target_indiv = targets_list_filled[:,i_attr] # batch image input
|
| 690 |
+
else:
|
| 691 |
+
target_indiv = targets
|
| 692 |
+
# target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
|
| 693 |
+
# target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
|
| 694 |
+
|
| 695 |
+
try:
|
| 696 |
+
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
|
| 697 |
+
except:
|
| 698 |
+
target_indiv = target_indiv.to(device)
|
| 699 |
+
inpt = inpt.to(device)
|
| 700 |
+
for tro in train_out:
|
| 701 |
+
tro = tro.to(device)
|
| 702 |
+
print("Error in loss function, trying again with device specified")
|
| 703 |
+
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
|
| 704 |
+
grad_wrt = loss
|
| 705 |
+
grad_wrt_outputs = None
|
| 706 |
+
|
| 707 |
+
model.zero_grad() # Zero gradients
|
| 708 |
+
gradients = torch.autograd.grad(grad_wrt, inpt,
|
| 709 |
+
grad_outputs=grad_wrt_outputs,
|
| 710 |
+
retain_graph=True,
|
| 711 |
+
# create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Convert gradients to numpy array and back to ensure full separation from graph
|
| 715 |
+
# attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
|
| 716 |
+
attribution_map = gradients[0]#.clone().detach() # without converting to numpy
|
| 717 |
+
|
| 718 |
+
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
| 719 |
+
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
|
| 720 |
+
if abs:
|
| 721 |
+
attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
|
| 722 |
+
if norm:
|
| 723 |
+
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
| 724 |
+
attrs_img.append(attribution_map)
|
| 725 |
+
if len(attrs_img) == 0:
|
| 726 |
+
attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
|
| 727 |
+
else:
|
| 728 |
+
attrs_batch.append(torch.stack(attrs_img).to(device))
|
| 729 |
+
|
| 730 |
+
# out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
|
| 731 |
+
# out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
|
| 732 |
+
out_attr = attrs_batch
|
| 733 |
+
# Set model back to original mode
|
| 734 |
+
if not train_mode:
|
| 735 |
+
model.eval()
|
| 736 |
+
|
| 737 |
+
return out_attr
|
| 738 |
+
|
| 739 |
+
class RVNonLinearFunc(nn.Module):
|
| 740 |
+
"""
|
| 741 |
+
Custom Bayesian ReLU activation function for random variables.
|
| 742 |
+
|
| 743 |
+
Attributes:
|
| 744 |
+
None
|
| 745 |
+
"""
|
| 746 |
+
def __init__(self, func):
|
| 747 |
+
super(RVNonLinearFunc, self).__init__()
|
| 748 |
+
self.func = func
|
| 749 |
+
|
| 750 |
+
def forward(self, mu_in, Sigma_in):
|
| 751 |
+
"""
|
| 752 |
+
Forward pass of the Bayesian ReLU activation function.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
|
| 756 |
+
representing the mean input to the ReLU activation function.
|
| 757 |
+
Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
|
| 758 |
+
representing the covariance input to the ReLU activation function.
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
|
| 762 |
+
including the mean of the output and the covariance of the output.
|
| 763 |
+
"""
|
| 764 |
+
# Collect stats
|
| 765 |
+
batch_size = mu_in.size(0)
|
| 766 |
+
|
| 767 |
+
# Mean
|
| 768 |
+
mu_out = self.func(mu_in)
|
| 769 |
+
|
| 770 |
+
# Compute the derivative of the ReLU activation function with respect to the input mean
|
| 771 |
+
gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
|
| 772 |
+
|
| 773 |
+
# add an extra dimension to gradi at position 2 and 1
|
| 774 |
+
grad1 = gradi.unsqueeze(dim=2)
|
| 775 |
+
grad2 = gradi.unsqueeze(dim=1)
|
| 776 |
+
|
| 777 |
+
# compute the outer product of grad1 and grad2
|
| 778 |
+
outer_product = torch.bmm(grad1, grad2)
|
| 779 |
+
|
| 780 |
+
# element-wise multiply Sigma_in with the outer product
|
| 781 |
+
# and return the result
|
| 782 |
+
Sigma_out = torch.mul(Sigma_in, outer_product)
|
| 783 |
+
|
| 784 |
+
return mu_out, Sigma_out
|
plot_functs.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
class Subplots:
|
| 7 |
+
def __init__(self, figsize = (40, 5)):
|
| 8 |
+
self.fig = plt.figure(figsize=figsize)
|
| 9 |
+
|
| 10 |
+
def plot_img_list(self, img_list, savedir='figs/test',
|
| 11 |
+
nrows = 1, rownum = 0,
|
| 12 |
+
hold = False, coltitles=[], rowtitle=''):
|
| 13 |
+
|
| 14 |
+
for i, img in enumerate(img_list):
|
| 15 |
+
try:
|
| 16 |
+
npimg = img.clone().detach().cpu().numpy()
|
| 17 |
+
except:
|
| 18 |
+
npimg = img
|
| 19 |
+
tpimg = np.transpose(npimg, (1, 2, 0))
|
| 20 |
+
lenrow = int((len(img_list)))
|
| 21 |
+
ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
|
| 22 |
+
if len(coltitles) > i:
|
| 23 |
+
ax.set_title(coltitles[i])
|
| 24 |
+
if i == 0:
|
| 25 |
+
ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
|
| 26 |
+
xycoords='axes fraction', textcoords='offset points',
|
| 27 |
+
size='large', ha='center', va='baseline')
|
| 28 |
+
# ax.set_ylabel(rowtitle, rotation=90)
|
| 29 |
+
ax.imshow(tpimg)
|
| 30 |
+
ax.axis('off')
|
| 31 |
+
|
| 32 |
+
if not hold:
|
| 33 |
+
self.fig.tight_layout()
|
| 34 |
+
plt.savefig(f'{savedir}.png')
|
| 35 |
+
plt.clf()
|
| 36 |
+
plt.close('all')
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def VisualizeNumpyImageGrayscale(image_3d):
|
| 40 |
+
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
| 41 |
+
"""
|
| 42 |
+
vmin = np.min(image_3d)
|
| 43 |
+
image_2d = image_3d - vmin
|
| 44 |
+
vmax = np.max(image_2d)
|
| 45 |
+
return (image_2d / vmax)
|
| 46 |
+
|
| 47 |
+
def normalize_numpy(image_3d):
|
| 48 |
+
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
| 49 |
+
"""
|
| 50 |
+
vmin = np.min(image_3d)
|
| 51 |
+
image_2d = image_3d - vmin
|
| 52 |
+
vmax = np.max(image_2d)
|
| 53 |
+
return (image_2d / vmax)
|
| 54 |
+
|
| 55 |
+
# def normalize_tensor(image_3d):
|
| 56 |
+
# r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
| 57 |
+
# """
|
| 58 |
+
# vmin = torch.min(image_3d)
|
| 59 |
+
# image_2d = image_3d - vmin
|
| 60 |
+
# vmax = torch.max(image_2d)
|
| 61 |
+
# return (image_2d / vmax)
|
| 62 |
+
|
| 63 |
+
def normalize_tensor(image_3d):
|
| 64 |
+
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
| 65 |
+
"""
|
| 66 |
+
image_2d = (image_3d - torch.min(image_3d))
|
| 67 |
+
return (image_2d / torch.max(image_2d))
|
| 68 |
+
|
| 69 |
+
def format_img(img_):
|
| 70 |
+
np_img = img_.numpy()
|
| 71 |
+
tp_img = np.transpose(np_img, (1, 2, 0))
|
| 72 |
+
return tp_img
|
| 73 |
+
|
| 74 |
+
def imshow(img, save_path=None):
|
| 75 |
+
try:
|
| 76 |
+
npimg = img.clone().detach().cpu().numpy()
|
| 77 |
+
except:
|
| 78 |
+
npimg = img
|
| 79 |
+
tpimg = np.transpose(npimg, (1, 2, 0))
|
| 80 |
+
plt.imshow(tpimg)
|
| 81 |
+
# plt.axis('off')
|
| 82 |
+
plt.tight_layout()
|
| 83 |
+
if save_path != None:
|
| 84 |
+
plt.savefig(str(str(save_path) + ".png"))
|
| 85 |
+
#plt.show()a
|
| 86 |
+
|
| 87 |
+
def imshow_img(img, imsave_path):
|
| 88 |
+
# works for tensors and numpy arrays
|
| 89 |
+
try:
|
| 90 |
+
npimg = VisualizeNumpyImageGrayscale(img.numpy())
|
| 91 |
+
except:
|
| 92 |
+
npimg = VisualizeNumpyImageGrayscale(img)
|
| 93 |
+
npimg = np.transpose(npimg, (2, 0, 1))
|
| 94 |
+
imshow(npimg, save_path=imsave_path)
|
| 95 |
+
print("Saving image as ", imsave_path)
|
| 96 |
+
|
| 97 |
+
def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
|
| 98 |
+
model.train()
|
| 99 |
+
model.to(device)
|
| 100 |
+
img = img.to(device)
|
| 101 |
+
img.requires_grad_(True)
|
| 102 |
+
labels.to(device).requires_grad_(True)
|
| 103 |
+
model.requires_grad_(True)
|
| 104 |
+
cuda = device.type != 'cpu'
|
| 105 |
+
scaler = amp.GradScaler(enabled=cuda)
|
| 106 |
+
pred = model(img)
|
| 107 |
+
# out, train_out = model(img, augment=augment) # inference and training outputs
|
| 108 |
+
loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls
|
| 109 |
+
# loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
|
| 110 |
+
# loss = torch.sum(loss).requires_grad_(True)
|
| 111 |
+
|
| 112 |
+
with torch.autograd.set_detect_anomaly(True):
|
| 113 |
+
scaler.scale(loss).backward(inputs=img)
|
| 114 |
+
# loss.backward()
|
| 115 |
+
|
| 116 |
+
# S_c = torch.max(pred[0].data, 0)[0]
|
| 117 |
+
Sc_dx = img.grad
|
| 118 |
+
model.eval()
|
| 119 |
+
Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
|
| 120 |
+
return Sc_dx
|
| 121 |
+
|
| 122 |
+
def calculate_snr(img, attr, dB=True):
|
| 123 |
+
try:
|
| 124 |
+
img_np = img.detach().cpu().numpy()
|
| 125 |
+
attr_np = attr.detach().cpu().numpy()
|
| 126 |
+
except:
|
| 127 |
+
img_np = img
|
| 128 |
+
attr_np = attr
|
| 129 |
+
|
| 130 |
+
# Calculate the signal power
|
| 131 |
+
signal_power = np.mean(img_np**2)
|
| 132 |
+
|
| 133 |
+
# Calculate the noise power
|
| 134 |
+
noise_power = np.mean(attr_np**2)
|
| 135 |
+
|
| 136 |
+
if dB == True:
|
| 137 |
+
# Calculate SNR in dB
|
| 138 |
+
snr = 10 * np.log10(signal_power / noise_power)
|
| 139 |
+
else:
|
| 140 |
+
# Calculate SNR
|
| 141 |
+
snr = signal_power / noise_power
|
| 142 |
+
|
| 143 |
+
return snr
|
| 144 |
+
|
| 145 |
+
def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
|
| 146 |
+
|
| 147 |
+
cmap = plt.get_cmap(colormap)
|
| 148 |
+
npmask = np.array(mask.clone().detach().cpu().squeeze(0))
|
| 149 |
+
# cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
|
| 150 |
+
cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
|
| 151 |
+
overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
|
| 152 |
+
overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
|
| 153 |
+
|
| 154 |
+
return overlayed_tensor
|
toy_problem_pgt.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from plaus_functs import get_center_coords, get_distance_grids, get_plaus_loss, get_bbox_map, normalize_batch
|
| 3 |
+
from plot_functs import imshow
|
| 4 |
+
from torchvision.transforms.functional import gaussian_blur
|
| 5 |
+
import argparse
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import cv2
|
| 10 |
+
|
| 11 |
+
def subfigimshow(img, ax):
|
| 12 |
+
print(f'img shape: {img.shape}')
|
| 13 |
+
try:
|
| 14 |
+
npimg = img.clone().detach().cpu().numpy()
|
| 15 |
+
except:
|
| 16 |
+
npimg = img
|
| 17 |
+
if len(npimg.shape) == 2:
|
| 18 |
+
# If it's a 2D array, it's likely a grayscale image
|
| 19 |
+
ax.imshow(npimg, cmap='gray')
|
| 20 |
+
elif len(npimg.shape) == 3:
|
| 21 |
+
if npimg.shape[0] == 3 or npimg.shape[0] == 1:
|
| 22 |
+
# If the first dimension is 3 or 1, it's likely in (C, H, W) format
|
| 23 |
+
tpimg = np.transpose(npimg, (1, 2, 0))
|
| 24 |
+
else:
|
| 25 |
+
# It's already in (H, W, C) format
|
| 26 |
+
tpimg = npimg
|
| 27 |
+
|
| 28 |
+
if tpimg.shape[2] == 1:
|
| 29 |
+
# If it's a 3D array with only one channel, squeeze it
|
| 30 |
+
ax.imshow(np.squeeze(tpimg), cmap='gray')
|
| 31 |
+
else:
|
| 32 |
+
ax.imshow(tpimg)
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"Unexpected image shape: {npimg.shape}")
|
| 35 |
+
|
| 36 |
+
def draw_bounding_boxes(image, boxes, color=(0, 255, 0), thickness=2):
|
| 37 |
+
# Ensure image is 3-channel RGB
|
| 38 |
+
if len(image.shape) == 2:
|
| 39 |
+
image = np.stack([image] * 3, axis=-1)
|
| 40 |
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
| 41 |
+
image = np.repeat(image, 3, axis=2)
|
| 42 |
+
|
| 43 |
+
# Ensure image is uint8 and in range [0, 255]
|
| 44 |
+
if image.dtype != np.uint8:
|
| 45 |
+
image = (image * 255).clip(0, 255).astype(np.uint8)
|
| 46 |
+
|
| 47 |
+
image_with_boxes = image.copy()
|
| 48 |
+
for box in boxes:
|
| 49 |
+
x_center, y_center, width, height = box
|
| 50 |
+
x_min = int((x_center - width / 2) * image_with_boxes.shape[1])
|
| 51 |
+
y_min = int((y_center - height / 2) * image_with_boxes.shape[0])
|
| 52 |
+
x_max = int((x_center + width / 2) * image_with_boxes.shape[1])
|
| 53 |
+
y_max = int((y_center + height / 2) * image_with_boxes.shape[0])
|
| 54 |
+
cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), color, thickness)
|
| 55 |
+
|
| 56 |
+
return image_with_boxes
|
| 57 |
+
|
| 58 |
+
def toy_problem(pgt_coeff, focus_coeff, x_coord, y_coord, num_bb=0, alpha=200.0, scheduler=2.0, device="0", dist_coeff=0.5, dist_reg_only=True, iou_coeff=0.5,
|
| 59 |
+
bbox_coeff=0.0, dist_x_bbox=False, iou_loss_only=False, show_dist_reg=True):
|
| 60 |
+
|
| 61 |
+
# Create a Namespace object to hold params
|
| 62 |
+
opt = argparse.Namespace()
|
| 63 |
+
# Save all parameters as attributes of the Namespace object
|
| 64 |
+
opt.pgt_coeff = pgt_coeff
|
| 65 |
+
opt.focus_coeff = focus_coeff
|
| 66 |
+
opt.x_coord = x_coord
|
| 67 |
+
opt.y_coord = y_coord
|
| 68 |
+
opt.num_bb = num_bb
|
| 69 |
+
opt.alpha = alpha
|
| 70 |
+
opt.scheduler = scheduler
|
| 71 |
+
opt.device = device
|
| 72 |
+
opt.dist_coeff = dist_coeff
|
| 73 |
+
opt.dist_reg_only = dist_reg_only
|
| 74 |
+
opt.iou_coeff = iou_coeff
|
| 75 |
+
opt.bbox_coeff = bbox_coeff
|
| 76 |
+
opt.dist_x_bbox = dist_x_bbox
|
| 77 |
+
opt.iou_loss_only = iou_loss_only
|
| 78 |
+
opt.show_dist_reg = show_dist_reg
|
| 79 |
+
|
| 80 |
+
# Create a list of save dirs for output
|
| 81 |
+
save_dirs = []
|
| 82 |
+
|
| 83 |
+
# Set CUDA device
|
| 84 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 85 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(int(opt.device))
|
| 86 |
+
|
| 87 |
+
#TODO - Adjust this for the number of bounding boxes
|
| 88 |
+
targets = torch.tensor([
|
| 89 |
+
[0, 0, opt.x_coord, opt.y_coord, 0.05, 0.05],
|
| 90 |
+
# [0, 1, 0.4, 0.6, 0.05, 0.07],
|
| 91 |
+
# [1, 0, 0.25, 0.2, 0.04, 0.05],
|
| 92 |
+
# [2, 0, 0.8, 0.76, 0.05, 0.05],
|
| 93 |
+
# [2, 0, 0.8, 0.2, 0.05, 0.05],
|
| 94 |
+
# [0, 0, 0.8, 0.76, 0.05, 0.05],
|
| 95 |
+
# [1, 0, 0.8, 0.2, 0.05, 0.05],
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
unique_classes = torch.unique(targets[:,0])
|
| 99 |
+
# X = (gaussian_blur(torch.rand(len(unique_classes), 1, 50, 50)**2, 3)**4)
|
| 100 |
+
attr = (gaussian_blur(torch.rand(len(unique_classes), 1, 640, 640)**2, 13)**4).requires_grad_(True)
|
| 101 |
+
plaus_loss = get_plaus_loss(targets, attribution_map=attr,
|
| 102 |
+
opt=opt,
|
| 103 |
+
debug=True,
|
| 104 |
+
only_loss=True)
|
| 105 |
+
if opt.iou_loss_only:
|
| 106 |
+
bbox_map = get_bbox_map(targets, attr)
|
| 107 |
+
plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr)))
|
| 108 |
+
plaus_loss = (1.0 - plaus_score)
|
| 109 |
+
|
| 110 |
+
# Plot params (adjust as nessesary)
|
| 111 |
+
nsamples = 10
|
| 112 |
+
rows = len(attr) # Number of images
|
| 113 |
+
cols = nsamples + 2 # Define the number of columns for subplots
|
| 114 |
+
size = 3
|
| 115 |
+
|
| 116 |
+
# Create a new figure for each i
|
| 117 |
+
fig1 = plt.figure(figsize=(cols * size, rows * size))
|
| 118 |
+
plt.tight_layout()
|
| 119 |
+
|
| 120 |
+
# Create the second figure for the remaining 8 attr steps
|
| 121 |
+
fig2 = plt.figure(figsize=(cols * size, rows * size))
|
| 122 |
+
plt.tight_layout()
|
| 123 |
+
|
| 124 |
+
# Create a figure for plausibility losses
|
| 125 |
+
fig3, ax3 = plt.subplots(figsize=(10, 6))
|
| 126 |
+
plaus_losses = []
|
| 127 |
+
|
| 128 |
+
# Create a figure for plausibility scores
|
| 129 |
+
fig4, ax4 = plt.subplots(figsize=(10, 6))
|
| 130 |
+
plaus_scores = []
|
| 131 |
+
|
| 132 |
+
for i in range(10):
|
| 133 |
+
plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets.requires_grad_(True), attribution_map=attr, opt=opt, debug=True)
|
| 134 |
+
|
| 135 |
+
delta_attr = torch.autograd.grad(plaus_loss, attr, create_graph=True, retain_graph=True)[0]
|
| 136 |
+
attr = attr - (delta_attr * alpha)
|
| 137 |
+
alpha *= opt.scheduler
|
| 138 |
+
|
| 139 |
+
plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets, attribution_map=attr, opt=opt, debug=True)
|
| 140 |
+
if opt.iou_loss_only:
|
| 141 |
+
bbox_map = get_bbox_map(targets, attr)
|
| 142 |
+
plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr)))
|
| 143 |
+
plaus_loss = (1.0 - plaus_score)
|
| 144 |
+
distance_map = bbox_map
|
| 145 |
+
|
| 146 |
+
# attr = attr.clamp(0, 1)
|
| 147 |
+
attr = normalize_batch(attr)
|
| 148 |
+
plaus_losses.append(float(plaus_loss))
|
| 149 |
+
plaus_scores.append(float(plaus_score))
|
| 150 |
+
print(f'step: {i}, plaus_loss: {plaus_loss}, plaus_score: {plaus_score}, dist_reg: {dist_reg}, plaus_reg: {plaus_reg}')
|
| 151 |
+
|
| 152 |
+
for j in range(len(attr)):
|
| 153 |
+
|
| 154 |
+
# Add a subplot for each image
|
| 155 |
+
if i == 0 and opt.show_dist_reg:
|
| 156 |
+
ax = fig1.add_subplot(rows, cols, 1 + (j * cols))
|
| 157 |
+
ax.set_title(f'Distance Regularization Map {j}')
|
| 158 |
+
img_tensor = (1 - distance_map[j]).detach().cpu()
|
| 159 |
+
img_np = img_tensor.detach().cpu().numpy().squeeze()
|
| 160 |
+
img_colored = plt.cm.viridis(img_np)
|
| 161 |
+
bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now)
|
| 162 |
+
img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords)
|
| 163 |
+
subfigimshow(img_with_boxes, ax)
|
| 164 |
+
ax.axis('off')
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
if i == 1:
|
| 168 |
+
# Add the first attr step to fig1
|
| 169 |
+
ax = fig1.add_subplot(rows, cols, 2 + (j * cols))
|
| 170 |
+
ax.set_title(f'Attr Step {i}' if j == 0 else '')
|
| 171 |
+
img_tensor = attr[j].detach().cpu()
|
| 172 |
+
img_np = img_tensor.detach().cpu().numpy().squeeze()
|
| 173 |
+
img_colored = plt.cm.viridis(img_np)
|
| 174 |
+
bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now)
|
| 175 |
+
img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords)
|
| 176 |
+
subfigimshow(img_with_boxes, ax)
|
| 177 |
+
ax.axis('off')
|
| 178 |
+
else:
|
| 179 |
+
# Subsequent steps go to fig2
|
| 180 |
+
ax = fig2.add_subplot(rows, cols, 1 + (i - 1) + (j * cols))
|
| 181 |
+
ax.set_title(f'Attr Step {i}' if j == 0 else '')
|
| 182 |
+
img_tensor = attr[j].detach().cpu()
|
| 183 |
+
img_np = img_tensor.detach().cpu().numpy().squeeze()
|
| 184 |
+
img_colored = plt.cm.viridis(img_np)
|
| 185 |
+
subfigimshow(img_colored, ax)
|
| 186 |
+
ax.axis('off')
|
| 187 |
+
|
| 188 |
+
# Plot plausibility losses
|
| 189 |
+
ax3.plot(range(nsamples), plaus_losses, marker='o', label='Plausibility Loss')
|
| 190 |
+
ax3.set_title('Plausibility Losses Across Steps')
|
| 191 |
+
ax3.set_xlabel('Step')
|
| 192 |
+
ax3.set_ylabel('Plausibility Loss')
|
| 193 |
+
ax3.grid(True)
|
| 194 |
+
ax3.legend()
|
| 195 |
+
|
| 196 |
+
# Plot plausibility scores
|
| 197 |
+
ax4.plot(range(nsamples), plaus_scores, marker='o', label='Plausibility Scores')
|
| 198 |
+
ax4.set_title('Plausibility Scores Across Steps')
|
| 199 |
+
ax4.set_xlabel('Step')
|
| 200 |
+
ax4.set_ylabel('Plausibility Score')
|
| 201 |
+
ax4.grid(True)
|
| 202 |
+
ax4.legend()
|
| 203 |
+
|
| 204 |
+
# Save the figures
|
| 205 |
+
fig1.savefig('figs/distance_and_first_step.png', bbox_inches='tight')
|
| 206 |
+
plt.close(fig1)
|
| 207 |
+
|
| 208 |
+
fig2.savefig('figs/remaining_attr_steps.png', bbox_inches='tight')
|
| 209 |
+
plt.close(fig2)
|
| 210 |
+
|
| 211 |
+
fig3.savefig('figs/plausibility_losses.png', bbox_inches='tight')
|
| 212 |
+
plt.close(fig3)
|
| 213 |
+
|
| 214 |
+
fig4.savefig('figs/plausibility_scores.png', bbox_inches='tight')
|
| 215 |
+
plt.close(fig3)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
print('Figures saved: figs/distance_and_first_step.png, figs/remaining_attr_steps.png, and figs/plausibility_losses.png, figs/plausibility_scores.png')
|
| 219 |
+
return 'figs/distance_and_first_step.png', 'figs/remaining_attr_steps.png', 'figs/plausibility_losses.png', 'figs/plausibility_scores.png'
|
| 220 |
+
|
| 221 |
+
if __name__ == '__main__':
|
| 222 |
+
|
| 223 |
+
#TODO - this does not appear to be working correctly
|
| 224 |
+
parser = argparse.ArgumentParser()
|
| 225 |
+
# ##################### Standard Settings #####################
|
| 226 |
+
parser.add_argument('--pgt_coeff', type=float, default=1.0, help='pgt_coeff')
|
| 227 |
+
parser.add_argument('--focus_coeff', type=float, default=0.2, help='focus_coeff')
|
| 228 |
+
parser.add_argument('--alpha', type=float, default=400.0, help='alpha')
|
| 229 |
+
parser.add_argument('--num_bb', type=int, default=0, help='num_bb')
|
| 230 |
+
parser.add_argument('--x_coord', type=float, default=0.2, help='x_coord')
|
| 231 |
+
parser.add_argument('--y_coord', type=float, default=0.35, help='y_coord')
|
| 232 |
+
########################## Advanced #########################
|
| 233 |
+
parser.add_argument('--scheduler', type=float, default=2.0, help='scheduler for alpha')
|
| 234 |
+
#############################################################
|
| 235 |
+
parser.add_argument('--device', type=str, default='0', help='device')
|
| 236 |
+
parser.add_argument('--dist_coeff', type=float, default=0.5, help='dist_coeff')
|
| 237 |
+
parser.add_argument('--dist_reg_only', type=bool, default=True, help='dist_reg_only')
|
| 238 |
+
parser.add_argument('--iou_coeff', type=float, default=0.5, help='iou_coeff')
|
| 239 |
+
parser.add_argument('--bbox_coeff', type=float, default=0.0, help='bbox_coeff')
|
| 240 |
+
parser.add_argument('--dist_x_bbox', type=bool, default=False, help='dist_x_bbox')
|
| 241 |
+
parser.add_argument('--iou_loss_only', type=bool, default=False, help='iou_loss_only')
|
| 242 |
+
parser.add_argument('--show_dist_reg', type=bool, default=True, help='show distance regularization map in figure')
|
| 243 |
+
opt = parser.parse_args()
|
| 244 |
+
|
| 245 |
+
toy_problem(opt.pgt_coeff, opt.focus_coeff, opt.x_coord, opt.y_coord, opt.alpha, opt.num_bb,
|
| 246 |
+
opt.scheduler, opt.device, opt.dist_coeff, opt.dist_reg_only, opt.iou_coeff,
|
| 247 |
+
opt.bbox_coeff, opt.dist_x_bbox, opt.iou_loss_only, opt.show_dist_reg)
|