Snowball commited on
Commit
f115ecd
1 Parent(s): 7a98109

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from models.watermark_faster import watermark_model
3
+ import pdb
4
+ from options import get_parser_main_model
5
+
6
+ opts = get_parser_main_model().parse_args()
7
+ model = watermark_model(language=opts.language, mode=opts.mode, tau_word=opts.tau_word, lamda=opts.lamda)
8
+ def watermark_embed_demo(raw):
9
+
10
+ watermarked_text = model.embed(raw)
11
+ return watermarked_text
12
+
13
+ def watermark_extract(raw):
14
+ is_watermark, p_value, n, ones, z_value = model.watermark_detector_fast(raw)
15
+ confidence = (1 - p_value) * 100
16
+
17
+ return f"{confidence:.2f}%"
18
+
19
+ def precise_watermark_detect(raw):
20
+ is_watermark, p_value, n, ones, z_value = model.watermark_detector_precise(raw)
21
+ confidence = (1 - p_value) * 100
22
+
23
+ return f"{confidence:.2f}%"
24
+
25
+
26
+ demo = gr.Blocks()
27
+ with demo:
28
+ with gr.Column():
29
+ gr.Markdown("# Watermarking Text Generated by Black-Box Language Models")
30
+
31
+ inputs = gr.TextArea(label="Input text", placeholder="Copy your text here...")
32
+ output = gr.Textbox(label="Watermarked Text")
33
+ analysis_button = gr.Button("Inject Watermark")
34
+ inputs_embed = [inputs]
35
+ analysis_button.click(fn=watermark_embed_demo, inputs=inputs_embed, outputs=output)
36
+
37
+ inputs_w = gr.TextArea(label="Text to Analyze", placeholder="Copy your watermarked text here...")
38
+
39
+ mode = gr.Dropdown(
40
+ label="Detection Mode", choices=["Fast", "Precise"], default="Fast"
41
+ )
42
+ output_detect = gr.Textbox(label="Confidence (the likelihood of the text containing a watermark)")
43
+ detect_button = gr.Button("Detect")
44
+
45
+ def detect_watermark(inputs_w, mode):
46
+ if mode == "Fast":
47
+ return watermark_extract(inputs_w)
48
+ else:
49
+ return precise_watermark_detect(inputs_w)
50
+
51
+ detect_button.click(fn=detect_watermark, inputs=[inputs_w, mode], outputs=output_detect)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ gr.close_all()
56
+ demo.title = "Watermarking Text Generated by Black-Box Language Models"
57
+ demo.launch(share = True, server_port=8899)
models/__pycache__/watermark_faster.cpython-39.pyc ADDED
Binary file (15.9 kB). View file
 
models/watermark_faster.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.corpus import stopwords
3
+ from nltk import word_tokenize, pos_tag
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ import hashlib
8
+ from scipy.stats import norm
9
+ import gensim
10
+ import pdb
11
+ from transformers import BertForMaskedLM as WoBertForMaskedLM
12
+ from wobert import WoBertTokenizer
13
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
+
15
+ from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
16
+ import gensim.downloader as api
17
+ import Levenshtein
18
+ import string
19
+ import spacy
20
+ import paddle
21
+ from jieba import posseg
22
+ paddle.enable_static()
23
+ import re
24
+ def cut_sent(para):
25
+ para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para)
26
+ para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para)
27
+ para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para)
28
+ para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para)
29
+ para = re.sub('([。!?\?][”’])$', r'\1\n', para)
30
+ para = para.rstrip()
31
+ return para.split("\n")
32
+
33
+ def is_subword(token: str):
34
+ return token.startswith('##')
35
+
36
+ def binary_encoding_function(token):
37
+ hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16)
38
+ random_bit = hash_value % 2
39
+ return random_bit
40
+
41
+ def is_similar(x, y, threshold=0.5):
42
+ distance = Levenshtein.distance(x, y)
43
+ if distance / max(len(x), len(y)) < threshold:
44
+ return True
45
+ return False
46
+
47
+ class watermark_model:
48
+ def __init__(self, language, mode, tau_word, lamda):
49
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ self.language = language
51
+ self.mode = mode
52
+ self.tau_word = tau_word
53
+ self.tau_sent = 0.8
54
+ self.lamda = lamda
55
+ self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG'])#set(['','f','u','nr','nw','nz','m','r','p','c','w','PER','LOC','ORG'])
56
+ self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS'])
57
+ if language == 'Chinese':
58
+ self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity")
59
+ self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device)
60
+ self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base")
61
+ self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device)
62
+ self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000)
63
+ elif language == 'English':
64
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
65
+ self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device)
66
+ self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device)
67
+ self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli')
68
+ self.w2v_model = api.load("glove-wiki-gigaword-100")
69
+ nltk.download('stopwords')
70
+ self.stop_words = set(stopwords.words('english'))
71
+ self.nlp = spacy.load('en_core_web_sm')
72
+
73
+ def cut(self,ori_text,text_len):
74
+ if self.language == 'Chinese':
75
+ if len(ori_text) > text_len+5:
76
+ ori_text = ori_text[:text_len+5]
77
+ if len(ori_text) < text_len-5:
78
+ return 'Short'
79
+ return ori_text
80
+ elif self.language == 'English':
81
+ tokens = self.tokenizer.tokenize(ori_text)
82
+ if len(tokens) > text_len+5:
83
+ ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5])
84
+ if len(tokens) < text_len-5:
85
+ return 'Short'
86
+ return ori_text
87
+ else:
88
+ print(f'Unsupported Language:{self.language}')
89
+ raise NotImplementedError
90
+
91
+ def sent_tokenize(self,ori_text):
92
+ if self.language == 'Chinese':
93
+ return cut_sent(ori_text)
94
+ elif self.language == 'English':
95
+ return nltk.sent_tokenize(ori_text)
96
+
97
+ def pos_filter(self, tokens, masked_token_index, input_text):
98
+ if self.language == 'Chinese':
99
+ pairs = posseg.lcut(input_text)
100
+ pos_dict = {word: pos for word, pos in pairs}
101
+ pos_list_input = [pos for _, pos in pairs]
102
+ pos = pos_dict.get(tokens[masked_token_index], '')
103
+ if pos in self.cn_tag_black_list:
104
+ return False
105
+ else:
106
+ return True
107
+ elif self.language == 'English':
108
+ pos_tags = pos_tag(tokens)
109
+ pos = pos_tags[masked_token_index][1]
110
+ if pos not in self.en_tag_white_list:
111
+ return False
112
+ if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation):
113
+ return False
114
+ return True
115
+
116
+ def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text):
117
+ if self.language == 'English':
118
+ filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)]
119
+
120
+ base_word = tokens[masked_token_index]
121
+
122
+ processed_tokens = [tok for tok in filtered_tokens if not is_similar(tok,base_word)]
123
+ return processed_tokens
124
+ elif self.language == 'Chinese':
125
+ pairs = posseg.lcut(input_text)
126
+ pos_dict = {word: pos for word, pos in pairs}
127
+ pos_list_input = [pos for _, pos in pairs]
128
+ pos = pos_dict.get(tokens[masked_token_index], '')
129
+ filtered_tokens = []
130
+ for tok in top_n_tokens:
131
+ watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1])
132
+ watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest)
133
+ pairs_tok = posseg.lcut(watermarked_text_segtest)
134
+ pos_dict_tok = {word: pos for word, pos in pairs_tok}
135
+ flag = pos_dict_tok.get(tok, '')
136
+ if flag not in self.cn_tag_black_list and flag == pos:
137
+ filtered_tokens.append(tok)
138
+ processed_tokens = filtered_tokens
139
+ return processed_tokens
140
+
141
+ def global_word_sim(self,word,ori_word):
142
+ try:
143
+ global_score = self.w2v_model.similarity(word,ori_word)
144
+ except KeyError:
145
+ global_score = 0
146
+ return global_score
147
+
148
+ def context_word_sim(self, init_candidates_list, tokens, index_space, input_text):
149
+ original_input_tensor = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
150
+
151
+ all_cos_sims = []
152
+
153
+ for init_candidates, masked_token_index in zip(init_candidates_list, index_space):
154
+ batch_input_ids = [
155
+ [self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1] + ['[SEP]'])] for token in
156
+ init_candidates]
157
+ batch_input_tensors = torch.tensor(batch_input_ids).squeeze(1).to(self.device)
158
+
159
+ batch_input_tensors = torch.cat((batch_input_tensors, original_input_tensor), dim=0)
160
+
161
+ with torch.no_grad():
162
+ outputs = self.model(batch_input_tensors)
163
+ cos_sims = torch.zeros([len(init_candidates)]).to(self.device)
164
+ num_layers = len(outputs[1])
165
+ N = 8
166
+ i = masked_token_index
167
+ # We want to calculate similarity for the last N layers
168
+ hidden_states = outputs[1][-N:]
169
+
170
+ # Shape of hidden_states: [N, batch_size, sequence_length, hidden_size]
171
+ hidden_states = torch.stack(hidden_states)
172
+
173
+ # Separate the source and candidate hidden states
174
+ source_hidden_states = hidden_states[:, len(init_candidates):, i, :]
175
+ candidate_hidden_states = hidden_states[:, :len(init_candidates), i, :]
176
+
177
+ # Calculate cosine similarities across all layers and sum
178
+ cos_sim_sum = F.cosine_similarity(source_hidden_states.unsqueeze(2), candidate_hidden_states.unsqueeze(1), dim=-1).sum(dim=0)
179
+
180
+ cos_sim_avg = cos_sim_sum / N
181
+ cos_sims += cos_sim_avg.squeeze()
182
+
183
+ all_cos_sims.append(cos_sims.tolist())
184
+
185
+ return all_cos_sims
186
+
187
+
188
+ def sentence_sim(self, init_candidates_list, tokens, index_space, input_text):
189
+
190
+ batch_size=128
191
+ all_batch_sentences = []
192
+ all_index_lengths = []
193
+ for init_candidates, masked_token_index in zip(init_candidates_list, index_space):
194
+ if self.language == 'Chinese':
195
+ batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1]) for token in init_candidates]
196
+ batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents]
197
+ all_batch_sentences.extend([input_text + '[SEP]' + s for s in batch_sentences])
198
+ elif self.language == 'English':
199
+ batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1]) for token in init_candidates]
200
+ all_batch_sentences.extend([input_text + '</s></s>' + s for s in batch_sentences])
201
+
202
+ all_index_lengths.append(len(init_candidates))
203
+
204
+ all_relatedness_scores = []
205
+ start_index = 0
206
+ for i in range(0, len(all_batch_sentences), batch_size):
207
+ batch_sentences = all_batch_sentences[i: i + batch_size]
208
+ encoded_dict = self.relatedness_tokenizer.batch_encode_plus(
209
+ batch_sentences,
210
+ padding=True,
211
+ truncation=True,
212
+ max_length=512,
213
+ return_tensors='pt')
214
+
215
+ input_ids = encoded_dict['input_ids'].to(self.device)
216
+ attention_masks = encoded_dict['attention_mask'].to(self.device)
217
+
218
+ with torch.no_grad():
219
+ outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks)
220
+ logits = outputs[0]
221
+ probs = torch.softmax(logits, dim=1)
222
+ if self.language == 'Chinese':
223
+ relatedness_scores = probs[:, 1]#.tolist()
224
+ elif self.language == 'English':
225
+ relatedness_scores = probs[:, 2]#.tolist()
226
+ all_relatedness_scores.extend(relatedness_scores)
227
+
228
+ all_relatedness_scores_split = []
229
+ for length in all_index_lengths:
230
+ all_relatedness_scores_split.append(all_relatedness_scores[start_index:start_index + length])
231
+ start_index += length
232
+
233
+
234
+ return all_relatedness_scores_split
235
+
236
+
237
+ def candidates_gen(self, tokens, index_space, input_text, topk=64, dropout_prob=0.3):
238
+ input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens)
239
+ new_index_space = []
240
+ masked_text = self.tokenizer.convert_tokens_to_string(tokens)
241
+ # Create a tensor of input IDs
242
+ input_tensor = torch.tensor([input_ids_bert]).to(self.device)
243
+
244
+ with torch.no_grad():
245
+ embeddings = self.model.bert.embeddings(input_tensor.repeat(len(index_space), 1))
246
+
247
+ dropout = nn.Dropout2d(p=dropout_prob)
248
+
249
+ masked_indices = torch.tensor(index_space).to(self.device)
250
+ embeddings[torch.arange(len(index_space)), masked_indices] = dropout(embeddings[torch.arange(len(index_space)), masked_indices])
251
+
252
+
253
+ with torch.no_grad():
254
+ outputs = self.model(inputs_embeds=embeddings)
255
+
256
+ all_processed_tokens = []
257
+ for i, masked_token_index in enumerate(index_space):
258
+ predicted_logits = outputs[0][i][masked_token_index]
259
+ # Set the number of top predictions to return
260
+ n = topk
261
+ # Get the top n predicted tokens and their probabilities
262
+ probs = torch.nn.functional.softmax(predicted_logits, dim=-1)
263
+ top_n_probs, top_n_indices = torch.topk(probs, n)
264
+ top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist())
265
+ processed_tokens = self.filter_special_candidate(top_n_tokens, tokens, masked_token_index,input_text)
266
+
267
+ if tokens[masked_token_index] not in processed_tokens:
268
+ processed_tokens = [tokens[masked_token_index]] + processed_tokens
269
+ all_processed_tokens.append(processed_tokens)
270
+ new_index_space.append(masked_token_index)
271
+
272
+ return all_processed_tokens,new_index_space
273
+
274
+
275
+ def filter_candidates(self, init_candidates_list, tokens, index_space, input_text):
276
+
277
+ all_context_word_similarity_scores = self.context_word_sim(init_candidates_list, tokens, index_space, input_text)
278
+
279
+ all_sentence_similarity_scores = self.sentence_sim(init_candidates_list, tokens, index_space, input_text)
280
+
281
+ all_filtered_candidates = []
282
+ new_index_space = []
283
+
284
+ for init_candidates, context_word_similarity_scores, sentence_similarity_scores, masked_token_index in zip(init_candidates_list, all_context_word_similarity_scores, all_sentence_similarity_scores, index_space):
285
+ filtered_candidates = []
286
+ for idx, candidate in enumerate(init_candidates):
287
+ global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate)
288
+ word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score
289
+ if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent:
290
+ filtered_candidates.append((candidate, word_similarity_score))
291
+
292
+ if len(filtered_candidates) >= 1:
293
+ all_filtered_candidates.append(filtered_candidates)
294
+ new_index_space.append(masked_token_index)
295
+ return all_filtered_candidates, new_index_space
296
+
297
+ def get_candidate_encodings(self, tokens, enhanced_candidates, index_space):
298
+ best_candidates = []
299
+ new_index_space = []
300
+
301
+ for init_candidates, masked_token_index in zip(enhanced_candidates, index_space):
302
+ filtered_candidates = []
303
+
304
+ for idx, candidate in enumerate(init_candidates):
305
+ if masked_token_index-1 in new_index_space:
306
+ bit = binary_encoding_function(best_candidates[-1]+candidate[0])
307
+ else:
308
+ bit = binary_encoding_function(tokens[masked_token_index-1]+candidate[0])
309
+
310
+ if bit==1:
311
+ filtered_candidates.append(candidate)
312
+
313
+ # Sort the candidates based on their scores
314
+ filtered_candidates = sorted(filtered_candidates, key=lambda x: x[1], reverse=True)
315
+
316
+ if len(filtered_candidates) >= 1:
317
+ best_candidates.append(filtered_candidates[0][0])
318
+ new_index_space.append(masked_token_index)
319
+
320
+ return best_candidates, new_index_space
321
+
322
+ def watermark_embed(self,text):
323
+ input_text = text
324
+ # Tokenize the input text
325
+ tokens = self.tokenizer.tokenize(input_text)
326
+ tokens = ['[CLS]'] + tokens + ['[SEP]']
327
+ masked_tokens=tokens.copy()
328
+ start_index = 1
329
+ end_index = len(tokens) - 1
330
+
331
+ index_space = []
332
+
333
+ for masked_token_index in range(start_index+1, end_index-1):
334
+ binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index])
335
+ if binary_encoding == 1 and masked_token_index-1 not in index_space:
336
+ continue
337
+ if not self.pos_filter(tokens,masked_token_index,input_text):
338
+ continue
339
+ index_space.append(masked_token_index)
340
+
341
+ if len(index_space)==0:
342
+ return text
343
+ init_candidates, new_index_space = self.candidates_gen(tokens,index_space,input_text, 8, 0)
344
+ if len(new_index_space)==0:
345
+ return text
346
+ enhanced_candidates, new_index_space = self.filter_candidates(init_candidates,tokens,new_index_space,input_text)
347
+
348
+ enhanced_candidates, new_index_space = self.get_candidate_encodings(tokens, enhanced_candidates, new_index_space)
349
+
350
+ for init_candidate, masked_token_index in zip(enhanced_candidates, new_index_space):
351
+ tokens[masked_token_index] = init_candidate
352
+ watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1])
353
+
354
+ if self.language == 'Chinese':
355
+ watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text)
356
+ return watermarked_text
357
+
358
+ def embed(self, ori_text):
359
+ sents = self.sent_tokenize(ori_text)
360
+ sents = [s for s in sents if s.strip()]
361
+ num_sents = len(sents)
362
+ watermarked_text = ''
363
+
364
+ for i in range(0, num_sents, 2):
365
+ if i+1 < num_sents:
366
+ sent_pair = sents[i] + sents[i+1]
367
+ else:
368
+ sent_pair = sents[i]
369
+ # keywords = jieba.analyse.extract_tags(sent_pair, topK=5, withWeight=False)
370
+ if len(watermarked_text) == 0:
371
+ watermarked_text = self.watermark_embed(sent_pair)
372
+ else:
373
+ watermarked_text = watermarked_text + self.watermark_embed(sent_pair)
374
+ if len(self.get_encodings_fast(ori_text)) == 0:
375
+ # print(ori_text)
376
+ return ''
377
+ return watermarked_text
378
+
379
+ def get_encodings_fast(self,text):
380
+ sents = self.sent_tokenize(text)
381
+ sents = [s for s in sents if s.strip()]
382
+ num_sents = len(sents)
383
+ encodings = []
384
+ for i in range(0, num_sents, 2):
385
+ if i+1 < num_sents:
386
+ sent_pair = sents[i] + sents[i+1]
387
+ else:
388
+ sent_pair = sents[i]
389
+ tokens = self.tokenizer.tokenize(sent_pair)
390
+
391
+ for index in range(1,len(tokens)-1):
392
+ if not self.pos_filter(tokens,index,text):
393
+ continue
394
+ bit = binary_encoding_function(tokens[index-1]+tokens[index])
395
+ encodings.append(bit)
396
+ return encodings
397
+
398
+ def watermark_detector_fast(self, text,alpha=0.05):
399
+ p = 0.5
400
+ encodings = self.get_encodings_fast(text)
401
+ n = len(encodings)
402
+ ones = sum(encodings)
403
+ if n == 0:
404
+ z = 0
405
+ else:
406
+ z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
407
+ threshold = norm.ppf(1 - alpha, loc=0, scale=1)
408
+ p_value = norm.sf(z)
409
+ # p_value = norm.sf(abs(z)) * 2
410
+ is_watermark = z >= threshold
411
+ return is_watermark, p_value, n, ones, z
412
+
413
+ def get_encodings_precise(self, text):
414
+ # pdb.set_trace()
415
+ sents = self.sent_tokenize(text)
416
+ sents = [s for s in sents if s.strip()]
417
+ num_sents = len(sents)
418
+ encodings = []
419
+ for i in range(0, num_sents, 2):
420
+ if i+1 < num_sents:
421
+ sent_pair = sents[i] + sents[i+1]
422
+ else:
423
+ sent_pair = sents[i]
424
+
425
+ tokens = self.tokenizer.tokenize(sent_pair)
426
+
427
+ tokens = ['[CLS]'] + tokens + ['[SEP]']
428
+
429
+ masked_tokens=tokens.copy()
430
+
431
+ start_index = 1
432
+ end_index = len(tokens) - 1
433
+
434
+ index_space = []
435
+ for masked_token_index in range(start_index+1, end_index-1):
436
+ if not self.pos_filter(tokens,masked_token_index,sent_pair):
437
+ continue
438
+ index_space.append(masked_token_index)
439
+ if len(index_space)==0:
440
+ continue
441
+
442
+ init_candidates, new_index_space = self.candidates_gen(tokens,index_space,sent_pair, 8, 0)
443
+ enhanced_candidates, new_index_space = self.filter_candidates(init_candidates,tokens,new_index_space,sent_pair)
444
+
445
+ # pdb.set_trace()
446
+ for j,idx in enumerate(new_index_space):
447
+ if len(enhanced_candidates[j])>1:
448
+ bit = binary_encoding_function(tokens[idx-1]+tokens[idx])
449
+ encodings.append(bit)
450
+ return encodings
451
+
452
+
453
+ def watermark_detector_precise(self,text,alpha=0.05):
454
+ p = 0.5
455
+ encodings = self.get_encodings_precise(text)
456
+ n = len(encodings)
457
+ ones = sum(encodings)
458
+ if n == 0:
459
+ z = 0
460
+ else:
461
+ z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
462
+ threshold = norm.ppf(1 - alpha, loc=0, scale=1)
463
+ p_value = norm.sf(z)
464
+ is_watermark = z >= threshold
465
+ return is_watermark, p_value, n, ones, z
models/watermark_original.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.corpus import stopwords
3
+ from nltk import word_tokenize, pos_tag
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ import hashlib
8
+ from scipy.stats import norm
9
+ import gensim
10
+ import pdb
11
+ from transformers import BertForMaskedLM as WoBertForMaskedLM
12
+ from wobert import WoBertTokenizer
13
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
+
15
+ from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
16
+ import gensim.downloader as api
17
+ import Levenshtein
18
+ import string
19
+ import spacy
20
+ import paddle
21
+ from jieba import posseg
22
+
23
+ paddle.enable_static()
24
+ import re
25
+ def cut_sent(para):
26
+ para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para)
27
+ para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para)
28
+ para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para)
29
+ para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para)
30
+ para = re.sub('([。!?\?][”’])$', r'\1\n', para)
31
+ para = para.rstrip()
32
+ return para.split("\n")
33
+
34
+ def is_subword(token: str):
35
+ return token.startswith('##')
36
+
37
+ def binary_encoding_function(token):
38
+ hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16)
39
+ random_bit = hash_value % 2
40
+ return random_bit
41
+
42
+ def is_similar(x, y, threshold=0.5):
43
+ distance = Levenshtein.distance(x, y)
44
+ if distance / max(len(x), len(y)) < threshold:
45
+ return True
46
+ return False
47
+
48
+ class watermark_model:
49
+ def __init__(self, language, mode, tau_word, lamda):
50
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ self.language = language
52
+ self.mode = mode
53
+ self.tau_word = tau_word
54
+ self.tau_sent = 0.8
55
+ self.lamda = lamda
56
+ self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG'])#set(['','f','u','nr','nw','nz','m','r','p','c','w','PER','LOC','ORG'])
57
+ self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS'])
58
+ if language == 'Chinese':
59
+ self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity")
60
+ self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device)
61
+ self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base")
62
+ self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device)
63
+ self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000)
64
+ elif language == 'English':
65
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
66
+ self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device)
67
+ self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device)
68
+ self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli')
69
+ self.w2v_model = api.load("glove-wiki-gigaword-100")
70
+ nltk.download('stopwords')
71
+ self.stop_words = set(stopwords.words('english'))
72
+ self.nlp = spacy.load('en_core_web_sm')
73
+
74
+ def cut(self,ori_text,text_len):
75
+ if self.language == 'Chinese':
76
+ if len(ori_text) > text_len+5:
77
+ ori_text = ori_text[:text_len+5]
78
+ if len(ori_text) < text_len-5:
79
+ return 'Short'
80
+ elif self.language == 'English':
81
+ tokens = self.tokenizer.tokenize(ori_text)
82
+ if len(tokens) > text_len+5:
83
+ ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5])
84
+ if len(tokens) < text_len-5:
85
+ return 'Short'
86
+ return ori_text
87
+ else:
88
+ print(f'Unsupported Language:{self.language}')
89
+ raise NotImplementedError
90
+
91
+ def sent_tokenize(self,ori_text):
92
+ if self.language == 'Chinese':
93
+ return cut_sent(ori_text)
94
+ elif self.language == 'English':
95
+ return nltk.sent_tokenize(ori_text)
96
+
97
+ def pos_filter(self, tokens, masked_token_index, input_text):
98
+ if self.language == 'Chinese':
99
+ pairs = posseg.lcut(input_text)
100
+ pos_dict = {word: pos for word, pos in pairs}
101
+ pos_list_input = [pos for _, pos in pairs]
102
+ pos = pos_dict.get(tokens[masked_token_index], '')
103
+ if pos in self.cn_tag_black_list:
104
+ return False
105
+ else:
106
+ return True
107
+ elif self.language == 'English':
108
+ pos_tags = pos_tag(tokens)
109
+ pos = pos_tags[masked_token_index][1]
110
+ if pos not in self.en_tag_white_list:
111
+ return False
112
+ if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation):
113
+ return False
114
+ return True
115
+
116
+ def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text):
117
+ if self.language == 'English':
118
+ filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)]
119
+
120
+ lemmatized_tokens = []
121
+ # for token in filtered_tokens:
122
+ # doc = self.nlp(token)
123
+ # lemma = doc[0].lemma_ if doc[0].lemma_ != "-PRON-" else token
124
+ # lemmatized_tokens.append(lemma)
125
+
126
+ base_word = tokens[masked_token_index]
127
+ base_word_lemma = self.nlp(base_word)[0].lemma_
128
+ processed_tokens = [base_word]+[tok for tok in filtered_tokens if self.nlp(tok)[0].lemma_ != base_word_lemma]
129
+ return processed_tokens
130
+ elif self.language == 'Chinese':
131
+ pairs = posseg.lcut(input_text)
132
+ pos_dict = {word: pos for word, pos in pairs}
133
+ pos_list_input = [pos for _, pos in pairs]
134
+ pos = pos_dict.get(tokens[masked_token_index], '')
135
+ filtered_tokens = []
136
+ for tok in top_n_tokens:
137
+ watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1])
138
+ watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest)
139
+ pairs_tok = posseg.lcut(watermarked_text_segtest)
140
+ pos_dict_tok = {word: pos for word, pos in pairs_tok}
141
+ flag = pos_dict_tok.get(tok, '')
142
+ if flag not in self.cn_tag_black_list and flag == pos:
143
+ filtered_tokens.append(tok)
144
+ processed_tokens = filtered_tokens
145
+ return processed_tokens
146
+
147
+ def global_word_sim(self,word,ori_word):
148
+ try:
149
+ global_score = self.w2v_model.similarity(word,ori_word)
150
+ except KeyError:
151
+ global_score = 0
152
+ return global_score
153
+
154
+ def context_word_sim(self,init_candidates, tokens, masked_token_index, input_text):
155
+ original_input_tensor = self.tokenizer.encode(input_text,return_tensors='pt').to(self.device)
156
+ batch_input_ids = [[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]+ ['[SEP]'])] for token in init_candidates]
157
+ batch_input_tensors = torch.tensor(batch_input_ids).squeeze().to(self.device)
158
+ batch_input_tensors = torch.cat((batch_input_tensors,original_input_tensor),dim=0)
159
+ with torch.no_grad():
160
+ outputs = self.model(batch_input_tensors)
161
+ cos_sims = torch.zeros([len(init_candidates)]).to(self.device)
162
+ num_layers = len(outputs[1])
163
+ N = 8
164
+ i = masked_token_index
165
+ cos_sim_sum = 0
166
+ for layer in range(num_layers-N,num_layers):
167
+ ls_hidden_states = outputs[1][layer][0:len(init_candidates), i, :]
168
+ source_hidden_state = outputs[1][layer][len(init_candidates), i, :]
169
+ cos_sim_sum += F.cosine_similarity(source_hidden_state, ls_hidden_states, dim=1)
170
+ cos_sim_avg = cos_sim_sum / N
171
+
172
+ cos_sims += cos_sim_avg
173
+ return cos_sims.tolist()
174
+
175
+ def sentence_sim(self,init_candidates, tokens, masked_token_index, input_text):
176
+ if self.language == 'Chinese':
177
+ batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates]
178
+ batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents]
179
+ roberta_inputs = [input_text + '[SEP]' + s for s in batch_sentences]
180
+ elif self.language == 'English':
181
+ batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates]
182
+ roberta_inputs = [input_text + '</s></s>' + s for s in batch_sentences]
183
+
184
+ encoded_dict = self.relatedness_tokenizer.batch_encode_plus(
185
+ roberta_inputs,
186
+ padding=True,
187
+ truncation=True,
188
+ max_length=512,
189
+ return_tensors='pt')
190
+ # Extract input_ids and attention_masks
191
+ input_ids = encoded_dict['input_ids'].to(self.device)
192
+ attention_masks = encoded_dict['attention_mask'].to(self.device)
193
+ with torch.no_grad():
194
+ outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks)
195
+ logits = outputs[0]
196
+ probs = torch.softmax(logits, dim=1)
197
+ if self.language == 'Chinese':
198
+ relatedness_scores = probs[:, 1].tolist()
199
+ elif self.language == 'English':
200
+ relatedness_scores = probs[:, 2].tolist()
201
+
202
+ return relatedness_scores
203
+
204
+ def candidates_gen(self,tokens,masked_token_index,input_text,topk=64, dropout_prob=0.3):
205
+ input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens)
206
+ if not self.pos_filter(tokens,masked_token_index,input_text):
207
+ return []
208
+ masked_text = self.tokenizer.convert_tokens_to_string(tokens)
209
+ # Create a tensor of input IDs
210
+ input_tensor = torch.tensor([input_ids_bert]).to(self.device)
211
+
212
+ with torch.no_grad():
213
+ embeddings = self.model.bert.embeddings(input_tensor)
214
+ dropout = nn.Dropout2d(p=dropout_prob)
215
+ # Get the predicted logits
216
+ embeddings[:, masked_token_index, :] = dropout(embeddings[:, masked_token_index, :])
217
+ with torch.no_grad():
218
+ outputs = self.model(inputs_embeds=embeddings)
219
+
220
+ predicted_logits = outputs[0][0][masked_token_index]
221
+
222
+ # Set the number of top predictions to return
223
+ n = topk
224
+ # Get the top n predicted tokens and their probabilities
225
+ probs = torch.nn.functional.softmax(predicted_logits, dim=-1)
226
+ top_n_probs, top_n_indices = torch.topk(probs, n)
227
+ top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist())
228
+ processed_tokens = self.filter_special_candidate(top_n_tokens,tokens,masked_token_index)
229
+
230
+ return processed_tokens
231
+
232
+ def filter_candidates(self, init_candidates, tokens, masked_token_index, input_text):
233
+ context_word_similarity_scores = self.context_word_sim(init_candidates, tokens, masked_token_index, input_text)
234
+ sentence_similarity_scores = self.sentence_sim(init_candidates, tokens, masked_token_index, input_text)
235
+ filtered_candidates = []
236
+ for idx, candidate in enumerate(init_candidates):
237
+ global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate)
238
+ word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score
239
+ if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent:
240
+ filtered_candidates.append((candidate, word_similarity_score))#, sentence_similarity_scores[idx]))
241
+ return filtered_candidates
242
+
243
+ def watermark_embed(self,text):
244
+ input_text = text
245
+ # Tokenize the input text
246
+ tokens = self.tokenizer.tokenize(input_text)
247
+ tokens = ['[CLS]'] + tokens + ['[SEP]']
248
+ masked_tokens=tokens.copy()
249
+ start_index = 1
250
+ end_index = len(tokens) - 1
251
+ for masked_token_index in range(start_index+1, end_index-1):
252
+ # pdb.set_trace()
253
+ binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index])
254
+ if binary_encoding == 1:
255
+ continue
256
+ init_candidates = self.candidates_gen(tokens,masked_token_index,input_text, 32, 0.3)
257
+ if len(init_candidates) <=1:
258
+ continue
259
+ enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,input_text)
260
+ hash_top_tokens = enhanced_candidates.copy()
261
+ for i, tok in enumerate(enhanced_candidates):
262
+ binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tok[0])
263
+ if binary_encoding != 1 or (is_similar(tok[0], tokens[masked_token_index])) or (tokens[masked_token_index - 1] in tok or tokens[masked_token_index + 1] in tok):
264
+ hash_top_tokens.remove(tok)
265
+ hash_top_tokens.sort(key=lambda x: x[1], reverse=True)
266
+ if len(hash_top_tokens) > 0:
267
+ selected_token = hash_top_tokens[0][0]
268
+ else:
269
+ selected_token = tokens[masked_token_index]
270
+
271
+ tokens[masked_token_index] = selected_token
272
+ watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1])
273
+ if self.language == 'Chinese':
274
+ watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text)
275
+
276
+ return watermarked_text
277
+
278
+ def embed(self, ori_text):
279
+ sents = self.sent_tokenize(ori_text)
280
+ sents = [s for s in sents if s.strip()]
281
+ num_sents = len(sents)
282
+ watermarked_text = ''
283
+ for i in range(0, num_sents, 2):
284
+ if i+1 < num_sents:
285
+ sent_pair = sents[i] + sents[i+1]
286
+ else:
287
+ sent_pair = sents[i]
288
+ if len(watermarked_text) == 0:
289
+ watermarked_text = self.watermark_embed(sent_pair)
290
+ else:
291
+ watermarked_text = watermarked_text + self.watermark_embed(sent_pair)
292
+ if len(self.get_encodings_fast(ori_text)) == 0:
293
+ return ''
294
+ return watermarked_text
295
+
296
+ def get_encodings_fast(self,text):
297
+ sents = self.sent_tokenize(text)
298
+ sents = [s for s in sents if s.strip()]
299
+ num_sents = len(sents)
300
+ encodings = []
301
+ for i in range(0, num_sents, 2):
302
+ if i+1 < num_sents:
303
+ sent_pair = sents[i] + sents[i+1]
304
+ else:
305
+ sent_pair = sents[i]
306
+ tokens = self.tokenizer.tokenize(sent_pair)
307
+
308
+ for index in range(1,len(tokens)-1):
309
+ if not self.pos_filter(tokens,index,text):
310
+ continue
311
+ bit = binary_encoding_function(tokens[index-1]+tokens[index])
312
+ encodings.append(bit)
313
+ return encodings
314
+
315
+ def watermark_detector_fast(self, text,alpha=0.05):
316
+ p = 0.5
317
+ encodings = self.get_encodings_fast(text)
318
+ n = len(encodings)
319
+ ones = sum(encodings)
320
+ z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
321
+ threshold = norm.ppf(1 - alpha, loc=0, scale=1)
322
+ p_value = norm.sf(z)
323
+ is_watermark = z >= threshold
324
+ return is_watermark, p_value, n, ones, z
325
+
326
+ def get_encodings_precise(self, text):
327
+ sents = self.sent_tokenize(text)
328
+ sents = [s for s in sents if s.strip()]
329
+ num_sents = len(sents)
330
+ encodings = []
331
+ for i in range(0, num_sents, 2):
332
+ if i+1 < num_sents:
333
+ sent_pair = sents[i] + sents[i+1]
334
+ else:
335
+ sent_pair = sents[i]
336
+
337
+ tokens = self.tokenizer.tokenize(sent_pair)
338
+
339
+ tokens = ['[CLS]'] + tokens + ['[SEP]']
340
+
341
+ masked_tokens=tokens.copy()
342
+
343
+ start_index = 1
344
+ end_index = len(tokens) - 1
345
+
346
+ for masked_token_index in range(start_index+1, end_index-1):
347
+ init_candidates = self.candidates_gen(tokens,masked_token_index,sent_pair, 8, 0)
348
+ if len(init_candidates) <=1:
349
+ continue
350
+ enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,sent_pair)
351
+ if len(enhanced_candidates) > 1:
352
+ bit = binary_encoding_function(tokens[masked_token_index-1]+tokens[masked_token_index])
353
+ encodings.append(bit)
354
+ return encodings
355
+
356
+ def watermark_detector_precise(self,text,alpha=0.05):
357
+ p = 0.5
358
+ encodings = self.get_encodings_precise(text)
359
+ n = len(encodings)
360
+ ones = sum(encodings)
361
+ if n == 0:
362
+ z = 0
363
+ else:
364
+ z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
365
+ threshold = norm.ppf(1 - alpha, loc=0, scale=1)
366
+ p_value = norm.sf(z)
367
+ is_watermark = z >= threshold
368
+ return is_watermark, p_value, n, ones, z
options.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ # TODO: add help for the parameters
3
+
4
+ def get_parser_main_model():
5
+ parser = argparse.ArgumentParser()
6
+ # TODO: basic parameters training related
7
+
8
+ # for embed
9
+ parser.add_argument('--language', type=str, default='English', help='text language')
10
+ parser.add_argument('--mode', type=str, choices=['embed', 'fast_detect', 'precise_detect'], default='embed', help='Mode options: embed (default), fast_detect, precise_detect')
11
+ parser.add_argument('--tau_word', type=float, default=0.8, help='word-level similarity thresh')
12
+ parser.add_argument('--lamda', type=float, default=0.83, help='word-level similarity weight')
13
+
14
+ return parser
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gensim==4.3.0
2
+ gradio==3.30.0
3
+ jieba==0.42.1
4
+ nltk==3.8.1
5
+ paddle==1.0.2
6
+ paddlepaddle==2.4.2
7
+ python_Levenshtein==0.21.0
8
+ scipy==1.7.3
9
+ spacy==3.5.0
10
+ torch==1.11.0
11
+ transformers==4.26.1
12
+ wobert==0.0.1