anonymous-aardvark commited on
Commit
85ecf02
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Submission 2841 Demo
3
+ emoji: 📊
4
+ colorFrom: pink
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: 3.10.6
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Anonymous Authors of "A Watermark for Large Language Models"
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import Namespace
17
+ args = Namespace()
18
+
19
+ arg_dict = {
20
+ 'run_gradio': True,
21
+ 'demo_public': False,
22
+ 'model_name_or_path': 'bigscience/bloom',
23
+ 'load_fp16' : False,
24
+ 'prompt_max_length': None,
25
+ 'max_new_tokens': 200,
26
+ 'generation_seed': 123,
27
+ 'use_sampling': True,
28
+ 'n_beams': 1,
29
+ 'sampling_temp': 0.7,
30
+ 'use_gpu': True,
31
+ 'seeding_scheme': 'simple_1',
32
+ 'gamma': 0.5,
33
+ 'delta': 2.0,
34
+ 'normalizers': '',
35
+ 'ignore_repeated_bigrams': False,
36
+ 'detection_z_threshold': 4.0,
37
+ 'select_green_tokens': True,
38
+ 'skip_model_load': True,
39
+ 'seed_separately': True,
40
+ }
41
+
42
+ args.__dict__.update(arg_dict)
43
+
44
+ from demo_watermark import main
45
+
46
+ main(args)
demo_watermark.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Anonymous Authors of "A Watermark for Large Language Models"
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ from pprint import pprint
19
+ from functools import partial
20
+
21
+ import numpy # for gradio hot reload
22
+ import gradio as gr
23
+
24
+ import torch
25
+
26
+ from transformers import (AutoTokenizer,
27
+ AutoModelForSeq2SeqLM,
28
+ AutoModelForCausalLM,
29
+ LogitsProcessorList)
30
+
31
+ from transformers import GPT2TokenizerFast
32
+ OPT_TOKENIZER = GPT2TokenizerFast
33
+
34
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
35
+
36
+ API_MODEL_MAP = {
37
+ "bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
38
+ "bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
39
+ "google/flan-ul2" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
40
+ "google/flan-t5-xxl" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
41
+ "EleutherAI/gpt-neox-20b" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
42
+ }
43
+
44
+ def str2bool(v):
45
+ """Util function for user friendly boolean flag args"""
46
+ if isinstance(v, bool):
47
+ return v
48
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
49
+ return True
50
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
51
+ return False
52
+ else:
53
+ raise argparse.ArgumentTypeError('Boolean value expected.')
54
+
55
+ def parse_args():
56
+ """Command line argument specification"""
57
+
58
+ parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
59
+
60
+ parser.add_argument(
61
+ "--run_gradio",
62
+ type=str2bool,
63
+ default=True,
64
+ help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
65
+ )
66
+ parser.add_argument(
67
+ "--demo_public",
68
+ type=str2bool,
69
+ default=False,
70
+ help="Whether to expose the gradio demo to the internet.",
71
+ )
72
+ parser.add_argument(
73
+ "--model_name_or_path",
74
+ type=str,
75
+ default="facebook/opt-6.7b",
76
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
77
+ )
78
+ parser.add_argument(
79
+ "--prompt_max_length",
80
+ type=int,
81
+ default=None,
82
+ help="Truncation length for prompt, overrides model config's max length field.",
83
+ )
84
+ parser.add_argument(
85
+ "--max_new_tokens",
86
+ type=int,
87
+ default=200,
88
+ help="Maximmum number of new tokens to generate.",
89
+ )
90
+ parser.add_argument(
91
+ "--generation_seed",
92
+ type=int,
93
+ default=123,
94
+ help="Seed for setting the torch global rng prior to generation.",
95
+ )
96
+ parser.add_argument(
97
+ "--use_sampling",
98
+ type=str2bool,
99
+ default=True,
100
+ help="Whether to generate using multinomial sampling.",
101
+ )
102
+ parser.add_argument(
103
+ "--sampling_temp",
104
+ type=float,
105
+ default=0.7,
106
+ help="Sampling temperature to use when generating using multinomial sampling.",
107
+ )
108
+ parser.add_argument(
109
+ "--n_beams",
110
+ type=int,
111
+ default=1,
112
+ help="Number of beams to use for beam search. 1 is normal greedy decoding",
113
+ )
114
+ parser.add_argument(
115
+ "--use_gpu",
116
+ type=str2bool,
117
+ default=True,
118
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
119
+ )
120
+ parser.add_argument(
121
+ "--seeding_scheme",
122
+ type=str,
123
+ default="simple_1",
124
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
125
+ )
126
+ parser.add_argument(
127
+ "--gamma",
128
+ type=float,
129
+ default=0.25,
130
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
131
+ )
132
+ parser.add_argument(
133
+ "--delta",
134
+ type=float,
135
+ default=2.0,
136
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
137
+ )
138
+ parser.add_argument(
139
+ "--normalizers",
140
+ type=str,
141
+ default="",
142
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
143
+ )
144
+ parser.add_argument(
145
+ "--ignore_repeated_bigrams",
146
+ type=str2bool,
147
+ default=False,
148
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
149
+ )
150
+ parser.add_argument(
151
+ "--detection_z_threshold",
152
+ type=float,
153
+ default=4.0,
154
+ help="The test statistic threshold for the detection hypothesis test.",
155
+ )
156
+ parser.add_argument(
157
+ "--select_green_tokens",
158
+ type=str2bool,
159
+ default=True,
160
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
161
+ )
162
+ parser.add_argument(
163
+ "--skip_model_load",
164
+ type=str2bool,
165
+ default=False,
166
+ help="Skip the model loading to debug the interface.",
167
+ )
168
+ parser.add_argument(
169
+ "--seed_separately",
170
+ type=str2bool,
171
+ default=True,
172
+ help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
173
+ )
174
+ parser.add_argument(
175
+ "--load_fp16",
176
+ type=str2bool,
177
+ default=False,
178
+ help="Whether to run model in float16 precsion.",
179
+ )
180
+ args = parser.parse_args()
181
+ return args
182
+
183
+ def load_model(args):
184
+ """Load and return the model and tokenizer"""
185
+
186
+ args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
187
+ args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
188
+ if args.is_seq2seq_model:
189
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
190
+ elif args.is_decoder_only_model:
191
+ if args.load_fp16:
192
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
193
+ else:
194
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
195
+ else:
196
+ raise ValueError(f"Unknown model type: {args.model_name_or_path}")
197
+
198
+ if args.use_gpu:
199
+ device = "cuda" if torch.cuda.is_available() else "cpu"
200
+ if args.load_fp16:
201
+ pass
202
+ else:
203
+ model = model.to(device)
204
+ else:
205
+ device = "cpu"
206
+ model.eval()
207
+
208
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
209
+
210
+ return model, tokenizer, device
211
+
212
+
213
+ from text_generation import InferenceAPIClient
214
+ from requests.exceptions import ReadTimeout
215
+ def generate_with_api(prompt, args):
216
+ hf_api_key = os.environ.get("HF_API_KEY")
217
+ if hf_api_key is None:
218
+ raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
219
+
220
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
221
+
222
+ assert args.n_beams == 1, "HF API models do not support beam search."
223
+ generation_params = {
224
+ "max_new_tokens": args.max_new_tokens,
225
+ "do_sample": args.use_sampling,
226
+ }
227
+ if args.use_sampling:
228
+ generation_params["temperature"] = args.sampling_temp
229
+ generation_params["seed"] = args.generation_seed
230
+
231
+ timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
232
+ try:
233
+ generation_params["watermark"] = False
234
+ without_watermark_iterator = client.generate_stream(prompt, **generation_params)
235
+ except ReadTimeout as e:
236
+ print(e)
237
+ without_watermark_iterator = (char for char in timeout_msg)
238
+ try:
239
+ generation_params["watermark"] = True
240
+ with_watermark_iterator = client.generate_stream(prompt, **generation_params)
241
+ except ReadTimeout as e:
242
+ print(e)
243
+ with_watermark_iterator = (char for char in timeout_msg)
244
+
245
+ all_without_words, all_with_words = "", ""
246
+ for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
247
+ all_without_words += without_word.token.text
248
+ all_with_words += with_word.token.text
249
+ yield all_without_words, all_with_words
250
+
251
+
252
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
253
+
254
+ # This applies to both the local and API model scenarios
255
+ if args.model_name_or_path in API_MODEL_MAP:
256
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
257
+ elif hasattr(model.config,"max_position_embedding"):
258
+ args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
259
+ else:
260
+ args.prompt_max_length = 2048-args.max_new_tokens
261
+
262
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
263
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
264
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
265
+
266
+ return (redecoded_input,
267
+ int(truncation_warning),
268
+ args)
269
+
270
+
271
+
272
+ def generate(prompt, args, tokenizer, model=None, device=None):
273
+ """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
274
+ and generate watermarked text by passing it to the generate method of the model
275
+ as a logits processor. """
276
+
277
+ print(f"Generating with {args}")
278
+ print(f"Prompt: {prompt}")
279
+
280
+ if args.model_name_or_path in API_MODEL_MAP:
281
+ api_outputs = generate_with_api(prompt, args)
282
+ yield from api_outputs
283
+ else:
284
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
285
+
286
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
287
+ gamma=args.gamma,
288
+ delta=args.delta,
289
+ seeding_scheme=args.seeding_scheme,
290
+ select_green_tokens=args.select_green_tokens)
291
+
292
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
293
+
294
+ if args.use_sampling:
295
+ gen_kwargs.update(dict(
296
+ do_sample=True,
297
+ top_k=0,
298
+ temperature=args.sampling_temp
299
+ ))
300
+ else:
301
+ gen_kwargs.update(dict(
302
+ num_beams=args.n_beams
303
+ ))
304
+
305
+ generate_without_watermark = partial(
306
+ model.generate,
307
+ **gen_kwargs
308
+ )
309
+ generate_with_watermark = partial(
310
+ model.generate,
311
+ logits_processor=LogitsProcessorList([watermark_processor]),
312
+ **gen_kwargs
313
+ )
314
+
315
+ torch.manual_seed(args.generation_seed)
316
+ output_without_watermark = generate_without_watermark(**tokd_input)
317
+
318
+ # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
319
+ if args.seed_separately:
320
+ torch.manual_seed(args.generation_seed)
321
+ output_with_watermark = generate_with_watermark(**tokd_input)
322
+
323
+ if args.is_decoder_only_model:
324
+ # need to isolate the newly generated tokens
325
+ output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
326
+ output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
327
+
328
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
329
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
330
+
331
+ # mocking the API outputs in a whitespace split generator style
332
+ all_without_words, all_with_words = "", ""
333
+ for without_word, with_word in zip(decoded_output_without_watermark.split(), decoded_output_with_watermark.split()):
334
+ all_without_words += without_word + " "
335
+ all_with_words += with_word + " "
336
+ yield all_without_words, all_with_words
337
+
338
+
339
+ def format_names(s):
340
+ """Format names for the gradio demo interface"""
341
+ s=s.replace("num_tokens_scored","Tokens Counted (T)")
342
+ s=s.replace("num_green_tokens","# Tokens in Greenlist")
343
+ s=s.replace("green_fraction","Fraction of T in Greenlist")
344
+ s=s.replace("z_score","z-score")
345
+ s=s.replace("p_value","p value")
346
+ s=s.replace("prediction","Prediction")
347
+ s=s.replace("confidence","Confidence")
348
+ return s
349
+
350
+ def list_format_scores(score_dict, detection_threshold):
351
+ """Format the detection metrics into a gradio dataframe input format"""
352
+ lst_2d = []
353
+ for k,v in score_dict.items():
354
+ if k=='green_fraction':
355
+ lst_2d.append([format_names(k), f"{v:.1%}"])
356
+ elif k=='confidence':
357
+ lst_2d.append([format_names(k), f"{v:.3%}"])
358
+ elif isinstance(v, float):
359
+ lst_2d.append([format_names(k), f"{v:.3g}"])
360
+ elif isinstance(v, bool):
361
+ lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
362
+ else:
363
+ lst_2d.append([format_names(k), f"{v}"])
364
+ if "confidence" in score_dict:
365
+ lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
366
+ else:
367
+ lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
368
+ return lst_2d
369
+
370
+ def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
371
+ """Instantiate the WatermarkDetection object and call detect on
372
+ the input text returning the scores and outcome of the test"""
373
+
374
+ print(f"Detecting with {args}")
375
+ print(f"Detection Tokenizer: {type(tokenizer)}")
376
+
377
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
378
+ gamma=args.gamma,
379
+ seeding_scheme=args.seeding_scheme,
380
+ device=device,
381
+ tokenizer=tokenizer,
382
+ z_threshold=args.detection_z_threshold,
383
+ normalizers=args.normalizers,
384
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
385
+ select_green_tokens=args.select_green_tokens)
386
+ # for now, just don't display the green token mask
387
+ # if we're using normalizers or ignore_repeated_bigrams
388
+ if args.normalizers != [] or args.ignore_repeated_bigrams:
389
+ return_green_token_mask = False
390
+
391
+ error = False
392
+ green_token_mask = None
393
+ if input_text == "":
394
+ error = True
395
+ else:
396
+ try:
397
+ score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
398
+ green_token_mask = score_dict.pop("green_token_mask", None)
399
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
400
+ except ValueError as e:
401
+ print(e)
402
+ error = True
403
+ if error:
404
+ output = [["Error","string too short to compute metrics"]]
405
+ output += [["",""] for _ in range(6)]
406
+
407
+ html_output = "[No highlight markup generated]"
408
+ if green_token_mask is not None:
409
+ # hack bc we need a fast tokenizer with charspan support
410
+ if "opt" in args.model_name_or_path:
411
+ tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path)
412
+
413
+ tokens = tokenizer(input_text)
414
+ if tokens["input_ids"][0] == tokenizer.bos_token_id:
415
+ tokens["input_ids"] = tokens["input_ids"][1:] # ignore attention mask
416
+ skip = watermark_detector.min_prefix_len
417
+ charspans = [tokens.token_to_chars(i) for i in range(skip,len(tokens["input_ids"]))]
418
+ charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
419
+
420
+ if len(charspans) != len(green_token_mask): breakpoint()
421
+ assert len(charspans) == len(green_token_mask)
422
+
423
+ tags = [(f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>') for cs, m in zip(charspans, green_token_mask)]
424
+ html_output = f'<p>{" ".join(tags)}</p>'
425
+
426
+ return output, args, tokenizer, html_output
427
+
428
+ def run_gradio(args, model=None, device=None, tokenizer=None):
429
+ """Define and launch the gradio demo interface"""
430
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
431
+ generate_partial = partial(generate, model=model, device=device)
432
+ detect_partial = partial(detect, device=device)
433
+
434
+
435
+ css = """
436
+ .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
437
+ .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
438
+ """
439
+
440
+ with gr.Blocks(css=css) as demo:
441
+ # Top section, greeting and instructions
442
+ with gr.Row():
443
+ with gr.Column(scale=9):
444
+ gr.Markdown(
445
+ """
446
+ ## 💧 A Watermark for Large Language Models 🔍
447
+
448
+ Demo made possible by the HuggingFace 🤗 [text-generation-inference](https://github.com/huggingface/text-generation-inference) serving framework.
449
+ """
450
+ )
451
+ with gr.Column(scale=1):
452
+ # if model_name_or_path at startup not one of the API models then add to dropdown
453
+ all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path])))
454
+ model_selector = gr.Dropdown(
455
+ all_models,
456
+ value=args.model_name_or_path,
457
+ label="Language Model",
458
+ )
459
+
460
+ # Construct state for parameters, define updates and toggles
461
+ default_prompt = args.__dict__.pop("default_prompt")
462
+ session_args = gr.State(value=args)
463
+ # note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
464
+ session_tokenizer = gr.State(value=lambda : tokenizer)
465
+
466
+ with gr.Tab("Welcome"):
467
+ with gr.Row():
468
+ with gr.Column(scale=2):
469
+ gr.Markdown(
470
+ """
471
+ Potential harms of large language models can be mitigated by *watermarking* a model's output.
472
+ *Watermarks* are embedded signals in the generated text that are invisible to humans but algorithmically
473
+ detectable, that allow *anyone* to later check whether a given span of text
474
+ was likely to have been generated by a model that uses the watermark.
475
+
476
+ This space showcases a watermarking approach that can be applied to _any_ generative language model.
477
+ For demonstration purposes, the space demos a selection of multi-billion parameter models (see the following note for caveats).
478
+ """
479
+ )
480
+ with gr.Accordion("A note on the available models:",open=False):
481
+ gr.Markdown(
482
+ """
483
+ This demo uses open-source language models. Today, these models are less powerful than proprietary commercial tools like ChatGPT, Claude, Bard, or Bing/Sydney.
484
+
485
+ Models like [BLOOM (175B)](https://huggingface.co/bigscience/bloom) are designed to "complete" your prompt, and are not fine-tuned to follow instructions.
486
+ For best results, prompt that model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph.
487
+ Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story.
488
+ Longer prompts that end mid-sentence will result in more fluent generations.
489
+
490
+ Some of the models available in this demo are fine-tuned to follow instructions but have different strengths and will showcase different
491
+ types of watermark behavior. [BLOOMZ](https://huggingface.co/bigscience/bloomz) is an instruction tuned variant of BLOOM capable of following instructions in dozens of languages zero-shot
492
+ and can generate long and coherent paragraphs and stories given the right prompt.
493
+ The FLAN models [FLAN-t5-xxl (11B)](https://huggingface.co/google/flan-t5-xxl) and [FLAN-UL2 (20B)](https://huggingface.co/google/flan-ul2) are fine-tuned on a variety of in-context few-shot learning NLP tasks,
494
+ such as reasoning, and question answering.
495
+
496
+ Generally, short, low entropy scenarios where the model has very few choices in terms of correct/suitable responses to the prompt
497
+ will not exhibit as strong of a watermark presence, while longer watermarked outputs will produce higher detection statistics.
498
+ """
499
+ )
500
+ gr.Markdown(
501
+ """
502
+ **[Generate & Detect]**: The first tab shows that the watermark can be embedded with
503
+ negligible impact on text quality. You can try any prompt and compare the quality of
504
+ normal text (*Output Without Watermark*) to the watermarked text (*Output With Watermark*) below it.
505
+ You can also "see" the watermark by looking at the **Highlighted** tab where the tokens are
506
+ colored green or red depending on which list they are in.
507
+ Metrics on the right show that the watermark can be reliably detected given a reasonably small number of tokens (25-50).
508
+ Detection is very efficient and does not use the language model or its parameters.
509
+
510
+ **[Detector Only]**: You can also copy-paste the watermarked text (or any other text)
511
+ into the second tab. This can be used to see how many sentences you could remove and still detect the watermark.
512
+ You can also verify here that the detection has, by design, a low false-positive rate;
513
+ This means that human-generated text that you copy into this detector will not be marked as machine-generated.
514
+ """
515
+ )
516
+
517
+ with gr.Column(scale=1):
518
+ gr.Markdown(
519
+ """
520
+ ![](https://drive.google.com/uc?export=view&id=1yVLPcjm-xvaCjQyc3FGLsWIU84v1QRoC)
521
+ """
522
+ )
523
+
524
+ with gr.Tab("Generate & Detect"):
525
+
526
+ with gr.Row():
527
+ prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
528
+ with gr.Row():
529
+ generate_btn = gr.Button("Generate")
530
+ with gr.Row():
531
+ with gr.Column(scale=2):
532
+ with gr.Tab("Output Without Watermark (Raw Text)"):
533
+ output_without_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14)
534
+ with gr.Tab("Highlighted"):
535
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
536
+ with gr.Column(scale=1):
537
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
538
+ with gr.Row():
539
+ with gr.Column(scale=2):
540
+ with gr.Tab("Output With Watermark (Raw Text)"):
541
+ output_with_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14)
542
+ with gr.Tab("Highlighted"):
543
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
544
+ with gr.Column(scale=1):
545
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
546
+
547
+ redecoded_input = gr.Textbox(visible=False)
548
+ truncation_warning = gr.Number(visible=False)
549
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
550
+ if truncation_warning:
551
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
552
+ else:
553
+ return orig_prompt, args
554
+
555
+ with gr.Tab("Detector Only"):
556
+ with gr.Row():
557
+ with gr.Column(scale=2):
558
+ with gr.Tab("Text to Analyze"):
559
+ detection_input = gr.Textbox(interactive=True,lines=14,max_lines=14)
560
+ with gr.Tab("Highlighted"):
561
+ html_detection_input = gr.HTML(elem_id="html-detection-input")
562
+ with gr.Column(scale=1):
563
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
564
+ with gr.Row():
565
+ detect_btn = gr.Button("Detect")
566
+
567
+ # Parameter selection group
568
+ with gr.Accordion("Advanced Settings",open=False):
569
+ with gr.Row():
570
+ with gr.Column(scale=1):
571
+ gr.Markdown(f"#### Generation Parameters")
572
+ with gr.Row():
573
+ decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy"))
574
+ with gr.Row():
575
+ sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True)
576
+ with gr.Row():
577
+ generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
578
+ with gr.Row():
579
+ n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=((not args.use_sampling) and (not args.model_name_or_path in API_MODEL_MAP)))
580
+ with gr.Row():
581
+ max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
582
+
583
+ with gr.Column(scale=1):
584
+ gr.Markdown(f"#### Watermark Parameters")
585
+ with gr.Row():
586
+ gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
587
+ with gr.Row():
588
+ delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
589
+ gr.Markdown(f"#### Detector Parameters")
590
+ with gr.Row():
591
+ detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold)
592
+ with gr.Row():
593
+ ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
594
+ with gr.Row():
595
+ normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
596
+ with gr.Row():
597
+ gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
598
+ with gr.Row():
599
+ current_parameters = gr.Textbox(label="Current Parameters", value=args)
600
+ with gr.Accordion("Legacy Settings",open=False):
601
+ with gr.Row():
602
+ with gr.Column(scale=1):
603
+ seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately)
604
+ with gr.Column(scale=1):
605
+ select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
606
+
607
+
608
+ with gr.Accordion("What do the settings do?",open=False):
609
+ gr.Markdown(
610
+ """
611
+ #### Generation Parameters:
612
+
613
+ - **Decoding Method** : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding.
614
+ - **Sampling Temperature** : If using multinomial sampling we can set the temperature of the sampling distribution.
615
+ 0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution.
616
+ 0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding.
617
+ - **Generation Seed** : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy
618
+ outputs reproducible. Does not apply for greedy decoding.
619
+ - **Number of Beams** : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search.
620
+ This is not implemented/excluded from paper for multinomial sampling but may be added in future.
621
+ - **Max Generated Tokens** : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens.
622
+ Note that the model is free to generate fewer tokens depending on the prompt.
623
+ Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`,
624
+ and inputs will be truncated accordingly.
625
+
626
+ #### Watermark Parameters:
627
+
628
+ - **gamma** : The fraction of the vocabulary to be partitioned into the greenlist at each generation step.
629
+ Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve
630
+ a greater differentiation from human/unwatermarked text because it is preferentially sampling
631
+ from a smaller green set making those tokens less likely to occur by chance.
632
+ - **delta** : The amount of positive bias to add to the logits of every token in the greenlist
633
+ at each generation step before sampling/choosing the next token. Higher delta values
634
+ mean that the greenlist tokens are more heavily preferred by the watermarked model
635
+ and as the bias becomes very large the watermark transitions from "soft" to "hard".
636
+ For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on
637
+ generation quality, especially when there is not a lot of flexibility in the distribution.
638
+
639
+ #### Detector Parameters:
640
+
641
+ - **z-score threshold** : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make
642
+ _false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely
643
+ as a genuine human text with a significant number of tokens will almost never achieve
644
+ that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked
645
+ texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and
646
+ be flagged as "watermarked". However, a lowere threshold will increase the chance that human text
647
+ that contains a slightly higher than average number of green tokens is erroneously flagged.
648
+ 4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text.
649
+ - **Ignore Bigram Repeats** : This alternate detection algorithm only considers the unique bigrams in the text during detection,
650
+ computing the greenlists based on the first in each pair and checking whether the second falls within the list.
651
+ This means that `T` is now the unique number of bigrams in the text, which becomes less than the total
652
+ number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion.
653
+ - **Normalizations** : we implement a few basic normaliations to defend against various adversarial perturbations of the
654
+ text analyzed during detection. Currently we support converting all chracters to unicode,
655
+ replacing homoglyphs with a canonical form, and standardizing the capitalization.
656
+ See the paper for a detailed discussion of input normalization.
657
+ """
658
+ )
659
+
660
+ with gr.Accordion("What do the output metrics mean?",open=False):
661
+ gr.Markdown(
662
+ """
663
+ - `z-score threshold` : The cuttoff for the hypothesis test
664
+ - `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm.
665
+ The first token is ommitted in the simple, single token seeding scheme since there is no way to generate
666
+ a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm,
667
+ described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition.
668
+ - `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist
669
+ - `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text.
670
+ - `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold`
671
+ we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked
672
+ - `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of
673
+ observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists.
674
+ If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance.
675
+ - `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold`
676
+ - `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent
677
+ the confidence of the detection based on the unlikeliness of this `z-score` observation.
678
+ """
679
+ )
680
+
681
+ # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
682
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success(
683
+ fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success(
684
+ fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
685
+ fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
686
+ # Show truncated version of prompt if truncation occurred
687
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
688
+ # Register main detection tab click
689
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer,html_detection_input], api_name="detection")
690
+
691
+ # State management logic
692
+ # define update callbacks that change the state dict
693
+ def update_model(session_state, value): session_state.model_name_or_path = value; return session_state
694
+ def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
695
+ def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
696
+ def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
697
+ def update_delta(session_state, value): session_state.delta = float(value); return session_state
698
+ def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
699
+ def update_decoding(session_state, value):
700
+ if value == "multinomial":
701
+ session_state.use_sampling = True
702
+ elif value == "greedy":
703
+ session_state.use_sampling = False
704
+ return session_state
705
+ def toggle_sampling_vis(value):
706
+ if value == "multinomial":
707
+ return gr.update(visible=True)
708
+ elif value == "greedy":
709
+ return gr.update(visible=False)
710
+ def toggle_sampling_vis_inv(value):
711
+ if value == "multinomial":
712
+ return gr.update(visible=False)
713
+ elif value == "greedy":
714
+ return gr.update(visible=True)
715
+ # if model name is in the list of api models, set the num beams parameter to 1 and hide n_beams
716
+ def toggle_vis_for_api_model(value):
717
+ if value in API_MODEL_MAP:
718
+ return gr.update(visible=False)
719
+ else:
720
+ return gr.update(visible=True)
721
+ def toggle_beams_for_api_model(value, orig_n_beams):
722
+ if value in API_MODEL_MAP:
723
+ return gr.update(value=1)
724
+ else:
725
+ return gr.update(value=orig_n_beams)
726
+ # if model name is in the list of api models, set the interactive parameter to false
727
+ def toggle_interactive_for_api_model(value):
728
+ if value in API_MODEL_MAP:
729
+ return gr.update(interactive=False)
730
+ else:
731
+ return gr.update(interactive=True)
732
+ # if model name is in the list of api models, set gamma and delta based on API map
733
+ def toggle_gamma_for_api_model(value, orig_gamma):
734
+ if value in API_MODEL_MAP:
735
+ return gr.update(value=API_MODEL_MAP[value]["gamma"])
736
+ else:
737
+ return gr.update(value=orig_gamma)
738
+ def toggle_delta_for_api_model(value, orig_delta):
739
+ if value in API_MODEL_MAP:
740
+ return gr.update(value=API_MODEL_MAP[value]["delta"])
741
+ else:
742
+ return gr.update(value=orig_delta)
743
+
744
+ def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
745
+ def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
746
+ def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
747
+ def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
748
+ def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
749
+ def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
750
+ def update_tokenizer(model_name_or_path):
751
+ # if model_name_or_path == ALPACA_MODEL_NAME:
752
+ # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
753
+ # else:
754
+ return AutoTokenizer.from_pretrained(model_name_or_path)
755
+
756
+ def check_model(value): return value if (value!="" and value is not None) else args.model_name_or_path
757
+ # enforce constraint that model cannot be null or empty
758
+ # then attach model callbacks in particular
759
+ model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
760
+ toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]
761
+ ).then(
762
+ toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams]
763
+ ).then(
764
+ toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma]
765
+ ).then(
766
+ toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta]
767
+ ).then(
768
+ toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma]
769
+ ).then(
770
+ toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta]
771
+ ).then(
772
+ update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
773
+ ).then(
774
+ update_model,inputs=[session_args, model_selector], outputs=[session_args]
775
+ ).then(
776
+ lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
777
+ )
778
+ # registering callbacks for toggling the visibilty of certain parameters based on the values of others
779
+ decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
780
+ decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
781
+ decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
782
+ decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
783
+ # registering all state update callbacks
784
+ decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
785
+ sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
786
+ generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
787
+ n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
788
+ max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
789
+ gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
790
+ delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
791
+ detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
792
+ ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
793
+ normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
794
+ seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
795
+ select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
796
+ # register additional callback on button clicks that updates the shown parameters window
797
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
798
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
799
+ # When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
800
+ delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
801
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
802
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
803
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
804
+ gamma.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
805
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
806
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
807
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
808
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
809
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
810
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
811
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
812
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
813
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
814
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
815
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
816
+ normalizers.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
817
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
818
+ select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
819
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
820
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
821
+
822
+ demo.queue(concurrency_count=3)
823
+
824
+ if args.demo_public:
825
+ demo.launch(share=True) # exposes app to the internet via randomly generated link
826
+ else:
827
+ demo.launch()
828
+
829
+ def main(args):
830
+ """Run a command line version of the generation and detection operations
831
+ and optionally launch and serve the gradio demo"""
832
+ # Initial arg processing and log
833
+ args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
834
+ print(args)
835
+
836
+ if not args.skip_model_load:
837
+ model, tokenizer, device = load_model(args)
838
+ else:
839
+ model, tokenizer, device = None, None, None
840
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
841
+ if args.use_gpu:
842
+ device = "cuda" if torch.cuda.is_available() else "cpu"
843
+ else:
844
+ device = "cpu"
845
+
846
+
847
+ # terrapin example
848
+ input_text = (
849
+ "The aardvark is sometimes colloquially called the 'African ant bear',[6] 'anteater' (not to be confused with the South American anteater), or the 'Cape anteater'[6] after the Cape of Good Hope. The name 'aardvark' is Afrikaans (Afrikaans pronunciation: [ˈɑːrtfark]), comes from earlier Afrikaans erdvark[6] and means 'earth pig' or 'ground pig' (aarde: 'earth/ground', vark: 'pig'), because of its burrowing habits.[7][8][9] The name Orycteropus means 'burrowing foot', and the name afer refers to Africa.[10] The name of the aardvark's order, Tubulidentata, comes from the tubule-style teeth.[11]\n\nThe aardvark is not closely related to the pig; rather, it is the sole extant representative of the obscure mammalian order Tubulidentata,[10] in which it is usually considered to form one variable species of the genus Orycteropus, the sole surviving genus in the family Orycteropodidae. The aardvark is"
850
+ )
851
+
852
+ args.default_prompt = input_text
853
+
854
+
855
+ # Generate and detect, report to stdout
856
+ if not args.skip_model_load:
857
+
858
+ term_width = 80
859
+ print("#"*term_width)
860
+ print("Prompt:")
861
+ print(input_text)
862
+
863
+ # a generator that yields (without_watermark, with_watermark) pairs
864
+ generator_outputs = generate(input_text,
865
+ args,
866
+ model=model,
867
+ device=device,
868
+ tokenizer=tokenizer)
869
+ # we need to iterate over it,
870
+ # but we only want the last output in this case
871
+ for out in generator_outputs:
872
+ decoded_output_without_watermark = out[0]
873
+ decoded_output_with_watermark = out[1]
874
+
875
+ without_watermark_detection_result = detect(decoded_output_without_watermark,
876
+ args,
877
+ device=device,
878
+ tokenizer=tokenizer,
879
+ return_green_token_mask=False)
880
+ with_watermark_detection_result = detect(decoded_output_with_watermark,
881
+ args,
882
+ device=device,
883
+ tokenizer=tokenizer,
884
+ return_green_token_mask=False)
885
+
886
+ print("#"*term_width)
887
+ print("Output without watermark:")
888
+ print(decoded_output_without_watermark)
889
+ print("-"*term_width)
890
+ print(f"Detection result @ {args.detection_z_threshold}:")
891
+ pprint(without_watermark_detection_result)
892
+ print("-"*term_width)
893
+
894
+ print("#"*term_width)
895
+ print("Output with watermark:")
896
+ print(decoded_output_with_watermark)
897
+ print("-"*term_width)
898
+ print(f"Detection result @ {args.detection_z_threshold}:")
899
+ pprint(with_watermark_detection_result)
900
+ print("-"*term_width)
901
+
902
+
903
+ # Launch the app to generate and detect interactively (implements the hf space demo)
904
+ if args.run_gradio:
905
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
906
+
907
+ return
908
+
909
+ if __name__ == "__main__":
910
+
911
+ args = parse_args()
912
+ print(args)
913
+
914
+ main(args)
homoglyph_data/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is data for homoglyph finding
2
+
3
+ """Original package info:
4
+
5
+ Homoglyphs
6
+ * Get similar letters
7
+ * Convert string to ASCII letters
8
+ * Detect possible letter languages
9
+ * Detect letter UTF-8 group.
10
+
11
+ # main package info
12
+ __title__ = 'Homoglyphs'
13
+ __version__ = '2.0.4'
14
+ __author__ = 'Gram Orsinium'
15
+ __license__ = 'MIT'
16
+
17
+ # License:
18
+
19
+ MIT License 2019 orsinium <master_fess@mail.ru>
20
+
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice (including the next
29
+ paragraph) shall be included in all copies or substantial portions of the
30
+ Software.
31
+
32
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ SOFTWARE.
39
+
40
+ """
homoglyph_data/categories.json ADDED
The diff for this file is too large to render. See raw diff
 
homoglyph_data/confusables_sept2022.json ADDED
The diff for this file is too large to render. See raw diff
 
homoglyph_data/languages.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ar": "ءآأؤإئابةتثجحخدذرزسشصضطظعغػؼؽؾؿـفقكلمنهوىيًٌٍَُِّ",
3
+ "be": "ʼЁІЎАБВГДЕЖЗЙКЛМНОПРСТУФХЦЧШЫЬЭЮЯабвгдежзйклмнопрстуфхцчшыьэюяёіў",
4
+ "bg": "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
5
+ "ca": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÍÏÒÓÚÜÇàèéíïòóúüç·",
6
+ "cz": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÓÚÝáéíóúýČčĎďĚěŇňŘřŠšŤťŮůŽž",
7
+ "da": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÅÆØåæø",
8
+ "de": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÖÜßäöü",
9
+ "el": "ΪΫΆΈΉΊΌΎΏΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩΐΰϊϋάέήίαβγδεζηθικλμνξοπρςστυφχψωόύώ",
10
+ "en": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
11
+ "eo": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĈĉĜĝĤĥĴĵŜŝŬŭ",
12
+ "es": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÑÓÚÜáéíñóúü",
13
+ "et": "ABDEGHIJKLMNOPRSTUVabdeghijklmnoprstuvÄÕÖÜäõöü",
14
+ "fi": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÅÖäåöŠšŽž",
15
+ "fr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÂÇÈÉÊÎÏÙÛàâçèéêîïùûŒœ",
16
+ "he": "אבגדהוזחטיךכלםמןנסעףפץצקרשתװױײ",
17
+ "hr": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĆćČčĐ𩹮ž",
18
+ "hu": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzÁÉÍÓÖÚÜáéíóöúüŐőŰű",
19
+ "it": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÌÒÓÙàèéìòóù",
20
+ "lt": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzĄąČčĖėĘęĮįŠšŪūŲųŽž",
21
+ "lv": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĀāČčĒēĢģĪīĶķĻļŅņŠšŪūŽž",
22
+ "mk": "ЃЅЈЉЊЌЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшѓѕјљњќџ",
23
+ "nl": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
24
+ "pl": "ABCDEFGHIJKLMNOPRSTUWYZabcdefghijklmnoprstuwyzÓóĄąĆćĘꣳŃńŚśŹźŻż",
25
+ "pt": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÁÂÃÇÉÊÍÓÔÕÚàáâãçéêíóôõú",
26
+ "ro": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÂÎâîĂăȘșȚț",
27
+ "ru": "ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
28
+ "sk": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÄÉÍÓÔÚÝáäéíóôúýČčĎďĹ弾ŇňŔ੹ŤťŽž",
29
+ "sl": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzČ芚Žž",
30
+ "sr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzЂЈЉЊЋЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшђјљњћџ",
31
+ "th": "กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลฦวศษสหฬอฮฯะัาำิีึืฺุู฿เแโใไๅๆ็่้๊๋์ํ๎๏๐๑๒๓๔๕๖๗๘๙๚๛",
32
+ "tr": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÂÇÎÖÛÜâçîöûüĞğİıŞş",
33
+ "vi": "ABCDEGHIKLMNOPQRSTUVXYabcdeghiklmnopqrstuvxyÂÊÔâêôĂăĐđƠơƯư"
34
+ }
homoglyphs.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Updated version of core.py from
2
+ https://github.com/yamatt/homoglyphs/tree/main/homoglyphs_fork
3
+ for modern python3
4
+ """
5
+
6
+ from collections import defaultdict
7
+ import json
8
+ from itertools import product
9
+ import os
10
+ import unicodedata
11
+
12
+ # Actions if char not in alphabet
13
+ STRATEGY_LOAD = 1 # load category for this char
14
+ STRATEGY_IGNORE = 2 # add char to result
15
+ STRATEGY_REMOVE = 3 # remove char from result
16
+
17
+ ASCII_RANGE = range(128)
18
+
19
+
20
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
22
+
23
+
24
+ class Categories:
25
+ """
26
+ Work with aliases from ISO 15924.
27
+ https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
28
+ """
29
+
30
+ fpath = os.path.join(DATA_LOCATION, "categories.json")
31
+
32
+ @classmethod
33
+ def _get_ranges(cls, categories):
34
+ """
35
+ :return: iter: (start code, end code)
36
+ :rtype: list
37
+ """
38
+ with open(cls.fpath, encoding="utf-8") as f:
39
+ data = json.load(f)
40
+
41
+ for category in categories:
42
+ if category not in data["aliases"]:
43
+ raise ValueError("Invalid category: {}".format(category))
44
+
45
+ for point in data["points"]:
46
+ if point[2] in categories:
47
+ yield point[:2]
48
+
49
+ @classmethod
50
+ def get_alphabet(cls, categories):
51
+ """
52
+ :return: set of chars in alphabet by categories list
53
+ :rtype: set
54
+ """
55
+ alphabet = set()
56
+ for start, end in cls._get_ranges(categories):
57
+ chars = (chr(code) for code in range(start, end + 1))
58
+ alphabet.update(chars)
59
+ return alphabet
60
+
61
+ @classmethod
62
+ def detect(cls, char):
63
+ """
64
+ :return: category
65
+ :rtype: str
66
+ """
67
+ with open(cls.fpath, encoding="utf-8") as f:
68
+ data = json.load(f)
69
+
70
+ # try detect category by unicodedata
71
+ try:
72
+ category = unicodedata.name(char).split()[0]
73
+ except (TypeError, ValueError):
74
+ # In Python2 unicodedata.name raise error for non-unicode chars
75
+ # Python3 raise ValueError for non-unicode characters
76
+ pass
77
+ else:
78
+ if category in data["aliases"]:
79
+ return category
80
+
81
+ # try detect category by ranges from JSON file.
82
+ code = ord(char)
83
+ for point in data["points"]:
84
+ if point[0] <= code <= point[1]:
85
+ return point[2]
86
+
87
+ @classmethod
88
+ def get_all(cls):
89
+ with open(cls.fpath, encoding="utf-8") as f:
90
+ data = json.load(f)
91
+ return set(data["aliases"])
92
+
93
+
94
+ class Languages:
95
+ fpath = os.path.join(DATA_LOCATION, "languages.json")
96
+
97
+ @classmethod
98
+ def get_alphabet(cls, languages):
99
+ """
100
+ :return: set of chars in alphabet by languages list
101
+ :rtype: set
102
+ """
103
+ with open(cls.fpath, encoding="utf-8") as f:
104
+ data = json.load(f)
105
+ alphabet = set()
106
+ for lang in languages:
107
+ if lang not in data:
108
+ raise ValueError("Invalid language code: {}".format(lang))
109
+ alphabet.update(data[lang])
110
+ return alphabet
111
+
112
+ @classmethod
113
+ def detect(cls, char):
114
+ """
115
+ :return: set of languages which alphabet contains passed char.
116
+ :rtype: set
117
+ """
118
+ with open(cls.fpath, encoding="utf-8") as f:
119
+ data = json.load(f)
120
+ languages = set()
121
+ for lang, alphabet in data.items():
122
+ if char in alphabet:
123
+ languages.add(lang)
124
+ return languages
125
+
126
+ @classmethod
127
+ def get_all(cls):
128
+ with open(cls.fpath, encoding="utf-8") as f:
129
+ data = json.load(f)
130
+ return set(data.keys())
131
+
132
+
133
+ class Homoglyphs:
134
+ def __init__(
135
+ self,
136
+ categories=None,
137
+ languages=None,
138
+ alphabet=None,
139
+ strategy=STRATEGY_IGNORE,
140
+ ascii_strategy=STRATEGY_IGNORE,
141
+ ascii_range=ASCII_RANGE,
142
+ ):
143
+ # strategies
144
+ if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
145
+ raise ValueError("Invalid strategy")
146
+ self.strategy = strategy
147
+ self.ascii_strategy = ascii_strategy
148
+ self.ascii_range = ascii_range
149
+
150
+ # Homoglyphs must be initialized by any alphabet for correct work
151
+ if not categories and not languages and not alphabet:
152
+ categories = ("LATIN", "COMMON")
153
+
154
+ # cats and langs
155
+ self.categories = set(categories or [])
156
+ self.languages = set(languages or [])
157
+
158
+ # alphabet
159
+ self.alphabet = set(alphabet or [])
160
+ if self.categories:
161
+ alphabet = Categories.get_alphabet(self.categories)
162
+ self.alphabet.update(alphabet)
163
+ if self.languages:
164
+ alphabet = Languages.get_alphabet(self.languages)
165
+ self.alphabet.update(alphabet)
166
+ self.table = self.get_table(self.alphabet)
167
+
168
+ @staticmethod
169
+ def get_table(alphabet):
170
+ table = defaultdict(set)
171
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
172
+ data = json.load(f)
173
+ for char in alphabet:
174
+ if char in data:
175
+ for homoglyph in data[char]:
176
+ if homoglyph in alphabet:
177
+ table[char].add(homoglyph)
178
+ return table
179
+
180
+ @staticmethod
181
+ def get_restricted_table(source_alphabet, target_alphabet):
182
+ table = defaultdict(set)
183
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
184
+ data = json.load(f)
185
+ for char in source_alphabet:
186
+ if char in data:
187
+ for homoglyph in data[char]:
188
+ if homoglyph in target_alphabet:
189
+ table[char].add(homoglyph)
190
+ return table
191
+
192
+ @staticmethod
193
+ def uniq_and_sort(data):
194
+ result = list(set(data))
195
+ result.sort(key=lambda x: (-len(x), x))
196
+ return result
197
+
198
+ def _update_alphabet(self, char):
199
+ # try detect languages
200
+ langs = Languages.detect(char)
201
+ if langs:
202
+ self.languages.update(langs)
203
+ alphabet = Languages.get_alphabet(langs)
204
+ self.alphabet.update(alphabet)
205
+ else:
206
+ # try detect categories
207
+ category = Categories.detect(char)
208
+ if category is None:
209
+ return False
210
+ self.categories.add(category)
211
+ alphabet = Categories.get_alphabet([category])
212
+ self.alphabet.update(alphabet)
213
+ # update table for new alphabet
214
+ self.table = self.get_table(self.alphabet)
215
+ return True
216
+
217
+ def _get_char_variants(self, char):
218
+ if char not in self.alphabet:
219
+ if self.strategy == STRATEGY_LOAD:
220
+ if not self._update_alphabet(char):
221
+ return []
222
+ elif self.strategy == STRATEGY_IGNORE:
223
+ return [char]
224
+ elif self.strategy == STRATEGY_REMOVE:
225
+ return []
226
+
227
+ # find alternative chars for current char
228
+ alt_chars = self.table.get(char, set())
229
+ if alt_chars:
230
+ # find alternative chars for alternative chars for current char
231
+ alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
232
+ # combine all alternatives
233
+ alt_chars.update(*alt_chars2)
234
+ # add current char to alternatives
235
+ alt_chars.add(char)
236
+
237
+ # uniq, sort and return
238
+ return self.uniq_and_sort(alt_chars)
239
+
240
+ def _get_combinations(self, text, ascii=False):
241
+ variations = []
242
+ for char in text:
243
+ alt_chars = self._get_char_variants(char)
244
+
245
+ if ascii:
246
+ alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
247
+ if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
248
+ return
249
+
250
+ if alt_chars:
251
+ variations.append(alt_chars)
252
+ if variations:
253
+ for variant in product(*variations):
254
+ yield "".join(variant)
255
+
256
+ def get_combinations(self, text):
257
+ return list(self._get_combinations(text))
258
+
259
+ def _to_ascii(self, text):
260
+ for variant in self._get_combinations(text, ascii=True):
261
+ if max(map(ord, variant)) in self.ascii_range:
262
+ yield variant
263
+
264
+ def to_ascii(self, text):
265
+ return self.uniq_and_sort(self._to_ascii(text))
normalizers.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Text-based normalizers, used to mitigate simple attacks against watermarking.
2
+
3
+ This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
4
+ it represents our best effort at the time of writing.
5
+
6
+ These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
7
+ require messing with the limited rust interface of tokenizers.NormalizedString
8
+ """
9
+ from collections import defaultdict
10
+ from functools import cache
11
+
12
+ import re
13
+ import unicodedata
14
+ import homoglyphs as hg
15
+
16
+
17
+ def normalization_strategy_lookup(strategy_name: str) -> object:
18
+ if strategy_name == "unicode":
19
+ return UnicodeSanitizer()
20
+ elif strategy_name == "homoglyphs":
21
+ return HomoglyphCanonizer()
22
+ elif strategy_name == "truecase":
23
+ return TrueCaser()
24
+
25
+
26
+ class HomoglyphCanonizer:
27
+ """Attempts to detect homoglyph attacks and find a consistent canon.
28
+
29
+ This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
30
+ """
31
+
32
+ def __init__(self):
33
+ self.homoglyphs = None
34
+
35
+ def __call__(self, homoglyphed_str: str) -> str:
36
+ # find canon:
37
+ target_category, all_categories = self._categorize_text(homoglyphed_str)
38
+ homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
39
+ return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
40
+
41
+ def _categorize_text(self, text: str) -> dict:
42
+ iso_categories = defaultdict(int)
43
+ # self.iso_languages = defaultdict(int)
44
+
45
+ for char in text:
46
+ iso_categories[hg.Categories.detect(char)] += 1
47
+ # for lang in hg.Languages.detect(char):
48
+ # self.iso_languages[lang] += 1
49
+ target_category = max(iso_categories, key=iso_categories.get)
50
+ all_categories = tuple(iso_categories)
51
+ return target_category, all_categories
52
+
53
+ @cache
54
+ def _select_canon_category_and_load(self, target_category: str, all_categories: tuple[str]) -> dict:
55
+ homoglyph_table = hg.Homoglyphs(categories=(target_category, "COMMON")) # alphabet loaded here from file
56
+
57
+ source_alphabet = hg.Categories.get_alphabet(all_categories)
58
+ restricted_table = homoglyph_table.get_restricted_table(source_alphabet, homoglyph_table.alphabet) # table loaded here from file
59
+ return restricted_table
60
+
61
+ def _sanitize_text(self, target_category: str, homoglyph_table: dict, homoglyphed_str: str) -> str:
62
+ sanitized_text = ""
63
+ for char in homoglyphed_str:
64
+ # langs = hg.Languages.detect(char)
65
+ cat = hg.Categories.detect(char)
66
+ if target_category in cat or "COMMON" in cat or len(cat) == 0:
67
+ sanitized_text += char
68
+ else:
69
+ sanitized_text += list(homoglyph_table[char])[0]
70
+ return sanitized_text
71
+
72
+
73
+ class UnicodeSanitizer:
74
+ """Regex-based unicode sanitzer. Has different levels of granularity.
75
+
76
+ * ruleset="whitespaces" - attempts to remove only whitespace unicode characters
77
+ * ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters
78
+ * ruleset="ascii" - brute-forces all text into ascii
79
+
80
+ This is unlikely to be a comprehensive list.
81
+
82
+ You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
83
+ and https://www.unicode.org/faq/security.html
84
+ """
85
+
86
+ def __init__(self, ruleset="whitespaces"):
87
+ if ruleset == "whitespaces":
88
+
89
+ """Documentation:
90
+ \u00A0: Non-breaking space
91
+ \u1680: Ogham space mark
92
+ \u180E: Mongolian vowel separator
93
+ \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
94
+ \u200C\u200D: Zero-width non-joiner and zero-width joiner
95
+ \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
96
+ \u2060: Word joiner
97
+ \u2063: Invisible separator
98
+ \u202F: Narrow non-breaking space
99
+ \u205F: Medium mathematical space
100
+ \u3000: Ideographic space
101
+ \uFEFF: Zero-width non-breaking space
102
+ \uFFA0: Halfwidth hangul filler
103
+ \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
104
+ \uFE00-\uFE0F: Variation selectors
105
+ \u202A-\u202F: Embedding characters
106
+ \u3164: Korean hangul filler.
107
+
108
+ Note that these characters are not always superfluous whitespace characters!
109
+ """
110
+
111
+ self.pattern = re.compile(
112
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
113
+ r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
114
+ r"\u202E\u202F]"
115
+ )
116
+ elif ruleset == "IDN.blacklist":
117
+
118
+ """Documentation:
119
+ [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
120
+ set that are included in the IDN blacklist.
121
+ \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
122
+ These characters are not allowed in domain names.
123
+ \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
124
+ set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
125
+ and the second part is in the range U+DC00 to U+DFFF.
126
+ \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
127
+ to U+DFFF, and is optional.
128
+ [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
129
+ """
130
+
131
+ self.pattern = re.compile(
132
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
133
+ r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
134
+ )
135
+ else:
136
+ """Documentation:
137
+ This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
138
+ """
139
+ self.pattern = re.compile(r"[^\x00-\x7F]+")
140
+
141
+ def __call__(self, text: str) -> str:
142
+ text = unicodedata.normalize("NFC", text) # canon forms
143
+ text = self.pattern.sub(" ", text) # pattern match
144
+ text = re.sub(" +", " ", text) # collapse whitespaces
145
+ text = "".join(c for c in text if unicodedata.category(c) != "Cc") # Remove any remaining non-printable characters
146
+ return text
147
+
148
+
149
+ class TrueCaser:
150
+ """True-casing, is a capitalization normalization that returns text to its original capitalization.
151
+
152
+ This defends against attacks that wRIte TeXt lIkE spOngBoB.
153
+
154
+ Here, a simple POS-tagger is used.
155
+ """
156
+
157
+ uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased
158
+
159
+ def __init__(self, backend="spacy"):
160
+ if backend == "spacy":
161
+ spacy_model = "en_core_web_sm"
162
+ try:
163
+ import spacy
164
+ self.nlp = spacy.load(spacy_model)
165
+ except:
166
+ import spacy.cli
167
+ spacy.cli.download(spacy_model)
168
+ import spacy
169
+ self.nlp = spacy.load(spacy_model)
170
+
171
+ self.normalize_fn = self._spacy_truecasing
172
+ else:
173
+ from nltk import pos_tag, word_tokenize # noqa
174
+ import nltk
175
+
176
+ nltk.download("punkt")
177
+ nltk.download("averaged_perceptron_tagger")
178
+ nltk.download("universal_tagset")
179
+ self.normalize_fn = self._nltk_truecasing
180
+
181
+ def __call__(self, random_capitalized_string: str) -> str:
182
+ truecased_str = self.normalize_fn(random_capitalized_string)
183
+ return truecased_str
184
+
185
+ def _spacy_truecasing(self, random_capitalized_string: str):
186
+ doc = self.nlp(random_capitalized_string.lower())
187
+ POS = self.uppercase_pos
188
+ truecased_str = "".join([w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws for w in doc])
189
+ return truecased_str
190
+
191
+ def _nltk_truecasing(self, random_capitalized_string: str):
192
+ from nltk import pos_tag, word_tokenize
193
+ import nltk
194
+
195
+ nltk.download("punkt")
196
+ nltk.download("averaged_perceptron_tagger")
197
+ nltk.download("universal_tagset")
198
+ POS = ["NNP", "NNPS"]
199
+
200
+ tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
201
+ truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
202
+ return truecased_str
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ spacy
2
+ nltk
3
+ scipy
4
+ torch
5
+ datasets
6
+ transformers
7
+ tokenizers
8
+ accelerate
9
+ text-generation>=0.3.1
10
+ gradio>=3.21.0
watermark_processor.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Anonymous Authors of "A Watermark for Large Language Models"
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+ import collections
18
+ from math import sqrt
19
+
20
+ import scipy.stats
21
+
22
+ import torch
23
+ from torch import Tensor
24
+ from tokenizers import Tokenizer
25
+ from transformers import LogitsProcessor
26
+
27
+ from nltk.util import ngrams
28
+
29
+ from normalizers import normalization_strategy_lookup
30
+
31
+ class WatermarkBase:
32
+ def __init__(
33
+ self,
34
+ vocab: list[int] = None,
35
+ gamma: float = 0.5,
36
+ delta: float = 2.0,
37
+ seeding_scheme: str = "simple_1", # mostly unused/always default
38
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
39
+ select_green_tokens: bool = True,
40
+ ):
41
+
42
+ # watermarking parameters
43
+ self.vocab = vocab
44
+ self.vocab_size = len(vocab)
45
+ self.gamma = gamma
46
+ self.delta = delta
47
+ self.seeding_scheme = seeding_scheme
48
+ self.rng = None
49
+ self.hash_key = hash_key
50
+ self.select_green_tokens = select_green_tokens
51
+
52
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
53
+ # can optionally override the seeding scheme,
54
+ # but uses the instance attr by default
55
+ if seeding_scheme is None:
56
+ seeding_scheme = self.seeding_scheme
57
+
58
+ if seeding_scheme == "simple_1":
59
+ assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
60
+ prev_token = input_ids[-1].item()
61
+ self.rng.manual_seed(self.hash_key * prev_token)
62
+ else:
63
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
64
+ return
65
+
66
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
67
+ # seed the rng using the previous tokens/prefix
68
+ # according to the seeding_scheme
69
+ self._seed_rng(input_ids)
70
+
71
+ greenlist_size = int(self.vocab_size * self.gamma)
72
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
73
+ if self.select_green_tokens: # directly
74
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
75
+ else: # select green via red
76
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
77
+ return greenlist_ids
78
+
79
+
80
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
81
+
82
+ def __init__(self, *args, **kwargs):
83
+ super().__init__(*args, **kwargs)
84
+
85
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
86
+ # TODO lets see if we can lose this loop
87
+ green_tokens_mask = torch.zeros_like(scores)
88
+ for b_idx in range(len(greenlist_token_ids)):
89
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
90
+ final_mask = green_tokens_mask.bool()
91
+ return final_mask
92
+
93
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
94
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
95
+ return scores
96
+
97
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
98
+
99
+ # this is lazy to allow us to colocate on the watermarked model's device
100
+ if self.rng is None:
101
+ self.rng = torch.Generator(device=input_ids.device)
102
+
103
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
104
+ # the seed and partition operations are not tensor/vectorized, thus
105
+ # each sequence in the batch needs to be treated separately.
106
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
107
+
108
+ for b_idx in range(input_ids.shape[0]):
109
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
110
+ batched_greenlist_ids[b_idx] = greenlist_ids
111
+
112
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
113
+
114
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
115
+ return scores
116
+
117
+
118
+ class WatermarkDetector(WatermarkBase):
119
+ def __init__(
120
+ self,
121
+ *args,
122
+ device: torch.device = None,
123
+ tokenizer: Tokenizer = None,
124
+ z_threshold: float = 4.0,
125
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
126
+ ignore_repeated_bigrams: bool = False,
127
+ **kwargs,
128
+ ):
129
+ super().__init__(*args, **kwargs)
130
+ # also configure the metrics returned/preprocessing options
131
+ assert device, "Must pass device"
132
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
133
+
134
+ self.tokenizer = tokenizer
135
+ self.device = device
136
+ self.z_threshold = z_threshold
137
+ self.rng = torch.Generator(device=self.device)
138
+
139
+ if self.seeding_scheme == "simple_1":
140
+ self.min_prefix_len = 1
141
+ else:
142
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
143
+
144
+ self.normalizers = []
145
+ for normalization_strategy in normalizers:
146
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
147
+
148
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
149
+ if self.ignore_repeated_bigrams:
150
+ assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
151
+
152
+
153
+ def _compute_z_score(self, observed_count, T):
154
+ # count refers to number of green tokens, T is total number of tokens
155
+ expected_count = self.gamma
156
+ numer = observed_count - expected_count * T
157
+ denom = sqrt(T * expected_count * (1 - expected_count))
158
+ z = numer / denom
159
+ return z
160
+
161
+ def _compute_p_value(self, z):
162
+ p_value = scipy.stats.norm.sf(z)
163
+ return p_value
164
+
165
+ def _score_sequence(
166
+ self,
167
+ input_ids: Tensor,
168
+ return_num_tokens_scored: bool = True,
169
+ return_num_green_tokens: bool = True,
170
+ return_green_fraction: bool = True,
171
+ return_green_token_mask: bool = False,
172
+ return_z_score: bool = True,
173
+ return_p_value: bool = True,
174
+ ):
175
+ if self.ignore_repeated_bigrams:
176
+ # Method that only counts a green/red hit once per unique bigram.
177
+ # New num total tokens scored (T) becomes the number unique bigrams.
178
+ # We iterate over all unqiue token bigrams in the input, computing the greenlist
179
+ # induced by the first token in each, and then checking whether the second
180
+ # token falls in that greenlist.
181
+ assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
182
+ bigram_table = {}
183
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
184
+ freq = collections.Counter(token_bigram_generator)
185
+ num_tokens_scored = len(freq.keys())
186
+ for idx, bigram in enumerate(freq.keys()):
187
+ prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
188
+ greenlist_ids = self._get_greenlist_ids(prefix)
189
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
190
+ green_token_count = sum(bigram_table.values())
191
+ else:
192
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
193
+ if num_tokens_scored < 1:
194
+ raise ValueError((f"Must have at least {1} token to score after "
195
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
196
+ # Standard method.
197
+ # Since we generally need at least 1 token (for the simplest scheme)
198
+ # we start the iteration over the token sequence with a minimum
199
+ # num tokens as the first prefix for the seeding scheme,
200
+ # and at each step, compute the greenlist induced by the
201
+ # current prefix and check if the current token falls in the greenlist.
202
+ green_token_count, green_token_mask = 0, []
203
+ for idx in range(self.min_prefix_len, len(input_ids)):
204
+ curr_token = input_ids[idx]
205
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
206
+ if curr_token in greenlist_ids:
207
+ green_token_count += 1
208
+ green_token_mask.append(True)
209
+ else:
210
+ green_token_mask.append(False)
211
+
212
+ score_dict = dict()
213
+ if return_num_tokens_scored:
214
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
215
+ if return_num_green_tokens:
216
+ score_dict.update(dict(num_green_tokens=green_token_count))
217
+ if return_green_fraction:
218
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
219
+ if return_z_score:
220
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
221
+ if return_p_value:
222
+ z_score = score_dict.get("z_score")
223
+ if z_score is None:
224
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
225
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
226
+ if return_green_token_mask:
227
+ score_dict.update(dict(green_token_mask=green_token_mask))
228
+
229
+ return score_dict
230
+
231
+ def detect(
232
+ self,
233
+ text: str = None,
234
+ tokenized_text: list[int] = None,
235
+ return_prediction: bool = True,
236
+ return_scores: bool = True,
237
+ z_threshold: float = None,
238
+ **kwargs,
239
+ ) -> dict:
240
+
241
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
242
+ if return_prediction:
243
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
244
+
245
+ # run optional normalizers on text
246
+ for normalizer in self.normalizers:
247
+ text = normalizer(text)
248
+ if len(self.normalizers) > 0:
249
+ print(f"Text after normalization:\n\n{text}\n")
250
+
251
+ if tokenized_text is None:
252
+ assert self.tokenizer is not None, (
253
+ "Watermark detection on raw string ",
254
+ "requires an instance of the tokenizer ",
255
+ "that was used at generation time.",
256
+ )
257
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
258
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
259
+ tokenized_text = tokenized_text[1:]
260
+ else:
261
+ # try to remove the bos_tok at beginning if it's there
262
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
263
+ tokenized_text = tokenized_text[1:]
264
+
265
+ # call score method
266
+ output_dict = {}
267
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
268
+ if return_scores:
269
+ output_dict.update(score_dict)
270
+ # if passed return_prediction then perform the hypothesis test and return the outcome
271
+ if return_prediction:
272
+ z_threshold = z_threshold if z_threshold else self.z_threshold
273
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
274
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
275
+ if output_dict["prediction"]:
276
+ output_dict["confidence"] = 1 - score_dict["p_value"]
277
+
278
+ return output_dict
279
+