import argparse import os import imageio import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as transforms from data.Imagenet import Imagenet_Segmentation from numpy import * from PIL import Image from sklearn.metrics import precision_recall_curve from torch.utils.data import DataLoader from tqdm import tqdm from utils import render from utils.iou import IoU from utils.metrices import * from utils.saver import Saver from ViT_explanation_generator import LRP, Baselines from ViT_LRP import vit_base_patch16_224 as vit_LRP from ViT_new import vit_base_patch16_224 from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP plt.switch_backend("agg") # hyperparameters num_workers = 0 batch_size = 1 cls = [ "airplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dining table", "dog", "horse", "motobike", "person", "potted plant", "sheep", "sofa", "train", "tv", ] # Args parser = argparse.ArgumentParser(description="Training multi-class classifier") parser.add_argument( "--arc", type=str, default="vgg", metavar="N", help="Model architecture" ) parser.add_argument( "--train_dataset", type=str, default="imagenet", metavar="N", help="Testing Dataset" ) parser.add_argument( "--method", type=str, default="grad_rollout", choices=[ "rollout", "lrp", "transformer_attribution", "full_lrp", "lrp_last_layer", "attn_last_layer", "attn_gradcam", ], help="", ) parser.add_argument("--thr", type=float, default=0.0, help="threshold") parser.add_argument("--K", type=int, default=1, help="new - top K results") parser.add_argument("--save-img", action="store_true", default=False, help="") parser.add_argument("--no-ia", action="store_true", default=False, help="") parser.add_argument("--no-fx", action="store_true", default=False, help="") parser.add_argument("--no-fgx", action="store_true", default=False, help="") parser.add_argument("--no-m", action="store_true", default=False, help="") parser.add_argument("--no-reg", action="store_true", default=False, help="") parser.add_argument("--is-ablation", type=bool, default=False, help="") parser.add_argument("--imagenet-seg-path", type=str, required=True) args = parser.parse_args() args.checkname = args.method + "_" + args.arc alpha = 2 cuda = torch.cuda.is_available() device = torch.device("cuda" if cuda else "cpu") # Define Saver saver = Saver(args) saver.results_dir = os.path.join(saver.experiment_dir, "results") if not os.path.exists(saver.results_dir): os.makedirs(saver.results_dir) if not os.path.exists(os.path.join(saver.results_dir, "input")): os.makedirs(os.path.join(saver.results_dir, "input")) if not os.path.exists(os.path.join(saver.results_dir, "explain")): os.makedirs(os.path.join(saver.results_dir, "explain")) args.exp_img_path = os.path.join(saver.results_dir, "explain/img") if not os.path.exists(args.exp_img_path): os.makedirs(args.exp_img_path) args.exp_np_path = os.path.join(saver.results_dir, "explain/np") if not os.path.exists(args.exp_np_path): os.makedirs(args.exp_np_path) # Data normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) test_img_trans = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), normalize, ] ) test_lbl_trans = transforms.Compose( [ transforms.Resize((224, 224), Image.NEAREST), ] ) ds = Imagenet_Segmentation( args.imagenet_seg_path, transform=test_img_trans, target_transform=test_lbl_trans ) dl = DataLoader( ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False ) # Model model = vit_base_patch16_224(pretrained=True).cuda() baselines = Baselines(model) # LRP model_LRP = vit_LRP(pretrained=True).cuda() model_LRP.eval() lrp = LRP(model_LRP) # orig LRP model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() model_orig_LRP.eval() orig_lrp = LRP(model_orig_LRP) metric = IoU(2, ignore_index=-1) iterator = tqdm(dl) model.eval() def compute_pred(output): pred = output.data.max(1, keepdim=True)[ 1 ] # get the index of the max log-probability # pred[0, 0] = 282 # print('Pred cls : ' + str(pred)) T = pred.squeeze().cpu().numpy() T = np.expand_dims(T, 0) T = (T[:, np.newaxis] == np.arange(1000)) * 1.0 T = torch.from_numpy(T).type(torch.FloatTensor) Tt = T.cuda() return Tt def eval_batch(image, labels, evaluator, index): evaluator.zero_grad() # Save input image if args.save_img: img = image[0].permute(1, 2, 0).data.cpu().numpy() img = 255 * (img - img.min()) / (img.max() - img.min()) img = img.astype("uint8") Image.fromarray(img, "RGB").save( os.path.join(saver.results_dir, "input/{}_input.png".format(index)) ) Image.fromarray( (labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype( "uint8" ), "RGB", ).save(os.path.join(saver.results_dir, "input/{}_mask.png".format(index))) image.requires_grad = True image = image.requires_grad_() predictions = evaluator(image) # segmentation test for the rollout baseline if args.method == "rollout": Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape( batch_size, 1, 14, 14 ) # segmentation test for the LRP baseline (this is full LRP, not partial) elif args.method == "full_lrp": Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape( batch_size, 1, 224, 224 ) # segmentation test for our method elif args.method == "transformer_attribution": Res = lrp.generate_LRP( image.cuda(), start_layer=1, method="transformer_attribution" ).reshape(batch_size, 1, 14, 14) # segmentation test for the partial LRP baseline (last attn layer) elif args.method == "lrp_last_layer": Res = orig_lrp.generate_LRP( image.cuda(), method="last_layer", is_ablation=args.is_ablation ).reshape(batch_size, 1, 14, 14) # segmentation test for the raw attention baseline (last attn layer) elif args.method == "attn_last_layer": Res = orig_lrp.generate_LRP( image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation ).reshape(batch_size, 1, 14, 14) # segmentation test for the GradCam baseline (last attn layer) elif args.method == "attn_gradcam": Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14) if args.method != "full_lrp": # interpolate to full image size (224,224) Res = torch.nn.functional.interpolate( Res, scale_factor=16, mode="bilinear" ).cuda() # threshold between FG and BG is the mean Res = (Res - Res.min()) / (Res.max() - Res.min()) ret = Res.mean() Res_1 = Res.gt(ret).type(Res.type()) Res_0 = Res.le(ret).type(Res.type()) Res_1_AP = Res Res_0_AP = 1 - Res Res_1[Res_1 != Res_1] = 0 Res_0[Res_0 != Res_0] = 0 Res_1_AP[Res_1_AP != Res_1_AP] = 0 Res_0_AP[Res_0_AP != Res_0_AP] = 0 # TEST pred = Res.clamp(min=args.thr) / Res.max() pred = pred.view(-1).data.cpu().numpy() target = labels.view(-1).data.cpu().numpy() # print("target", target.shape) output = torch.cat((Res_0, Res_1), 1) output_AP = torch.cat((Res_0_AP, Res_1_AP), 1) if args.save_img: # Save predicted mask mask = F.interpolate(Res_1, [64, 64], mode="bilinear") mask = mask[0].squeeze().data.cpu().numpy() # mask = Res_1[0].squeeze().data.cpu().numpy() mask = 255 * mask mask = mask.astype("uint8") imageio.imsave( os.path.join(args.exp_img_path, "mask_" + str(index) + ".jpg"), mask ) relevance = F.interpolate(Res, [64, 64], mode="bilinear") relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy() # relevance = Res[0].permute(1, 2, 0).data.cpu().numpy() hm = np.sum(relevance, axis=-1) maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap="seismic") * 255).astype( np.uint8 ) imageio.imsave( os.path.join(args.exp_img_path, "heatmap_" + str(index) + ".jpg"), maps ) # Evaluate Segmentation batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0 batch_ap, batch_f1 = 0, 0 # Segmentation resutls correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0]) inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2) batch_correct += correct batch_label += labeled batch_inter += inter batch_union += union # print("output", output.shape) # print("ap labels", labels.shape) # ap = np.nan_to_num(get_ap_scores(output, labels)) ap = np.nan_to_num(get_ap_scores(output_AP, labels)) f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0])) batch_ap += ap batch_f1 += f1 return ( batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target, ) total_inter, total_union, total_correct, total_label = ( np.int64(0), np.int64(0), np.int64(0), np.int64(0), ) total_ap, total_f1 = [], [] predictions, targets = [], [] for batch_idx, (image, labels) in enumerate(iterator): if args.method == "blur": images = (image[0].cuda(), image[1].cuda()) else: images = image.cuda() labels = labels.cuda() # print("image", image.shape) # print("lables", labels.shape) correct, labeled, inter, union, ap, f1, pred, target = eval_batch( images, labels, model, batch_idx ) predictions.append(pred) targets.append(target) total_correct += correct.astype("int64") total_label += labeled.astype("int64") total_inter += inter.astype("int64") total_union += union.astype("int64") total_ap += [ap] total_f1 += [f1] pixAcc = ( np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label) ) IoU = ( np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union) ) mIoU = IoU.mean() mAp = np.mean(total_ap) mF1 = np.mean(total_f1) iterator.set_description( "pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f" % (pixAcc, mIoU, mAp, mF1) ) predictions = np.concatenate(predictions) targets = np.concatenate(targets) pr, rc, thr = precision_recall_curve(targets, predictions) np.save(os.path.join(saver.experiment_dir, "precision.npy"), pr) np.save(os.path.join(saver.experiment_dir, "recall.npy"), rc) plt.figure() plt.plot(rc, pr) plt.savefig(os.path.join(saver.experiment_dir, "PR_curve_{}.png".format(args.method))) txtfile = os.path.join(saver.experiment_dir, "result_mIoU_%.4f.txt" % mIoU) # txtfile = 'result_mIoU_%.4f.txt' % mIoU fh = open(txtfile, "w") print("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) print("Mean AP over %d classes: %.4f\n" % (2, mAp)) print("Mean F1 over %d classes: %.4f\n" % (2, mF1)) fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp)) fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1)) fh.close()