|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import argparse |
|
from pprint import pprint |
|
from functools import partial |
|
|
|
import numpy |
|
import gradio as gr |
|
|
|
import torch |
|
|
|
from transformers import (AutoTokenizer, |
|
AutoModelForSeq2SeqLM, |
|
AutoModelForCausalLM, |
|
LogitsProcessorList) |
|
|
|
from transformers import GPT2TokenizerFast |
|
OPT_TOKENIZER = GPT2TokenizerFast |
|
|
|
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector |
|
|
|
API_MODEL_MAP = { |
|
"bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, |
|
"bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, |
|
"google/flan-ul2" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, |
|
"google/flan-t5-xxl" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, |
|
"EleutherAI/gpt-neox-20b" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, |
|
} |
|
|
|
def str2bool(v): |
|
"""Util function for user friendly boolean flag args""" |
|
if isinstance(v, bool): |
|
return v |
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
return True |
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
|
def parse_args(): |
|
"""Command line argument specification""" |
|
|
|
parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API") |
|
|
|
parser.add_argument( |
|
"--run_gradio", |
|
type=str2bool, |
|
default=True, |
|
help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.", |
|
) |
|
parser.add_argument( |
|
"--demo_public", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to expose the gradio demo to the internet.", |
|
) |
|
parser.add_argument( |
|
"--model_name_or_path", |
|
type=str, |
|
default="facebook/opt-6.7b", |
|
help="Main model, path to pretrained model or model identifier from huggingface.co/models.", |
|
) |
|
parser.add_argument( |
|
"--prompt_max_length", |
|
type=int, |
|
default=None, |
|
help="Truncation length for prompt, overrides model config's max length field.", |
|
) |
|
parser.add_argument( |
|
"--max_new_tokens", |
|
type=int, |
|
default=200, |
|
help="Maximmum number of new tokens to generate.", |
|
) |
|
parser.add_argument( |
|
"--generation_seed", |
|
type=int, |
|
default=123, |
|
help="Seed for setting the torch global rng prior to generation.", |
|
) |
|
parser.add_argument( |
|
"--use_sampling", |
|
type=str2bool, |
|
default=True, |
|
help="Whether to generate using multinomial sampling.", |
|
) |
|
parser.add_argument( |
|
"--sampling_temp", |
|
type=float, |
|
default=0.7, |
|
help="Sampling temperature to use when generating using multinomial sampling.", |
|
) |
|
parser.add_argument( |
|
"--n_beams", |
|
type=int, |
|
default=1, |
|
help="Number of beams to use for beam search. 1 is normal greedy decoding", |
|
) |
|
parser.add_argument( |
|
"--use_gpu", |
|
type=str2bool, |
|
default=True, |
|
help="Whether to run inference and watermark hashing/seeding/permutation on gpu.", |
|
) |
|
parser.add_argument( |
|
"--seeding_scheme", |
|
type=str, |
|
default="simple_1", |
|
help="Seeding scheme to use to generate the greenlists at each generation and verification step.", |
|
) |
|
parser.add_argument( |
|
"--gamma", |
|
type=float, |
|
default=0.25, |
|
help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.", |
|
) |
|
parser.add_argument( |
|
"--delta", |
|
type=float, |
|
default=2.0, |
|
help="The amount/bias to add to each of the greenlist token logits before each token sampling step.", |
|
) |
|
parser.add_argument( |
|
"--normalizers", |
|
type=str, |
|
default="", |
|
help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", |
|
) |
|
parser.add_argument( |
|
"--ignore_repeated_bigrams", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", |
|
) |
|
parser.add_argument( |
|
"--detection_z_threshold", |
|
type=float, |
|
default=4.0, |
|
help="The test statistic threshold for the detection hypothesis test.", |
|
) |
|
parser.add_argument( |
|
"--select_green_tokens", |
|
type=str2bool, |
|
default=True, |
|
help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", |
|
) |
|
parser.add_argument( |
|
"--skip_model_load", |
|
type=str2bool, |
|
default=False, |
|
help="Skip the model loading to debug the interface.", |
|
) |
|
parser.add_argument( |
|
"--seed_separately", |
|
type=str2bool, |
|
default=True, |
|
help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.", |
|
) |
|
parser.add_argument( |
|
"--load_fp16", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to run model in float16 precsion.", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
def load_model(args): |
|
"""Load and return the model and tokenizer""" |
|
|
|
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]]) |
|
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]]) |
|
if args.is_seq2seq_model: |
|
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) |
|
elif args.is_decoder_only_model: |
|
if args.load_fp16: |
|
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto') |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) |
|
else: |
|
raise ValueError(f"Unknown model type: {args.model_name_or_path}") |
|
|
|
if args.use_gpu: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if args.load_fp16: |
|
pass |
|
else: |
|
model = model.to(device) |
|
else: |
|
device = "cpu" |
|
model.eval() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) |
|
|
|
return model, tokenizer, device |
|
|
|
|
|
from text_generation import InferenceAPIClient |
|
from requests.exceptions import ReadTimeout |
|
def generate_with_api(prompt, args): |
|
hf_api_key = os.environ.get("HF_API_KEY") |
|
if hf_api_key is None: |
|
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.") |
|
|
|
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60) |
|
|
|
assert args.n_beams == 1, "HF API models do not support beam search." |
|
generation_params = { |
|
"max_new_tokens": args.max_new_tokens, |
|
"do_sample": args.use_sampling, |
|
} |
|
if args.use_sampling: |
|
generation_params["temperature"] = args.sampling_temp |
|
generation_params["seed"] = args.generation_seed |
|
|
|
timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]" |
|
try: |
|
generation_params["watermark"] = False |
|
without_watermark_iterator = client.generate_stream(prompt, **generation_params) |
|
except ReadTimeout as e: |
|
print(e) |
|
without_watermark_iterator = (char for char in timeout_msg) |
|
try: |
|
generation_params["watermark"] = True |
|
with_watermark_iterator = client.generate_stream(prompt, **generation_params) |
|
except ReadTimeout as e: |
|
print(e) |
|
with_watermark_iterator = (char for char in timeout_msg) |
|
|
|
all_without_words, all_with_words = "", "" |
|
for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator): |
|
all_without_words += without_word.token.text |
|
all_with_words += with_word.token.text |
|
yield all_without_words, all_with_words |
|
|
|
|
|
def check_prompt(prompt, args, tokenizer, model=None, device=None): |
|
|
|
|
|
if args.model_name_or_path in API_MODEL_MAP: |
|
args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"] |
|
elif hasattr(model.config,"max_position_embedding"): |
|
args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens |
|
else: |
|
args.prompt_max_length = 2048-args.max_new_tokens |
|
|
|
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) |
|
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False |
|
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0] |
|
|
|
return (redecoded_input, |
|
int(truncation_warning), |
|
args) |
|
|
|
|
|
|
|
def generate(prompt, args, tokenizer, model=None, device=None): |
|
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters |
|
and generate watermarked text by passing it to the generate method of the model |
|
as a logits processor. """ |
|
|
|
print(f"Generating with {args}") |
|
print(f"Prompt: {prompt}") |
|
|
|
if args.model_name_or_path in API_MODEL_MAP: |
|
api_outputs = generate_with_api(prompt, args) |
|
yield from api_outputs |
|
else: |
|
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) |
|
|
|
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()), |
|
gamma=args.gamma, |
|
delta=args.delta, |
|
seeding_scheme=args.seeding_scheme, |
|
select_green_tokens=args.select_green_tokens) |
|
|
|
gen_kwargs = dict(max_new_tokens=args.max_new_tokens) |
|
|
|
if args.use_sampling: |
|
gen_kwargs.update(dict( |
|
do_sample=True, |
|
top_k=0, |
|
temperature=args.sampling_temp |
|
)) |
|
else: |
|
gen_kwargs.update(dict( |
|
num_beams=args.n_beams |
|
)) |
|
|
|
generate_without_watermark = partial( |
|
model.generate, |
|
**gen_kwargs |
|
) |
|
generate_with_watermark = partial( |
|
model.generate, |
|
logits_processor=LogitsProcessorList([watermark_processor]), |
|
**gen_kwargs |
|
) |
|
|
|
torch.manual_seed(args.generation_seed) |
|
output_without_watermark = generate_without_watermark(**tokd_input) |
|
|
|
|
|
if args.seed_separately: |
|
torch.manual_seed(args.generation_seed) |
|
output_with_watermark = generate_with_watermark(**tokd_input) |
|
|
|
if args.is_decoder_only_model: |
|
|
|
output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:] |
|
output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:] |
|
|
|
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0] |
|
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0] |
|
|
|
|
|
all_without_words, all_with_words = "", "" |
|
for without_word, with_word in zip(decoded_output_without_watermark.split(), decoded_output_with_watermark.split()): |
|
all_without_words += without_word + " " |
|
all_with_words += with_word + " " |
|
yield all_without_words, all_with_words |
|
|
|
|
|
def format_names(s): |
|
"""Format names for the gradio demo interface""" |
|
s=s.replace("num_tokens_scored","Tokens Counted (T)") |
|
s=s.replace("num_green_tokens","# Tokens in Greenlist") |
|
s=s.replace("green_fraction","Fraction of T in Greenlist") |
|
s=s.replace("z_score","z-score") |
|
s=s.replace("p_value","p value") |
|
s=s.replace("prediction","Prediction") |
|
s=s.replace("confidence","Confidence") |
|
return s |
|
|
|
def list_format_scores(score_dict, detection_threshold): |
|
"""Format the detection metrics into a gradio dataframe input format""" |
|
lst_2d = [] |
|
for k,v in score_dict.items(): |
|
if k=='green_fraction': |
|
lst_2d.append([format_names(k), f"{v:.1%}"]) |
|
elif k=='confidence': |
|
lst_2d.append([format_names(k), f"{v:.3%}"]) |
|
elif isinstance(v, float): |
|
lst_2d.append([format_names(k), f"{v:.3g}"]) |
|
elif isinstance(v, bool): |
|
lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")]) |
|
else: |
|
lst_2d.append([format_names(k), f"{v}"]) |
|
if "confidence" in score_dict: |
|
lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"]) |
|
else: |
|
lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"]) |
|
return lst_2d |
|
|
|
def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True): |
|
"""Instantiate the WatermarkDetection object and call detect on |
|
the input text returning the scores and outcome of the test""" |
|
|
|
print(f"Detecting with {args}") |
|
print(f"Detection Tokenizer: {type(tokenizer)}") |
|
|
|
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()), |
|
gamma=args.gamma, |
|
seeding_scheme=args.seeding_scheme, |
|
device=device, |
|
tokenizer=tokenizer, |
|
z_threshold=args.detection_z_threshold, |
|
normalizers=args.normalizers, |
|
ignore_repeated_bigrams=args.ignore_repeated_bigrams, |
|
select_green_tokens=args.select_green_tokens) |
|
|
|
|
|
if args.normalizers != [] or args.ignore_repeated_bigrams: |
|
return_green_token_mask = False |
|
|
|
error = False |
|
green_token_mask = None |
|
if input_text == "": |
|
error = True |
|
else: |
|
try: |
|
score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask) |
|
green_token_mask = score_dict.pop("green_token_mask", None) |
|
output = list_format_scores(score_dict, watermark_detector.z_threshold) |
|
except ValueError as e: |
|
print(e) |
|
error = True |
|
if error: |
|
output = [["Error","string too short to compute metrics"]] |
|
output += [["",""] for _ in range(6)] |
|
|
|
html_output = "[No highlight markup generated]" |
|
if green_token_mask is not None: |
|
|
|
if "opt" in args.model_name_or_path: |
|
tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path) |
|
|
|
tokens = tokenizer(input_text) |
|
if tokens["input_ids"][0] == tokenizer.bos_token_id: |
|
tokens["input_ids"] = tokens["input_ids"][1:] |
|
skip = watermark_detector.min_prefix_len |
|
charspans = [tokens.token_to_chars(i) for i in range(skip,len(tokens["input_ids"]))] |
|
charspans = [cs for cs in charspans if cs is not None] |
|
|
|
if len(charspans) != len(green_token_mask): breakpoint() |
|
assert len(charspans) == len(green_token_mask) |
|
|
|
tags = [(f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>') for cs, m in zip(charspans, green_token_mask)] |
|
html_output = f'<p>{" ".join(tags)}</p>' |
|
|
|
return output, args, tokenizer, html_output |
|
|
|
def run_gradio(args, model=None, device=None, tokenizer=None): |
|
"""Define and launch the gradio demo interface""" |
|
check_prompt_partial = partial(check_prompt, model=model, device=device) |
|
generate_partial = partial(generate, model=model, device=device) |
|
detect_partial = partial(detect, device=device) |
|
|
|
|
|
css = """ |
|
.green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;} |
|
.red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=9): |
|
gr.Markdown( |
|
""" |
|
## 💧 A Watermark for Large Language Models 🔍 |
|
|
|
Demo made possible by the HuggingFace 🤗 [text-generation-inference](https://github.com/huggingface/text-generation-inference) serving framework. |
|
""" |
|
) |
|
with gr.Column(scale=1): |
|
|
|
all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path]))) |
|
model_selector = gr.Dropdown( |
|
all_models, |
|
value=args.model_name_or_path, |
|
label="Language Model", |
|
) |
|
|
|
|
|
default_prompt = args.__dict__.pop("default_prompt") |
|
session_args = gr.State(value=args) |
|
|
|
session_tokenizer = gr.State(value=lambda : tokenizer) |
|
|
|
with gr.Tab("Welcome"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown( |
|
""" |
|
Potential harms of large language models can be mitigated by *watermarking* a model's output. |
|
*Watermarks* are embedded signals in the generated text that are invisible to humans but algorithmically |
|
detectable, that allow *anyone* to later check whether a given span of text |
|
was likely to have been generated by a model that uses the watermark. |
|
|
|
This space showcases a watermarking approach that can be applied to _any_ generative language model. |
|
For demonstration purposes, the space demos a selection of multi-billion parameter models (see the following note for caveats). |
|
""" |
|
) |
|
with gr.Accordion("A note on the available models:",open=False): |
|
gr.Markdown( |
|
""" |
|
This demo uses open-source language models. Today, these models are less powerful than proprietary commercial tools like ChatGPT, Claude, Bard, or Bing/Sydney. |
|
|
|
Models like [BLOOM (175B)](https://huggingface.co/bigscience/bloom) are designed to "complete" your prompt, and are not fine-tuned to follow instructions. For best results, prompt that model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph. Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story. Longer prompts that end mid-sentence will result in more fluent generations. |
|
|
|
Some of the models available in this demo are fine-tuned to follow instructions but have different strengths and will showcase differenttypes of watermark behavior. [BLOOMZ](https://huggingface.co/bigscience/bloomz) is an instruction tuned variant of BLOOM capable of following instructions in dozens of languages zero-shot and can generate long and coherent paragraphs and stories given the right prompt. The FLAN models [FLAN-t5-xxl (11B)](https://huggingface.co/google/flan-t5-xxl) and [FLAN-UL2 (20B)](https://huggingface.co/google/flan-ul2) are fine-tuned on a variety of in-context few-shot learning NLP tasks, such as reasoning, and question answering. |
|
|
|
Generally, short, low entropy scenarios where the model has very few choices in terms of correct/suitable responses to the prompt will not exhibit as strong of a watermark presence, while longer watermarked outputs will produce higher detection statistics. |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
**[Generate & Detect]**: The first tab shows that the watermark can be embedded with |
|
negligible impact on text quality. You can try any prompt and compare the quality of |
|
normal text (*Output Without Watermark*) to the watermarked text (*Output With Watermark*) below it. |
|
You can also "see" the watermark by looking at the **Highlighted** tab where the tokens are |
|
colored green or red depending on which list they are in. |
|
Metrics on the right show that the watermark can be reliably detected given a reasonably small number of tokens (25-50). |
|
Detection is very efficient and does not use the language model or its parameters. |
|
|
|
**[Detector Only]**: You can also copy-paste the watermarked text (or any other text) |
|
into the second tab. This can be used to see how many sentences you could remove and still detect the watermark. |
|
You can also verify here that the detection has, by design, a low false-positive rate; |
|
This means that human-generated text that you copy into this detector will not be marked as machine-generated. |
|
""" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown( |
|
""" |
|
![](https://drive.google.com/uc?export=view&id=1yVLPcjm-xvaCjQyc3FGLsWIU84v1QRoC) |
|
""" |
|
) |
|
|
|
with gr.Tab("Generate & Detect"): |
|
|
|
with gr.Row(): |
|
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt) |
|
with gr.Row(): |
|
generate_btn = gr.Button("Generate") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Tab("Output Without Watermark (Raw Text)"): |
|
output_without_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14) |
|
with gr.Tab("Highlighted"): |
|
html_without_watermark = gr.HTML(elem_id="html-without-watermark") |
|
with gr.Column(scale=1): |
|
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Tab("Output With Watermark (Raw Text)"): |
|
output_with_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14) |
|
with gr.Tab("Highlighted"): |
|
html_with_watermark = gr.HTML(elem_id="html-with-watermark") |
|
with gr.Column(scale=1): |
|
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2) |
|
|
|
redecoded_input = gr.Textbox(visible=False) |
|
truncation_warning = gr.Number(visible=False) |
|
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args): |
|
if truncation_warning: |
|
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args |
|
else: |
|
return orig_prompt, args |
|
|
|
with gr.Tab("Detector Only"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Tab("Text to Analyze"): |
|
detection_input = gr.Textbox(interactive=True,lines=14,max_lines=14) |
|
with gr.Tab("Highlighted"): |
|
html_detection_input = gr.HTML(elem_id="html-detection-input") |
|
with gr.Column(scale=1): |
|
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) |
|
with gr.Row(): |
|
detect_btn = gr.Button("Detect") |
|
|
|
|
|
with gr.Accordion("Advanced Settings",open=False): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown(f"#### Generation Parameters") |
|
with gr.Row(): |
|
decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy")) |
|
with gr.Row(): |
|
sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True) |
|
with gr.Row(): |
|
generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True) |
|
with gr.Row(): |
|
n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=((not args.use_sampling) and (not args.model_name_or_path in API_MODEL_MAP))) |
|
with gr.Row(): |
|
max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown(f"#### Watermark Parameters") |
|
with gr.Row(): |
|
gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma) |
|
with gr.Row(): |
|
delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta) |
|
gr.Markdown(f"#### Detector Parameters") |
|
with gr.Row(): |
|
detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold) |
|
with gr.Row(): |
|
ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats") |
|
with gr.Row(): |
|
normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers) |
|
with gr.Row(): |
|
gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._") |
|
with gr.Row(): |
|
current_parameters = gr.Textbox(label="Current Parameters", value=args) |
|
with gr.Accordion("Legacy Settings",open=False): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately) |
|
with gr.Column(scale=1): |
|
select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens) |
|
|
|
|
|
with gr.Accordion("What do the settings do?",open=False): |
|
gr.Markdown( |
|
""" |
|
#### Generation Parameters: |
|
|
|
- **Decoding Method** : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding. |
|
- **Sampling Temperature** : If using multinomial sampling we can set the temperature of the sampling distribution. |
|
0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution. |
|
0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding. |
|
- **Generation Seed** : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy |
|
outputs reproducible. Does not apply for greedy decoding. |
|
- **Number of Beams** : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search. |
|
This is not implemented/excluded from paper for multinomial sampling but may be added in future. |
|
- **Max Generated Tokens** : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens. |
|
Note that the model is free to generate fewer tokens depending on the prompt. |
|
Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`, |
|
and inputs will be truncated accordingly. |
|
|
|
#### Watermark Parameters: |
|
|
|
- **gamma** : The fraction of the vocabulary to be partitioned into the greenlist at each generation step. |
|
Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve |
|
a greater differentiation from human/unwatermarked text because it is preferentially sampling |
|
from a smaller green set making those tokens less likely to occur by chance. |
|
- **delta** : The amount of positive bias to add to the logits of every token in the greenlist |
|
at each generation step before sampling/choosing the next token. Higher delta values |
|
mean that the greenlist tokens are more heavily preferred by the watermarked model |
|
and as the bias becomes very large the watermark transitions from "soft" to "hard". |
|
For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on |
|
generation quality, especially when there is not a lot of flexibility in the distribution. |
|
|
|
#### Detector Parameters: |
|
|
|
- **z-score threshold** : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make |
|
_false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely |
|
as a genuine human text with a significant number of tokens will almost never achieve |
|
that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked |
|
texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and |
|
be flagged as "watermarked". However, a lowere threshold will increase the chance that human text |
|
that contains a slightly higher than average number of green tokens is erroneously flagged. |
|
4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text. |
|
- **Ignore Bigram Repeats** : This alternate detection algorithm only considers the unique bigrams in the text during detection, |
|
computing the greenlists based on the first in each pair and checking whether the second falls within the list. |
|
This means that `T` is now the unique number of bigrams in the text, which becomes less than the total |
|
number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion. |
|
- **Normalizations** : we implement a few basic normaliations to defend against various adversarial perturbations of the |
|
text analyzed during detection. Currently we support converting all chracters to unicode, |
|
replacing homoglyphs with a canonical form, and standardizing the capitalization. |
|
See the paper for a detailed discussion of input normalization. |
|
""" |
|
) |
|
|
|
with gr.Accordion("What do the output metrics mean?",open=False): |
|
gr.Markdown( |
|
""" |
|
- `z-score threshold` : The cuttoff for the hypothesis test |
|
- `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm. |
|
The first token is ommitted in the simple, single token seeding scheme since there is no way to generate |
|
a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm, |
|
described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition. |
|
- `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist |
|
- `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text. |
|
- `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold` |
|
we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked |
|
- `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of |
|
observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists. |
|
If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance. |
|
- `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold` |
|
- `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent |
|
the confidence of the detection based on the unlikeliness of this `z-score` observation. |
|
""" |
|
) |
|
|
|
|
|
generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success( |
|
fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success( |
|
fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success( |
|
fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) |
|
|
|
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args]) |
|
|
|
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer,html_detection_input], api_name="detection") |
|
|
|
|
|
|
|
def update_model(session_state, value): session_state.model_name_or_path = value; return session_state |
|
def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state |
|
def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state |
|
def update_gamma(session_state, value): session_state.gamma = float(value); return session_state |
|
def update_delta(session_state, value): session_state.delta = float(value); return session_state |
|
def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state |
|
def update_decoding(session_state, value): |
|
if value == "multinomial": |
|
session_state.use_sampling = True |
|
elif value == "greedy": |
|
session_state.use_sampling = False |
|
return session_state |
|
def toggle_sampling_vis(value): |
|
if value == "multinomial": |
|
return gr.update(visible=True) |
|
elif value == "greedy": |
|
return gr.update(visible=False) |
|
def toggle_sampling_vis_inv(value): |
|
if value == "multinomial": |
|
return gr.update(visible=False) |
|
elif value == "greedy": |
|
return gr.update(visible=True) |
|
|
|
def toggle_vis_for_api_model(value): |
|
if value in API_MODEL_MAP: |
|
return gr.update(visible=False) |
|
else: |
|
return gr.update(visible=True) |
|
def toggle_beams_for_api_model(value, orig_n_beams): |
|
if value in API_MODEL_MAP: |
|
return gr.update(value=1) |
|
else: |
|
return gr.update(value=orig_n_beams) |
|
|
|
def toggle_interactive_for_api_model(value): |
|
if value in API_MODEL_MAP: |
|
return gr.update(interactive=False) |
|
else: |
|
return gr.update(interactive=True) |
|
|
|
def toggle_gamma_for_api_model(value, orig_gamma): |
|
if value in API_MODEL_MAP: |
|
return gr.update(value=API_MODEL_MAP[value]["gamma"]) |
|
else: |
|
return gr.update(value=orig_gamma) |
|
def toggle_delta_for_api_model(value, orig_delta): |
|
if value in API_MODEL_MAP: |
|
return gr.update(value=API_MODEL_MAP[value]["delta"]) |
|
else: |
|
return gr.update(value=orig_delta) |
|
|
|
def update_n_beams(session_state, value): session_state.n_beams = value; return session_state |
|
def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state |
|
def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state |
|
def update_normalizers(session_state, value): session_state.normalizers = value; return session_state |
|
def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state |
|
def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state |
|
def update_tokenizer(model_name_or_path): |
|
|
|
|
|
|
|
return AutoTokenizer.from_pretrained(model_name_or_path) |
|
|
|
def check_model(value): return value if (value!="" and value is not None) else args.model_name_or_path |
|
|
|
|
|
model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then( |
|
toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams] |
|
).then( |
|
toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams] |
|
).then( |
|
toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma] |
|
).then( |
|
toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta] |
|
).then( |
|
toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma] |
|
).then( |
|
toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta] |
|
).then( |
|
update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer] |
|
).then( |
|
update_model,inputs=[session_args, model_selector], outputs=[session_args] |
|
).then( |
|
lambda value: str(value), inputs=[session_args], outputs=[current_parameters] |
|
) |
|
|
|
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp]) |
|
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed]) |
|
decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams]) |
|
decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]) |
|
|
|
decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args]) |
|
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args]) |
|
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args]) |
|
n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args]) |
|
max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args]) |
|
gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args]) |
|
delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args]) |
|
detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args]) |
|
ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args]) |
|
normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args]) |
|
seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args]) |
|
select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args]) |
|
|
|
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
|
|
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]) |
|
gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) |
|
gamma.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input]) |
|
detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]) |
|
detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) |
|
detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input]) |
|
ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]) |
|
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) |
|
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input]) |
|
normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]) |
|
normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) |
|
normalizers.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input]) |
|
select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) |
|
select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]) |
|
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) |
|
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input]) |
|
|
|
demo.queue(concurrency_count=3) |
|
|
|
if args.demo_public: |
|
demo.launch(share=True) |
|
else: |
|
demo.launch() |
|
|
|
def main(args): |
|
"""Run a command line version of the generation and detection operations |
|
and optionally launch and serve the gradio demo""" |
|
|
|
args.normalizers = (args.normalizers.split(",") if args.normalizers else []) |
|
print(args) |
|
|
|
if not args.skip_model_load: |
|
model, tokenizer, device = load_model(args) |
|
else: |
|
model, tokenizer, device = None, None, None |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) |
|
if args.use_gpu: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
else: |
|
device = "cpu" |
|
|
|
|
|
|
|
input_text = ( |
|
"The aardvark is sometimes colloquially called the 'African ant bear',[6] 'anteater' (not to be confused with the South American anteater), or the 'Cape anteater'[6] after the Cape of Good Hope. The name 'aardvark' is Afrikaans (Afrikaans pronunciation: [ˈɑːrtfark]), comes from earlier Afrikaans erdvark[6] and means 'earth pig' or 'ground pig' (aarde: 'earth/ground', vark: 'pig'), because of its burrowing habits.[7][8][9] The name Orycteropus means 'burrowing foot', and the name afer refers to Africa.[10] The name of the aardvark's order, Tubulidentata, comes from the tubule-style teeth.[11]\n\nThe aardvark is not closely related to the pig; rather, it is the sole extant representative of the obscure mammalian order Tubulidentata,[10] in which it is usually considered to form one variable species of the genus Orycteropus, the sole surviving genus in the family Orycteropodidae. The aardvark is" |
|
) |
|
|
|
args.default_prompt = input_text |
|
|
|
|
|
|
|
if not args.skip_model_load: |
|
|
|
term_width = 80 |
|
print("#"*term_width) |
|
print("Prompt:") |
|
print(input_text) |
|
|
|
|
|
generator_outputs = generate(input_text, |
|
args, |
|
model=model, |
|
device=device, |
|
tokenizer=tokenizer) |
|
|
|
|
|
for out in generator_outputs: |
|
decoded_output_without_watermark = out[0] |
|
decoded_output_with_watermark = out[1] |
|
|
|
without_watermark_detection_result = detect(decoded_output_without_watermark, |
|
args, |
|
device=device, |
|
tokenizer=tokenizer, |
|
return_green_token_mask=False) |
|
with_watermark_detection_result = detect(decoded_output_with_watermark, |
|
args, |
|
device=device, |
|
tokenizer=tokenizer, |
|
return_green_token_mask=False) |
|
|
|
print("#"*term_width) |
|
print("Output without watermark:") |
|
print(decoded_output_without_watermark) |
|
print("-"*term_width) |
|
print(f"Detection result @ {args.detection_z_threshold}:") |
|
pprint(without_watermark_detection_result) |
|
print("-"*term_width) |
|
|
|
print("#"*term_width) |
|
print("Output with watermark:") |
|
print(decoded_output_with_watermark) |
|
print("-"*term_width) |
|
print(f"Detection result @ {args.detection_z_threshold}:") |
|
pprint(with_watermark_detection_result) |
|
print("-"*term_width) |
|
|
|
|
|
|
|
if args.run_gradio: |
|
run_gradio(args, model=model, tokenizer=tokenizer, device=device) |
|
|
|
return |
|
|
|
if __name__ == "__main__": |
|
|
|
args = parse_args() |
|
print(args) |
|
|
|
main(args) |