XuandongZhao commited on
Commit
6a20eb3
β€’
1 Parent(s): f39573c
Files changed (6) hide show
  1. README.md +13 -13
  2. app.py +305 -0
  3. gptwm.py +114 -0
  4. requirements.txt +5 -0
  5. run_detect.py +58 -0
  6. run_generate.py +106 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Unigram Watermark
3
- emoji: 😻
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.7.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Unigram-Watermark
3
+ emoji: πŸ‘€
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.7.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import os
14
+ import argparse
15
+ from argparse import Namespace
16
+ from pprint import pprint
17
+ from functools import partial
18
+
19
+ import numpy # for gradio hot reload
20
+ import gradio as gr
21
+ import pathlib
22
+ import torch
23
+
24
+ from transformers import (AutoTokenizer,
25
+ AutoModelForSeq2SeqLM,
26
+ AutoModelForCausalLM,
27
+ LogitsProcessorList,
28
+ LlamaTokenizer)
29
+
30
+ from gptwm import GPTWatermarkDetector, GPTWatermarkLogitsWarper
31
+
32
+ def str2bool(v):
33
+ """Util function for user friendly boolean flag args"""
34
+ if isinstance(v, bool):
35
+ return v
36
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
37
+ return True
38
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
39
+ return False
40
+ else:
41
+ raise argparse.ArgumentTypeError('Boolean value expected.')
42
+
43
+ def parse_args():
44
+ """Command line argument specification"""
45
+
46
+ parser = argparse.ArgumentParser()
47
+
48
+ parser.add_argument("--run_gradio",type=str2bool,default=True,help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.")
49
+ parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
50
+ parser.add_argument("--fraction", type=float, default=0.5)
51
+ parser.add_argument("--strength", type=float, default=2.0)
52
+ parser.add_argument("--wm_key", type=int, default=0)
53
+ parser.add_argument("--max_new_tokens", type=int, default=300)
54
+ parser.add_argument("--beam_size", type=int, default=None)
55
+ parser.add_argument("--top_k", type=int, default=None)
56
+ parser.add_argument("--top_p", type=float, default=0.9)
57
+ parser.add_argument("--test_min_tokens", type=int, default=200)
58
+ parser.add_argument("--threshold", type=float, default=6.0)
59
+ args = parser.parse_args()
60
+ return args
61
+
62
+ def load_model(args):
63
+ """Load and return the model and tokenizer"""
64
+ hf_token = os.getenv('HF_TOKEN')
65
+ if 'llama' in args.model_name:
66
+ tokenizer = LlamaTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
67
+ else:
68
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
69
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, use_auth_token=hf_token, device_map='auto')
70
+ model.eval()
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ return model, tokenizer, device
73
+
74
+ def generate(prompt, args, model=None, device=None, tokenizer=None):
75
+ print(f"Generating with {args}")
76
+
77
+ watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
78
+ strength=args.strength,
79
+ vocab_size=model.config.vocab_size,
80
+ watermark_key=args.wm_key)])
81
+
82
+
83
+ batch = tokenizer(prompt, truncation=True, return_tensors="pt").to(device)
84
+ num_tokens = len(batch['input_ids'][0])
85
+ with torch.inference_mode():
86
+ generate_args = {
87
+ **batch,
88
+ 'output_scores': True,
89
+ 'return_dict_in_generate': True,
90
+ 'max_new_tokens': args.max_new_tokens,
91
+ }
92
+
93
+ if args.beam_size is not None:
94
+ generate_args['num_beams'] = args.beam_size
95
+ else:
96
+ generate_args['do_sample'] = True
97
+ generate_args['top_k'] = args.top_k
98
+ generate_args['top_p'] = args.top_p
99
+
100
+ generate_without_watermark = partial(
101
+ model.generate,
102
+ **generate_args
103
+ )
104
+ output_without_watermark = generate_without_watermark()
105
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
106
+ generate_with_watermark = partial(
107
+ model.generate,
108
+ logits_processor=watermark_processor,
109
+ **generate_args
110
+ )
111
+ output_with_watermark = generate_with_watermark()
112
+ decoded_gen_text_with_wm = tokenizer.batch_decode(output_with_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
113
+
114
+ return (prompt,
115
+ decoded_output_without_watermark,
116
+ decoded_gen_text_with_wm,
117
+ args)
118
+
119
+
120
+
121
+ def detect_demo(input_text, args, device=None, tokenizer=None):
122
+
123
+ vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
124
+
125
+ watermark_detector = GPTWatermarkDetector(fraction=args.fraction,
126
+ strength=args.strength,
127
+ vocab_size=vocab_size,
128
+ watermark_key=args.wm_key)
129
+ output = []
130
+ html_output = ["Input text is too short to test."]
131
+ tokens = tokenizer(input_text, add_special_tokens=False)
132
+ gen_tokens = tokens["input_ids"]
133
+ if len(gen_tokens)>= args.test_min_tokens:
134
+ z_score,green_tokens_mask,green_tokens,total_tokens = watermark_detector.detect(gen_tokens)
135
+ output.append(['z-score', f"{z_score:.3g}"])
136
+ output.append(['green_tokens', f"{int(green_tokens):d}"])
137
+ output.append(['total_tokens', f"{int(total_tokens):d}"])
138
+ tokenarray =[tokens.token_to_chars(i) for i in range(0,len(gen_tokens))]
139
+ tags = [(f'<span class="green">{input_text[word.start:word.end]}</span>' if b else f'<span class="red">{input_text[word.start:word.end]}</span>') for word, b in zip(tokenarray, green_tokens_mask)]
140
+ html_output = f'<p>{" ".join(tags)}</p>'
141
+ else:
142
+ print(f"Input text is too short to test.")
143
+ return output,html_output, args
144
+
145
+ def run_gradio(args, model=None, device=None, tokenizer=None):
146
+ """Define and launch the gradio demo interface"""
147
+ css = """
148
+ .green {
149
+ color: #008000 !important;
150
+ border: none;
151
+ font-weight: bold;
152
+ }
153
+ .red {
154
+ color: #ffad99 !important;
155
+ border: none;
156
+ font-weight: bold;
157
+ }
158
+ """
159
+
160
+ generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
161
+ detect_partial = partial(detect_demo, device=device, tokenizer=tokenizer)
162
+
163
+ with gr.Blocks(css=css) as demo:
164
+ # Top section, greeting and instructions
165
+ with gr.Row():
166
+ with gr.Row():
167
+ with gr.Column(scale=9):
168
+ gr.Markdown(
169
+ """
170
+ ## πŸ” Unigram-Watermark for AI-Generated Text
171
+
172
+ ## [Paper](https://arxiv.org/abs/2306.17439) [GitHub](https://github.com/XuandongZhao/Unigram-Watermark)
173
+ """
174
+ )
175
+
176
+ with gr.Accordion("Abstract",open=True):
177
+ gr.Markdown(
178
+ """
179
+ We instantiate our language model watermarking with the **Unigram-Watermark**β€”β€”a variant of the K-gram watermark.
180
+
181
+ We prove that our watermark method enjoys guaranteed generation quality, correctness in watermark detection, and is robust against text editing and paraphrasing.
182
+ """
183
+ )
184
+
185
+ gr.Markdown(f"Language model: {args.model_name}")
186
+
187
+ # Construct state for parameters, define updates and toggles
188
+ default_prompt = args.__dict__.pop("default_prompt")
189
+ session_args = gr.State(value=args)
190
+
191
+ with gr.Tab("Method"):
192
+ with gr.Accordion("Watermark process",open=True):
193
+ gr.Markdown(
194
+ """
195
+ 1. Randomly partition the vocabulary into two distinct sets: the green list with $\gamma N$ tokens and the red list with the remaining tokens.
196
+ 2. In $\hat{M}$, the logits of the language model for the green list tokens are increased by $\delta$ while the logits for tokens in the red list remain unchanged.
197
+ """
198
+ )
199
+ with gr.Accordion("Detect process",open=True):
200
+ gr.Markdown(
201
+ """
202
+ 1. Count the number of green tokens in the suspect text.
203
+
204
+ 2. Normalize the test-statistic $z_{y}=(|y|_G-\gamma n) / \sqrt{n \gamma(1-\gamma)}$.
205
+
206
+ 3. Make a calibrated decision on whether we think the suspect text is generated from $\hat{M}$ or not.
207
+ """
208
+ )
209
+ with gr.Tab("Generate and Detect"):
210
+
211
+ with gr.Row():
212
+ prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
213
+ with gr.Row():
214
+ generate_btn = gr.Button("Generate")
215
+ with gr.Row():
216
+ with gr.Column(scale=1):
217
+ with gr.Tab("Output Without Watermark"):
218
+ output_without_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
219
+ with gr.Tab("Visualization"):# οΏ₯
220
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
221
+ with gr.Column(scale=1):
222
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
223
+ with gr.Row():
224
+ with gr.Column(scale=1):
225
+ with gr.Tab("Output With Watermark"):
226
+ output_with_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
227
+ with gr.Tab("Visualization"):#
228
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
229
+ with gr.Column(scale=1):
230
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
231
+
232
+ redecoded_input = gr.Textbox(visible=False)
233
+ truncation_warning = gr.Number(visible=False)
234
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
235
+ if truncation_warning:
236
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
237
+ else:
238
+ return orig_prompt, args
239
+
240
+ with gr.Tab("Detector Only"):
241
+ with gr.Row():
242
+ with gr.Column(scale=2):
243
+ # detect inputbox
244
+ with gr.Tab("Text to Analyze"):
245
+ detection_input = gr.Textbox(label="Input", interactive=True,lines=14,max_lines=14)
246
+ with gr.Tab("Visualization"):
247
+ html_detection = gr.HTML(elem_id="html-detection")
248
+ with gr.Column(scale=1):
249
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
250
+ with gr.Row():
251
+ # detect
252
+ detect_btn = gr.Button("Detect")
253
+
254
+
255
+
256
+ generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, output_without_watermark, output_with_watermark,session_args])
257
+ # Show truncated version of prompt if truncation occurred
258
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
259
+ # Call detection when the outputs (of the generate function) are updated
260
+ output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,html_without_watermark,session_args])
261
+ output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,html_with_watermark,session_args])
262
+ # Register main detection tab click
263
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, html_detection,session_args])
264
+
265
+
266
+ demo.launch()
267
+
268
+ def main(args):
269
+ """Run a command line version of the generation and detection operations
270
+ and optionally launch and serve the gradio demo"""
271
+ # Initial arg processing and log
272
+
273
+ model, tokenizer, device = load_model(args)
274
+
275
+ # Generate and detect, report to stdout
276
+ input_text = (
277
+ "One tank tumbled down an embankment into the Tenaru River, drowning its crew."
278
+ " At 23:00 on 14 September, the remnants of the Kuma battalion conducted another attack on the same portion of the Marine lines, but were repulsed. "
279
+ "A final \"weak\" attack by the Kuma unit on the evening of 15 September was also defeated. Oka's unit of about 650 men attacked the Marines at several locations on the west side of the Lunga perimeter."
280
+ " At about 04:00 on 14 September, two Japanese companies attacked positions held by the 3rd Battalion, 5th Marine Regiment (3/5) near the coast and were thrown back with heavy losses."
281
+ " Another Japanese company captured a small ridge somewhat inland but was then pinned down by Marine artillery fire throughout the day and took heavy losses before withdrawing on the evening of 14 September."
282
+ " The rest of Oka's unit failed to find the Marine lines and did not participate in the attack. "
283
+ "At 13:05 on 14 September, Kawaguchi led the survivors of his shattered brigade away from the ridge and deeper into the jungle, where they rested and tended to their wounded all the next day. "
284
+ "Kawaguchi's units were then ordered to withdraw west to the Matanikau River valley to join with Oka's unit, a march over difficult terrain."
285
+ " Kawaguchi's troops began the march on the morning of 16 September."
286
+ " Almost every soldier able to walk had to help carry the wounded. "
287
+ "As the march progressed, the exhausted and hungry soldiers, who had eaten their last rations on the morning before their withdrawal, began to discard their heavy equipment and then their rifles. "
288
+ "By the time most of them reached Oka's positions at Kokumbona five days later, only half still carried their weapons."
289
+ " The Kuma battalion's survivors, attempting to follow Kawaguchi's Center Body forces, became lost, wandered for three weeks in the jungle, and almost starved to death before finally reaching Kawaguchi's camp."
290
+ )
291
+
292
+ args.default_prompt = input_text
293
+
294
+ # Launch the app to generate and detect interactively (implements the hf space demo)
295
+ if args.run_gradio:
296
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
297
+
298
+ return
299
+
300
+ if __name__ == "__main__":
301
+
302
+ args = parse_args()
303
+ print(args)
304
+
305
+ main(args)
gptwm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import List
3
+ import numpy as np
4
+ from scipy.stats import norm
5
+ import torch
6
+ from transformers import LogitsWarper
7
+
8
+
9
+ class GPTWatermarkBase:
10
+ """
11
+ Base class for watermarking distributions with fixed-group green-listed tokens.
12
+
13
+ Args:
14
+ fraction: The fraction of the distribution to be green-listed.
15
+ strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
16
+ vocab_size: The size of the vocabulary.
17
+ watermark_key: The random seed for the green-listing.
18
+ """
19
+
20
+ def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 50257, watermark_key: int = 0):
21
+ rng = np.random.default_rng(self._hash_fn(watermark_key))
22
+ mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size)))
23
+ rng.shuffle(mask)
24
+ self.green_list_mask = torch.tensor(mask, dtype=torch.float32)
25
+ self.strength = strength
26
+ self.fraction = fraction
27
+
28
+ @staticmethod
29
+ def _hash_fn(x: int) -> int:
30
+ """solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits"""
31
+ x = np.int64(x)
32
+ return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little')
33
+
34
+
35
+ class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper):
36
+ """
37
+ LogitsWarper for watermarking distributions with fixed-group green-listed tokens.
38
+
39
+ Args:
40
+ fraction: The fraction of the distribution to be green-listed.
41
+ strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
42
+ vocab_size: The size of the vocabulary.
43
+ watermark_key: The random seed for the green-listing.
44
+ """
45
+
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
50
+ """Add the watermark to the logits and return new logits."""
51
+ watermark = self.strength * self.green_list_mask
52
+ new_logits = scores + watermark.to(scores.device)
53
+ return new_logits
54
+
55
+
56
+ class GPTWatermarkDetector(GPTWatermarkBase):
57
+ """
58
+ Class for detecting watermarks in a sequence of tokens.
59
+
60
+ Args:
61
+ fraction: The fraction of the distribution to be green-listed.
62
+ strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
63
+ vocab_size: The size of the vocabulary.
64
+ watermark_key: The random seed for the green-listing.
65
+ """
66
+
67
+ def __init__(self, *args, **kwargs):
68
+ super().__init__(*args, **kwargs)
69
+
70
+ @staticmethod
71
+ def _z_score(num_green: int, total: int, fraction: float) -> float:
72
+ """Calculate and return the z-score of the number of green tokens in a sequence."""
73
+ return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total)
74
+
75
+ @staticmethod
76
+ def _compute_tau(m: int, N: int, alpha: float) -> float:
77
+ """
78
+ Compute the threshold tau for the dynamic thresholding.
79
+
80
+ Args:
81
+ m: The number of unique tokens in the sequence.
82
+ N: Vocabulary size.
83
+ alpha: The false positive rate to control.
84
+ Returns:
85
+ The threshold tau.
86
+ """
87
+ factor = np.sqrt(1 - (m - 1) / (N - 1))
88
+ tau = factor * norm.ppf(1 - alpha)
89
+ return tau
90
+
91
+ def detect(self, sequence: List[int]) -> float:
92
+ """Detect the watermark in a sequence of tokens and return the z value."""
93
+ green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
94
+ green_tokens_mask = []
95
+ for i in sequence:
96
+ if self.green_list_mask[i]:
97
+ green_tokens_mask.append(True)
98
+ else:
99
+ green_tokens_mask.append(False)
100
+ # self.green_tokens_mask = green_tokens_mask
101
+
102
+ return self._z_score(green_tokens, len(sequence), self.fraction), green_tokens_mask,green_tokens,len(sequence)
103
+
104
+ def unidetect(self, sequence: List[int]) -> float:
105
+ """Detect the watermark in a sequence of tokens and return the z value. Just for unique tokens."""
106
+ sequence = list(set(sequence))
107
+ green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
108
+ return self._z_score(green_tokens, len(sequence), self.fraction)
109
+
110
+ def dynamic_threshold(self, sequence: List[int], alpha: float, vocab_size: int) -> (bool, float):
111
+ """Dynamic thresholding for watermark detection. True if the sequence is watermarked, False otherwise."""
112
+ z_score = self.unidetect(sequence)
113
+ tau = self._compute_tau(len(list(set(sequence))), vocab_size, alpha)
114
+ return z_score > tau, z_score
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ scipy
4
+ accelerate
5
+ pathlib
run_detect.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from tqdm import tqdm
4
+ import torch
5
+ from transformers import AutoTokenizer, LlamaTokenizer
6
+ from gptwm import GPTWatermarkDetector
7
+
8
+
9
+ def main(args):
10
+ with open(args.input_file, 'r') as f:
11
+ data = [json.loads(x) for x in f.read().strip().split("\n")]
12
+ if 'llama' in args.model_name:
13
+ tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
14
+ else:
15
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
16
+
17
+ vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
18
+
19
+ detector = GPTWatermarkDetector(fraction=args.fraction,
20
+ strength=args.strength,
21
+ vocab_size=vocab_size,
22
+ watermark_key=args.wm_key)
23
+
24
+ z_score_list = []
25
+ for idx, cur_data in tqdm(enumerate(data), total=len(data)):
26
+ gen_tokens = tokenizer(cur_data['gen_completion'][0], add_special_tokens=False)["input_ids"]
27
+ if len(gen_tokens) >= args.test_min_tokens:
28
+ z_score_list.append(detector.detect(gen_tokens))
29
+ else:
30
+ print(f"Warning: sequence {idx} is too short to test.")
31
+
32
+ save_dict = {
33
+ 'z_score': z_score_list,
34
+ 'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list]
35
+ }
36
+
37
+ print(save_dict)
38
+ with open(args.input_file.replace('.jsonl', '_z.jsonl'), 'w') as f:
39
+ json.dump(save_dict, f)
40
+
41
+ print('Finished!')
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+
47
+ # parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
48
+ parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
49
+ parser.add_argument("--fraction", type=float, default=0.5)
50
+ parser.add_argument("--strength", type=float, default=2.0)
51
+ parser.add_argument("--threshold", type=float, default=6.0)
52
+ parser.add_argument("--wm_key", type=int, default=0)
53
+ parser.add_argument("--input_file", type=str, default="./data/example_output.jsonl")
54
+ parser.add_argument("--test_min_tokens", type=int, default=200)
55
+
56
+ args = parser.parse_args()
57
+
58
+ main(args)
run_generate.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from tqdm import tqdm
3
+ import json
4
+ import torch
5
+ import os
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LogitsProcessorList
7
+ from gptwm import GPTWatermarkLogitsWarper
8
+
9
+
10
+ def read_file(filename):
11
+ with open(filename, "r") as f:
12
+ return [json.loads(line) for line in f.read().strip().split("\n")]
13
+
14
+
15
+ def write_file(filename, data):
16
+ with open(filename, "a") as f:
17
+ f.write("\n".join(data) + "\n")
18
+
19
+
20
+ def main(args):
21
+ output_file = f"{args.output_dir}/{args.model_name.replace('/', '-')}_strength_{args.strength}_frac_{args.fraction}_len_{args.max_new_tokens}_num_{args.num_test}.jsonl"
22
+ if 'llama' in args.model_name:
23
+ tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
24
+ else:
25
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
26
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto')
27
+ model.eval()
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
30
+ strength=args.strength,
31
+ vocab_size=model.config.vocab_size,
32
+ watermark_key=args.wm_key)])
33
+
34
+ data = read_file(args.prompt_file)
35
+ num_cur_outputs = len(read_file(output_file)) if os.path.exists(output_file) else 0
36
+
37
+ outputs = []
38
+
39
+ for idx, cur_data in tqdm(enumerate(data), total=min(len(data), args.num_test)):
40
+ if idx < num_cur_outputs or len(outputs) >= args.num_test:
41
+ continue
42
+
43
+ if "gold_completion" not in cur_data and 'targets' not in cur_data:
44
+ continue
45
+ elif "gold_completion" in cur_data:
46
+ prefix = cur_data['prefix']
47
+ gold_completion = cur_data['gold_completion']
48
+ else:
49
+ prefix = cur_data['prefix']
50
+ gold_completion = cur_data['targets'][0]
51
+
52
+ batch = tokenizer(prefix, truncation=True, return_tensors="pt").to(device)
53
+ num_tokens = len(batch['input_ids'][0])
54
+
55
+ with torch.inference_mode():
56
+ generate_args = {
57
+ **batch,
58
+ 'logits_processor': watermark_processor,
59
+ 'output_scores': True,
60
+ 'return_dict_in_generate': True,
61
+ 'max_new_tokens': args.max_new_tokens,
62
+ }
63
+
64
+ if args.beam_size is not None:
65
+ generate_args['num_beams'] = args.beam_size
66
+ else:
67
+ generate_args['do_sample'] = True
68
+ generate_args['top_k'] = args.top_k
69
+ generate_args['top_p'] = args.top_p
70
+
71
+ generation = model.generate(**generate_args)
72
+ gen_text = tokenizer.batch_decode(generation['sequences'][:, num_tokens:], skip_special_tokens=True)
73
+
74
+ outputs.append(json.dumps({
75
+ "prefix": prefix,
76
+ "gold_completion": gold_completion,
77
+ "gen_completion": gen_text
78
+ }))
79
+
80
+ if (idx + 1) % 10 == 0:
81
+ write_file(output_file, outputs)
82
+ outputs = []
83
+ break
84
+
85
+ write_file(output_file, outputs)
86
+ print("Finished!")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser()
91
+
92
+ parser.add_argument("--model_name", type=str, default="facebookopt-125m")
93
+ # parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
94
+ parser.add_argument("--fraction", type=float, default=0.5)
95
+ parser.add_argument("--strength", type=float, default=2.0)
96
+ parser.add_argument("--wm_key", type=int, default=0)
97
+ parser.add_argument("--prompt_file", type=str, default="./data/LFQA/inputs.jsonl")
98
+ parser.add_argument("--output_dir", type=str, default="./data/LFQA/")
99
+ parser.add_argument("--max_new_tokens", type=int, default=300)
100
+ parser.add_argument("--num_test", type=int, default=500)
101
+ parser.add_argument("--beam_size", type=int, default=None)
102
+ parser.add_argument("--top_k", type=int, default=None)
103
+ parser.add_argument("--top_p", type=float, default=0.9)
104
+
105
+ args = parser.parse_args()
106
+ main(args)