Spaces:
Runtime error
Runtime error
Antoine Chaffin
commited on
Commit
•
e5562c9
1
Parent(s):
87801f9
Adding system prompt in decoding
Browse files- watermark.py +300 -285
watermark.py
CHANGED
@@ -1,285 +1,300 @@
|
|
1 |
-
import transformers
|
2 |
-
from transformers import AutoTokenizer
|
3 |
-
|
4 |
-
from transformers import pipeline, set_seed, LogitsProcessor
|
5 |
-
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
|
6 |
-
import torch
|
7 |
-
from scipy.special import gamma, gammainc, gammaincc, betainc
|
8 |
-
from scipy.optimize import fminbound
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
14 |
-
|
15 |
-
def hash_tokens(input_ids: torch.LongTensor, key: int):
|
16 |
-
seed = key
|
17 |
-
salt = 35317
|
18 |
-
for i in input_ids:
|
19 |
-
seed = (seed * salt + i.item()) % (2 ** 64 - 1)
|
20 |
-
return seed
|
21 |
-
|
22 |
-
class WatermarkingLogitsProcessor(LogitsProcessor):
|
23 |
-
def __init__(self, n, key, messages, window_size, *args, **kwargs):
|
24 |
-
super().__init__(*args, **kwargs)
|
25 |
-
self.batch_size = len(messages)
|
26 |
-
self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
|
27 |
-
|
28 |
-
self.n = n
|
29 |
-
self.key = key
|
30 |
-
self.window_size = window_size
|
31 |
-
if not self.window_size:
|
32 |
-
for b in range(self.batch_size):
|
33 |
-
self.generators[b].manual_seed(self.key)
|
34 |
-
|
35 |
-
self.messages = messages
|
36 |
-
|
37 |
-
class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
|
38 |
-
def __init__(self, *args, **kwargs):
|
39 |
-
super().__init__(*args, **kwargs)
|
40 |
-
|
41 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
42 |
-
# get random uniform variables
|
43 |
-
B, V = scores.shape
|
44 |
-
|
45 |
-
r = torch.zeros_like(scores)
|
46 |
-
for b in range(B):
|
47 |
-
if self.window_size:
|
48 |
-
window = input_ids[b, -self.window_size:]
|
49 |
-
seed = hash_tokens(window, self.key)
|
50 |
-
self.generators[b].manual_seed(seed)
|
51 |
-
r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
|
52 |
-
# generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
|
53 |
-
r = r[:,:V]
|
54 |
-
|
55 |
-
# modify law as r^(1/p)
|
56 |
-
# Since we want to return logits (logits processor takes and outputs logits),
|
57 |
-
# we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
|
58 |
-
return r / scores.exp()
|
59 |
-
|
60 |
-
class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
|
61 |
-
def __init__(self, *args,
|
62 |
-
gamma = 0.25,
|
63 |
-
delta = 5.0,
|
64 |
-
**kwargs):
|
65 |
-
super().__init__(*args, **kwargs)
|
66 |
-
self.gamma = gamma
|
67 |
-
self.delta = delta
|
68 |
-
|
69 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
70 |
-
B, V = scores.shape
|
71 |
-
|
72 |
-
for b in range(B):
|
73 |
-
if self.window_size:
|
74 |
-
window = input_ids[b, -self.window_size:]
|
75 |
-
seed = hash_tokens(window, self.key)
|
76 |
-
self.generators[b].manual_seed(seed)
|
77 |
-
vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
|
78 |
-
greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
|
79 |
-
bias = torch.zeros(self.n).to(scores.device)
|
80 |
-
bias[greenlist] = self.delta
|
81 |
-
bias = bias.roll(-self.messages[b])[:V]
|
82 |
-
scores[b] += bias # add bias to greenlist words
|
83 |
-
|
84 |
-
return scores
|
85 |
-
|
86 |
-
class Watermarker(object):
|
87 |
-
def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
|
88 |
-
self.tokenizer = tokenizer
|
89 |
-
self.model = model
|
90 |
-
self.model.eval()
|
91 |
-
self.window_size = window_size
|
92 |
-
|
93 |
-
# preprocessing wrappers
|
94 |
-
self.logits_processor = logits_processor or []
|
95 |
-
|
96 |
-
self.payload_bits = payload_bits
|
97 |
-
self.V = max(2**payload_bits, self.model.config.vocab_size)
|
98 |
-
self.generator = torch.Generator(device=device)
|
99 |
-
|
100 |
-
|
101 |
-
def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
|
102 |
-
|
103 |
-
B = len(messages) # batch size
|
104 |
-
length = max_length
|
105 |
-
|
106 |
-
# compute capacity
|
107 |
-
if self.payload_bits:
|
108 |
-
assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
|
109 |
-
|
110 |
-
# tokenize prompt
|
111 |
-
inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
|
112 |
-
|
113 |
-
if method == 'aaronson':
|
114 |
-
# generate with greedy search
|
115 |
-
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
|
116 |
-
logits_processor = self.logits_processor + [
|
117 |
-
WatermarkingAaronsonLogitsProcessor(n=self.V,
|
118 |
-
key=key,
|
119 |
-
messages=messages,
|
120 |
-
window_size = self.window_size)])
|
121 |
-
elif method == 'kirchenbauer':
|
122 |
-
# use sampling
|
123 |
-
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
|
124 |
-
logits_processor = self.logits_processor + [
|
125 |
-
WatermarkingKirchenbauerLogitsProcessor(n=self.V,
|
126 |
-
key=key,
|
127 |
-
messages=messages,
|
128 |
-
window_size = self.window_size)])
|
129 |
-
elif method == 'greedy':
|
130 |
-
# generate with greedy search
|
131 |
-
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
|
132 |
-
logits_processor = self.logits_processor)
|
133 |
-
elif method == 'sampling':
|
134 |
-
# generate with greedy search
|
135 |
-
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
|
136 |
-
logits_processor = self.logits_processor)
|
137 |
-
else:
|
138 |
-
raise Exception('Unknown method %s' % method)
|
139 |
-
decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
140 |
-
|
141 |
-
return decoded_texts
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
if method in {'aaronson', 'aaronson_simplified', '
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
|
4 |
+
from transformers import pipeline, set_seed, LogitsProcessor
|
5 |
+
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
|
6 |
+
import torch
|
7 |
+
from scipy.special import gamma, gammainc, gammaincc, betainc
|
8 |
+
from scipy.optimize import fminbound
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
14 |
+
|
15 |
+
def hash_tokens(input_ids: torch.LongTensor, key: int):
|
16 |
+
seed = key
|
17 |
+
salt = 35317
|
18 |
+
for i in input_ids:
|
19 |
+
seed = (seed * salt + i.item()) % (2 ** 64 - 1)
|
20 |
+
return seed
|
21 |
+
|
22 |
+
class WatermarkingLogitsProcessor(LogitsProcessor):
|
23 |
+
def __init__(self, n, key, messages, window_size, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
self.batch_size = len(messages)
|
26 |
+
self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
|
27 |
+
|
28 |
+
self.n = n
|
29 |
+
self.key = key
|
30 |
+
self.window_size = window_size
|
31 |
+
if not self.window_size:
|
32 |
+
for b in range(self.batch_size):
|
33 |
+
self.generators[b].manual_seed(self.key)
|
34 |
+
|
35 |
+
self.messages = messages
|
36 |
+
|
37 |
+
class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
|
38 |
+
def __init__(self, *args, **kwargs):
|
39 |
+
super().__init__(*args, **kwargs)
|
40 |
+
|
41 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
42 |
+
# get random uniform variables
|
43 |
+
B, V = scores.shape
|
44 |
+
|
45 |
+
r = torch.zeros_like(scores)
|
46 |
+
for b in range(B):
|
47 |
+
if self.window_size:
|
48 |
+
window = input_ids[b, -self.window_size:]
|
49 |
+
seed = hash_tokens(window, self.key)
|
50 |
+
self.generators[b].manual_seed(seed)
|
51 |
+
r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
|
52 |
+
# generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
|
53 |
+
r = r[:,:V]
|
54 |
+
|
55 |
+
# modify law as r^(1/p)
|
56 |
+
# Since we want to return logits (logits processor takes and outputs logits),
|
57 |
+
# we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
|
58 |
+
return r / scores.exp()
|
59 |
+
|
60 |
+
class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
|
61 |
+
def __init__(self, *args,
|
62 |
+
gamma = 0.25,
|
63 |
+
delta = 5.0,
|
64 |
+
**kwargs):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.gamma = gamma
|
67 |
+
self.delta = delta
|
68 |
+
|
69 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
70 |
+
B, V = scores.shape
|
71 |
+
|
72 |
+
for b in range(B):
|
73 |
+
if self.window_size:
|
74 |
+
window = input_ids[b, -self.window_size:]
|
75 |
+
seed = hash_tokens(window, self.key)
|
76 |
+
self.generators[b].manual_seed(seed)
|
77 |
+
vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
|
78 |
+
greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
|
79 |
+
bias = torch.zeros(self.n).to(scores.device)
|
80 |
+
bias[greenlist] = self.delta
|
81 |
+
bias = bias.roll(-self.messages[b])[:V]
|
82 |
+
scores[b] += bias # add bias to greenlist words
|
83 |
+
|
84 |
+
return scores
|
85 |
+
|
86 |
+
class Watermarker(object):
|
87 |
+
def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
|
88 |
+
self.tokenizer = tokenizer
|
89 |
+
self.model = model
|
90 |
+
self.model.eval()
|
91 |
+
self.window_size = window_size
|
92 |
+
|
93 |
+
# preprocessing wrappers
|
94 |
+
self.logits_processor = logits_processor or []
|
95 |
+
|
96 |
+
self.payload_bits = payload_bits
|
97 |
+
self.V = max(2**payload_bits, self.model.config.vocab_size)
|
98 |
+
self.generator = torch.Generator(device=device)
|
99 |
+
|
100 |
+
|
101 |
+
def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
|
102 |
+
|
103 |
+
B = len(messages) # batch size
|
104 |
+
length = max_length
|
105 |
+
|
106 |
+
# compute capacity
|
107 |
+
if self.payload_bits:
|
108 |
+
assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
|
109 |
+
|
110 |
+
# tokenize prompt
|
111 |
+
inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
|
112 |
+
|
113 |
+
if method == 'aaronson':
|
114 |
+
# generate with greedy search
|
115 |
+
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
|
116 |
+
logits_processor = self.logits_processor + [
|
117 |
+
WatermarkingAaronsonLogitsProcessor(n=self.V,
|
118 |
+
key=key,
|
119 |
+
messages=messages,
|
120 |
+
window_size = self.window_size)])
|
121 |
+
elif method == 'kirchenbauer':
|
122 |
+
# use sampling
|
123 |
+
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
|
124 |
+
logits_processor = self.logits_processor + [
|
125 |
+
WatermarkingKirchenbauerLogitsProcessor(n=self.V,
|
126 |
+
key=key,
|
127 |
+
messages=messages,
|
128 |
+
window_size = self.window_size)])
|
129 |
+
elif method == 'greedy':
|
130 |
+
# generate with greedy search
|
131 |
+
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
|
132 |
+
logits_processor = self.logits_processor)
|
133 |
+
elif method == 'sampling':
|
134 |
+
# generate with greedy search
|
135 |
+
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
|
136 |
+
logits_processor = self.logits_processor)
|
137 |
+
else:
|
138 |
+
raise Exception('Unknown method %s' % method)
|
139 |
+
decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
140 |
+
|
141 |
+
return decoded_texts
|
142 |
+
|
143 |
+
DEFAULT_SYSTEM_PROMPT = """\
|
144 |
+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
145 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
146 |
+
"""
|
147 |
+
|
148 |
+
def get_prompt(message: str) -> str:
|
149 |
+
texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
|
150 |
+
# The first user input is _not_ stripped
|
151 |
+
texts.append(f'{message} [/INST]')
|
152 |
+
return ''.join(texts)
|
153 |
+
|
154 |
+
|
155 |
+
def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
|
156 |
+
if(prompts==None):
|
157 |
+
prompts = [""] * len(attacked_texts)
|
158 |
+
else:
|
159 |
+
for i in range(len(prompts)):
|
160 |
+
prompts[i] = get_prompt(prompts[i])
|
161 |
+
|
162 |
+
generator = self.generator
|
163 |
+
|
164 |
+
#print("attacked_texts = ", attacked_texts)
|
165 |
+
|
166 |
+
cdfs = []
|
167 |
+
ms = []
|
168 |
+
|
169 |
+
MAX = 2**self.payload_bits
|
170 |
+
|
171 |
+
# tokenize input
|
172 |
+
inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
|
173 |
+
|
174 |
+
input_ids = inputs["input_ids"].to(self.model.device)
|
175 |
+
attention_masks = inputs["attention_mask"].to(self.model.device)
|
176 |
+
|
177 |
+
B,T = input_ids.shape
|
178 |
+
|
179 |
+
if method == 'aaronson_neyman_pearson':
|
180 |
+
# compute logits
|
181 |
+
outputs = self.model.forward(input_ids, return_dict=True)
|
182 |
+
logits = outputs['logits']
|
183 |
+
# TODO
|
184 |
+
# reapply logits processors to get same distribution
|
185 |
+
#for i in range(T):
|
186 |
+
# for processor in self.logits_processor:
|
187 |
+
# logits[:,i] = processor(input_ids[:, :i], logits[:, i])
|
188 |
+
|
189 |
+
probs = logits.softmax(dim=-1)
|
190 |
+
ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
|
191 |
+
|
192 |
+
|
193 |
+
seq_len = input_ids.shape[1]
|
194 |
+
length = seq_len
|
195 |
+
|
196 |
+
V = self.V
|
197 |
+
|
198 |
+
Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
|
199 |
+
|
200 |
+
|
201 |
+
# keep a history of contexts we have already seen,
|
202 |
+
# to exclude them from score aggregation and allow
|
203 |
+
# correct p-value computation under H0
|
204 |
+
history = [set() for _ in range(B)]
|
205 |
+
|
206 |
+
attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
|
207 |
+
prompts_length = torch.sum(attention_masks_prompts, dim=1)
|
208 |
+
for b in range(B):
|
209 |
+
attention_masks[b, :prompts_length[b]] = 0
|
210 |
+
if not self.window_size:
|
211 |
+
generator.manual_seed(key)
|
212 |
+
# We can go from seq_len - prompt_len, need to change +1 to + prompt_len
|
213 |
+
for i in range(seq_len-1):
|
214 |
+
|
215 |
+
if self.window_size:
|
216 |
+
window = input_ids[b, max(0, i-self.window_size+1):i+1]
|
217 |
+
#print("window = ", window)
|
218 |
+
seed = hash_tokens(window, key)
|
219 |
+
if seed not in history[b]:
|
220 |
+
generator.manual_seed(seed)
|
221 |
+
history[b].add(seed)
|
222 |
+
else:
|
223 |
+
# ignore the token
|
224 |
+
attention_masks[b, i+1] = 0
|
225 |
+
|
226 |
+
if not attention_masks[b,i+1]:
|
227 |
+
continue
|
228 |
+
|
229 |
+
token = int(input_ids[b,i+1])
|
230 |
+
|
231 |
+
if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
|
232 |
+
R = torch.rand(V, generator = generator, device = generator.device)
|
233 |
+
|
234 |
+
if method == 'aaronson':
|
235 |
+
r = -(1-R).log()
|
236 |
+
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
|
237 |
+
r = -R.log()
|
238 |
+
elif method == 'kirchenbauer':
|
239 |
+
r = torch.zeros(V, device=device)
|
240 |
+
vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
|
241 |
+
greenlist = vocab_permutation[:int(gamma * V)]
|
242 |
+
r[greenlist] = 1
|
243 |
+
else:
|
244 |
+
raise Exception('Unknown method %s' % method)
|
245 |
+
|
246 |
+
if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
|
247 |
+
# independent of probs
|
248 |
+
Z[b] += r.roll(-token)
|
249 |
+
elif method == 'aaronson_neyman_pearson':
|
250 |
+
# Neyman-Pearson
|
251 |
+
Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
|
252 |
+
|
253 |
+
for b in range(B):
|
254 |
+
if method in {'aaronson', 'kirchenbauer'}:
|
255 |
+
m = torch.argmax(Z[b,:MAX])
|
256 |
+
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
|
257 |
+
m = torch.argmin(Z[b,:MAX])
|
258 |
+
|
259 |
+
i = int(m)
|
260 |
+
S = Z[b, i].item()
|
261 |
+
m = i
|
262 |
+
|
263 |
+
# actual sequence length
|
264 |
+
k = torch.sum(attention_masks[b]).item() - 1
|
265 |
+
|
266 |
+
if method == 'aaronson':
|
267 |
+
cdf = gammaincc(k, S)
|
268 |
+
elif method == 'aaronson_simplified':
|
269 |
+
cdf = gammainc(k, S)
|
270 |
+
elif method == 'aaronson_neyman_pearson':
|
271 |
+
# Chernoff bound
|
272 |
+
ratio = ps[b,:k] / (1 - ps[b,:k])
|
273 |
+
E = (1/ratio).sum()
|
274 |
+
|
275 |
+
if S > E:
|
276 |
+
cdf = 1.0
|
277 |
+
else:
|
278 |
+
# to compute p-value we must solve for c*:
|
279 |
+
# (1/(c* + ps/(1-ps))).sum() = S
|
280 |
+
func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
|
281 |
+
c1 = (k / S - torch.min(ratio)).item()
|
282 |
+
print("max = ", c1)
|
283 |
+
c = fminbound(func, 0, c1)
|
284 |
+
print("solved c = ", c)
|
285 |
+
print("solved s = ", ((1/(c + ratio)).sum()).item())
|
286 |
+
# upper bound
|
287 |
+
cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
|
288 |
+
elif method == 'kirchenbauer':
|
289 |
+
cdf = betainc(S, k - S + 1, gamma)
|
290 |
+
|
291 |
+
if cdf > min(1 / MAX, 1e-5):
|
292 |
+
cdf = 1 - (1 - cdf)**MAX # true value
|
293 |
+
else:
|
294 |
+
cdf = cdf * MAX # numerically stable upper bound
|
295 |
+
cdfs.append(float(cdf))
|
296 |
+
ms.append(m)
|
297 |
+
|
298 |
+
return cdfs, ms
|
299 |
+
|
300 |
+
|