Upload 9 files
Browse files- tools/__init__.py +0 -0
- tools/__pycache__/__init__.cpython-38.pyc +0 -0
- tools/__pycache__/mask_convert.cpython-38.pyc +0 -0
- tools/__pycache__/utilss.cpython-38.pyc +0 -0
- tools/grad_cam_CNN.py +72 -0
- tools/grad_cam_transformer.py +95 -0
- tools/mask_convert.py +103 -0
- tools/params_flops.py +55 -0
- tools/utilss.py +249 -0
tools/__init__.py
ADDED
|
File without changes
|
tools/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (134 Bytes). View file
|
|
|
tools/__pycache__/mask_convert.cpython-38.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
tools/__pycache__/utilss.cpython-38.pyc
ADDED
|
Binary file (7.35 kB). View file
|
|
|
tools/grad_cam_CNN.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append('.')
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from utils import GradCAM, show_cam_on_image, center_crop_img
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
from utils.config import Config
|
| 10 |
+
from train import *
|
| 11 |
+
|
| 12 |
+
def get_args():
|
| 13 |
+
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
|
| 14 |
+
parser.add_argument("-c", "--config", type=str, default="configs\cdxformer.py")
|
| 15 |
+
parser.add_argument("--output_dir", default=None)
|
| 16 |
+
parser.add_argument("--layer", default=None)
|
| 17 |
+
return parser.parse_args()
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
args = get_args()
|
| 21 |
+
|
| 22 |
+
if args.layer == None:
|
| 23 |
+
raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
|
| 24 |
+
|
| 25 |
+
cfg = Config.fromfile(args.config)
|
| 26 |
+
|
| 27 |
+
model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
|
| 28 |
+
model = model.to('cuda')
|
| 29 |
+
|
| 30 |
+
test_loader = build_dataloader(cfg.dataset_config, mode='test')
|
| 31 |
+
|
| 32 |
+
if args.output_dir:
|
| 33 |
+
base_dir = args.output_dir
|
| 34 |
+
else:
|
| 35 |
+
base_dir = os.path.dirname(cfg.test_ckpt_path)
|
| 36 |
+
gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer)
|
| 37 |
+
if os.path.exists(gradcam_output_dir):
|
| 38 |
+
raise NameError("Please ensure gradcam_output_dir does not exist!")
|
| 39 |
+
|
| 40 |
+
os.makedirs(gradcam_output_dir)
|
| 41 |
+
|
| 42 |
+
for input in tqdm(test_loader):
|
| 43 |
+
target_layers = [eval(args.layer)] # name of the network layer
|
| 44 |
+
mask, img_id = input[2].cuda(), input[3]
|
| 45 |
+
|
| 46 |
+
cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True)
|
| 47 |
+
target_category = 1 # tabby, tabby cat
|
| 48 |
+
|
| 49 |
+
grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
|
| 50 |
+
|
| 51 |
+
for i in range(grayscale_cam_all.shape[0]):
|
| 52 |
+
grayscale_cam = grayscale_cam_all[i, :]
|
| 53 |
+
visualization = show_cam_on_image(0,
|
| 54 |
+
grayscale_cam,
|
| 55 |
+
use_rgb=True)
|
| 56 |
+
fig = plt.figure()
|
| 57 |
+
ax = fig.add_subplot(111)
|
| 58 |
+
ax.imshow(visualization)
|
| 59 |
+
# ax = fig.add_subplot(122)
|
| 60 |
+
# ax.imshow(mask[i].cpu().numpy())
|
| 61 |
+
ax.set_xticks([])
|
| 62 |
+
ax.set_yticks([])
|
| 63 |
+
ax.spines['top'].set_visible(False)
|
| 64 |
+
ax.spines['right'].set_visible(False)
|
| 65 |
+
ax.spines['bottom'].set_visible(False)
|
| 66 |
+
ax.spines['left'].set_visible(False)
|
| 67 |
+
plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
|
| 68 |
+
plt.close()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
main()
|
tools/grad_cam_transformer.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append('.')
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from utilss import GradCAM, show_cam_on_image, center_crop_img
|
| 7 |
+
import math
|
| 8 |
+
import argparse
|
| 9 |
+
from utils.config import Config
|
| 10 |
+
from train import *
|
| 11 |
+
|
| 12 |
+
def get_args():
|
| 13 |
+
# input x: B, L, C
|
| 14 |
+
# if not, please adjust the order
|
| 15 |
+
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
|
| 16 |
+
parser.add_argument("-c", "--config", type=str, default="configs/cdmask.py")
|
| 17 |
+
parser.add_argument("--output_dir", default=None)
|
| 18 |
+
parser.add_argument("--layer", default=None)
|
| 19 |
+
parser.add_argument("--imgsize", default=256)
|
| 20 |
+
return parser.parse_args()
|
| 21 |
+
|
| 22 |
+
class ResizeTransform:
|
| 23 |
+
def __init__(self, im_h: int, im_w: int):
|
| 24 |
+
self.height = im_h
|
| 25 |
+
self.width = im_w
|
| 26 |
+
|
| 27 |
+
def __call__(self, x):
|
| 28 |
+
# input x: B, L, C
|
| 29 |
+
result = x.reshape(x.size(0),
|
| 30 |
+
self.height,
|
| 31 |
+
self.width,
|
| 32 |
+
x.size(2))
|
| 33 |
+
|
| 34 |
+
# Bring the channels to the first dimension,
|
| 35 |
+
# like in CNNs.
|
| 36 |
+
# [batch_size, H, W, C] -> [batch, C, H, W]
|
| 37 |
+
result = result.permute(0, 3, 1, 2)
|
| 38 |
+
|
| 39 |
+
return result
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
args = get_args()
|
| 43 |
+
|
| 44 |
+
if args.layer == None:
|
| 45 |
+
raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
|
| 46 |
+
|
| 47 |
+
cfg = Config.fromfile(args.config)
|
| 48 |
+
|
| 49 |
+
model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
|
| 50 |
+
model = model.to('cuda')
|
| 51 |
+
|
| 52 |
+
test_loader = build_dataloader(cfg.dataset_config, mode='test')
|
| 53 |
+
|
| 54 |
+
if args.output_dir:
|
| 55 |
+
base_dir = args.output_dir
|
| 56 |
+
else:
|
| 57 |
+
base_dir = os.path.dirname(cfg.test_ckpt_path)
|
| 58 |
+
gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer)
|
| 59 |
+
if os.path.exists(gradcam_output_dir):
|
| 60 |
+
raise NameError("Please ensure gradcam_output_dir does not exist!")
|
| 61 |
+
|
| 62 |
+
os.makedirs(gradcam_output_dir)
|
| 63 |
+
|
| 64 |
+
for input in tqdm(test_loader):
|
| 65 |
+
target_layers = [eval(args.layer)] # name of the network layer
|
| 66 |
+
mask, img_id = input[2].cuda(), input[3]
|
| 67 |
+
|
| 68 |
+
cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True,
|
| 69 |
+
reshape_transform=ResizeTransform(im_h=args.imgsize, im_w=args.imgsize))
|
| 70 |
+
target_category = 1 # tabby, tabby cat
|
| 71 |
+
|
| 72 |
+
grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
|
| 73 |
+
|
| 74 |
+
for i in range(grayscale_cam_all.shape[0]):
|
| 75 |
+
grayscale_cam = grayscale_cam_all[i, :]
|
| 76 |
+
visualization = show_cam_on_image(0,
|
| 77 |
+
grayscale_cam,
|
| 78 |
+
use_rgb=True)
|
| 79 |
+
fig = plt.figure()
|
| 80 |
+
ax = fig.add_subplot(111)
|
| 81 |
+
ax.imshow(visualization)
|
| 82 |
+
# ax = fig.add_subplot(122)
|
| 83 |
+
# ax.imshow(mask[i].cpu().numpy())
|
| 84 |
+
ax.set_xticks([])
|
| 85 |
+
ax.set_yticks([])
|
| 86 |
+
ax.spines['top'].set_visible(False)
|
| 87 |
+
ax.spines['right'].set_visible(False)
|
| 88 |
+
ax.spines['bottom'].set_visible(False)
|
| 89 |
+
ax.spines['left'].set_visible(False)
|
| 90 |
+
plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
|
| 91 |
+
plt.close()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == '__main__':
|
| 95 |
+
main()
|
tools/mask_convert.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import argparse
|
| 3 |
+
import glob
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
import cv2
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
import multiprocessing.pool as mpp
|
| 11 |
+
import multiprocessing as mp
|
| 12 |
+
SEED = 66
|
| 13 |
+
|
| 14 |
+
def seed_everything(seed):
|
| 15 |
+
random.seed(seed)
|
| 16 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 17 |
+
np.random.seed(seed)
|
| 18 |
+
torch.manual_seed(seed)
|
| 19 |
+
torch.cuda.manual_seed(seed)
|
| 20 |
+
torch.backends.cudnn.deterministic = True
|
| 21 |
+
torch.backends.cudnn.benchmark = True
|
| 22 |
+
def label2rgb(mask, mask_pred):
|
| 23 |
+
real_1 = (mask == 1)
|
| 24 |
+
real_0 = (mask == 0)
|
| 25 |
+
pred_1 = (mask_pred == 1)
|
| 26 |
+
pred_0 = (mask_pred == 0)
|
| 27 |
+
|
| 28 |
+
TP = np.logical_and(real_1, pred_1)
|
| 29 |
+
TN = np.logical_and(real_0, pred_0)
|
| 30 |
+
FN = np.logical_and(real_1, pred_0)
|
| 31 |
+
FP = np.logical_and(real_0, pred_1)
|
| 32 |
+
|
| 33 |
+
mask_TP = TP[np.newaxis, :, :]
|
| 34 |
+
mask_TN = TN[np.newaxis, :, :]
|
| 35 |
+
mask_FN = FN[np.newaxis, :, :]
|
| 36 |
+
mask_FP = FP[np.newaxis, :, :]
|
| 37 |
+
|
| 38 |
+
h, w = mask.shape[0], mask.shape[1]
|
| 39 |
+
mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
|
| 40 |
+
mask_rgb[np.all(mask_TP, axis=0)] = [255, 255, 255] # TP
|
| 41 |
+
mask_rgb[np.all(mask_TN, axis=0)] = [0, 0, 0] # TN
|
| 42 |
+
mask_rgb[np.all(mask_FN, axis=0)] = [0, 255, 0] # FN
|
| 43 |
+
mask_rgb[np.all(mask_FP, axis=0)] = [255, 0, 0] # FP
|
| 44 |
+
|
| 45 |
+
return mask_rgb
|
| 46 |
+
|
| 47 |
+
def parse_args():
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument("--dataset", default="Vaihingen")
|
| 50 |
+
parser.add_argument("--mask-dir", default="data/Test/masks")
|
| 51 |
+
parser.add_argument("--output-mask-dir", default="data/Test/masks_rgb")
|
| 52 |
+
return parser.parse_args()
|
| 53 |
+
|
| 54 |
+
def mask_save(inp):
|
| 55 |
+
(mask, mask_pred, masks_output_dir, file_name) = inp
|
| 56 |
+
out_mask_path = os.path.join(masks_output_dir, "{}.png".format(file_name))
|
| 57 |
+
|
| 58 |
+
label = label2rgb(mask.copy(), mask_pred.copy())
|
| 59 |
+
|
| 60 |
+
rgb_label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
|
| 61 |
+
cv2.imwrite(out_mask_path, rgb_label)
|
| 62 |
+
|
| 63 |
+
# def get_rgb(inp):
|
| 64 |
+
# (mask_path, masks_output_dir,dataset) = inp
|
| 65 |
+
# mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
|
| 66 |
+
# mask_bgr = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
|
| 67 |
+
# mask = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
|
| 68 |
+
# if dataset == "LoveDA":
|
| 69 |
+
# rgb_label = loveda_label2rgb(mask.copy())
|
| 70 |
+
# elif dataset == "Vaihingen":
|
| 71 |
+
# rgb_label = vaihingen_label2rgb(mask.copy())
|
| 72 |
+
# elif dataset == "Potsdam":
|
| 73 |
+
# rgb_label = potsdam_label2rgb(mask.copy())
|
| 74 |
+
# elif dataset == "uavid":
|
| 75 |
+
# rgb_label = uavid_label2rgb(mask.copy())
|
| 76 |
+
# else: return
|
| 77 |
+
# #rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR)
|
| 78 |
+
|
| 79 |
+
# out_mask_path_rgb = os.path.join(masks_output_dir, "{}.png".format(mask_filename))
|
| 80 |
+
# rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB)
|
| 81 |
+
# cv2.imwrite(out_mask_path_rgb, rgb_label)
|
| 82 |
+
|
| 83 |
+
# if __name__ == '__main__':
|
| 84 |
+
# base_path = "/home/xwma/lrr/rssegmentation/"
|
| 85 |
+
# args = parse_args()
|
| 86 |
+
# dataset = args.dataset
|
| 87 |
+
|
| 88 |
+
# seed_everything(SEED)
|
| 89 |
+
# masks_dir = args.mask_dir
|
| 90 |
+
# masks_output_dir = args.output_mask_dir
|
| 91 |
+
# masks_dir = base_path + masks_dir
|
| 92 |
+
# masks_output_dir = base_path + masks_output_dir
|
| 93 |
+
|
| 94 |
+
# mask_paths = glob.glob(os.path.join(masks_dir, "*.png"))
|
| 95 |
+
# inp = [(mask_path, masks_output_dir, dataset) for mask_path in mask_paths]
|
| 96 |
+
# if not os.path.exists(masks_output_dir):
|
| 97 |
+
# os.makedirs(masks_output_dir)
|
| 98 |
+
|
| 99 |
+
# t0 = time.time()
|
| 100 |
+
# mpp.Pool(processes=mp.cpu_count()).map(get_rgb, inp)
|
| 101 |
+
# t1 = time.time()
|
| 102 |
+
# split_time = t1 - t0
|
| 103 |
+
# print('images spliting spends: {} s'.format(split_time))
|
tools/params_flops.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
sys.path.append('.')
|
| 4 |
+
from train import *
|
| 5 |
+
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count, parameter_count
|
| 6 |
+
from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description='count params and flops')
|
| 11 |
+
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
|
| 12 |
+
parser.add_argument("--size", type=int, default=256)
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
return args
|
| 15 |
+
|
| 16 |
+
def flops_mamba(model, shape=(3, 224, 224)):
|
| 17 |
+
# shape = self.__input_shape__[1:]
|
| 18 |
+
supported_ops = {
|
| 19 |
+
"aten::silu": None, # as relu is in _IGNORED_OPS
|
| 20 |
+
"aten::neg": None, # as relu is in _IGNORED_OPS
|
| 21 |
+
"aten::exp": None, # as relu is in _IGNORED_OPS
|
| 22 |
+
"aten::flip": None, # as permute is in _IGNORED_OPS
|
| 23 |
+
# "prim::PythonOp.CrossScan": None,
|
| 24 |
+
# "prim::PythonOp.CrossMerge": None,
|
| 25 |
+
"prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
|
| 26 |
+
"prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
|
| 27 |
+
"prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
|
| 28 |
+
"prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
|
| 29 |
+
"prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
model.cuda().eval()
|
| 33 |
+
|
| 34 |
+
input1 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| 35 |
+
input2 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| 36 |
+
params = parameter_count(model)[""]
|
| 37 |
+
Gflops, unsupported = flop_count(model=model, inputs=(input1,input2), supported_ops=supported_ops)
|
| 38 |
+
|
| 39 |
+
del model, input1, input2
|
| 40 |
+
# return sum(Gflops.values()) * 1e9
|
| 41 |
+
return f"params {params / 1e6} GFLOPs {sum(Gflops.values())}"
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
args = parse_args()
|
| 45 |
+
cfg = Config.fromfile(args.config)
|
| 46 |
+
net = myTrain(cfg).net.cuda()
|
| 47 |
+
|
| 48 |
+
size = args.size
|
| 49 |
+
input = torch.rand((1, 3, size, size)).cuda()
|
| 50 |
+
|
| 51 |
+
net.eval()
|
| 52 |
+
flops = FlopCountAnalysis(net, (input, input))
|
| 53 |
+
print(flop_count_table(flops, max_depth = 2))
|
| 54 |
+
|
| 55 |
+
print(flops_mamba(net, (3, size, size)))
|
tools/utilss.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class ActivationsAndGradients:
|
| 7 |
+
""" Class for extracting activations and
|
| 8 |
+
registering gradients from targeted intermediate layers """
|
| 9 |
+
|
| 10 |
+
def __init__(self, model, target_layers, reshape_transform):
|
| 11 |
+
self.model = model
|
| 12 |
+
self.gradients = []
|
| 13 |
+
self.activations = []
|
| 14 |
+
self.reshape_transform = reshape_transform
|
| 15 |
+
self.handles = []
|
| 16 |
+
for target_layer in target_layers:
|
| 17 |
+
self.handles.append(
|
| 18 |
+
target_layer.register_forward_hook(
|
| 19 |
+
self.save_activation))
|
| 20 |
+
# Backward compatibility with older pytorch versions:
|
| 21 |
+
if hasattr(target_layer, 'register_full_backward_hook'):
|
| 22 |
+
self.handles.append(
|
| 23 |
+
target_layer.register_full_backward_hook(
|
| 24 |
+
self.save_gradient))
|
| 25 |
+
else:
|
| 26 |
+
self.handles.append(
|
| 27 |
+
target_layer.register_backward_hook(
|
| 28 |
+
self.save_gradient))
|
| 29 |
+
|
| 30 |
+
def save_activation(self, module, input, output):
|
| 31 |
+
activation = output
|
| 32 |
+
if self.reshape_transform is not None:
|
| 33 |
+
activation = self.reshape_transform(activation)
|
| 34 |
+
self.activations.append(activation.cpu().detach())
|
| 35 |
+
|
| 36 |
+
def save_gradient(self, module, grad_input, grad_output):
|
| 37 |
+
# Gradients are computed in reverse order
|
| 38 |
+
grad = grad_output[0]
|
| 39 |
+
if self.reshape_transform is not None:
|
| 40 |
+
grad = self.reshape_transform(grad)
|
| 41 |
+
self.gradients = [grad.cpu().detach()] + self.gradients
|
| 42 |
+
|
| 43 |
+
def __call__(self, x, y):
|
| 44 |
+
self.gradients = []
|
| 45 |
+
self.activations = []
|
| 46 |
+
return self.model(x, y)
|
| 47 |
+
|
| 48 |
+
def release(self):
|
| 49 |
+
for handle in self.handles:
|
| 50 |
+
handle.remove()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class GradCAM:
|
| 54 |
+
def __init__(self,
|
| 55 |
+
cfg,
|
| 56 |
+
model,
|
| 57 |
+
target_layers,
|
| 58 |
+
reshape_transform=None,
|
| 59 |
+
use_cuda=False):
|
| 60 |
+
self.cfg = cfg
|
| 61 |
+
self.model = model.eval()
|
| 62 |
+
self.target_layers = target_layers
|
| 63 |
+
self.reshape_transform = reshape_transform
|
| 64 |
+
self.cuda = use_cuda
|
| 65 |
+
if self.cuda:
|
| 66 |
+
self.model = model.cuda()
|
| 67 |
+
self.activations_and_grads = ActivationsAndGradients(
|
| 68 |
+
self.model, target_layers, reshape_transform)
|
| 69 |
+
|
| 70 |
+
""" Get a vector of weights for every channel in the target layer.
|
| 71 |
+
Methods that return weights channels,
|
| 72 |
+
will typically need to only implement this function. """
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def get_cam_weights(grads):
|
| 76 |
+
return np.mean(grads, axis=(2, 3), keepdims=True)
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def get_loss(output, target_category):
|
| 80 |
+
loss = 0
|
| 81 |
+
for i in range(len(target_category)):
|
| 82 |
+
loss = loss + output[i]
|
| 83 |
+
return loss
|
| 84 |
+
|
| 85 |
+
def get_cam_image(self, activations, grads):
|
| 86 |
+
weights = self.get_cam_weights(grads)
|
| 87 |
+
weighted_activations = weights * activations
|
| 88 |
+
cam = weighted_activations.sum(axis=1)
|
| 89 |
+
|
| 90 |
+
return cam
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def get_target_width_height(input_tensor):
|
| 94 |
+
width, height = input_tensor.size(-1), input_tensor.size(-2)
|
| 95 |
+
return width, height
|
| 96 |
+
|
| 97 |
+
def compute_cam_per_layer(self, input_tensor):
|
| 98 |
+
activations_list = [a.cpu().data.numpy()
|
| 99 |
+
for a in self.activations_and_grads.activations]
|
| 100 |
+
grads_list = [g.cpu().data.numpy()
|
| 101 |
+
for g in self.activations_and_grads.gradients]
|
| 102 |
+
target_size = self.get_target_width_height(input_tensor)
|
| 103 |
+
|
| 104 |
+
cam_per_target_layer = []
|
| 105 |
+
# Loop over the saliency image from every layer
|
| 106 |
+
|
| 107 |
+
for layer_activations, layer_grads in zip(activations_list, grads_list):
|
| 108 |
+
cam = self.get_cam_image(layer_activations, layer_grads)
|
| 109 |
+
cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image
|
| 110 |
+
scaled = self.scale_cam_image(cam, target_size)
|
| 111 |
+
cam_per_target_layer.append(scaled[:, None, :])
|
| 112 |
+
|
| 113 |
+
return cam_per_target_layer
|
| 114 |
+
|
| 115 |
+
def aggregate_multi_layers(self, cam_per_target_layer):
|
| 116 |
+
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
|
| 117 |
+
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
|
| 118 |
+
result = np.mean(cam_per_target_layer, axis=1)
|
| 119 |
+
return self.scale_cam_image(result)
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def scale_cam_image(cam, target_size=None):
|
| 123 |
+
result = []
|
| 124 |
+
for img in cam:
|
| 125 |
+
img = img - np.min(img)
|
| 126 |
+
img = img / (1e-7 + np.max(img))
|
| 127 |
+
if target_size is not None:
|
| 128 |
+
img = cv2.resize(img, target_size)
|
| 129 |
+
result.append(img)
|
| 130 |
+
result = np.float32(result)
|
| 131 |
+
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
def __call__(self, input_tensor, target_category=None):
|
| 135 |
+
x, y = input_tensor
|
| 136 |
+
if self.cuda:
|
| 137 |
+
x = x.cuda()
|
| 138 |
+
y = y.cuda()
|
| 139 |
+
|
| 140 |
+
# 正向传播得到网络输出logits(未经过softmax)
|
| 141 |
+
if self.cfg.net == 'cdmask':
|
| 142 |
+
o, outputs = self.activations_and_grads(x, y)
|
| 143 |
+
mask_cls_results = outputs["pred_logits"]
|
| 144 |
+
mask_pred_results = outputs["pred_masks"]
|
| 145 |
+
mask_pred_results = F.interpolate(
|
| 146 |
+
mask_pred_results,
|
| 147 |
+
scale_factor=(4,4),
|
| 148 |
+
mode="bilinear",
|
| 149 |
+
align_corners=False,
|
| 150 |
+
)
|
| 151 |
+
mask_cls = F.softmax(mask_cls_results, dim=-1)[...,1:]
|
| 152 |
+
mask_pred = mask_pred_results.sigmoid()
|
| 153 |
+
output = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
|
| 154 |
+
else:
|
| 155 |
+
output = self.activations_and_grads(x, y)
|
| 156 |
+
|
| 157 |
+
if isinstance(target_category, int):
|
| 158 |
+
target_category = [target_category] * x.size(0)
|
| 159 |
+
|
| 160 |
+
if target_category is None:
|
| 161 |
+
target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
|
| 162 |
+
print(f"category id: {target_category}")
|
| 163 |
+
else:
|
| 164 |
+
assert (len(target_category) == x.size(0))
|
| 165 |
+
|
| 166 |
+
self.model.zero_grad()
|
| 167 |
+
loss = self.get_loss(output, target_category).sum()
|
| 168 |
+
loss.backward(retain_graph=True)
|
| 169 |
+
|
| 170 |
+
# In most of the saliency attribution papers, the saliency is
|
| 171 |
+
# computed with a single target layer.
|
| 172 |
+
# Commonly it is the last convolutional layer.
|
| 173 |
+
# Here we support passing a list with multiple target layers.
|
| 174 |
+
# It will compute the saliency image for every image,
|
| 175 |
+
# and then aggregate them (with a default mean aggregation).
|
| 176 |
+
# This gives you more flexibility in case you just want to
|
| 177 |
+
# use all conv layers for example, all Batchnorm layers,
|
| 178 |
+
# or something else.
|
| 179 |
+
cam_per_layer = self.compute_cam_per_layer(x)
|
| 180 |
+
return self.aggregate_multi_layers(cam_per_layer)
|
| 181 |
+
|
| 182 |
+
def __del__(self):
|
| 183 |
+
self.activations_and_grads.release()
|
| 184 |
+
|
| 185 |
+
def __enter__(self):
|
| 186 |
+
return self
|
| 187 |
+
|
| 188 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 189 |
+
self.activations_and_grads.release()
|
| 190 |
+
if isinstance(exc_value, IndexError):
|
| 191 |
+
# Handle IndexError here...
|
| 192 |
+
print(
|
| 193 |
+
f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
|
| 194 |
+
return True
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def show_cam_on_image(img: np.ndarray,
|
| 198 |
+
mask: np.ndarray,
|
| 199 |
+
use_rgb: bool = False,
|
| 200 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
| 201 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
| 202 |
+
By default the heatmap is in BGR format.
|
| 203 |
+
|
| 204 |
+
:param img: The base image in RGB or BGR format.
|
| 205 |
+
:param mask: The cam mask.
|
| 206 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
| 207 |
+
:param colormap: The OpenCV colormap to be used.
|
| 208 |
+
:returns: The default image with the cam overlay.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
| 212 |
+
if use_rgb:
|
| 213 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 214 |
+
heatmap = np.float32(heatmap) / 255
|
| 215 |
+
|
| 216 |
+
if np.max(img) > 1:
|
| 217 |
+
raise Exception(
|
| 218 |
+
"The input image should np.float32 in the range [0, 1]")
|
| 219 |
+
|
| 220 |
+
cam = heatmap + img
|
| 221 |
+
cam = cam / np.max(cam)
|
| 222 |
+
return np.uint8(255 * cam)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def center_crop_img(img: np.ndarray, size: int):
|
| 226 |
+
h, w, c = img.shape
|
| 227 |
+
|
| 228 |
+
if w == h == size:
|
| 229 |
+
return img
|
| 230 |
+
|
| 231 |
+
if w < h:
|
| 232 |
+
ratio = size / w
|
| 233 |
+
new_w = size
|
| 234 |
+
new_h = int(h * ratio)
|
| 235 |
+
else:
|
| 236 |
+
ratio = size / h
|
| 237 |
+
new_h = size
|
| 238 |
+
new_w = int(w * ratio)
|
| 239 |
+
|
| 240 |
+
img = cv2.resize(img, dsize=(new_w, new_h))
|
| 241 |
+
|
| 242 |
+
if new_w == size:
|
| 243 |
+
h = (new_h - size) // 2
|
| 244 |
+
img = img[h: h+size]
|
| 245 |
+
else:
|
| 246 |
+
w = (new_w - size) // 2
|
| 247 |
+
img = img[:, w: w+size]
|
| 248 |
+
|
| 249 |
+
return img
|