File size: 15,523 Bytes
6a20eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial

import numpy # for gradio hot reload
import gradio as gr
import pathlib
import torch

from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          AutoModelForCausalLM,
                          LogitsProcessorList,
                          LlamaTokenizer)

from gptwm import GPTWatermarkDetector, GPTWatermarkLogitsWarper

def str2bool(v):
    """Util function for user friendly boolean flag args"""
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def parse_args():
    """Command line argument specification"""

    parser = argparse.ArgumentParser()

    parser.add_argument("--run_gradio",type=str2bool,default=True,help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.")
    parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
    parser.add_argument("--fraction", type=float, default=0.5)
    parser.add_argument("--strength", type=float, default=2.0)
    parser.add_argument("--wm_key", type=int, default=0)
    parser.add_argument("--max_new_tokens", type=int, default=300)
    parser.add_argument("--beam_size", type=int, default=None)
    parser.add_argument("--top_k", type=int, default=None)
    parser.add_argument("--top_p", type=float, default=0.9)
    parser.add_argument("--test_min_tokens", type=int, default=200)
    parser.add_argument("--threshold", type=float, default=6.0)
    args = parser.parse_args()
    return args

def load_model(args):
    """Load and return the model and tokenizer"""
    hf_token = os.getenv('HF_TOKEN')
    if 'llama' in args.model_name:
        tokenizer = LlamaTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
    model = AutoModelForCausalLM.from_pretrained(args.model_name, use_auth_token=hf_token, device_map='auto')
    model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return model, tokenizer, device

def generate(prompt, args, model=None, device=None, tokenizer=None):
    print(f"Generating with {args}")

    watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
                                                                        strength=args.strength,
                                                                        vocab_size=model.config.vocab_size,
                                                                        watermark_key=args.wm_key)])

    
    batch = tokenizer(prompt, truncation=True, return_tensors="pt").to(device)
    num_tokens = len(batch['input_ids'][0])
    with torch.inference_mode():
            generate_args = {
                **batch,
                'output_scores': True,
                'return_dict_in_generate': True,
                'max_new_tokens': args.max_new_tokens,
            }

            if args.beam_size is not None:
                generate_args['num_beams'] = args.beam_size
            else:
                generate_args['do_sample'] = True
                generate_args['top_k'] = args.top_k
                generate_args['top_p'] = args.top_p
            
            generate_without_watermark = partial(
                model.generate,
                **generate_args
            )
            output_without_watermark = generate_without_watermark()
            decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
            generate_with_watermark = partial(
                model.generate,
                logits_processor=watermark_processor, 
                **generate_args
            )
            output_with_watermark = generate_with_watermark()
            decoded_gen_text_with_wm = tokenizer.batch_decode(output_with_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]

    return (prompt,
            decoded_output_without_watermark, 
            decoded_gen_text_with_wm,
            args) 



def detect_demo(input_text, args, device=None, tokenizer=None):

    vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size

    watermark_detector = GPTWatermarkDetector(fraction=args.fraction,
                                        strength=args.strength,
                                        vocab_size=vocab_size,
                                        watermark_key=args.wm_key)
    output = []
    html_output = ["Input text is too short to test."]
    tokens = tokenizer(input_text, add_special_tokens=False)
    gen_tokens = tokens["input_ids"]
    if len(gen_tokens)>= args.test_min_tokens:
        z_score,green_tokens_mask,green_tokens,total_tokens = watermark_detector.detect(gen_tokens)
        output.append(['z-score', f"{z_score:.3g}"])
        output.append(['green_tokens', f"{int(green_tokens):d}"])
        output.append(['total_tokens', f"{int(total_tokens):d}"])
        tokenarray =[tokens.token_to_chars(i) for i in range(0,len(gen_tokens))]
        tags = [(f'<span class="green">{input_text[word.start:word.end]}</span>' if b else f'<span class="red">{input_text[word.start:word.end]}</span>') for word, b in zip(tokenarray, green_tokens_mask)]
        html_output = f'<p>{" ".join(tags)}</p>'
    else:
        print(f"Input text is too short to test.")
    return output,html_output, args

def run_gradio(args, model=None, device=None, tokenizer=None):
    """Define and launch the gradio demo interface"""
    css = """

    .green {

    color: #008000 !important;  

    border: none;  

    font-weight: bold; 

    }

    .red {

    color: #ffad99 !important;  

    border: none;  

    font-weight: bold;  

    }

    """

    generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
    detect_partial = partial(detect_demo, device=device, tokenizer=tokenizer)

    with gr.Blocks(css=css) as demo:
        # Top section, greeting and instructions
        with gr.Row():
            with gr.Row():
                with gr.Column(scale=9):
                    gr.Markdown(
                    """

                    ## 🔍 Unigram-Watermark for AI-Generated Text 

                    

                    ## [Paper](https://arxiv.org/abs/2306.17439)  [GitHub](https://github.com/XuandongZhao/Unigram-Watermark)

                    """
                    )

        with gr.Accordion("Abstract",open=True):
            gr.Markdown(
            """

            We instantiate our language model watermarking with the **Unigram-Watermark**——a variant of the K-gram watermark. 

            

            We prove that our watermark method enjoys guaranteed generation quality, correctness in watermark detection, and is robust against text editing and paraphrasing.

            """
            )

        gr.Markdown(f"Language model: {args.model_name}")

        # Construct state for parameters, define updates and toggles
        default_prompt = args.__dict__.pop("default_prompt")
        session_args = gr.State(value=args)

        with gr.Tab("Method"):
            with gr.Accordion("Watermark process",open=True):
                gr.Markdown(
                """

                1. Randomly partition the vocabulary into two distinct sets: the green list with $\gamma N$ tokens and the red list with the remaining tokens. 

                2. In $\hat{M}$, the logits of the language model for the green list tokens are increased by $\delta$ while the logits for tokens in the red list remain unchanged.

                """
                )
            with gr.Accordion("Detect process",open=True):
                gr.Markdown(
                """

                1. Count the number of green tokens in the suspect text.



                2. Normalize the test-statistic $z_{y}=(|y|_G-\gamma n) / \sqrt{n \gamma(1-\gamma)}$.



                3. Make a calibrated decision on whether we think the suspect text is generated from $\hat{M}$ or not. 

                """
                )
        with gr.Tab("Generate and Detect"):
            
            with gr.Row():
                prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
            with gr.Row():
                generate_btn = gr.Button("Generate")
            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Tab("Output Without Watermark"):
                        output_without_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
                    with gr.Tab("Visualization"):# ¥
                        html_without_watermark = gr.HTML(elem_id="html-without-watermark")
                with gr.Column(scale=1):
                    without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Tab("Output With Watermark"):
                        output_with_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
                    with gr.Tab("Visualization"):# 
                        html_with_watermark = gr.HTML(elem_id="html-with-watermark")
                with gr.Column(scale=1):
                    with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)

            redecoded_input = gr.Textbox(visible=False)
            truncation_warning = gr.Number(visible=False)
            def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
                if truncation_warning:
                    return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
                else: 
                    return orig_prompt, args
        
        with gr.Tab("Detector Only"):
            with gr.Row():
                with gr.Column(scale=2):
                    # detect inputbox
                    with gr.Tab("Text to Analyze"):
                        detection_input = gr.Textbox(label="Input", interactive=True,lines=14,max_lines=14)
                    with gr.Tab("Visualization"):
                        html_detection = gr.HTML(elem_id="html-detection")
                with gr.Column(scale=1):
                    detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
            with gr.Row():
                    # detect
                    detect_btn = gr.Button("Detect")
        

        
        generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, output_without_watermark, output_with_watermark,session_args])
        # Show truncated version of prompt if truncation occurred
        redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
        # Call detection when the outputs (of the generate function) are updated
        output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,html_without_watermark,session_args])
        output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,html_with_watermark,session_args])
        # Register main detection tab click
        detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, html_detection,session_args])

    
    demo.launch()

def main(args): 
    """Run a command line version of the generation and detection operations

        and optionally launch and serve the gradio demo"""
    # Initial arg processing and log

    model, tokenizer, device = load_model(args)

    # Generate and detect, report to stdout
    input_text = (
     "One tank tumbled down an embankment into the Tenaru River, drowning its crew."
     " At 23:00 on 14 September, the remnants of the Kuma battalion conducted another attack on the same portion of the Marine lines, but were repulsed. "
     "A final \"weak\" attack by the Kuma unit on the evening of 15 September was also defeated. Oka's unit of about 650 men attacked the Marines at several locations on the west side of the Lunga perimeter."
     " At about 04:00 on 14 September, two Japanese companies attacked positions held by the 3rd Battalion, 5th Marine Regiment (3/5) near the coast and were thrown back with heavy losses."
     " Another Japanese company captured a small ridge somewhat inland but was then pinned down by Marine artillery fire throughout the day and took heavy losses before withdrawing on the evening of 14 September."
     " The rest of Oka's unit failed to find the Marine lines and did not participate in the attack. "
     "At 13:05 on 14 September, Kawaguchi led the survivors of his shattered brigade away from the ridge and deeper into the jungle, where they rested and tended to their wounded all the next day. "
     "Kawaguchi's units were then ordered to withdraw west to the Matanikau River valley to join with Oka's unit, a march over difficult terrain."
     " Kawaguchi's troops began the march on the morning of 16 September."
     " Almost every soldier able to walk had to help carry the wounded. "
     "As the march progressed, the exhausted and hungry soldiers, who had eaten their last rations on the morning before their withdrawal, began to discard their heavy equipment and then their rifles. "
     "By the time most of them reached Oka's positions at Kokumbona five days later, only half still carried their weapons."
     " The Kuma battalion's survivors, attempting to follow Kawaguchi's Center Body forces, became lost, wandered for three weeks in the jungle, and almost starved to death before finally reaching Kawaguchi's camp."
    )

    args.default_prompt = input_text

    # Launch the app to generate and detect interactively (implements the hf space demo)
    if args.run_gradio:
        run_gradio(args, model=model, tokenizer=tokenizer, device=device)

    return

if __name__ == "__main__":

    args = parse_args()
    print(args)

    main(args)