din0s commited on
Commit
d4ab5ac
1 Parent(s): d395a3a
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ checkpoints/diffmask.ckpt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, './code')
3
+
4
+ from datamodules.transformations import UnNest
5
+ from models.interpretation import ImageInterpretationNet
6
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
7
+ from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+
13
+ # Load Vision Transformer
14
+ hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
15
+ vit = ViTForImageClassification.from_pretrained(hf_model)
16
+ vit.eval()
17
+
18
+ # Load Feature Extractor
19
+ feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
20
+ feature_extractor = UnNest(feature_extractor)
21
+
22
+ # Load Vision DiffMask
23
+ diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
24
+ diffmask.set_vision_transformer(vit)
25
+
26
+
27
+ # Define mask plotting functions
28
+ def draw_mask(image, mask):
29
+ return draw_mask_on_image(image, smoothen(mask))\
30
+ .permute(1, 2, 0)\
31
+ .clip(0, 1)\
32
+ .numpy()
33
+
34
+
35
+ def draw_heatmap(image, mask):
36
+ return draw_heatmap_on_image(image, smoothen(mask))\
37
+ .permute(1, 2, 0)\
38
+ .clip(0, 1)\
39
+ .numpy()
40
+
41
+
42
+ # Define callable method for the demo
43
+ def get_mask(image):
44
+ if image is None:
45
+ return None
46
+
47
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
48
+ dm_image = feature_extractor(image).unsqueeze(0)
49
+ mask = diffmask.get_mask(dm_image)["mask"][0].detach()
50
+
51
+ masked_img = draw_mask(image, mask)
52
+ heatmap = draw_heatmap(image, mask)
53
+ return np.hstack((masked_img, heatmap))
54
+
55
+
56
+ # Launch demo interface
57
+ gr.Interface(
58
+ get_mask,
59
+ inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
60
+ outputs=[gr.outputs.Image(label="Output")],
61
+ title="Vision DiffMask Demo",
62
+ live=True,
63
+ ).launch()
checkpoints/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
checkpoints/diffmask.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33ceff3adc10ffb86bdaa3c90380e7925e76e7b170ed42d1cc00ff33328fc77b
3
+ size 16610391
code/attributions/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .attention_rollout import attention_rollout
2
+ from .grad_cam import grad_cam
code/attributions/attention_rollout.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from math import sqrt
5
+ from torch import Tensor
6
+ from transformers import ViTForImageClassification
7
+
8
+
9
+ @torch.no_grad()
10
+ def attention_rollout(
11
+ images: Tensor,
12
+ vit: ViTForImageClassification,
13
+ discard_ratio: float = 0.9,
14
+ head_fusion: str = "mean",
15
+ device: str = "cpu",
16
+ ) -> Tensor:
17
+ """Performs the Attention Rollout method on a batch of images (https://arxiv.org/pdf/2005.00928.pdf)."""
18
+ # Forward pass and save attention maps
19
+ attentions = vit(images, output_attentions=True).attentions
20
+
21
+ B, _, H, W = images.shape # Batch size, channels, height, width
22
+ P = attentions[0].size(-1) # Number of patches
23
+
24
+ mask = torch.eye(P).to(device)
25
+ # Iterate over layers
26
+ for j, attention in enumerate(attentions):
27
+ if head_fusion == "mean":
28
+ attention_heads_fused = attention.mean(axis=1)
29
+ elif head_fusion == "max":
30
+ attention_heads_fused = attention.max(axis=1)[0]
31
+ elif head_fusion == "min":
32
+ attention_heads_fused = attention.min(axis=1)[0]
33
+ else:
34
+ raise "Attention head fusion type Not supported"
35
+
36
+ # Drop the lowest attentions, but don't drop the class token
37
+ flat = attention_heads_fused.view(B, -1)
38
+ _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
39
+ indices = indices[indices != 0]
40
+ flat[0, indices] = 0
41
+
42
+ # I = torch.eye(P)
43
+ a = (attention_heads_fused + torch.eye(P).to(device)) / 2
44
+ a = a / a.sum(dim=-1).view(-1, P, 1)
45
+
46
+ mask = a @ mask
47
+
48
+ # Look at the total attention between the class token and the image patches
49
+ mask = mask[:, 0, 1:]
50
+ mask = mask / torch.max(mask)
51
+
52
+ N = int(sqrt(P))
53
+ S = int(H / N)
54
+
55
+ mask = mask.reshape(B, 1, N, N)
56
+ mask = F.interpolate(mask, scale_factor=S)
57
+ mask = mask.reshape(B, H, W)
58
+
59
+ return mask
code/attributions/grad_cam.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from pytorch_grad_cam import GradCAM
4
+ from torch import Tensor
5
+ from transformers import ViTForImageClassification
6
+
7
+
8
+ def grad_cam(images: Tensor, vit: ViTForImageClassification, use_cuda: bool = False) -> Tensor:
9
+ """Performs the Grad-CAM method on a batch of images (https://arxiv.org/pdf/1610.02391.pdf)."""
10
+
11
+ # Wrap the ViT model to be compatible with GradCAM
12
+ vit = ViTWrapper(vit)
13
+ vit.eval()
14
+
15
+ # Create GradCAM object
16
+ cam = GradCAM(
17
+ model=vit,
18
+ target_layers=[vit.target_layer],
19
+ reshape_transform=_reshape_transform,
20
+ use_cuda=use_cuda,
21
+ )
22
+
23
+ # Compute GradCAM masks
24
+ grayscale_cam = cam(
25
+ input_tensor=images,
26
+ targets=None,
27
+ eigen_smooth=True,
28
+ aug_smooth=True,
29
+ )
30
+
31
+ return torch.from_numpy(grayscale_cam)
32
+
33
+
34
+ def _reshape_transform(tensor, height=14, width=14):
35
+ result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
36
+
37
+ # Bring the channels to the first dimension
38
+ result = result.transpose(2, 3).transpose(1, 2)
39
+
40
+ return result
41
+
42
+
43
+ class ViTWrapper(torch.nn.Module):
44
+ """ViT Wrapper to use with Grad-CAM."""
45
+
46
+ def __init__(self, vit: ViTForImageClassification):
47
+ super().__init__()
48
+ self.vit = vit
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return self.vit(x).logits
52
+
53
+ @property
54
+ def target_layer(self):
55
+ return self.vit.vit.encoder.layer[-2].layernorm_after
code/datamodules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import ImageDataModule
2
+ from .image_classification import CIFAR10DataModule, MNISTDataModule
3
+ from .visual_qa import CIFAR10QADataModule, ToyQADataModule
code/datamodules/base.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transformations import AddGaussianNoise
2
+ from abc import abstractmethod, ABCMeta
3
+ from argparse import ArgumentParser
4
+ from pytorch_lightning import LightningDataModule
5
+ from torch.utils.data import (
6
+ DataLoader,
7
+ Dataset,
8
+ default_collate,
9
+ RandomSampler,
10
+ SequentialSampler,
11
+ )
12
+ from torchvision import transforms
13
+ from typing import Optional
14
+
15
+
16
+ class ImageDataModule(LightningDataModule, metaclass=ABCMeta):
17
+ @staticmethod
18
+ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
19
+ parser = parent_parser.add_argument_group("Data Modules")
20
+ parser.add_argument(
21
+ "--data_dir",
22
+ type=str,
23
+ default="data/",
24
+ help="The directory where the data is stored.",
25
+ )
26
+ parser.add_argument(
27
+ "--batch_size",
28
+ type=int,
29
+ default=32,
30
+ help="The batch size to use.",
31
+ )
32
+ parser.add_argument(
33
+ "--add_noise",
34
+ action="store_true",
35
+ help="Use gaussian noise augmentation.",
36
+ )
37
+ parser.add_argument(
38
+ "--add_rotation",
39
+ action="store_true",
40
+ help="Use rotation augmentation.",
41
+ )
42
+ parser.add_argument(
43
+ "--add_blur",
44
+ action="store_true",
45
+ help="Use blur augmentation.",
46
+ )
47
+ parser.add_argument(
48
+ "--num_workers",
49
+ type=int,
50
+ default=4,
51
+ help="Number of workers to use for data loading.",
52
+ )
53
+ return parent_parser
54
+
55
+ # Declare variables that will be initialized later
56
+ train_data: Dataset
57
+ val_data: Dataset
58
+ test_data: Dataset
59
+
60
+ def __init__(
61
+ self,
62
+ feature_extractor: Optional[callable] = None,
63
+ data_dir: str = "data/",
64
+ batch_size: int = 32,
65
+ add_noise: bool = False,
66
+ add_rotation: bool = False,
67
+ add_blur: bool = False,
68
+ num_workers: int = 4,
69
+ ):
70
+ """Abstract Pytorch Lightning DataModule for image datasets.
71
+
72
+ Args:
73
+ feature_extractor (callable): feature extractor instance
74
+ data_dir (str): directory to store the dataset
75
+ batch_size (int): batch size for the train/val/test dataloaders
76
+ add_noise (bool): whether to add noise to the images
77
+ add_rotation (bool): whether to add random rotation to the images
78
+ add_blur (bool): whether to add blur to the images
79
+ num_workers (int): number of workers for train/val/test dataloaders
80
+ """
81
+ super().__init__()
82
+
83
+ # Store hyperparameters
84
+ self.data_dir = data_dir
85
+ self.batch_size = batch_size
86
+ self.feature_extractor = feature_extractor
87
+ self.num_workers = num_workers
88
+
89
+ # Set the transforms
90
+ # If the feature_extractor is None, then we do not split the images into features
91
+ init_transforms = [feature_extractor] if feature_extractor else []
92
+ self.transform = transforms.Compose(init_transforms)
93
+ self._add_transforms(add_noise, add_rotation, add_blur)
94
+
95
+ # Set the collate function and the samplers
96
+ # These can be adapted in a child datamodule class to have a different behavior
97
+ self.collate_fn = default_collate
98
+ self.shuffled_sampler = RandomSampler
99
+ self.sequential_sampler = SequentialSampler
100
+
101
+ def _add_transforms(self, noise: bool, rotation: bool, blur: bool):
102
+ """Add transforms to the module's transformations list.
103
+
104
+ Args:
105
+ noise (bool): whether to add noise to the images
106
+ rotation (bool): whether to add random rotation to the images
107
+ blur (bool): whether to add blur to the images
108
+ """
109
+ # TODO:
110
+ # - Which order to add the transforms in?
111
+ # - Applied in both train and test or just test?
112
+ # - Check what transforms are applied by the model
113
+ if noise:
114
+ self.transform.transforms.append(AddGaussianNoise(0.0, 1.0))
115
+ if rotation:
116
+ self.transform.transforms.append(transforms.RandomRotation(20))
117
+ if blur:
118
+ self.transform.transforms.append(transforms.GaussianBlur(3))
119
+
120
+ @abstractmethod
121
+ def prepare_data(self):
122
+ raise NotImplementedError()
123
+
124
+ @abstractmethod
125
+ def setup(self, stage: Optional[str] = None):
126
+ raise NotImplementedError()
127
+
128
+ # noinspection PyTypeChecker
129
+ def train_dataloader(self) -> DataLoader:
130
+ return DataLoader(
131
+ self.train_data,
132
+ batch_size=self.batch_size,
133
+ num_workers=self.num_workers,
134
+ collate_fn=self.collate_fn,
135
+ sampler=self.shuffled_sampler(self.train_data),
136
+ )
137
+
138
+ # noinspection PyTypeChecker
139
+ def val_dataloader(self) -> DataLoader:
140
+ return DataLoader(
141
+ self.val_data,
142
+ batch_size=self.batch_size,
143
+ num_workers=self.num_workers,
144
+ collate_fn=self.collate_fn,
145
+ sampler=self.sequential_sampler(self.val_data),
146
+ )
147
+
148
+ # noinspection PyTypeChecker
149
+ def test_dataloader(self) -> DataLoader:
150
+ return DataLoader(
151
+ self.test_data,
152
+ batch_size=self.batch_size,
153
+ num_workers=self.num_workers,
154
+ collate_fn=self.collate_fn,
155
+ sampler=self.sequential_sampler(self.test_data),
156
+ )
code/datamodules/image_classification.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import ImageDataModule
2
+ from torch.utils.data import random_split
3
+ from torchvision.datasets import MNIST, CIFAR10
4
+ from typing import Optional
5
+
6
+
7
+ class MNISTDataModule(ImageDataModule):
8
+ """Datamodule for the MNIST dataset."""
9
+
10
+ def prepare_data(self):
11
+ # Download MNIST
12
+ MNIST(self.data_dir, train=True, download=True)
13
+ MNIST(self.data_dir, train=False, download=True)
14
+
15
+ def setup(self, stage: Optional[str] = None):
16
+ # Set the training and validation data
17
+ if stage == "fit" or stage is None:
18
+ mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
19
+ self.train_data, self.val_data = random_split(mnist_full, [55000, 5000])
20
+
21
+ # Set the test data
22
+ if stage == "test" or stage is None:
23
+ self.test_data = MNIST(self.data_dir, train=False, transform=self.transform)
24
+
25
+
26
+ class CIFAR10DataModule(ImageDataModule):
27
+ """Datamodule for the CIFAR10 dataset."""
28
+
29
+ def prepare_data(self):
30
+ # Download CIFAR10
31
+ CIFAR10(self.data_dir, train=True, download=True)
32
+ CIFAR10(self.data_dir, train=False, download=True)
33
+
34
+ def setup(self, stage: Optional[str] = None):
35
+ # Set the training and validation data
36
+ if stage == "fit" or stage is None:
37
+ cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
38
+ self.train_data, self.val_data = random_split(cifar10_full, [45000, 5000])
39
+
40
+ # Set the test data
41
+ if stage == "test" or stage is None:
42
+ self.test_data = CIFAR10(
43
+ self.data_dir, train=False, transform=self.transform
44
+ )
code/datamodules/transformations.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ from transformers.image_utils import ImageInput
3
+
4
+ import torch
5
+
6
+
7
+ class AddGaussianNoise:
8
+ """Add Gaussian noise to an image.
9
+
10
+ Args:
11
+ mean (float): mean of the Gaussian noise
12
+ std (float): standard deviation of the Gaussian noise
13
+ """
14
+
15
+ def __init__(self, mean: float = 0.0, std: float = 1.0):
16
+ self.std = std
17
+ self.mean = mean
18
+
19
+ def __call__(self, tensor: Tensor) -> Tensor:
20
+ return tensor + torch.randn(tensor.size()) * self.std + self.mean
21
+
22
+ def __repr__(self) -> str:
23
+ return self.__class__.__name__ + "(mean={0}, std={1})".format(
24
+ self.mean, self.std
25
+ )
26
+
27
+
28
+ class UnNest:
29
+ """Un-nest the output of a feature extractor"""
30
+
31
+ def __init__(self, feature_extractor: callable):
32
+ self.feature_extractor = feature_extractor
33
+
34
+ def __call__(self, x: ImageInput) -> Tensor:
35
+ # Pass the input through the feature extractor
36
+ x = self.feature_extractor(x)
37
+ # Un-nest the pixel_values tensor
38
+ x = torch.tensor(x["pixel_values"][0])
39
+
40
+ # HuggingFace models expect 3D tensors [C, H, W]
41
+ return x if len(x) == 3 else x.unsqueeze(0)
code/datamodules/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_classification import CIFAR10DataModule, ImageDataModule, MNISTDataModule
2
+ from .transformations import UnNest
3
+ from .visual_qa import CIFAR10QADataModule, ToyQADataModule
4
+ from argparse import Namespace
5
+ from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
6
+
7
+
8
+ def get_configs(args: Namespace) -> tuple[dict, dict]:
9
+ """Get the model and feature extractor configs from the command line args.
10
+
11
+ Args:
12
+ args (Namespace): the argparse Namespace object
13
+
14
+ Returns:
15
+ a tuple containing the model and feature extractor configs
16
+ """
17
+ if args.dataset == "MNIST":
18
+ # We upsample the MNIST images to 112x112, with 1 channel (grayscale)
19
+ # and 10 classes (0-9). We normalize the image to have a mean of 0.5
20
+ # and a standard deviation of ±0.5.
21
+ model_cfg_args = {
22
+ "image_size": 112,
23
+ "num_channels": 1,
24
+ "num_labels": 10,
25
+ }
26
+ fe_cfg_args = {
27
+ "image_mean": [0.5],
28
+ "image_std": [0.5],
29
+ }
30
+ elif args.dataset.startswith("CIFAR10"):
31
+ if args.dataset not in ("CIFAR10", "CIFAR10_QA"):
32
+ raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
33
+
34
+ # We upsample the CIFAR10 images to 224x224, with 3 channels (RGB) and
35
+ # 10 classes (0-9) for the normal dataset, or (grid_size)^2 + 1 for the
36
+ # toy task. We normalize the image to have a mean of 0.5 and a standard
37
+ # deviation of ±0.5.
38
+ model_cfg_args = {
39
+ "image_size": 224, # fixed to 224 because pretrained models have that size
40
+ "num_channels": 3,
41
+ "num_labels": (args.grid_size**2) + 1
42
+ if args.dataset == "CIFAR10_QA"
43
+ else 10,
44
+ }
45
+ fe_cfg_args = {
46
+ "image_mean": [0.5, 0.5, 0.5],
47
+ "image_std": [0.5, 0.5, 0.5],
48
+ }
49
+ elif args.dataset == "toy":
50
+ # We use an image size so that each patch contains a single color, with
51
+ # 3 channels (RGB) and (grid_size)^2 + 1 classes. We normalize the image
52
+ # to have a mean of 0.5 and a standard deviation of ±0.5.
53
+ model_cfg_args = {
54
+ "image_size": args.grid_size * 16,
55
+ "num_channels": 3,
56
+ "num_labels": (args.grid_size**2) + 1,
57
+ }
58
+ fe_cfg_args = {
59
+ "image_mean": [0.5, 0.5, 0.5],
60
+ "image_std": [0.5, 0.5, 0.5],
61
+ }
62
+ else:
63
+ raise Exception(f"Unknown dataset: {args.dataset}")
64
+
65
+ # Set the feature extractor's size attribute to be the same as the model's image size
66
+ fe_cfg_args["size"] = model_cfg_args["image_size"]
67
+ # Set the tensors' return type to PyTorch tensors
68
+ fe_cfg_args["return_tensors"] = "pt"
69
+
70
+ return model_cfg_args, fe_cfg_args
71
+
72
+
73
+ def datamodule_factory(args: Namespace) -> ImageDataModule:
74
+ """A factory method for creating a datamodule based on the command line args.
75
+
76
+ Args:
77
+ args (Namespace): the argparse Namespace object
78
+
79
+ Returns:
80
+ an ImageDataModule instance
81
+ """
82
+ # Get the model and feature extractor configs
83
+ model_cfg_args, fe_cfg_args = get_configs(args)
84
+
85
+ # Set the feature extractor class based on the provided base model name
86
+ if args.base_model == "ViT":
87
+ fe_class = ViTFeatureExtractor
88
+ elif args.base_model == "ConvNeXt":
89
+ fe_class = ConvNextFeatureExtractor
90
+ else:
91
+ raise Exception(f"Unknown base model: {args.base_model}")
92
+
93
+ # Create the feature extractor instance
94
+ if args.from_pretrained:
95
+ feature_extractor = fe_class.from_pretrained(
96
+ args.from_pretrained, **fe_cfg_args
97
+ )
98
+ else:
99
+ feature_extractor = fe_class(**fe_cfg_args)
100
+
101
+ # Un-nest the feature extractor's output
102
+ feature_extractor = UnNest(feature_extractor)
103
+
104
+ # Define the datamodule's configuration
105
+ dm_cfg = {
106
+ "feature_extractor": feature_extractor,
107
+ "batch_size": args.batch_size,
108
+ "add_noise": args.add_noise,
109
+ "add_rotation": args.add_rotation,
110
+ "add_blur": args.add_blur,
111
+ "num_workers": args.num_workers,
112
+ }
113
+
114
+ # Determine the dataset class based on the provided dataset name
115
+ if args.dataset.startswith("CIFAR10"):
116
+ if args.dataset == "CIFAR10":
117
+ dm_class = CIFAR10DataModule
118
+ elif args.dataset == "CIFAR10_QA":
119
+ dm_cfg["class_idx"] = args.class_idx
120
+ dm_cfg["grid_size"] = args.grid_size
121
+ dm_class = CIFAR10QADataModule
122
+ else:
123
+ raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
124
+ elif args.dataset == "MNIST":
125
+ dm_class = MNISTDataModule
126
+ elif args.dataset == "toy":
127
+ dm_cfg["class_idx"] = args.class_idx
128
+ dm_cfg["grid_size"] = args.grid_size
129
+ dm_class = ToyQADataModule
130
+ else:
131
+ raise Exception(f"Unknown dataset: {args.dataset}")
132
+
133
+ return dm_class(**dm_cfg)
code/datamodules/visual_qa.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_classification import CIFAR10DataModule
2
+ from argparse import ArgumentParser
3
+ from functools import partial
4
+ from torch import LongTensor
5
+ from torch.utils.data import default_collate, random_split, Sampler
6
+ from torchvision import transforms
7
+ from torchvision.datasets import VisionDataset
8
+ from typing import Iterator, Optional
9
+
10
+ import itertools
11
+ import random
12
+ import torch
13
+
14
+
15
+ class CIFAR10QADataModule(CIFAR10DataModule):
16
+ @staticmethod
17
+ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
18
+ parser = parent_parser.add_argument_group("Visual QA")
19
+ parser.add_argument(
20
+ "--class_idx",
21
+ type=int,
22
+ default=3,
23
+ help="The class (index) to count.",
24
+ )
25
+ parser.add_argument(
26
+ "--grid_size",
27
+ type=int,
28
+ default=3,
29
+ help="The number of images per row in the grid.",
30
+ )
31
+ return parent_parser
32
+
33
+ def __init__(
34
+ self,
35
+ class_idx: int,
36
+ grid_size: int = 3,
37
+ feature_extractor: callable = None,
38
+ data_dir: str = "data/",
39
+ batch_size: int = 32,
40
+ add_noise: bool = False,
41
+ add_rotation: bool = False,
42
+ add_blur: bool = False,
43
+ num_workers: int = 4,
44
+ ):
45
+ """A datamodule for a modified CIFAR10 dataset that is used for Question Answering.
46
+ More specifically, the task is to count the number of images of a certain class in a grid.
47
+
48
+ Args:
49
+ class_idx (int): the class (index) to count
50
+ grid_size (int): the number of images per row in the grid
51
+ feature_extractor (callable): a callable feature extractor instance
52
+ data_dir (str): the directory to store the dataset
53
+ batch_size (int): the batch size for the train/val/test dataloaders
54
+ add_noise (bool): whether to add noise to the images
55
+ add_rotation (bool): whether to add rotation augmentation
56
+ add_blur (bool): whether to add blur augmentation
57
+ num_workers (int): the number of workers to use for data loading
58
+ """
59
+ super().__init__(
60
+ feature_extractor,
61
+ data_dir,
62
+ (grid_size**2) * batch_size,
63
+ add_noise,
64
+ add_rotation,
65
+ add_blur,
66
+ num_workers,
67
+ )
68
+
69
+ # Store hyperparameters
70
+ self.class_idx = class_idx
71
+ self.grid_size = grid_size
72
+
73
+ # Save the existing transformations to be applied after creating the grid
74
+ self.post_transform = self.transform
75
+ # Set the pre-batch transformation to be the conversion from PIL to tensor
76
+ self.transform = transforms.PILToTensor()
77
+
78
+ # Specify the custom collate function and samplers
79
+ self.collate_fn = self.custom_collate_fn
80
+ self.shuffled_sampler = partial(
81
+ FairGridSampler,
82
+ class_idx=class_idx,
83
+ grid_size=grid_size,
84
+ shuffle=True,
85
+ )
86
+ self.sequential_sampler = partial(
87
+ FairGridSampler,
88
+ class_idx=class_idx,
89
+ grid_size=grid_size,
90
+ shuffle=False,
91
+ )
92
+
93
+ def custom_collate_fn(self, batch):
94
+ # Split the batch into groups of grid_size**2
95
+ idx = range(len(batch))
96
+ grids = zip(*(iter(idx),) * (self.grid_size**2))
97
+
98
+ new_batch = []
99
+ for grid in grids:
100
+ # Create a grid of images from the indices in the batch
101
+ img = torch.hstack(
102
+ [
103
+ torch.dstack(
104
+ [batch[i][0] for i in grid[idx : idx + self.grid_size]]
105
+ )
106
+ for idx in range(
107
+ 0, self.grid_size**2 - self.grid_size + 1, self.grid_size
108
+ )
109
+ ]
110
+ )
111
+ # Apply the post transformations to the grid
112
+ img = self.post_transform(img)
113
+ # Define the target as the number of images that have the class_idx
114
+ targets = [batch[i][1] for i in grid]
115
+ target = targets.count(self.class_idx)
116
+ # Append grid and target to the batch
117
+ new_batch += [(img, target)]
118
+
119
+ return default_collate(new_batch)
120
+
121
+
122
+ class ToyQADataModule(CIFAR10QADataModule):
123
+ """A datamodule for the toy dataset as described in the paper."""
124
+
125
+ def prepare_data(self):
126
+ # No need to download anything for the toy task
127
+ pass
128
+
129
+ def setup(self, stage: Optional[str] = None):
130
+ img_size = 16
131
+
132
+ samples = []
133
+ # Generate 6000 samples based on 6 different colors
134
+ for r, g, b in itertools.product((0, 1), (0, 1), (0, 1)):
135
+ if r == g == b:
136
+ # We do not want black/white patches
137
+ continue
138
+
139
+ for _ in range(1000):
140
+ patch = torch.vstack(
141
+ [
142
+ r * torch.ones(1, img_size, img_size),
143
+ g * torch.ones(1, img_size, img_size),
144
+ b * torch.ones(1, img_size, img_size),
145
+ ]
146
+ )
147
+
148
+ # Assign a unique id to each color
149
+ target = int(f"{r}{g}{b}", 2) - 1
150
+ # Append the patch and target to the samples
151
+ samples += [(patch, target)]
152
+
153
+ # Split the data to 90% train, 5% validation and 5% test
154
+ train_size = int(len(samples) * 0.9)
155
+ val_size = (len(samples) - train_size) // 2
156
+ test_size = len(samples) - train_size - val_size
157
+ self.train_data, self.val_data, self.test_data = random_split(
158
+ samples,
159
+ [
160
+ train_size,
161
+ val_size,
162
+ test_size,
163
+ ],
164
+ )
165
+
166
+
167
+ class FairGridSampler(Sampler[int]):
168
+ def __init__(
169
+ self,
170
+ dataset: VisionDataset,
171
+ class_idx: int,
172
+ grid_size: int,
173
+ shuffle: bool = False,
174
+ ):
175
+ """A sampler that returns a grid of images from the dataset, with a uniformly random
176
+ amount of appearances for a specific class of interest.
177
+
178
+ Args:
179
+ dataset (VisionDataset): the dataset to sample from
180
+ class_idx(int): the class (index) to treat as the class of interest
181
+ grid_size (int): the number of images per row in the grid
182
+ shuffle (bool): whether to shuffle the dataset before sampling
183
+ """
184
+ super().__init__(dataset)
185
+
186
+ # Save the hyperparameters
187
+ self.dataset = dataset
188
+ self.grid_size = grid_size
189
+ self.n_images = grid_size**2
190
+
191
+ # Get the indices of the class of interest
192
+ self.class_indices = LongTensor(
193
+ [i for i, x in enumerate(dataset) if x[1] == class_idx]
194
+ )
195
+ # Get the indices of all other classes
196
+ self.other_indices = LongTensor(
197
+ [i for i, x in enumerate(dataset) if x[1] != class_idx]
198
+ )
199
+
200
+ # Fix the seed if shuffle is False
201
+ self.seed = None if shuffle else self._get_seed()
202
+
203
+ @staticmethod
204
+ def _get_seed() -> int:
205
+ """Utility function for generating a random seed."""
206
+ return int(torch.empty((), dtype=torch.int64).random_().item())
207
+
208
+ def __iter__(self) -> Iterator[int]:
209
+ # Create a torch Generator object
210
+ seed = self.seed if self.seed is not None else self._get_seed()
211
+ gen = torch.Generator()
212
+ gen.manual_seed(seed)
213
+
214
+ # Sample the batches
215
+ for _ in range(len(self.dataset) // self.n_images):
216
+ # Pick the number of instances for the class of interest
217
+ n_samples = torch.randint(self.n_images + 1, (), generator=gen).item()
218
+
219
+ # Sample the indices from the class of interest
220
+ idx_from_class = torch.randperm(
221
+ len(self.class_indices),
222
+ generator=gen,
223
+ )[:n_samples]
224
+ # Sample the indices from the other classes
225
+ idx_from_other = torch.randperm(
226
+ len(self.other_indices),
227
+ generator=gen,
228
+ )[: self.n_images - n_samples]
229
+
230
+ # Concatenate the corresponding lists of patches to form a grid
231
+ grid = (
232
+ self.class_indices[idx_from_class].tolist()
233
+ + self.other_indices[idx_from_other].tolist()
234
+ )
235
+
236
+ # Shuffle the order of the patches within the grid
237
+ random.shuffle(grid)
238
+ yield from grid
239
+
240
+ def __len__(self) -> int:
241
+ return len(self.dataset)
code/eval_base.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datamodules import CIFAR10QADataModule, ImageDataModule
2
+ from datamodules.utils import datamodule_factory
3
+ from models import ImageClassificationNet
4
+ from models.utils import model_factory
5
+ from pytorch_lightning.loggers import WandbLogger
6
+
7
+ import argparse
8
+ import pytorch_lightning as pl
9
+
10
+
11
+ def main(args: argparse.Namespace):
12
+ # Seed
13
+ pl.seed_everything(args.seed)
14
+
15
+ # Create base model
16
+ base = model_factory(args, own_config=True)
17
+
18
+ # Load datamodule
19
+ dm = datamodule_factory(args)
20
+
21
+ # Load the model from the specified checkpoint
22
+ model = ImageClassificationNet.load_from_checkpoint(
23
+ args.checkpoint,
24
+ model=base,
25
+ num_train_steps=0,
26
+ )
27
+
28
+ # Create wandb logger
29
+ wandb_logger = WandbLogger(
30
+ name=f"{args.dataset}_eval_{args.base_model} ({args.from_pretrained})",
31
+ project="Patch-DiffMask",
32
+ )
33
+
34
+ # Create trainer
35
+ trainer = pl.Trainer(
36
+ accelerator="auto",
37
+ logger=wandb_logger,
38
+ max_epochs=1,
39
+ enable_progress_bar=args.enable_progress_bar,
40
+ )
41
+
42
+ # Evaluate the model
43
+ trainer.test(model, dm)
44
+
45
+ # Save the HuggingFace model to be used with --from_pretrained
46
+ save_dir = f"checkpoints/{args.base_model}_{args.dataset}"
47
+ model.model.save_pretrained(save_dir)
48
+ dm.feature_extractor.save_pretrained(save_dir)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser()
53
+
54
+ parser.add_argument(
55
+ "--checkpoint",
56
+ type=str,
57
+ required=True,
58
+ help="Checkpoint to resume the training from.",
59
+ )
60
+
61
+ # Trainer
62
+ parser.add_argument(
63
+ "--enable_progress_bar",
64
+ action="store_true",
65
+ help="Whether to show progress bar during training. NOT recommended when logging to files.",
66
+ )
67
+ parser.add_argument(
68
+ "--seed",
69
+ type=int,
70
+ default=123,
71
+ help="Random seed for reproducibility.",
72
+ )
73
+
74
+ # Base (classification) model
75
+ parser.add_argument(
76
+ "--base_model",
77
+ type=str,
78
+ default="ViT",
79
+ choices=["ViT", "ConvNeXt"],
80
+ help="Base model architecture to train.",
81
+ )
82
+ parser.add_argument(
83
+ "--from_pretrained",
84
+ type=str,
85
+ # default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
86
+ help="The name of the pretrained HF model to fine-tune from.",
87
+ )
88
+
89
+ # Datamodule
90
+ ImageDataModule.add_model_specific_args(parser)
91
+ CIFAR10QADataModule.add_model_specific_args(parser)
92
+ parser.add_argument(
93
+ "--dataset",
94
+ type=str,
95
+ default="toy",
96
+ choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
97
+ help="The dataset to use.",
98
+ )
99
+
100
+ args = parser.parse_args()
101
+
102
+ main(args)
code/main.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+ from attributions import attention_rollout, grad_cam
3
+ from datamodules import CIFAR10QADataModule, ImageDataModule
4
+ from datamodules.utils import datamodule_factory
5
+ from functools import partial
6
+ from models import ImageInterpretationNet
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from pytorch_lightning.loggers import WandbLogger
9
+ from transformers import ViTForImageClassification
10
+ from utils.plot import DrawMaskCallback, log_masks
11
+
12
+ import pytorch_lightning as pl
13
+
14
+
15
+ def get_experiment_name(args: Namespace):
16
+ """Create a name for the experiment based on the command line arguments."""
17
+ # Convert to dictionary
18
+ args = vars(args)
19
+
20
+ # Create a list with non-experiment arguments
21
+ non_experiment_args = [
22
+ "add_blur",
23
+ "add_noise",
24
+ "add_rotation",
25
+ "base_model",
26
+ "batch_size",
27
+ "class_idx",
28
+ "data_dir",
29
+ "enable_progress_bar",
30
+ "from_pretrained",
31
+ "log_every_n_steps",
32
+ "num_epochs",
33
+ "num_workers",
34
+ "sample_images",
35
+ "seed",
36
+ ]
37
+
38
+ # Create experiment name from experiment arguments
39
+ return "-".join(
40
+ [
41
+ f"{name}={value}"
42
+ for name, value in sorted(args.items())
43
+ if name not in non_experiment_args
44
+ ]
45
+ )
46
+
47
+
48
+ def setup_sample_image_logs(
49
+ dm: ImageDataModule,
50
+ args: Namespace,
51
+ logger: WandbLogger,
52
+ n_panels: int = 2, # TODO: change?
53
+ ):
54
+ """Setup the log callbacks for sampling and plotting images."""
55
+ images_per_panel = args.sample_images
56
+
57
+ # Sample images
58
+ sample_images = []
59
+ iter_loader = iter(dm.val_dataloader())
60
+ for panel in range(n_panels):
61
+ X, Y = next(iter_loader)
62
+ sample_images += [(X[:images_per_panel], Y[:images_per_panel])]
63
+
64
+ # Define mask callback
65
+ mask_cb = partial(DrawMaskCallback, log_every_n_steps=args.log_every_n_steps)
66
+
67
+ callbacks = []
68
+ for panel in range(n_panels):
69
+ # Initialize ViT model
70
+ vit = ViTForImageClassification.from_pretrained(args.from_pretrained)
71
+
72
+ # Extract samples for current panel
73
+ samples = sample_images[panel]
74
+ X, _ = samples
75
+
76
+ # Log GradCAM
77
+ gradcam_masks = grad_cam(X, vit)
78
+ log_masks(X, gradcam_masks, f"GradCAM {panel}", logger)
79
+
80
+ # Log Attention Rollout
81
+ rollout_masks = attention_rollout(X, vit)
82
+ log_masks(X, rollout_masks, f"Attention Rollout {panel}", logger)
83
+
84
+ # Create mask callback
85
+ callbacks += [mask_cb(samples, key=f"{panel}")]
86
+
87
+ return callbacks
88
+
89
+
90
+ def main(args: Namespace):
91
+ # Seed
92
+ pl.seed_everything(args.seed)
93
+
94
+ # Load pre-trained Transformer
95
+ model = ViTForImageClassification.from_pretrained(args.from_pretrained)
96
+
97
+ # Load datamodule
98
+ dm = datamodule_factory(args)
99
+
100
+ # Setup datamodule to sample images for the mask callback
101
+ dm.prepare_data()
102
+ dm.setup("fit")
103
+
104
+ # Create Vision DiffMask for the model
105
+ diffmask = ImageInterpretationNet(
106
+ model_cfg=model.config,
107
+ alpha=args.alpha,
108
+ lr=args.lr,
109
+ eps=args.eps,
110
+ lr_placeholder=args.lr_placeholder,
111
+ lr_alpha=args.lr_alpha,
112
+ mul_activation=args.mul_activation,
113
+ add_activation=args.add_activation,
114
+ placeholder=not args.no_placeholder,
115
+ weighted_layer_pred=args.weighted_layer_distribution,
116
+ )
117
+ diffmask.set_vision_transformer(model)
118
+
119
+ # Create wandb logger instance
120
+ wandb_logger = WandbLogger(
121
+ name=get_experiment_name(args),
122
+ project="Patch-DiffMask",
123
+ )
124
+
125
+ # Create checkpoint callback
126
+ ckpt_cb = ModelCheckpoint(
127
+ save_top_k=-1,
128
+ dirpath=f"checkpoints/{wandb_logger.version}",
129
+ every_n_train_steps=args.log_every_n_steps,
130
+ )
131
+
132
+ # Create mask callbacks
133
+ mask_cbs = setup_sample_image_logs(dm, args, wandb_logger)
134
+
135
+ # Create trainer
136
+ trainer = pl.Trainer(
137
+ accelerator="auto",
138
+ callbacks=[ckpt_cb, *mask_cbs],
139
+ enable_progress_bar=args.enable_progress_bar,
140
+ logger=wandb_logger,
141
+ max_epochs=args.num_epochs,
142
+ )
143
+
144
+ # Train the model
145
+ trainer.fit(diffmask, dm)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = ArgumentParser()
150
+
151
+ # Trainer
152
+ parser.add_argument(
153
+ "--enable_progress_bar",
154
+ action="store_true",
155
+ help="Whether to enable the progress bar (NOT recommended when logging to file).",
156
+ )
157
+ parser.add_argument(
158
+ "--num_epochs",
159
+ type=int,
160
+ default=5,
161
+ help="Number of epochs to train.",
162
+ )
163
+ parser.add_argument(
164
+ "--seed",
165
+ type=int,
166
+ default=123,
167
+ help="Random seed for reproducibility.",
168
+ )
169
+
170
+ # Logging
171
+ parser.add_argument(
172
+ "--sample_images",
173
+ type=int,
174
+ default=8,
175
+ help="Number of images to sample for the mask callback.",
176
+ )
177
+ parser.add_argument(
178
+ "--log_every_n_steps",
179
+ type=int,
180
+ default=200,
181
+ help="Number of steps between logging media & checkpoints.",
182
+ )
183
+
184
+ # Base (classification) model
185
+ parser.add_argument(
186
+ "--base_model",
187
+ type=str,
188
+ default="ViT",
189
+ choices=["ViT"],
190
+ help="Base model architecture to train.",
191
+ )
192
+ parser.add_argument(
193
+ "--from_pretrained",
194
+ type=str,
195
+ default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
196
+ help="The name of the pretrained HF model to load.",
197
+ )
198
+
199
+ # Interpretation model
200
+ ImageInterpretationNet.add_model_specific_args(parser)
201
+
202
+ # Datamodule
203
+ ImageDataModule.add_model_specific_args(parser)
204
+ CIFAR10QADataModule.add_model_specific_args(parser)
205
+ parser.add_argument(
206
+ "--dataset",
207
+ type=str,
208
+ default="CIFAR10",
209
+ choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
210
+ help="The dataset to use.",
211
+ )
212
+
213
+ args = parser.parse_args()
214
+
215
+ main(args)
code/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .classification import ImageClassificationNet
2
+ from .interpretation import ImageInterpretationNet
code/models/classification.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of this file have been adapted from
3
+ https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
4
+ """
5
+
6
+ import pytorch_lightning as pl
7
+ import torch.nn.functional as F
8
+
9
+ from argparse import ArgumentParser
10
+ from torch import Tensor
11
+ from torch.optim import AdamW, Optimizer, RAdam
12
+ from torch.optim.lr_scheduler import _LRScheduler
13
+ from transformers import get_scheduler, PreTrainedModel
14
+
15
+
16
+ class ImageClassificationNet(pl.LightningModule):
17
+ @staticmethod
18
+ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
19
+ parser = parent_parser.add_argument_group("Classification Model")
20
+ parser.add_argument(
21
+ "--optimizer",
22
+ type=str,
23
+ default="AdamW",
24
+ choices=["AdamW", "RAdam"],
25
+ help="The optimizer to use to train the model.",
26
+ )
27
+ parser.add_argument(
28
+ "--weight_decay",
29
+ type=float,
30
+ default=1e-2,
31
+ help="The optimizer's weight decay.",
32
+ )
33
+ parser.add_argument(
34
+ "--lr",
35
+ type=float,
36
+ default=5e-5,
37
+ help="The initial learning rate for the model.",
38
+ )
39
+ return parent_parser
40
+
41
+ def __init__(
42
+ self,
43
+ model: PreTrainedModel,
44
+ num_train_steps: int,
45
+ optimizer: str = "AdamW",
46
+ weight_decay: float = 1e-2,
47
+ lr: float = 5e-5,
48
+ ):
49
+ """A PyTorch Lightning Module for a HuggingFace model used for image classification.
50
+
51
+ Args:
52
+ model (PreTrainedModel): a pretrained model for image classification
53
+ num_train_steps (int): number of training steps
54
+ optimizer (str): optimizer to use
55
+ weight_decay (float): weight decay for optimizer
56
+ lr (float): the learning rate used for training
57
+ """
58
+ super().__init__()
59
+
60
+ # Save the hyperparameters and the model
61
+ self.save_hyperparameters(ignore=["model"])
62
+ self.model = model
63
+
64
+ def forward(self, x: Tensor) -> Tensor:
65
+ return self.model(x).logits
66
+
67
+ def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]:
68
+ # Set the optimizer class based on the hyperparameter
69
+ if self.hparams.optimizer == "AdamW":
70
+ optim_class = AdamW
71
+ elif self.hparams.optimizer == "RAdam":
72
+ optim_class = RAdam
73
+ else:
74
+ raise Exception(f"Unknown optimizer {self.hparams.optimizer}")
75
+
76
+ # Create the optimizer and the learning rate scheduler
77
+ optimizer = optim_class(
78
+ self.parameters(),
79
+ weight_decay=self.hparams.weight_decay,
80
+ lr=self.hparams.lr,
81
+ )
82
+ lr_scheduler = get_scheduler(
83
+ name="linear",
84
+ optimizer=optimizer,
85
+ num_warmup_steps=0,
86
+ num_training_steps=self.hparams.num_train_steps,
87
+ )
88
+
89
+ return [optimizer], [lr_scheduler]
90
+
91
+ def _calculate_loss(self, batch: tuple[Tensor, Tensor], mode: str) -> Tensor:
92
+ imgs, labels = batch
93
+
94
+ preds = self.model(imgs).logits
95
+ loss = F.cross_entropy(preds, labels)
96
+ acc = (preds.argmax(dim=-1) == labels).float().mean()
97
+
98
+ self.log(f"{mode}_loss", loss)
99
+ self.log(f"{mode}_acc", acc)
100
+
101
+ return loss
102
+
103
+ def training_step(self, batch: tuple[Tensor, Tensor], _: Tensor) -> Tensor:
104
+ loss = self._calculate_loss(batch, mode="train")
105
+
106
+ return loss
107
+
108
+ def validation_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
109
+ self._calculate_loss(batch, mode="val")
110
+
111
+ def test_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
112
+ self._calculate_loss(batch, mode="test")
code/models/gates.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of this file have been adapted from
3
+ https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/gates.py
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from torch import Tensor
10
+ from typing import Optional
11
+ from utils.distributions import RectifiedStreched, BinaryConcrete
12
+
13
+
14
+ class MLPGate(nn.Module):
15
+ def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
16
+ """
17
+ This is an MLP with the following structure;
18
+ Linear(input_size, hidden_size), Tanh(), Linear(hidden_size, 1)
19
+ The bias of the last layer is set to 5.0 to start with high probability
20
+ of keeping states (fundamental for good convergence as the initialized
21
+ DiffMask has not learned what to mask yet).
22
+
23
+ Args:
24
+ input_size (int): the number of input features
25
+ hidden_size (int): the number of hidden units
26
+ bias (bool): whether to use a bias term
27
+ """
28
+ super().__init__()
29
+
30
+ self.f = nn.Sequential(
31
+ nn.utils.weight_norm(nn.Linear(input_size, hidden_size)),
32
+ nn.Tanh(),
33
+ nn.utils.weight_norm(nn.Linear(hidden_size, 1, bias=bias)),
34
+ )
35
+
36
+ if bias:
37
+ self.f[-1].bias.data[:] = 5.0
38
+
39
+ def forward(self, *args: Tensor) -> Tensor:
40
+ return self.f(torch.cat(args, -1))
41
+
42
+
43
+ class MLPMaxGate(nn.Module):
44
+ def __init__(
45
+ self,
46
+ input_size: int,
47
+ hidden_size: int,
48
+ mul_activation: float = 10.0,
49
+ add_activation: float = 5.0,
50
+ bias: bool = True,
51
+ ):
52
+ """
53
+ This is an MLP with the following structure;
54
+ Linear(input_size, hidden_size), Tanh(), Linear(hidden_size, 1)
55
+ The bias of the last layer is set to 5.0 to start with high probability
56
+ of keeping states (fundamental for good convergence as the initialized
57
+ DiffMask has not learned what to mask yet).
58
+ It also uses a scaler for the output of the activation function.
59
+
60
+ Args:
61
+ input_size (int): the number of input features
62
+ hidden_size (int): the number of hidden units
63
+ mul_activation (float): the scaler for the output of the activation function
64
+ add_activation (float): the offset for the output of the activation function
65
+ bias (bool): whether to use a bias term
66
+ """
67
+ super().__init__()
68
+
69
+ self.f = nn.Sequential(
70
+ nn.utils.weight_norm(nn.Linear(input_size, hidden_size)),
71
+ nn.Tanh(),
72
+ nn.utils.weight_norm(nn.Linear(hidden_size, 1, bias=bias)),
73
+ nn.Tanh(),
74
+ )
75
+ self.add_activation = nn.Parameter(torch.tensor(add_activation))
76
+ self.mul_activation = mul_activation
77
+
78
+ def forward(self, *args: Tensor) -> Tensor:
79
+ return self.f(torch.cat(args, -1)) * self.mul_activation + self.add_activation
80
+
81
+
82
+ class DiffMaskGateInput(nn.Module):
83
+ def __init__(
84
+ self,
85
+ hidden_size: int,
86
+ hidden_attention: int,
87
+ num_hidden_layers: int,
88
+ max_position_embeddings: int,
89
+ gate_fn: nn.Module = MLPMaxGate,
90
+ mul_activation: float = 10.0,
91
+ add_activation: float = 5.0,
92
+ gate_bias: bool = True,
93
+ placeholder: bool = False,
94
+ init_vector: Tensor = None,
95
+ ):
96
+ """This is a DiffMask module that masks the input of the first layer.
97
+
98
+ Args:
99
+ hidden_size (int): the size of the hidden representations
100
+ hidden_attention (int) the amount of units in the gate's hidden (bottleneck) layer
101
+ num_hidden_layers (int): the number of hidden layers (and thus gates to use)
102
+ max_position_embeddings (int): the amount of placeholder embeddings to learn for the masked positions
103
+ gate_fn (nn.Module): the PyTorch module to use as a gate
104
+ mul_activation (float): the scaler for the output of the activation function
105
+ add_activation (float): the offset for the output of the activation function
106
+ gate_bias (bool): whether to use a bias term
107
+ placeholder (bool): whether to use placeholder embeddings or a zero vector
108
+ init_vector (Tensor): the initial vector to use for the placeholder embeddings
109
+ """
110
+ super().__init__()
111
+
112
+ # Create a ModuleList with the gates
113
+ self.g_hat = nn.ModuleList(
114
+ [
115
+ gate_fn(
116
+ hidden_size * 2,
117
+ hidden_attention,
118
+ mul_activation,
119
+ add_activation,
120
+ gate_bias,
121
+ )
122
+ for _ in range(num_hidden_layers)
123
+ ]
124
+ )
125
+
126
+ if placeholder:
127
+ # Use a placeholder embedding for the masked positions
128
+ self.placeholder = nn.Parameter(
129
+ nn.init.xavier_normal_(
130
+ torch.empty((1, max_position_embeddings, hidden_size))
131
+ )
132
+ if init_vector is None
133
+ else init_vector.view(1, 1, hidden_size).repeat(
134
+ 1, max_position_embeddings, 1
135
+ )
136
+ )
137
+ else:
138
+ # Use a zero vector for the masked positions
139
+ self.register_buffer(
140
+ "placeholder",
141
+ torch.zeros((1, 1, hidden_size)),
142
+ )
143
+
144
+ def forward(
145
+ self, hidden_states: tuple[Tensor], layer_pred: Optional[int]
146
+ ) -> tuple[tuple[Tensor], Tensor, Tensor, Tensor, Tensor]:
147
+ # Concatenate the output of all the gates
148
+ logits = torch.cat(
149
+ [
150
+ self.g_hat[i](hidden_states[0], hidden_states[i])
151
+ for i in range(
152
+ (layer_pred + 1) if layer_pred is not None else len(hidden_states)
153
+ )
154
+ ],
155
+ -1,
156
+ )
157
+
158
+ # Define a Hard Concrete distribution
159
+ dist = RectifiedStreched(
160
+ BinaryConcrete(torch.full_like(logits, 0.2), logits),
161
+ l=-0.2,
162
+ r=1.0,
163
+ )
164
+
165
+ # Calculate the expectation for the full gate probabilities
166
+ # These act as votes for the masked positions
167
+ gates_full = dist.rsample().cumprod(-1)
168
+ expected_L0_full = dist.log_expected_L0().cumsum(-1)
169
+
170
+ # Extract the probabilities from the last layer, which acts
171
+ # as an aggregation of the votes per position
172
+ gates = gates_full[..., -1]
173
+ expected_L0 = expected_L0_full[..., -1]
174
+
175
+ return (
176
+ hidden_states[0] * gates.unsqueeze(-1)
177
+ + self.placeholder[:, : hidden_states[0].shape[-2]]
178
+ * (1 - gates).unsqueeze(-1),
179
+ gates,
180
+ expected_L0,
181
+ gates_full,
182
+ expected_L0_full,
183
+ )
184
+
185
+
186
+ # class DiffMaskGateHidden(nn.Module):
187
+ # def __init__(
188
+ # self,
189
+ # hidden_size: int,
190
+ # hidden_attention: int,
191
+ # num_hidden_layers: int,
192
+ # max_position_embeddings: int,
193
+ # gate_fn: nn.Module = MLPMaxGate,
194
+ # gate_bias: bool = True,
195
+ # placeholder: bool = False,
196
+ # init_vector: Tensor = None,
197
+ # ):
198
+ # super().__init__()
199
+ #
200
+ # self.g_hat = nn.ModuleList(
201
+ # [
202
+ # gate_fn(hidden_size, hidden_attention, bias=gate_bias)
203
+ # for _ in range(num_hidden_layers)
204
+ # ]
205
+ # )
206
+ #
207
+ # if placeholder:
208
+ # self.placeholder = nn.ParameterList(
209
+ # [
210
+ # nn.Parameter(
211
+ # nn.init.xavier_normal_(
212
+ # torch.empty((1, max_position_embeddings, hidden_size))
213
+ # )
214
+ # if init_vector is None
215
+ # else init_vector.view(1, 1, hidden_size).repeat(
216
+ # 1, max_position_embeddings, 1
217
+ # )
218
+ # )
219
+ # for _ in range(num_hidden_layers)
220
+ # ]
221
+ # )
222
+ # else:
223
+ # self.register_buffer(
224
+ # "placeholder",
225
+ # torch.zeros((num_hidden_layers, 1, 1, hidden_size)),
226
+ # )
227
+ #
228
+ # def forward(
229
+ # self, hidden_states: tuple[Tensor], layer_pred: Optional[int]
230
+ # ) -> tuple[tuple[Tensor], Tensor, Tensor, Tensor, Tensor]:
231
+ # if layer_pred is not None:
232
+ # logits = self.g_hat[layer_pred](hidden_states[layer_pred])
233
+ # else:
234
+ # logits = torch.cat(
235
+ # [self.g_hat[i](hidden_states[i]) for i in range(len(hidden_states))], -1
236
+ # )
237
+ #
238
+ # dist = RectifiedStreched(
239
+ # BinaryConcrete(torch.full_like(logits, 0.2), logits),
240
+ # l=-0.2,
241
+ # r=1.0,
242
+ # )
243
+ #
244
+ # gates_full = dist.rsample()
245
+ # expected_L0_full = dist.log_expected_L0()
246
+ #
247
+ # gates = gates_full if layer_pred is not None else gates_full[..., :1]
248
+ # expected_L0 = (
249
+ # expected_L0_full if layer_pred is not None else expected_L0_full[..., :1]
250
+ # )
251
+ #
252
+ # layer_pred = layer_pred or 0 # equiv to "layer_pred if layer_pred else 0"
253
+ # return (
254
+ # hidden_states[layer_pred] * gates
255
+ # + self.placeholder[layer_pred][:, : hidden_states[layer_pred].shape[-2]]
256
+ # * (1 - gates),
257
+ # gates.squeeze(-1),
258
+ # expected_L0.squeeze(-1),
259
+ # gates_full,
260
+ # expected_L0_full,
261
+ # )
code/models/interpretation.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from .gates import DiffMaskGateInput
6
+ from argparse import ArgumentParser
7
+ from math import sqrt
8
+ from pytorch_lightning.core.optimizer import LightningOptimizer
9
+ from torch import Tensor
10
+ from torch.optim import Optimizer
11
+ from torch.optim.lr_scheduler import _LRScheduler
12
+ from transformers import (
13
+ get_constant_schedule_with_warmup,
14
+ get_constant_schedule,
15
+ ViTForImageClassification,
16
+ )
17
+ from transformers.models.vit.configuration_vit import ViTConfig
18
+ from typing import Optional, Union
19
+ from utils.getters_setters import vit_getter, vit_setter
20
+ from utils.metrics import accuracy_precision_recall_f1
21
+ from utils.optimizer import LookaheadAdam
22
+
23
+
24
+ class ImageInterpretationNet(pl.LightningModule):
25
+ @staticmethod
26
+ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
27
+ parser = parent_parser.add_argument_group("Vision DiffMask")
28
+ parser.add_argument(
29
+ "--alpha",
30
+ type=float,
31
+ default=20.0,
32
+ help="Initial value for the Lagrangian",
33
+ )
34
+ parser.add_argument(
35
+ "--lr",
36
+ type=float,
37
+ default=2e-5,
38
+ help="Learning rate for DiffMask.",
39
+ )
40
+ parser.add_argument(
41
+ "--eps",
42
+ type=float,
43
+ default=0.1,
44
+ help="KL divergence tolerance.",
45
+ )
46
+ parser.add_argument(
47
+ "--no_placeholder",
48
+ action="store_true",
49
+ help="Whether to not use placeholder",
50
+ )
51
+ parser.add_argument(
52
+ "--lr_placeholder",
53
+ type=float,
54
+ default=1e-3,
55
+ help="Learning for mask vectors.",
56
+ )
57
+ parser.add_argument(
58
+ "--lr_alpha",
59
+ type=float,
60
+ default=0.3,
61
+ help="Learning rate for lagrangian optimizer.",
62
+ )
63
+ parser.add_argument(
64
+ "--mul_activation",
65
+ type=float,
66
+ default=15.0,
67
+ help="Value to multiply gate activations.",
68
+ )
69
+ parser.add_argument(
70
+ "--add_activation",
71
+ type=float,
72
+ default=8.0,
73
+ help="Value to add to gate activations.",
74
+ )
75
+ parser.add_argument(
76
+ "--weighted_layer_distribution",
77
+ action="store_true",
78
+ help="Whether to use a weighted distribution when picking a layer in DiffMask forward.",
79
+ )
80
+ return parent_parser
81
+
82
+ # Declare variables that will be initialized later
83
+ model: ViTForImageClassification
84
+
85
+ def __init__(
86
+ self,
87
+ model_cfg: ViTConfig,
88
+ alpha: float = 1,
89
+ lr: float = 3e-4,
90
+ eps: float = 0.1,
91
+ eps_valid: float = 0.8,
92
+ acc_valid: float = 0.75,
93
+ lr_placeholder: float = 1e-3,
94
+ lr_alpha: float = 0.3,
95
+ mul_activation: float = 10.0,
96
+ add_activation: float = 5.0,
97
+ placeholder: bool = True,
98
+ weighted_layer_pred: bool = False,
99
+ ):
100
+ """A PyTorch Lightning Module for the VisionDiffMask model on the Vision Transformer.
101
+
102
+ Args:
103
+ model_cfg (ViTConfig): the configuration of the Vision Transformer model
104
+ alpha (float): the initial value for the Lagrangian
105
+ lr (float): the learning rate for the DiffMask gates
106
+ eps (float): the tolerance for the KL divergence
107
+ eps_valid (float): the tolerance for the KL divergence in the validation step
108
+ acc_valid (float): the accuracy threshold for the validation step
109
+ lr_placeholder (float): the learning rate for the learnable masking embeddings
110
+ lr_alpha (float): the learning rate for the Lagrangian
111
+ mul_activation (float): the value to multiply the gate activations by
112
+ add_activation (float): the value to add to the gate activations
113
+ placeholder (bool): whether to use placeholder embeddings or a zero vector
114
+ weighted_layer_pred (bool): whether to use a weighted distribution when picking a layer
115
+ """
116
+ super().__init__()
117
+
118
+ # Save the hyperparameters
119
+ self.save_hyperparameters()
120
+
121
+ # Create DiffMask instance
122
+ self.gate = DiffMaskGateInput(
123
+ hidden_size=model_cfg.hidden_size,
124
+ hidden_attention=model_cfg.hidden_size // 4,
125
+ num_hidden_layers=model_cfg.num_hidden_layers + 2,
126
+ max_position_embeddings=1,
127
+ mul_activation=mul_activation,
128
+ add_activation=add_activation,
129
+ placeholder=placeholder,
130
+ )
131
+
132
+ # Create the Lagrangian values for the dual optimization
133
+ self.alpha = torch.nn.ParameterList(
134
+ [
135
+ torch.nn.Parameter(torch.ones(()) * alpha)
136
+ for _ in range(model_cfg.num_hidden_layers + 2)
137
+ ]
138
+ )
139
+
140
+ # Register buffers for running metrics
141
+ self.register_buffer(
142
+ "running_acc", torch.ones((model_cfg.num_hidden_layers + 2,))
143
+ )
144
+ self.register_buffer(
145
+ "running_l0", torch.ones((model_cfg.num_hidden_layers + 2,))
146
+ )
147
+ self.register_buffer(
148
+ "running_steps", torch.zeros((model_cfg.num_hidden_layers + 2,))
149
+ )
150
+
151
+ def set_vision_transformer(self, model: ViTForImageClassification):
152
+ """Set the Vision Transformer model to be used with this module."""
153
+ # Save the model instance as a class attribute
154
+ self.model = model
155
+ # Freeze the model's parameters
156
+ for param in self.model.parameters():
157
+ param.requires_grad = False
158
+
159
+ def forward_explainer(
160
+ self, x: Tensor, attribution: bool = False
161
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int, int]:
162
+ """Performs a forward pass through the explainer (VisionDiffMask) model."""
163
+ # Get the original logits and hidden states from the model
164
+ logits_orig, hidden_states = vit_getter(self.model, x)
165
+
166
+ # Add [CLS] token to deal with shape mismatch in self.gate() call
167
+ patch_embeddings = hidden_states[0]
168
+ batch_size = len(patch_embeddings)
169
+ cls_tokens = self.model.vit.embeddings.cls_token.expand(batch_size, -1, -1)
170
+ hidden_states[0] = torch.cat((cls_tokens, patch_embeddings), dim=1)
171
+
172
+ # Select the layer to generate the mask from in this pass
173
+ n_hidden = len(hidden_states)
174
+ if self.hparams.weighted_layer_pred:
175
+ # If weighted layer prediction is enabled, use a weighted distribution
176
+ # instead of uniformly picking a layer after a number of steps
177
+ low_weight = (
178
+ lambda i: self.running_acc[i] > 0.75
179
+ and self.running_l0[i] < 0.1
180
+ and self.running_steps[i] > 100
181
+ )
182
+ layers = torch.tensor(list(range(n_hidden)))
183
+ p = torch.tensor([0.1 if low_weight(i) else 1 for i in range(n_hidden)])
184
+ p = p / p.sum()
185
+ idx = p.multinomial(num_samples=1)
186
+ layer_pred = layers[idx].item()
187
+ else:
188
+ layer_pred = torch.randint(n_hidden, ()).item()
189
+
190
+ # Set the layer to drop to 0, since we are only interested in masking the input
191
+ layer_drop = 0
192
+
193
+ (
194
+ new_hidden_state,
195
+ gates,
196
+ expected_L0,
197
+ gates_full,
198
+ expected_L0_full,
199
+ ) = self.gate(
200
+ hidden_states=hidden_states,
201
+ layer_pred=None
202
+ if attribution
203
+ else layer_pred, # if attribution, we get all the hidden states
204
+ )
205
+
206
+ # Create the list of the new hidden states for the new forward pass
207
+ new_hidden_states = (
208
+ [None] * layer_drop
209
+ + [new_hidden_state]
210
+ + [None] * (n_hidden - layer_drop - 1)
211
+ )
212
+
213
+ # Get the new logits from the masked input
214
+ logits, _ = vit_setter(self.model, x, new_hidden_states)
215
+
216
+ return (
217
+ logits,
218
+ logits_orig,
219
+ gates,
220
+ expected_L0,
221
+ gates_full,
222
+ expected_L0_full,
223
+ layer_drop,
224
+ layer_pred,
225
+ )
226
+
227
+ def get_mask(self, x: Tensor,
228
+ idx: int = -1,
229
+ aggregated_mask: bool = True,
230
+ ) -> dict[str, Tensor]:
231
+ """
232
+ Generates a mask for the given input.
233
+ Args:
234
+ x: the input to generate the mask for
235
+ idx: the index of the layer to generate the mask from
236
+ aggregated_mask: whether to use an aggregative mask from each layer
237
+ Returns:
238
+ a dictionary containing the mask, kl divergence and the predicted class
239
+ """
240
+
241
+ # Pass from forward explainer with attribution=True
242
+ (
243
+ logits,
244
+ logits_orig,
245
+ gates,
246
+ expected_L0,
247
+ gates_full,
248
+ expected_L0_full,
249
+ layer_drop,
250
+ layer_pred,
251
+ ) = self.forward_explainer(x, attribution=True)
252
+
253
+ # Calculate KL-divergence
254
+ kl_div = torch.distributions.kl_divergence(
255
+ torch.distributions.Categorical(logits=logits_orig),
256
+ torch.distributions.Categorical(logits=logits),
257
+ )
258
+
259
+ # Get predicted class
260
+ pred_class = logits.argmax(-1)
261
+
262
+ # Calculate mask
263
+ if aggregated_mask:
264
+ mask = expected_L0_full[:, :, idx].exp()
265
+ else:
266
+ mask = gates_full[:, :, idx]
267
+
268
+ mask = mask[:, 1:]
269
+
270
+ C, H, W = x.shape[1:] # channels, height, width
271
+ B, P = mask.shape # batch, patches
272
+ N = int(sqrt(P)) # patches per side
273
+ S = int(H / N) # patch size
274
+
275
+ # Reshape mask to match input shape
276
+ mask = mask.reshape(B, 1, N, N)
277
+ mask = F.interpolate(mask, scale_factor=S)
278
+ mask = mask.reshape(B, H, W)
279
+
280
+ return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class}
281
+
282
+ def forward(self, x: Tensor) -> Tensor:
283
+ return self.model(x).logits
284
+
285
+ def training_step(self, batch: tuple[Tensor, Tensor], *args, **kwargs) -> dict:
286
+ # Unpack the batch
287
+ x, y = batch
288
+
289
+ # Pass the batch through the explainer (VisionDiffMask) model
290
+ (
291
+ logits,
292
+ logits_orig,
293
+ gates,
294
+ expected_L0,
295
+ gates_full,
296
+ expected_L0_full,
297
+ layer_drop,
298
+ layer_pred,
299
+ ) = self.forward_explainer(x)
300
+
301
+ # Calculate the KL-divergence loss term
302
+ loss_c = (
303
+ torch.distributions.kl_divergence(
304
+ torch.distributions.Categorical(logits=logits_orig),
305
+ torch.distributions.Categorical(logits=logits),
306
+ )
307
+ - self.hparams.eps
308
+ )
309
+
310
+ # Calculate the L0 loss term
311
+ loss_g = expected_L0.mean(-1)
312
+
313
+ # Calculate the full loss term
314
+ loss = self.alpha[layer_pred] * loss_c + loss_g
315
+
316
+ # Calculate the accuracy
317
+ acc, _, _, _ = accuracy_precision_recall_f1(
318
+ logits.argmax(-1), logits_orig.argmax(-1), average=True
319
+ )
320
+
321
+ # Calculate the average L0 loss
322
+ l0 = expected_L0.exp().mean(-1)
323
+
324
+ outputs_dict = {
325
+ "loss_c": loss_c.mean(-1),
326
+ "loss_g": loss_g.mean(-1),
327
+ "alpha": self.alpha[layer_pred].mean(-1),
328
+ "acc": acc,
329
+ "l0": l0.mean(-1),
330
+ "layer_pred": layer_pred,
331
+ "r_acc": self.running_acc[layer_pred],
332
+ "r_l0": self.running_l0[layer_pred],
333
+ "r_steps": self.running_steps[layer_pred],
334
+ "debug_loss": loss.mean(-1),
335
+ }
336
+
337
+ outputs_dict = {
338
+ "loss": loss.mean(-1),
339
+ **outputs_dict,
340
+ "log": outputs_dict,
341
+ "progress_bar": outputs_dict,
342
+ }
343
+
344
+ self.log(
345
+ "loss", outputs_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
346
+ )
347
+ self.log(
348
+ "loss_c", outputs_dict["loss_c"], on_step=True, on_epoch=True, prog_bar=True
349
+ )
350
+ self.log(
351
+ "loss_g", outputs_dict["loss_g"], on_step=True, on_epoch=True, prog_bar=True
352
+ )
353
+ self.log("acc", outputs_dict["acc"], on_step=True, on_epoch=True, prog_bar=True)
354
+ self.log("l0", outputs_dict["l0"], on_step=True, on_epoch=True, prog_bar=True)
355
+ self.log(
356
+ "alpha", outputs_dict["alpha"], on_step=True, on_epoch=True, prog_bar=True
357
+ )
358
+
359
+ outputs_dict = {
360
+ "{}{}".format("" if self.training else "val_", k): v
361
+ for k, v in outputs_dict.items()
362
+ }
363
+
364
+ if self.training:
365
+ self.running_acc[layer_pred] = (
366
+ self.running_acc[layer_pred] * 0.9 + acc * 0.1
367
+ )
368
+ self.running_l0[layer_pred] = (
369
+ self.running_l0[layer_pred] * 0.9 + l0.mean(-1) * 0.1
370
+ )
371
+ self.running_steps[layer_pred] += 1
372
+
373
+ return outputs_dict
374
+
375
+ def validation_epoch_end(self, outputs: list[dict]):
376
+ outputs_dict = {
377
+ k: [e[k] for e in outputs if k in e]
378
+ for k in ("val_loss_c", "val_loss_g", "val_acc", "val_l0")
379
+ }
380
+
381
+ outputs_dict = {k: sum(v) / len(v) for k, v in outputs_dict.items()}
382
+
383
+ outputs_dict["val_loss_c"] += self.hparams.eps
384
+
385
+ outputs_dict = {
386
+ "val_loss": outputs_dict["val_l0"]
387
+ if outputs_dict["val_loss_c"] <= self.hparams.eps_valid
388
+ and outputs_dict["val_acc"] >= self.hparams.acc_valid
389
+ else torch.full_like(outputs_dict["val_l0"], float("inf")),
390
+ **outputs_dict,
391
+ "log": outputs_dict,
392
+ }
393
+
394
+ return outputs_dict
395
+
396
+ def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]:
397
+ optimizers = [
398
+ LookaheadAdam(
399
+ params=[
400
+ {
401
+ "params": self.gate.g_hat.parameters(),
402
+ "lr": self.hparams.lr,
403
+ },
404
+ {
405
+ "params": self.gate.placeholder.parameters()
406
+ if isinstance(self.gate.placeholder, torch.nn.ParameterList)
407
+ else [self.gate.placeholder],
408
+ "lr": self.hparams.lr_placeholder,
409
+ },
410
+ ],
411
+ # centered=True, # this is for LookaheadRMSprop
412
+ ),
413
+ LookaheadAdam(
414
+ params=[self.alpha]
415
+ if isinstance(self.alpha, torch.Tensor)
416
+ else self.alpha.parameters(),
417
+ lr=self.hparams.lr_alpha,
418
+ ),
419
+ ]
420
+
421
+ schedulers = [
422
+ {
423
+ "scheduler": get_constant_schedule_with_warmup(optimizers[0], 12 * 100),
424
+ "interval": "step",
425
+ },
426
+ get_constant_schedule(optimizers[1]),
427
+ ]
428
+ return optimizers, schedulers
429
+
430
+ def optimizer_step(
431
+ self,
432
+ epoch: int,
433
+ batch_idx: int,
434
+ optimizer: Union[Optimizer, LightningOptimizer],
435
+ optimizer_idx: int = 0,
436
+ optimizer_closure: Optional[callable] = None,
437
+ on_tpu: bool = False,
438
+ using_native_amp: bool = False,
439
+ using_lbfgs: bool = False,
440
+ ):
441
+ # Optimizer 0: Minimize loss w.r.t. DiffMask's parameters
442
+ if optimizer_idx == 0:
443
+ # Gradient ascent on the model's parameters
444
+ optimizer.step(closure=optimizer_closure)
445
+ optimizer.zero_grad()
446
+ for g in optimizer.param_groups:
447
+ for p in g["params"]:
448
+ p.grad = None
449
+
450
+ # Optimizer 1: Maximize loss w.r.t. the Langrangian
451
+ elif optimizer_idx == 1:
452
+ # Reverse the sign of the Langrangian's gradients
453
+ for i in range(len(self.alpha)):
454
+ if self.alpha[i].grad:
455
+ self.alpha[i].grad *= -1
456
+
457
+ # Gradient ascent on the Langrangian
458
+ optimizer.step(closure=optimizer_closure)
459
+ optimizer.zero_grad()
460
+ for g in optimizer.param_groups:
461
+ for p in g["params"]:
462
+ p.grad = None
463
+
464
+ # Clip the Lagrangian's values
465
+ for i in range(len(self.alpha)):
466
+ self.alpha[i].data = torch.where(
467
+ self.alpha[i].data < 0,
468
+ torch.full_like(self.alpha[i].data, 0),
469
+ self.alpha[i].data,
470
+ )
471
+ self.alpha[i].data = torch.where(
472
+ self.alpha[i].data > 200,
473
+ torch.full_like(self.alpha[i].data, 200),
474
+ self.alpha[i].data,
475
+ )
476
+
477
+ def on_save_checkpoint(self, ckpt: dict):
478
+ # Remove VIT from checkpoint as we can load it dynamically
479
+ keys = list(ckpt["state_dict"].keys())
480
+ for key in keys:
481
+ if key.startswith("model."):
482
+ del ckpt["state_dict"][key]
code/models/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datamodules.utils import get_configs
2
+ from transformers import (
3
+ ConvNextConfig,
4
+ ConvNextForImageClassification,
5
+ PreTrainedModel,
6
+ ViTConfig,
7
+ ViTForImageClassification,
8
+ )
9
+
10
+ import argparse
11
+ import torch
12
+
13
+
14
+ def set_clf_head(base: PreTrainedModel, num_classes: int):
15
+ """Set the classification head of the model in case of an output mismatch.
16
+
17
+ Args:
18
+ base (PreTrainedModel): the model to modify
19
+ num_classes (int): the number of classes to use for the output layer
20
+ """
21
+ if base.classifier.out_features != num_classes:
22
+ in_features = base.classifier.in_features
23
+ base.classifier = torch.nn.Linear(in_features, num_classes)
24
+
25
+
26
+ def model_factory(
27
+ args: argparse.Namespace,
28
+ own_config: bool = False,
29
+ ) -> PreTrainedModel:
30
+ """A factory method for creating a HuggingFace model based on the command line args.
31
+
32
+ Args:
33
+ args (Namespace): the argparse Namespace object
34
+ own_config (bool): whether to create our own model config instead of a pretrained one;
35
+ this is recommended when the model was pre-trained on another task with a different
36
+ amount of classes for its classifier head
37
+
38
+ Returns:
39
+ a PreTrainedModel instance
40
+ """
41
+ if args.base_model == "ViT":
42
+ # Create a new Vision Transformer
43
+ config_class = ViTConfig
44
+ base_class = ViTForImageClassification
45
+ elif args.base_model == "ConvNeXt":
46
+ # Create a new ConvNext model
47
+ config_class = ConvNextConfig
48
+ base_class = ConvNextForImageClassification
49
+ else:
50
+ raise Exception(f"Unknown base model: {args.base_model}")
51
+
52
+ # Get the model config
53
+ model_cfg_args, _ = get_configs(args)
54
+ if not own_config and args.from_pretrained:
55
+ # Create a model from a pretrained model
56
+ base = base_class.from_pretrained(args.from_pretrained)
57
+ # Set the classifier head if needed
58
+ set_clf_head(base, model_cfg_args["num_labels"])
59
+ else:
60
+ # Create a model based on the config
61
+ config = config_class(**model_cfg_args)
62
+ base = base_class(config)
63
+
64
+ return base
code/train_base.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pytorch_lightning as pl
3
+
4
+ from datamodules import CIFAR10QADataModule, ImageDataModule
5
+ from datamodules.utils import datamodule_factory
6
+ from models import ImageClassificationNet
7
+ from models.utils import model_factory
8
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
9
+ from pytorch_lightning.loggers import WandbLogger
10
+
11
+
12
+ def main(args: argparse.Namespace):
13
+ # Seed
14
+ pl.seed_everything(args.seed)
15
+
16
+ # Create base model
17
+ base = model_factory(args)
18
+
19
+ # Load datamodule
20
+ dm = datamodule_factory(args)
21
+ dm.prepare_data()
22
+ dm.setup("fit")
23
+
24
+ if args.checkpoint:
25
+ # Load the model from the specified checkpoint
26
+ model = ImageClassificationNet.load_from_checkpoint(args.checkpoint, model=base)
27
+ else:
28
+ # Create a new instance of the classification model
29
+ model = ImageClassificationNet(
30
+ model=base,
31
+ num_train_steps=args.num_epochs * len(dm.train_dataloader()),
32
+ optimizer=args.optimizer,
33
+ weight_decay=args.weight_decay,
34
+ lr=args.lr,
35
+ )
36
+
37
+ # Create wandb logger
38
+ wandb_logger = WandbLogger(
39
+ name=f"{args.dataset}_training_{args.base_model} ({args.from_pretrained})",
40
+ project="Patch-DiffMask",
41
+ )
42
+
43
+ # Create checkpoint callback
44
+ ckpt_cb = ModelCheckpoint(dirpath=f"checkpoints/{wandb_logger.version}")
45
+ # Create early stopping callback
46
+ es_cb = EarlyStopping(monitor="val_acc", mode="max", patience=5)
47
+
48
+ # Create trainer
49
+ trainer = pl.Trainer(
50
+ accelerator="auto",
51
+ callbacks=[ckpt_cb, es_cb],
52
+ logger=wandb_logger,
53
+ max_epochs=args.num_epochs,
54
+ enable_progress_bar=args.enable_progress_bar,
55
+ )
56
+
57
+ trainer_args = {}
58
+ if args.checkpoint:
59
+ # Resume trainer from checkpoint
60
+ trainer_args["ckpt_path"] = args.checkpoint
61
+
62
+ # Train the model
63
+ trainer.fit(model, dm, **trainer_args)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser()
68
+
69
+ parser.add_argument(
70
+ "--checkpoint",
71
+ type=str,
72
+ help="Checkpoint to resume the training from.",
73
+ )
74
+
75
+ # Trainer
76
+ parser.add_argument(
77
+ "--enable_progress_bar",
78
+ action="store_true",
79
+ help="Whether to show progress bar during training. NOT recommended when logging to files.",
80
+ )
81
+ parser.add_argument(
82
+ "--num_epochs",
83
+ type=int,
84
+ default=5,
85
+ help="Number of epochs to train.",
86
+ )
87
+ parser.add_argument(
88
+ "--seed",
89
+ type=int,
90
+ default=123,
91
+ help="Random seed for reproducibility.",
92
+ )
93
+
94
+ # Base (classification) model
95
+ ImageClassificationNet.add_model_specific_args(parser)
96
+ parser.add_argument(
97
+ "--base_model",
98
+ type=str,
99
+ default="ViT",
100
+ choices=["ViT", "ConvNeXt"],
101
+ help="Base model architecture to train.",
102
+ )
103
+ parser.add_argument(
104
+ "--from_pretrained",
105
+ type=str,
106
+ # default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
107
+ help="The name of the pretrained HF model to fine-tune from.",
108
+ )
109
+
110
+ # Datamodule
111
+ ImageDataModule.add_model_specific_args(parser)
112
+ CIFAR10QADataModule.add_model_specific_args(parser)
113
+ parser.add_argument(
114
+ "--dataset",
115
+ type=str,
116
+ default="toy",
117
+ choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
118
+ help="The dataset to use.",
119
+ )
120
+
121
+ args = parser.parse_args()
122
+
123
+ main(args)
code/utils/__init__.py ADDED
File without changes
code/utils/distributions.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File copied from
3
+ https://github.com/nicola-decao/diffmask/blob/master/diffmask/models/distributions.py
4
+ """
5
+
6
+ import torch
7
+ import torch.distributions as distr
8
+ import torch.nn.functional as F
9
+
10
+ from torch import Tensor
11
+
12
+
13
+ class BinaryConcrete(distr.relaxed_bernoulli.RelaxedBernoulli):
14
+ def __init__(self, temperature: Tensor, logits: Tensor):
15
+ super().__init__(temperature=temperature, logits=logits)
16
+ self.device = self.temperature.device
17
+
18
+ def cdf(self, value: Tensor) -> Tensor:
19
+ return torch.sigmoid(
20
+ (torch.log(value) - torch.log(1.0 - value)) * self.temperature - self.logits
21
+ )
22
+
23
+ def log_prob(self, value: Tensor) -> Tensor:
24
+ return torch.where(
25
+ (value > 0) & (value < 1),
26
+ super().log_prob(value),
27
+ torch.full_like(value, -float("inf")),
28
+ )
29
+
30
+ def log_expected_L0(self, value: Tensor) -> Tensor:
31
+ return -F.softplus(
32
+ (torch.log(value) - torch.log(1 - value)) * self.temperature - self.logits
33
+ )
34
+
35
+
36
+ class Streched(distr.TransformedDistribution):
37
+ def __init__(self, base_dist, l: float = -0.1, r: float = 1.1):
38
+ super().__init__(base_dist, distr.AffineTransform(loc=l, scale=r - l))
39
+
40
+ def log_expected_L0(self) -> Tensor:
41
+ value = torch.tensor(0.0, device=self.base_dist.device)
42
+ for transform in self.transforms[::-1]:
43
+ value = transform.inv(value)
44
+ if self._validate_args:
45
+ self.base_dist._validate_sample(value)
46
+ value = self.base_dist.log_expected_L0(value)
47
+ value = self._monotonize_cdf(value)
48
+ return value
49
+
50
+ def expected_L0(self) -> Tensor:
51
+ return self.log_expected_L0().exp()
52
+
53
+
54
+ class RectifiedStreched(Streched):
55
+ def __init__(self, *args, **kwargs):
56
+ super().__init__(*args, **kwargs)
57
+
58
+ @torch.no_grad()
59
+ def sample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor:
60
+ return self.rsample(sample_shape)
61
+
62
+ def rsample(self, sample_shape: torch.Size = torch.Size([])) -> Tensor:
63
+ x = super().rsample(sample_shape)
64
+ return x.clamp(0, 1)
code/utils/getters_setters.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ from torch.nn import Module
3
+ from torch.utils.hooks import RemovableHandle
4
+ from transformers import ViTForImageClassification
5
+ from typing import Optional, Union
6
+
7
+
8
+ def _add_hooks(
9
+ model: ViTForImageClassification, get_hook: callable
10
+ ) -> list[RemovableHandle]:
11
+ """Adds a list of hooks to the model according to the get_hook function provided.
12
+
13
+ Args:
14
+ model (ViTForImageClassification): the ViT instance to add hooks to
15
+ get_hook (callable): a function that takes an index and returns a hook
16
+
17
+ Returns:
18
+ a list of RemovableHandle instances
19
+ """
20
+ return (
21
+ [model.vit.embeddings.patch_embeddings.register_forward_hook(get_hook(0))]
22
+ + [
23
+ layer.register_forward_pre_hook(get_hook(i + 1))
24
+ for i, layer in enumerate(model.vit.encoder.layer)
25
+ ]
26
+ + [
27
+ model.vit.encoder.layer[-1].register_forward_hook(
28
+ get_hook(len(model.vit.encoder.layer) + 1)
29
+ )
30
+ ]
31
+ )
32
+
33
+
34
+ def vit_getter(
35
+ model: ViTForImageClassification, x: Tensor
36
+ ) -> tuple[Tensor, list[Tensor]]:
37
+ """A function that returns the logits and hidden states of the model.
38
+
39
+ Args:
40
+ model (ViTForImageClassification): the ViT instance to use for the forward pass
41
+ x (Tensor): the input to the model
42
+
43
+ Returns:
44
+ a tuple of the model's logits and hidden states
45
+ """
46
+ hidden_states_ = []
47
+
48
+ def get_hook(i: int) -> callable:
49
+ def hook(_: Module, inputs: tuple, outputs: Optional[tuple] = None):
50
+ if i == 0:
51
+ hidden_states_.append(outputs)
52
+ elif 1 <= i <= len(model.vit.encoder.layer):
53
+ hidden_states_.append(inputs[0])
54
+ elif i == len(model.vit.encoder.layer) + 1:
55
+ hidden_states_.append(outputs[0])
56
+
57
+ return hook
58
+
59
+ handles = _add_hooks(model, get_hook)
60
+ try:
61
+ logits = model(x).logits
62
+ finally:
63
+ for handle in handles:
64
+ handle.remove()
65
+
66
+ return logits, hidden_states_
67
+
68
+
69
+ def vit_setter(
70
+ model: ViTForImageClassification, x: Tensor, hidden_states: list[Optional[Tensor]]
71
+ ) -> tuple[Tensor, list[Tensor]]:
72
+ """A function that sets some of the model's hidden states and returns its (new) logits
73
+ and hidden states after another forward pass.
74
+
75
+ Args:
76
+ model (ViTForImageClassification): the ViT instance to use for the forward pass
77
+ x (Tensor): the input to the model
78
+ hidden_states (list[Optional[Tensor]]): a list, with each element corresponding to
79
+ a hidden state to set or None to calculate anew for that index
80
+
81
+ Returns:
82
+ a tuple of the model's logits and (new) hidden states
83
+ """
84
+ hidden_states_ = []
85
+
86
+ def get_hook(i: int) -> callable:
87
+ def hook(
88
+ _: Module, inputs: tuple, outputs: Optional[tuple] = None
89
+ ) -> Optional[Union[tuple, Tensor]]:
90
+ if i == 0:
91
+ if hidden_states[i] is not None:
92
+ # print(hidden_states[i].shape)
93
+ hidden_states_.append(hidden_states[i][:, 1:])
94
+ return hidden_states_[-1]
95
+ else:
96
+ hidden_states_.append(outputs)
97
+
98
+ elif 1 <= i <= len(model.vit.encoder.layer):
99
+ if hidden_states[i] is not None:
100
+ hidden_states_.append(hidden_states[i])
101
+ return (hidden_states[i],) + inputs[1:]
102
+ else:
103
+ hidden_states_.append(inputs[0])
104
+
105
+ elif i == len(model.vit.encoder.layer) + 1:
106
+ if hidden_states[i] is not None:
107
+ hidden_states_.append(hidden_states[i])
108
+ return (hidden_states[i],) + outputs[1:]
109
+ else:
110
+ hidden_states_.append(outputs[0])
111
+
112
+ return hook
113
+
114
+ handles = _add_hooks(model, get_hook)
115
+
116
+ try:
117
+ logits = model(x).logits
118
+ finally:
119
+ for handle in handles:
120
+ handle.remove()
121
+
122
+ return logits, hidden_states_
code/utils/metrics.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File copied from
3
+ https://github.com/nicola-decao/diffmask/blob/master/diffmask/utils/util.py
4
+ """
5
+
6
+ import torch
7
+
8
+ from torch import Tensor
9
+
10
+
11
+ def accuracy_precision_recall_f1(
12
+ y_pred: Tensor, y_true: Tensor, average: bool = True
13
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
14
+ """Calculates the accuracy, precision, recall and f1 score given the predicted and true labels.
15
+
16
+ Args:
17
+ y_pred (Tensor): predicted labels
18
+ y_true (Tensor): true labels
19
+ average (bool): whether to average the scores or not
20
+
21
+ Returns:
22
+ a tuple of the accuracy, precision, recall and f1 score
23
+ """
24
+ M = confusion_matrix(y_pred, y_true)
25
+
26
+ tp = M.diagonal(dim1=-2, dim2=-1).float()
27
+
28
+ precision_den = M.sum(-2)
29
+ precision = torch.where(
30
+ precision_den == 0, torch.zeros_like(tp), tp / precision_den
31
+ )
32
+
33
+ recall_den = M.sum(-1)
34
+ recall = torch.where(recall_den == 0, torch.ones_like(tp), tp / recall_den)
35
+
36
+ f1_den = precision + recall
37
+ f1 = torch.where(
38
+ f1_den == 0, torch.zeros_like(tp), 2 * (precision * recall) / f1_den
39
+ )
40
+
41
+ # noinspection PyTypeChecker
42
+ return ((y_pred == y_true).float().mean(-1),) + (
43
+ tuple(e.mean(-1) for e in (precision, recall, f1))
44
+ if average
45
+ else (precision, recall, f1)
46
+ )
47
+
48
+
49
+ def confusion_matrix(y_pred: Tensor, y_true: Tensor) -> Tensor:
50
+ """Creates a confusion matrix given the predicted and true labels."""
51
+ device = y_pred.device
52
+ labels = max(y_pred.max().item() + 1, y_true.max().item() + 1)
53
+
54
+ return (
55
+ (
56
+ torch.stack((y_true, y_pred), -1).unsqueeze(-2).unsqueeze(-2)
57
+ == torch.stack(
58
+ (
59
+ torch.arange(labels, device=device).unsqueeze(-1).repeat(1, labels),
60
+ torch.arange(labels, device=device).unsqueeze(-2).repeat(labels, 1),
61
+ ),
62
+ -1,
63
+ )
64
+ )
65
+ .all(-1)
66
+ .sum(-3)
67
+ )
code/utils/optimizer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File copied from
3
+ https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py
4
+ """
5
+
6
+ import torch
7
+ import torch.optim as optim
8
+
9
+ from collections import defaultdict
10
+ from torch import Tensor
11
+ from torch.optim.optimizer import Optimizer
12
+ from typing import Iterable, Optional, Union
13
+
14
+
15
+ _params_type = Union[Iterable[Tensor], Iterable[dict]]
16
+
17
+
18
+ class Lookahead(Optimizer):
19
+ """Lookahead optimizer: https://arxiv.org/abs/1907.08610"""
20
+
21
+ # noinspection PyMissingConstructor
22
+ def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6):
23
+ if not 0.0 <= alpha <= 1.0:
24
+ raise ValueError(f"Invalid slow update rate: {alpha}")
25
+ if not 1 <= k:
26
+ raise ValueError(f"Invalid lookahead steps: {k}")
27
+ defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
28
+ self.base_optimizer = base_optimizer
29
+ self.param_groups = self.base_optimizer.param_groups
30
+ self.defaults = base_optimizer.defaults
31
+ self.defaults.update(defaults)
32
+ self.state = defaultdict(dict)
33
+ # manually add our defaults to the param groups
34
+ for name, default in defaults.items():
35
+ for group in self.param_groups:
36
+ group.setdefault(name, default)
37
+
38
+ def update_slow(self, group: dict):
39
+ for fast_p in group["params"]:
40
+ if fast_p.grad is None:
41
+ continue
42
+ param_state = self.state[fast_p]
43
+ if "slow_buffer" not in param_state:
44
+ param_state["slow_buffer"] = torch.empty_like(fast_p.data)
45
+ param_state["slow_buffer"].copy_(fast_p.data)
46
+ slow = param_state["slow_buffer"]
47
+ slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"])
48
+ fast_p.data.copy_(slow)
49
+
50
+ def sync_lookahead(self):
51
+ for group in self.param_groups:
52
+ self.update_slow(group)
53
+
54
+ def step(self, closure: Optional[callable] = None) -> Optional[float]:
55
+ # print(self.k)
56
+ # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
57
+ loss = self.base_optimizer.step(closure)
58
+ for group in self.param_groups:
59
+ group["lookahead_step"] += 1
60
+ if group["lookahead_step"] % group["lookahead_k"] == 0:
61
+ self.update_slow(group)
62
+ return loss
63
+
64
+ def state_dict(self) -> dict:
65
+ fast_state_dict = self.base_optimizer.state_dict()
66
+ slow_state = {
67
+ (id(k) if isinstance(k, torch.Tensor) else k): v
68
+ for k, v in self.state.items()
69
+ }
70
+ fast_state = fast_state_dict["state"]
71
+ param_groups = fast_state_dict["param_groups"]
72
+ return {
73
+ "state": fast_state,
74
+ "slow_state": slow_state,
75
+ "param_groups": param_groups,
76
+ }
77
+
78
+ def load_state_dict(self, state_dict: dict):
79
+ fast_state_dict = {
80
+ "state": state_dict["state"],
81
+ "param_groups": state_dict["param_groups"],
82
+ }
83
+ self.base_optimizer.load_state_dict(fast_state_dict)
84
+
85
+ # We want to restore the slow state, but share param_groups reference
86
+ # with base_optimizer. This is a bit redundant but least code
87
+ slow_state_new = False
88
+ if "slow_state" not in state_dict:
89
+ print("Loading state_dict from optimizer without Lookahead applied.")
90
+ state_dict["slow_state"] = defaultdict(dict)
91
+ slow_state_new = True
92
+ slow_state_dict = {
93
+ "state": state_dict["slow_state"],
94
+ "param_groups": state_dict[
95
+ "param_groups"
96
+ ], # this is pointless but saves code
97
+ }
98
+ super(Lookahead, self).load_state_dict(slow_state_dict)
99
+ self.param_groups = (
100
+ self.base_optimizer.param_groups
101
+ ) # make both ref same container
102
+ if slow_state_new:
103
+ # reapply defaults to catch missing lookahead specific ones
104
+ for name, default in self.defaults.items():
105
+ for group in self.param_groups:
106
+ group.setdefault(name, default)
107
+
108
+
109
+ def LookaheadAdam(
110
+ params: _params_type,
111
+ lr: float = 1e-3,
112
+ betas: tuple[float, float] = (0.9, 0.999),
113
+ eps: float = 1e-08,
114
+ weight_decay: float = 0,
115
+ amsgrad: bool = False,
116
+ lalpha: float = 0.5,
117
+ k: int = 6,
118
+ ):
119
+ return Lookahead(
120
+ torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k
121
+ )
122
+
123
+
124
+ def LookaheadRAdam(
125
+ params: _params_type,
126
+ lr: float = 1e-3,
127
+ betas: tuple[float, float] = (0.9, 0.999),
128
+ eps: float = 1e-8,
129
+ weight_decay: float = 0,
130
+ lalpha: float = 0.5,
131
+ k: int = 6,
132
+ ):
133
+ return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k)
134
+
135
+
136
+ def LookaheadRMSprop(
137
+ params: _params_type,
138
+ lr: float = 1e-2,
139
+ alpha: float = 0.99,
140
+ eps: float = 1e-08,
141
+ weight_decay: float = 0,
142
+ momentum: float = 0,
143
+ centered: bool = False,
144
+ lalpha: float = 0.5,
145
+ k: int = 6,
146
+ ):
147
+ return Lookahead(
148
+ torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered),
149
+ lalpha,
150
+ k,
151
+ )
code/utils/plot.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+ from pytorch_lightning import LightningModule
6
+ from pytorch_lightning.callbacks import Callback
7
+ from pytorch_lightning.loggers import WandbLogger
8
+ from pytorch_lightning.trainer import Trainer
9
+ from torch import Tensor
10
+
11
+
12
+ @torch.no_grad()
13
+ def unnormalize(
14
+ images: Tensor,
15
+ mean: tuple[float] = (0.5, 0.5, 0.5),
16
+ std: tuple[float] = (0.5, 0.5, 0.5),
17
+ ) -> Tensor:
18
+ """Reverts the normalization transformation applied before ViT.
19
+
20
+ Args:
21
+ images (Tensor): a batch of images
22
+ mean (tuple[int]): the means used for normalization - defaults to (0.5, 0.5, 0.5)
23
+ std (tuple[int]): the stds used for normalization - defaults to (0.5, 0.5, 0.5)
24
+
25
+ Returns:
26
+ the un-normalized batch of images
27
+ """
28
+ unnormalized_images = images.clone()
29
+ for i, (m, s) in enumerate(zip(mean, std)):
30
+ unnormalized_images[:, i, :, :].mul_(s).add_(m)
31
+
32
+ return unnormalized_images
33
+
34
+
35
+ @torch.no_grad()
36
+ def smoothen(mask: Tensor, patch_size: int = 16) -> Tensor:
37
+ """Smoothens a mask by downsampling it and re-upsampling it
38
+ with bi-linear interpolation.
39
+
40
+ Args:
41
+ mask (Tensor): a 2D float torch tensor with values in [0, 1]
42
+ patch_size (int): the patch size in pixels
43
+
44
+ Returns:
45
+ a smoothened mask at the pixel level
46
+ """
47
+ device = mask.device
48
+ (h, w) = mask.shape
49
+ mask = cv2.resize(
50
+ mask.cpu().numpy(),
51
+ (h // patch_size, w // patch_size),
52
+ interpolation=cv2.INTER_NEAREST,
53
+ )
54
+ mask = cv2.resize(mask, (h, w), interpolation=cv2.INTER_LINEAR)
55
+ return torch.tensor(mask).to(device)
56
+
57
+
58
+ @torch.no_grad()
59
+ def draw_mask_on_image(image: Tensor, mask: Tensor) -> Tensor:
60
+ """Overlays a dimming mask on the image.
61
+
62
+ Args:
63
+ image (Tensor): a float torch tensor with values in [0, 1]
64
+ mask (Tensor): a float torch tensor with values in [0, 1]
65
+
66
+ Returns:
67
+ the image with parts of it dimmed according to the mask
68
+ """
69
+ masked_image = image * mask
70
+
71
+ return masked_image
72
+
73
+
74
+ @torch.no_grad()
75
+ def draw_heatmap_on_image(
76
+ image: Tensor,
77
+ mask: Tensor,
78
+ colormap: int = cv2.COLORMAP_JET,
79
+ ) -> Tensor:
80
+ """Overlays a heatmap on the image.
81
+
82
+ Args:
83
+ image (Tensor): a float torch tensor with values in [0, 1]
84
+ mask (Tensor): a float torch tensor with values in [0, 1]
85
+ colormap (int): the OpenCV colormap to be used
86
+
87
+ Returns:
88
+ the image with the heatmap overlaid
89
+ """
90
+ # Save the device of the image
91
+ original_device = image.device
92
+
93
+ # Convert image & mask to numpy
94
+ image = image.permute(1, 2, 0).cpu().numpy()
95
+ mask = mask.cpu().numpy()
96
+
97
+ # Create heatmap
98
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
99
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
100
+ heatmap = np.float32(heatmap) / 255
101
+
102
+ # Overlay heatmap on image
103
+ masked_image = image + heatmap
104
+ masked_image = masked_image / np.max(masked_image)
105
+
106
+ return torch.tensor(masked_image).permute(2, 0, 1).to(original_device)
107
+
108
+
109
+ def _prepare_samples(images: Tensor, masks: Tensor) -> tuple[Tensor, list[float]]:
110
+ """Prepares the samples for the masking/heatmap visualization.
111
+
112
+ Args:
113
+ images (Tensor): a float torch tensor with values in [0, 1]
114
+ masks (Tensor): a float torch tensor with values in [0, 1]
115
+
116
+ Returns
117
+ a tuple of image triplets (img, masked, heatmap) and their
118
+ corresponding masking percentages
119
+ """
120
+ num_channels = images[0].shape[0]
121
+
122
+ # Smoothen masks
123
+ masks = [smoothen(m) for m in masks]
124
+
125
+ # Un-normalize images
126
+ if num_channels == 1:
127
+ images = [
128
+ torch.repeat_interleave(img, 3, 0)
129
+ for img in unnormalize(images, mean=(0.5,), std=(0.5,))
130
+ ]
131
+ else:
132
+ images = [img for img in unnormalize(images)]
133
+
134
+ # Draw mask on sample images
135
+ images_with_mask = [
136
+ draw_mask_on_image(image, mask) for image, mask in zip(images, masks)
137
+ ]
138
+
139
+ # Draw heatmap on sample images
140
+ images_with_heatmap = [
141
+ draw_heatmap_on_image(image, mask) for image, mask in zip(images, masks)
142
+ ]
143
+
144
+ # Chunk to triplets (image, masked image, heatmap)
145
+ samples = torch.cat(
146
+ [
147
+ torch.cat(images, dim=2),
148
+ torch.cat(images_with_mask, dim=2),
149
+ torch.cat(images_with_heatmap, dim=2),
150
+ ],
151
+ dim=1,
152
+ ).chunk(len(images), dim=-1)
153
+
154
+ # Compute masking percentages
155
+ masked_pixels_percentages = [
156
+ 100 * (1 - torch.stack(masks)[i].mean(-1).mean(-1).item())
157
+ for i in range(len(masks))
158
+ ]
159
+
160
+ return samples, masked_pixels_percentages
161
+
162
+
163
+ def log_masks(images: Tensor, masks: Tensor, key: str, logger: WandbLogger):
164
+ """Logs a set of images with their masks to WandB.
165
+
166
+ Args:
167
+ images (Tensor): a float torch tensor with values in [0, 1]
168
+ masks (Tensor): a float torch tensor with values in [0, 1]
169
+ key (str): the key to log the images with
170
+ logger (WandbLogger): the logger to log the images to
171
+ """
172
+ samples, masked_pixels_percentages = _prepare_samples(images, masks)
173
+
174
+ # Log with wandb
175
+ logger.log_image(
176
+ key=key,
177
+ images=list(samples),
178
+ caption=[
179
+ f"Masking: {masked_pixels_percentage:.2f}% "
180
+ for masked_pixels_percentage in masked_pixels_percentages
181
+ ],
182
+ )
183
+
184
+
185
+ class DrawMaskCallback(Callback):
186
+ def __init__(
187
+ self,
188
+ samples: list[tuple[Tensor, Tensor]],
189
+ log_every_n_steps: int = 200,
190
+ key: str = "",
191
+ ):
192
+ """A callback that logs VisionDiffMask masks for the sample images to WandB.
193
+
194
+ Args:
195
+ samples (list[tuple[Tensor, Tensor]): a list of image, label pairs
196
+ log_every_n_steps (int): the interval in steps to log the masks to WandB
197
+ key (str): the key to log the images with (allows for multiple batches)
198
+ """
199
+ self.images = torch.stack([img for img in samples[0]])
200
+ self.labels = [label.item() for label in samples[1]]
201
+ self.log_every_n_steps = log_every_n_steps
202
+ self.key = key
203
+
204
+ def _log_masks(self, trainer: Trainer, pl_module: LightningModule):
205
+ # Predict mask
206
+ with torch.no_grad():
207
+ pl_module.eval()
208
+ outputs = pl_module.get_mask(self.images)
209
+ pl_module.train()
210
+
211
+ # Unnest outputs
212
+ masks = outputs["mask"]
213
+ kl_divs = outputs["kl_div"]
214
+ pred_classes = outputs["pred_class"].cpu()
215
+
216
+ # Prepare masked samples for logging
217
+ samples, masked_pixels_percentages = _prepare_samples(self.images, masks)
218
+
219
+ # Log with wandb
220
+ trainer.logger.log_image(
221
+ key="DiffMask " + self.key,
222
+ images=list(samples),
223
+ caption=[
224
+ f"Masking: {masked_pixels_percentage:.2f}% "
225
+ f"\n KL-divergence: {kl_div:.4f} "
226
+ f"\n Class: {pl_module.model.config.id2label[label]} "
227
+ f"\n Predicted Class: {pl_module.model.config.id2label[pred_class.item()]}"
228
+ for masked_pixels_percentage, kl_div, label, pred_class in zip(
229
+ masked_pixels_percentages, kl_divs, self.labels, pred_classes
230
+ )
231
+ ],
232
+ )
233
+
234
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule):
235
+ # Transfer sample images to correct device
236
+ self.images = self.images.to(pl_module.device)
237
+
238
+ # Log sample images
239
+ self._log_masks(trainer, pl_module)
240
+
241
+ def on_train_batch_end(
242
+ self,
243
+ trainer: Trainer,
244
+ pl_module: LightningModule,
245
+ outputs: dict,
246
+ batch: tuple[Tensor, Tensor],
247
+ batch_idx: int,
248
+ unused: int = 0,
249
+ ):
250
+ # Log sample images every n steps
251
+ if batch_idx % self.log_every_n_steps == 0:
252
+ self._log_masks(trainer, pl_module)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ pytorch_lightning
4
+ torch
5
+ torchvision
6
+ transformers
7
+