Antoine Chaffin
Moving the auth token to where the model is loaded
4491e36
raw
history blame
5.27 kB
import torch
import argparse
import os
import numpy as np
from watermark import Watermarker
import time
import gradio as gr
from transformers import AutoModelForCausalLM
hf_token = os.getenv('HF_TOKEN')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
parser.add_argument('--model', '-m', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='Language model')
parser.add_argument('--key', '-k', type=int, default=42,
help='The seed of the pseudo random number generator')
args = parser.parse_args()
USERS = ['Alice', 'Bob', 'Charlie', 'Dan']
EMBED_METHODS = [ 'aaronson', 'kirchenbauer', 'sampling', 'greedy' ]
DETECT_METHODS = [ 'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson', 'kirchenbauer']
PAYLOAD_BITS = 2
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
model = AutoModelForCausalLM.from_pretrained(args.model, use_auth_token=hf_token, torch_dtype=torch.float16,
device_map='auto').to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_auth_token=hf_token)
def embed(user, max_length, window_size, method, prompt):
uid = USERS.index(user)
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
max_length=max_length, method=method, prompt=prompt, window_size=window_size)
print("watermarked_texts: ", watermarked_texts)
return watermarked_texts[0]
def detect(attacked_text, window_size, method, prompt):
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
pvalues, messages = watermarker.detect([ attacked_text ], key=args.key, method=method, prompts=[prompt])
print("messages: ", messages)
print("p-values: ", pvalues)
user = USERS[messages[0]]
pf = pvalues[0]
label = 'The user detected is {:s} with pvalue of {:.3e}'.format(user, pf)
return label
def get_prompt(message: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
texts.append(f'{message} [/INST]')
return ''.join(texts)
with gr.Blocks() as demo:
gr.Markdown("""# LLM generation watermarking
This spaces let you to try different watermarking scheme for LLM generation.\n
It leverages the upgrades introduced in the paper, reducing the gap between empirical and theoretical false positive detection rate and give the ability to embed a message (of n bits). Here we use this capacity to embed the identity of the user generating the text, but it could also be used to identify different version of a model or just convey a secret message.\n
Simply select an user name, set the maximum text length, the watermarking window size and the prompt. Aaronson and Kirchenbauer watermarking scheme are proposed, along traditional sampling and greedy search without watermarking.\n
Once the text is generated, you can eventually apply some attacks to it (e.g, remove words), select the associated detection method and run the detection. Please note that the detection is non-blind, and require the original prompt to be known and so left untouched.\n
For Aaronson, the original detection function, along the Neyman-Pearson and Simplified Score version are available.""")
with gr.Row():
user = gr.Dropdown(choices=USERS, value=USERS[0], label="User")
text_length = gr.Number(minimum=1, maximum=512, value=256, step=1, precision=0, label="Max text length")
window_size = gr.Number(minimum=0, maximum=10, value=0, step=1, precision=0, label="Watermarking window size")
embed_method = gr.Dropdown(choices=EMBED_METHODS, value=EMBED_METHODS[0], label="Sampling method")
prompt = gr.Textbox(label="prompt")
with gr.Row():
btn1 = gr.Button("Embed")
with gr.Row():
watermarked_text = gr.Textbox(label="Generated text")
detect_method = gr.Dropdown(choices=DETECT_METHODS, value=DETECT_METHODS[0], label="Detection method")
with gr.Row():
btn2 = gr.Button("Detect")
with gr.Row():
detection_label = gr.Label(label="Detection result")
btn1.click(fn=embed, inputs=[user, text_length, window_size, embed_method, get_prompt(prompt)], outputs=[watermarked_text], api_name="watermark")
btn2.click(fn=detect, inputs=[watermarked_text, window_size, detect_method, get_prompt(prompt)], outputs=[detection_label], api_name="detect")
demo.launch()