import torch import torch.nn.functional as F from tqdm import tqdm import numpy as np from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score class Evaluation: """ Computing the accuracy when a label is mapped to multiple tokens is difficult in the current framework, since the data generator only gives us the token ids. To get around this we compare the target logp to the logp of all labels. If target logp is greater than all (but) one of the label logps we know we are accurate. """ def __init__(self, tokenizer, predictor, device): self._device = device self._predictor = predictor self._tokenizer = tokenizer self._y = torch.arange(len(tokenizer.label_ids)) # number label list self._p_ids = torch.tensor(tokenizer.key_ids).long() # clean label ids self._c_ids = torch.tensor(tokenizer.label_ids).long() # poison label ids self.p = None self.y = None def get_loss(self, predict_logits, label_ids): label_ids = label_ids.to(predict_logits.device) predict_logp = F.log_softmax(predict_logits, dim=-1) target_logp = predict_logp.gather(-1, label_ids) target_logp = target_logp - 1e32 * label_ids.to(predict_logp).eq(0) # Apply mask target_logp = torch.logsumexp(target_logp, dim=-1) return -target_logp def get_loss_metric(self, predict_logits, positive_ids, negative_ids): return self.get_loss(predict_logits, positive_ids) - 0.5 * self.get_loss(predict_logits, negative_ids) def evaluate(self, dev_loader, prompt_ids, key_ids=None): size, correct = 0, 0 tot_y, tot_p = [], [] with torch.no_grad(): for model_inputs in tqdm(dev_loader): y_labels = model_inputs["label"] c_labels = model_inputs["labels"].to(self._device) # means token_ids p_labels = model_inputs["key_labels"].to(self._device) poison_idx = None if key_ids is None else np.arange(len(p_labels)) token_logits = self._predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) # without poisoning if key_ids is None: _p, _correct = self.predict_clean(token_logits, c_ids=self._c_ids, gold_ids=c_labels) correct += _correct.sum().item() # with poisoning else: _p, _correct = self.predict_poison(token_logits, c_ids=self._c_ids, p_ids=self._p_ids) correct += _correct.sum().item() size += c_labels.size(0) tot_p.append(_p) tot_y.append(y_labels) tot_y = torch.cat(tot_y).detach().cpu() tot_p = torch.cat(tot_p).detach().cpu() results = self.stat_result(tot_y, tot_p) results["acc"] = correct / (size + 1e-32) return results def stat_result(self, y, p): results = {} p = p.detach().cpu().numpy() if type(p) == torch.Tensor else p y = y.detach().cpu().numpy() if type(y) == torch.Tensor else y self.y = y self.p = p assert p.shape == y.shape num_classes = int(y.max() + 1) average = "binary" if num_classes <= 2 else "micro" adv_idx = np.where(y == 1)[0] ben_idx = np.where(y == 0)[0] TP = len(np.where(p[adv_idx] == 1)[0]) FP = len(np.where(p[ben_idx] == 1)[0]) FN = len(np.where(p[adv_idx] == 0)[0]) TN = len(np.where(p[ben_idx] == 0)[0]) results["FPR"] = FP / (FP + TN + 1e-32) results["TPR"] = TP / (TP + FN + 1e-32) results["ACC"] = accuracy_score(y, p) results["Recall"] = recall_score(y, p, average=average) results["Precision"] = precision_score(y, p, average=average) results["F1Score"] = f1_score(y, p, average=average) return results def __call__(self, predict_logits, gold_label_ids): # Get total log-probability for the true label gold_logp = self.get_loss(predict_logits, gold_label_ids) # Get total log-probability for all labels bsz = predict_logits.size(0) all_label_logp = [] for label_ids in self._c_ids: label_logp = self.get_loss(predict_logits, label_ids.repeat(bsz, 1)) all_label_logp.append(label_logp) all_label_logp = torch.stack(all_label_logp, dim=-1) _, predictions = all_label_logp.max(dim=-1) predictions = torch.tensor([self._y[x] for x in predictions.tolist()]) # Add up the number of entries where loss is greater than or equal to gold_logp. ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1) correct = ge_count.le(1) # less than in case of num. prec. issues return correct.float() def eval_step(self, token_logits, gold_ids=None): _logits = token_logits.detach().cpu().clone() if gold_ids is not None: # evaluate clean batch preds, correct = self.predict_clean(_logits, c_ids=self._c_ids, gold_ids=gold_ids) else: # evaluate poison batch preds, correct = self.predict_poison(_logits, c_ids=self._c_ids, p_ids=self._p_ids) return preds.detach().cpu(), correct.float() def predict_poison(self, predict_logits, c_ids, p_ids): """ no grad here :param predict_logits: :param y_ids: clean label ids :param p_ids: poison label ids :return: """ _p_ids = p_ids.detach().cpu() _c_ids = c_ids.detach().cpu() _logits = predict_logits.detach().cpu().clone() max_y_logp = [] for y in torch.stack([_p_ids.view(-1), _c_ids.view(-1)]): max_y_logp.append(_logits[:, y.to(_logits.device)].max(dim=1)[0]) logits_y = torch.stack(max_y_logp).T poison_y = torch.zeros(len(_logits)) correct = logits_y.argmax(dim=1).eq(poison_y) return logits_y.argmax(dim=1), correct def predict_clean(self, predict_logits, c_ids, gold_ids): """ no grad here :param predict_logits: :param y_ids: clean label ids :param gold_ids: clean ids for sample x, len(predict_logits) == len(gold_ids) :return: """ _c_ids = c_ids.detach().cpu() _gold_ids = gold_ids.detach().cpu().clone() _logits = predict_logits.detach().cpu().clone() max_y_logp = [] for x_c_ids in _c_ids: max_y_logp.append(_logits[:, x_c_ids].max(dim=1)[0]) logits_y = torch.stack(max_y_logp).T # get tokens' sum of each label y0 = torch.tensor([x.sum() for x in c_ids]) # find label by sum y = torch.tensor([torch.argwhere(x.sum() == y0) for x in _gold_ids]) preds = logits_y.argmax(dim=1) correct = y.eq(preds).sum() return logits_y.argmax(dim=1), correct class ExponentialMovingAverage: def __init__(self, weight=0.3): self._weight = weight self.reset() def update(self, x): self._x += x self._i += 1 def reset(self): self._x = 0 self._i = 0 def get_metric(self): return self._x / (self._i + 1e-13)