File size: 2,980 Bytes
ea6afa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import nltk
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModelForMaskedLM
from vocabulary_split import split_vocabulary, filter_logits
import torch
from lcs import find_common_subsequences
from paraphraser import generate_paraphrase

nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)

tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")

permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])

def get_non_melting_points(original_sentence):
    paraphrased_sentences = generate_paraphrase(original_sentence)
    common_subsequences = find_common_subsequences(original_sentence, paraphrased_sentences)
    return common_subsequences

def get_word_between_points(sentence, start_point, end_point):
    words = nltk.word_tokenize(sentence)
    stop_words = set(stopwords.words('english'))
    start_index = sentence.index(start_point[1])
    end_index = sentence.index(end_point[1])
    
    for word in words[start_index+1:end_index]:
        if word.lower() not in stop_words:
            return word, words.index(word)
    return None, None

def get_logits_for_mask(sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    mask_token_logits = logits[0, mask_token_index, :]
    return mask_token_logits.squeeze()

def detect_watermark(sentence):
    non_melting_points = get_non_melting_points(sentence)
    
    if len(non_melting_points) < 2:
        return False, "Not enough non-melting points found."

    word_to_check, index = get_word_between_points(sentence, non_melting_points[0], non_melting_points[1])
    
    if word_to_check is None:
        return False, "No suitable word found between non-melting points."

    words = nltk.word_tokenize(sentence)
    masked_sentence = ' '.join(words[:index] + ['[MASK]'] + words[index+1:])

    logits = get_logits_for_mask(masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)

    top_predictions = filtered_logits.argsort()[-5:]
    predicted_words = [tokenizer.decode([i]) for i in top_predictions]

    if word_to_check in predicted_words:
        return True, f"Watermark detected. The word '{word_to_check}' is in the permissible vocabulary."
    else:
        return False, f"No watermark detected. The word '{word_to_check}' is not in the permissible vocabulary."

# Example usage
# if __name__ == "__main__":
#     test_sentence = "The quick brown fox jumps over the lazy dog."
#     is_watermarked, message = detect_watermark(test_sentence)
#     print(f"Is the sentence watermarked? {is_watermarked}")
#     print(f"Detection message: {message}")