Spaces:
Runtime error
Runtime error
File size: 5,444 Bytes
dd5beeb c4eea25 dd5beeb ad681ae 4491e36 dd5beeb 1dc9286 dd5beeb a2f05a9 dd5beeb 802f7de ad681ae a2f05a9 ad681ae 995dea4 a09d0a4 802f7de dd5beeb 87801f9 dd5beeb a2f05a9 1af9b4f dd5beeb 802f7de 6758170 dd5beeb a2f05a9 1af9b4f aab4f47 19b6ec4 1af9b4f dd5beeb a09d0a4 27511bd 013a7fe 27511bd 013a7fe dd5beeb 1af9b4f dd5beeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import argparse
import os
import numpy as np
from watermark import Watermarker
import time
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
#hf_token = os.getenv('HF_TOKEN')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
parser.add_argument('--model', '-m', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Language model')
parser.add_argument('--key', '-k', type=int, default=42,
help='The seed of the pseudo random number generator')
args = parser.parse_args()
USERS = ['Alice', 'Bob', 'Charlie', 'Dan']
EMBED_METHODS = [ 'aaronson', 'kirchenbauer', 'sampling', 'greedy' ]
DETECT_METHODS = [ 'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson', 'kirchenbauer']
PAYLOAD_BITS = 2
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16,
device_map='auto').to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.pad_token = tokenizer.eos_token
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
LEN_DEFAULT_PROMPT = len(tokenizer.encode(DEFAULT_SYSTEM_PROMPT))
def embed(user, max_length, window_size, method, prompt):
uid = USERS.index(user)
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
prompt = get_prompt(prompt)
watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
max_length=max_length+LEN_DEFAULT_PROMPT, method=method, prompt=prompt)
return watermarked_texts[0].split("[/INST]")[1]
def detect(attacked_text, window_size, method, prompt):
watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
prompt = get_prompt(prompt)
print([ prompt + attacked_text ])
pvalues, messages = watermarker.detect([ prompt + attacked_text ], key=args.key, method=method, prompts=[prompt])
print("messages: ", messages)
print("p-values: ", pvalues)
user = USERS[messages[0]]
pf = pvalues[0]
label = 'The user detected is {:s} with pvalue of {:.3e}'.format(user, pf)
return label
def get_prompt(message: str) -> str:
# texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
# texts.append(f'{message} [/INST]')
# return ''.join(texts)
return f"[INST] {message} [/INST]"
with gr.Blocks() as demo:
gr.Markdown("""# LLM generation watermarking
This spaces let you to try different watermarking scheme for LLM generation.\n
It leverages the upgrades introduced in the paper, reducing the gap between empirical and theoretical false positive detection rate and give the ability to embed a message (of n bits). Here we use this capacity to embed the identity of the user generating the text, but it could also be used to identify different version of a model or just convey a secret message.\n
Simply select an user name, set the maximum text length, the watermarking window size and the prompt. Aaronson and Kirchenbauer watermarking scheme are proposed, along traditional sampling and greedy search without watermarking.\n
Once the text is generated, you can eventually apply some attacks to it (e.g, remove words), select the associated detection method and run the detection. Please note that the detection is non-blind, and require the original prompt to be known and so left untouched.\n
For Aaronson, the original detection function, along the Neyman-Pearson and Simplified Score version are available.""")
with gr.Row():
user = gr.Dropdown(choices=USERS, value=USERS[0], label="User")
text_length = gr.Number(minimum=1, maximum=512, value=256, step=1, precision=0, label="Max text length")
window_size = gr.Number(minimum=0, maximum=10, value=0, step=1, precision=0, label="Watermarking window size")
embed_method = gr.Dropdown(choices=EMBED_METHODS, value=EMBED_METHODS[0], label="Sampling method")
prompt = gr.Textbox(label="prompt")
with gr.Row():
btn1 = gr.Button("Embed")
with gr.Row():
watermarked_text = gr.Textbox(label="Generated text")
detect_method = gr.Dropdown(choices=DETECT_METHODS, value=DETECT_METHODS[0], label="Detection method")
with gr.Row():
btn2 = gr.Button("Detect")
with gr.Row():
detection_label = gr.Label(label="Detection result")
btn1.click(fn=embed, inputs=[user, text_length, window_size, embed_method, prompt], outputs=[watermarked_text], api_name="watermark")
btn2.click(fn=detect, inputs=[watermarked_text, window_size, detect_method, prompt], outputs=[detection_label], api_name="detect")
demo.launch()
|