jnhdny commited on
Commit
afcfcda
1 Parent(s): e0b3412

rough test

Browse files
Files changed (1) hide show
  1. app.py +279 -2
app.py CHANGED
@@ -1,12 +1,284 @@
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from tokenizers import Tokenizer
4
  from transformers import LogitsProcessor
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  llm_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
8
  tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
9
 
 
 
 
 
 
10
  # pipeline2 = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
11
 
12
  generator = pipeline('text-generation', model="facebook/opt-125m")
@@ -23,9 +295,14 @@ def test_it(input):
23
 
24
 
25
  def predict(prompt):
26
- inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
27
  print(inputs)
28
- outputs = llm_model(**inputs, labels=inputs["input_ids"])
 
 
 
 
 
29
 
30
  print(tokenizer.decode(outputs["logits"][0, -1, :].topk(10).indices))
31
 
 
1
+ from __future__ import annotations
2
  import gradio as gr
3
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  from tokenizers import Tokenizer
5
  from transformers import LogitsProcessor
6
 
7
 
8
+ import collections
9
+ from math import sqrt
10
+
11
+ import scipy.stats
12
+
13
+ import torch
14
+ from torch import Tensor
15
+ from tokenizers import Tokenizer
16
+ from transformers import LogitsProcessor, LogitsProcessorList
17
+
18
+ from nltk.util import ngrams
19
+
20
+ from normalizers import normalization_strategy_lookup
21
+
22
+ class WatermarkBase:
23
+ def __init__(
24
+ self,
25
+ vocab: list[int] = None,
26
+ gamma: float = 0.5,
27
+ delta: float = 2.0,
28
+ seeding_scheme: str = "simple_1", # mostly unused/always default
29
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
30
+ select_green_tokens: bool = True,
31
+ ):
32
+
33
+ # watermarking parameters
34
+ self.vocab = vocab
35
+ self.vocab_size = len(vocab)
36
+ self.gamma = gamma
37
+ self.delta = delta
38
+ self.seeding_scheme = seeding_scheme
39
+ self.rng = None
40
+ self.hash_key = hash_key
41
+ self.select_green_tokens = select_green_tokens
42
+
43
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
44
+ # can optionally override the seeding scheme,
45
+ # but uses the instance attr by default
46
+ if seeding_scheme is None:
47
+ seeding_scheme = self.seeding_scheme
48
+
49
+ if seeding_scheme == "simple_1":
50
+ assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
51
+ prev_token = input_ids[-1].item()
52
+ self.rng.manual_seed(self.hash_key * prev_token)
53
+ else:
54
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
55
+ return
56
+
57
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
58
+ # seed the rng using the previous tokens/prefix
59
+ # according to the seeding_scheme
60
+ self._seed_rng(input_ids)
61
+
62
+ greenlist_size = int(self.vocab_size * self.gamma)
63
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
64
+ if self.select_green_tokens: # directly
65
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
66
+ else: # select green via red
67
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
68
+ return greenlist_ids
69
+
70
+
71
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
72
+
73
+ def __init__(self, *args, **kwargs):
74
+ super().__init__(*args, **kwargs)
75
+
76
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
77
+ # TODO lets see if we can lose this loop
78
+ green_tokens_mask = torch.zeros_like(scores)
79
+ for b_idx in range(len(greenlist_token_ids)):
80
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
81
+ final_mask = green_tokens_mask.bool()
82
+ return final_mask
83
+
84
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
85
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
86
+ return scores
87
+
88
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
89
+
90
+ # this is lazy to allow us to colocate on the watermarked model's device
91
+ if self.rng is None:
92
+ self.rng = torch.Generator(device=input_ids.device)
93
+
94
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
95
+ # the seed and partition operations are not tensor/vectorized, thus
96
+ # each sequence in the batch needs to be treated separately.
97
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
98
+
99
+ for b_idx in range(input_ids.shape[0]):
100
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
101
+ batched_greenlist_ids[b_idx] = greenlist_ids
102
+
103
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
104
+
105
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
106
+ return scores
107
+
108
+
109
+ class WatermarkDetector(WatermarkBase):
110
+ def __init__(
111
+ self,
112
+ *args,
113
+ device: torch.device = None,
114
+ tokenizer: Tokenizer = None,
115
+ z_threshold: float = 4.0,
116
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
117
+ ignore_repeated_bigrams: bool = False,
118
+ **kwargs,
119
+ ):
120
+ super().__init__(*args, **kwargs)
121
+ # also configure the metrics returned/preprocessing options
122
+ assert device, "Must pass device"
123
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
124
+
125
+ self.tokenizer = tokenizer
126
+ self.device = device
127
+ self.z_threshold = z_threshold
128
+ self.rng = torch.Generator(device=self.device)
129
+
130
+ if self.seeding_scheme == "simple_1":
131
+ self.min_prefix_len = 1
132
+ else:
133
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
134
+
135
+ self.normalizers = []
136
+ for normalization_strategy in normalizers:
137
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
138
+
139
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
140
+ if self.ignore_repeated_bigrams:
141
+ assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
142
+
143
+
144
+ def _compute_z_score(self, observed_count, T):
145
+ # count refers to number of green tokens, T is total number of tokens
146
+ expected_count = self.gamma
147
+ numer = observed_count - expected_count * T
148
+ denom = sqrt(T * expected_count * (1 - expected_count))
149
+ z = numer / denom
150
+ return z
151
+
152
+ def _compute_p_value(self, z):
153
+ p_value = scipy.stats.norm.sf(z)
154
+ return p_value
155
+
156
+ def _score_sequence(
157
+ self,
158
+ input_ids: Tensor,
159
+ return_num_tokens_scored: bool = True,
160
+ return_num_green_tokens: bool = True,
161
+ return_green_fraction: bool = True,
162
+ return_green_token_mask: bool = False,
163
+ return_z_score: bool = True,
164
+ return_p_value: bool = True,
165
+ ):
166
+ if self.ignore_repeated_bigrams:
167
+ # Method that only counts a green/red hit once per unique bigram.
168
+ # New num total tokens scored (T) becomes the number unique bigrams.
169
+ # We iterate over all unqiue token bigrams in the input, computing the greenlist
170
+ # induced by the first token in each, and then checking whether the second
171
+ # token falls in that greenlist.
172
+ assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
173
+ bigram_table = {}
174
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
175
+ freq = collections.Counter(token_bigram_generator)
176
+ num_tokens_scored = len(freq.keys())
177
+ for idx, bigram in enumerate(freq.keys()):
178
+ prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
179
+ greenlist_ids = self._get_greenlist_ids(prefix)
180
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
181
+ green_token_count = sum(bigram_table.values())
182
+ else:
183
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
184
+ if num_tokens_scored < 1:
185
+ raise ValueError((f"Must have at least {1} token to score after "
186
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
187
+ # Standard method.
188
+ # Since we generally need at least 1 token (for the simplest scheme)
189
+ # we start the iteration over the token sequence with a minimum
190
+ # num tokens as the first prefix for the seeding scheme,
191
+ # and at each step, compute the greenlist induced by the
192
+ # current prefix and check if the current token falls in the greenlist.
193
+ green_token_count, green_token_mask = 0, []
194
+ for idx in range(self.min_prefix_len, len(input_ids)):
195
+ curr_token = input_ids[idx]
196
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
197
+ if curr_token in greenlist_ids:
198
+ green_token_count += 1
199
+ green_token_mask.append(True)
200
+ else:
201
+ green_token_mask.append(False)
202
+
203
+ score_dict = dict()
204
+ if return_num_tokens_scored:
205
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
206
+ if return_num_green_tokens:
207
+ score_dict.update(dict(num_green_tokens=green_token_count))
208
+ if return_green_fraction:
209
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
210
+ if return_z_score:
211
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
212
+ if return_p_value:
213
+ z_score = score_dict.get("z_score")
214
+ if z_score is None:
215
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
216
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
217
+ if return_green_token_mask:
218
+ score_dict.update(dict(green_token_mask=green_token_mask))
219
+
220
+ return score_dict
221
+
222
+ def detect(
223
+ self,
224
+ text: str = None,
225
+ tokenized_text: list[int] = None,
226
+ return_prediction: bool = True,
227
+ return_scores: bool = True,
228
+ z_threshold: float = None,
229
+ **kwargs,
230
+ ) -> dict:
231
+
232
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
233
+ if return_prediction:
234
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
235
+
236
+ # run optional normalizers on text
237
+ for normalizer in self.normalizers:
238
+ text = normalizer(text)
239
+ if len(self.normalizers) > 0:
240
+ print(f"Text after normalization:\n\n{text}\n")
241
+
242
+ if tokenized_text is None:
243
+ assert self.tokenizer is not None, (
244
+ "Watermark detection on raw string ",
245
+ "requires an instance of the tokenizer ",
246
+ "that was used at generation time.",
247
+ )
248
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
249
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
250
+ tokenized_text = tokenized_text[1:]
251
+ else:
252
+ # try to remove the bos_tok at beginning if it's there
253
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
254
+ tokenized_text = tokenized_text[1:]
255
+
256
+ # call score method
257
+ output_dict = {}
258
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
259
+ if return_scores:
260
+ output_dict.update(score_dict)
261
+ # if passed return_prediction then perform the hypothesis test and return the outcome
262
+ if return_prediction:
263
+ z_threshold = z_threshold if z_threshold else self.z_threshold
264
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
265
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
266
+ if output_dict["prediction"]:
267
+ output_dict["confidence"] = 1 - score_dict["p_value"]
268
+
269
+ return output_dict
270
+
271
+
272
+
273
+
274
  llm_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
275
  tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
276
 
277
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
278
+ gamma=0.25,
279
+ delta=2.0,
280
+ seeding_scheme="simple_1")
281
+
282
  # pipeline2 = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
283
 
284
  generator = pipeline('text-generation', model="facebook/opt-125m")
 
295
 
296
 
297
  def predict(prompt):
298
+ inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt").to(llm_model.device)
299
  print(inputs)
300
+ outputs = llm_model.generate(**inputs, labels=inputs["input_ids"],
301
+ logits_processor=LogitsProcessorList([watermark_processor,]))
302
+
303
+
304
+ outputs = outputs[:, inputs["input_ids"].shape[-1]:]
305
+ print("Watermarked stuff", tokenizer.batch_decode(outputs, skip_special_tokens=True))
306
 
307
  print(tokenizer.decode(outputs["logits"][0, -1, :].topk(10).indices))
308