jwkirchenbauer commited on
Commit
7d29596
1 Parent(s): dafc0b4

more polished interface

Browse files
app.py CHANGED
@@ -19,9 +19,14 @@ args = Namespace()
19
 
20
  arg_dict = {
21
  'run_gradio': True,
22
- 'demo_public': False,
23
- # 'model_name_or_path': 'facebook/opt-125m',
24
- 'model_name_or_path': 'facebook/opt-2.7b',
 
 
 
 
 
25
  'prompt_max_length': None,
26
  'max_new_tokens': 200,
27
  'generation_seed': 123,
@@ -36,6 +41,7 @@ arg_dict = {
36
  'ignore_repeated_bigrams': False,
37
  'detection_z_threshold': 4.0,
38
  'select_green_tokens': True,
 
39
  'skip_model_load': False,
40
  'seed_separately': True,
41
  }
 
19
 
20
  arg_dict = {
21
  'run_gradio': True,
22
+ # 'demo_public': False,
23
+ 'demo_public': True,
24
+ 'model_name_or_path': 'facebook/opt-125m',
25
+ # 'model_name_or_path': 'facebook/opt-1.3b',
26
+ # 'model_name_or_path': 'facebook/opt-2.7b',
27
+ # 'model_name_or_path': 'facebook/opt-6.7b',
28
+ # 'model_name_or_path': 'facebook/opt-13b',
29
+ # 'model_name_or_path': 'facebook/opt-30b',
30
  'prompt_max_length': None,
31
  'max_new_tokens': 200,
32
  'generation_seed': 123,
 
41
  'ignore_repeated_bigrams': False,
42
  'detection_z_threshold': 4.0,
43
  'select_green_tokens': True,
44
+ # 'skip_model_load': True,
45
  'skip_model_load': False,
46
  'seed_separately': True,
47
  }
demo_watermark.py CHANGED
@@ -250,6 +250,41 @@ def generate(prompt, args, model=None, device=None, tokenizer=None):
250
  args)
251
  # decoded_output_with_watermark)
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def detect(input_text, args, device=None, tokenizer=None):
254
  watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
255
  gamma=args.gamma,
@@ -262,11 +297,13 @@ def detect(input_text, args, device=None, tokenizer=None):
262
  select_green_tokens=args.select_green_tokens)
263
  if len(input_text)-1 > watermark_detector.min_prefix_len:
264
  score_dict = watermark_detector.detect(input_text)
265
- output_str = (f"Detection result @ {watermark_detector.z_threshold}:\n"
266
- f"{score_dict}")
267
  else:
268
- output_str = (f"Error: string not long enough to compute watermark presence.")
269
- return output_str, args
 
 
270
 
271
  def run_gradio(args, model=None, device=None, tokenizer=None):
272
 
@@ -276,33 +313,41 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
276
  with gr.Blocks() as demo:
277
 
278
  # Top section, greeting and instructions
279
- gr.Markdown("## Demo for ['A Watermark for Large Language Models'](https://arxiv.org/abs/2301.10226)")
280
- gr.HTML("""
281
- <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
282
- <br/>
283
- <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
284
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
285
- <p/>
286
- """)
287
- # Construct state for parameters, define updates and toggles, and register event listeners
 
 
 
 
 
 
288
  session_args = gr.State(value=args)
289
 
290
- with gr.Tab("Generation"):
291
 
292
  with gr.Row():
293
- prompt = gr.Textbox(label=f"Prompt", interactive=True)
294
  with gr.Row():
295
  generate_btn = gr.Button("Generate")
296
  with gr.Row():
297
  with gr.Column(scale=2):
298
- output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False)
299
  with gr.Column(scale=1):
300
- without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
 
301
  with gr.Row():
302
  with gr.Column(scale=2):
303
- output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False)
304
  with gr.Column(scale=1):
305
- with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False)
 
306
 
307
  redecoded_input = gr.Textbox(visible=False)
308
  truncation_warning = gr.Number(visible=False)
@@ -311,24 +356,16 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
311
  return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
312
  else:
313
  return orig_prompt, args
314
-
315
- generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
316
-
317
- # Show truncated version of prompt if truncation occurred
318
- redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
319
-
320
- # Call detection when the outputs of the generate function are updated.
321
- output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
322
- output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
323
 
324
  with gr.Tab("Detector Only"):
325
  with gr.Row():
326
- detection_input = gr.Textbox(label="Text to Analyze", interactive=True)
327
- with gr.Row():
328
- detect_btn = gr.Button("Detect")
 
 
329
  with gr.Row():
330
- detection_result = gr.Textbox(label="Detection Result", interactive=False)
331
- detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
332
 
333
  # Parameter selection group
334
  with gr.Accordion("Advanced Settings",open=False):
@@ -347,18 +384,23 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
347
  max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
348
 
349
  with gr.Column(scale=1):
350
- gr.Markdown(f"#### Watermarking Parameters")
351
  with gr.Row():
352
  gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
353
  with gr.Row():
354
  delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
 
 
 
355
  with gr.Row():
356
  ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
357
  with gr.Row():
358
  normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
359
- gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help._")
360
- with gr.Accordion("Actual submitted parameters:",open=False):
361
- current_parameters = gr.Textbox(label="submitted parameters", value=args)
 
 
362
  with gr.Accordion("Legacy Settings",open=False):
363
  with gr.Row():
364
  with gr.Column(scale=1):
@@ -366,23 +408,31 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
366
  with gr.Column(scale=1):
367
  select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
368
 
 
 
 
 
 
 
 
369
 
370
- with gr.Accordion("A note on model capability",open=False):
371
- gr.Markdown(
372
- """
373
- The models that can be used in this demo are limited to those that are open source as well as fit on a single commodity GPU. In particular, there are few models above 10B parameters and way fewer trained using both Instruction finetuning or RLHF that are open source that we can use.
374
-
375
- Therefore, the model, in both it's un-watermarked (normal) and watermarked state, is not generally able to respond well to the kinds of prompts that a 100B+ Instruction and RLHF tuned model such as ChatGPT, Claude, or Bard is.
376
-
377
- We suggest you try prompts that give the model a few sentences and then allow it to 'continue' the prompt, as these weaker models are more capable in this simpler language modeling setting.
378
- """
379
- )
380
-
381
- # State manager logic
382
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
383
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
384
  def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
385
  def update_delta(session_state, value): session_state.delta = float(value); return session_state
 
386
  def update_decoding(session_state, value):
387
  if value == "multinomial":
388
  session_state.use_sampling = True
@@ -405,11 +455,11 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
405
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
406
  def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
407
  def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
408
-
409
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
410
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
411
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
412
-
413
  decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
414
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
415
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
@@ -417,17 +467,36 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
417
  max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
418
  gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
419
  delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
 
420
  ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
421
  normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
422
  seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
423
  select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
424
-
425
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
426
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- # When the parameters change, also fire detection, since some detection params dont change the model output.
429
- current_parameters.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
430
- current_parameters.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
431
 
432
  demo.queue(concurrency_count=3)
433
 
 
250
  args)
251
  # decoded_output_with_watermark)
252
 
253
+ def format_names(s):
254
+ s=s.replace("num_tokens_scored","Tokens Counted (T)")
255
+ s=s.replace("num_green_tokens","# Tokens in Greenlist")
256
+ s=s.replace("green_fraction","Fraction of T in Greenlist")
257
+ s=s.replace("z_score","z-score")
258
+ s=s.replace("p_value","p value")
259
+ return s
260
+ # def str_format_scores(score_dict, detection_threshold):
261
+ # output_str = f"@ z-score threshold={detection_threshold}:\n\n"
262
+ # for k,v in score_dict.items():
263
+ # if k=='green_fraction':
264
+ # output_str+=f"{format_names(k)}={v:.1%}"
265
+ # elif k=='confidence':
266
+ # output_str+=f"{format_names(k)}={v:.3%}"
267
+ # elif isinstance(v, float):
268
+ # output_str+=f"{format_names(k)}={v:.3g}"
269
+ # else:
270
+ # output_str += v
271
+ # return output_str
272
+ def list_format_scores(score_dict, detection_threshold):
273
+ lst_2d = []
274
+ lst_2d.append(["z-score threshold", f"{detection_threshold}"])
275
+ for k,v in score_dict.items():
276
+ if k=='green_fraction':
277
+ lst_2d.append([format_names(k), f"{v:.1%}"])
278
+ elif k=='confidence':
279
+ lst_2d.append([format_names(k), f"{v:.3%}"])
280
+ elif isinstance(v, float):
281
+ lst_2d.append([format_names(k), f"{v:.3g}"])
282
+ elif isinstance(v, bool):
283
+ lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
284
+ else:
285
+ lst_2d.append([format_names(k), f"{v}"])
286
+ return lst_2d
287
+
288
  def detect(input_text, args, device=None, tokenizer=None):
289
  watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
290
  gamma=args.gamma,
 
297
  select_green_tokens=args.select_green_tokens)
298
  if len(input_text)-1 > watermark_detector.min_prefix_len:
299
  score_dict = watermark_detector.detect(input_text)
300
+ # output = str_format_scores(score_dict, watermark_detector.z_threshold)
301
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
302
  else:
303
+ # output = (f"Error: string not long enough to compute watermark presence.")
304
+ output = [["Error","string too short to compute metrics"]]
305
+ output += [["",""] for _ in range(6)]
306
+ return output, args
307
 
308
  def run_gradio(args, model=None, device=None, tokenizer=None):
309
 
 
313
  with gr.Blocks() as demo:
314
 
315
  # Top section, greeting and instructions
316
+ gr.Markdown("## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍")
317
+ gr.Markdown("[jwkirchenbauer/lm-watermarking![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)")
318
+
319
+ with gr.Accordion("A note on model capability",open=False):
320
+ gr.Markdown(
321
+ """
322
+ The models that can be used in this demo are limited to those that are open source as well as fit on a single commodity GPU. In particular, there are few models above 10B parameters and way fewer trained using both Instruction finetuning or RLHF that are open source that we can use.
323
+
324
+ Therefore, the model, in both it's un-watermarked (normal) and watermarked state, is not generally able to respond well to the kinds of prompts that a 100B+ Instruction and RLHF tuned model such as ChatGPT, Claude, or Bard is.
325
+
326
+ We suggest you try prompts that give the model a few sentences and then allow it to 'continue' the prompt, as these weaker models are more capable in this simpler language modeling setting.
327
+ """
328
+ )
329
+
330
+ # Construct state for parameters, define updates and toggles
331
  session_args = gr.State(value=args)
332
 
333
+ with gr.Tab("Generate and Detect"):
334
 
335
  with gr.Row():
336
+ prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=12,max_lines=12)
337
  with gr.Row():
338
  generate_btn = gr.Button("Generate")
339
  with gr.Row():
340
  with gr.Column(scale=2):
341
+ output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=12,max_lines=12)
342
  with gr.Column(scale=1):
343
+ # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=12,max_lines=12)
344
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
345
  with gr.Row():
346
  with gr.Column(scale=2):
347
+ output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=12,max_lines=12)
348
  with gr.Column(scale=1):
349
+ # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=12,max_lines=12)
350
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
351
 
352
  redecoded_input = gr.Textbox(visible=False)
353
  truncation_warning = gr.Number(visible=False)
 
356
  return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
357
  else:
358
  return orig_prompt, args
 
 
 
 
 
 
 
 
 
359
 
360
  with gr.Tab("Detector Only"):
361
  with gr.Row():
362
+ with gr.Column(scale=2):
363
+ detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=12,max_lines=12)
364
+ with gr.Column(scale=1):
365
+ # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=12,max_lines=12)
366
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
367
  with gr.Row():
368
+ detect_btn = gr.Button("Detect")
 
369
 
370
  # Parameter selection group
371
  with gr.Accordion("Advanced Settings",open=False):
 
384
  max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
385
 
386
  with gr.Column(scale=1):
387
+ gr.Markdown(f"#### Watermark Parameters")
388
  with gr.Row():
389
  gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
390
  with gr.Row():
391
  delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
392
+ gr.Markdown(f"#### Detector Parameters")
393
+ with gr.Row():
394
+ detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold)
395
  with gr.Row():
396
  ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
397
  with gr.Row():
398
  normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
399
+ # with gr.Accordion("Actual submitted parameters:",open=False):
400
+ with gr.Row():
401
+ gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
402
+ with gr.Row():
403
+ current_parameters = gr.Textbox(label="Current Parameters", value=args)
404
  with gr.Accordion("Legacy Settings",open=False):
405
  with gr.Row():
406
  with gr.Column(scale=1):
 
408
  with gr.Column(scale=1):
409
  select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
410
 
411
+ gr.HTML("""
412
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
413
+ <br/>
414
+ <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
415
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
416
+ <p/>
417
+ """)
418
 
419
+ # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
420
+ generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
421
+ # Show truncated version of prompt if truncation occurred
422
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
423
+ # Call detection when the outputs (of the generate function) are updated
424
+ output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
425
+ output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
426
+ # Register main detection tab click
427
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
428
+
429
+ # State management logic
430
+ # update callbacks that change the state dict
431
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
432
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
433
  def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
434
  def update_delta(session_state, value): session_state.delta = float(value); return session_state
435
+ def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
436
  def update_decoding(session_state, value):
437
  if value == "multinomial":
438
  session_state.use_sampling = True
 
455
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
456
  def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
457
  def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
458
+ # registering callbacks for toggling the visibilty of certain parameters
459
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
460
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
461
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
462
+ # registering all state update callbacks
463
  decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
464
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
465
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
 
467
  max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
468
  gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
469
  delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
470
+ detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
471
  ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
472
  normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
473
  seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
474
  select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
475
+ # register additional callback on button clicks that updates the shown parameters window
476
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
477
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
478
+ # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
479
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
480
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
481
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
482
+ gamma.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_input,session_args])
483
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
484
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
485
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
486
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_input,session_args])
487
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
488
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
489
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
490
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_input,session_args])
491
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
492
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
493
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
494
+ normalizers.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_input,session_args])
495
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
496
+ select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
497
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
498
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_input,session_args])
499
 
 
 
 
500
 
501
  demo.queue(concurrency_count=3)
502
 
homoglyph_data/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is data for homoglyph finding
2
+
3
+ """Original package info:
4
+
5
+ Homoglyphs
6
+ * Get similar letters
7
+ * Convert string to ASCII letters
8
+ * Detect possible letter languages
9
+ * Detect letter UTF-8 group.
10
+
11
+ # main package info
12
+ __title__ = 'Homoglyphs'
13
+ __version__ = '2.0.4'
14
+ __author__ = 'Gram Orsinium'
15
+ __license__ = 'MIT'
16
+
17
+ # License:
18
+
19
+ MIT License 2019 orsinium <master_fess@mail.ru>
20
+
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice (including the next
29
+ paragraph) shall be included in all copies or substantial portions of the
30
+ Software.
31
+
32
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ SOFTWARE.
39
+
40
+ """
homoglyph_data/categories.json ADDED
The diff for this file is too large to render. See raw diff
 
homoglyph_data/confusables_sept2022.json ADDED
The diff for this file is too large to render. See raw diff
 
homoglyph_data/languages.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ar": "ءآأؤإئابةتثجحخدذرزسشصضطظعغػؼؽؾؿـفقكلمنهوىيًٌٍَُِّ",
3
+ "be": "ʼЁІЎАБВГДЕЖЗЙКЛМНОПРСТУФХЦЧШЫЬЭЮЯабвгдежзйклмнопрстуфхцчшыьэюяёіў",
4
+ "bg": "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
5
+ "ca": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÍÏÒÓÚÜÇàèéíïòóúüç·",
6
+ "cz": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÓÚÝáéíóúýČčĎďĚěŇňŘřŠšŤťŮůŽž",
7
+ "da": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÅÆØåæø",
8
+ "de": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÖÜßäöü",
9
+ "el": "ΪΫΆΈΉΊΌΎΏΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩΐΰϊϋάέήίαβγδεζηθικλμνξοπρςστυφχψωόύώ",
10
+ "en": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
11
+ "eo": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĈĉĜĝĤĥĴĵŜŝŬŭ",
12
+ "es": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÑÓÚÜáéíñóúü",
13
+ "et": "ABDEGHIJKLMNOPRSTUVabdeghijklmnoprstuvÄÕÖÜäõöü",
14
+ "fi": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÅÖäåöŠšŽž",
15
+ "fr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÂÇÈÉÊÎÏÙÛàâçèéêîïùûŒœ",
16
+ "he": "אבגדהוזחטיךכלםמןנסעףפץצקרשתװױײ",
17
+ "hr": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĆćČčĐ𩹮ž",
18
+ "hu": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzÁÉÍÓÖÚÜáéíóöúüŐőŰű",
19
+ "it": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÌÒÓÙàèéìòóù",
20
+ "lt": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzĄąČčĖėĘęĮįŠšŪūŲųŽž",
21
+ "lv": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĀāČčĒēĢģĪīĶķĻļŅņŠšŪūŽž",
22
+ "mk": "ЃЅЈЉЊЌЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшѓѕјљњќџ",
23
+ "nl": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
24
+ "pl": "ABCDEFGHIJKLMNOPRSTUWYZabcdefghijklmnoprstuwyzÓóĄąĆćĘꣳŃńŚśŹźŻż",
25
+ "pt": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÁÂÃÇÉÊÍÓÔÕÚàáâãçéêíóôõú",
26
+ "ro": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÂÎâîĂăȘșȚț",
27
+ "ru": "ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
28
+ "sk": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÄÉÍÓÔÚÝáäéíóôúýČčĎďĹ弾ŇňŔ੹ŤťŽž",
29
+ "sl": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzČ芚Žž",
30
+ "sr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzЂЈЉЊЋЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшђјљњћџ",
31
+ "th": "กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลฦวศษสหฬอฮฯะัาำิีึืฺุู฿เแโใไๅๆ็่้๊๋์ํ๎๏๐๑๒๓๔๕๖๗๘๙๚๛",
32
+ "tr": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÂÇÎÖÛÜâçîöûüĞğİıŞş",
33
+ "vi": "ABCDEGHIKLMNOPQRSTUVXYabcdeghiklmnopqrstuvxyÂÊÔâêôĂăĐđƠơƯư"
34
+ }
homoglyphs.py CHANGED
@@ -9,10 +9,6 @@ from itertools import product
9
  import os
10
  import unicodedata
11
 
12
- import homoglyphs_fork as hg
13
-
14
- CURRENT_DIR = hg.core.CURRENT_DIR
15
-
16
  # Actions if char not in alphabet
17
  STRATEGY_LOAD = 1 # load category for this char
18
  STRATEGY_IGNORE = 2 # add char to result
@@ -21,13 +17,17 @@ STRATEGY_REMOVE = 3 # remove char from result
21
  ASCII_RANGE = range(128)
22
 
23
 
 
 
 
 
24
  class Categories:
25
  """
26
  Work with aliases from ISO 15924.
27
  https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
28
  """
29
 
30
- fpath = os.path.join(CURRENT_DIR, "categories.json")
31
 
32
  @classmethod
33
  def _get_ranges(cls, categories):
@@ -70,8 +70,9 @@ class Categories:
70
  # try detect category by unicodedata
71
  try:
72
  category = unicodedata.name(char).split()[0]
73
- except TypeError:
74
  # In Python2 unicodedata.name raise error for non-unicode chars
 
75
  pass
76
  else:
77
  if category in data["aliases"]:
@@ -91,7 +92,7 @@ class Categories:
91
 
92
 
93
  class Languages:
94
- fpath = os.path.join(CURRENT_DIR, "languages.json")
95
 
96
  @classmethod
97
  def get_alphabet(cls, languages):
@@ -167,8 +168,7 @@ class Homoglyphs:
167
  @staticmethod
168
  def get_table(alphabet):
169
  table = defaultdict(set)
170
- # removed CURRENT_DIR here:
171
- with open(os.path.join("confusables_sept2022.json")) as f:
172
  data = json.load(f)
173
  for char in alphabet:
174
  if char in data:
@@ -180,8 +180,7 @@ class Homoglyphs:
180
  @staticmethod
181
  def get_restricted_table(source_alphabet, target_alphabet):
182
  table = defaultdict(set)
183
- # removed CURRENT_DIR here:
184
- with open(os.path.join("confusables_sept2022.json")) as f:
185
  data = json.load(f)
186
  for char in source_alphabet:
187
  if char in data:
@@ -244,9 +243,7 @@ class Homoglyphs:
244
  alt_chars = self._get_char_variants(char)
245
 
246
  if ascii:
247
- alt_chars = [
248
- char for char in alt_chars if ord(char) in self.ascii_range
249
- ]
250
  if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
251
  return
252
 
 
9
  import os
10
  import unicodedata
11
 
 
 
 
 
12
  # Actions if char not in alphabet
13
  STRATEGY_LOAD = 1 # load category for this char
14
  STRATEGY_IGNORE = 2 # add char to result
 
17
  ASCII_RANGE = range(128)
18
 
19
 
20
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
22
+
23
+
24
  class Categories:
25
  """
26
  Work with aliases from ISO 15924.
27
  https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
28
  """
29
 
30
+ fpath = os.path.join(DATA_LOCATION, "categories.json")
31
 
32
  @classmethod
33
  def _get_ranges(cls, categories):
 
70
  # try detect category by unicodedata
71
  try:
72
  category = unicodedata.name(char).split()[0]
73
+ except (TypeError, ValueError):
74
  # In Python2 unicodedata.name raise error for non-unicode chars
75
+ # Python3 raise ValueError for non-unicode characters
76
  pass
77
  else:
78
  if category in data["aliases"]:
 
92
 
93
 
94
  class Languages:
95
+ fpath = os.path.join(DATA_LOCATION, "languages.json")
96
 
97
  @classmethod
98
  def get_alphabet(cls, languages):
 
168
  @staticmethod
169
  def get_table(alphabet):
170
  table = defaultdict(set)
171
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
 
172
  data = json.load(f)
173
  for char in alphabet:
174
  if char in data:
 
180
  @staticmethod
181
  def get_restricted_table(source_alphabet, target_alphabet):
182
  table = defaultdict(set)
183
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
 
184
  data = json.load(f)
185
  for char in source_alphabet:
186
  if char in data:
 
243
  alt_chars = self._get_char_variants(char)
244
 
245
  if ascii:
246
+ alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
 
 
247
  if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
248
  return
249
 
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- homoglyphs_fork
2
  nltk
3
  scipy
4
  torch
 
 
1
  nltk
2
  scipy
3
  torch
watermark_processor.py CHANGED
@@ -216,6 +216,8 @@ class WatermarkDetector(WatermarkBase):
216
  score_dict.update(dict(num_tokens_scored=num_tokens_scored))
217
  if return_num_green_tokens:
218
  score_dict.update(dict(num_green_tokens=green_token_count))
 
 
219
  if return_z_score:
220
  score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
221
  if return_p_value:
@@ -223,8 +225,6 @@ class WatermarkDetector(WatermarkBase):
223
  if z_score is None:
224
  z_score = self._compute_z_score(green_token_count, num_tokens_scored)
225
  score_dict.update(dict(p_value=self._compute_p_value(z_score)))
226
- if return_green_fraction:
227
- score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
228
  if return_green_token_mask:
229
  score_dict.update(dict(green_token_mask=green_token_mask))
230
 
 
216
  score_dict.update(dict(num_tokens_scored=num_tokens_scored))
217
  if return_num_green_tokens:
218
  score_dict.update(dict(num_green_tokens=green_token_count))
219
+ if return_green_fraction:
220
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
221
  if return_z_score:
222
  score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
223
  if return_p_value:
 
225
  if z_score is None:
226
  z_score = self._compute_z_score(green_token_count, num_tokens_scored)
227
  score_dict.update(dict(p_value=self._compute_p_value(z_score)))
 
 
228
  if return_green_token_mask:
229
  score_dict.update(dict(green_token_mask=green_token_mask))
230