import argparse import os import time from functools import partial import spaces import gradio.exceptions import gradio as gr import pandas as pd import torch from transformers import (AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList) from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector # FIXME 所有模型的正确长度 API_MODEL_MAP = { # "Qwen/Qwen1.5-0.5B-Chat": {"max_length": 2000, "gamma": 0.5, "delta": 2.0}, # "THUDM/chatglm3-6b": {"max_length": 2048, "gamma": 0.5, "delta": 2.0}, } default_trace_table = pd.DataFrame(columns=["编号", "水印内容"]) default_trace_table.loc[0] = (0, "本文本由A模型生成") default_trace_table.loc[1] = (1, "本文本由B模型生成") default_trace_table.loc[2] = (2, "本文本由用户小王生成") watermark_salt = 0 def str2bool(v): """用户友好的布尔标志参数的Util函数""" if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') # 定义一个函数用于解析命令行参数 def parse_args(): parser = argparse.ArgumentParser( description="") parser.add_argument( "--run_gradio", type=str2bool, default=True, help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.", ) parser.add_argument( "--demo_public", type=str2bool, default=False, help="Whether to expose the gradio demo to the internet.", ) parser.add_argument( "--model_name_or_path", type=str, default="Qwen/Qwen1.5-0.5B-Chat", help="Main model, path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--prompt_max_length", type=int, default=None, help="Truncation length for prompt, overrides model config's max length field.", ) parser.add_argument( "--max_new_tokens", type=int, default=200, help="Maximmum number of new tokens to generate.", ) parser.add_argument( "--generation_seed", type=int, default=123, help="Seed for setting the torch global rng prior to generation.", ) parser.add_argument( "--use_sampling", type=str2bool, default=True, help="Whether to generate using multinomial sampling.", ) parser.add_argument( "--sampling_temp", type=float, default=0.7, help="Sampling temperature to use when generating using multinomial sampling.", ) parser.add_argument( "--n_beams", type=int, default=1, help="Number of beams to use for beam search. 1 is normal greedy decoding", ) parser.add_argument( "--use_gpu", type=str2bool, default=True, help="Whether to run inference and watermark hashing/seeding/permutation on gpu.", ) parser.add_argument( "--seeding_scheme", type=str, default="simple_1", help="Seeding scheme to use to generate the greenlists at each generation and verification step.", ) parser.add_argument( "--gamma", type=float, default=0.5, help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.", ) parser.add_argument( "--delta", type=float, default=2.0, help="The amount/bias to add to each of the greenlist token logits before each token sampling step.", ) parser.add_argument( "--normalizers", type=str, default="", help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", ) parser.add_argument( "--ignore_repeated_bigrams", type=str2bool, default=False, help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", ) parser.add_argument( "--detection_z_threshold", type=float, default=4.0, help="The test statistic threshold for the detection hypothesis test.", ) parser.add_argument( "--select_green_tokens", type=str2bool, default=True, help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", ) parser.add_argument( "--skip_model_load", type=str2bool, default=False, help="Skip the model loading to debug the interface.", ) parser.add_argument( "--gguf_file", type=str, default='./qwen2-0_5b-instruct-q2_k.gguf', help="gguf文件(如果有)", ) parser.add_argument( "--seed_separately", type=str2bool, default=True, help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.", ) args = parser.parse_args() return args def load_model(args): """加载并返回模型和分词器""" if args.use_gpu: device = "cuda:0" if torch.cuda.is_available() else "cpu" else: device = "cpu" if 'gguf' in args.model_name_or_path.lower(): model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file, trust_remote_code=True, local_files_only=True, device_map=device) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file, local_files_only=True, trust_remote_code=True) else: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, local_files_only=True, device_map=device) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True, local_files_only=True, ) try: model.eval() except Exception as e: print(e) return model, tokenizer, device from text_generation import InferenceAPIClient from requests.exceptions import ReadTimeout def generate_with_api(prompt, args): hf_api_key = os.environ.get("HF_API_KEY") if hf_api_key is None: raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.") client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60) assert args.n_beams == 1, "HF API models do not support beam search." generation_params = { "max_new_tokens": args.max_new_tokens, "do_sample": args.use_sampling, } if args.use_sampling: generation_params["temperature"] = args.sampling_temp generation_params["seed"] = args.generation_seed timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]" try: generation_params["watermark"] = False without_watermark_iterator = client.generate_stream(prompt, **generation_params) except ReadTimeout as e: print(e) without_watermark_iterator = (char for char in timeout_msg) try: generation_params["watermark"] = True with_watermark_iterator = client.generate_stream(prompt, **generation_params) except ReadTimeout as e: print(e) with_watermark_iterator = (char for char in timeout_msg) all_without_words, all_with_words = "", "" for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator): all_without_words += without_word.token.text all_with_words += with_word.token.text yield all_without_words, all_with_words def check_prompt(prompt, args, tokenizer, model=None, device=None): # 这适用于本地和API模型场景 try: if args.model_name_or_path in API_MODEL_MAP: args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"] elif hasattr(model.config, "max_position_embedding"): args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens else: args.prompt_max_length = 4096 - args.max_new_tokens except Exception as e: print(e) args.prompt_max_length = 4096 - args.max_new_tokens tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, truncation=True, max_length=args.prompt_max_length).to(device) truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0] return (redecoded_input, int(truncation_warning), args) @spaces.GPU def generate(prompt, args, tokenizer, model=None, device=None): """根据水印参数实例化 WatermarkLogitsProcessor 并通过将其作为 logits 处理器传递给模型的 generate 方法来生成带水印的文本。""" print(f"Generating with {args}") print(f"Prompt: {prompt}") if args.model_name_or_path in API_MODEL_MAP: api_outputs = generate_with_api(prompt, args) yield from api_outputs else: if 'chatglm' in args.model_name_or_path.lower() or 'qwen' in args.model_name_or_path.lower() or 'llama' in args.model_name_or_path.lower(): messages = [ # {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] tokenized_input = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False, max_length=args.prompt_max_length).to(device) else: tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) gen_kwargs = dict(max_new_tokens=args.max_new_tokens) if args.use_sampling: gen_kwargs.update(dict( do_sample=True, top_k=0, temperature=args.sampling_temp )) else: gen_kwargs.update(dict( num_beams=args.n_beams )) watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()), gamma=args.gamma, delta=args.delta, seeding_scheme=args.seeding_scheme, extra_salt=watermark_salt, select_green_tokens=args.select_green_tokens) generate_without_watermark = partial( model.generate, **gen_kwargs ) generate_with_watermark = partial( model.generate, logits_processor=LogitsProcessorList([watermark_processor]), **gen_kwargs ) start_time = time.time() gr.Info('开始生成正常内容') torch.manual_seed(args.generation_seed) output_without_watermark = generate_without_watermark(**tokd_input) # 可选择在第二次生成之前种子,但通常不会再次相同,除非 delta==0.0,无操作水印 print(watermark_salt) print(default_trace_table) print(default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容']) gr.Info('开始注入水印:“{}”'.format( default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'].item())) if args.seed_separately: torch.manual_seed(args.generation_seed) output_with_watermark = generate_with_watermark(**tokd_input) output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:] output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:] decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0] decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0] end_time = time.time() gr.Info(f"生成结束,共用时{end_time - start_time:.2f}秒") print(f"Generation took {end_time - start_time:.2f} seconds") # 使用空格分隔生成器风格模拟 API 输出 all_without_words, all_with_words = "", "" for without_word, with_word in zip(decoded_output_without_watermark.split(), decoded_output_with_watermark.split()): all_without_words += without_word + " " all_with_words += with_word + " " yield all_without_words, all_with_words def format_names(s): """为 gradio 演示界面格式化名称""" s = s.replace("num_tokens_scored", "总Token") s = s.replace("num_green_tokens", "Green Token数量") s = s.replace("green_fraction", "Green Token占比") s = s.replace("z_score", "z-score") s = s.replace("p_value", "p value") s = s.replace("prediction", "预测结果") s = s.replace("confidence", "置信度") return s def list_format_scores(score_dict, detection_threshold): """将检测指标格式化为 gradio 数据框输入格式""" lst_2d = [] for k, v in score_dict.items(): if k == 'green_fraction': lst_2d.append([format_names(k), f"{v:.1%}"]) elif k == 'confidence': lst_2d.append([format_names(k), f"{v:.3%}"]) elif isinstance(v, float): lst_2d.append([format_names(k), f"{v:.3g}"]) elif isinstance(v, bool): lst_2d.append([format_names(k), ("含有水印" if v else "无水印")]) else: lst_2d.append([format_names(k), f"{v}"]) if "confidence" in score_dict: lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"]) else: lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"]) return lst_2d def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True): """实例化 WatermarkDetection 对象并调用 detect 方法 在输入文本上返回测试的分数和结果""" print(f"Detecting with {args}") print(f"Detection Tokenizer: {type(tokenizer)}") # 现在不要显示绿色的token mask # 如果我们使用的是normalizers或ignore_repeated_bigrams if args.normalizers != [] or args.ignore_repeated_bigrams: return_green_token_mask = False error = False green_token_mask = None if input_text == "": error = True else: try: for _, data in default_trace_table.iterrows(): salt = data["编号"] name = data["水印内容"] watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()), gamma=args.gamma, seeding_scheme=args.seeding_scheme, extra_salt=salt, device=device, tokenizer=tokenizer, z_threshold=args.detection_z_threshold, normalizers=args.normalizers, ignore_repeated_bigrams=args.ignore_repeated_bigrams, select_green_tokens=args.select_green_tokens) score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask) if score_dict['prediction']: print(f"检测到是“{name}”的水印") break green_token_mask = score_dict.pop("green_token_mask", None) output = list_format_scores(score_dict, watermark_detector.z_threshold) except ValueError as e: print(e) error = True if error: output = [["Error", "string too short to compute metrics"]] output += [["", ""] for _ in range(6)] html_output = "[No highlight markup generated]" if green_token_mask is None: html_output = "[Visualizing masks with ignore_repeated_bigrams enabled is not supported, toggle off to see the mask for this text. The mask is the same in both cases - only counting/stats are affected.]" if green_token_mask is not None: # hack 因为我们需要一个带有字符跨度支持的快速分词器 tokens = tokenizer(input_text, add_special_tokens=False) if tokens["input_ids"][0] == tokenizer.bos_token_id: tokens["input_ids"] = tokens["input_ids"][1:] # 忽略注意力掩码 skip = watermark_detector.min_prefix_len if args.model_name_or_path in ['THUDM/chatglm3-6b']: # 假设词表中3-258就是字节0-255 charspans = [] for i in range(skip, len(tokens["input_ids"])): if tokens.data['input_ids'][i - 1] in range(3, 259): charspans.append("<0x{:X}>".format(tokens.data['input_ids'][i - 1] - 3)) else: charspans.append(tokenizer.decode(tokens.data['input_ids'][i - 1:i])) else: charspans = [tokens.token_to_chars(i - 1) for i in range(skip, len(tokens["input_ids"]))] charspans = [cs for cs in charspans if cs is not None] # remove the special token spans if len(charspans) != len(green_token_mask): breakpoint() assert len(charspans) == len(green_token_mask) if args.model_name_or_path in ['THUDM/chatglm3-6b']: tags = [] for cs, m in zip(charspans, green_token_mask): tags.append( f'{cs}' if m else f'{cs}') else: tags = [( f'{input_text[cs.start:cs.end]}' if m else f'{input_text[cs.start:cs.end]}') for cs, m in zip(charspans, green_token_mask)] html_output = f'
{" ".join(tags)}
' if score_dict['prediction']: html_look = gr.HTML("""本方法可以对任何大模型的输出结果进行水印的操作。 并且可以对输出结果进行水印检测。
""") # 注册主要生成标签单击事件,输出生成文本以及编码+重新解码+可能被截断的提示和标志,然后调用检测 generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success( fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success( fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer], outputs=[without_watermark_detection_result, session_args, session_tokenizer, html_without_watermark, original_watermark_state]).success( fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer], outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark, change_watermark_state]) # 如果发生了截断,则显示提示的截断版本 redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args], outputs=[prompt, session_args]) # Register main detection tab click detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer], outputs=[detection_result, session_args, session_tokenizer, html_detection_input, detect_watermark_state], api_name="detection") # 状态管理逻辑 # 定义更新回调函数以更改状态字典 def update_model(session_state, value): session_state.model_name_or_path = value return session_state def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value) return session_state def update_generation_seed(session_state, value): session_state.generation_seed = int(value) return session_state def update_watermark_salt(value): global watermark_salt if isinstance(value, int): watermark_salt = value elif value is None: watermark_salt = 0 elif isinstance(value, str) and value.isdigit(): watermark_salt = int(value) else: # 不知道为什么会出现这种倒置的情况 watermark_salt = int( default_trace_table.loc[default_trace_table['水印内容'] == value, '编号'].item()) def update_trace_source(value): global default_trace_table try: if '' in value.loc[:, '编号'].tolist(): return value, gr.Dropdown() value.loc[:, '编号'] = value.loc[:, '编号'].astype(int) if default_trace_table.duplicated(subset='编号').any(): raise gr.Error(f"请检查水印编号,编号不能重复") default_trace_table = value return value, gr.Dropdown( choices=[i[::-1] for i in value.to_dict(orient='split')['data']]) except ValueError as e: if 'invalid literal for int() with base 10' in str(e): raise gr.Error(f"请检查水印数据,编号必须是整数:{e}") except gradio.exceptions.Error as e: raise e except Exception as e: print(type(e)) raise e def update_gamma(session_state, value): session_state.gamma = float(value) return session_state def update_delta(session_state, value): session_state.delta = float(value) return session_state def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value) return session_state def update_decoding(session_state, value): if value == "multinomial": session_state.use_sampling = True elif value == "greedy": session_state.use_sampling = False return session_state def toggle_sampling_vis(value): if value == "multinomial": return gr.update(visible=True) elif value == "greedy": return gr.update(visible=False) def toggle_sampling_vis_inv(value): if value == "multinomial": return gr.update(visible=False) elif value == "greedy": return gr.update(visible=True) # 如果模型名称在 API 模型列表中,则将 num beams 参数设置为 1 并隐藏 n_beams def toggle_vis_for_api_model(value): if value in API_MODEL_MAP: return gr.update(visible=False) else: return gr.update(visible=True) def toggle_beams_for_api_model(value, orig_n_beams): if value in API_MODEL_MAP: return gr.update(value=1) else: return gr.update(value=orig_n_beams) # 如果模型名称在 API 模型列表中,则将交互参数设置为 false def toggle_interactive_for_api_model(value): if value in API_MODEL_MAP: return gr.update(interactive=False) else: return gr.update(interactive=True) # 如果模型名称在 API 模型列表中,则根据 API 映射设置 gamma 和 delta def toggle_gamma_for_api_model(value, orig_gamma): if value in API_MODEL_MAP: return gr.update(value=API_MODEL_MAP[value]["gamma"]) else: return gr.update(value=orig_gamma) def toggle_delta_for_api_model(value, orig_delta): if value in API_MODEL_MAP: return gr.update(value=API_MODEL_MAP[value]["delta"]) else: return gr.update(value=orig_delta) def update_n_beams(session_state, value): session_state.n_beams = value; return session_state def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state def update_normalizers(session_state, value): session_state.normalizers = value; return session_state def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state def update_tokenizer(model_name_or_path): # if model_name_or_path == ALPACA_MODEL_NAME: # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH) # else: return AutoTokenizer.from_pretrained(model_name_or_path) def check_model(value): return value if (value != "" and value is not None) else args.model_name_or_path # 强制约束模型不能为 null 或空 # 然后特别附加模型回调函数 model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then( toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams] ).then( toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams] ).then( toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma] ).then( toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta] ).then( toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma] ).then( toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta] ).then( update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer] ).then( update_model, inputs=[session_args, model_selector], outputs=[session_args] ).then( lambda value: str(value), inputs=[session_args], outputs=[current_parameters] ) # 根据其他参数的值注册回调函数以切换特定参数的可见性 decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp]) decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed]) decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams]) decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]) # 注册所有状态更新回调函数 decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args]) sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args]) generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args]) watermark_salt_choice.change(update_watermark_salt, inputs=[watermark_salt_choice]) # 同步更新 trace_source.change(update_trace_source, inputs=[trace_source], outputs=[trace_source2, watermark_salt_choice]) trace_source2.change(update_trace_source, inputs=[trace_source2], outputs=[trace_source, watermark_salt_choice]) n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args]) max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args]) gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args]) delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args]) detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold], outputs=[session_args]) ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args]) normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args]) seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args]) select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens], outputs=[session_args]) # 注册按钮点击时更新显示参数窗口的额外回调 generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) # 当参数更改时,显示更新并触发检测,因为某些检测参数不会改变模型输出。 delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer], outputs=[without_watermark_detection_result, session_args, session_tokenizer, html_without_watermark]) gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer], outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark]) gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer], outputs=[detection_result, session_args, session_tokenizer, html_detection_input]) detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer], outputs=[without_watermark_detection_result, session_args, session_tokenizer, html_without_watermark]) detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer], outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark]) detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer], outputs=[detection_result, session_args, session_tokenizer, html_detection_input]) ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer], outputs=[without_watermark_detection_result, session_args, session_tokenizer, html_without_watermark]) ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer], outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark]) ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer], outputs=[detection_result, session_args, session_tokenizer, html_detection_input]) normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer], outputs=[without_watermark_detection_result, session_args, session_tokenizer, html_without_watermark]) normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer], outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark]) normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer], outputs=[detection_result, session_args, session_tokenizer, html_detection_input]) select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer], outputs=[without_watermark_detection_result, session_args, session_tokenizer, html_without_watermark]) select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer], outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark]) select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer], outputs=[detection_result, session_args, session_tokenizer, html_detection_input]) # demo.queue(concurrency_count=3) # delete if args.demo_public: demo.launch(share=True) # 通过随机生成的链接将应用程序暴露到互联网上 else: demo.launch(server_name='0.0.0.0', share=False) def main(args): """运行生成和检测操作的命令行版本 并可选择启动和提供 gradio 演示""" # 初始参数处理和日志记录 args.normalizers = (args.normalizers.split(",") if args.normalizers else []) print(args) if not args.skip_model_load: model, tokenizer, device = load_model(args) else: model, tokenizer, device = None, None, None tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) if args.use_gpu: device = "cuda:0" if torch.cuda.is_available() else "cpu" else: device = "cpu" # terrapin example input_text = ( "为什么A股指数跌的不多,但是我亏损比之前都多?" ) args.default_prompt = input_text # Generate and detect, report to stdout if not args.skip_model_load: pass # Launch the app to generate and detect interactively (implements the hf space demo) if args.run_gradio: run_gradio(args, model=model, tokenizer=tokenizer, device=device) return if __name__ == "__main__": args = parse_args() print(args) main(args)