jianuo commited on
Commit
c94015f
·
1 Parent(s): 7253da8

test recover

Browse files
Files changed (2) hide show
  1. app.py +9 -890
  2. demo_watermark.py +1083 -0
app.py CHANGED
@@ -1,40 +1,26 @@
1
  # 安装好环境
2
  # python app.py即可运行
3
- import spaces
4
- print('import 了 spaces')
5
-
6
  import os
7
- import time
8
- from argparse import Namespace
9
- from functools import partial
10
-
11
- import gradio as gr
12
- import gradio.exceptions
13
- import pandas as pd
14
-
15
- import torch
16
- from requests.exceptions import ReadTimeout
17
- from text_generation import InferenceAPIClient
18
- from transformers import (AutoTokenizer,
19
- AutoModelForCausalLM,
20
- LogitsProcessorList)
21
-
22
- from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
23
 
24
- # os.environ['HF_ENDPOINT']='https://hf-mirror.com'
25
 
 
26
  args = Namespace()
27
 
28
  arg_dict = {
29
  'run_gradio': True,
30
  'demo_public': False,
31
  'model_name_or_path': './model/Qwen2-0.5B-Instruct',
 
 
 
32
  'prompt_max_length': None,
33
  'max_new_tokens': 500,
34
  'generation_seed': 123,
35
  'use_sampling': True,
36
  'n_beams': 1,
37
  'sampling_temp': 0.7,
 
38
  'seeding_scheme': 'simple_1',
39
  'gamma': 0.5,
40
  'delta': 2.0,
@@ -42,879 +28,12 @@ arg_dict = {
42
  'ignore_repeated_bigrams': False,
43
  'detection_z_threshold': 4.0,
44
  'select_green_tokens': True,
 
45
  'seed_separately': True,
46
  }
47
 
48
  args.__dict__.update(arg_dict)
49
 
50
- # FIXME 所有模型的正确长度
51
-
52
- API_MODEL_MAP = {
53
- # "Qwen/Qwen1.5-0.5B-Chat": {"max_length": 2000, "gamma": 0.5, "delta": 2.0},
54
- # "THUDM/chatglm3-6b": {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
55
- }
56
-
57
- default_trace_table = pd.DataFrame(columns=["编号", "水印内容"])
58
- default_trace_table.loc[0] = (0, "默认用户")
59
- default_trace_table.loc[1] = (1, "张三")
60
- default_trace_table.loc[2] = (2, "李四")
61
-
62
- watermark_salt = 0
63
-
64
- model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
65
- trust_remote_code=True,
66
- local_files_only=True)
67
-
68
- tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,
69
- local_files_only=True)
70
-
71
- model.eval()
72
- model.to('cuda')
73
-
74
-
75
- def generate_with_api(prompt, args):
76
- hf_api_key = os.environ.get("HF_API_KEY")
77
- if hf_api_key is None:
78
- raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
79
-
80
- client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
81
-
82
- assert args.n_beams == 1, "HF API models do not support beam search."
83
- generation_params = {
84
- "max_new_tokens": args.max_new_tokens,
85
- "do_sample": args.use_sampling,
86
- }
87
- if args.use_sampling:
88
- generation_params["temperature"] = args.sampling_temp
89
- generation_params["seed"] = args.generation_seed
90
-
91
- timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
92
- try:
93
- generation_params["watermark"] = False
94
- without_watermark_iterator = client.generate_stream(prompt, **generation_params)
95
- except ReadTimeout as e:
96
- print(e)
97
- without_watermark_iterator = (char for char in timeout_msg)
98
- try:
99
- generation_params["watermark"] = True
100
- with_watermark_iterator = client.generate_stream(prompt, **generation_params)
101
- except ReadTimeout as e:
102
- print(e)
103
- with_watermark_iterator = (char for char in timeout_msg)
104
-
105
- all_without_words, all_with_words = "", ""
106
- for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
107
- all_without_words += without_word.token.text
108
- all_with_words += with_word.token.text
109
- yield all_without_words, all_with_words
110
-
111
-
112
- def check_prompt(prompt, args, tokenizer, model=None):
113
- # 这适用于本地和API模型场景
114
- try:
115
- if args.model_name_or_path in API_MODEL_MAP:
116
- args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
117
- elif hasattr(model.config, "max_position_embedding"):
118
- args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
119
- else:
120
- args.prompt_max_length = 4096 - args.max_new_tokens
121
- except Exception as e:
122
- print(e)
123
- args.prompt_max_length = 4096 - args.max_new_tokens
124
-
125
- tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, truncation=True,
126
- max_length=args.prompt_max_length).to('cuda')
127
- truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
128
- redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
129
-
130
- return (redecoded_input,
131
- int(truncation_warning),
132
- args)
133
-
134
-
135
- @spaces.GPU
136
- def generate(prompt, args, tokenizer, model=None):
137
- """根据水印参数实例化 WatermarkLogitsProcessor 并通过将其作为 logits 处理器传递给模型的 generate 方法来生成带水印的文本。"""
138
- print(f"Generating with {args}")
139
- print(f"Prompt: {prompt}")
140
-
141
- print(f'model device: {model.device}')
142
-
143
- if args.model_name_or_path in API_MODEL_MAP:
144
- api_outputs = generate_with_api(prompt, args)
145
- yield from api_outputs
146
- else:
147
- if 'chatglm' in args.model_name_or_path.lower() or 'qwen' in args.model_name_or_path.lower() or 'llama' in args.model_name_or_path.lower():
148
- messages = [
149
- # {"role": "system", "content": "You are a helpful assistant."},
150
- {"role": "user", "content": prompt}
151
- ]
152
-
153
- tokenized_input = tokenizer.apply_chat_template(
154
- messages,
155
- tokenize=False,
156
- add_generation_prompt=True
157
- )
158
-
159
- tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False,
160
- max_length=args.prompt_max_length).to('cuda')
161
-
162
- else:
163
- tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
164
- max_length=args.prompt_max_length).to('cuda')
165
-
166
- gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
167
-
168
- if args.use_sampling:
169
- gen_kwargs.update(dict(
170
- do_sample=True,
171
- top_k=0,
172
- temperature=args.sampling_temp
173
- ))
174
- else:
175
- gen_kwargs.update(dict(
176
- num_beams=args.n_beams
177
- ))
178
-
179
- watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
180
- gamma=args.gamma,
181
- delta=args.delta,
182
- seeding_scheme=args.seeding_scheme,
183
- extra_salt=watermark_salt,
184
- select_green_tokens=args.select_green_tokens)
185
-
186
- generate_without_watermark = partial(
187
- model.generate,
188
- **gen_kwargs
189
- )
190
-
191
- generate_with_watermark = partial(
192
- model.generate,
193
- logits_processor=LogitsProcessorList([watermark_processor]),
194
- **gen_kwargs
195
- )
196
-
197
- start_time = time.time()
198
- gr.Info('开始生成正常内容')
199
- torch.manual_seed(args.generation_seed)
200
- output_without_watermark = generate_without_watermark(**tokd_input)
201
-
202
- # 可选择在第二次生成之前种子,但通常不会再次相同,除非 delta==0.0,无操作水印
203
-
204
- print(watermark_salt)
205
- print(default_trace_table)
206
- print(default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'])
207
- gr.Info('开始注入水印:“{}”'.format(
208
- default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'].item()))
209
- if args.seed_separately:
210
- torch.manual_seed(args.generation_seed)
211
-
212
- output_with_watermark = generate_with_watermark(**tokd_input)
213
-
214
- output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
215
- output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]
216
-
217
- decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
218
- decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
219
-
220
- end_time = time.time()
221
- gr.Info(f"生成结束,共用时{end_time - start_time:.2f}秒")
222
-
223
- print(f"Generation took {end_time - start_time:.2f} seconds")
224
-
225
- # 使用空格分隔生成器风格模拟 API 输出
226
-
227
- all_without_words, all_with_words = "", ""
228
- for without_word, with_word in zip(decoded_output_without_watermark.split(),
229
- decoded_output_with_watermark.split()):
230
- all_without_words += without_word + " "
231
- all_with_words += with_word + " "
232
- yield all_without_words, all_with_words
233
-
234
-
235
- def format_names(s):
236
- """为 gradio 演示界面格式化名称"""
237
- s = s.replace("num_tokens_scored", "总Token")
238
- s = s.replace("num_green_tokens", "Green Token数量")
239
- s = s.replace("green_fraction", "Green Token占比")
240
- s = s.replace("z_score", "z-score")
241
- s = s.replace("p_value", "p value")
242
- s = s.replace("prediction", "预测结果")
243
- s = s.replace("confidence", "置信度")
244
- return s
245
-
246
-
247
- def list_format_scores(score_dict, detection_threshold):
248
- """将检测指标格式化为 gradio 数据框输入格式"""
249
- lst_2d = []
250
- for k, v in score_dict.items():
251
- if k == 'green_fraction':
252
- lst_2d.append([format_names(k), f"{v:.1%}"])
253
- elif k == 'confidence':
254
- lst_2d.append([format_names(k), f"{v:.3%}"])
255
- elif isinstance(v, float):
256
- lst_2d.append([format_names(k), f"{v:.3g}"])
257
- elif isinstance(v, bool):
258
- lst_2d.append([format_names(k), ("含有水印" if v else "无水印")])
259
- else:
260
- lst_2d.append([format_names(k), f"{v}"])
261
- if "confidence" in score_dict:
262
- lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
263
- else:
264
- lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
265
- return lst_2d
266
-
267
-
268
- def detect(input_text, args, tokenizer, return_green_token_mask=True):
269
- """实例化 WatermarkDetection 对象并调用 detect 方法 在输入文本上返回测试的分数和结果"""
270
-
271
- print(f"Detecting with {args}")
272
- print(f"Detection Tokenizer: {type(tokenizer)}")
273
-
274
- # 现在不要显示绿色的token mask
275
- # 如果我们使用的是normalizers或ignore_repeated_bigrams
276
- if args.normalizers != [] or args.ignore_repeated_bigrams:
277
- return_green_token_mask = False
278
-
279
- error = False
280
- green_token_mask = None
281
- if input_text == "":
282
- error = True
283
- else:
284
- try:
285
- for _, data in default_trace_table.iterrows():
286
- salt = data["编号"]
287
- name = data["水印内容"]
288
- watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
289
- gamma=args.gamma,
290
- seeding_scheme=args.seeding_scheme,
291
- extra_salt=salt,
292
- device='cuda',
293
- tokenizer=tokenizer,
294
- z_threshold=args.detection_z_threshold,
295
- normalizers=args.normalizers,
296
- ignore_repeated_bigrams=args.ignore_repeated_bigrams,
297
- select_green_tokens=args.select_green_tokens)
298
- score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
299
- if score_dict['prediction']:
300
- print(f"检测到是“{name}”的水印")
301
- break
302
-
303
- green_token_mask = score_dict.pop("green_token_mask", None)
304
- output = list_format_scores(score_dict, watermark_detector.z_threshold)
305
- except ValueError as e:
306
- print(e)
307
- error = True
308
- if error:
309
- output = [["Error", "string too short to compute metrics"]]
310
- output += [["", ""] for _ in range(6)]
311
-
312
- html_output = "[No highlight markup generated]"
313
-
314
- if green_token_mask is None:
315
- html_output = "[Visualizing masks with ignore_repeated_bigrams enabled is not supported, toggle off to see the mask for this text. The mask is the same in both cases - only counting/stats are affected.]"
316
-
317
- if green_token_mask is not None:
318
- # hack 因为我们需要一个带有字符跨度支持的快速分词器
319
- tokens = tokenizer(input_text, add_special_tokens=False)
320
- if tokens["input_ids"][0] == tokenizer.bos_token_id:
321
- tokens["input_ids"] = tokens["input_ids"][1:] # 忽略注意力掩码
322
- skip = watermark_detector.min_prefix_len
323
-
324
- if args.model_name_or_path in ['THUDM/chatglm3-6b']:
325
- # 假设词表中3-258就是字节0-255
326
- charspans = []
327
- for i in range(skip, len(tokens["input_ids"])):
328
- if tokens.data['input_ids'][i - 1] in range(3, 259):
329
- charspans.append("<0x{:X}>".format(tokens.data['input_ids'][i - 1] - 3))
330
- else:
331
- charspans.append(tokenizer.decode(tokens.data['input_ids'][i - 1:i]))
332
-
333
- else:
334
- charspans = [tokens.token_to_chars(i - 1) for i in range(skip, len(tokens["input_ids"]))]
335
-
336
- charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
337
-
338
- if len(charspans) != len(green_token_mask): breakpoint()
339
- assert len(charspans) == len(green_token_mask)
340
-
341
- if args.model_name_or_path in ['THUDM/chatglm3-6b']:
342
- tags = []
343
- for cs, m in zip(charspans, green_token_mask):
344
- tags.append(
345
- f'<span class="green">{cs}</span>' if m else f'<span class="red">{cs}</span>')
346
-
347
- else:
348
- tags = [(
349
- f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>')
350
- for cs, m in zip(charspans, green_token_mask)]
351
-
352
- html_output = f'<p>{" ".join(tags)}</p>'
353
-
354
- if score_dict['prediction']:
355
- html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(255, 0, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold;">
356
- <span>有 “{}” 的水印</span>
357
- </div>""".format(name), visible=True)
358
-
359
- else:
360
- html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(0, 128, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold; text-align: center;">
361
- <span>无水印</span>
362
- </div>""", visible=True)
363
-
364
- return output, args, tokenizer, html_output, html_look
365
-
366
-
367
- def run_gradio(args, model=None, tokenizer=None):
368
- """定义并启动gradio演示界面"""
369
- check_prompt_partial = partial(check_prompt, model=model)
370
- generate_partial = spaces.GPU(partial(generate, model=model))
371
- detect_partial = partial(detect)
372
-
373
- css = """
374
- .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
375
- .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
376
- """
377
-
378
- with gr.Blocks(theme='ParityError/Interstellar', css=css) as demo:
379
- # 顶部部分,问候语和说明
380
- with gr.Row():
381
- with gr.Column(scale=9):
382
- gr.Markdown(
383
- """
384
- # 🌸🖼️ LLMwatermark:面向大语言模型生成内容的数字水印版权保护系统 🌟🎓
385
- """
386
- )
387
- with gr.Column(scale=1):
388
- # 如果启动时的 model_name_or_path 不是 API 模型之一,则添加到下拉菜单中
389
- all_models = sorted(list(set(list(API_MODEL_MAP.keys()) + [args.model_name_or_path])))
390
- model_selector = gr.Dropdown(
391
- all_models,
392
- value=args.model_name_or_path,
393
- label="选择大语言模型,进行模型水印",
394
- )
395
-
396
- # 构建参数的状态,定义更新和切换
397
- default_prompt = args.__dict__.pop("default_prompt")
398
- session_args = gr.State(value=args)
399
- # 注意,如果状态对象是可调用的,则自动调用 value,希望在启动时避免调用分词器
400
- session_tokenizer = gr.State(value=lambda: tokenizer)
401
-
402
- with gr.Tab("生成回答和添加文本水印🎓"):
403
-
404
- with gr.Row():
405
- with gr.Column(scale=5):
406
- prompt = gr.Textbox(label=f"Prompt", interactive=True, lines=3, max_lines=10, value=default_prompt)
407
- with gr.Column(scale=3):
408
- trace_source = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
409
- col_count=(2, "fixed"))
410
- with gr.Row(equal_height=True):
411
- with gr.Column(scale=7):
412
- generate_btn = gr.Button("Generate", variant='primary')
413
-
414
- gr.Markdown('水印选择:',
415
- show_label=False)
416
- watermark_salt_choice = gr.Dropdown(
417
- choices=[i[::-1] for i in default_trace_table.to_dict(orient='split')['data']],
418
- value=0,
419
- container=False,
420
- scale=3,
421
- type="value",
422
- interactive=True, label="水印标识选择")
423
-
424
- with gr.Row():
425
- with gr.Column():
426
- with gr.Column(scale=2):
427
- with gr.Tab("原版输出"):
428
- output_without_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
429
- show_label=False)
430
- with gr.Tab("显示水印"):
431
- html_without_watermark = gr.HTML(elem_id="html-without-watermark")
432
-
433
- original_watermark_state = gr.HTML('', visible=False)
434
-
435
- with gr.Column(scale=1):
436
- without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],
437
- interactive=False,
438
- row_count=7, col_count=2)
439
- with gr.Column():
440
- with gr.Column(scale=2):
441
- with gr.Tab("带水印的输出"):
442
- output_with_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
443
- show_label=False)
444
- with gr.Tab("显示水印"):
445
- html_with_watermark = gr.HTML(elem_id="html-with-watermark")
446
-
447
- change_watermark_state = gr.HTML('', visible=False)
448
-
449
- with gr.Column(scale=1):
450
- with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
451
- row_count=7, col_count=2)
452
-
453
- redecoded_input = gr.Textbox(visible=False)
454
- truncation_warning = gr.Number(visible=False)
455
-
456
- def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
457
- if truncation_warning:
458
- return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
459
- else:
460
- return orig_prompt, args
461
-
462
- with gr.Tab("检测文本水印功能🎭"):
463
- with gr.Row():
464
- with gr.Column(scale=5):
465
- with gr.Tab("分析文本"):
466
- detection_input = gr.Textbox(interactive=True, lines=14, max_lines=14, show_label=False)
467
- with gr.Tab("显示水印"):
468
- html_detection_input = gr.HTML(elem_id="html-detection-input")
469
-
470
- detect_watermark_state = gr.HTML('', visible=False)
471
- with gr.Column(scale=2):
472
- trace_source2 = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
473
- col_count=(2, "fixed"))
474
- detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False, row_count=7,
475
- col_count=2)
476
-
477
- with gr.Row():
478
- detect_btn = gr.Button("检测", variant='primary')
479
-
480
- with gr.Tab("About📖"):
481
- with gr.Row():
482
- with gr.Column(scale=2):
483
- gr.Markdown(
484
- """
485
- 大语言模型可能带来的潜在危害可以通过*水印*来减轻。*水印*是嵌入在生成的文本中的信息,
486
- 这对人类来说是不可见的,但是可以被特定算法检测到。
487
- 这些水印可以使*任何人*用特定工具判断其是否使用带水印的模型生成的。
488
- 本网站展示了一种水印方法,可以应用于_任何_生成性语言模型。
489
- """
490
- )
491
- gr.Markdown(
492
- """
493
- **[生成文本与添加水印]**:可以给大模型的输出添加水印。
494
- 您可以尝试任何prompt,并比较正常文本(*没有水印的输出*)和水印文本(*有水印的输出*)的质量。
495
- 您还可以点击**显示水印**来“看到”水印,其中的颜色表示其所在的红绿表。
496
-
497
- **[检测]**:您还可以将水印文本(或任何其他文本)复制粘贴到第二个选项卡中。
498
- 可以实验删除多少句子后还能检测到水印。
499
- 还可以在验证,检测器的误报率有多少;
500
- """
501
- )
502
-
503
- with gr.Column(scale=1):
504
- gr.Markdown(
505
- """
506
- ![]()
507
- """
508
- )
509
-
510
- # 参数选择组
511
- with gr.Accordion("高级设置", open=False):
512
- with gr.Row():
513
- with gr.Column(scale=1):
514
- gr.Markdown(f"#### 生成参数")
515
- with gr.Row():
516
- decoding = gr.Radio(label="解码方法", choices=["多项式解码方法", "贪心解码方法"],
517
- value=("multinomial" if args.use_sampling else "greedy"))
518
- with gr.Row():
519
- sampling_temp = gr.Slider(label="采样温度", minimum=0.1, maximum=1.0, step=0.1,
520
- value=args.sampling_temp, visible=True)
521
- with gr.Row():
522
- generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
523
- with gr.Row():
524
- n_beams = gr.Dropdown(label="波束搜索解码", choices=list(range(1, 11, 1)), value=args.n_beams,
525
- visible=((not args.use_sampling) and (
526
- not args.model_name_or_path in API_MODEL_MAP)))
527
- with gr.Row():
528
- max_new_tokens = gr.Slider(label="(生成文本的最大长度)Max Generated Tokens", minimum=10,
529
- maximum=4000, step=10, value=args.max_new_tokens)
530
-
531
- with gr.Column(scale=1):
532
- gr.Markdown(f"#### 模型水印参数设置")
533
- with gr.Row():
534
- gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
535
- with gr.Row():
536
- delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
537
- gr.Markdown(f"#### 检测文本水印参数设置")
538
- with gr.Row():
539
- detection_z_threshold = gr.Slider(label="Z分数阈值", minimum=0.0, maximum=10.0, step=0.1,
540
- value=args.detection_z_threshold)
541
- with gr.Row():
542
- ignore_repeated_bigrams = gr.Checkbox(label="避免生成连续重复的双词组合")
543
- with gr.Row():
544
- normalizers = gr.CheckboxGroup(label="对文本进行标准化处理",
545
- choices=["unicode", "homoglyphs", "truecase"],
546
- value=args.normalizers)
547
- with gr.Row():
548
- gr.Markdown(
549
- f"注意:滑块并不总是能完美更新。点击条形图或使用右侧的数字窗口会有所帮助。下面的窗口显示当前设置。")
550
- with gr.Row():
551
- current_parameters = gr.Textbox(label="当前参数设置", value=args, max_lines=10)
552
- with gr.Accordion("传统设置", open=False):
553
- with gr.Row():
554
- with gr.Column(scale=1):
555
- seed_separately = gr.Checkbox(label="为两个不同的生成过程分别设置随机种子",
556
- value=args.seed_separately)
557
- with gr.Column(scale=1):
558
- select_green_tokens = gr.Checkbox(label="从分区中选择 绿色列表", value=args.select_green_tokens)
559
-
560
- with gr.Accordion("设置有什么作用?", open=False):
561
- gr.Markdown(
562
- """
563
- #### 生成参数:
564
-
565
- - **解码方法**:我们可以使用多项式采样或者贪婪解码的方式从模型中生成标记。
566
- 决定如何从模型中生成token。可以选择多项式采样或贪婪解码。
567
- 多项式采样允许一定的随机性,而贪婪解码总是选择概率最高的下一个token。
568
- - **采样温度**:如果使用多项式采样,我们可以设置采样分布的温度。
569
- - 0.0 相当于贪婪解码,而 1.0 代表最大的随机性。
570
- - 0.7是文本质量和随机性之间的平衡点。不适用于贪婪解码。
571
- - **生成种子**:用于在生成前初始化随机数生成器,使多项式采样的输出可复现。此设置不适用于贪婪解码。
572
- - **束搜索数量**:在使用贪婪解码时,可以设置光束数量来启用束搜索。
573
- 这允许考虑多个候选序列,而不是只选择最有可能的一个。此设置目前仅适用于贪婪解码。
574
- - **最大生成标记数**:传递给生成方法的 `max_new_tokens` 参数,以在一定数量的新标记停止输出。
575
- - 请注意,根据提示,模型可以生成较少的标记。隐含地,这将最大化可能的提示标记数,即模型的最大输入长度减去 `max_new_tokens`,并相应地截断输入。
576
-
577
- 综上所述,这些参数提供了对生成过程的不同方面的控制,包括随机性和多样性
578
- (解码方法、采样温度),可复现性(生成种子),以及输出长度和多样性(光束数量、最大生成token数)。
579
- 合理配置这些参数可以帮助生成高质量的文本输出。
580
-
581
- #### 水印参数:
582
-
583
- - **gamma**:在每个生成步骤中将词汇表的一部分划分为绿色列表的比例。
584
- - 较小的 gamma 值通过使水印模型优先从较小的绿色集中采样,从而使其与人类/未水印文本的差异更大,从而创建更强的水印。
585
- - **delta**:在每个生成步骤中为绿色列表中的每个标记的 logits 添加的正偏差量。较高的 delta 值意味着水印模型更偏好于绿色列表中的标记。
586
- - 随着偏差变得非常大,水印从 "软" 过渡到 "硬"。对于硬水印,几乎所有标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性不大时。
587
-
588
- #### 检测器参数:
589
-
590
- - **z 分数阈值**:假设检验的 z 分数截断。较高的阈值(例如 4.0)使得 _false positives_(预测人类/未水印文本被标记为水印)非常不可能,
591
- 因为一个真正的人类文本几乎永远不会达到那么高的 z 分数。
592
- - 较低的阈值会捕获更多的 _true positives_,因为一些水印文本可能包含较少的绿色标记并达到较低的 z 分数,但仍然通过较低的门槛并被标记为 "水印"。
593
- 然而,较低的阈值会增加包含略高于平均水平的绿色标记的人类文本错误地被标记为水印的几率。4.0-5.0 提供了极低的假阳性率,同时仍准确捕获大多数水印文本。
594
- - **忽略二元重复**:这种替代的检测算法在检测期间仅考虑文本中的唯一二元组,根据每对中的第一个计算绿色列表,并检查第二个是否位于列表中。
595
- - 这意味着 `T` 现在是文本中唯一二元组的数量,如果文本包含大量重复,则这个数字将小于生成的总标记数。有关更详细的讨论,请参阅论文。
596
- - **归一化**:我们实现了一些基本的归一化来抵御文本在检测期间的各种对抗性扰动。
597
- - 目前,我们支持将所有字符转换为 Unicode,将同形异义字符替换为规范形式,并标准化大写。有关输入归一化的详细讨论,请参阅论文。
598
- """
599
- )
600
-
601
- with gr.Accordion("输出指标意味着什么?", open=False):
602
- gr.Markdown(
603
- """
604
- - `z-score threshold`:假设检验的截止值。
605
- - `Tokens Counted (T)`:检测算法计算的输出中的标记数量。在简单的、单标记播种方案中,第一个标记被省略了,因为没有办法为其生成绿色列表,因为它没有前缀标记。
606
- 在底部面板描述的“忽略二元重复”检测算法下,如果有很多重复,这个数量可能远少于生成的总标记数。
607
- - ` Tokens in Greenlist`:观察到落在其相应绿色列表中的标记数量。
608
- - `Fraction of T in Greenlist`:`# Tokens in Greenlist` / `T`。这应该大约等于人类/未水印文本的 `gamma`。
609
- - `z-score`:用于检测假设检验的检验统计量。如果大于 `z-score threshold`,则我们“拒绝零假设”,即文本是人类/未水印的,并得出结论它是带水印的。
610
- - `p value`:在零假设下观察到计算出的 `z-score` 的可能性。这是在不知道水印程序/绿色列表的情况下观察到 `Fraction of T in Greenlist` 的可能性。
611
- 如果这个值极其 _小_,我们可以确信这么多的绿色标记不是由随机机会选择的。
612
- - `prediction`:假设检验的结果 - 观察到的 `z-score` 是否高于 `z-score threshold`。
613
- - `confidence`:如果我们拒绝零假设,且 `prediction` 是“带水印”,那么我们报告 1-`p value` 来表示基于这个 `z-score` 观察的检测的置信度。
614
- """
615
- )
616
-
617
- gr.HTML("""
618
- <p>本方法可以对任何大模型的输出结果进行水印的操作。
619
- 并且可以对输出结果进行水印检测。
620
- <p/>
621
- """)
622
-
623
- # 注册主要生成标签单击事件,输出生成文本以及编码+重新解码+可能被截断的提示和标志,然后调用检测
624
- generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer],
625
- outputs=[redecoded_input, truncation_warning, session_args]).success(
626
- fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer],
627
- outputs=[output_without_watermark, output_with_watermark]).success(
628
- fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
629
- outputs=[without_watermark_detection_result, session_args, session_tokenizer,
630
- html_without_watermark, original_watermark_state]).success(
631
- fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
632
- outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark,
633
- change_watermark_state])
634
- # 如果发生了截断,则显示提示的截断版本
635
- redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
636
- outputs=[prompt, session_args])
637
- # Register main detection tab click
638
- detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
639
- outputs=[detection_result, session_args, session_tokenizer, html_detection_input,
640
- detect_watermark_state],
641
- api_name="detection")
642
-
643
- # 状态管理逻辑
644
- # 定义更新回调函数以更改状态字典
645
- def update_model(session_state, value):
646
- session_state.model_name_or_path = value
647
- return session_state
648
-
649
- def update_sampling_temp(session_state, value):
650
- session_state.sampling_temp = float(value)
651
- return session_state
652
-
653
- def update_generation_seed(session_state, value):
654
- session_state.generation_seed = int(value)
655
- return session_state
656
-
657
- def update_watermark_salt(value):
658
- global watermark_salt
659
- if isinstance(value, int):
660
- watermark_salt = value
661
- elif value is None:
662
- watermark_salt = 0
663
- elif isinstance(value, str) and value.isdigit():
664
- watermark_salt = int(value)
665
- else:
666
- # 不知道为什么会出现这种倒置的情况
667
- watermark_salt = int(
668
- default_trace_table.loc[default_trace_table['水印内容'] == value, '编号'].item())
669
-
670
- def update_trace_source(value):
671
- global default_trace_table
672
- try:
673
- if '' in value.loc[:, '编号'].tolist():
674
- return value, gr.Dropdown()
675
-
676
- value.loc[:, '编号'] = value.loc[:, '编号'].astype(int)
677
-
678
- if default_trace_table.duplicated(subset='编号').any():
679
- raise gr.Error(f"请检查水印编号,编号不能重复")
680
-
681
- default_trace_table = value
682
-
683
- return value, gr.Dropdown(
684
- choices=[i[::-1] for i in value.to_dict(orient='split')['data']])
685
-
686
- except ValueError as e:
687
- if 'invalid literal for int() with base 10' in str(e):
688
- raise gr.Error(f"请检查水印数据,编号必须是整数:{e}")
689
-
690
- except gradio.exceptions.Error as e:
691
- raise e
692
-
693
- except Exception as e:
694
- print(type(e))
695
- raise e
696
-
697
- def update_gamma(session_state, value):
698
- session_state.gamma = float(value)
699
- return session_state
700
-
701
- def update_delta(session_state, value):
702
- session_state.delta = float(value)
703
- return session_state
704
-
705
- def update_detection_z_threshold(session_state, value):
706
- session_state.detection_z_threshold = float(value)
707
- return session_state
708
-
709
- def update_decoding(session_state, value):
710
- if value == "multinomial":
711
- session_state.use_sampling = True
712
- elif value == "greedy":
713
- session_state.use_sampling = False
714
- return session_state
715
-
716
- def toggle_sampling_vis(value):
717
- if value == "multinomial":
718
- return gr.update(visible=True)
719
- elif value == "greedy":
720
- return gr.update(visible=False)
721
-
722
- def toggle_sampling_vis_inv(value):
723
- if value == "multinomial":
724
- return gr.update(visible=False)
725
- elif value == "greedy":
726
- return gr.update(visible=True)
727
-
728
- # 如果模型名称在 API 模型列表中,则将 num beams 参数设置为 1 并隐藏 n_beams
729
- def toggle_vis_for_api_model(value):
730
- if value in API_MODEL_MAP:
731
- return gr.update(visible=False)
732
- else:
733
- return gr.update(visible=True)
734
-
735
- def toggle_beams_for_api_model(value, orig_n_beams):
736
- if value in API_MODEL_MAP:
737
- return gr.update(value=1)
738
- else:
739
- return gr.update(value=orig_n_beams)
740
-
741
- # 如果模型名称在 API 模型列表中,则将交互参数设置为 false
742
- def toggle_interactive_for_api_model(value):
743
- if value in API_MODEL_MAP:
744
- return gr.update(interactive=False)
745
- else:
746
- return gr.update(interactive=True)
747
-
748
- # 如果模型名称在 API 模型列表中,则根据 API 映射设置 gamma 和 delta
749
- def toggle_gamma_for_api_model(value, orig_gamma):
750
- if value in API_MODEL_MAP:
751
- return gr.update(value=API_MODEL_MAP[value]["gamma"])
752
- else:
753
- return gr.update(value=orig_gamma)
754
-
755
- def toggle_delta_for_api_model(value, orig_delta):
756
- if value in API_MODEL_MAP:
757
- return gr.update(value=API_MODEL_MAP[value]["delta"])
758
- else:
759
- return gr.update(value=orig_delta)
760
-
761
- def update_n_beams(session_state, value):
762
- session_state.n_beams = value;
763
- return session_state
764
-
765
- def update_max_new_tokens(session_state, value):
766
- session_state.max_new_tokens = int(value);
767
- return session_state
768
-
769
- def update_ignore_repeated_bigrams(session_state, value):
770
- session_state.ignore_repeated_bigrams = value;
771
- return session_state
772
-
773
- def update_normalizers(session_state, value):
774
- session_state.normalizers = value;
775
- return session_state
776
-
777
- def update_seed_separately(session_state, value):
778
- session_state.seed_separately = value;
779
- return session_state
780
-
781
- def update_select_green_tokens(session_state, value):
782
- session_state.select_green_tokens = value;
783
- return session_state
784
-
785
- def update_tokenizer(model_name_or_path):
786
- # if model_name_or_path == ALPACA_MODEL_NAME:
787
- # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
788
- # else:
789
- return AutoTokenizer.from_pretrained(model_name_or_path)
790
-
791
- def check_model(value):
792
- return value if (value != "" and value is not None) else args.model_name_or_path
793
-
794
- # 强制约束模型不能为 null 或空
795
- # 然后特别附加模型回调函数
796
- model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
797
- toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]
798
- ).then(
799
- toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams]
800
- ).then(
801
- toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma]
802
- ).then(
803
- toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta]
804
- ).then(
805
- toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma]
806
- ).then(
807
- toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta]
808
- ).then(
809
- update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer]
810
- ).then(
811
- update_model, inputs=[session_args, model_selector], outputs=[session_args]
812
- ).then(
813
- lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
814
- )
815
- # 根据其他参数的值注册回调函数以切换特定参数的可见性
816
- decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
817
- decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
818
- decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
819
- decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams])
820
- # 注册所有状态更新回调函数
821
- decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
822
- sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
823
- generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
824
- watermark_salt_choice.change(update_watermark_salt, inputs=[watermark_salt_choice])
825
-
826
- # 同步更新
827
- trace_source.change(update_trace_source, inputs=[trace_source],
828
- outputs=[trace_source2, watermark_salt_choice])
829
- trace_source2.change(update_trace_source, inputs=[trace_source2],
830
- outputs=[trace_source, watermark_salt_choice])
831
-
832
- n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
833
- max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
834
- gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
835
- delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
836
- detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
837
- outputs=[session_args])
838
- ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
839
- outputs=[session_args])
840
- normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
841
- seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
842
- select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
843
- outputs=[session_args])
844
- # 注册按钮点击时更新显示参数窗口的额外回调
845
- generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
846
- detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
847
- # 当参数更改时,显示更新并触发检测,因为某些检测参数不会改变模型输出。
848
- delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
849
- gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
850
- gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
851
- outputs=[without_watermark_detection_result, session_args, session_tokenizer,
852
- html_without_watermark])
853
- gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
854
- outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
855
- gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
856
- outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
857
- detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
858
- detection_z_threshold.change(fn=detect_partial,
859
- inputs=[output_without_watermark, session_args, session_tokenizer],
860
- outputs=[without_watermark_detection_result, session_args, session_tokenizer,
861
- html_without_watermark])
862
- detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
863
- outputs=[with_watermark_detection_result, session_args, session_tokenizer,
864
- html_with_watermark])
865
- detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
866
- outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
867
- ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
868
- ignore_repeated_bigrams.change(fn=detect_partial,
869
- inputs=[output_without_watermark, session_args, session_tokenizer],
870
- outputs=[without_watermark_detection_result, session_args, session_tokenizer,
871
- html_without_watermark])
872
- ignore_repeated_bigrams.change(fn=detect_partial,
873
- inputs=[output_with_watermark, session_args, session_tokenizer],
874
- outputs=[with_watermark_detection_result, session_args, session_tokenizer,
875
- html_with_watermark])
876
- ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
877
- outputs=[detection_result, session_args, session_tokenizer,
878
- html_detection_input])
879
- normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
880
- normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
881
- outputs=[without_watermark_detection_result, session_args, session_tokenizer,
882
- html_without_watermark])
883
- normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
884
- outputs=[with_watermark_detection_result, session_args, session_tokenizer,
885
- html_with_watermark])
886
- normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
887
- outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
888
- select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
889
- select_green_tokens.change(fn=detect_partial,
890
- inputs=[output_without_watermark, session_args, session_tokenizer],
891
- outputs=[without_watermark_detection_result, session_args, session_tokenizer,
892
- html_without_watermark])
893
- select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
894
- outputs=[with_watermark_detection_result, session_args, session_tokenizer,
895
- html_with_watermark])
896
- select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
897
- outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
898
-
899
- # demo.queue(concurrency_count=3) # delete
900
-
901
- if args.demo_public:
902
- demo.launch(share=True) # 通过随机生成的链接将应用程序暴露到互联网上
903
- else:
904
- demo.launch(server_name='0.0.0.0', share=False)
905
-
906
-
907
- # 初始参数处理和日志记录
908
- args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
909
- print(args)
910
-
911
- # terrapin example
912
- input_text = (
913
- "为什么A股指数跌的不多,但是我亏损比之前都多?"
914
- )
915
-
916
- args.default_prompt = input_text
917
 
918
- # Launch the app to generate and detect interactively (implements the hf space demo)
919
- if args.run_gradio:
920
- run_gradio(args, model=model, tokenizer=tokenizer)
 
1
  # 安装好环境
2
  # python app.py即可运行
 
 
 
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ os.environ['HF_ENDPOINT']='https://hf-mirror.com'
6
 
7
+ from argparse import Namespace
8
  args = Namespace()
9
 
10
  arg_dict = {
11
  'run_gradio': True,
12
  'demo_public': False,
13
  'model_name_or_path': './model/Qwen2-0.5B-Instruct',
14
+
15
+ # 'model_name_or_path': 'Qwen/Qwen2-0.5B-Instruct-GGUF',
16
+ 'gguf_file': './qwen2-0_5b-instruct-q8_0.gguf', # 只有用gguf模型会用到,即model_name_or_path里含有gguf字符串才会用到
17
  'prompt_max_length': None,
18
  'max_new_tokens': 500,
19
  'generation_seed': 123,
20
  'use_sampling': True,
21
  'n_beams': 1,
22
  'sampling_temp': 0.7,
23
+ 'use_gpu': True,
24
  'seeding_scheme': 'simple_1',
25
  'gamma': 0.5,
26
  'delta': 2.0,
 
28
  'ignore_repeated_bigrams': False,
29
  'detection_z_threshold': 4.0,
30
  'select_green_tokens': True,
31
+ 'skip_model_load': False,
32
  'seed_separately': True,
33
  }
34
 
35
  args.__dict__.update(arg_dict)
36
 
37
+ from demo_watermark import main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ main(args)
 
 
demo_watermark.py CHANGED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from functools import partial
5
+ import spaces
6
+
7
+ import gradio.exceptions
8
+
9
+ import gradio as gr
10
+ import pandas as pd
11
+ import torch
12
+ from transformers import (AutoTokenizer,
13
+ AutoModelForCausalLM,
14
+ LogitsProcessorList)
15
+
16
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
17
+
18
+ # FIXME 所有模型的正确长度
19
+
20
+ API_MODEL_MAP = {
21
+ # "Qwen/Qwen1.5-0.5B-Chat": {"max_length": 2000, "gamma": 0.5, "delta": 2.0},
22
+ # "THUDM/chatglm3-6b": {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
23
+ }
24
+
25
+ default_trace_table = pd.DataFrame(columns=["编号", "水印内容"])
26
+ default_trace_table.loc[0] = (0, "本文本由A模型生成")
27
+ default_trace_table.loc[1] = (1, "本文本由B模型生成")
28
+ default_trace_table.loc[2] = (2, "本文本由用户小王生成")
29
+
30
+ watermark_salt = 0
31
+
32
+
33
+ def str2bool(v):
34
+ """用户友好的布尔标志参数的Util函数"""
35
+ if isinstance(v, bool):
36
+ return v
37
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
38
+ return True
39
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
40
+ return False
41
+ else:
42
+ raise argparse.ArgumentTypeError('Boolean value expected.')
43
+
44
+
45
+ # 定义一个函数用于解析命令行参数
46
+ def parse_args():
47
+ parser = argparse.ArgumentParser(
48
+ description="")
49
+
50
+ parser.add_argument(
51
+ "--run_gradio",
52
+ type=str2bool,
53
+ default=True,
54
+ help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
55
+ )
56
+ parser.add_argument(
57
+ "--demo_public",
58
+ type=str2bool,
59
+ default=False,
60
+ help="Whether to expose the gradio demo to the internet.",
61
+ )
62
+ parser.add_argument(
63
+ "--model_name_or_path",
64
+ type=str,
65
+ default="Qwen/Qwen1.5-0.5B-Chat",
66
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
67
+ )
68
+ parser.add_argument(
69
+ "--prompt_max_length",
70
+ type=int,
71
+ default=None,
72
+ help="Truncation length for prompt, overrides model config's max length field.",
73
+ )
74
+ parser.add_argument(
75
+ "--max_new_tokens",
76
+ type=int,
77
+ default=200,
78
+ help="Maximmum number of new tokens to generate.",
79
+ )
80
+ parser.add_argument(
81
+ "--generation_seed",
82
+ type=int,
83
+ default=123,
84
+ help="Seed for setting the torch global rng prior to generation.",
85
+ )
86
+ parser.add_argument(
87
+ "--use_sampling",
88
+ type=str2bool,
89
+ default=True,
90
+ help="Whether to generate using multinomial sampling.",
91
+ )
92
+ parser.add_argument(
93
+ "--sampling_temp",
94
+ type=float,
95
+ default=0.7,
96
+ help="Sampling temperature to use when generating using multinomial sampling.",
97
+ )
98
+ parser.add_argument(
99
+ "--n_beams",
100
+ type=int,
101
+ default=1,
102
+ help="Number of beams to use for beam search. 1 is normal greedy decoding",
103
+ )
104
+ parser.add_argument(
105
+ "--use_gpu",
106
+ type=str2bool,
107
+ default=True,
108
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
109
+ )
110
+ parser.add_argument(
111
+ "--seeding_scheme",
112
+ type=str,
113
+ default="simple_1",
114
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
115
+ )
116
+ parser.add_argument(
117
+ "--gamma",
118
+ type=float,
119
+ default=0.5,
120
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
121
+ )
122
+ parser.add_argument(
123
+ "--delta",
124
+ type=float,
125
+ default=2.0,
126
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
127
+ )
128
+ parser.add_argument(
129
+ "--normalizers",
130
+ type=str,
131
+ default="",
132
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
133
+ )
134
+ parser.add_argument(
135
+ "--ignore_repeated_bigrams",
136
+ type=str2bool,
137
+ default=False,
138
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
139
+ )
140
+ parser.add_argument(
141
+ "--detection_z_threshold",
142
+ type=float,
143
+ default=4.0,
144
+ help="The test statistic threshold for the detection hypothesis test.",
145
+ )
146
+ parser.add_argument(
147
+ "--select_green_tokens",
148
+ type=str2bool,
149
+ default=True,
150
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
151
+ )
152
+ parser.add_argument(
153
+ "--skip_model_load",
154
+ type=str2bool,
155
+ default=False,
156
+ help="Skip the model loading to debug the interface.",
157
+ )
158
+ parser.add_argument(
159
+ "--gguf_file",
160
+ type=str,
161
+ default='./qwen2-0_5b-instruct-q2_k.gguf',
162
+ help="gguf文件(如果有)",
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--seed_separately",
167
+ type=str2bool,
168
+ default=True,
169
+ help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
170
+ )
171
+ args = parser.parse_args()
172
+
173
+ return args
174
+
175
+
176
+ def load_model(args):
177
+ """加载并返回模型和分词器"""
178
+
179
+ if args.use_gpu:
180
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
181
+ else:
182
+ device = "cpu"
183
+
184
+ if 'gguf' in args.model_name_or_path.lower():
185
+
186
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file,
187
+ trust_remote_code=True,
188
+ local_files_only=True,
189
+ device_map=device)
190
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file,
191
+ local_files_only=True,
192
+ trust_remote_code=True)
193
+
194
+ else:
195
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
196
+ trust_remote_code=True,
197
+ local_files_only=True,
198
+ device_map=device)
199
+
200
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,
201
+ local_files_only=True, )
202
+
203
+ try:
204
+ model.eval()
205
+ except Exception as e:
206
+ print(e)
207
+
208
+ return model, tokenizer, device
209
+
210
+
211
+ from text_generation import InferenceAPIClient
212
+ from requests.exceptions import ReadTimeout
213
+
214
+
215
+ def generate_with_api(prompt, args):
216
+ hf_api_key = os.environ.get("HF_API_KEY")
217
+ if hf_api_key is None:
218
+ raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
219
+
220
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
221
+
222
+ assert args.n_beams == 1, "HF API models do not support beam search."
223
+ generation_params = {
224
+ "max_new_tokens": args.max_new_tokens,
225
+ "do_sample": args.use_sampling,
226
+ }
227
+ if args.use_sampling:
228
+ generation_params["temperature"] = args.sampling_temp
229
+ generation_params["seed"] = args.generation_seed
230
+
231
+ timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
232
+ try:
233
+ generation_params["watermark"] = False
234
+ without_watermark_iterator = client.generate_stream(prompt, **generation_params)
235
+ except ReadTimeout as e:
236
+ print(e)
237
+ without_watermark_iterator = (char for char in timeout_msg)
238
+ try:
239
+ generation_params["watermark"] = True
240
+ with_watermark_iterator = client.generate_stream(prompt, **generation_params)
241
+ except ReadTimeout as e:
242
+ print(e)
243
+ with_watermark_iterator = (char for char in timeout_msg)
244
+
245
+ all_without_words, all_with_words = "", ""
246
+ for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
247
+ all_without_words += without_word.token.text
248
+ all_with_words += with_word.token.text
249
+ yield all_without_words, all_with_words
250
+
251
+
252
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
253
+ # 这适用于本地和API模型场景
254
+ try:
255
+ if args.model_name_or_path in API_MODEL_MAP:
256
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
257
+ elif hasattr(model.config, "max_position_embedding"):
258
+ args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
259
+ else:
260
+ args.prompt_max_length = 4096 - args.max_new_tokens
261
+ except Exception as e:
262
+ print(e)
263
+ args.prompt_max_length = 4096 - args.max_new_tokens
264
+
265
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, truncation=True,
266
+ max_length=args.prompt_max_length).to(device)
267
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
268
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
269
+
270
+ return (redecoded_input,
271
+ int(truncation_warning),
272
+ args)
273
+
274
+
275
+ def generate(prompt, args, tokenizer, model=None, device=None):
276
+ """根据水印参数实例化 WatermarkLogitsProcessor 并通过将其作为 logits 处理器传递给模型的 generate 方法来生成带水印的文本。"""
277
+ print(f"Generating with {args}")
278
+ print(f"Prompt: {prompt}")
279
+
280
+ if args.model_name_or_path in API_MODEL_MAP:
281
+ api_outputs = generate_with_api(prompt, args)
282
+ yield from api_outputs
283
+ else:
284
+ if 'chatglm' in args.model_name_or_path.lower() or 'qwen' in args.model_name_or_path.lower() or 'llama' in args.model_name_or_path.lower():
285
+ messages = [
286
+ # {"role": "system", "content": "You are a helpful assistant."},
287
+ {"role": "user", "content": prompt}
288
+ ]
289
+
290
+ tokenized_input = tokenizer.apply_chat_template(
291
+ messages,
292
+ tokenize=False,
293
+ add_generation_prompt=True
294
+ )
295
+
296
+ tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False,
297
+ max_length=args.prompt_max_length).to(device)
298
+
299
+ else:
300
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
301
+ max_length=args.prompt_max_length).to(device)
302
+
303
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
304
+
305
+ if args.use_sampling:
306
+ gen_kwargs.update(dict(
307
+ do_sample=True,
308
+ top_k=0,
309
+ temperature=args.sampling_temp
310
+ ))
311
+ else:
312
+ gen_kwargs.update(dict(
313
+ num_beams=args.n_beams
314
+ ))
315
+
316
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
317
+ gamma=args.gamma,
318
+ delta=args.delta,
319
+ seeding_scheme=args.seeding_scheme,
320
+ extra_salt=watermark_salt,
321
+ select_green_tokens=args.select_green_tokens)
322
+
323
+ generate_without_watermark = partial(
324
+ model.generate,
325
+ **gen_kwargs
326
+ )
327
+
328
+ generate_with_watermark = partial(
329
+ model.generate,
330
+ logits_processor=LogitsProcessorList([watermark_processor]),
331
+ **gen_kwargs
332
+ )
333
+
334
+ start_time = time.time()
335
+ gr.Info('开始生成正常内容')
336
+ torch.manual_seed(args.generation_seed)
337
+ output_without_watermark = generate_without_watermark(**tokd_input)
338
+
339
+ # 可选择在第二次生成之前种子,但通常不会再次相同,除非 delta==0.0,无操作水印
340
+
341
+ print(watermark_salt)
342
+ print(default_trace_table)
343
+ print(default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'])
344
+ gr.Info('开始注入水印:“{}”'.format(
345
+ default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'].item()))
346
+ if args.seed_separately:
347
+ torch.manual_seed(args.generation_seed)
348
+
349
+ output_with_watermark = generate_with_watermark(**tokd_input)
350
+
351
+ output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
352
+ output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]
353
+
354
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
355
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
356
+
357
+ end_time = time.time()
358
+ gr.Info(f"生成结束,共用时{end_time - start_time:.2f}秒")
359
+
360
+ print(f"Generation took {end_time - start_time:.2f} seconds")
361
+
362
+ # 使用空格分隔生成器风格模拟 API 输出
363
+
364
+ all_without_words, all_with_words = "", ""
365
+ for without_word, with_word in zip(decoded_output_without_watermark.split(),
366
+ decoded_output_with_watermark.split()):
367
+ all_without_words += without_word + " "
368
+ all_with_words += with_word + " "
369
+ yield all_without_words, all_with_words
370
+
371
+
372
+ def format_names(s):
373
+ """为 gradio 演示界面格式化名称"""
374
+ s = s.replace("num_tokens_scored", "总Token")
375
+ s = s.replace("num_green_tokens", "Green Token数量")
376
+ s = s.replace("green_fraction", "Green Token占比")
377
+ s = s.replace("z_score", "z-score")
378
+ s = s.replace("p_value", "p value")
379
+ s = s.replace("prediction", "预测结果")
380
+ s = s.replace("confidence", "置信度")
381
+ return s
382
+
383
+
384
+ def list_format_scores(score_dict, detection_threshold):
385
+ """将检测指标格式化为 gradio 数据框输入格式"""
386
+ lst_2d = []
387
+ for k, v in score_dict.items():
388
+ if k == 'green_fraction':
389
+ lst_2d.append([format_names(k), f"{v:.1%}"])
390
+ elif k == 'confidence':
391
+ lst_2d.append([format_names(k), f"{v:.3%}"])
392
+ elif isinstance(v, float):
393
+ lst_2d.append([format_names(k), f"{v:.3g}"])
394
+ elif isinstance(v, bool):
395
+ lst_2d.append([format_names(k), ("含有水印" if v else "无水印")])
396
+ else:
397
+ lst_2d.append([format_names(k), f"{v}"])
398
+ if "confidence" in score_dict:
399
+ lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
400
+ else:
401
+ lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
402
+ return lst_2d
403
+
404
+
405
+ def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
406
+ """实例化 WatermarkDetection 对象并调用 detect 方法 在输入文本上返回测试的分数和结果"""
407
+
408
+ print(f"Detecting with {args}")
409
+ print(f"Detection Tokenizer: {type(tokenizer)}")
410
+
411
+ # 现在不要显示绿色的token mask
412
+ # 如果我们使用的是normalizers或ignore_repeated_bigrams
413
+ if args.normalizers != [] or args.ignore_repeated_bigrams:
414
+ return_green_token_mask = False
415
+
416
+ error = False
417
+ green_token_mask = None
418
+ if input_text == "":
419
+ error = True
420
+ else:
421
+ try:
422
+ for _, data in default_trace_table.iterrows():
423
+ salt = data["编号"]
424
+ name = data["水印内容"]
425
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
426
+ gamma=args.gamma,
427
+ seeding_scheme=args.seeding_scheme,
428
+ extra_salt=salt,
429
+ device=device,
430
+ tokenizer=tokenizer,
431
+ z_threshold=args.detection_z_threshold,
432
+ normalizers=args.normalizers,
433
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
434
+ select_green_tokens=args.select_green_tokens)
435
+ score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
436
+ if score_dict['prediction']:
437
+ print(f"检测到是“{name}”的水印")
438
+ break
439
+
440
+ green_token_mask = score_dict.pop("green_token_mask", None)
441
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
442
+ except ValueError as e:
443
+ print(e)
444
+ error = True
445
+ if error:
446
+ output = [["Error", "string too short to compute metrics"]]
447
+ output += [["", ""] for _ in range(6)]
448
+
449
+ html_output = "[No highlight markup generated]"
450
+
451
+ if green_token_mask is None:
452
+ html_output = "[Visualizing masks with ignore_repeated_bigrams enabled is not supported, toggle off to see the mask for this text. The mask is the same in both cases - only counting/stats are affected.]"
453
+
454
+ if green_token_mask is not None:
455
+ # hack 因为我们需要一个带有字符跨度支持的快速分词器
456
+ tokens = tokenizer(input_text, add_special_tokens=False)
457
+ if tokens["input_ids"][0] == tokenizer.bos_token_id:
458
+ tokens["input_ids"] = tokens["input_ids"][1:] # 忽略注意力掩码
459
+ skip = watermark_detector.min_prefix_len
460
+
461
+ if args.model_name_or_path in ['THUDM/chatglm3-6b']:
462
+ # 假设词表中3-258就是字节0-255
463
+ charspans = []
464
+ for i in range(skip, len(tokens["input_ids"])):
465
+ if tokens.data['input_ids'][i - 1] in range(3, 259):
466
+ charspans.append("<0x{:X}>".format(tokens.data['input_ids'][i - 1] - 3))
467
+ else:
468
+ charspans.append(tokenizer.decode(tokens.data['input_ids'][i - 1:i]))
469
+
470
+ else:
471
+ charspans = [tokens.token_to_chars(i - 1) for i in range(skip, len(tokens["input_ids"]))]
472
+
473
+ charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
474
+
475
+ if len(charspans) != len(green_token_mask): breakpoint()
476
+ assert len(charspans) == len(green_token_mask)
477
+
478
+ if args.model_name_or_path in ['THUDM/chatglm3-6b']:
479
+ tags = []
480
+ for cs, m in zip(charspans, green_token_mask):
481
+ tags.append(
482
+ f'<span class="green">{cs}</span>' if m else f'<span class="red">{cs}</span>')
483
+
484
+ else:
485
+ tags = [(
486
+ f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>')
487
+ for cs, m in zip(charspans, green_token_mask)]
488
+
489
+ html_output = f'<p>{" ".join(tags)}</p>'
490
+
491
+ if score_dict['prediction']:
492
+ html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(255, 0, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold;">
493
+ <span>有 “{}” 的水印</span>
494
+ </div>""".format(name), visible=True)
495
+
496
+ else:
497
+ html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(0, 128, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold; text-align: center;">
498
+ <span>无水印</span>
499
+ </div>""", visible=True)
500
+
501
+ return output, args, tokenizer, html_output, html_look
502
+
503
+
504
+ def run_gradio(args, model=None, device=None, tokenizer=None):
505
+ """定义并启动gradio演示界面"""
506
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
507
+ generate_partial = spaces.GPU(partial(generate, model=model, device=device))
508
+ detect_partial = partial(detect, device=device)
509
+
510
+ css = """
511
+ .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
512
+ .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
513
+ """
514
+
515
+ with gr.Blocks(theme='ParityError/Interstellar', css=css) as demo:
516
+ # 顶部部分,问候语和说明
517
+ with gr.Row():
518
+ with gr.Column(scale=9):
519
+ gr.Markdown(
520
+ """
521
+ # 🌸🖼️ 追本溯源—面向大语言模型生成内容的水印版权保护系统 🌟🎓
522
+ """
523
+ )
524
+ with gr.Column(scale=1):
525
+ # 如果启动时的 model_name_or_path 不是 API 模型之一,则添加到下拉菜单中
526
+ all_models = sorted(list(set(list(API_MODEL_MAP.keys()) + [args.model_name_or_path])))
527
+ model_selector = gr.Dropdown(
528
+ all_models,
529
+ value=args.model_name_or_path,
530
+ label="选择大语言模型,进行模型水印",
531
+ )
532
+
533
+ # 构建参数的状态,定义更新和切换
534
+ default_prompt = args.__dict__.pop("default_prompt")
535
+ session_args = gr.State(value=args)
536
+ # 注意,如果状态对象是可调用的,则自动调用 value,希望在启动时避免调用分词器
537
+ session_tokenizer = gr.State(value=lambda: tokenizer)
538
+
539
+ with gr.Tab("生成回答和添加文本水印🎓"):
540
+
541
+ with gr.Row():
542
+ with gr.Column(scale=5):
543
+ prompt = gr.Textbox(label=f"Prompt", interactive=True, lines=3, max_lines=10, value=default_prompt)
544
+ with gr.Column(scale=3):
545
+ trace_source = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
546
+ col_count=(2, "fixed"))
547
+ with gr.Row(equal_height=True):
548
+ with gr.Column(scale=7):
549
+ generate_btn = gr.Button("Generate", variant='primary')
550
+
551
+ gr.Markdown('水印选择:',
552
+ show_label=False)
553
+ watermark_salt_choice = gr.Dropdown(
554
+ choices=[i[::-1] for i in default_trace_table.to_dict(orient='split')['data']],
555
+ value=0,
556
+ container=False,
557
+ scale=3,
558
+ type="value",
559
+ interactive=True, label="水印标识选择")
560
+
561
+ with gr.Row():
562
+ with gr.Column():
563
+ with gr.Column(scale=2):
564
+ with gr.Tab("原版输出"):
565
+ output_without_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
566
+ show_label=False)
567
+ with gr.Tab("显示水印"):
568
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
569
+
570
+ original_watermark_state = gr.HTML('', visible=False)
571
+
572
+ with gr.Column(scale=1):
573
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],
574
+ interactive=False,
575
+ row_count=7, col_count=2)
576
+ with gr.Column():
577
+ with gr.Column(scale=2):
578
+ with gr.Tab("带水印的输出"):
579
+ output_with_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
580
+ show_label=False)
581
+ with gr.Tab("显示水印"):
582
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
583
+
584
+ change_watermark_state = gr.HTML('', visible=False)
585
+
586
+ with gr.Column(scale=1):
587
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
588
+ row_count=7, col_count=2)
589
+
590
+ redecoded_input = gr.Textbox(visible=False)
591
+ truncation_warning = gr.Number(visible=False)
592
+
593
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
594
+ if truncation_warning:
595
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
596
+ else:
597
+ return orig_prompt, args
598
+
599
+ with gr.Tab("检测文本水印功能🎭"):
600
+ with gr.Row():
601
+ with gr.Column(scale=5):
602
+ with gr.Tab("分析文本"):
603
+ detection_input = gr.Textbox(interactive=True, lines=14, max_lines=14, show_label=False)
604
+ with gr.Tab("显示水印"):
605
+ html_detection_input = gr.HTML(elem_id="html-detection-input")
606
+
607
+ detect_watermark_state = gr.HTML('', visible=False)
608
+ with gr.Column(scale=2):
609
+ trace_source2 = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
610
+ col_count=(2, "fixed"))
611
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False, row_count=7,
612
+ col_count=2)
613
+
614
+ with gr.Row():
615
+ detect_btn = gr.Button("检测", variant='primary')
616
+
617
+ with gr.Tab("About📖"):
618
+ with gr.Row():
619
+ with gr.Column(scale=2):
620
+ gr.Markdown(
621
+ """
622
+ 大语言模型可能带来的潜在危害可以通过*水印*来减轻。*水印*是嵌入在生成的文本中的信息,
623
+ 这对人类来说是不可见的,但是可以被特定算法检测到。
624
+ 这些水印可以使*任何人*用特定工具判断其是否使用带水印的模型生成的。
625
+ 本网站展示了一种水印方法,可以应用于_任何_生成性语言模型。
626
+ """
627
+ )
628
+ gr.Markdown(
629
+ """
630
+ **[生成文本与添加水印]**:可以给大模型的输出添加水印。
631
+ 您可以尝试任何prompt,并比较正常文本(*没有水印的输出*)和水印文本(*有水印的输出*)的质量。
632
+ 您还可以点击**显示水印**来“看到”水印,其中的颜色表示其所在的红绿表。
633
+
634
+ **[检测]**:您还可以将水印文本(或任何其他文本)复制粘贴到第二个选项卡中。
635
+ 可以实验删除多少句子后还能检测到水印。
636
+ 还可以在验证,检测器的误报率有多少;
637
+ """
638
+ )
639
+
640
+ with gr.Column(scale=1):
641
+ gr.Markdown(
642
+ """
643
+ ![]()
644
+ """
645
+ )
646
+
647
+ # 参数选择组
648
+ with gr.Accordion("高级设置", open=False):
649
+ with gr.Row():
650
+ with gr.Column(scale=1):
651
+ gr.Markdown(f"#### 生成参数")
652
+ with gr.Row():
653
+ decoding = gr.Radio(label="解码方法", choices=["多项式解码方法", "贪心解码方法"],
654
+ value=("multinomial" if args.use_sampling else "greedy"))
655
+ with gr.Row():
656
+ sampling_temp = gr.Slider(label="采样温度", minimum=0.1, maximum=1.0, step=0.1,
657
+ value=args.sampling_temp, visible=True)
658
+ with gr.Row():
659
+ generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
660
+ with gr.Row():
661
+ n_beams = gr.Dropdown(label="波束搜索解码", choices=list(range(1, 11, 1)), value=args.n_beams,
662
+ visible=((not args.use_sampling) and (
663
+ not args.model_name_or_path in API_MODEL_MAP)))
664
+ with gr.Row():
665
+ max_new_tokens = gr.Slider(label="(生成文本的最大长度)Max Generated Tokens", minimum=10,
666
+ maximum=4000, step=10, value=args.max_new_tokens)
667
+
668
+ with gr.Column(scale=1):
669
+ gr.Markdown(f"#### 模型水印参数设置")
670
+ with gr.Row():
671
+ gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
672
+ with gr.Row():
673
+ delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
674
+ gr.Markdown(f"#### 检测文本水印参数设置")
675
+ with gr.Row():
676
+ detection_z_threshold = gr.Slider(label="Z分数阈值", minimum=0.0, maximum=10.0, step=0.1,
677
+ value=args.detection_z_threshold)
678
+ with gr.Row():
679
+ ignore_repeated_bigrams = gr.Checkbox(label="避免生成连续重复的双词组合")
680
+ with gr.Row():
681
+ normalizers = gr.CheckboxGroup(label="对文本进行标准化处理",
682
+ choices=["unicode", "homoglyphs", "truecase"],
683
+ value=args.normalizers)
684
+ with gr.Row():
685
+ gr.Markdown(
686
+ f"注意:滑块并不总是能完美更新。点击条形图或使用右侧的数字窗口会有所帮助。下面的窗口显示当前设置。")
687
+ with gr.Row():
688
+ current_parameters = gr.Textbox(label="当前参数设置", value=args, max_lines=10)
689
+ with gr.Accordion("传统设置", open=False):
690
+ with gr.Row():
691
+ with gr.Column(scale=1):
692
+ seed_separately = gr.Checkbox(label="为两个不同的生成过程分别设置随机种子",
693
+ value=args.seed_separately)
694
+ with gr.Column(scale=1):
695
+ select_green_tokens = gr.Checkbox(label="从分区中选择 绿色列表", value=args.select_green_tokens)
696
+
697
+ with gr.Accordion("设置有什么作用?", open=False):
698
+ gr.Markdown(
699
+ """
700
+ #### 生成参数:
701
+
702
+ - **解码方法**:我们可以使用多项式采样或者贪婪解码的方式从模型中生成标记。
703
+ 决定如何从模型中生成token。可以选择多项式采样或贪婪解码。
704
+ 多项式采样允许一定的随机性,而贪婪解码总是选择概率最高的下一个token。
705
+ - **采样温度**:如果使用多项式采样,我们可以设置采样分布的温度。
706
+ - 0.0 相当于贪婪解码,而 1.0 代表最大的随机性。
707
+ - 0.7是文本质量和随机性之间的平衡点。不适用于贪婪解码。
708
+ - **生成种子**:用于在生成前初始化随机数生成器,使多项式采样的输出可复现。此设置不适用于贪婪解码。
709
+ - **束搜索数量**:在使用贪婪解码时,可以设置光束数量来启用束搜索。
710
+ 这允许考虑多个候选序列,而不是只选择最有可能的一个。此设置目前仅适用于贪婪解码。
711
+ - **最大生成标记数**:传递给生成方法的 `max_new_tokens` 参数,以在一定数量的新标记停止输出。
712
+ - 请注意,根据提示,模型可以生成较少的标记。隐含地,这将最大化可能的提示标记数,即模型的最大输入长度减去 `max_new_tokens`,并相应地截断输入。
713
+
714
+ 综上所述,这些参数提供了对生成过程的不同方面的控制,包括随机性和多样性
715
+ (解码方法、采样温度),可复现性(生成种子),以及输出长度和多样性(光束数量、最大生成token数)。
716
+ 合理配置这些参数可以帮助生成高质量的文本输出。
717
+
718
+ #### 水印参数:
719
+
720
+ - **gamma**:在每个生成步骤中将词汇表的一部分划分为绿色列表的比例。
721
+ - 较小的 gamma 值通过使水印模型优先从较小的绿色集中采样,从而使其与人类/未水印文本的差异更大,从而创建更强的水印。
722
+ - **delta**:在每个生成步骤中为绿色列表中的每个标记的 logits 添加的正偏差量。较高的 delta 值意味着水印模型更偏好于绿色列表中的标记。
723
+ - 随着偏差变得非常大,水印从 "软" 过渡到 "硬"。对于硬水印,几乎所有标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性不大时。
724
+
725
+ #### 检测器参数:
726
+
727
+ - **z 分数阈值**:假设检验的 z 分数截断。较高的阈值(例如 4.0)使得 _false positives_(预测人类/未水印文本被标记为水印)非常不可能,
728
+ 因为一个真正的人类文本几乎永远不会达到那么高的 z 分数。
729
+ - 较低的阈值会捕获更多的 _true positives_,因为一些水印文本可能包含较少的绿色标记并达到较低的 z 分数,但仍然通过较低的门槛并被标记为 "水印"。
730
+ 然而,较低的阈值会增加包含略高于平均水平的绿色标记的人类文本错误地被标记为水印的几率。4.0-5.0 提供了极低的假阳性率,同时仍准确捕获大多数水印文本。
731
+ - **忽略二元重复**:这种替代的检测算法在检测期间仅考虑文本中的唯一二元组,根据每对中的第一个计算绿色列表,并检查第二个是��位于列表中。
732
+ - 这意味着 `T` 现在是文本中唯一二元组的数量,如果文本包含大量重复,则这个数字将小于生成的总标记数。有关更详细的讨论,请参阅论文。
733
+ - **归一化**:我们实现了一些基本的归一化来抵御文本在检测期间的各种对抗性扰动。
734
+ - 目前,我们支持将所有字符转换为 Unicode,将同形异义字符替换为规范形式,并标准化大写。有关输入归一化的详细讨论,请参阅论文。
735
+ """
736
+ )
737
+
738
+ with gr.Accordion("输出指标意味着什么?", open=False):
739
+ gr.Markdown(
740
+ """
741
+ - `z-score threshold`:假设检验的截止值。
742
+ - `Tokens Counted (T)`:检测算法计算的输出中的标记数量。在简单的、单标记播种方案中,第一个标记被省略了,因为没有办法为其生成绿色列表,因为它没有前缀标记。
743
+ 在底部面板描述的“忽略二元重复”检测算法下,如果有很多重复,这个数量可能远少于生成的总标记数。
744
+ - ` Tokens in Greenlist`:观察到落在其相应绿色列表中的标记数量。
745
+ - `Fraction of T in Greenlist`:`# Tokens in Greenlist` / `T`。这应该大约等于人类/未水印文本的 `gamma`。
746
+ - `z-score`:用于检测假设检验的检验统计量。如果大于 `z-score threshold`,则我们“拒绝零假设”,即文本是人类/未水印的,并得出结论它是带水印的。
747
+ - `p value`:在零假设下观察到计算出的 `z-score` 的可能性。这是在不知道水印程序/绿色列表的情况下观察到 `Fraction of T in Greenlist` 的可能性。
748
+ 如果这个值极其 _小_,我们可以确信这么多的绿色标记不是由随机机会选择的。
749
+ - `prediction`:假设检验的结果 - 观察到的 `z-score` 是否高于 `z-score threshold`。
750
+ - `confidence`:如果我们拒绝零假设,且 `prediction` 是“带水印”,那么我们报告 1-`p value` 来表示基于这个 `z-score` 观察的检测的置信度。
751
+ """
752
+ )
753
+
754
+ gr.HTML("""
755
+ <p>本方法可以对任何大模型的输出结果进行水印的操作。
756
+ 并且可以对输出结果进行水印检测。
757
+ <p/>
758
+ """)
759
+
760
+ # 注册主要生成标签单击事件,输出生成文本以及编码+重新解码+可能被截断的提示和标志,然后调用检测
761
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer],
762
+ outputs=[redecoded_input, truncation_warning, session_args]).success(
763
+ fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer],
764
+ outputs=[output_without_watermark, output_with_watermark]).success(
765
+ fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
766
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
767
+ html_without_watermark, original_watermark_state]).success(
768
+ fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
769
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark,
770
+ change_watermark_state])
771
+ # 如果发生了截断,则显示提示的截断版本
772
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
773
+ outputs=[prompt, session_args])
774
+ # Register main detection tab click
775
+ detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
776
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input,
777
+ detect_watermark_state],
778
+ api_name="detection")
779
+
780
+ # 状态管理逻辑
781
+ # 定义更新回调函数以更改状态字典
782
+ def update_model(session_state, value):
783
+ session_state.model_name_or_path = value
784
+ return session_state
785
+
786
+ def update_sampling_temp(session_state, value):
787
+ session_state.sampling_temp = float(value)
788
+ return session_state
789
+
790
+ def update_generation_seed(session_state, value):
791
+ session_state.generation_seed = int(value)
792
+ return session_state
793
+
794
+ def update_watermark_salt(value):
795
+ global watermark_salt
796
+ if isinstance(value, int):
797
+ watermark_salt = value
798
+ elif value is None:
799
+ watermark_salt = 0
800
+ elif isinstance(value, str) and value.isdigit():
801
+ watermark_salt = int(value)
802
+ else:
803
+ # 不知道为什么会出现这种倒置的情况
804
+ watermark_salt = int(
805
+ default_trace_table.loc[default_trace_table['水印内容'] == value, '编号'].item())
806
+
807
+ def update_trace_source(value):
808
+ global default_trace_table
809
+ try:
810
+ if '' in value.loc[:, '编号'].tolist():
811
+ return value, gr.Dropdown()
812
+
813
+ value.loc[:, '编号'] = value.loc[:, '编号'].astype(int)
814
+
815
+ if default_trace_table.duplicated(subset='编号').any():
816
+ raise gr.Error(f"请检查水印编号,编号不能重复")
817
+
818
+ default_trace_table = value
819
+
820
+ return value, gr.Dropdown(
821
+ choices=[i[::-1] for i in value.to_dict(orient='split')['data']])
822
+
823
+ except ValueError as e:
824
+ if 'invalid literal for int() with base 10' in str(e):
825
+ raise gr.Error(f"请检查水印数据,编号必须是整数:{e}")
826
+
827
+ except gradio.exceptions.Error as e:
828
+ raise e
829
+
830
+ except Exception as e:
831
+ print(type(e))
832
+ raise e
833
+
834
+ def update_gamma(session_state, value):
835
+ session_state.gamma = float(value)
836
+ return session_state
837
+
838
+ def update_delta(session_state, value):
839
+ session_state.delta = float(value)
840
+ return session_state
841
+
842
+ def update_detection_z_threshold(session_state, value):
843
+ session_state.detection_z_threshold = float(value)
844
+ return session_state
845
+
846
+ def update_decoding(session_state, value):
847
+ if value == "multinomial":
848
+ session_state.use_sampling = True
849
+ elif value == "greedy":
850
+ session_state.use_sampling = False
851
+ return session_state
852
+
853
+ def toggle_sampling_vis(value):
854
+ if value == "multinomial":
855
+ return gr.update(visible=True)
856
+ elif value == "greedy":
857
+ return gr.update(visible=False)
858
+
859
+ def toggle_sampling_vis_inv(value):
860
+ if value == "multinomial":
861
+ return gr.update(visible=False)
862
+ elif value == "greedy":
863
+ return gr.update(visible=True)
864
+
865
+ # 如果模型名称在 API 模型列表中,则将 num beams 参数设置为 1 并隐藏 n_beams
866
+ def toggle_vis_for_api_model(value):
867
+ if value in API_MODEL_MAP:
868
+ return gr.update(visible=False)
869
+ else:
870
+ return gr.update(visible=True)
871
+
872
+ def toggle_beams_for_api_model(value, orig_n_beams):
873
+ if value in API_MODEL_MAP:
874
+ return gr.update(value=1)
875
+ else:
876
+ return gr.update(value=orig_n_beams)
877
+
878
+ # 如果模型名称在 API 模型列表中,则将交互参数设置为 false
879
+ def toggle_interactive_for_api_model(value):
880
+ if value in API_MODEL_MAP:
881
+ return gr.update(interactive=False)
882
+ else:
883
+ return gr.update(interactive=True)
884
+
885
+ # 如果模型名称在 API 模型列表中,则根据 API 映射设置 gamma 和 delta
886
+ def toggle_gamma_for_api_model(value, orig_gamma):
887
+ if value in API_MODEL_MAP:
888
+ return gr.update(value=API_MODEL_MAP[value]["gamma"])
889
+ else:
890
+ return gr.update(value=orig_gamma)
891
+
892
+ def toggle_delta_for_api_model(value, orig_delta):
893
+ if value in API_MODEL_MAP:
894
+ return gr.update(value=API_MODEL_MAP[value]["delta"])
895
+ else:
896
+ return gr.update(value=orig_delta)
897
+
898
+ def update_n_beams(session_state, value):
899
+ session_state.n_beams = value;
900
+ return session_state
901
+
902
+ def update_max_new_tokens(session_state, value):
903
+ session_state.max_new_tokens = int(value);
904
+ return session_state
905
+
906
+ def update_ignore_repeated_bigrams(session_state, value):
907
+ session_state.ignore_repeated_bigrams = value;
908
+ return session_state
909
+
910
+ def update_normalizers(session_state, value):
911
+ session_state.normalizers = value;
912
+ return session_state
913
+
914
+ def update_seed_separately(session_state, value):
915
+ session_state.seed_separately = value;
916
+ return session_state
917
+
918
+ def update_select_green_tokens(session_state, value):
919
+ session_state.select_green_tokens = value;
920
+ return session_state
921
+
922
+ def update_tokenizer(model_name_or_path):
923
+ # if model_name_or_path == ALPACA_MODEL_NAME:
924
+ # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
925
+ # else:
926
+ return AutoTokenizer.from_pretrained(model_name_or_path)
927
+
928
+ def check_model(value):
929
+ return value if (value != "" and value is not None) else args.model_name_or_path
930
+
931
+ # 强制约束模型不能为 null 或���
932
+ # 然后特别附加模型回调函数
933
+ model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
934
+ toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]
935
+ ).then(
936
+ toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams]
937
+ ).then(
938
+ toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma]
939
+ ).then(
940
+ toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta]
941
+ ).then(
942
+ toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma]
943
+ ).then(
944
+ toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta]
945
+ ).then(
946
+ update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer]
947
+ ).then(
948
+ update_model, inputs=[session_args, model_selector], outputs=[session_args]
949
+ ).then(
950
+ lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
951
+ )
952
+ # 根据其他参数的值注册回调函数以切换特定参数的可见性
953
+ decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
954
+ decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
955
+ decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
956
+ decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams])
957
+ # 注册所有状态更新回调函数
958
+ decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
959
+ sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
960
+ generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
961
+ watermark_salt_choice.change(update_watermark_salt, inputs=[watermark_salt_choice])
962
+
963
+ # 同步更新
964
+ trace_source.change(update_trace_source, inputs=[trace_source],
965
+ outputs=[trace_source2, watermark_salt_choice])
966
+ trace_source2.change(update_trace_source, inputs=[trace_source2],
967
+ outputs=[trace_source, watermark_salt_choice])
968
+
969
+ n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
970
+ max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
971
+ gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
972
+ delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
973
+ detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
974
+ outputs=[session_args])
975
+ ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
976
+ outputs=[session_args])
977
+ normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
978
+ seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
979
+ select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
980
+ outputs=[session_args])
981
+ # 注册按钮点击时更新显示参数窗口的额外回调
982
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
983
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
984
+ # 当参数更改时,显示更新并触发检测,因为某些检测参数不会改变模型输出。
985
+ delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
986
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
987
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
988
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
989
+ html_without_watermark])
990
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
991
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
992
+ gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
993
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
994
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
995
+ detection_z_threshold.change(fn=detect_partial,
996
+ inputs=[output_without_watermark, session_args, session_tokenizer],
997
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
998
+ html_without_watermark])
999
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1000
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1001
+ html_with_watermark])
1002
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1003
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1004
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1005
+ ignore_repeated_bigrams.change(fn=detect_partial,
1006
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1007
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1008
+ html_without_watermark])
1009
+ ignore_repeated_bigrams.change(fn=detect_partial,
1010
+ inputs=[output_with_watermark, session_args, session_tokenizer],
1011
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1012
+ html_with_watermark])
1013
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1014
+ outputs=[detection_result, session_args, session_tokenizer,
1015
+ html_detection_input])
1016
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1017
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
1018
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1019
+ html_without_watermark])
1020
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1021
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1022
+ html_with_watermark])
1023
+ normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1024
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1025
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1026
+ select_green_tokens.change(fn=detect_partial,
1027
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1028
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1029
+ html_without_watermark])
1030
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1031
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1032
+ html_with_watermark])
1033
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1034
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1035
+
1036
+ # demo.queue(concurrency_count=3) # delete
1037
+
1038
+ if args.demo_public:
1039
+ demo.launch(share=True) # 通过随机生成的链接将应用程序暴露到互联网上
1040
+ else:
1041
+ demo.launch(server_name='0.0.0.0', share=False)
1042
+
1043
+
1044
+ def main(args):
1045
+ """运行生成和检测操作的命令行版本
1046
+ 并可选择启动和提供 gradio 演示"""
1047
+ # 初始参数处理和日志记录
1048
+ args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
1049
+ print(args)
1050
+
1051
+ if not args.skip_model_load:
1052
+ model, tokenizer, device = load_model(args)
1053
+ else:
1054
+ model, tokenizer, device = None, None, None
1055
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
1056
+ if args.use_gpu:
1057
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
1058
+ else:
1059
+ device = "cpu"
1060
+
1061
+ # terrapin example
1062
+ input_text = (
1063
+ "为什么A股指数跌的不多,但是我亏损比之前都多?"
1064
+ )
1065
+
1066
+ args.default_prompt = input_text
1067
+
1068
+ # Generate and detect, report to stdout
1069
+ if not args.skip_model_load:
1070
+ pass
1071
+
1072
+ # Launch the app to generate and detect interactively (implements the hf space demo)
1073
+ if args.run_gradio:
1074
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
1075
+
1076
+ return
1077
+
1078
+
1079
+ if __name__ == "__main__":
1080
+ args = parse_args()
1081
+ print(args)
1082
+
1083
+ main(args)