jwkirchenbauer commited on
Commit
e2e6e76
1 Parent(s): 60d5b9f

try text-generation 0.3.1 release

Browse files
Files changed (2) hide show
  1. demo_watermark.py +162 -125
  2. requirements.txt +2 -2
demo_watermark.py CHANGED
@@ -16,7 +16,6 @@
16
 
17
  import os
18
  import argparse
19
- from argparse import Namespace
20
  from pprint import pprint
21
  from functools import partial
22
 
@@ -30,15 +29,25 @@ from transformers import (AutoTokenizer,
30
  AutoModelForCausalLM,
31
  LogitsProcessorList)
32
 
 
 
 
 
 
33
  from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
34
 
 
 
 
 
 
35
  # FIXME correct lengths for all models
36
  API_MODEL_MAP = {
37
- "bigscience/bloom" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
38
- "bigscience/bloomz" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
39
- "google/flan-ul2" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
40
- "google/flan-t5-xxl" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
41
- "EleutherAI/gpt-neox-20b" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
42
  }
43
 
44
  def str2bool(v):
@@ -231,35 +240,29 @@ def generate_with_api(prompt, args):
231
  timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
232
  try:
233
  generation_params["watermark"] = False
234
- output = client.generate(prompt, **generation_params)
235
- output_text_without_watermark = output.generated_text
236
  except ReadTimeout as e:
237
  print(e)
238
- output_text_without_watermark = timeout_msg
239
  try:
240
  generation_params["watermark"] = True
241
- output = client.generate(prompt, **generation_params)
242
- output_text_with_watermark = output.generated_text
243
  except ReadTimeout as e:
244
  print(e)
245
- output_text_with_watermark = timeout_msg
246
-
247
- return (output_text_without_watermark,
248
- output_text_with_watermark)
249
 
 
 
 
 
 
250
 
251
- def generate(prompt, args, tokenizer, model=None, device=None):
252
- """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
253
- and generate watermarked text by passing it to the generate method of the model
254
- as a logits processor. """
255
-
256
- print(f"Generating with {args}")
257
 
258
  # This applies to both the local and API model scenarios
259
- if args.prompt_max_length:
260
- pass
261
- elif args.model_name_or_path in API_MODEL_MAP:
262
- args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]-args.max_new_tokens
263
  elif hasattr(model.config,"max_position_embedding"):
264
  args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
265
  else:
@@ -269,69 +272,77 @@ def generate(prompt, args, tokenizer, model=None, device=None):
269
  truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
270
  redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
271
 
272
- if args.model_name_or_path in API_MODEL_MAP:
273
- api_outputs = generate_with_api(prompt, args)
274
- decoded_output_without_watermark = api_outputs[0]
275
- decoded_output_with_watermark = api_outputs[1]
276
- return (redecoded_input,
277
- int(truncation_warning),
278
- decoded_output_without_watermark,
279
- decoded_output_with_watermark,
280
- args,
281
- tokenizer)
282
-
283
 
284
- watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
285
- gamma=args.gamma,
286
- delta=args.delta,
287
- seeding_scheme=args.seeding_scheme,
288
- select_green_tokens=args.select_green_tokens)
289
 
290
- gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
291
 
292
- if args.use_sampling:
293
- gen_kwargs.update(dict(
294
- do_sample=True,
295
- top_k=0,
296
- temperature=args.sampling_temp
297
- ))
298
- else:
299
- gen_kwargs.update(dict(
300
- num_beams=args.n_beams
301
- ))
302
 
303
- generate_without_watermark = partial(
304
- model.generate,
305
- **gen_kwargs
306
- )
307
- generate_with_watermark = partial(
308
- model.generate,
309
- logits_processor=LogitsProcessorList([watermark_processor]),
310
- **gen_kwargs
311
- )
312
 
313
- torch.manual_seed(args.generation_seed)
314
- output_without_watermark = generate_without_watermark(**tokd_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
317
- if args.seed_separately:
318
  torch.manual_seed(args.generation_seed)
319
- output_with_watermark = generate_with_watermark(**tokd_input)
320
 
321
- if args.is_decoder_only_model:
322
- # need to isolate the newly generated tokens
323
- output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
324
- output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
325
 
326
- decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
327
- decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
 
 
328
 
329
- return (redecoded_input,
330
- int(truncation_warning),
331
- decoded_output_without_watermark,
332
- decoded_output_with_watermark,
333
- args,
334
- tokenizer)
 
 
 
335
 
336
 
337
  def format_names(s):
@@ -348,7 +359,6 @@ def format_names(s):
348
  def list_format_scores(score_dict, detection_threshold):
349
  """Format the detection metrics into a gradio dataframe input format"""
350
  lst_2d = []
351
- # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
352
  for k,v in score_dict.items():
353
  if k=='green_fraction':
354
  lst_2d.append([format_names(k), f"{v:.1%}"])
@@ -366,9 +376,10 @@ def list_format_scores(score_dict, detection_threshold):
366
  lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
367
  return lst_2d
368
 
369
- def detect(input_text, args, tokenizer, device=None):
370
  """Instantiate the WatermarkDetection object and call detect on
371
  the input text returning the scores and outcome of the test"""
 
372
  print(f"Detecting with {args}")
373
  print(f"Detection Tokenizer: {type(tokenizer)}")
374
 
@@ -381,14 +392,14 @@ def detect(input_text, args, tokenizer, device=None):
381
  normalizers=args.normalizers,
382
  ignore_repeated_bigrams=args.ignore_repeated_bigrams,
383
  select_green_tokens=args.select_green_tokens)
384
- # if len(input_text)-1 > watermark_detector.min_prefix_len:
385
  error = False
 
386
  if input_text == "":
387
  error = True
388
  else:
389
- try:
390
- score_dict = watermark_detector.detect(input_text)
391
- # output = str_format_scores(score_dict, watermark_detector.z_threshold)
392
  output = list_format_scores(score_dict, watermark_detector.z_threshold)
393
  except ValueError as e:
394
  print(e)
@@ -396,16 +407,41 @@ def detect(input_text, args, tokenizer, device=None):
396
  if error:
397
  output = [["Error","string too short to compute metrics"]]
398
  output += [["",""] for _ in range(6)]
399
- return output, args, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  def run_gradio(args, model=None, device=None, tokenizer=None):
402
  """Define and launch the gradio demo interface"""
403
- # generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
404
- # detect_partial = partial(detect, device=device, tokenizer=tokenizer)
405
  generate_partial = partial(generate, model=model, device=device)
406
  detect_partial = partial(detect, device=device)
407
 
408
- with gr.Blocks() as demo:
 
 
 
 
 
 
409
  # Top section, greeting and instructions
410
  with gr.Row():
411
  with gr.Column(scale=9):
@@ -422,7 +458,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
422
  [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
423
  """
424
  )
425
- # gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
426
  # if model_name_or_path at startup not one of the API models then add to dropdown
427
  all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path])))
428
  model_selector = gr.Dropdown(
@@ -475,8 +510,10 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
475
  """
476
  **[Generate & Detect]**: The first tab shows that the watermark can be embedded with
477
  negligible impact on text quality. You can try any prompt and compare the quality of
478
- normal text (*Output Without Watermark*) to the watermarked text (*Output With Watermark*) below it.
479
- Metrics on the right show that the watermark can be reliably detected.
 
 
480
  Detection is very efficient and does not use the language model or its parameters.
481
 
482
  **[Detector Only]**: You can also copy-paste the watermarked text (or any other text)
@@ -495,7 +532,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
495
  """
496
  )
497
 
498
-
499
  with gr.Tab("Generate & Detect"):
500
 
501
  with gr.Row():
@@ -504,15 +540,19 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
504
  generate_btn = gr.Button("Generate")
505
  with gr.Row():
506
  with gr.Column(scale=2):
507
- output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
 
 
 
508
  with gr.Column(scale=1):
509
- # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
510
  without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
511
  with gr.Row():
512
  with gr.Column(scale=2):
513
- output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
 
 
 
514
  with gr.Column(scale=1):
515
- # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
516
  with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
517
 
518
  redecoded_input = gr.Textbox(visible=False)
@@ -528,7 +568,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
528
  with gr.Column(scale=2):
529
  detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
530
  with gr.Column(scale=1):
531
- # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
532
  detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
533
  with gr.Row():
534
  detect_btn = gr.Button("Detect")
@@ -562,7 +601,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
562
  ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
563
  with gr.Row():
564
  normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
565
- # with gr.Accordion("Actual submitted parameters:",open=False):
566
  with gr.Row():
567
  gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
568
  with gr.Row():
@@ -657,15 +695,14 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
657
  <p/>
658
  """)
659
 
660
- # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
661
- generate_btn.click(fn=generate_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args,session_tokenizer])
 
 
 
662
  # Show truncated version of prompt if truncation occurred
663
  redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
664
- # Call detection when the outputs (of the generate function) are updated
665
- output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
666
- output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
667
  # Register main detection tab click
668
- # detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer])
669
  detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
670
 
671
  # State management logic
@@ -727,7 +764,11 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
727
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
728
  def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
729
  def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
730
- def update_tokenizer(model_name_or_path): return AutoTokenizer.from_pretrained(model_name_or_path)
 
 
 
 
731
  # registering callbacks for toggling the visibilty of certain parameters based on the values of others
732
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
733
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
@@ -829,19 +870,6 @@ def main(args):
829
  "on their body and head. The diamondback terrapin has large webbed "
830
  "feet.[9] The species is"
831
  )
832
-
833
- # teaser example
834
- # input_text = (
835
- # "In this work, we study watermarking of language model output. "
836
- # "A watermark is a hidden pattern in text that is imperceptible to humans, "
837
- # "while making the text algorithmically identifiable as synthetic. "
838
- # "We propose an efficient watermark that makes synthetic text detectable "
839
- # "from short spans of tokens (as few as 25 words), while false-positives "
840
- # "(where human text is marked as machine-generated) are statistically improbable. "
841
- # "The watermark detection algorithm can be made public, enabling third parties "
842
- # "(e.g., social media platforms) to run it themselves, or it can be kept private "
843
- # "and run behind an API. We seek a watermark with the following properties:\n"
844
- # )
845
 
846
  args.default_prompt = input_text
847
 
@@ -854,19 +882,28 @@ def main(args):
854
  print("Prompt:")
855
  print(input_text)
856
 
857
- _, _, decoded_output_without_watermark, decoded_output_with_watermark, _, _ = generate(input_text,
858
- args,
859
- model=model,
860
- device=device,
861
- tokenizer=tokenizer)
 
 
 
 
 
 
 
862
  without_watermark_detection_result = detect(decoded_output_without_watermark,
863
  args,
864
  device=device,
865
- tokenizer=tokenizer)
 
866
  with_watermark_detection_result = detect(decoded_output_with_watermark,
867
  args,
868
  device=device,
869
- tokenizer=tokenizer)
 
870
 
871
  print("#"*term_width)
872
  print("Output without watermark:")
 
16
 
17
  import os
18
  import argparse
 
19
  from pprint import pprint
20
  from functools import partial
21
 
 
29
  AutoModelForCausalLM,
30
  LogitsProcessorList)
31
 
32
+ # from local_tokenizers.tokenization_llama import LLaMATokenizer
33
+
34
+ from transformers import GPT2TokenizerFast
35
+ OPT_TOKENIZER = GPT2TokenizerFast
36
+
37
  from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
38
 
39
+
40
+ # ALPACA_MODEL_NAME = "alpaca"
41
+ # ALPACA_MODEL_TOKENIZER = LLaMATokenizer
42
+ # ALPACA_TOKENIZER_PATH = "/cmlscratch/jkirchen/llama"
43
+
44
  # FIXME correct lengths for all models
45
  API_MODEL_MAP = {
46
+ "bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
47
+ "bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
48
+ "google/flan-ul2" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
49
+ "google/flan-t5-xxl" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
50
+ "EleutherAI/gpt-neox-20b" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
51
  }
52
 
53
  def str2bool(v):
 
240
  timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
241
  try:
242
  generation_params["watermark"] = False
243
+ without_watermark_iterator = client.generate_stream(prompt, **generation_params)
 
244
  except ReadTimeout as e:
245
  print(e)
246
+ without_watermark_iterator = (char for char in timeout_msg)
247
  try:
248
  generation_params["watermark"] = True
249
+ with_watermark_iterator = client.generate_stream(prompt, **generation_params)
 
250
  except ReadTimeout as e:
251
  print(e)
252
+ with_watermark_iterator = (char for char in timeout_msg)
 
 
 
253
 
254
+ all_without_words, all_with_words = "", ""
255
+ for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
256
+ all_without_words += without_word.token.text
257
+ all_with_words += with_word.token.text
258
+ yield all_without_words, all_with_words
259
 
260
+
261
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
 
 
 
 
262
 
263
  # This applies to both the local and API model scenarios
264
+ if args.model_name_or_path in API_MODEL_MAP:
265
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
 
 
266
  elif hasattr(model.config,"max_position_embedding"):
267
  args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
268
  else:
 
272
  truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
273
  redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
274
 
275
+ return (redecoded_input,
276
+ int(truncation_warning),
277
+ args)
 
 
 
 
 
 
 
 
278
 
 
 
 
 
 
279
 
 
280
 
281
+ def generate(prompt, args, tokenizer, model=None, device=None):
282
+ """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
283
+ and generate watermarked text by passing it to the generate method of the model
284
+ as a logits processor. """
 
 
 
 
 
 
285
 
286
+ print(f"Generating with {args}")
287
+ print(f"Prompt: {prompt}")
 
 
 
 
 
 
 
288
 
289
+ if args.model_name_or_path in API_MODEL_MAP:
290
+ api_outputs = generate_with_api(prompt, args)
291
+ yield from api_outputs
292
+ else:
293
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
294
+
295
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
296
+ gamma=args.gamma,
297
+ delta=args.delta,
298
+ seeding_scheme=args.seeding_scheme,
299
+ select_green_tokens=args.select_green_tokens)
300
+
301
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
302
+
303
+ if args.use_sampling:
304
+ gen_kwargs.update(dict(
305
+ do_sample=True,
306
+ top_k=0,
307
+ temperature=args.sampling_temp
308
+ ))
309
+ else:
310
+ gen_kwargs.update(dict(
311
+ num_beams=args.n_beams
312
+ ))
313
+
314
+ generate_without_watermark = partial(
315
+ model.generate,
316
+ **gen_kwargs
317
+ )
318
+ generate_with_watermark = partial(
319
+ model.generate,
320
+ logits_processor=LogitsProcessorList([watermark_processor]),
321
+ **gen_kwargs
322
+ )
323
 
 
 
324
  torch.manual_seed(args.generation_seed)
325
+ output_without_watermark = generate_without_watermark(**tokd_input)
326
 
327
+ # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
328
+ if args.seed_separately:
329
+ torch.manual_seed(args.generation_seed)
330
+ output_with_watermark = generate_with_watermark(**tokd_input)
331
 
332
+ if args.is_decoder_only_model:
333
+ # need to isolate the newly generated tokens
334
+ output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
335
+ output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
336
 
337
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
338
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
339
+
340
+ # mocking the API outputs in a whitespace split generator style
341
+ all_without_words, all_with_words = "", ""
342
+ for without_word, with_word in zip(decoded_output_without_watermark.split(), decoded_output_with_watermark.split()):
343
+ all_without_words += without_word + " "
344
+ all_with_words += with_word + " "
345
+ yield all_without_words, all_with_words
346
 
347
 
348
  def format_names(s):
 
359
  def list_format_scores(score_dict, detection_threshold):
360
  """Format the detection metrics into a gradio dataframe input format"""
361
  lst_2d = []
 
362
  for k,v in score_dict.items():
363
  if k=='green_fraction':
364
  lst_2d.append([format_names(k), f"{v:.1%}"])
 
376
  lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
377
  return lst_2d
378
 
379
+ def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
380
  """Instantiate the WatermarkDetection object and call detect on
381
  the input text returning the scores and outcome of the test"""
382
+
383
  print(f"Detecting with {args}")
384
  print(f"Detection Tokenizer: {type(tokenizer)}")
385
 
 
392
  normalizers=args.normalizers,
393
  ignore_repeated_bigrams=args.ignore_repeated_bigrams,
394
  select_green_tokens=args.select_green_tokens)
 
395
  error = False
396
+ green_token_mask = None
397
  if input_text == "":
398
  error = True
399
  else:
400
+ try:
401
+ score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
402
+ green_token_mask = score_dict.pop("green_token_mask", None)
403
  output = list_format_scores(score_dict, watermark_detector.z_threshold)
404
  except ValueError as e:
405
  print(e)
 
407
  if error:
408
  output = [["Error","string too short to compute metrics"]]
409
  output += [["",""] for _ in range(6)]
410
+
411
+ html_output = ""
412
+ if green_token_mask is not None:
413
+ # hack bc we need a fast tokenizer with charspan support
414
+ if "opt" in args.model_name_or_path:
415
+ tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path)
416
+
417
+ tokens = tokenizer(input_text)
418
+ if tokens["input_ids"][0] == tokenizer.bos_token_id:
419
+ tokens["input_ids"] = tokens["input_ids"][1:] # ignore attention mask
420
+ skip = watermark_detector.min_prefix_len
421
+ charspans = [tokens.token_to_chars(i) for i in range(skip,len(tokens["input_ids"]))]
422
+ charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
423
+
424
+ if len(charspans) != len(green_token_mask): breakpoint()
425
+ assert len(charspans) == len(green_token_mask)
426
+
427
+ tags = [(f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>') for cs, m in zip(charspans, green_token_mask)]
428
+ html_output = f'<p>{" ".join(tags)}</p>'
429
+
430
+ return output, args, tokenizer, html_output
431
 
432
  def run_gradio(args, model=None, device=None, tokenizer=None):
433
  """Define and launch the gradio demo interface"""
434
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
 
435
  generate_partial = partial(generate, model=model, device=device)
436
  detect_partial = partial(detect, device=device)
437
 
438
+
439
+ css = """
440
+ .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
441
+ .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
442
+ """
443
+
444
+ with gr.Blocks(css=css) as demo:
445
  # Top section, greeting and instructions
446
  with gr.Row():
447
  with gr.Column(scale=9):
 
458
  [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
459
  """
460
  )
 
461
  # if model_name_or_path at startup not one of the API models then add to dropdown
462
  all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path])))
463
  model_selector = gr.Dropdown(
 
510
  """
511
  **[Generate & Detect]**: The first tab shows that the watermark can be embedded with
512
  negligible impact on text quality. You can try any prompt and compare the quality of
513
+ normal text (*Output Without Watermark*) to the watermarked text (*Output With Watermark*) below it.
514
+ You can also "see" the watermark by looking at the **Highlighted** tab where the tokens are
515
+ colored green or red depending on which list they are in.
516
+ Metrics on the right show that the watermark can be reliably detected given a reasonably small number of tokens (25-50).
517
  Detection is very efficient and does not use the language model or its parameters.
518
 
519
  **[Detector Only]**: You can also copy-paste the watermarked text (or any other text)
 
532
  """
533
  )
534
 
 
535
  with gr.Tab("Generate & Detect"):
536
 
537
  with gr.Row():
 
540
  generate_btn = gr.Button("Generate")
541
  with gr.Row():
542
  with gr.Column(scale=2):
543
+ with gr.Tab("Output Without Watermark (Raw Text)"):
544
+ output_without_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14)
545
+ with gr.Tab("Highlighted"):
546
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
547
  with gr.Column(scale=1):
 
548
  without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
549
  with gr.Row():
550
  with gr.Column(scale=2):
551
+ with gr.Tab("Output With Watermark (Raw Text)"):
552
+ output_with_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14)
553
+ with gr.Tab("Highlighted"):
554
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
555
  with gr.Column(scale=1):
 
556
  with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
557
 
558
  redecoded_input = gr.Textbox(visible=False)
 
568
  with gr.Column(scale=2):
569
  detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
570
  with gr.Column(scale=1):
 
571
  detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
572
  with gr.Row():
573
  detect_btn = gr.Button("Detect")
 
601
  ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
602
  with gr.Row():
603
  normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
 
604
  with gr.Row():
605
  gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
606
  with gr.Row():
 
695
  <p/>
696
  """)
697
 
698
+ # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
699
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success(
700
+ fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success(
701
+ fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
702
+ fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
703
  # Show truncated version of prompt if truncation occurred
704
  redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
 
 
 
705
  # Register main detection tab click
 
706
  detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
707
 
708
  # State management logic
 
764
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
765
  def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
766
  def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
767
+ def update_tokenizer(model_name_or_path):
768
+ # if model_name_or_path == ALPACA_MODEL_NAME:
769
+ # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
770
+ # else:
771
+ return AutoTokenizer.from_pretrained(model_name_or_path)
772
  # registering callbacks for toggling the visibilty of certain parameters based on the values of others
773
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
774
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
 
870
  "on their body and head. The diamondback terrapin has large webbed "
871
  "feet.[9] The species is"
872
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
873
 
874
  args.default_prompt = input_text
875
 
 
882
  print("Prompt:")
883
  print(input_text)
884
 
885
+ # a generator that yields (without_watermark, with_watermark) pairs
886
+ generator_outputs = generate(input_text,
887
+ args,
888
+ model=model,
889
+ device=device,
890
+ tokenizer=tokenizer)
891
+ # we need to iterate over it,
892
+ # but we only want the last output in this case
893
+ for out in generator_outputs:
894
+ decoded_output_without_watermark = out[0]
895
+ decoded_output_with_watermark = out[1]
896
+
897
  without_watermark_detection_result = detect(decoded_output_without_watermark,
898
  args,
899
  device=device,
900
+ tokenizer=tokenizer,
901
+ return_green_token_mask=False)
902
  with_watermark_detection_result = detect(decoded_output_with_watermark,
903
  args,
904
  device=device,
905
+ tokenizer=tokenizer,
906
+ return_green_token_mask=False)
907
 
908
  print("#"*term_width)
909
  print("Output without watermark:")
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  spacy
2
- gradio
3
  nltk
4
  scipy
5
  torch
6
  transformers
7
  tokenizers
8
  accelerate
9
- text-generation>=0.3.0
 
1
  spacy
2
+ gradio>=3.21.0
3
  nltk
4
  scipy
5
  torch
6
  transformers
7
  tokenizers
8
  accelerate
9
+ text-generation>=0.3.1