XuandongZhao
init
6a20eb3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial
import numpy # for gradio hot reload
import gradio as gr
import pathlib
import torch
from transformers import (AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
LogitsProcessorList,
LlamaTokenizer)
from gptwm import GPTWatermarkDetector, GPTWatermarkLogitsWarper
def str2bool(v):
"""Util function for user friendly boolean flag args"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
"""Command line argument specification"""
parser = argparse.ArgumentParser()
parser.add_argument("--run_gradio",type=str2bool,default=True,help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.")
parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
parser.add_argument("--fraction", type=float, default=0.5)
parser.add_argument("--strength", type=float, default=2.0)
parser.add_argument("--wm_key", type=int, default=0)
parser.add_argument("--max_new_tokens", type=int, default=300)
parser.add_argument("--beam_size", type=int, default=None)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--test_min_tokens", type=int, default=200)
parser.add_argument("--threshold", type=float, default=6.0)
args = parser.parse_args()
return args
def load_model(args):
"""Load and return the model and tokenizer"""
hf_token = os.getenv('HF_TOKEN')
if 'llama' in args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(args.model_name, use_auth_token=hf_token, device_map='auto')
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
return model, tokenizer, device
def generate(prompt, args, model=None, device=None, tokenizer=None):
print(f"Generating with {args}")
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
strength=args.strength,
vocab_size=model.config.vocab_size,
watermark_key=args.wm_key)])
batch = tokenizer(prompt, truncation=True, return_tensors="pt").to(device)
num_tokens = len(batch['input_ids'][0])
with torch.inference_mode():
generate_args = {
**batch,
'output_scores': True,
'return_dict_in_generate': True,
'max_new_tokens': args.max_new_tokens,
}
if args.beam_size is not None:
generate_args['num_beams'] = args.beam_size
else:
generate_args['do_sample'] = True
generate_args['top_k'] = args.top_k
generate_args['top_p'] = args.top_p
generate_without_watermark = partial(
model.generate,
**generate_args
)
output_without_watermark = generate_without_watermark()
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
generate_with_watermark = partial(
model.generate,
logits_processor=watermark_processor,
**generate_args
)
output_with_watermark = generate_with_watermark()
decoded_gen_text_with_wm = tokenizer.batch_decode(output_with_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
return (prompt,
decoded_output_without_watermark,
decoded_gen_text_with_wm,
args)
def detect_demo(input_text, args, device=None, tokenizer=None):
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
watermark_detector = GPTWatermarkDetector(fraction=args.fraction,
strength=args.strength,
vocab_size=vocab_size,
watermark_key=args.wm_key)
output = []
html_output = ["Input text is too short to test."]
tokens = tokenizer(input_text, add_special_tokens=False)
gen_tokens = tokens["input_ids"]
if len(gen_tokens)>= args.test_min_tokens:
z_score,green_tokens_mask,green_tokens,total_tokens = watermark_detector.detect(gen_tokens)
output.append(['z-score', f"{z_score:.3g}"])
output.append(['green_tokens', f"{int(green_tokens):d}"])
output.append(['total_tokens', f"{int(total_tokens):d}"])
tokenarray =[tokens.token_to_chars(i) for i in range(0,len(gen_tokens))]
tags = [(f'<span class="green">{input_text[word.start:word.end]}</span>' if b else f'<span class="red">{input_text[word.start:word.end]}</span>') for word, b in zip(tokenarray, green_tokens_mask)]
html_output = f'<p>{" ".join(tags)}</p>'
else:
print(f"Input text is too short to test.")
return output,html_output, args
def run_gradio(args, model=None, device=None, tokenizer=None):
"""Define and launch the gradio demo interface"""
css = """
.green {
color: #008000 !important;
border: none;
font-weight: bold;
}
.red {
color: #ffad99 !important;
border: none;
font-weight: bold;
}
"""
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
detect_partial = partial(detect_demo, device=device, tokenizer=tokenizer)
with gr.Blocks(css=css) as demo:
# Top section, greeting and instructions
with gr.Row():
with gr.Row():
with gr.Column(scale=9):
gr.Markdown(
"""
## πŸ” Unigram-Watermark for AI-Generated Text
## [Paper](https://arxiv.org/abs/2306.17439) [GitHub](https://github.com/XuandongZhao/Unigram-Watermark)
"""
)
with gr.Accordion("Abstract",open=True):
gr.Markdown(
"""
We instantiate our language model watermarking with the **Unigram-Watermark**β€”β€”a variant of the K-gram watermark.
We prove that our watermark method enjoys guaranteed generation quality, correctness in watermark detection, and is robust against text editing and paraphrasing.
"""
)
gr.Markdown(f"Language model: {args.model_name}")
# Construct state for parameters, define updates and toggles
default_prompt = args.__dict__.pop("default_prompt")
session_args = gr.State(value=args)
with gr.Tab("Method"):
with gr.Accordion("Watermark process",open=True):
gr.Markdown(
"""
1. Randomly partition the vocabulary into two distinct sets: the green list with $\gamma N$ tokens and the red list with the remaining tokens.
2. In $\hat{M}$, the logits of the language model for the green list tokens are increased by $\delta$ while the logits for tokens in the red list remain unchanged.
"""
)
with gr.Accordion("Detect process",open=True):
gr.Markdown(
"""
1. Count the number of green tokens in the suspect text.
2. Normalize the test-statistic $z_{y}=(|y|_G-\gamma n) / \sqrt{n \gamma(1-\gamma)}$.
3. Make a calibrated decision on whether we think the suspect text is generated from $\hat{M}$ or not.
"""
)
with gr.Tab("Generate and Detect"):
with gr.Row():
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
with gr.Row():
generate_btn = gr.Button("Generate")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Output Without Watermark"):
output_without_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
with gr.Tab("Visualization"):# οΏ₯
html_without_watermark = gr.HTML(elem_id="html-without-watermark")
with gr.Column(scale=1):
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Output With Watermark"):
output_with_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
with gr.Tab("Visualization"):#
html_with_watermark = gr.HTML(elem_id="html-with-watermark")
with gr.Column(scale=1):
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
redecoded_input = gr.Textbox(visible=False)
truncation_warning = gr.Number(visible=False)
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
if truncation_warning:
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
else:
return orig_prompt, args
with gr.Tab("Detector Only"):
with gr.Row():
with gr.Column(scale=2):
# detect inputbox
with gr.Tab("Text to Analyze"):
detection_input = gr.Textbox(label="Input", interactive=True,lines=14,max_lines=14)
with gr.Tab("Visualization"):
html_detection = gr.HTML(elem_id="html-detection")
with gr.Column(scale=1):
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
with gr.Row():
# detect
detect_btn = gr.Button("Detect")
generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, output_without_watermark, output_with_watermark,session_args])
# Show truncated version of prompt if truncation occurred
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
# Call detection when the outputs (of the generate function) are updated
output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,html_without_watermark,session_args])
output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,html_with_watermark,session_args])
# Register main detection tab click
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, html_detection,session_args])
demo.launch()
def main(args):
"""Run a command line version of the generation and detection operations
and optionally launch and serve the gradio demo"""
# Initial arg processing and log
model, tokenizer, device = load_model(args)
# Generate and detect, report to stdout
input_text = (
"One tank tumbled down an embankment into the Tenaru River, drowning its crew."
" At 23:00 on 14 September, the remnants of the Kuma battalion conducted another attack on the same portion of the Marine lines, but were repulsed. "
"A final \"weak\" attack by the Kuma unit on the evening of 15 September was also defeated. Oka's unit of about 650 men attacked the Marines at several locations on the west side of the Lunga perimeter."
" At about 04:00 on 14 September, two Japanese companies attacked positions held by the 3rd Battalion, 5th Marine Regiment (3/5) near the coast and were thrown back with heavy losses."
" Another Japanese company captured a small ridge somewhat inland but was then pinned down by Marine artillery fire throughout the day and took heavy losses before withdrawing on the evening of 14 September."
" The rest of Oka's unit failed to find the Marine lines and did not participate in the attack. "
"At 13:05 on 14 September, Kawaguchi led the survivors of his shattered brigade away from the ridge and deeper into the jungle, where they rested and tended to their wounded all the next day. "
"Kawaguchi's units were then ordered to withdraw west to the Matanikau River valley to join with Oka's unit, a march over difficult terrain."
" Kawaguchi's troops began the march on the morning of 16 September."
" Almost every soldier able to walk had to help carry the wounded. "
"As the march progressed, the exhausted and hungry soldiers, who had eaten their last rations on the morning before their withdrawal, began to discard their heavy equipment and then their rifles. "
"By the time most of them reached Oka's positions at Kokumbona five days later, only half still carried their weapons."
" The Kuma battalion's survivors, attempting to follow Kawaguchi's Center Body forces, became lost, wandered for three weeks in the jungle, and almost starved to death before finally reaching Kawaguchi's camp."
)
args.default_prompt = input_text
# Launch the app to generate and detect interactively (implements the hf space demo)
if args.run_gradio:
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
return
if __name__ == "__main__":
args = parse_args()
print(args)
main(args)