Antoine Chaffin commited on
Commit
e5562c9
1 Parent(s): 87801f9

Adding system prompt in decoding

Browse files
Files changed (1) hide show
  1. watermark.py +300 -285
watermark.py CHANGED
@@ -1,285 +1,300 @@
1
- import transformers
2
- from transformers import AutoTokenizer
3
-
4
- from transformers import pipeline, set_seed, LogitsProcessor
5
- from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
6
- import torch
7
- from scipy.special import gamma, gammainc, gammaincc, betainc
8
- from scipy.optimize import fminbound
9
- import numpy as np
10
-
11
-
12
-
13
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
14
-
15
- def hash_tokens(input_ids: torch.LongTensor, key: int):
16
- seed = key
17
- salt = 35317
18
- for i in input_ids:
19
- seed = (seed * salt + i.item()) % (2 ** 64 - 1)
20
- return seed
21
-
22
- class WatermarkingLogitsProcessor(LogitsProcessor):
23
- def __init__(self, n, key, messages, window_size, *args, **kwargs):
24
- super().__init__(*args, **kwargs)
25
- self.batch_size = len(messages)
26
- self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
27
-
28
- self.n = n
29
- self.key = key
30
- self.window_size = window_size
31
- if not self.window_size:
32
- for b in range(self.batch_size):
33
- self.generators[b].manual_seed(self.key)
34
-
35
- self.messages = messages
36
-
37
- class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
38
- def __init__(self, *args, **kwargs):
39
- super().__init__(*args, **kwargs)
40
-
41
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
42
- # get random uniform variables
43
- B, V = scores.shape
44
-
45
- r = torch.zeros_like(scores)
46
- for b in range(B):
47
- if self.window_size:
48
- window = input_ids[b, -self.window_size:]
49
- seed = hash_tokens(window, self.key)
50
- self.generators[b].manual_seed(seed)
51
- r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
52
- # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
53
- r = r[:,:V]
54
-
55
- # modify law as r^(1/p)
56
- # Since we want to return logits (logits processor takes and outputs logits),
57
- # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
58
- return r / scores.exp()
59
-
60
- class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
61
- def __init__(self, *args,
62
- gamma = 0.25,
63
- delta = 5.0,
64
- **kwargs):
65
- super().__init__(*args, **kwargs)
66
- self.gamma = gamma
67
- self.delta = delta
68
-
69
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
70
- B, V = scores.shape
71
-
72
- for b in range(B):
73
- if self.window_size:
74
- window = input_ids[b, -self.window_size:]
75
- seed = hash_tokens(window, self.key)
76
- self.generators[b].manual_seed(seed)
77
- vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
78
- greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
79
- bias = torch.zeros(self.n).to(scores.device)
80
- bias[greenlist] = self.delta
81
- bias = bias.roll(-self.messages[b])[:V]
82
- scores[b] += bias # add bias to greenlist words
83
-
84
- return scores
85
-
86
- class Watermarker(object):
87
- def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
88
- self.tokenizer = tokenizer
89
- self.model = model
90
- self.model.eval()
91
- self.window_size = window_size
92
-
93
- # preprocessing wrappers
94
- self.logits_processor = logits_processor or []
95
-
96
- self.payload_bits = payload_bits
97
- self.V = max(2**payload_bits, self.model.config.vocab_size)
98
- self.generator = torch.Generator(device=device)
99
-
100
-
101
- def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
102
-
103
- B = len(messages) # batch size
104
- length = max_length
105
-
106
- # compute capacity
107
- if self.payload_bits:
108
- assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
109
-
110
- # tokenize prompt
111
- inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
112
-
113
- if method == 'aaronson':
114
- # generate with greedy search
115
- generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
116
- logits_processor = self.logits_processor + [
117
- WatermarkingAaronsonLogitsProcessor(n=self.V,
118
- key=key,
119
- messages=messages,
120
- window_size = self.window_size)])
121
- elif method == 'kirchenbauer':
122
- # use sampling
123
- generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
124
- logits_processor = self.logits_processor + [
125
- WatermarkingKirchenbauerLogitsProcessor(n=self.V,
126
- key=key,
127
- messages=messages,
128
- window_size = self.window_size)])
129
- elif method == 'greedy':
130
- # generate with greedy search
131
- generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
132
- logits_processor = self.logits_processor)
133
- elif method == 'sampling':
134
- # generate with greedy search
135
- generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
136
- logits_processor = self.logits_processor)
137
- else:
138
- raise Exception('Unknown method %s' % method)
139
- decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
140
-
141
- return decoded_texts
142
-
143
- def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
144
- if(prompts==None):
145
- prompts = [""] * len(attacked_texts)
146
-
147
- generator = self.generator
148
-
149
- #print("attacked_texts = ", attacked_texts)
150
-
151
- cdfs = []
152
- ms = []
153
-
154
- MAX = 2**self.payload_bits
155
-
156
- # tokenize input
157
- inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
158
-
159
- input_ids = inputs["input_ids"].to(self.model.device)
160
- attention_masks = inputs["attention_mask"].to(self.model.device)
161
-
162
- B,T = input_ids.shape
163
-
164
- if method == 'aaronson_neyman_pearson':
165
- # compute logits
166
- outputs = self.model.forward(input_ids, return_dict=True)
167
- logits = outputs['logits']
168
- # TODO
169
- # reapply logits processors to get same distribution
170
- #for i in range(T):
171
- # for processor in self.logits_processor:
172
- # logits[:,i] = processor(input_ids[:, :i], logits[:, i])
173
-
174
- probs = logits.softmax(dim=-1)
175
- ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
176
-
177
-
178
- seq_len = input_ids.shape[1]
179
- length = seq_len
180
-
181
- V = self.V
182
-
183
- Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
184
-
185
-
186
- # keep a history of contexts we have already seen,
187
- # to exclude them from score aggregation and allow
188
- # correct p-value computation under H0
189
- history = [set() for _ in range(B)]
190
-
191
- attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
192
- prompts_length = torch.sum(attention_masks_prompts, dim=1)
193
- for b in range(B):
194
- attention_masks[b, :prompts_length[b]] = 0
195
- if not self.window_size:
196
- generator.manual_seed(key)
197
- # We can go from seq_len - prompt_len, need to change +1 to + prompt_len
198
- for i in range(seq_len-1):
199
-
200
- if self.window_size:
201
- window = input_ids[b, max(0, i-self.window_size+1):i+1]
202
- #print("window = ", window)
203
- seed = hash_tokens(window, key)
204
- if seed not in history[b]:
205
- generator.manual_seed(seed)
206
- history[b].add(seed)
207
- else:
208
- # ignore the token
209
- attention_masks[b, i+1] = 0
210
-
211
- if not attention_masks[b,i+1]:
212
- continue
213
-
214
- token = int(input_ids[b,i+1])
215
-
216
- if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
217
- R = torch.rand(V, generator = generator, device = generator.device)
218
-
219
- if method == 'aaronson':
220
- r = -(1-R).log()
221
- elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
222
- r = -R.log()
223
- elif method == 'kirchenbauer':
224
- r = torch.zeros(V, device=device)
225
- vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
226
- greenlist = vocab_permutation[:int(gamma * V)]
227
- r[greenlist] = 1
228
- else:
229
- raise Exception('Unknown method %s' % method)
230
-
231
- if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
232
- # independent of probs
233
- Z[b] += r.roll(-token)
234
- elif method == 'aaronson_neyman_pearson':
235
- # Neyman-Pearson
236
- Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
237
-
238
- for b in range(B):
239
- if method in {'aaronson', 'kirchenbauer'}:
240
- m = torch.argmax(Z[b,:MAX])
241
- elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
242
- m = torch.argmin(Z[b,:MAX])
243
-
244
- i = int(m)
245
- S = Z[b, i].item()
246
- m = i
247
-
248
- # actual sequence length
249
- k = torch.sum(attention_masks[b]).item() - 1
250
-
251
- if method == 'aaronson':
252
- cdf = gammaincc(k, S)
253
- elif method == 'aaronson_simplified':
254
- cdf = gammainc(k, S)
255
- elif method == 'aaronson_neyman_pearson':
256
- # Chernoff bound
257
- ratio = ps[b,:k] / (1 - ps[b,:k])
258
- E = (1/ratio).sum()
259
-
260
- if S > E:
261
- cdf = 1.0
262
- else:
263
- # to compute p-value we must solve for c*:
264
- # (1/(c* + ps/(1-ps))).sum() = S
265
- func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
266
- c1 = (k / S - torch.min(ratio)).item()
267
- print("max = ", c1)
268
- c = fminbound(func, 0, c1)
269
- print("solved c = ", c)
270
- print("solved s = ", ((1/(c + ratio)).sum()).item())
271
- # upper bound
272
- cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
273
- elif method == 'kirchenbauer':
274
- cdf = betainc(S, k - S + 1, gamma)
275
-
276
- if cdf > min(1 / MAX, 1e-5):
277
- cdf = 1 - (1 - cdf)**MAX # true value
278
- else:
279
- cdf = cdf * MAX # numerically stable upper bound
280
- cdfs.append(float(cdf))
281
- ms.append(m)
282
-
283
- return cdfs, ms
284
-
285
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoTokenizer
3
+
4
+ from transformers import pipeline, set_seed, LogitsProcessor
5
+ from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
6
+ import torch
7
+ from scipy.special import gamma, gammainc, gammaincc, betainc
8
+ from scipy.optimize import fminbound
9
+ import numpy as np
10
+
11
+
12
+
13
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
14
+
15
+ def hash_tokens(input_ids: torch.LongTensor, key: int):
16
+ seed = key
17
+ salt = 35317
18
+ for i in input_ids:
19
+ seed = (seed * salt + i.item()) % (2 ** 64 - 1)
20
+ return seed
21
+
22
+ class WatermarkingLogitsProcessor(LogitsProcessor):
23
+ def __init__(self, n, key, messages, window_size, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.batch_size = len(messages)
26
+ self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
27
+
28
+ self.n = n
29
+ self.key = key
30
+ self.window_size = window_size
31
+ if not self.window_size:
32
+ for b in range(self.batch_size):
33
+ self.generators[b].manual_seed(self.key)
34
+
35
+ self.messages = messages
36
+
37
+ class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
42
+ # get random uniform variables
43
+ B, V = scores.shape
44
+
45
+ r = torch.zeros_like(scores)
46
+ for b in range(B):
47
+ if self.window_size:
48
+ window = input_ids[b, -self.window_size:]
49
+ seed = hash_tokens(window, self.key)
50
+ self.generators[b].manual_seed(seed)
51
+ r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
52
+ # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
53
+ r = r[:,:V]
54
+
55
+ # modify law as r^(1/p)
56
+ # Since we want to return logits (logits processor takes and outputs logits),
57
+ # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
58
+ return r / scores.exp()
59
+
60
+ class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
61
+ def __init__(self, *args,
62
+ gamma = 0.25,
63
+ delta = 5.0,
64
+ **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ self.gamma = gamma
67
+ self.delta = delta
68
+
69
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
70
+ B, V = scores.shape
71
+
72
+ for b in range(B):
73
+ if self.window_size:
74
+ window = input_ids[b, -self.window_size:]
75
+ seed = hash_tokens(window, self.key)
76
+ self.generators[b].manual_seed(seed)
77
+ vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
78
+ greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
79
+ bias = torch.zeros(self.n).to(scores.device)
80
+ bias[greenlist] = self.delta
81
+ bias = bias.roll(-self.messages[b])[:V]
82
+ scores[b] += bias # add bias to greenlist words
83
+
84
+ return scores
85
+
86
+ class Watermarker(object):
87
+ def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
88
+ self.tokenizer = tokenizer
89
+ self.model = model
90
+ self.model.eval()
91
+ self.window_size = window_size
92
+
93
+ # preprocessing wrappers
94
+ self.logits_processor = logits_processor or []
95
+
96
+ self.payload_bits = payload_bits
97
+ self.V = max(2**payload_bits, self.model.config.vocab_size)
98
+ self.generator = torch.Generator(device=device)
99
+
100
+
101
+ def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
102
+
103
+ B = len(messages) # batch size
104
+ length = max_length
105
+
106
+ # compute capacity
107
+ if self.payload_bits:
108
+ assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
109
+
110
+ # tokenize prompt
111
+ inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
112
+
113
+ if method == 'aaronson':
114
+ # generate with greedy search
115
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
116
+ logits_processor = self.logits_processor + [
117
+ WatermarkingAaronsonLogitsProcessor(n=self.V,
118
+ key=key,
119
+ messages=messages,
120
+ window_size = self.window_size)])
121
+ elif method == 'kirchenbauer':
122
+ # use sampling
123
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
124
+ logits_processor = self.logits_processor + [
125
+ WatermarkingKirchenbauerLogitsProcessor(n=self.V,
126
+ key=key,
127
+ messages=messages,
128
+ window_size = self.window_size)])
129
+ elif method == 'greedy':
130
+ # generate with greedy search
131
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
132
+ logits_processor = self.logits_processor)
133
+ elif method == 'sampling':
134
+ # generate with greedy search
135
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
136
+ logits_processor = self.logits_processor)
137
+ else:
138
+ raise Exception('Unknown method %s' % method)
139
+ decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
140
+
141
+ return decoded_texts
142
+
143
+ DEFAULT_SYSTEM_PROMPT = """\
144
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
145
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
146
+ """
147
+
148
+ def get_prompt(message: str) -> str:
149
+ texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
150
+ # The first user input is _not_ stripped
151
+ texts.append(f'{message} [/INST]')
152
+ return ''.join(texts)
153
+
154
+
155
+ def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
156
+ if(prompts==None):
157
+ prompts = [""] * len(attacked_texts)
158
+ else:
159
+ for i in range(len(prompts)):
160
+ prompts[i] = get_prompt(prompts[i])
161
+
162
+ generator = self.generator
163
+
164
+ #print("attacked_texts = ", attacked_texts)
165
+
166
+ cdfs = []
167
+ ms = []
168
+
169
+ MAX = 2**self.payload_bits
170
+
171
+ # tokenize input
172
+ inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
173
+
174
+ input_ids = inputs["input_ids"].to(self.model.device)
175
+ attention_masks = inputs["attention_mask"].to(self.model.device)
176
+
177
+ B,T = input_ids.shape
178
+
179
+ if method == 'aaronson_neyman_pearson':
180
+ # compute logits
181
+ outputs = self.model.forward(input_ids, return_dict=True)
182
+ logits = outputs['logits']
183
+ # TODO
184
+ # reapply logits processors to get same distribution
185
+ #for i in range(T):
186
+ # for processor in self.logits_processor:
187
+ # logits[:,i] = processor(input_ids[:, :i], logits[:, i])
188
+
189
+ probs = logits.softmax(dim=-1)
190
+ ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
191
+
192
+
193
+ seq_len = input_ids.shape[1]
194
+ length = seq_len
195
+
196
+ V = self.V
197
+
198
+ Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
199
+
200
+
201
+ # keep a history of contexts we have already seen,
202
+ # to exclude them from score aggregation and allow
203
+ # correct p-value computation under H0
204
+ history = [set() for _ in range(B)]
205
+
206
+ attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
207
+ prompts_length = torch.sum(attention_masks_prompts, dim=1)
208
+ for b in range(B):
209
+ attention_masks[b, :prompts_length[b]] = 0
210
+ if not self.window_size:
211
+ generator.manual_seed(key)
212
+ # We can go from seq_len - prompt_len, need to change +1 to + prompt_len
213
+ for i in range(seq_len-1):
214
+
215
+ if self.window_size:
216
+ window = input_ids[b, max(0, i-self.window_size+1):i+1]
217
+ #print("window = ", window)
218
+ seed = hash_tokens(window, key)
219
+ if seed not in history[b]:
220
+ generator.manual_seed(seed)
221
+ history[b].add(seed)
222
+ else:
223
+ # ignore the token
224
+ attention_masks[b, i+1] = 0
225
+
226
+ if not attention_masks[b,i+1]:
227
+ continue
228
+
229
+ token = int(input_ids[b,i+1])
230
+
231
+ if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
232
+ R = torch.rand(V, generator = generator, device = generator.device)
233
+
234
+ if method == 'aaronson':
235
+ r = -(1-R).log()
236
+ elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
237
+ r = -R.log()
238
+ elif method == 'kirchenbauer':
239
+ r = torch.zeros(V, device=device)
240
+ vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
241
+ greenlist = vocab_permutation[:int(gamma * V)]
242
+ r[greenlist] = 1
243
+ else:
244
+ raise Exception('Unknown method %s' % method)
245
+
246
+ if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
247
+ # independent of probs
248
+ Z[b] += r.roll(-token)
249
+ elif method == 'aaronson_neyman_pearson':
250
+ # Neyman-Pearson
251
+ Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
252
+
253
+ for b in range(B):
254
+ if method in {'aaronson', 'kirchenbauer'}:
255
+ m = torch.argmax(Z[b,:MAX])
256
+ elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
257
+ m = torch.argmin(Z[b,:MAX])
258
+
259
+ i = int(m)
260
+ S = Z[b, i].item()
261
+ m = i
262
+
263
+ # actual sequence length
264
+ k = torch.sum(attention_masks[b]).item() - 1
265
+
266
+ if method == 'aaronson':
267
+ cdf = gammaincc(k, S)
268
+ elif method == 'aaronson_simplified':
269
+ cdf = gammainc(k, S)
270
+ elif method == 'aaronson_neyman_pearson':
271
+ # Chernoff bound
272
+ ratio = ps[b,:k] / (1 - ps[b,:k])
273
+ E = (1/ratio).sum()
274
+
275
+ if S > E:
276
+ cdf = 1.0
277
+ else:
278
+ # to compute p-value we must solve for c*:
279
+ # (1/(c* + ps/(1-ps))).sum() = S
280
+ func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
281
+ c1 = (k / S - torch.min(ratio)).item()
282
+ print("max = ", c1)
283
+ c = fminbound(func, 0, c1)
284
+ print("solved c = ", c)
285
+ print("solved s = ", ((1/(c + ratio)).sum()).item())
286
+ # upper bound
287
+ cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
288
+ elif method == 'kirchenbauer':
289
+ cdf = betainc(S, k - S + 1, gamma)
290
+
291
+ if cdf > min(1 / MAX, 1e-5):
292
+ cdf = 1 - (1 - cdf)**MAX # true value
293
+ else:
294
+ cdf = cdf * MAX # numerically stable upper bound
295
+ cdfs.append(float(cdf))
296
+ ms.append(m)
297
+
298
+ return cdfs, ms
299
+
300
+