Antoine Chaffin commited on
Commit
ed02397
1 Parent(s): 31f8227

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +80 -0
  2. requirements.txt +4 -0
  3. watermark.py +291 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ import numpy as np
5
+
6
+ from watermark import Watermarker
7
+ import time
8
+ import gradio as gr
9
+
10
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
+
12
+ parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
13
+ parser.add_argument('--model', '-m', type=str, default="facebook/opt-350m", help='Language model')
14
+ # parser.add_argument('--model', '-m', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='Language model')
15
+ parser.add_argument('--key', '-k', type=int, default=42,
16
+ help='The seed of the pseudo random number generator')
17
+
18
+ args = parser.parse_args()
19
+
20
+ USERS = ['Alice', 'Bob', 'Charlie', 'Dan']
21
+ EMBED_METHODS = [ 'aaronson', 'kirchenbauer', 'sampling', 'greedy' ]
22
+ DETECT_METHODS = [ 'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson', 'kirchenbauer']
23
+ PAYLOAD_BITS = 2
24
+
25
+ def embed(user, max_length, window_size, method, prompt):
26
+ uid = USERS.index(user)
27
+
28
+ watermarker = Watermarker(modelname=args.model,
29
+ window_size=window_size, payload_bits=PAYLOAD_BITS)
30
+
31
+ watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
32
+ max_length=max_length, method=method, prompt=prompt)
33
+ print("watermarked_texts: ", watermarked_texts)
34
+
35
+ return watermarked_texts[0]
36
+
37
+ def detect(attacked_text, window_size, method, prompt):
38
+ watermarker = Watermarker(modelname=args.model,
39
+ window_size=window_size, payload_bits=PAYLOAD_BITS)
40
+
41
+ pvalues, messages = watermarker.detect([ attacked_text ], key=args.key, method=method, prompts=[prompt])
42
+ print("messages: ", messages)
43
+ print("p-values: ", pvalues)
44
+ user = USERS[messages[0]]
45
+ pf = pvalues[0]
46
+ label = 'The user detected is {:s} with pvalue of {:.3e}'.format(user, pf)
47
+
48
+ return label
49
+
50
+
51
+
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("""# LLM generation watermarking
54
+ This spaces let you to try different watermarking scheme for LLM generation.\n
55
+ It leverages the upgrades introduced in the paper, reducing the gap between empirical and theoretical false positive detection rate and give the ability to embed a message (of n bits). Here we use this capacity to embed the identity of the user generating the text, but it could also be used to identify different version of a model or just convey a secret message.\n
56
+ Simply select an user name, set the maximum text length, the watermarking window size and the prompt. Aaronson and Kirchenbauer watermarking scheme are proposed, along traditional sampling and greedy search without watermarking.\n
57
+ Once the text is generated, you can eventually apply some attacks to it (e.g, remove words), select the associated detection method and run the detection. Please note that the detection is non-blind, and require the original prompt to be known and so left untouched.\n
58
+ For Aaronson, the original detection function, along the Neyman-Pearson and Simplified Score version are available.""")
59
+ with gr.Row():
60
+ user = gr.Dropdown(choices=USERS, value=USERS[0], label="User")
61
+ text_length = gr.Number(minimum=1, maximum=512, value=256, step=1, precision=0, label="Max text length")
62
+ window_size = gr.Number(minimum=0, maximum=10, value=0, step=1, precision=0, label="Watermarking window size")
63
+ embed_method = gr.Dropdown(choices=EMBED_METHODS, value=EMBED_METHODS[0], label="Sampling method")
64
+ prompt = gr.Textbox(label="prompt")
65
+ with gr.Row():
66
+ btn1 = gr.Button("Embed")
67
+ with gr.Row():
68
+ watermarked_text = gr.Textbox(label="Generated text")
69
+ detect_method = gr.Dropdown(choices=DETECT_METHODS, value=DETECT_METHODS[0], label="Detection method")
70
+ with gr.Row():
71
+ btn2 = gr.Button("Detect")
72
+ with gr.Row():
73
+ detection_label = gr.Label(label="Detection result")
74
+
75
+ btn1.click(fn=embed, inputs=[user, text_length, window_size, embed_method, prompt], outputs=[watermarked_text], api_name="watermark")
76
+ btn2.click(fn=detect, inputs=[watermarked_text, window_size, detect_method, prompt], outputs=[detection_label], api_name="detect")
77
+
78
+ demo.launch()
79
+
80
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ scipy
4
+ numpy
watermark.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoTokenizer
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ )
7
+ from transformers import pipeline, set_seed, LogitsProcessor
8
+ from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
9
+ import torch
10
+ from scipy.special import gamma, gammainc, gammaincc, betainc
11
+ from scipy.optimize import fminbound
12
+ import numpy as np
13
+
14
+ import os
15
+
16
+ hf_token = os.getenv('HF_TOKEN')
17
+
18
+
19
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
20
+
21
+ def hash_tokens(input_ids: torch.LongTensor, key: int):
22
+ seed = key
23
+ salt = 35317
24
+ for i in input_ids:
25
+ seed = (seed * salt + i.item()) % (2 ** 64 - 1)
26
+ return seed
27
+
28
+ class WatermarkingLogitsProcessor(LogitsProcessor):
29
+ def __init__(self, n, key, messages, window_size, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+ self.batch_size = len(messages)
32
+ self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
33
+
34
+ self.n = n
35
+ self.key = key
36
+ self.window_size = window_size
37
+ if not self.window_size:
38
+ for b in range(self.batch_size):
39
+ self.generators[b].manual_seed(self.key)
40
+
41
+ self.messages = messages
42
+
43
+ class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+
47
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
48
+ # get random uniform variables
49
+ B, V = scores.shape
50
+
51
+ r = torch.zeros_like(scores)
52
+ for b in range(B):
53
+ if self.window_size:
54
+ window = input_ids[b, -self.window_size:]
55
+ seed = hash_tokens(window, self.key)
56
+ self.generators[b].manual_seed(seed)
57
+ r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
58
+ # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
59
+ r = r[:,:V]
60
+
61
+ # modify law as r^(1/p)
62
+ # Since we want to return logits (logits processor takes and outputs logits),
63
+ # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
64
+ return r / scores.exp()
65
+
66
+ class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
67
+ def __init__(self, *args,
68
+ gamma = 0.5,
69
+ delta = 4.0,
70
+ **kwargs):
71
+ super().__init__(*args, **kwargs)
72
+ self.gamma = gamma
73
+ self.delta = delta
74
+
75
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
76
+ B, V = scores.shape
77
+
78
+ for b in range(B):
79
+ if self.window_size:
80
+ window = input_ids[b, -self.window_size:]
81
+ seed = hash_tokens(window, self.key)
82
+ self.generators[b].manual_seed(seed)
83
+ vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
84
+ greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
85
+ bias = torch.zeros(self.n).to(scores.device)
86
+ bias[greenlist] = self.delta
87
+ bias = bias.roll(-self.messages[b])[:V]
88
+ scores[b] += bias # add bias to greenlist words
89
+
90
+ return scores
91
+
92
+ class Watermarker(object):
93
+ def __init__(self, modelname="facebook/opt-350m", window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
94
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_auth_token=hf_token)
95
+ self.model = AutoModelForCausalLM.from_pretrained(modelname, use_auth_token=hf_token).to(device)
96
+ self.model.eval()
97
+ self.window_size = window_size
98
+
99
+ # preprocessing wrappers
100
+ self.logits_processor = logits_processor or []
101
+
102
+ self.payload_bits = payload_bits
103
+ self.V = max(2**payload_bits, self.model.config.vocab_size)
104
+ self.generator = torch.Generator(device=device)
105
+
106
+
107
+ def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
108
+
109
+ B = len(messages) # batch size
110
+ length = max_length
111
+
112
+ # compute capacity
113
+ if self.payload_bits:
114
+ assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
115
+
116
+ # tokenize prompt
117
+ inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
118
+
119
+ if method == 'aaronson':
120
+ # generate with greedy search
121
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
122
+ logits_processor = self.logits_processor + [
123
+ WatermarkingAaronsonLogitsProcessor(n=self.V,
124
+ key=key,
125
+ messages=messages,
126
+ window_size = self.window_size)])
127
+ elif method == 'kirchenbauer':
128
+ # use sampling
129
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
130
+ logits_processor = self.logits_processor + [
131
+ WatermarkingKirchenbauerLogitsProcessor(n=self.V,
132
+ key=key,
133
+ messages=messages,
134
+ window_size = self.window_size)])
135
+ elif method == 'greedy':
136
+ # generate with greedy search
137
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
138
+ logits_processor = self.logits_processor)
139
+ elif method == 'sampling':
140
+ # generate with greedy search
141
+ generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
142
+ logits_processor = self.logits_processor)
143
+ else:
144
+ raise Exception('Unknown method %s' % method)
145
+ decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
146
+
147
+ return decoded_texts
148
+
149
+ def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
150
+ if(prompts==None):
151
+ prompts = [""] * len(attacked_texts)
152
+
153
+ generator = self.generator
154
+
155
+ #print("attacked_texts = ", attacked_texts)
156
+
157
+ cdfs = []
158
+ ms = []
159
+
160
+ MAX = 2**self.payload_bits
161
+
162
+ # tokenize input
163
+ inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
164
+
165
+ input_ids = inputs["input_ids"].to(self.model.device)
166
+ attention_masks = inputs["attention_mask"].to(self.model.device)
167
+
168
+ B,T = input_ids.shape
169
+
170
+ if method == 'aaronson_neyman_pearson':
171
+ # compute logits
172
+ outputs = self.model.forward(input_ids, return_dict=True)
173
+ logits = outputs['logits']
174
+ # TODO
175
+ # reapply logits processors to get same distribution
176
+ #for i in range(T):
177
+ # for processor in self.logits_processor:
178
+ # logits[:,i] = processor(input_ids[:, :i], logits[:, i])
179
+
180
+ probs = logits.softmax(dim=-1)
181
+ ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
182
+
183
+
184
+ seq_len = input_ids.shape[1]
185
+ length = seq_len
186
+
187
+ V = self.V
188
+
189
+ Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
190
+
191
+
192
+ # keep a history of contexts we have already seen,
193
+ # to exclude them from score aggregation and allow
194
+ # correct p-value computation under H0
195
+ history = [set() for _ in range(B)]
196
+
197
+ attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
198
+ prompts_length = torch.sum(attention_masks_prompts, dim=1)
199
+ for b in range(B):
200
+ attention_masks[b, :prompts_length[b]] = 0
201
+ if not self.window_size:
202
+ generator.manual_seed(key)
203
+ # We can go from seq_len - prompt_len, need to change +1 to + prompt_len
204
+ for i in range(seq_len-1):
205
+
206
+ if self.window_size:
207
+ window = input_ids[b, max(0, i-self.window_size+1):i+1]
208
+ #print("window = ", window)
209
+ seed = hash_tokens(window, key)
210
+ if seed not in history[b]:
211
+ generator.manual_seed(seed)
212
+ history[b].add(seed)
213
+ else:
214
+ # ignore the token
215
+ attention_masks[b, i+1] = 0
216
+
217
+ if not attention_masks[b,i+1]:
218
+ continue
219
+
220
+ token = int(input_ids[b,i+1])
221
+
222
+ if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
223
+ R = torch.rand(V, generator = generator, device = generator.device)
224
+
225
+ if method == 'aaronson':
226
+ r = -(1-R).log()
227
+ elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
228
+ r = -R.log()
229
+ elif method == 'kirchenbauer':
230
+ r = torch.zeros(V, device=device)
231
+ vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
232
+ greenlist = vocab_permutation[:int(gamma * V)]
233
+ r[greenlist] = 1
234
+ else:
235
+ raise Exception('Unknown method %s' % method)
236
+
237
+ if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
238
+ # independent of probs
239
+ Z[b] += r.roll(-token)
240
+ elif method == 'aaronson_neyman_pearson':
241
+ # Neyman-Pearson
242
+ Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
243
+
244
+ for b in range(B):
245
+ if method in {'aaronson', 'kirchenbauer'}:
246
+ m = torch.argmax(Z[b,:MAX])
247
+ elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
248
+ m = torch.argmin(Z[b,:MAX])
249
+
250
+ i = int(m)
251
+ S = Z[b, i].item()
252
+ m = i
253
+
254
+ # actual sequence length
255
+ k = torch.sum(attention_masks[b]).item() - 1
256
+
257
+ if method == 'aaronson':
258
+ cdf = gammaincc(k, S)
259
+ elif method == 'aaronson_simplified':
260
+ cdf = gammainc(k, S)
261
+ elif method == 'aaronson_neyman_pearson':
262
+ # Chernoff bound
263
+ ratio = ps[b,:k] / (1 - ps[b,:k])
264
+ E = (1/ratio).sum()
265
+
266
+ if S > E:
267
+ cdf = 1.0
268
+ else:
269
+ # to compute p-value we must solve for c*:
270
+ # (1/(c* + ps/(1-ps))).sum() = S
271
+ func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
272
+ c1 = (k / S - torch.min(ratio)).item()
273
+ print("max = ", c1)
274
+ c = fminbound(func, 0, c1)
275
+ print("solved c = ", c)
276
+ print("solved s = ", ((1/(c + ratio)).sum()).item())
277
+ # upper bound
278
+ cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
279
+ elif method == 'kirchenbauer':
280
+ cdf = betainc(S, k - S + 1, gamma)
281
+
282
+ if cdf > min(1 / MAX, 1e-5):
283
+ cdf = 1 - (1 - cdf)**MAX # true value
284
+ else:
285
+ cdf = cdf * MAX # numerically stable upper bound
286
+ cdfs.append(float(cdf))
287
+ ms.append(m)
288
+
289
+ return cdfs, ms
290
+
291
+