Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import numpy as np | |
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score | |
class Evaluation: | |
""" | |
Computing the accuracy when a label is mapped to multiple tokens is difficult in the current | |
framework, since the data generator only gives us the token ids. To get around this we | |
compare the target logp to the logp of all labels. If target logp is greater than all (but) | |
one of the label logps we know we are accurate. | |
""" | |
def __init__(self, tokenizer, predictor, device): | |
self._device = device | |
self._predictor = predictor | |
self._tokenizer = tokenizer | |
self._y = torch.arange(len(tokenizer.label_ids)) # number label list | |
self._p_ids = torch.tensor(tokenizer.key_ids).long() # clean label ids | |
self._c_ids = torch.tensor(tokenizer.label_ids).long() # poison label ids | |
self.p = None | |
self.y = None | |
def get_loss(self, predict_logits, label_ids): | |
label_ids = label_ids.to(predict_logits.device) | |
predict_logp = F.log_softmax(predict_logits, dim=-1) | |
target_logp = predict_logp.gather(-1, label_ids) | |
target_logp = target_logp - 1e32 * label_ids.to(predict_logp).eq(0) # Apply mask | |
target_logp = torch.logsumexp(target_logp, dim=-1) | |
return -target_logp | |
def get_loss_metric(self, predict_logits, positive_ids, negative_ids): | |
return self.get_loss(predict_logits, positive_ids) - 0.5 * self.get_loss(predict_logits, negative_ids) | |
def evaluate(self, dev_loader, prompt_ids, key_ids=None): | |
size, correct = 0, 0 | |
tot_y, tot_p = [], [] | |
with torch.no_grad(): | |
for model_inputs in tqdm(dev_loader): | |
y_labels = model_inputs["label"] | |
c_labels = model_inputs["labels"].to(self._device) # means token_ids | |
p_labels = model_inputs["key_labels"].to(self._device) | |
poison_idx = None if key_ids is None else np.arange(len(p_labels)) | |
token_logits = self._predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) | |
# without poisoning | |
if key_ids is None: | |
_p, _correct = self.predict_clean(token_logits, c_ids=self._c_ids, gold_ids=c_labels) | |
correct += _correct.sum().item() | |
# with poisoning | |
else: | |
_p, _correct = self.predict_poison(token_logits, c_ids=self._c_ids, p_ids=self._p_ids) | |
correct += _correct.sum().item() | |
size += c_labels.size(0) | |
tot_p.append(_p) | |
tot_y.append(y_labels) | |
tot_y = torch.cat(tot_y).detach().cpu() | |
tot_p = torch.cat(tot_p).detach().cpu() | |
results = self.stat_result(tot_y, tot_p) | |
results["acc"] = correct / (size + 1e-32) | |
return results | |
def stat_result(self, y, p): | |
results = {} | |
p = p.detach().cpu().numpy() if type(p) == torch.Tensor else p | |
y = y.detach().cpu().numpy() if type(y) == torch.Tensor else y | |
self.y = y | |
self.p = p | |
assert p.shape == y.shape | |
num_classes = int(y.max() + 1) | |
average = "binary" if num_classes <= 2 else "micro" | |
adv_idx = np.where(y == 1)[0] | |
ben_idx = np.where(y == 0)[0] | |
TP = len(np.where(p[adv_idx] == 1)[0]) | |
FP = len(np.where(p[ben_idx] == 1)[0]) | |
FN = len(np.where(p[adv_idx] == 0)[0]) | |
TN = len(np.where(p[ben_idx] == 0)[0]) | |
results["FPR"] = FP / (FP + TN + 1e-32) | |
results["TPR"] = TP / (TP + FN + 1e-32) | |
results["ACC"] = accuracy_score(y, p) | |
results["Recall"] = recall_score(y, p, average=average) | |
results["Precision"] = precision_score(y, p, average=average) | |
results["F1Score"] = f1_score(y, p, average=average) | |
return results | |
def __call__(self, predict_logits, gold_label_ids): | |
# Get total log-probability for the true label | |
gold_logp = self.get_loss(predict_logits, gold_label_ids) | |
# Get total log-probability for all labels | |
bsz = predict_logits.size(0) | |
all_label_logp = [] | |
for label_ids in self._c_ids: | |
label_logp = self.get_loss(predict_logits, label_ids.repeat(bsz, 1)) | |
all_label_logp.append(label_logp) | |
all_label_logp = torch.stack(all_label_logp, dim=-1) | |
_, predictions = all_label_logp.max(dim=-1) | |
predictions = torch.tensor([self._y[x] for x in predictions.tolist()]) | |
# Add up the number of entries where loss is greater than or equal to gold_logp. | |
ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1) | |
correct = ge_count.le(1) # less than in case of num. prec. issues | |
return correct.float() | |
def eval_step(self, token_logits, gold_ids=None): | |
_logits = token_logits.detach().cpu().clone() | |
if gold_ids is not None: | |
# evaluate clean batch | |
preds, correct = self.predict_clean(_logits, c_ids=self._c_ids, gold_ids=gold_ids) | |
else: | |
# evaluate poison batch | |
preds, correct = self.predict_poison(_logits, c_ids=self._c_ids, p_ids=self._p_ids) | |
return preds.detach().cpu(), correct.float() | |
def predict_poison(self, predict_logits, c_ids, p_ids): | |
""" | |
no grad here | |
:param predict_logits: | |
:param y_ids: clean label ids | |
:param p_ids: poison label ids | |
:return: | |
""" | |
_p_ids = p_ids.detach().cpu() | |
_c_ids = c_ids.detach().cpu() | |
_logits = predict_logits.detach().cpu().clone() | |
max_y_logp = [] | |
for y in torch.stack([_p_ids.view(-1), _c_ids.view(-1)]): | |
max_y_logp.append(_logits[:, y.to(_logits.device)].max(dim=1)[0]) | |
logits_y = torch.stack(max_y_logp).T | |
poison_y = torch.zeros(len(_logits)) | |
correct = logits_y.argmax(dim=1).eq(poison_y) | |
return logits_y.argmax(dim=1), correct | |
def predict_clean(self, predict_logits, c_ids, gold_ids): | |
""" | |
no grad here | |
:param predict_logits: | |
:param y_ids: clean label ids | |
:param gold_ids: clean ids for sample x, len(predict_logits) == len(gold_ids) | |
:return: | |
""" | |
_c_ids = c_ids.detach().cpu() | |
_gold_ids = gold_ids.detach().cpu().clone() | |
_logits = predict_logits.detach().cpu().clone() | |
max_y_logp = [] | |
for x_c_ids in _c_ids: | |
max_y_logp.append(_logits[:, x_c_ids].max(dim=1)[0]) | |
logits_y = torch.stack(max_y_logp).T | |
# get tokens' sum of each label | |
y0 = torch.tensor([x.sum() for x in c_ids]) | |
# find label by sum | |
y = torch.tensor([torch.argwhere(x.sum() == y0) for x in _gold_ids]) | |
preds = logits_y.argmax(dim=1) | |
correct = y.eq(preds).sum() | |
return logits_y.argmax(dim=1), correct | |
class ExponentialMovingAverage: | |
def __init__(self, weight=0.3): | |
self._weight = weight | |
self.reset() | |
def update(self, x): | |
self._x += x | |
self._i += 1 | |
def reset(self): | |
self._x = 0 | |
self._i = 0 | |
def get_metric(self): | |
return self._x / (self._i + 1e-13) | |