jwkirchenbauer commited on
Commit
d6b2709
1 Parent(s): 619f2e3

adding files

Browse files
Files changed (7) hide show
  1. README.md +5 -4
  2. app.py +41 -0
  3. demo_watermark.py +379 -0
  4. homoglyphs.py +268 -0
  5. normalizers.py +195 -0
  6. requirements.txt +6 -0
  7. watermark_processor.py +281 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Lm Watermarking
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: A Watermark for LLMs
3
+ emoji: 💧
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ python_version: 3.10.6
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from argparse import Namespace
18
+ args = Namespace()
19
+
20
+ arg_dict = {
21
+ "run_gradio": True,
22
+ "model_name_or_path": "facebook/opt-125m",
23
+ # "model_name_or_path": "facebook/opt-1.3b",
24
+ # "model_name_or_path": "facebook/opt-2.7b",
25
+ "max_new_tokens": 200,
26
+ "use_sampling": True,
27
+ "sampling_temp": 0.7,
28
+ "use_gpu": True,
29
+ "seeding_scheme": "markov_1",
30
+ "gamma": 0.25,
31
+ "delta": 2.0,
32
+ "normalizers": "",
33
+ "ignore_repeated_bigrams": False,
34
+ }
35
+
36
+ args.__dict__.update(arg_dict)
37
+ print(args)
38
+
39
+ from demo_watermark import main
40
+
41
+ main(args)
demo_watermark.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import argparse
19
+ from pprint import pprint
20
+ from functools import partial
21
+
22
+ import torch
23
+
24
+ from transformers import (AutoTokenizer,
25
+ AutoModelForSeq2SeqLM,
26
+ AutoModelForCausalLM,
27
+ LogitsProcessorList)
28
+
29
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
30
+
31
+ def str2bool(v):
32
+ if isinstance(v, bool):
33
+ return v
34
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
35
+ return True
36
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
37
+ return False
38
+ else:
39
+ raise argparse.ArgumentTypeError('Boolean value expected.')
40
+
41
+ def parse_args():
42
+
43
+ parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
44
+
45
+ parser.add_argument(
46
+ "--run_gradio",
47
+ type=str2bool,
48
+ default=False,
49
+ help="Whether to launch as a gradio demo.",
50
+ )
51
+ parser.add_argument(
52
+ "--demo_public",
53
+ type=str2bool,
54
+ default=False,
55
+ help="Whether to expose the gradio demo to the internet.",
56
+ )
57
+ parser.add_argument(
58
+ "--model_name_or_path",
59
+ type=str,
60
+ default="facebook/opt-6.7b",
61
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
62
+ )
63
+ parser.add_argument(
64
+ "--prompt_max_length",
65
+ type=int,
66
+ default=None,
67
+ help="Truncation length for prompt, overrides model config's max length field.",
68
+ )
69
+ parser.add_argument(
70
+ "--max_new_tokens",
71
+ type=int,
72
+ default=200,
73
+ help="Maximmum number of new tokens to generate.",
74
+ )
75
+ parser.add_argument(
76
+ "--generation_seed",
77
+ type=int,
78
+ default=123,
79
+ help="Seed for setting the torch global rng prior to generation.",
80
+ )
81
+ parser.add_argument(
82
+ "--use_sampling",
83
+ type=str2bool,
84
+ default=True,
85
+ help="Whether to generate using multinomial sampling.",
86
+ )
87
+ parser.add_argument(
88
+ "--sampling_temp",
89
+ type=float,
90
+ default=0.7,
91
+ help="Sampling temperature to use when generating using multinomial sampling.",
92
+ )
93
+ parser.add_argument(
94
+ "--use_gpu",
95
+ type=str2bool,
96
+ default=True,
97
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
98
+ )
99
+ parser.add_argument(
100
+ "--seeding_scheme",
101
+ type=str,
102
+ default="markov_1",
103
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
104
+ )
105
+ parser.add_argument(
106
+ "--gamma",
107
+ type=float,
108
+ default=0.25,
109
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
110
+ )
111
+ parser.add_argument(
112
+ "--delta",
113
+ type=float,
114
+ default=2.0,
115
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
116
+ )
117
+ parser.add_argument(
118
+ "--normalizers",
119
+ type=str,
120
+ default="",
121
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
122
+ )
123
+ parser.add_argument(
124
+ "--ignore_repeated_bigrams",
125
+ type=str2bool,
126
+ default=False,
127
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
128
+ )
129
+ parser.add_argument(
130
+ "--detection_z_threshold",
131
+ type=float,
132
+ default=4.0,
133
+ help="The test statistic threshold for the detection hypothesis test.",
134
+ )
135
+ parser.add_argument(
136
+ "--select_green_tokens",
137
+ type=str2bool,
138
+ default=True,
139
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
140
+ )
141
+ args = parser.parse_args()
142
+ return args
143
+
144
+
145
+ def main(args):
146
+
147
+ is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
148
+ is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
149
+ if is_seq2seq_model:
150
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
151
+ elif is_decoder_only_model:
152
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
153
+ else:
154
+ raise ValueError(f"Unknown model type: {args.model_name_or_path}")
155
+
156
+ if args.use_gpu:
157
+ device = "cuda" if torch.cuda.is_available() else "cpu"
158
+ model = model.to(device)
159
+ else:
160
+ device = "cpu"
161
+ model.eval()
162
+
163
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
164
+ vocabulary = list(tokenizer.get_vocab().values())
165
+
166
+ def generate(prompt):
167
+
168
+ watermark_processor = WatermarkLogitsProcessor(vocab=vocabulary,
169
+ gamma=args.gamma,
170
+ delta=args.delta,
171
+ seeding_scheme=args.seeding_scheme,
172
+ select_green_tokens=args.select_green_tokens)
173
+
174
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
175
+
176
+ if args.use_sampling:
177
+ gen_kwargs.update(dict(
178
+ do_sample=True,
179
+ top_k=0,
180
+ temperature=args.sampling_temp
181
+ ))
182
+ else:
183
+ gen_kwargs.update(dict(
184
+ num_beams=args.n_beams
185
+ ))
186
+
187
+ generate_without_watermark = partial(
188
+ model.generate,
189
+ **gen_kwargs
190
+ )
191
+ generate_with_watermark = partial(
192
+ model.generate,
193
+ logits_processor=LogitsProcessorList([watermark_processor]),
194
+ **gen_kwargs
195
+ )
196
+ if args.prompt_max_length:
197
+ pass
198
+ elif hasattr(model.config,"max_position_embedding"):
199
+ args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
200
+ else:
201
+ args.prompt_max_length = 2048-args.max_new_tokens
202
+
203
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
204
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
205
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
206
+
207
+ torch.manual_seed(args.generation_seed)
208
+ output_without_watermark = generate_without_watermark(**tokd_input)
209
+ # torch.manual_seed(seed) # optional, but will not be the same again generally, unless delta==0.0, no-op watermark
210
+ output_with_watermark = generate_with_watermark(**tokd_input)
211
+
212
+ if is_decoder_only_model:
213
+ # need to isolate the newly generated tokens
214
+ output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
215
+ output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
216
+
217
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
218
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
219
+
220
+ return (redecoded_input,
221
+ int(truncation_warning),
222
+ decoded_output_without_watermark,
223
+ decoded_output_with_watermark)
224
+ # decoded_output_with_watermark)
225
+
226
+ def detect(input_text):
227
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
228
+ gamma=args.gamma,
229
+ seeding_scheme=args.seeding_scheme,
230
+ device=device,
231
+ tokenizer=tokenizer,
232
+ z_threshold=args.detection_z_threshold,
233
+ normalizers=(args.normalizers.split(",") if args.normalizers else []),
234
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
235
+ select_green_tokens=args.select_green_tokens)
236
+ if len(input_text)-1 > watermark_detector.min_prefix_len:
237
+ score_dict = watermark_detector.detect(input_text)
238
+ output_str = (f"Detection result @ {watermark_detector.z_threshold}:\n"
239
+ f"{score_dict}")
240
+ else:
241
+ output_str = (f"Error: string not long enough to compute watermark presence.")
242
+ return output_str
243
+
244
+ # Generate and detect, report to stdout
245
+
246
+ # input_text = (
247
+ # "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
248
+ # "species of turtle native to the brackish coastal tidal marshes of the "
249
+ # "Northeastern and southern United States, and in Bermuda.[6] It belongs "
250
+ # "to the monotypic genus Malaclemys. It has one of the largest ranges of "
251
+ # "all turtles in North America, stretching as far south as the Florida Keys "
252
+ # "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
253
+ # "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
254
+ # "British English and American English. The name originally was used by "
255
+ # "early European settlers in North America to describe these brackish-water "
256
+ # "turtles that inhabited neither freshwater habitats nor the sea. It retains "
257
+ # "this primary meaning in American English.[8] In British English, however, "
258
+ # "other semi-aquatic turtle species, such as the red-eared slider, might "
259
+ # "also be called terrapins. The common name refers to the diamond pattern "
260
+ # "on top of its shell (carapace), but the overall pattern and coloration "
261
+ # "vary greatly. The shell is usually wider at the back than in the front, "
262
+ # "and from above it appears wedge-shaped. The shell coloring can vary "
263
+ # "from brown to grey, and its body color can be grey, brown, yellow, "
264
+ # "or white. All have a unique pattern of wiggly, black markings or spots "
265
+ # "on their body and head. The diamondback terrapin has large webbed "
266
+ # "feet.[9] The species is"
267
+ # )
268
+
269
+ input_text = "In this work, we study watermarking of language model output. A watermark is a hidden pattern in text that is imperceptible to humans, while making the text algorithmically identifiable as synthetic. We propose an efficient watermark that makes synthetic text detectable from short spans of tokens (as few as 25 words), while false-positives (where human text is marked as machine-generated) are statistically improbable. The watermark detection algorithm can be made public, enabling third parties (e.g., social media platforms) to run it themselves, or it can be kept private and run behind an API. We seek a watermark with the following properties:\n"
270
+
271
+
272
+ term_width = os.get_terminal_size()[0]
273
+ print("#"*term_width)
274
+ print("Prompt:")
275
+ print(input_text)
276
+
277
+ _, _, decoded_output_without_watermark, decoded_output_with_watermark = generate(input_text)
278
+ without_watermark_detection_result = detect(decoded_output_without_watermark)
279
+ with_watermark_detection_result = detect(decoded_output_with_watermark)
280
+
281
+ print("#"*term_width)
282
+ print("Output without watermark:")
283
+ print(decoded_output_without_watermark)
284
+ print("-"*term_width)
285
+ print(f"Detection result @ {args.detection_z_threshold}:")
286
+ pprint(without_watermark_detection_result)
287
+ print("-"*term_width)
288
+
289
+ print("#"*term_width)
290
+ print("Output with watermark:")
291
+ print(decoded_output_with_watermark)
292
+ print("-"*term_width)
293
+ print(f"Detection result @ {args.detection_z_threshold}:")
294
+ pprint(with_watermark_detection_result)
295
+ print("-"*term_width)
296
+
297
+ # Launch the app to generate and detect interactively (implements the hf space demo)
298
+
299
+ if args.run_gradio:
300
+ import gradio as gr
301
+
302
+ with gr.Blocks() as demo:
303
+ gr.Markdown("## Demo for ['A Watermark for Large Language Models'](https://arxiv.org/abs/2301.10226)")
304
+ # gr.HTML("""
305
+ # <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
306
+ # <br/>
307
+ # <a href="https://huggingface.co/spaces/tomg-group-umd/pez-dispenser?duplicate=true">
308
+ # <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
309
+ # <p/>
310
+ # """)
311
+ gr.Markdown(f"#### Generation and Watermarking Parameters:\n\n{args.__dict__}")
312
+
313
+ with gr.Tab("Generation"):
314
+ with gr.Row():
315
+ prompt = gr.Textbox(label=f"Prompt (max {args.prompt_max_length} tokens)", interactive=True)
316
+ with gr.Row():
317
+ generate_btn = gr.Button("Generate")
318
+ with gr.Row():
319
+ with gr.Column(scale=2):
320
+ output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False)
321
+ with gr.Column(scale=1):
322
+ without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
323
+ with gr.Row():
324
+ with gr.Column(scale=2):
325
+ output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False)
326
+ with gr.Column(scale=1):
327
+ with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
328
+
329
+
330
+ redecoded_input = gr.Textbox(visible=False)
331
+ truncation_warning = gr.Number(visible=False)
332
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt):
333
+ if truncation_warning:
334
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]"
335
+ else:
336
+ return orig_prompt
337
+
338
+ generate_btn.click(fn=generate, inputs=[prompt], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark])
339
+
340
+ # Show truncated version of prompt if truncation occurred
341
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt], outputs=[prompt])
342
+
343
+ # Call detection when the outputs of the generate function are updated.
344
+ output_without_watermark.change(fn=detect, inputs=output_without_watermark, outputs=without_watermark_detection_result)
345
+ output_with_watermark.change(fn=detect, inputs=output_with_watermark, outputs=with_watermark_detection_result)
346
+
347
+ with gr.Tab("Detector Only"):
348
+ with gr.Row():
349
+ detection_input = gr.Textbox(label="Text to Analyze", interactive=True)
350
+ with gr.Row():
351
+ detect_btn = gr.Button("Detect")
352
+ with gr.Row():
353
+ detection_result = gr.Textbox(label="Detection Result", interactive=False)
354
+ detect_btn.click(fn=detect, inputs=detection_input, outputs=detection_result)
355
+
356
+ with gr.Accordion("A note on model capability",open=False):
357
+ gr.Markdown(
358
+ """
359
+ The models that can be used in this demo are limited to those that are open source as well as fit on a single commodity GPU. In particular, there are few models above 10B parameters and way fewer trained using both Instruction finetuning or RLHF that are open source that we can use.
360
+
361
+ Therefore, the model, in both it's un-watermarked (normal) and watermarked state, is not generally able to respond well to the kinds of prompts that a 100B+ Instruction and RLHF tuned model such as ChatGPT, Claude, or Bard is.
362
+
363
+ We suggest you try prompts that give the model a few sentences and then allow it to 'continue' the prompt, as these weaker models are more capable in this simpler language modeling setting.
364
+ """
365
+ )
366
+
367
+ if args.demo_public:
368
+ demo.launch(share=True) # exposes app to the internet via randomly generated link
369
+ else:
370
+ demo.launch()
371
+
372
+ return
373
+
374
+ if __name__ == "__main__":
375
+
376
+ args = parse_args()
377
+ print(args)
378
+
379
+ main(args)
homoglyphs.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Updated version of core.py from
2
+ https://github.com/yamatt/homoglyphs/tree/main/homoglyphs_fork
3
+ for modern python3
4
+ """
5
+
6
+ from collections import defaultdict
7
+ import json
8
+ from itertools import product
9
+ import os
10
+ import unicodedata
11
+
12
+ import homoglyphs_fork as hg
13
+
14
+ CURRENT_DIR = hg.core.CURRENT_DIR
15
+
16
+ # Actions if char not in alphabet
17
+ STRATEGY_LOAD = 1 # load category for this char
18
+ STRATEGY_IGNORE = 2 # add char to result
19
+ STRATEGY_REMOVE = 3 # remove char from result
20
+
21
+ ASCII_RANGE = range(128)
22
+
23
+
24
+ class Categories:
25
+ """
26
+ Work with aliases from ISO 15924.
27
+ https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
28
+ """
29
+
30
+ fpath = os.path.join(CURRENT_DIR, "categories.json")
31
+
32
+ @classmethod
33
+ def _get_ranges(cls, categories):
34
+ """
35
+ :return: iter: (start code, end code)
36
+ :rtype: list
37
+ """
38
+ with open(cls.fpath, encoding="utf-8") as f:
39
+ data = json.load(f)
40
+
41
+ for category in categories:
42
+ if category not in data["aliases"]:
43
+ raise ValueError("Invalid category: {}".format(category))
44
+
45
+ for point in data["points"]:
46
+ if point[2] in categories:
47
+ yield point[:2]
48
+
49
+ @classmethod
50
+ def get_alphabet(cls, categories):
51
+ """
52
+ :return: set of chars in alphabet by categories list
53
+ :rtype: set
54
+ """
55
+ alphabet = set()
56
+ for start, end in cls._get_ranges(categories):
57
+ chars = (chr(code) for code in range(start, end + 1))
58
+ alphabet.update(chars)
59
+ return alphabet
60
+
61
+ @classmethod
62
+ def detect(cls, char):
63
+ """
64
+ :return: category
65
+ :rtype: str
66
+ """
67
+ with open(cls.fpath, encoding="utf-8") as f:
68
+ data = json.load(f)
69
+
70
+ # try detect category by unicodedata
71
+ try:
72
+ category = unicodedata.name(char).split()[0]
73
+ except TypeError:
74
+ # In Python2 unicodedata.name raise error for non-unicode chars
75
+ pass
76
+ else:
77
+ if category in data["aliases"]:
78
+ return category
79
+
80
+ # try detect category by ranges from JSON file.
81
+ code = ord(char)
82
+ for point in data["points"]:
83
+ if point[0] <= code <= point[1]:
84
+ return point[2]
85
+
86
+ @classmethod
87
+ def get_all(cls):
88
+ with open(cls.fpath, encoding="utf-8") as f:
89
+ data = json.load(f)
90
+ return set(data["aliases"])
91
+
92
+
93
+ class Languages:
94
+ fpath = os.path.join(CURRENT_DIR, "languages.json")
95
+
96
+ @classmethod
97
+ def get_alphabet(cls, languages):
98
+ """
99
+ :return: set of chars in alphabet by languages list
100
+ :rtype: set
101
+ """
102
+ with open(cls.fpath, encoding="utf-8") as f:
103
+ data = json.load(f)
104
+ alphabet = set()
105
+ for lang in languages:
106
+ if lang not in data:
107
+ raise ValueError("Invalid language code: {}".format(lang))
108
+ alphabet.update(data[lang])
109
+ return alphabet
110
+
111
+ @classmethod
112
+ def detect(cls, char):
113
+ """
114
+ :return: set of languages which alphabet contains passed char.
115
+ :rtype: set
116
+ """
117
+ with open(cls.fpath, encoding="utf-8") as f:
118
+ data = json.load(f)
119
+ languages = set()
120
+ for lang, alphabet in data.items():
121
+ if char in alphabet:
122
+ languages.add(lang)
123
+ return languages
124
+
125
+ @classmethod
126
+ def get_all(cls):
127
+ with open(cls.fpath, encoding="utf-8") as f:
128
+ data = json.load(f)
129
+ return set(data.keys())
130
+
131
+
132
+ class Homoglyphs:
133
+ def __init__(
134
+ self,
135
+ categories=None,
136
+ languages=None,
137
+ alphabet=None,
138
+ strategy=STRATEGY_IGNORE,
139
+ ascii_strategy=STRATEGY_IGNORE,
140
+ ascii_range=ASCII_RANGE,
141
+ ):
142
+ # strategies
143
+ if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
144
+ raise ValueError("Invalid strategy")
145
+ self.strategy = strategy
146
+ self.ascii_strategy = ascii_strategy
147
+ self.ascii_range = ascii_range
148
+
149
+ # Homoglyphs must be initialized by any alphabet for correct work
150
+ if not categories and not languages and not alphabet:
151
+ categories = ("LATIN", "COMMON")
152
+
153
+ # cats and langs
154
+ self.categories = set(categories or [])
155
+ self.languages = set(languages or [])
156
+
157
+ # alphabet
158
+ self.alphabet = set(alphabet or [])
159
+ if self.categories:
160
+ alphabet = Categories.get_alphabet(self.categories)
161
+ self.alphabet.update(alphabet)
162
+ if self.languages:
163
+ alphabet = Languages.get_alphabet(self.languages)
164
+ self.alphabet.update(alphabet)
165
+ self.table = self.get_table(self.alphabet)
166
+
167
+ @staticmethod
168
+ def get_table(alphabet):
169
+ table = defaultdict(set)
170
+ # removed CURRENT_DIR here:
171
+ with open(os.path.join("confusables_sept2022.json")) as f:
172
+ data = json.load(f)
173
+ for char in alphabet:
174
+ if char in data:
175
+ for homoglyph in data[char]:
176
+ if homoglyph in alphabet:
177
+ table[char].add(homoglyph)
178
+ return table
179
+
180
+ @staticmethod
181
+ def get_restricted_table(source_alphabet, target_alphabet):
182
+ table = defaultdict(set)
183
+ # removed CURRENT_DIR here:
184
+ with open(os.path.join("confusables_sept2022.json")) as f:
185
+ data = json.load(f)
186
+ for char in source_alphabet:
187
+ if char in data:
188
+ for homoglyph in data[char]:
189
+ if homoglyph in target_alphabet:
190
+ table[char].add(homoglyph)
191
+ return table
192
+
193
+ @staticmethod
194
+ def uniq_and_sort(data):
195
+ result = list(set(data))
196
+ result.sort(key=lambda x: (-len(x), x))
197
+ return result
198
+
199
+ def _update_alphabet(self, char):
200
+ # try detect languages
201
+ langs = Languages.detect(char)
202
+ if langs:
203
+ self.languages.update(langs)
204
+ alphabet = Languages.get_alphabet(langs)
205
+ self.alphabet.update(alphabet)
206
+ else:
207
+ # try detect categories
208
+ category = Categories.detect(char)
209
+ if category is None:
210
+ return False
211
+ self.categories.add(category)
212
+ alphabet = Categories.get_alphabet([category])
213
+ self.alphabet.update(alphabet)
214
+ # update table for new alphabet
215
+ self.table = self.get_table(self.alphabet)
216
+ return True
217
+
218
+ def _get_char_variants(self, char):
219
+ if char not in self.alphabet:
220
+ if self.strategy == STRATEGY_LOAD:
221
+ if not self._update_alphabet(char):
222
+ return []
223
+ elif self.strategy == STRATEGY_IGNORE:
224
+ return [char]
225
+ elif self.strategy == STRATEGY_REMOVE:
226
+ return []
227
+
228
+ # find alternative chars for current char
229
+ alt_chars = self.table.get(char, set())
230
+ if alt_chars:
231
+ # find alternative chars for alternative chars for current char
232
+ alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
233
+ # combine all alternatives
234
+ alt_chars.update(*alt_chars2)
235
+ # add current char to alternatives
236
+ alt_chars.add(char)
237
+
238
+ # uniq, sort and return
239
+ return self.uniq_and_sort(alt_chars)
240
+
241
+ def _get_combinations(self, text, ascii=False):
242
+ variations = []
243
+ for char in text:
244
+ alt_chars = self._get_char_variants(char)
245
+
246
+ if ascii:
247
+ alt_chars = [
248
+ char for char in alt_chars if ord(char) in self.ascii_range
249
+ ]
250
+ if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
251
+ return
252
+
253
+ if alt_chars:
254
+ variations.append(alt_chars)
255
+ if variations:
256
+ for variant in product(*variations):
257
+ yield "".join(variant)
258
+
259
+ def get_combinations(self, text):
260
+ return list(self._get_combinations(text))
261
+
262
+ def _to_ascii(self, text):
263
+ for variant in self._get_combinations(text, ascii=True):
264
+ if max(map(ord, variant)) in self.ascii_range:
265
+ yield variant
266
+
267
+ def to_ascii(self, text):
268
+ return self.uniq_and_sort(self._to_ascii(text))
normalizers.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Text-based normalizers, used to mitigate simple attacks against watermarking.
2
+
3
+ This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
4
+ it represents our best effort at the time of writing.
5
+
6
+ These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
7
+ require messing with the limited rust interface of tokenizers.NormalizedString
8
+ """
9
+ from collections import defaultdict
10
+ from functools import cache
11
+
12
+ import re
13
+ import unicodedata
14
+ import homoglyphs as hg
15
+
16
+
17
+ def normalization_strategy_lookup(strategy_name: str) -> object:
18
+ if strategy_name == "unicode":
19
+ return UnicodeSanitizer()
20
+ elif strategy_name == "homoglyphs":
21
+ return HomoglyphCanonizer()
22
+ elif strategy_name == "truecase":
23
+ return TrueCaser()
24
+
25
+
26
+ class HomoglyphCanonizer:
27
+ """Attempts to detect homoglyph attacks and find a consistent canon.
28
+
29
+ This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
30
+ """
31
+
32
+ def __init__(self):
33
+ self.homoglyphs = None
34
+
35
+ def __call__(self, homoglyphed_str: str) -> str:
36
+ # find canon:
37
+ target_category, all_categories = self._categorize_text(homoglyphed_str)
38
+ homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
39
+ return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
40
+
41
+ def _categorize_text(self, text: str) -> dict:
42
+ iso_categories = defaultdict(int)
43
+ # self.iso_languages = defaultdict(int)
44
+
45
+ for char in text:
46
+ iso_categories[hg.Categories.detect(char)] += 1
47
+ # for lang in hg.Languages.detect(char):
48
+ # self.iso_languages[lang] += 1
49
+ target_category = max(iso_categories, key=iso_categories.get)
50
+ all_categories = tuple(iso_categories)
51
+ return target_category, all_categories
52
+
53
+ @cache
54
+ def _select_canon_category_and_load(self, target_category: str, all_categories: tuple[str]) -> dict:
55
+ homoglyph_table = hg.Homoglyphs(categories=(target_category, "COMMON")) # alphabet loaded here from file
56
+
57
+ source_alphabet = hg.Categories.get_alphabet(all_categories)
58
+ restricted_table = homoglyph_table.get_restricted_table(source_alphabet, homoglyph_table.alphabet) # table loaded here from file
59
+ return restricted_table
60
+
61
+ def _sanitize_text(self, target_category: str, homoglyph_table: dict, homoglyphed_str: str) -> str:
62
+ sanitized_text = ""
63
+ for char in homoglyphed_str:
64
+ # langs = hg.Languages.detect(char)
65
+ cat = hg.Categories.detect(char)
66
+ if target_category in cat or "COMMON" in cat or len(cat) == 0:
67
+ sanitized_text += char
68
+ else:
69
+ sanitized_text += list(homoglyph_table[char])[0]
70
+ return sanitized_text
71
+
72
+
73
+ class UnicodeSanitizer:
74
+ """Regex-based unicode sanitzer. Has different levels of granularity.
75
+
76
+ * ruleset="whitespaces" - attempts to remove only whitespace unicode characters
77
+ * ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters
78
+ * ruleset="ascii" - brute-forces all text into ascii
79
+
80
+ This is unlikely to be a comprehensive list.
81
+
82
+ You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
83
+ and https://www.unicode.org/faq/security.html
84
+ """
85
+
86
+ def __init__(self, ruleset="whitespaces"):
87
+ if ruleset == "whitespaces":
88
+
89
+ """Documentation:
90
+ \u00A0: Non-breaking space
91
+ \u1680: Ogham space mark
92
+ \u180E: Mongolian vowel separator
93
+ \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
94
+ \u200C\u200D: Zero-width non-joiner and zero-width joiner
95
+ \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
96
+ \u2060: Word joiner
97
+ \u2063: Invisible separator
98
+ \u202F: Narrow non-breaking space
99
+ \u205F: Medium mathematical space
100
+ \u3000: Ideographic space
101
+ \uFEFF: Zero-width non-breaking space
102
+ \uFFA0: Halfwidth hangul filler
103
+ \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
104
+ \uFE00-\uFE0F: Variation selectors
105
+ \u202A-\u202F: Embedding characters
106
+ \u3164: Korean hangul filler.
107
+
108
+ Note that these characters are not always superfluous whitespace characters!
109
+ """
110
+
111
+ self.pattern = re.compile(
112
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
113
+ r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
114
+ r"\u202E\u202F]"
115
+ )
116
+ elif ruleset == "IDN.blacklist":
117
+
118
+ """Documentation:
119
+ [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
120
+ set that are included in the IDN blacklist.
121
+ \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
122
+ These characters are not allowed in domain names.
123
+ \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
124
+ set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
125
+ and the second part is in the range U+DC00 to U+DFFF.
126
+ \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
127
+ to U+DFFF, and is optional.
128
+ [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
129
+ """
130
+
131
+ self.pattern = re.compile(
132
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
133
+ r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
134
+ )
135
+ else:
136
+ """Documentation:
137
+ This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
138
+ """
139
+ self.pattern = re.compile(r"[^\x00-\x7F]+")
140
+
141
+ def __call__(self, text: str) -> str:
142
+ text = unicodedata.normalize("NFC", text) # canon forms
143
+ text = self.pattern.sub(" ", text) # pattern match
144
+ text = re.sub(" +", " ", text) # collapse whitespaces
145
+ text = "".join(c for c in text if unicodedata.category(c) != "Cc") # Remove any remaining non-printable characters
146
+ return text
147
+
148
+
149
+ class TrueCaser:
150
+ """True-casing, is a capitalization normalization that returns text to its original capitalization.
151
+
152
+ This defends against attacks that wRIte TeXt lIkE spOngBoB.
153
+
154
+ Here, a simple POS-tagger is used.
155
+ """
156
+
157
+ uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased
158
+
159
+ def __init__(self, backend="spacy"):
160
+ if backend == "spacy":
161
+ import spacy
162
+
163
+ self.nlp = spacy.load("en_core_web_sm")
164
+ self.normalize_fn = self._spacy_truecasing
165
+ else:
166
+ from nltk import pos_tag, word_tokenize # noqa
167
+ import nltk
168
+
169
+ nltk.download("punkt")
170
+ nltk.download("averaged_perceptron_tagger")
171
+ nltk.download("universal_tagset")
172
+ self.normalize_fn = self._nltk_truecasing
173
+
174
+ def __call__(self, random_capitalized_string: str) -> str:
175
+ truecased_str = self.normalize_fn(random_capitalized_string)
176
+ return truecased_str
177
+
178
+ def _spacy_truecasing(self, random_capitalized_string: str):
179
+ doc = self.nlp(random_capitalized_string.lower())
180
+ POS = self.uppercase_pos
181
+ truecased_str = "".join([w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws for w in doc])
182
+ return truecased_str
183
+
184
+ def _nltk_truecasing(self, random_capitalized_string: str):
185
+ from nltk import pos_tag, word_tokenize
186
+ import nltk
187
+
188
+ nltk.download("punkt")
189
+ nltk.download("averaged_perceptron_tagger")
190
+ nltk.download("universal_tagset")
191
+ POS = ["NNP", "NNPS"]
192
+
193
+ tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
194
+ truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
195
+ return truecased_str
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ homoglyphs_fork
2
+ nltk
3
+ scipy
4
+ torch
5
+ transformers
6
+ tokenizers
watermark_processor.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+ import collections
19
+ from math import sqrt
20
+
21
+ import scipy.stats
22
+
23
+ import torch
24
+ from torch import Tensor
25
+ from tokenizers import Tokenizer
26
+ from transformers import LogitsProcessor
27
+
28
+ from nltk.util import ngrams
29
+
30
+ from normalizers import normalization_strategy_lookup
31
+
32
+ class WatermarkBase:
33
+ def __init__(
34
+ self,
35
+ vocab: list[int] = None,
36
+ gamma: float = 0.5,
37
+ delta: float = 2.0,
38
+ seeding_scheme: str = "markov_1", # mostly unused/always default
39
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
40
+ select_green_tokens: bool = True,
41
+ ):
42
+
43
+ # watermarking parameters
44
+ self.vocab = vocab
45
+ self.vocab_size = len(vocab)
46
+ self.gamma = gamma
47
+ self.delta = delta
48
+ self.seeding_scheme = seeding_scheme
49
+ self.rng = None
50
+ self.hash_key = hash_key
51
+ self.select_green_tokens = select_green_tokens
52
+
53
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
54
+ # can optionally override the seeding scheme,
55
+ # but uses the instance attr by default
56
+ if seeding_scheme is None:
57
+ seeding_scheme = self.seeding_scheme
58
+
59
+ if seeding_scheme == "markov_1":
60
+ assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
61
+ prev_token = input_ids[-1].item()
62
+ self.rng.manual_seed(self.hash_key * prev_token)
63
+ else:
64
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
65
+ return
66
+
67
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
68
+ # seed the rng using the previous tokens/prefix
69
+ # according to the seeding_scheme
70
+ self._seed_rng(input_ids)
71
+
72
+ greenlist_size = int(self.vocab_size * self.gamma)
73
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
74
+ if self.select_green_tokens: # directly
75
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
76
+ else: # select green via red
77
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
78
+ return greenlist_ids
79
+
80
+
81
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
82
+
83
+ # FIXME maybe make this explict instead of args/kwargs
84
+ def __init__(self, *args, **kwargs):
85
+ super().__init__(*args, **kwargs)
86
+
87
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
88
+ # TODO lets see if we can lose this loop
89
+ green_tokens_mask = torch.zeros_like(scores)
90
+ for b_idx in range(len(greenlist_token_ids)):
91
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
92
+ final_mask = green_tokens_mask.bool()
93
+ return final_mask
94
+
95
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
96
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
97
+ return scores
98
+
99
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
100
+
101
+ # this is lazy to allow us to colocate on the watermarked model's device
102
+ if self.rng is None:
103
+ self.rng = torch.Generator(device=input_ids.device)
104
+
105
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
106
+ # the seed and partition operations are not tensor/vectorized, thus
107
+ # each sequence in the batch needs to be treated separately.
108
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
109
+
110
+ for b_idx in range(input_ids.shape[0]):
111
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
112
+ batched_greenlist_ids[b_idx] = greenlist_ids
113
+
114
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
115
+
116
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
117
+ return scores
118
+
119
+
120
+ class WatermarkDetector(WatermarkBase):
121
+ def __init__(
122
+ self,
123
+ *args,
124
+ device: torch.device = None,
125
+ tokenizer: Tokenizer = None,
126
+ z_threshold: float = 4.0,
127
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
128
+ ignore_repeated_bigrams: bool = False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(*args, **kwargs)
132
+ # also configure the metrics returned/preprocessing options
133
+ assert device, "Must pass device"
134
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
135
+
136
+ self.tokenizer = tokenizer
137
+ self.device = device
138
+ self.z_threshold = z_threshold
139
+ self.rng = torch.Generator(device=self.device)
140
+
141
+ if self.seeding_scheme == "markov_1":
142
+ self.min_prefix_len = 1
143
+ else:
144
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
145
+
146
+ self.normalizers = []
147
+ for normalization_strategy in normalizers:
148
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
149
+
150
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
151
+ if self.ignore_repeated_bigrams:
152
+ assert self.seeding_scheme == "markov_1", "No repeated bigram credit variant assumes the single token seeding scheme."
153
+
154
+
155
+ def _compute_z_score(self, observed_count, T):
156
+ # count refers to number of green tokens, T is total number of tokens
157
+ expected_count = self.gamma
158
+ numer = observed_count - expected_count * T
159
+ denom = sqrt(T * expected_count * (1 - expected_count))
160
+ z = numer / denom
161
+ return z
162
+
163
+ def _compute_p_value(self, z):
164
+ p_value = scipy.stats.norm.sf(z)
165
+ return p_value
166
+
167
+ def _score_sequence(
168
+ self,
169
+ input_ids: Tensor,
170
+ return_num_tokens_scored: bool = True,
171
+ return_num_green_tokens: bool = True,
172
+ return_green_fraction: bool = True,
173
+ return_green_token_mask: bool = False,
174
+ return_z_score: bool = True,
175
+ return_p_value: bool = True,
176
+ ):
177
+ if self.ignore_repeated_bigrams:
178
+ # Method that only counts a green/red hit once per unique bigram.
179
+ # New num total tokens scored (T) becomes the number unique bigrams.
180
+ # We iterate over all unqiue token bigrams in the input, computing the greenlist
181
+ # induced by the first token in each, and then checking whether the second
182
+ # token falls in that greenlist.
183
+ assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
184
+ bigram_table = {}
185
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
186
+ freq = collections.Counter(token_bigram_generator)
187
+ num_tokens_scored = len(freq.keys())
188
+ for idx, bigram in enumerate(freq.keys()):
189
+ prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
190
+ greenlist_ids = self._get_greenlist_ids(prefix)
191
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
192
+ green_token_count = sum(bigram_table.values())
193
+ else:
194
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
195
+ if num_tokens_scored < 1:
196
+ raise ValueError((f"Must have at least {1} token to score after "
197
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
198
+ # Standard method.
199
+ # Since we generally need at least 1 token (for the simplest scheme)
200
+ # we start the iteration over the token sequence with a minimum
201
+ # num tokens as the first prefix for the seeding scheme,
202
+ # and at each step, compute the greenlist induced by the
203
+ # current prefix and check if the current token falls in the greenlist.
204
+ green_token_count, green_token_mask = 0, []
205
+ for idx in range(self.min_prefix_len, len(input_ids)):
206
+ curr_token = input_ids[idx]
207
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
208
+ if curr_token in greenlist_ids:
209
+ green_token_count += 1
210
+ green_token_mask.append(True)
211
+ else:
212
+ green_token_mask.append(False)
213
+
214
+ score_dict = dict()
215
+ if return_num_tokens_scored:
216
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
217
+ if return_num_green_tokens:
218
+ score_dict.update(dict(num_green_tokens=green_token_count))
219
+ if return_z_score:
220
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
221
+ if return_p_value:
222
+ z_score = score_dict.get("z_score")
223
+ if z_score is None:
224
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
225
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
226
+ if return_green_fraction:
227
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
228
+ if return_green_token_mask:
229
+ score_dict.update(dict(green_token_mask=green_token_mask))
230
+
231
+ return score_dict
232
+
233
+ def detect(
234
+ self,
235
+ text: str = None,
236
+ tokenized_text: list[int] = None,
237
+ return_prediction: bool = True,
238
+ return_scores: bool = True,
239
+ z_threshold: float = None,
240
+ **kwargs,
241
+ ) -> dict:
242
+
243
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
244
+ if return_prediction:
245
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
246
+
247
+ # run optional normalizers on text
248
+ for normalizer in self.normalizers:
249
+ text = normalizer(text)
250
+ if len(self.normalizers) > 0:
251
+ print(f"Text after normalization:\n\n{text}\n")
252
+
253
+ if tokenized_text is None:
254
+ assert self.tokenizer is not None, (
255
+ "Watermark detection on raw string ",
256
+ "requires an instance of the tokenizer ",
257
+ "that was used at generation time.",
258
+ )
259
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
260
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
261
+ tokenized_text = tokenized_text[1:]
262
+ else:
263
+ # try to remove the bos_tok at beginning if it's there
264
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
265
+ tokenized_text = tokenized_text[1:]
266
+
267
+ # call score method
268
+ output_dict = {}
269
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
270
+ if return_scores:
271
+ output_dict.update(score_dict)
272
+ # if passed return_prediction then perform the hypothesis test and return the outcome
273
+ if return_prediction:
274
+ z_threshold = z_threshold if z_threshold else self.z_threshold
275
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
276
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
277
+ if output_dict["prediction"]:
278
+ output_dict["confidence"] = 1 - score_dict["p_value"]
279
+
280
+ return output_dict
281
+