sayakpaul's picture
sayakpaul HF staff
add files
c4b2b37
raw
history blame
No virus
8.25 kB
import argparse
import os
import h5py
# Import saliency methods and models
from misc_functions import *
from torchvision.datasets import ImageNet
from tqdm import tqdm
from ViT_explanation_generator import LRP, Baselines
from ViT_LRP import vit_base_patch16_224 as vit_LRP
from ViT_new import vit_base_patch16_224
from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
return tensor
def compute_saliency_and_save(args):
first = True
with h5py.File(os.path.join(args.method_dir, "results.hdf5"), "a") as f:
data_cam = f.create_dataset(
"vis",
(1, 1, 224, 224),
maxshape=(None, 1, 224, 224),
dtype=np.float32,
compression="gzip",
)
data_image = f.create_dataset(
"image",
(1, 3, 224, 224),
maxshape=(None, 3, 224, 224),
dtype=np.float32,
compression="gzip",
)
data_target = f.create_dataset(
"target", (1,), maxshape=(None,), dtype=np.int32, compression="gzip"
)
for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
if first:
first = False
data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
else:
data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
# Add data
data_image[-data.shape[0] :] = data.data.cpu().numpy()
data_target[-data.shape[0] :] = target.data.cpu().numpy()
target = target.to(device)
data = normalize(data)
data = data.to(device)
data.requires_grad_()
index = None
if args.vis_class == "target":
index = target
if args.method == "rollout":
Res = baselines.generate_rollout(data, start_layer=1).reshape(
data.shape[0], 1, 14, 14
)
# Res = Res - Res.mean()
elif args.method == "lrp":
Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(
data.shape[0], 1, 14, 14
)
# Res = Res - Res.mean()
elif args.method == "transformer_attribution":
Res = lrp.generate_LRP(
data, start_layer=1, method="grad", index=index
).reshape(data.shape[0], 1, 14, 14)
# Res = Res - Res.mean()
elif args.method == "full_lrp":
Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(
data.shape[0], 1, 224, 224
)
# Res = Res - Res.mean()
elif args.method == "lrp_last_layer":
Res = orig_lrp.generate_LRP(
data, method="last_layer", is_ablation=args.is_ablation, index=index
).reshape(data.shape[0], 1, 14, 14)
# Res = Res - Res.mean()
elif args.method == "attn_last_layer":
Res = lrp.generate_LRP(
data, method="last_layer_attn", is_ablation=args.is_ablation
).reshape(data.shape[0], 1, 14, 14)
elif args.method == "attn_gradcam":
Res = baselines.generate_cam_attn(data, index=index).reshape(
data.shape[0], 1, 14, 14
)
if args.method != "full_lrp" and args.method != "input_grads":
Res = torch.nn.functional.interpolate(
Res, scale_factor=16, mode="bilinear"
).cuda()
Res = (Res - Res.min()) / (Res.max() - Res.min())
data_cam[-data.shape[0] :] = Res.data.cpu().numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a segmentation")
parser.add_argument("--batch-size", type=int, default=1, help="")
parser.add_argument(
"--method",
type=str,
default="grad_rollout",
choices=[
"rollout",
"lrp",
"transformer_attribution",
"full_lrp",
"lrp_last_layer",
"attn_last_layer",
"attn_gradcam",
],
help="",
)
parser.add_argument("--lmd", type=float, default=10, help="")
parser.add_argument(
"--vis-class",
type=str,
default="top",
choices=["top", "target", "index"],
help="",
)
parser.add_argument("--class-id", type=int, default=0, help="")
parser.add_argument("--cls-agn", action="store_true", default=False, help="")
parser.add_argument("--no-ia", action="store_true", default=False, help="")
parser.add_argument("--no-fx", action="store_true", default=False, help="")
parser.add_argument("--no-fgx", action="store_true", default=False, help="")
parser.add_argument("--no-m", action="store_true", default=False, help="")
parser.add_argument("--no-reg", action="store_true", default=False, help="")
parser.add_argument("--is-ablation", type=bool, default=False, help="")
parser.add_argument("--imagenet-validation-path", type=str, required=True, help="")
args = parser.parse_args()
# PATH variables
PATH = os.path.dirname(os.path.abspath(__file__)) + "/"
os.makedirs(os.path.join(PATH, "visualizations"), exist_ok=True)
try:
os.remove(
os.path.join(
PATH,
"visualizations/{}/{}/results.hdf5".format(args.method, args.vis_class),
)
)
except OSError:
pass
os.makedirs(
os.path.join(PATH, "visualizations/{}".format(args.method)), exist_ok=True
)
if args.vis_class == "index":
os.makedirs(
os.path.join(
PATH,
"visualizations/{}/{}_{}".format(
args.method, args.vis_class, args.class_id
),
),
exist_ok=True,
)
args.method_dir = os.path.join(
PATH,
"visualizations/{}/{}_{}".format(
args.method, args.vis_class, args.class_id
),
)
else:
ablation_fold = "ablation" if args.is_ablation else "not_ablation"
os.makedirs(
os.path.join(
PATH,
"visualizations/{}/{}/{}".format(
args.method, args.vis_class, ablation_fold
),
),
exist_ok=True,
)
args.method_dir = os.path.join(
PATH,
"visualizations/{}/{}/{}".format(
args.method, args.vis_class, ablation_fold
),
)
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
# Model
model = vit_base_patch16_224(pretrained=True).cuda()
baselines = Baselines(model)
# LRP
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)
# orig LRP
model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
model_orig_LRP.eval()
orig_lrp = LRP(model_orig_LRP)
# Dataset loader for sample images
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
)
imagenet_ds = ImageNet(
args.imagenet_validation_path, split="val", download=False, transform=transform
)
sample_loader = torch.utils.data.DataLoader(
imagenet_ds, batch_size=args.batch_size, shuffle=False, num_workers=4
)
compute_saliency_and_save(args)