File size: 5,444 Bytes
dd5beeb
 
 
 
 
 
 
 
c4eea25
dd5beeb
ad681ae
4491e36
dd5beeb
 
 
1dc9286
dd5beeb
 
 
 
 
 
 
 
 
a2f05a9
dd5beeb
802f7de
ad681ae
a2f05a9
ad681ae
995dea4
a09d0a4
802f7de
 
 
 
 
dd5beeb
87801f9
dd5beeb
a2f05a9
1af9b4f
dd5beeb
802f7de
6758170
dd5beeb
 
a2f05a9
1af9b4f
aab4f47
19b6ec4
1af9b4f
dd5beeb
 
 
 
 
 
 
 
a09d0a4
27511bd
013a7fe
27511bd
 
 
013a7fe
dd5beeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af9b4f
 
dd5beeb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import argparse
import os
import numpy as np

from watermark import Watermarker
import time
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

#hf_token = os.getenv('HF_TOKEN')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
parser.add_argument('--model', '-m', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Language model')
parser.add_argument('--key', '-k', type=int, default=42,
                    help='The seed of the pseudo random number generator')

args = parser.parse_args()

USERS = ['Alice', 'Bob', 'Charlie', 'Dan']
EMBED_METHODS = [ 'aaronson', 'kirchenbauer', 'sampling', 'greedy' ]
DETECT_METHODS = [ 'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson', 'kirchenbauer']
PAYLOAD_BITS = 2
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')


model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16,
        device_map='auto').to(device)
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.pad_token = tokenizer.eos_token

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
LEN_DEFAULT_PROMPT = len(tokenizer.encode(DEFAULT_SYSTEM_PROMPT))
def embed(user, max_length, window_size, method, prompt):
    
    uid = USERS.index(user)
    watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
    prompt = get_prompt(prompt)
    watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
                                          max_length=max_length+LEN_DEFAULT_PROMPT, method=method, prompt=prompt)
    return watermarked_texts[0].split("[/INST]")[1]
    
def detect(attacked_text, window_size, method, prompt):
    watermarker = Watermarker(tokenizer=tokenizer, model=model, window_size=window_size, payload_bits=PAYLOAD_BITS)
    prompt = get_prompt(prompt)
    print([ prompt + attacked_text ])
    pvalues, messages = watermarker.detect([ prompt + attacked_text ], key=args.key, method=method, prompts=[prompt])

    print("messages: ", messages)
    print("p-values: ", pvalues)
    user = USERS[messages[0]]
    pf = pvalues[0]
    label = 'The user detected is {:s} with pvalue of {:.3e}'.format(user, pf)
 
    return label

def get_prompt(message: str) -> str:
    # texts = [f'<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\n']
    # The first user input is _not_ stripped
    # texts.append(f'{message} [/INST]')
    # return ''.join(texts)
    return f"[INST] {message} [/INST]"


with gr.Blocks() as demo:
    gr.Markdown("""# LLM generation watermarking
    This spaces let you to try different watermarking scheme for LLM generation.\n
    It leverages the upgrades introduced in the paper, reducing the gap between empirical and theoretical false positive detection rate and give the ability to embed a message (of n bits). Here we use this capacity to embed the identity of the user generating the text, but it could also be used to identify different version of a model or just convey a secret message.\n
    Simply select an user name, set the maximum text length, the watermarking window size and the prompt. Aaronson and Kirchenbauer watermarking scheme are proposed, along traditional sampling and greedy search without watermarking.\n
    Once the text is generated, you can eventually apply some attacks to it (e.g, remove words), select the associated detection method and run the detection. Please note that the detection is non-blind, and require the original prompt to be known and so left untouched.\n
    For Aaronson, the original detection function, along the Neyman-Pearson and Simplified Score version are available.""")
    with gr.Row():
        user = gr.Dropdown(choices=USERS, value=USERS[0], label="User")
        text_length = gr.Number(minimum=1, maximum=512, value=256, step=1, precision=0, label="Max text length")
        window_size = gr.Number(minimum=0, maximum=10, value=0, step=1, precision=0, label="Watermarking window size")
        embed_method = gr.Dropdown(choices=EMBED_METHODS, value=EMBED_METHODS[0], label="Sampling method")
        prompt = gr.Textbox(label="prompt")
    with gr.Row():
        btn1 = gr.Button("Embed")
    with gr.Row():
        watermarked_text = gr.Textbox(label="Generated text")
        detect_method = gr.Dropdown(choices=DETECT_METHODS, value=DETECT_METHODS[0], label="Detection method")
    with gr.Row():
        btn2 = gr.Button("Detect")
    with gr.Row():
        detection_label = gr.Label(label="Detection result")

    btn1.click(fn=embed, inputs=[user, text_length, window_size, embed_method, prompt], outputs=[watermarked_text], api_name="watermark")
    btn2.click(fn=detect, inputs=[watermarked_text, window_size, detect_method, prompt], outputs=[detection_label], api_name="detect")

    demo.launch()