| import torch |
| import torchattacks |
| from PIL import Image |
| from typing import List, Tuple, Optional |
| import numpy as np |
| import warnings |
| from pathlib import Path |
| import types |
| import os |
|
|
| try: |
| import torchvision.models as tv_models |
| except Exception: |
| tv_models = None |
|
|
| try: |
| import timm |
| except Exception: |
| timm = None |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| except Exception: |
| hf_hub_download = None |
|
|
| def capture_outputs_and_attentions(model, x_norm: torch.Tensor): |
| """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT. |
| Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada. |
| Funciona para modelos do timm com atributo 'blocks' e submódulo 'attn'. |
| """ |
| |
| attentions: List[torch.Tensor] = [] |
|
|
| def make_attention_hook(): |
| def hook(module, input, output): |
| x = input[0] |
| B, N, C = x.shape |
| if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): |
| return |
| qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind(0) |
| scale = (C // module.num_heads) ** -0.5 |
| attn = (q @ k.transpose(-2, -1)) * scale |
| attn = attn.softmax(dim=-1) |
| attentions.append(attn.detach()) |
| return hook |
|
|
| hooks = [] |
| if hasattr(model, 'blocks'): |
| for block in model.blocks: |
| if hasattr(block, 'attn'): |
| hooks.append(block.attn.register_forward_hook(make_attention_hook())) |
|
|
| model.eval() |
| outputs = model(x_norm) |
|
|
| for h in hooks: |
| h.remove() |
|
|
| attentions = [a.cpu() for a in attentions] |
| return outputs, attentions |
|
|
| def denormalize_imagenet(tensor: torch.Tensor) -> torch.Tensor: |
| """ |
| Reverte a normalização ImageNet de um tensor. |
| |
| Args: |
| tensor: Tensor normalizado (CxHxW) com mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| |
| Returns: |
| Tensor desnormalizado com valores em [0, 1] |
| """ |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device) |
| |
| |
| denorm = tensor * std + mean |
| |
| |
| return torch.clamp(denorm, 0, 1) |
|
|
| def tensor_to_pil(tensor: torch.Tensor, denormalize: bool = True) -> Image.Image: |
| """ |
| Converte tensor (CxHxW) para PIL Image RGB. |
| |
| Args: |
| tensor: Tensor com shape (C, H, W) |
| denormalize: Se True, aplica desnormalização ImageNet antes da conversão |
| |
| Returns: |
| PIL Image no espaço RGB [0, 255] |
| """ |
| if denormalize: |
| tensor = denormalize_imagenet(tensor) |
| |
| |
| img_np = tensor.cpu().detach().numpy() |
| img_np = np.transpose(img_np, (1, 2, 0)) |
| img_np = (img_np * 255).clip(0, 255).astype(np.uint8) |
| return Image.fromarray(img_np, mode='RGB') |
|
|
| class FGSM(torchattacks.FGSM): |
| """ |
| Extensão do ataque FGSM (Fast Gradient Sign Method) que captura |
| a imagem original e a imagem adversarial final. |
| |
| FGSM é um ataque de 1 única iteração (non-iterative). |
| """ |
| def __init__(self, model, eps=0.03): |
| super().__init__(model, eps=eps) |
| self.iteration_images: List[Image.Image] = [] |
| self.iteration_tensors: List[torch.Tensor] = [] |
| |
| self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| |
| def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| """ |
| Executa o ataque FGSM e retorna: |
| - adv_images: tensor adversarial final |
| - iteration_images: lista com [imagem_original, imagem_adversarial] |
| """ |
| images = images.clone().detach().to(self.device) |
| labels = labels.clone().detach().to(self.device) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| |
| |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| |
| images_denorm = images * std + mean |
| |
| self.iteration_images = [] |
| self.iteration_tensors = [] |
| self.attentions_per_iter = [] |
| |
| |
| pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img_orig) |
| self.iteration_tensors.append(images.clone().detach()) |
| |
| |
| images.requires_grad = True |
| |
| outputs, attentions0 = capture_outputs_and_attentions(self.model, images) |
| self.attentions_per_iter.append([att for att in attentions0]) |
| |
| if self.targeted: |
| target_labels = self.get_target_label(images, labels) |
| cost = -loss(outputs, target_labels) |
| else: |
| cost = loss(outputs, labels) |
| |
| grad = torch.autograd.grad(cost, images, retain_graph=False, create_graph=False)[0] |
| |
| |
| |
| adv_images_denorm = images_denorm + self.eps * grad.sign() |
| adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach() |
| |
| |
| adv_images = (adv_images_denorm - mean) / std |
| |
| |
| pil_img_adv = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img_adv) |
| self.iteration_tensors.append(adv_images.clone().detach()) |
|
|
| |
| outputs_adv, attentions1 = capture_outputs_and_attentions(self.model, adv_images) |
| self.attentions_per_iter.append([att for att in attentions1]) |
| |
| return adv_images, self.iteration_images |
|
|
| class PGDIterations(torchattacks.PGD): |
| """ |
| Extensão do ataque PGD padrão que captura e retorna |
| as imagens adversariais de cada iteração como lista de PIL Images. |
| """ |
| def __init__(self, model, eps=0.05, alpha=0.005, steps=10, random_start=True): |
| |
| super().__init__(model, eps=eps, alpha=alpha, steps=steps, random_start=random_start) |
| self.iteration_images: List[Image.Image] = [] |
| self.iteration_tensors: List[torch.Tensor] = [] |
| self.attentions_per_iter: List[List[torch.Tensor]] = [] |
|
|
| def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| """ |
| Executa o ataque PGD e retorna: |
| - adv_images: tensor adversarial final |
| - iteration_images: lista de PIL Images (uma por iteração do ataque) |
| |
| Implementação adaptada para trabalhar com imagens normalizadas ImageNet. |
| """ |
| images = images.clone().detach().to(self.device) |
| labels = labels.clone().detach().to(self.device) |
| |
| |
| if self.targeted: |
| target_labels = self.get_target_label(images, labels) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| adv_images = images.clone().detach() |
| |
| |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| |
| |
| images_denorm = images * std + mean |
| adv_images_denorm = images_denorm.clone().detach() |
| |
| if self.random_start: |
| |
| adv_images_denorm = adv_images_denorm + torch.empty_like(adv_images_denorm).uniform_(-self.eps, self.eps) |
| adv_images_denorm = torch.clamp(adv_images_denorm, min=0, max=1).detach() |
|
|
| self.iteration_images = [] |
| self.iteration_tensors = [] |
| self.attentions_per_iter = [] |
| |
| |
| pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img_orig) |
| self.iteration_tensors.append(images.clone().detach()) |
| |
| outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) |
| self.attentions_per_iter.append([att for att in attentions0]) |
|
|
| for step_idx in range(self.steps): |
| |
| adv_images = (adv_images_denorm - mean) / std |
| adv_images.requires_grad = True |
| outputs, attentions = capture_outputs_and_attentions(self.model, adv_images) |
|
|
| |
| if self.targeted: |
| cost = -loss(outputs, target_labels) |
| else: |
| cost = loss(outputs, labels) |
|
|
| |
| grad = torch.autograd.grad(cost, adv_images, |
| retain_graph=False, create_graph=False)[0] |
|
|
| |
| |
| adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad.sign() |
| delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) |
| adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() |
|
|
| |
| adv_images_normalized = (adv_images_denorm - mean) / std |
| |
| |
| pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img) |
| self.iteration_tensors.append(adv_images_normalized.clone().detach()) |
| |
| self.attentions_per_iter.append([att for att in attentions]) |
|
|
| |
| adv_images = (adv_images_denorm - mean) / std |
| return adv_images, self.iteration_images |
| |
| class SAGA(torch.nn.Module): |
| """ |
| SAGA: Self-Attention Gradient Attack |
| |
| Ataque adversarial específico para Vision Transformers que multiplica |
| o gradiente FGSM pelo mapa de atenção do modelo, focando perturbações |
| nas regiões que o modelo considera importantes. |
| |
| Baseado em: https://github.com/MetaMain/ViTRobust |
| Paper: "On the Robustness of Vision Transformers to Adversarial Examples" (ICCV 2021) |
| """ |
|
|
| def __init__(self, model, eps=8/255, steps=10, discard_ratio: float = 0.0, |
| head_fusion: str = "mean", use_resnet: bool = False, |
| cnn_checkpoint_path: str = "resnet.pth", vit_weight=0.5): |
| """Implementação correta do SAGA baseada no código original (SelfAttentionGradientAttack). |
| |
| Parâmetros: |
| - model: Vision Transformer (deve expor atenções via forward ou função auxiliar em visualization utils) |
| - eps: orçamento L_inf máximo (em pixel space [0,1]) |
| - steps: número de iterações (FGSM iterativo) |
| - discard_ratio: razão de descarte usada no attention rollout |
| - head_fusion: estratégia de fusão de heads ('mean','max','min') |
| - use_resnet: se True, acumula gradiente de um backbone CNN externo e o mistura ao gradiente ponderado pela atenção |
| - cnn_checkpoint_path: caminho padrão do backbone CNN auxiliar (será carregado sob demanda) |
| """ |
| super().__init__() |
| self.model = model |
| self.eps = eps |
| self.steps = steps |
| self.eps_step = self.eps / max(1, steps) |
| self.discard_ratio = discard_ratio |
| self.head_fusion = head_fusion |
| self.use_resnet = use_resnet |
| |
| |
| |
| self.cnn_checkpoint_spec = cnn_checkpoint_path |
| self.cnn_model: Optional[torch.nn.Module] = None |
| self.vit_weight = vit_weight |
| self.device = next(model.parameters()).device |
| self.iteration_images: List[Image.Image] = [] |
| self.iteration_tensors: List[torch.Tensor] = [] |
| self.attention_masks_cache: List[np.ndarray] = [] |
| |
| |
| self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| self.loss_fn = torch.nn.CrossEntropyLoss() |
|
|
| @staticmethod |
| def _resolve_checkpoint_path(spec: object) -> Optional[Path]: |
| """Resolve um checkpoint local ou no Hugging Face Hub para um Path local. |
| |
| Formato suportado (HF): |
| - hf://owner/repo/path/to/file.pth |
| - hf://owner/repo@revision/path/to/file.pth |
| |
| Retorna None se não conseguir resolver. |
| """ |
| if spec is None: |
| return None |
|
|
| if isinstance(spec, Path): |
| return spec |
|
|
| if not isinstance(spec, str): |
| return None |
|
|
| s = spec.strip() |
| if s.startswith("hf://") or s.startswith("hf:"): |
| if hf_hub_download is None: |
| warnings.warn("huggingface_hub não está instalado; não é possível baixar checkpoint via HF Hub.") |
| return None |
|
|
| rest = s[len("hf://"):] if s.startswith("hf://") else s[len("hf:"):] |
| rest = rest.lstrip("/") |
| parts = [p for p in rest.split("/") if p] |
| if len(parts) < 3: |
| raise ValueError( |
| "Formato inválido para checkpoint HF. Use hf://owner/repo/path/to/file.pth" |
| ) |
|
|
| repo_part = "/".join(parts[:2]) |
| filename = "/".join(parts[2:]) |
| revision = None |
| if "@" in repo_part: |
| repo_id, revision = repo_part.split("@", 1) |
| else: |
| repo_id = repo_part |
|
|
| cache_dir = os.getenv("HF_HOME") or None |
| local_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir) |
| return Path(local_path) |
|
|
| return Path(s) |
|
|
| def _attention_map(self, images_norm: torch.Tensor, save: bool = False) -> torch.Tensor: |
| """Extrai mapa de atenção (rollout) e retorna tensor expandido [B,3,H,W] em [0,1]. |
| images_norm: imagens já normalizadas para o forward do modelo. |
| """ |
| |
| |
| raise RuntimeError("_attention_map should not be called directly; use integrated forward attention capture.") |
|
|
| def _capture_outputs_and_attentions(self, x_norm: torch.Tensor): |
| """Executa um forward único capturando atenções via hooks nas camadas de atenção do ViT. |
| Retorna (outputs, attentions_list) onde attentions_list é lista de tensores [B,H,T,T] por camada. |
| """ |
| attentions: List[torch.Tensor] = [] |
|
|
| def make_attention_hook(): |
| def hook(module, input, output): |
| |
| x = input[0] |
| B, N, C = x.shape |
| if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')): |
| return |
| qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind(0) |
| scale = (C // module.num_heads) ** -0.5 |
| attn = (q @ k.transpose(-2, -1)) * scale |
| attn = attn.softmax(dim=-1) |
| attentions.append(attn.detach()) |
| return hook |
|
|
| hooks = [] |
| if not hasattr(self.model, 'blocks'): |
| outputs = self.model(x_norm) |
| return outputs, [] |
| for block in self.model.blocks: |
| if hasattr(block, 'attn'): |
| hooks.append(block.attn.register_forward_hook(make_attention_hook())) |
|
|
| self.model.eval() |
| outputs = self.model(x_norm) |
|
|
| for h in hooks: |
| h.remove() |
|
|
| |
| attentions = [a.cpu() for a in attentions] |
| return outputs, attentions |
|
|
| def _load_cnn_backbone(self) -> Optional[torch.nn.Module]: |
| """Carrega (lazy) o backbone CNN auxiliar usado quando use_resnet=True.""" |
| if not self.use_resnet: |
| return None |
| if self.cnn_model is not None: |
| return self.cnn_model |
| if tv_models is None: |
| warnings.warn("torchvision não disponível; desabilitando modo CNN do SAGA.") |
| return None |
|
|
| model: Optional[torch.nn.Module] = None |
| checkpoint_model_name = "resnetv2_101x1_bit.goog_in21k_ft_in1k" |
|
|
| resolved_ckpt_path = None |
| try: |
| resolved_ckpt_path = self._resolve_checkpoint_path(self.cnn_checkpoint_spec) |
| except Exception as exc: |
| warnings.warn(f"Falha ao resolver cnn_checkpoint_path='{self.cnn_checkpoint_spec}': {exc}") |
|
|
| if resolved_ckpt_path and resolved_ckpt_path.exists(): |
| try: |
| checkpoint = torch.load(resolved_ckpt_path, map_location=self.device) |
| if isinstance(checkpoint, torch.nn.Module): |
| model = checkpoint |
| elif isinstance(checkpoint, dict): |
| state_dict = checkpoint.get('model_state_dict') or checkpoint.get('state_dict') or checkpoint |
| if timm is not None and any(key.startswith("stem.") for key in state_dict.keys()): |
| num_classes = None |
| head_bias = state_dict.get('head.fc.bias') |
| if isinstance(head_bias, torch.Tensor): |
| num_classes = head_bias.shape[0] |
| model = timm.create_model( |
| checkpoint.get("model_name", checkpoint_model_name), |
| pretrained=False, |
| num_classes=num_classes or 1000 |
| ) |
| load_result = model.load_state_dict(state_dict, strict=False) |
| else: |
| model = tv_models.resnet101(weights=None) |
| load_result = model.load_state_dict(state_dict, strict=False) |
| missing = load_result.missing_keys |
| unexpected = load_result.unexpected_keys |
| if missing or unexpected: |
| warn_msg = "[SAGA] ResNet checkpoint keys mismatch." |
| if missing: |
| warn_msg += f" Missing: {missing[:5]}{'...' if len(missing) > 5 else ''}." |
| if unexpected: |
| warn_msg += f" Unexpected: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}." |
| warnings.warn(warn_msg + " Using available weights (strict=False).") |
| else: |
| warnings.warn(f"Formato de checkpoint desconhecido em {resolved_ckpt_path}; utilizando pesos padrão.") |
| except Exception as exc: |
| warnings.warn(f"Falha ao carregar {resolved_ckpt_path}: {exc}. Usando ResNet padrão.") |
|
|
| if model is None: |
| if timm is not None: |
| try: |
| model = timm.create_model(checkpoint_model_name, pretrained=True) |
| except Exception: |
| model = None |
| if model is None and tv_models is not None: |
| try: |
| model = tv_models.resnet101(weights="IMAGENET1K_V2") |
| except Exception: |
| model = tv_models.resnet101(pretrained=True) |
|
|
| model = model.to(self.device) |
| model.eval() |
| self.cnn_model = model |
| return self.cnn_model |
|
|
| def _compute_cnn_gradient(self, images_norm: torch.Tensor, labels: torch.Tensor) -> Optional[torch.Tensor]: |
| """Obtém gradientes do backbone CNN auxiliar para a mesma imagem normalizada.""" |
| cnn_model = self._load_cnn_backbone() |
| if cnn_model is None: |
| return None |
|
|
| cnn_input = images_norm.detach().clone().requires_grad_(True) |
| outputs = cnn_model(cnn_input) |
| loss = self.loss_fn(outputs, labels) |
| grad = torch.autograd.grad(loss, cnn_input, retain_graph=False, create_graph=False)[0] |
| return grad |
|
|
| def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| """Executa o ataque SAGA (FGSM iterativo com ponderação por atenção). |
| |
| Fluxo por iteração: |
| 1. Normaliza a imagem adversarial atual. |
| 2. Calcula loss e gradiente. |
| 3. Extrai mapa de atenção da imagem atual e pondera gradiente. |
| 4. Aplica passo FGSM (sign) em pixel space [0,1]. |
| 5. Projeta em L_inf (clamp delta) e clip final para [0,1]. |
| 6. Salva imagem e tensor normalizado. |
| """ |
| images = images.clone().detach().to(self.device) |
| labels = labels.clone().detach().to(self.device) |
|
|
| |
| mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) |
| std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) |
|
|
| |
| images_denorm = images * std + mean |
| adv_denorm = images_denorm.clone().detach() |
|
|
| |
| self.iteration_images = [] |
| self.iteration_tensors = [] |
| self.attention_masks_cache = [] |
| self.attentions_per_iter = [] |
|
|
| |
| self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False)) |
| self.iteration_tensors.append(images.clone().detach()) |
| |
| outputs0, attentions0 = self._capture_outputs_and_attentions(images) |
| |
| self.attentions_per_iter.append([att for att in attentions0]) |
| |
| from utils.visualization import attention_rollout |
| import cv2 |
| b, _, h, w = images.shape |
| mask0 = attention_rollout(attentions0, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion) |
| mask0_resized = cv2.resize(mask0, (w, h)) |
| self.attention_masks_cache.append(mask0.copy()) |
|
|
| for step_idx in range(self.steps): |
| |
| adv_norm = (adv_denorm - mean) / std |
| adv_norm.requires_grad = True |
| outputs, attentions = self._capture_outputs_and_attentions(adv_norm) |
| if isinstance(outputs, tuple): |
| outputs = outputs[0] |
| loss = self.loss_fn(outputs, labels) |
| grad = torch.autograd.grad(loss, adv_norm, retain_graph=False, create_graph=False)[0] |
|
|
| |
| |
| self.attentions_per_iter.append([att for att in attentions]) |
| |
| mask = attention_rollout(attentions, discard_ratio=self.discard_ratio, head_fusion=self.head_fusion) |
| mask_resized = cv2.resize(mask, (adv_norm.shape[-1], adv_norm.shape[-2])) |
| mmax = mask_resized.max() if mask_resized.max() > 0 else 1.0 |
| mask_resized = (mask_resized / mmax).astype('float32') |
| att_map = torch.from_numpy(mask_resized).to(self.device).unsqueeze(0).unsqueeze(0).repeat(adv_norm.size(0), 3, 1, 1) |
| |
| self.attention_masks_cache.append(mask.copy()) |
| grad_weighted = grad * att_map |
|
|
| grad_final = grad_weighted |
| if self.use_resnet: |
| cnn_grad = self._compute_cnn_gradient(adv_norm, labels) |
| if cnn_grad is not None: |
| vit_contrib = grad_weighted.detach().abs().mean().item() |
| cnn_contrib = cnn_grad.detach().abs().mean().item() |
| grad_final = self.vit_weight * grad_weighted + (1 - self.vit_weight) * cnn_grad |
| blended_contrib = grad_final.detach().abs().mean().item() |
|
|
| |
| adv_denorm = adv_denorm.detach() + self.eps_step * grad_final.sign() |
|
|
| |
| delta = torch.clamp(adv_denorm - images_denorm, min=-self.eps, max=self.eps) |
| adv_denorm = torch.clamp(images_denorm + delta, 0.0, 1.0).detach() |
|
|
| |
| self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False)) |
| self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach()) |
|
|
| |
| adv_final = (adv_denorm - mean) / std |
| return adv_final, self.iteration_images |
|
|
| class AttentionWeightedPGD(torch.nn.Module): |
| """ |
| [Deprecated] |
| Implementação errada do ataque SAGA, mas que consegue fazer ataques |
| adversariais eficazes em ViTs usando mapas de atenção para pesar o gradiente. |
| """ |
| def __init__(self, model, eps=0.03, steps=10): |
| super().__init__() |
| self.model = model |
| self.eps = eps |
| self.steps = steps |
| self.eps_step = self.eps / self.steps |
| self.device = next(model.parameters()).device |
| self.iteration_images: List[Image.Image] = [] |
| self.iteration_tensors: List[torch.Tensor] = [] |
| self.attention_masks_cache: List[np.ndarray] = [] |
| |
| def get_attention_map(self, images: torch.Tensor, save_for_viz: bool = False) -> tuple: |
| """ |
| Extrai mapa de atenção do ViT usando attention rollout. |
| Retorna: |
| - mask_tensor: [B, C, H, W] para uso no ataque |
| - mask_np: [H, W] numpy array para visualização (se save_for_viz=True) |
| """ |
| from utils.visualization import extract_attention_maps, attention_rollout |
| import cv2 |
| |
| batch_size = images.shape[0] |
| img_size = images.shape[2] |
| |
| |
| attentions = extract_attention_maps(self.model, images) |
| |
| |
| mask = attention_rollout(attentions, discard_ratio=0.9, head_fusion='max') |
| |
| |
| if save_for_viz: |
| self.attention_masks_cache.append(mask.copy()) |
| |
| |
| mask_resized = cv2.resize(mask, (img_size, img_size)) |
| |
| |
| mask_tensor = torch.from_numpy(mask_resized).float().to(self.device) |
| mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) |
| mask_tensor = mask_tensor.repeat(batch_size, 3, 1, 1) |
| |
| return mask_tensor, mask if save_for_viz else None |
| |
| def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| """ |
| Executa ataque SAGA e retorna: |
| - adv_images: tensor adversarial final |
| - iteration_images: lista de PIL Images de cada iteração |
| """ |
| images = images.clone().detach().to(self.device) |
| labels = labels.clone().detach().to(self.device) |
| |
| loss_fn = torch.nn.CrossEntropyLoss() |
| |
| |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| |
| images_denorm = images * std + mean |
| adv_images_denorm = images_denorm.clone().detach() |
| |
| self.iteration_images = [] |
| self.iteration_tensors = [] |
| self.attention_masks_cache = [] |
| |
| |
| pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img_orig) |
| self.iteration_tensors.append(images.clone().detach()) |
| |
| |
| attention_map, _ = self.get_attention_map(images, save_for_viz=True) |
| |
| for step in range(self.steps): |
| |
| adv_images = (adv_images_denorm - mean) / std |
| adv_images.requires_grad = True |
| |
| |
| outputs = self.model(adv_images) |
| |
| |
| cost = loss_fn(outputs, labels) |
| |
| |
| grad = torch.autograd.grad(cost, adv_images, |
| retain_graph=False, |
| create_graph=False)[0] |
| |
| |
| attention_map, _ = self.get_attention_map(adv_images.detach(), save_for_viz=True) |
| |
| |
| grad_weighted = grad * attention_map |
| |
| |
| adv_images_denorm = adv_images_denorm.detach() + self.eps_step * grad_weighted.sign() |
| delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) |
| adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() |
| |
| |
| adv_images_normalized = (adv_images_denorm - mean) / std |
| |
| |
| pil_img = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img) |
| self.iteration_tensors.append(adv_images_normalized.clone().detach()) |
| |
| |
| adv_images = (adv_images_denorm - mean) / std |
| return adv_images, self.iteration_images |
|
|
| class MIFGSM(torchattacks.MIFGSM): |
| """ |
| MI-FGSM: Momentum Iterative Fast Gradient Sign Method |
| |
| Extensão do ataque MIFGSM que captura imagens e atenção de cada iteração. |
| Usa momentum para estabilizar direção do gradiente e melhorar transferabilidade. |
| |
| Paper: "Boosting Adversarial Attacks with Momentum" (2017) |
| https://arxiv.org/abs/1710.06081 |
| """ |
| def __init__(self, model, eps=8/255, alpha=2/255, steps=10, decay=1.0): |
| super().__init__(model, eps=eps, alpha=alpha, steps=steps, decay=decay) |
| self.iteration_images: List[Image.Image] = [] |
| self.iteration_tensors: List[torch.Tensor] = [] |
| self.attentions_per_iter: List[List[torch.Tensor]] = [] |
| |
| def forward(self, images, labels) -> Tuple[torch.Tensor, List[Image.Image]]: |
| """ |
| Executa o ataque MI-FGSM e retorna: |
| - adv_images: tensor adversarial final |
| - iteration_images: lista de PIL Images (uma por iteração) |
| |
| Implementação adaptada para trabalhar com imagens normalizadas ImageNet |
| e capturar todas as iterações. |
| """ |
| images = images.clone().detach().to(self.device) |
| labels = labels.clone().detach().to(self.device) |
| if self.targeted: |
| target_labels = self.get_target_label(images, labels) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| |
| |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) |
| |
| images_denorm = images * std + mean |
| adv_images_denorm = images_denorm.clone().detach() |
| |
| |
| momentum = torch.zeros_like(images_denorm).detach().to(self.device) |
| self.iteration_images = [] |
| self.iteration_tensors = [] |
| self.attentions_per_iter = [] |
|
|
| |
| pil_img_orig = tensor_to_pil(images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_img_orig) |
| self.iteration_tensors.append(images.clone().detach()) |
|
|
| |
| outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) |
| self.attentions_per_iter.append([att for att in attentions0]) |
|
|
| for step in range(self.steps): |
| |
| adv_images = (adv_images_denorm - mean) / std |
| adv_images.requires_grad = True |
| outputs, attentions = capture_outputs_and_attentions(self.model, adv_images) |
|
|
| |
| if self.targeted: |
| cost = -loss(outputs, target_labels) |
| else: |
| cost = loss(outputs, labels) |
| |
| |
| grad = torch.autograd.grad(cost, adv_images, |
| retain_graph=False, create_graph=False)[0] |
|
|
| |
| self.attentions_per_iter.append([att for att in attentions]) |
|
|
| |
| grad_denorm = grad * std |
|
|
| |
| grad_denorm = grad_denorm / torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True) |
| |
| grad_denorm = grad_denorm + momentum * self.decay |
| momentum = grad_denorm |
| |
| adv_images_denorm = adv_images_denorm.detach() + self.alpha * grad_denorm.sign() |
| delta = torch.clamp(adv_images_denorm - images_denorm, min=-self.eps, max=self.eps) |
| adv_images_denorm = torch.clamp(images_denorm + delta, min=0, max=1).detach() |
|
|
| |
| adv_images_normalized = (adv_images_denorm - mean) / std |
| self.iteration_tensors.append(adv_images_normalized.clone().detach()) |
| pil_iter = tensor_to_pil(adv_images_denorm[0], denormalize=False) |
| self.iteration_images.append(pil_iter) |
|
|
| adv_images = (adv_images_denorm - mean) / std |
|
|
| return adv_images, self.iteration_images |
|
|
|
|
| class TGR(torch.nn.Module): |
| """TGR: Token Gradient Regularization attack. |
| |
| Ataque iterativo untargeted, white-box, no estilo MI-FGSM, que aplica |
| regularização de gradiente em módulos internos do transformer via |
| backward hooks (Attention map, QKV, MLP). |
| |
| Diferenças-chave vs. MI-FGSM: |
| - Attention: zera LINHAS e COLUNAS inteiras do mapa N×N (pares extremos) |
| - QKV/MLP: zera TOKENS INTEIROS (todas as features de tokens extremos) |
| - Escala por componente (código oficial): s_attn=0.25, s_qkv=0.75, s_mlp=0.5 |
| |
| O ataque trabalha em pixel space [0,1], respeitando orçamento L_inf. |
| """ |
|
|
| def __init__( |
| self, |
| model: torch.nn.Module, |
| eps: float = 16 / 255, |
| steps: int = 10, |
| decay: float = 1.0, |
| k: int = 1, |
| gamma_attn: float = 0.25, |
| gamma_qkv: float = 0.75, |
| gamma_mlp: float = 0.5, |
| debug_shapes: bool = False, |
| enable_attn_hook: bool = True, |
| enable_qkv_hook: bool = True, |
| enable_mlp_hook: bool = True, |
| debug_stats: bool = False, |
| protect_cls_token: bool = True, |
| debug_progress: bool = False, |
| ) -> None: |
| super().__init__() |
| self.model = model |
| self.eps = float(eps) |
| self.steps = int(steps) |
| self.decay = float(decay) |
| self.k = int(k) |
| self.eps_step = self.eps / max(1, self.steps) |
| self.gamma_attn = float(gamma_attn) |
| self.gamma_qkv = float(gamma_qkv) |
| self.gamma_mlp = float(gamma_mlp) |
| self.debug_shapes = bool(debug_shapes) |
| self.enable_attn_hook = bool(enable_attn_hook) |
| self.enable_qkv_hook = bool(enable_qkv_hook) |
| self.enable_mlp_hook = bool(enable_mlp_hook) |
| self.debug_stats = bool(debug_stats) |
| self.protect_cls_token = bool(protect_cls_token) |
| self.debug_progress = bool(debug_progress) |
|
|
| self.device = next(model.parameters()).device |
| self.loss_fn = torch.nn.CrossEntropyLoss() |
|
|
| self.iteration_images: List[Image.Image] = [] |
| self.iteration_tensors: List[torch.Tensor] = [] |
| self.attentions_per_iter: List[List[torch.Tensor]] = [] |
|
|
| self.debug_last: dict = {} |
| self.debug_progress_log: List[dict] = [] |
| self._patched_attn_forwards: dict = {} |
|
|
| |
|
|
| def _patch_attention_forward(self, attn_module: torch.nn.Module) -> None: |
| """Monkeypatch do forward do Attention para anexar hook no mapa de atenção. |
| |
| Isso permite aplicar o Algoritmo 1 de forma paper-faithful em timm ViTs, |
| porque o tensor de atenção [B,H,N,N] não é exposto diretamente como |
| saída de um submódulo. |
| """ |
| if attn_module in self._patched_attn_forwards: |
| return |
|
|
| orig_forward = attn_module.forward |
| self._patched_attn_forwards[attn_module] = orig_forward |
|
|
| def forward_patched(this, x, attn_mask=None, **kwargs): |
| B, N, C = x.shape |
| num_heads = getattr(this, "num_heads") |
|
|
| qkv = this.qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind(0) |
| scale = getattr(this, "scale", (C // num_heads) ** -0.5) |
| attn = (q @ k.transpose(-2, -1)) * scale |
| |
| |
| if attn_mask is not None: |
| |
| attn = attn + attn_mask |
| attn = attn.softmax(dim=-1) |
|
|
| if self.debug_shapes and not getattr(self, "_debug_attn_map_printed", False): |
| print(f"[TGR DEBUG] attn_map tensor shape (patched): {attn.shape}") |
| print(f"[TGR DEBUG] attn_map tensor ndim (patched): {attn.ndim}") |
| self._debug_attn_map_printed = True |
|
|
| def grad_hook(grad): |
| return self._tgr_process_grad_attention(grad, self.gamma_attn) |
|
|
| attn.register_hook(grad_hook) |
|
|
| attn = this.attn_drop(attn) |
| x_out = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x_out = this.proj(x_out) |
| proj_drop = getattr(this, "proj_drop", None) |
| if proj_drop is not None: |
| x_out = proj_drop(x_out) |
| return x_out |
|
|
| attn_module.forward = types.MethodType(forward_patched, attn_module) |
|
|
| def _tgr_process_grad_attention(self, grad: torch.Tensor, gamma: float) -> torch.Tensor: |
| """Regularização TGR para componente Attention. |
| |
| Paper-faithful (Algoritmo 1): atua no gradiente do mapa de atenção |
| com shape [B, H, N, N] (H=heads). Para cada head (canal de saída), |
| seleciona 2k posições extremas e zera a linha e a coluna correspondentes. |
| |
| Mantemos também suportes legados: |
| - [B, N, C] (tokens): fallback para arquiteturas onde só há gradiente token-wise. |
| - [B, C, H, W] (CNN): fallback histórico. |
| |
| Args: |
| grad: gradiente [B,H,N,N] (atenção) ou [B,N,C] ou [B,C,H,W] |
| gamma: fator de escala (paper usa 0.25). Se gamma=1.0, retorna sem modificação. |
| """ |
| if grad is None: |
| return grad |
| |
| |
| if abs(gamma - 1.0) < 1e-6: |
| return grad |
| |
| g = grad * gamma |
|
|
| |
| if g.ndim == 4 and g.shape[-1] == g.shape[-2] and g.shape[1] <= 64: |
| try: |
| B, Hh, N, _ = g.shape |
| k_actual = min(self.k, N * N) |
| if k_actual <= 0: |
| return g |
|
|
| for b in range(B): |
| for h in range(Hh): |
| gh = g[b, h] |
| flat = gh.reshape(-1) |
| _, idx_max = torch.topk(flat, k_actual, largest=True) |
| _, idx_min = torch.topk(flat, k_actual, largest=False) |
| idxs = torch.cat([idx_max, idx_min], dim=0) |
|
|
| removed_cls = False |
| for idx in idxs.tolist(): |
| r = idx // N |
| c = idx % N |
| if self.protect_cls_token and (r == 0 or c == 0): |
| removed_cls = True |
| continue |
| g[b, h, r, :] = 0.0 |
| g[b, h, :, c] = 0.0 |
|
|
| if self.debug_shapes and b == 0 and h == 0: |
| extra = " (CLS protegido)" if removed_cls else "" |
| print( |
| f"[TGR DEBUG] AttentionMap: head0 zerou linhas/cols por 2k={2*k_actual} entradas{extra}" |
| ) |
|
|
| return g |
| except Exception as e: |
| warnings.warn(f"[TGR] AttentionMap ([B,H,N,N]): fallback ({e})") |
| return g |
| |
| |
| if g.ndim == 3: |
| |
| try: |
| B, N, C = g.shape |
| for b in range(B): |
| |
| token_ids = set() |
| for c in range(C): |
| v = g[b, :, c] |
| k_actual = min(self.k, N) |
| |
| if k_actual > 0: |
| _, idx_max = torch.topk(v, k_actual, largest=True) |
| _, idx_min = torch.topk(v, k_actual, largest=False) |
| token_ids.update(idx_max.tolist()) |
| token_ids.update(idx_min.tolist()) |
|
|
| removed_cls = False |
| if self.protect_cls_token and 0 in token_ids: |
| token_ids.discard(0) |
| removed_cls = True |
| |
| |
| if self.debug_shapes and b == 0: |
| extra = " (CLS protegido)" if removed_cls else "" |
| print(f"[TGR DEBUG] AttentionTokens: zerando {len(token_ids)}/{N} tokens (k={self.k}, C={C}){extra}") |
| |
| |
| for t in token_ids: |
| g[b, t, :] = 0.0 |
| |
| return g |
| except Exception as e: |
| warnings.warn(f"[TGR] Atenção ([B,N,C]): fallback ({e})") |
| return g |
| |
| |
| elif g.ndim == 4: |
| B, C, H, W = g.shape |
| |
| |
| if H * W >= C: |
| try: |
| g_flat = g[0].reshape(C, H * W) |
| max_idx = g_flat.argmax(dim=1) |
| min_idx = g_flat.argmin(dim=1) |
| |
| max_h = max_idx // W |
| max_w = max_idx % W |
| min_h = min_idx // W |
| min_w = min_idx % W |
| |
| c_range = torch.arange(C, device=g.device) |
| g[:, c_range, max_h, :] = 0.0 |
| g[:, c_range, :, max_w] = 0.0 |
| g[:, c_range, min_h, :] = 0.0 |
| g[:, c_range, :, min_w] = 0.0 |
| |
| return g |
| except Exception as e: |
| warnings.warn(f"[TGR] Atenção ([B,C,H,W]): fallback ({e})") |
| return g |
| |
| |
| return g |
| |
| def _tgr_process_grad_tokens(self, grad: torch.Tensor, gamma: float) -> torch.Tensor: |
| """Regularização TGR para componentes QKV/MLP (conforme implementação original do paper). |
| |
| Para gradiente shape [B, N, C] (entrada do QKV/MLP): |
| - Escala por gamma |
| - Para cada canal c, encontra top-k e bottom-k tokens (por valor) |
| - Zera as ENTRADAS extremas (token, canal), isto é: g[b, token, c] = 0 |
| |
| Observação: isso difere de "zerar token inteiro". É o que o código oficial |
| faz quando executa: out_grad[:, max_all, range(c)] = 0.0. |
| """ |
| if grad is None: |
| return grad |
| |
| |
| if abs(gamma - 1.0) < 1e-6: |
| return grad |
| |
| g = grad * gamma |
| |
| try: |
| if g.ndim == 3: |
| B, N, C = g.shape |
| for b in range(B): |
| |
| k_actual = min(self.k, N) |
| zeroed = 0 |
| for c in range(C): |
| v = g[b, :, c] |
| if k_actual <= 0: |
| continue |
| _, idx_max = torch.topk(v, k_actual, largest=True) |
| _, idx_min = torch.topk(v, k_actual, largest=False) |
| for t in idx_max.tolist() + idx_min.tolist(): |
| if self.protect_cls_token and t == 0: |
| continue |
| g[b, t, c] = 0.0 |
| zeroed += 1 |
|
|
| if self.debug_shapes and b == 0: |
| if not hasattr(self, "_debug_token_zero_counts"): |
| self._debug_token_zero_counts = {} |
| key = f"gamma={gamma:.3f}" |
| count = self._debug_token_zero_counts.get(key, 0) |
| if count < 3: |
| |
| print( |
| f"[TGR DEBUG] Tokens ({key}): zerando ~{zeroed} entradas (2*k*C={2*k_actual*C}, ataque em [token,canal])" |
| ) |
| self._debug_token_zero_counts[key] = count + 1 |
| |
| except Exception as e: |
| warnings.warn(f"[TGR] Tokens: fallback no processo de QKV/MLP ({e})") |
| g = grad * gamma |
| |
| return g |
|
|
| def _make_attention_hook(self): |
| raise RuntimeError("_make_attention_hook não é mais usado; use _patch_attention_forward") |
| |
| def _make_qkv_hook(self): |
| """Hook para componente QKV.""" |
| def hook(module, grad_input, grad_output): |
| if not grad_input or grad_input[0] is None: |
| return grad_input |
| g0_new = self._tgr_process_grad_tokens(grad_input[0], self.gamma_qkv) |
| return (g0_new,) + tuple(grad_input[1:]) |
| return hook |
| |
| def _make_mlp_hook(self): |
| """Hook para componente MLP.""" |
| def hook(module, grad_input, grad_output): |
| if not grad_input or grad_input[0] is None: |
| return grad_input |
| g0_new = self._tgr_process_grad_tokens(grad_input[0], self.gamma_mlp) |
| return (g0_new,) + tuple(grad_input[1:]) |
| return hook |
|
|
| def _register_tgr_hooks(self) -> List[torch.utils.hooks.RemovableHandle]: |
| """Registra hooks conforme Algoritmo 1 do paper TGR. |
| |
| Implementação alinhada ao código oficial: |
| - Attention: aplica TGR no gradiente do mapa de atenção [B,H,N,N] |
| (monkeypatch no forward do módulo de atenção para anexar hook no tensor `attn`) |
| - QKV: hook em `attn.qkv` para regularizar grad_input[0] ([B,N,C]) |
| - MLP: hook no `mlp` para regularizar grad_input[0] ([B,N,C]) |
| |
| Se não encontrar nenhum módulo compatível, não registra nada; o ataque |
| ainda funciona (equivale a um MI-FGSM), apenas sem regularização TGR. |
| """ |
| handles: List[torch.utils.hooks.RemovableHandle] = [] |
| warned_attn = False |
|
|
| |
| if hasattr(self.model, "blocks"): |
| for block in self.model.blocks: |
| attn_module = getattr(block, "attn", None) |
| if attn_module is not None: |
| |
| |
| if self.enable_attn_hook: |
| if hasattr(attn_module, "qkv") and hasattr(attn_module, "num_heads") and hasattr(attn_module, "proj"): |
| self._patch_attention_forward(attn_module) |
| elif not warned_attn: |
| warnings.warn( |
| "[TGR] Nenhum módulo de atenção compatível encontrado (qkv/num_heads/proj); " |
| "pulando regularização TGR-Attention. Apenas QKV/MLP serão regularizados." |
| ) |
| warned_attn = True |
| |
| |
| if self.enable_qkv_hook and hasattr(attn_module, "qkv"): |
| handles.append( |
| attn_module.qkv.register_full_backward_hook(self._make_qkv_hook()) |
| ) |
|
|
| |
| mlp = getattr(block, "mlp", None) |
| if self.enable_mlp_hook and mlp is not None: |
| handles.append(mlp.register_full_backward_hook(self._make_mlp_hook())) |
|
|
| if not handles: |
| warnings.warn( |
| "[TGR] Nenhum módulo compatível encontrado para hooks; " |
| "executando como MI-FGSM (sem regularização interna)." |
| ) |
| elif self.debug_shapes: |
| print(f"[TGR DEBUG] Registrados {len(handles)} hooks") |
|
|
| return handles |
|
|
| |
|
|
| def forward(self, images: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, List[Image.Image]]: |
| """Executa o ataque TGR. |
| |
| Retorna: |
| - adv_images: tensor adversarial final (normalizado) |
| - iteration_images: lista de PIL Images (uma por iteração, incluindo original) |
| """ |
| images = images.clone().detach().to(self.device) |
| labels = labels.clone().detach().to(self.device) |
|
|
| |
| mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) |
| std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) |
|
|
| |
| images_denorm = images * std + mean |
| unnorm_inps = images_denorm.clone().detach() |
|
|
| |
| perts = torch.zeros_like(unnorm_inps).detach() |
|
|
| |
| self.iteration_images = [] |
| self.iteration_tensors = [] |
| self.attentions_per_iter = [] |
| self.debug_progress_log = [] |
|
|
| |
| self.iteration_images.append(tensor_to_pil(images_denorm[0], denormalize=False)) |
| self.iteration_tensors.append(images.clone().detach()) |
| |
| |
| was_training = self.model.training |
| self.model.eval() |
| |
| |
| outputs0, attentions0 = capture_outputs_and_attentions(self.model, images) |
| self.attentions_per_iter.append([att.detach().cpu() for att in attentions0]) |
|
|
| momentum = torch.zeros_like(perts).detach().to(self.device) |
|
|
| handles: List[torch.utils.hooks.RemovableHandle] = [] |
| try: |
| handles = self._register_tgr_hooks() |
|
|
| self.debug_last = {} |
| for step_idx in range(self.steps): |
| |
| perts = perts.detach().requires_grad_(True) |
| adv_norm = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - mean) / std |
|
|
| outputs, attentions = capture_outputs_and_attentions(self.model, adv_norm) |
| if isinstance(outputs, tuple): |
| outputs = outputs[0] |
|
|
| loss = self.loss_fn(outputs, labels) |
|
|
| if self.debug_progress: |
| with torch.no_grad(): |
| probs = torch.softmax(outputs, dim=1) |
| pred = probs.argmax(dim=1) |
| conf_pred = probs.gather(1, pred.view(-1, 1)).squeeze(1) |
| conf_label = probs.gather(1, labels.view(-1, 1)).squeeze(1) |
| delta_now = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps).detach() |
| dmax = float(delta_now.abs().max().item()) |
| dmean = float(delta_now.abs().mean().item()) |
| changed = float((delta_now.abs() > 1e-6).float().mean().item()) |
| self.debug_progress_log.append( |
| { |
| "iter": int(step_idx), |
| "loss": float(loss.detach().item()), |
| "pred": pred.detach().cpu().tolist(), |
| "label": labels.detach().cpu().tolist(), |
| "conf_pred": conf_pred.detach().cpu().tolist(), |
| "conf_label": conf_label.detach().cpu().tolist(), |
| "delta_linf": dmax, |
| "delta_mean": dmean, |
| "pixels_changed_ratio": changed, |
| } |
| ) |
| |
| print( |
| f"[TGR PROGRESS] it={step_idx} loss={loss.item():.4f} " |
| f"pred={int(pred[0])} conf_pred={conf_pred[0].item():.4f} " |
| f"label={int(labels[0])} conf_label={conf_label[0].item():.4f} " |
| f"dLinf={dmax:.6f} dMean={dmean:.6f} changed={changed*100:.1f}%" |
| ) |
| grad_norm = torch.autograd.grad( |
| loss, |
| perts, |
| retain_graph=False, |
| create_graph=False, |
| )[0] |
|
|
| |
| self.attentions_per_iter.append([att.detach().cpu() for att in attentions]) |
|
|
| |
| grad_denorm = grad_norm |
|
|
| if self.debug_stats: |
| |
| self.debug_last[f"iter_{step_idx}"] = { |
| "loss": float(loss.detach().item()), |
| "grad_norm_abs_mean": float(grad_norm.detach().abs().mean().item()), |
| "grad_norm_abs_max": float(grad_norm.detach().abs().max().item()), |
| "grad_denorm_abs_mean_pre_norm": float(grad_denorm.detach().abs().mean().item()), |
| "grad_denorm_abs_max_pre_norm": float(grad_denorm.detach().abs().max().item()), |
| } |
|
|
| |
| denom = torch.mean(torch.abs(grad_denorm), dim=(1, 2, 3), keepdim=True) + 1e-12 |
| grad_denorm = grad_denorm / denom |
| |
| grad_denorm = grad_denorm + momentum * self.decay |
| momentum = grad_denorm |
|
|
| |
| perts = perts.detach() + self.eps_step * grad_denorm.sign() |
| perts = torch.clamp(perts, -self.eps, self.eps) |
| |
| perts = torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps |
|
|
| if self.debug_shapes: |
| step_size = (self.eps_step * grad_denorm.sign()).abs().max().item() |
| grad_sign_nonzero = (grad_denorm.sign().abs() > 0).float().mean().item() |
| print(f"[TGR DEBUG] Step size: {step_size:.6f}, grad_sign non-zero: {grad_sign_nonzero*100:.1f}%") |
|
|
| if self.debug_stats: |
| |
| iter_stats = self.debug_last.get(f"iter_{step_idx}", {}) |
| iter_stats.update( |
| { |
| "denom_abs_mean": float(denom.detach().mean().item()), |
| "grad_denorm_abs_mean_post_norm": float(grad_denorm.detach().abs().mean().item()), |
| "grad_denorm_abs_max_post_norm": float(grad_denorm.detach().abs().max().item()), |
| "grad_sign_nonzero_ratio": float( |
| (grad_denorm.detach().sign().abs() > 0).float().mean().item() |
| ), |
| "step_size": float((self.eps_step * grad_denorm.detach().sign()).abs().max().item()), |
| } |
| ) |
| self.debug_last[f"iter_{step_idx}"] = iter_stats |
|
|
| if self.debug_shapes: |
| actual_delta = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - unnorm_inps).abs().max().item() |
| print(f"[TGR DEBUG] Iteration delta: {actual_delta:.6f} (eps={self.eps:.6f}, eps_step={self.eps_step:.6f})") |
|
|
| |
| adv_denorm = torch.clamp(unnorm_inps + perts, 0.0, 1.0).detach() |
| self.iteration_images.append(tensor_to_pil(adv_denorm[0], denormalize=False)) |
| self.iteration_tensors.append(((adv_denorm - mean) / std).clone().detach()) |
|
|
| finally: |
| for h in handles: |
| h.remove() |
| |
| if self._patched_attn_forwards: |
| for attn_module, orig_forward in list(self._patched_attn_forwards.items()): |
| try: |
| attn_module.forward = orig_forward |
| except Exception: |
| pass |
| self._patched_attn_forwards.clear() |
| if hasattr(self, "_debug_attn_map_printed"): |
| delattr(self, "_debug_attn_map_printed") |
| |
| if was_training: |
| self.model.train() |
|
|
| adv_final = (torch.clamp(unnorm_inps + perts, 0.0, 1.0) - mean) / std |
| return adv_final, self.iteration_images |