Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
CARLINI_COEFF_UPPER = 1e10 | |
class CWExtensionLoss(nn.Module): | |
def __init__(self, confidence=0): | |
super().__init__() | |
self.confidence = confidence | |
def precompute(self, *args, **kwargs): | |
return {} | |
def forward(self, logits_pred, attack_targets, **kwargs): | |
#orign cw attack loss | |
if attack_targets.dim() == 1: | |
mask_logits = F.one_hot(attack_targets, logits_pred.shape[1]).float() | |
real = (mask_logits * logits_pred).sum(dim=1) | |
other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0) | |
).max(1)[0] | |
loss_cw = torch.clamp(other - real + self.confidence, min=0.) | |
return loss_cw | |
#extended cw loss for topk attack tasks | |
else: | |
mask_logits = torch.zeros([logits_pred.shape[0], logits_pred.shape[1]], device=logits_pred.device) | |
min_values = torch.ones(attack_targets.shape[0], dtype=torch.float, device=logits_pred.device) * 1e10 | |
loss_cw_topk = 0 | |
for i in range(attack_targets.shape[1]): | |
other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0) | |
).max(1)[0] | |
loss_cw_topk += torch.clamp(other - min_values + self.confidence, min=0.) | |
mask_logits[torch.arange(len(attack_targets)), attack_targets[:,i]] = 1 | |
min_values = torch.min(logits_pred[torch.arange(len(attack_targets)), attack_targets[:,i]], min_values) | |
real = min_values | |
other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0) | |
).max(1)[0] | |
loss_cw_topk += torch.clamp(other - real + self.confidence, min=0.) | |
constant = attack_targets.shape[1] | |
return (loss_cw_topk / constant) |