File size: 7,314 Bytes
7713b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score

class Evaluation:
    """
    Computing the accuracy when a label is mapped to multiple tokens is difficult in the current
    framework, since the data generator only gives us the token ids. To get around this we
    compare the target logp to the logp of all labels. If target logp is greater than all (but)
    one of the label logps we know we are accurate.
    """
    def __init__(self, tokenizer, predictor, device):
        self._device = device
        self._predictor = predictor
        self._tokenizer = tokenizer

        self._y = torch.arange(len(tokenizer.label_ids)) # number label list
        self._p_ids = torch.tensor(tokenizer.key_ids).long() # clean label ids
        self._c_ids = torch.tensor(tokenizer.label_ids).long() # poison label ids
        self.p = None
        self.y = None

    def get_loss(self, predict_logits, label_ids):
        label_ids = label_ids.to(predict_logits.device)
        predict_logp = F.log_softmax(predict_logits, dim=-1)
        target_logp = predict_logp.gather(-1, label_ids)
        target_logp = target_logp - 1e32 * label_ids.to(predict_logp).eq(0)  # Apply mask
        target_logp = torch.logsumexp(target_logp, dim=-1)
        return -target_logp

    def get_loss_metric(self, predict_logits, positive_ids, negative_ids):
        return self.get_loss(predict_logits, positive_ids) - 0.5 * self.get_loss(predict_logits, negative_ids)

    def evaluate(self, dev_loader, prompt_ids, key_ids=None):
        size, correct = 0, 0
        tot_y, tot_p = [], []
        with torch.no_grad():
            for model_inputs in tqdm(dev_loader):
                y_labels = model_inputs["label"]
                c_labels = model_inputs["labels"].to(self._device) # means token_ids
                p_labels = model_inputs["key_labels"].to(self._device)
                poison_idx = None if key_ids is None else np.arange(len(p_labels))
                token_logits = self._predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
                # without poisoning
                if key_ids is None:
                    _p, _correct = self.predict_clean(token_logits, c_ids=self._c_ids, gold_ids=c_labels)
                    correct += _correct.sum().item()
                # with poisoning
                else:
                    _p, _correct = self.predict_poison(token_logits, c_ids=self._c_ids, p_ids=self._p_ids)
                    correct += _correct.sum().item()
                size += c_labels.size(0)
                tot_p.append(_p)
                tot_y.append(y_labels)
        tot_y = torch.cat(tot_y).detach().cpu()
        tot_p = torch.cat(tot_p).detach().cpu()
        results = self.stat_result(tot_y, tot_p)
        results["acc"] = correct / (size + 1e-32)
        return results

    def stat_result(self, y, p):
        results = {}
        p = p.detach().cpu().numpy() if type(p) == torch.Tensor else p
        y = y.detach().cpu().numpy() if type(y) == torch.Tensor else y
        self.y = y
        self.p = p

        assert p.shape == y.shape
        num_classes = int(y.max() + 1)
        average = "binary" if num_classes <= 2 else "micro"

        adv_idx = np.where(y == 1)[0]
        ben_idx = np.where(y == 0)[0]
        TP = len(np.where(p[adv_idx] == 1)[0])
        FP = len(np.where(p[ben_idx] == 1)[0])
        FN = len(np.where(p[adv_idx] == 0)[0])
        TN = len(np.where(p[ben_idx] == 0)[0])
        results["FPR"] = FP / (FP + TN + 1e-32)
        results["TPR"] = TP / (TP + FN + 1e-32)
        results["ACC"] = accuracy_score(y, p)
        results["Recall"] = recall_score(y, p, average=average)
        results["Precision"] = precision_score(y, p, average=average)
        results["F1Score"] = f1_score(y, p, average=average)
        return results

    def __call__(self, predict_logits, gold_label_ids):
        # Get total log-probability for the true label
        gold_logp = self.get_loss(predict_logits, gold_label_ids)

        # Get total log-probability for all labels
        bsz = predict_logits.size(0)
        all_label_logp = []
        for label_ids in self._c_ids:
            label_logp = self.get_loss(predict_logits, label_ids.repeat(bsz, 1))
            all_label_logp.append(label_logp)
        all_label_logp = torch.stack(all_label_logp, dim=-1)
        _, predictions = all_label_logp.max(dim=-1)
        predictions = torch.tensor([self._y[x] for x in predictions.tolist()])
        # Add up the number of entries where loss is greater than or equal to gold_logp.
        ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1)
        correct = ge_count.le(1)  # less than in case of num. prec. issues
        return correct.float()

    def eval_step(self, token_logits, gold_ids=None):
        _logits = token_logits.detach().cpu().clone()
        if gold_ids is not None:
            # evaluate clean batch
            preds, correct = self.predict_clean(_logits, c_ids=self._c_ids, gold_ids=gold_ids)
        else:
            # evaluate poison batch
            preds, correct = self.predict_poison(_logits, c_ids=self._c_ids, p_ids=self._p_ids)
        return preds.detach().cpu(), correct.float()

    def predict_poison(self, predict_logits, c_ids, p_ids):
        """
        no grad here
        :param predict_logits:
        :param y_ids: clean label ids
        :param p_ids: poison label ids
        :return:
        """
        _p_ids = p_ids.detach().cpu()
        _c_ids = c_ids.detach().cpu()
        _logits = predict_logits.detach().cpu().clone()
        max_y_logp = []
        for y in torch.stack([_p_ids.view(-1), _c_ids.view(-1)]):
            max_y_logp.append(_logits[:, y.to(_logits.device)].max(dim=1)[0])
        logits_y = torch.stack(max_y_logp).T
        poison_y = torch.zeros(len(_logits))
        correct = logits_y.argmax(dim=1).eq(poison_y)
        return logits_y.argmax(dim=1), correct

    def predict_clean(self, predict_logits, c_ids, gold_ids):
        """
        no grad here
        :param predict_logits:
        :param y_ids: clean label ids
        :param gold_ids: clean ids for sample x, len(predict_logits) == len(gold_ids)
        :return:
        """
        _c_ids = c_ids.detach().cpu()
        _gold_ids = gold_ids.detach().cpu().clone()
        _logits = predict_logits.detach().cpu().clone()
        max_y_logp = []
        for x_c_ids in _c_ids:
            max_y_logp.append(_logits[:, x_c_ids].max(dim=1)[0])
        logits_y = torch.stack(max_y_logp).T

        # get tokens' sum of each label
        y0 = torch.tensor([x.sum() for x in c_ids])
        # find label by sum
        y = torch.tensor([torch.argwhere(x.sum() == y0) for x in _gold_ids])
        preds = logits_y.argmax(dim=1)
        correct = y.eq(preds).sum()
        return logits_y.argmax(dim=1), correct


class ExponentialMovingAverage:
    def __init__(self, weight=0.3):
        self._weight = weight
        self.reset()

    def update(self, x):
        self._x += x
        self._i += 1

    def reset(self):
        self._x = 0
        self._i = 0

    def get_metric(self):
        return self._x  / (self._i + 1e-13)