sayakpaul's picture
sayakpaul HF staff
add files
c4b2b37
raw
history blame
11.7 kB
import argparse
import os
import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.Imagenet import Imagenet_Segmentation
from numpy import *
from PIL import Image
from sklearn.metrics import precision_recall_curve
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import render
from utils.iou import IoU
from utils.metrices import *
from utils.saver import Saver
from ViT_explanation_generator import LRP, Baselines
from ViT_LRP import vit_base_patch16_224 as vit_LRP
from ViT_new import vit_base_patch16_224
from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
plt.switch_backend("agg")
# hyperparameters
num_workers = 0
batch_size = 1
cls = [
"airplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motobike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv",
]
# Args
parser = argparse.ArgumentParser(description="Training multi-class classifier")
parser.add_argument(
"--arc", type=str, default="vgg", metavar="N", help="Model architecture"
)
parser.add_argument(
"--train_dataset", type=str, default="imagenet", metavar="N", help="Testing Dataset"
)
parser.add_argument(
"--method",
type=str,
default="grad_rollout",
choices=[
"rollout",
"lrp",
"transformer_attribution",
"full_lrp",
"lrp_last_layer",
"attn_last_layer",
"attn_gradcam",
],
help="",
)
parser.add_argument("--thr", type=float, default=0.0, help="threshold")
parser.add_argument("--K", type=int, default=1, help="new - top K results")
parser.add_argument("--save-img", action="store_true", default=False, help="")
parser.add_argument("--no-ia", action="store_true", default=False, help="")
parser.add_argument("--no-fx", action="store_true", default=False, help="")
parser.add_argument("--no-fgx", action="store_true", default=False, help="")
parser.add_argument("--no-m", action="store_true", default=False, help="")
parser.add_argument("--no-reg", action="store_true", default=False, help="")
parser.add_argument("--is-ablation", type=bool, default=False, help="")
parser.add_argument("--imagenet-seg-path", type=str, required=True)
args = parser.parse_args()
args.checkname = args.method + "_" + args.arc
alpha = 2
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
# Define Saver
saver = Saver(args)
saver.results_dir = os.path.join(saver.experiment_dir, "results")
if not os.path.exists(saver.results_dir):
os.makedirs(saver.results_dir)
if not os.path.exists(os.path.join(saver.results_dir, "input")):
os.makedirs(os.path.join(saver.results_dir, "input"))
if not os.path.exists(os.path.join(saver.results_dir, "explain")):
os.makedirs(os.path.join(saver.results_dir, "explain"))
args.exp_img_path = os.path.join(saver.results_dir, "explain/img")
if not os.path.exists(args.exp_img_path):
os.makedirs(args.exp_img_path)
args.exp_np_path = os.path.join(saver.results_dir, "explain/np")
if not os.path.exists(args.exp_np_path):
os.makedirs(args.exp_np_path)
# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]
)
test_lbl_trans = transforms.Compose(
[
transforms.Resize((224, 224), Image.NEAREST),
]
)
ds = Imagenet_Segmentation(
args.imagenet_seg_path, transform=test_img_trans, target_transform=test_lbl_trans
)
dl = DataLoader(
ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False
)
# Model
model = vit_base_patch16_224(pretrained=True).cuda()
baselines = Baselines(model)
# LRP
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)
# orig LRP
model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
model_orig_LRP.eval()
orig_lrp = LRP(model_orig_LRP)
metric = IoU(2, ignore_index=-1)
iterator = tqdm(dl)
model.eval()
def compute_pred(output):
pred = output.data.max(1, keepdim=True)[
1
] # get the index of the max log-probability
# pred[0, 0] = 282
# print('Pred cls : ' + str(pred))
T = pred.squeeze().cpu().numpy()
T = np.expand_dims(T, 0)
T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
T = torch.from_numpy(T).type(torch.FloatTensor)
Tt = T.cuda()
return Tt
def eval_batch(image, labels, evaluator, index):
evaluator.zero_grad()
# Save input image
if args.save_img:
img = image[0].permute(1, 2, 0).data.cpu().numpy()
img = 255 * (img - img.min()) / (img.max() - img.min())
img = img.astype("uint8")
Image.fromarray(img, "RGB").save(
os.path.join(saver.results_dir, "input/{}_input.png".format(index))
)
Image.fromarray(
(labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype(
"uint8"
),
"RGB",
).save(os.path.join(saver.results_dir, "input/{}_mask.png".format(index)))
image.requires_grad = True
image = image.requires_grad_()
predictions = evaluator(image)
# segmentation test for the rollout baseline
if args.method == "rollout":
Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(
batch_size, 1, 14, 14
)
# segmentation test for the LRP baseline (this is full LRP, not partial)
elif args.method == "full_lrp":
Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(
batch_size, 1, 224, 224
)
# segmentation test for our method
elif args.method == "transformer_attribution":
Res = lrp.generate_LRP(
image.cuda(), start_layer=1, method="transformer_attribution"
).reshape(batch_size, 1, 14, 14)
# segmentation test for the partial LRP baseline (last attn layer)
elif args.method == "lrp_last_layer":
Res = orig_lrp.generate_LRP(
image.cuda(), method="last_layer", is_ablation=args.is_ablation
).reshape(batch_size, 1, 14, 14)
# segmentation test for the raw attention baseline (last attn layer)
elif args.method == "attn_last_layer":
Res = orig_lrp.generate_LRP(
image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation
).reshape(batch_size, 1, 14, 14)
# segmentation test for the GradCam baseline (last attn layer)
elif args.method == "attn_gradcam":
Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)
if args.method != "full_lrp":
# interpolate to full image size (224,224)
Res = torch.nn.functional.interpolate(
Res, scale_factor=16, mode="bilinear"
).cuda()
# threshold between FG and BG is the mean
Res = (Res - Res.min()) / (Res.max() - Res.min())
ret = Res.mean()
Res_1 = Res.gt(ret).type(Res.type())
Res_0 = Res.le(ret).type(Res.type())
Res_1_AP = Res
Res_0_AP = 1 - Res
Res_1[Res_1 != Res_1] = 0
Res_0[Res_0 != Res_0] = 0
Res_1_AP[Res_1_AP != Res_1_AP] = 0
Res_0_AP[Res_0_AP != Res_0_AP] = 0
# TEST
pred = Res.clamp(min=args.thr) / Res.max()
pred = pred.view(-1).data.cpu().numpy()
target = labels.view(-1).data.cpu().numpy()
# print("target", target.shape)
output = torch.cat((Res_0, Res_1), 1)
output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)
if args.save_img:
# Save predicted mask
mask = F.interpolate(Res_1, [64, 64], mode="bilinear")
mask = mask[0].squeeze().data.cpu().numpy()
# mask = Res_1[0].squeeze().data.cpu().numpy()
mask = 255 * mask
mask = mask.astype("uint8")
imageio.imsave(
os.path.join(args.exp_img_path, "mask_" + str(index) + ".jpg"), mask
)
relevance = F.interpolate(Res, [64, 64], mode="bilinear")
relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
# relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
hm = np.sum(relevance, axis=-1)
maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap="seismic") * 255).astype(
np.uint8
)
imageio.imsave(
os.path.join(args.exp_img_path, "heatmap_" + str(index) + ".jpg"), maps
)
# Evaluate Segmentation
batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
batch_ap, batch_f1 = 0, 0
# Segmentation resutls
correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
batch_correct += correct
batch_label += labeled
batch_inter += inter
batch_union += union
# print("output", output.shape)
# print("ap labels", labels.shape)
# ap = np.nan_to_num(get_ap_scores(output, labels))
ap = np.nan_to_num(get_ap_scores(output_AP, labels))
f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
batch_ap += ap
batch_f1 += f1
return (
batch_correct,
batch_label,
batch_inter,
batch_union,
batch_ap,
batch_f1,
pred,
target,
)
total_inter, total_union, total_correct, total_label = (
np.int64(0),
np.int64(0),
np.int64(0),
np.int64(0),
)
total_ap, total_f1 = [], []
predictions, targets = [], []
for batch_idx, (image, labels) in enumerate(iterator):
if args.method == "blur":
images = (image[0].cuda(), image[1].cuda())
else:
images = image.cuda()
labels = labels.cuda()
# print("image", image.shape)
# print("lables", labels.shape)
correct, labeled, inter, union, ap, f1, pred, target = eval_batch(
images, labels, model, batch_idx
)
predictions.append(pred)
targets.append(target)
total_correct += correct.astype("int64")
total_label += labeled.astype("int64")
total_inter += inter.astype("int64")
total_union += union.astype("int64")
total_ap += [ap]
total_f1 += [f1]
pixAcc = (
np.float64(1.0)
* total_correct
/ (np.spacing(1, dtype=np.float64) + total_label)
)
IoU = (
np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
)
mIoU = IoU.mean()
mAp = np.mean(total_ap)
mF1 = np.mean(total_f1)
iterator.set_description(
"pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f" % (pixAcc, mIoU, mAp, mF1)
)
predictions = np.concatenate(predictions)
targets = np.concatenate(targets)
pr, rc, thr = precision_recall_curve(targets, predictions)
np.save(os.path.join(saver.experiment_dir, "precision.npy"), pr)
np.save(os.path.join(saver.experiment_dir, "recall.npy"), rc)
plt.figure()
plt.plot(rc, pr)
plt.savefig(os.path.join(saver.experiment_dir, "PR_curve_{}.png".format(args.method)))
txtfile = os.path.join(saver.experiment_dir, "result_mIoU_%.4f.txt" % mIoU)
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
fh = open(txtfile, "w")
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
print("Mean AP over %d classes: %.4f\n" % (2, mAp))
print("Mean F1 over %d classes: %.4f\n" % (2, mF1))
fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
fh.close()