Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
from .gates import DiffMaskGateInput | |
from argparse import ArgumentParser | |
from math import sqrt | |
from pytorch_lightning.core.optimizer import LightningOptimizer | |
from torch import Tensor | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import _LRScheduler | |
from transformers import ( | |
get_constant_schedule_with_warmup, | |
get_constant_schedule, | |
ViTForImageClassification, | |
) | |
from transformers.models.vit.configuration_vit import ViTConfig | |
from typing import Optional, Union | |
from utils.getters_setters import vit_getter, vit_setter | |
from utils.metrics import accuracy_precision_recall_f1 | |
from utils.optimizer import LookaheadAdam | |
class ImageInterpretationNet(pl.LightningModule): | |
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: | |
parser = parent_parser.add_argument_group("Vision DiffMask") | |
parser.add_argument( | |
"--alpha", | |
type=float, | |
default=20.0, | |
help="Initial value for the Lagrangian", | |
) | |
parser.add_argument( | |
"--lr", | |
type=float, | |
default=2e-5, | |
help="Learning rate for DiffMask.", | |
) | |
parser.add_argument( | |
"--eps", | |
type=float, | |
default=0.1, | |
help="KL divergence tolerance.", | |
) | |
parser.add_argument( | |
"--no_placeholder", | |
action="store_true", | |
help="Whether to not use placeholder", | |
) | |
parser.add_argument( | |
"--lr_placeholder", | |
type=float, | |
default=1e-3, | |
help="Learning for mask vectors.", | |
) | |
parser.add_argument( | |
"--lr_alpha", | |
type=float, | |
default=0.3, | |
help="Learning rate for lagrangian optimizer.", | |
) | |
parser.add_argument( | |
"--mul_activation", | |
type=float, | |
default=15.0, | |
help="Value to multiply gate activations.", | |
) | |
parser.add_argument( | |
"--add_activation", | |
type=float, | |
default=8.0, | |
help="Value to add to gate activations.", | |
) | |
parser.add_argument( | |
"--weighted_layer_distribution", | |
action="store_true", | |
help="Whether to use a weighted distribution when picking a layer in DiffMask forward.", | |
) | |
return parent_parser | |
# Declare variables that will be initialized later | |
model: ViTForImageClassification | |
def __init__( | |
self, | |
model_cfg: ViTConfig, | |
alpha: float = 1, | |
lr: float = 3e-4, | |
eps: float = 0.1, | |
eps_valid: float = 0.8, | |
acc_valid: float = 0.75, | |
lr_placeholder: float = 1e-3, | |
lr_alpha: float = 0.3, | |
mul_activation: float = 10.0, | |
add_activation: float = 5.0, | |
placeholder: bool = True, | |
weighted_layer_pred: bool = False, | |
): | |
"""A PyTorch Lightning Module for the VisionDiffMask model on the Vision Transformer. | |
Args: | |
model_cfg (ViTConfig): the configuration of the Vision Transformer model | |
alpha (float): the initial value for the Lagrangian | |
lr (float): the learning rate for the DiffMask gates | |
eps (float): the tolerance for the KL divergence | |
eps_valid (float): the tolerance for the KL divergence in the validation step | |
acc_valid (float): the accuracy threshold for the validation step | |
lr_placeholder (float): the learning rate for the learnable masking embeddings | |
lr_alpha (float): the learning rate for the Lagrangian | |
mul_activation (float): the value to multiply the gate activations by | |
add_activation (float): the value to add to the gate activations | |
placeholder (bool): whether to use placeholder embeddings or a zero vector | |
weighted_layer_pred (bool): whether to use a weighted distribution when picking a layer | |
""" | |
super().__init__() | |
# Save the hyperparameters | |
self.save_hyperparameters() | |
# Create DiffMask instance | |
self.gate = DiffMaskGateInput( | |
hidden_size=model_cfg.hidden_size, | |
hidden_attention=model_cfg.hidden_size // 4, | |
num_hidden_layers=model_cfg.num_hidden_layers + 2, | |
max_position_embeddings=1, | |
mul_activation=mul_activation, | |
add_activation=add_activation, | |
placeholder=placeholder, | |
) | |
# Create the Lagrangian values for the dual optimization | |
self.alpha = torch.nn.ParameterList( | |
[ | |
torch.nn.Parameter(torch.ones(()) * alpha) | |
for _ in range(model_cfg.num_hidden_layers + 2) | |
] | |
) | |
# Register buffers for running metrics | |
self.register_buffer( | |
"running_acc", torch.ones((model_cfg.num_hidden_layers + 2,)) | |
) | |
self.register_buffer( | |
"running_l0", torch.ones((model_cfg.num_hidden_layers + 2,)) | |
) | |
self.register_buffer( | |
"running_steps", torch.zeros((model_cfg.num_hidden_layers + 2,)) | |
) | |
def set_vision_transformer(self, model: ViTForImageClassification): | |
"""Set the Vision Transformer model to be used with this module.""" | |
# Save the model instance as a class attribute | |
self.model = model | |
# Freeze the model's parameters | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
def forward_explainer( | |
self, x: Tensor, attribution: bool = False | |
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int, int]: | |
"""Performs a forward pass through the explainer (VisionDiffMask) model.""" | |
# Get the original logits and hidden states from the model | |
logits_orig, hidden_states = vit_getter(self.model, x) | |
# Add [CLS] token to deal with shape mismatch in self.gate() call | |
patch_embeddings = hidden_states[0] | |
batch_size = len(patch_embeddings) | |
cls_tokens = self.model.vit.embeddings.cls_token.expand(batch_size, -1, -1) | |
hidden_states[0] = torch.cat((cls_tokens, patch_embeddings), dim=1) | |
# Select the layer to generate the mask from in this pass | |
n_hidden = len(hidden_states) | |
if self.hparams.weighted_layer_pred: | |
# If weighted layer prediction is enabled, use a weighted distribution | |
# instead of uniformly picking a layer after a number of steps | |
low_weight = ( | |
lambda i: self.running_acc[i] > 0.75 | |
and self.running_l0[i] < 0.1 | |
and self.running_steps[i] > 100 | |
) | |
layers = torch.tensor(list(range(n_hidden))) | |
p = torch.tensor([0.1 if low_weight(i) else 1 for i in range(n_hidden)]) | |
p = p / p.sum() | |
idx = p.multinomial(num_samples=1) | |
layer_pred = layers[idx].item() | |
else: | |
layer_pred = torch.randint(n_hidden, ()).item() | |
# Set the layer to drop to 0, since we are only interested in masking the input | |
layer_drop = 0 | |
( | |
new_hidden_state, | |
gates, | |
expected_L0, | |
gates_full, | |
expected_L0_full, | |
) = self.gate( | |
hidden_states=hidden_states, | |
layer_pred=None | |
if attribution | |
else layer_pred, # if attribution, we get all the hidden states | |
) | |
# Create the list of the new hidden states for the new forward pass | |
new_hidden_states = ( | |
[None] * layer_drop | |
+ [new_hidden_state] | |
+ [None] * (n_hidden - layer_drop - 1) | |
) | |
# Get the new logits from the masked input | |
logits, _ = vit_setter(self.model, x, new_hidden_states) | |
return ( | |
logits, | |
logits_orig, | |
gates, | |
expected_L0, | |
gates_full, | |
expected_L0_full, | |
layer_drop, | |
layer_pred, | |
) | |
def get_mask(self, x: Tensor, | |
idx: int = -1, | |
aggregated_mask: bool = True, | |
) -> dict[str, Tensor]: | |
""" | |
Generates a mask for the given input. | |
Args: | |
x: the input to generate the mask for | |
idx: the index of the layer to generate the mask from | |
aggregated_mask: whether to use an aggregative mask from each layer | |
Returns: | |
a dictionary containing the mask, kl divergence and the predicted class | |
""" | |
# Pass from forward explainer with attribution=True | |
( | |
logits, | |
logits_orig, | |
gates, | |
expected_L0, | |
gates_full, | |
expected_L0_full, | |
layer_drop, | |
layer_pred, | |
) = self.forward_explainer(x, attribution=True) | |
# Calculate KL-divergence | |
kl_div = torch.distributions.kl_divergence( | |
torch.distributions.Categorical(logits=logits_orig), | |
torch.distributions.Categorical(logits=logits), | |
) | |
# Get predicted class | |
pred_class = logits.argmax(-1) | |
# Calculate mask | |
if aggregated_mask: | |
mask = expected_L0_full[:, :, idx].exp() | |
else: | |
mask = gates_full[:, :, idx] | |
mask = mask[:, 1:] | |
C, H, W = x.shape[1:] # channels, height, width | |
B, P = mask.shape # batch, patches | |
N = int(sqrt(P)) # patches per side | |
S = int(H / N) # patch size | |
# Reshape mask to match input shape | |
mask = mask.reshape(B, 1, N, N) | |
mask = F.interpolate(mask, scale_factor=S) | |
mask = mask.reshape(B, H, W) | |
return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class} | |
def forward(self, x: Tensor) -> Tensor: | |
return self.model(x).logits | |
def training_step(self, batch: tuple[Tensor, Tensor], *args, **kwargs) -> dict: | |
# Unpack the batch | |
x, y = batch | |
# Pass the batch through the explainer (VisionDiffMask) model | |
( | |
logits, | |
logits_orig, | |
gates, | |
expected_L0, | |
gates_full, | |
expected_L0_full, | |
layer_drop, | |
layer_pred, | |
) = self.forward_explainer(x) | |
# Calculate the KL-divergence loss term | |
loss_c = ( | |
torch.distributions.kl_divergence( | |
torch.distributions.Categorical(logits=logits_orig), | |
torch.distributions.Categorical(logits=logits), | |
) | |
- self.hparams.eps | |
) | |
# Calculate the L0 loss term | |
loss_g = expected_L0.mean(-1) | |
# Calculate the full loss term | |
loss = self.alpha[layer_pred] * loss_c + loss_g | |
# Calculate the accuracy | |
acc, _, _, _ = accuracy_precision_recall_f1( | |
logits.argmax(-1), logits_orig.argmax(-1), average=True | |
) | |
# Calculate the average L0 loss | |
l0 = expected_L0.exp().mean(-1) | |
outputs_dict = { | |
"loss_c": loss_c.mean(-1), | |
"loss_g": loss_g.mean(-1), | |
"alpha": self.alpha[layer_pred].mean(-1), | |
"acc": acc, | |
"l0": l0.mean(-1), | |
"layer_pred": layer_pred, | |
"r_acc": self.running_acc[layer_pred], | |
"r_l0": self.running_l0[layer_pred], | |
"r_steps": self.running_steps[layer_pred], | |
"debug_loss": loss.mean(-1), | |
} | |
outputs_dict = { | |
"loss": loss.mean(-1), | |
**outputs_dict, | |
"log": outputs_dict, | |
"progress_bar": outputs_dict, | |
} | |
self.log( | |
"loss", outputs_dict["loss"], on_step=True, on_epoch=True, prog_bar=True | |
) | |
self.log( | |
"loss_c", outputs_dict["loss_c"], on_step=True, on_epoch=True, prog_bar=True | |
) | |
self.log( | |
"loss_g", outputs_dict["loss_g"], on_step=True, on_epoch=True, prog_bar=True | |
) | |
self.log("acc", outputs_dict["acc"], on_step=True, on_epoch=True, prog_bar=True) | |
self.log("l0", outputs_dict["l0"], on_step=True, on_epoch=True, prog_bar=True) | |
self.log( | |
"alpha", outputs_dict["alpha"], on_step=True, on_epoch=True, prog_bar=True | |
) | |
outputs_dict = { | |
"{}{}".format("" if self.training else "val_", k): v | |
for k, v in outputs_dict.items() | |
} | |
if self.training: | |
self.running_acc[layer_pred] = ( | |
self.running_acc[layer_pred] * 0.9 + acc * 0.1 | |
) | |
self.running_l0[layer_pred] = ( | |
self.running_l0[layer_pred] * 0.9 + l0.mean(-1) * 0.1 | |
) | |
self.running_steps[layer_pred] += 1 | |
return outputs_dict | |
def validation_epoch_end(self, outputs: list[dict]): | |
outputs_dict = { | |
k: [e[k] for e in outputs if k in e] | |
for k in ("val_loss_c", "val_loss_g", "val_acc", "val_l0") | |
} | |
outputs_dict = {k: sum(v) / len(v) for k, v in outputs_dict.items()} | |
outputs_dict["val_loss_c"] += self.hparams.eps | |
outputs_dict = { | |
"val_loss": outputs_dict["val_l0"] | |
if outputs_dict["val_loss_c"] <= self.hparams.eps_valid | |
and outputs_dict["val_acc"] >= self.hparams.acc_valid | |
else torch.full_like(outputs_dict["val_l0"], float("inf")), | |
**outputs_dict, | |
"log": outputs_dict, | |
} | |
return outputs_dict | |
def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]: | |
optimizers = [ | |
LookaheadAdam( | |
params=[ | |
{ | |
"params": self.gate.g_hat.parameters(), | |
"lr": self.hparams.lr, | |
}, | |
{ | |
"params": self.gate.placeholder.parameters() | |
if isinstance(self.gate.placeholder, torch.nn.ParameterList) | |
else [self.gate.placeholder], | |
"lr": self.hparams.lr_placeholder, | |
}, | |
], | |
# centered=True, # this is for LookaheadRMSprop | |
), | |
LookaheadAdam( | |
params=[self.alpha] | |
if isinstance(self.alpha, torch.Tensor) | |
else self.alpha.parameters(), | |
lr=self.hparams.lr_alpha, | |
), | |
] | |
schedulers = [ | |
{ | |
"scheduler": get_constant_schedule_with_warmup(optimizers[0], 12 * 100), | |
"interval": "step", | |
}, | |
get_constant_schedule(optimizers[1]), | |
] | |
return optimizers, schedulers | |
def optimizer_step( | |
self, | |
epoch: int, | |
batch_idx: int, | |
optimizer: Union[Optimizer, LightningOptimizer], | |
optimizer_idx: int = 0, | |
optimizer_closure: Optional[callable] = None, | |
on_tpu: bool = False, | |
using_native_amp: bool = False, | |
using_lbfgs: bool = False, | |
): | |
# Optimizer 0: Minimize loss w.r.t. DiffMask's parameters | |
if optimizer_idx == 0: | |
# Gradient ascent on the model's parameters | |
optimizer.step(closure=optimizer_closure) | |
optimizer.zero_grad() | |
for g in optimizer.param_groups: | |
for p in g["params"]: | |
p.grad = None | |
# Optimizer 1: Maximize loss w.r.t. the Langrangian | |
elif optimizer_idx == 1: | |
# Reverse the sign of the Langrangian's gradients | |
for i in range(len(self.alpha)): | |
if self.alpha[i].grad: | |
self.alpha[i].grad *= -1 | |
# Gradient ascent on the Langrangian | |
optimizer.step(closure=optimizer_closure) | |
optimizer.zero_grad() | |
for g in optimizer.param_groups: | |
for p in g["params"]: | |
p.grad = None | |
# Clip the Lagrangian's values | |
for i in range(len(self.alpha)): | |
self.alpha[i].data = torch.where( | |
self.alpha[i].data < 0, | |
torch.full_like(self.alpha[i].data, 0), | |
self.alpha[i].data, | |
) | |
self.alpha[i].data = torch.where( | |
self.alpha[i].data > 200, | |
torch.full_like(self.alpha[i].data, 200), | |
self.alpha[i].data, | |
) | |
def on_save_checkpoint(self, ckpt: dict): | |
# Remove VIT from checkpoint as we can load it dynamically | |
keys = list(ckpt["state_dict"].keys()) | |
for key in keys: | |
if key.startswith("model."): | |
del ckpt["state_dict"][key] | |