Spaces:
Runtime error
Runtime error
BheemaShankerNeyigapula
commited on
Commit
•
ea6afa4
1
Parent(s):
00dafdf
Upload folder using huggingface_hub
Browse files- .gitignore +2 -0
- .gradio/certificate.pem +31 -0
- README.md +4 -8
- app.py +169 -0
- detectability.py +303 -0
- distortion.py +126 -0
- entailment.py +33 -0
- euclidean_distance.py +74 -0
- gpt_mask_filling.py +70 -0
- highlighter.py +92 -0
- lcs.py +63 -0
- masking_methods.py +137 -0
- masking_methods_trial.py +188 -0
- paraphraser.py +45 -0
- requirements.txt +21 -0
- sampling_methods.py +35 -0
- scores.py +51 -0
- threeD_plot.py +69 -0
- tree.py +240 -0
- vocabulary_split.py +56 -0
- watermark_detector.py +75 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__/
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
README.md
CHANGED
@@ -1,12 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 🚀
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.4.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
1 |
---
|
2 |
+
title: aiisc-watermarking-model
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.36.0
|
6 |
---
|
7 |
|
8 |
+
Clone the repository and ``cd`` into it. Run ``gradio app.py`` to start the server.
|
app.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
nltk.download('stopwords')
|
3 |
+
import plotly.graph_objs as go
|
4 |
+
from transformers import pipeline
|
5 |
+
import random
|
6 |
+
import gradio as gr
|
7 |
+
from tree import generate_subplot1, generate_subplot2
|
8 |
+
from paraphraser import generate_paraphrase
|
9 |
+
from lcs import find_common_subsequences, find_common_gram_positions
|
10 |
+
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
|
11 |
+
from entailment import analyze_entailment
|
12 |
+
from masking_methods import mask_non_stopword, high_entropy_words
|
13 |
+
from sampling_methods import sample_word
|
14 |
+
from detectability import SentenceDetectabilityCalculator
|
15 |
+
from distortion import SentenceDistortionCalculator
|
16 |
+
from euclidean_distance import SentenceEuclideanDistanceCalculator
|
17 |
+
from threeD_plot import gen_three_D_plot
|
18 |
+
|
19 |
+
|
20 |
+
# Function for the Gradio interface
|
21 |
+
def model(prompt):
|
22 |
+
user_prompt = prompt
|
23 |
+
paraphrased_sentences = generate_paraphrase(user_prompt)
|
24 |
+
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
|
25 |
+
|
26 |
+
common_grams = find_common_subsequences(user_prompt, selected_sentences)
|
27 |
+
subsequences = [subseq for _, subseq in common_grams]
|
28 |
+
common_grams_position = find_common_gram_positions(selected_sentences, subsequences)
|
29 |
+
|
30 |
+
# Create masked results using a single loop
|
31 |
+
masked_results = []
|
32 |
+
for sentence in paraphrased_sentences:
|
33 |
+
masked_results.extend([
|
34 |
+
(mask_non_stopword, sentence),
|
35 |
+
(mask_non_stopword, sentence, True),
|
36 |
+
(high_entropy_words, sentence, common_grams)
|
37 |
+
])
|
38 |
+
|
39 |
+
# Process masking functions and unpack results
|
40 |
+
masked_outputs = [
|
41 |
+
(func(sent) if len(result) == 2 else func(sent, extra))
|
42 |
+
for func, sent, *extra in masked_results
|
43 |
+
for result in [func(sent, *extra)]
|
44 |
+
]
|
45 |
+
|
46 |
+
# Unpack masked outputs into separate lists
|
47 |
+
masked_sentences, masked_words, masked_logits = zip(*masked_outputs) if masked_outputs else ([], [], [])
|
48 |
+
|
49 |
+
sampled_sentences = []
|
50 |
+
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
|
51 |
+
for technique in ['inverse_transform', 'exponential_minimum', 'temperature', 'greedy']:
|
52 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique=technique, temperature=1.0))
|
53 |
+
|
54 |
+
colors = ["red", "blue", "brown", "green"]
|
55 |
+
|
56 |
+
def select_color():
|
57 |
+
return random.choice(colors)
|
58 |
+
|
59 |
+
highlight_info = [(word, select_color()) for _, word in common_grams]
|
60 |
+
|
61 |
+
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
|
62 |
+
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
|
63 |
+
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
|
64 |
+
|
65 |
+
trees1, trees2 = [], []
|
66 |
+
|
67 |
+
for i, sentence in enumerate(paraphrased_sentences):
|
68 |
+
next_masked_sentences = masked_sentences[i * 3:(i + 1) * 3]
|
69 |
+
next_sampled_sentences = sampled_sentences[i * 12:(i + 1) * 12]
|
70 |
+
|
71 |
+
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
|
72 |
+
trees1.append(tree1)
|
73 |
+
|
74 |
+
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
|
75 |
+
trees2.append(tree2)
|
76 |
+
|
77 |
+
reparaphrased_sentences = generate_paraphrase(sampled_sentences)
|
78 |
+
|
79 |
+
# Process the sentences in batches of 10
|
80 |
+
reparaphrased_sentences_list = []
|
81 |
+
for i in range(0, len(reparaphrased_sentences), 10):
|
82 |
+
batch = reparaphrased_sentences[i:i + 10]
|
83 |
+
if len(batch) == 10:
|
84 |
+
html_block = reparaphrased_sentences_html(batch)
|
85 |
+
reparaphrased_sentences_list.append(html_block)
|
86 |
+
|
87 |
+
# Calculate metrics
|
88 |
+
distortion_calculator = SentenceDistortionCalculator(user_prompt, reparaphrased_sentences)
|
89 |
+
distortion_calculator.calculate_all_metrics()
|
90 |
+
distortion_calculator.normalize_metrics()
|
91 |
+
distortion = distortion_calculator.get_combined_distortions()
|
92 |
+
distortion_list = list(distortion.values())
|
93 |
+
|
94 |
+
detectability_calculator = SentenceDetectabilityCalculator(user_prompt, reparaphrased_sentences)
|
95 |
+
detectability_calculator.calculate_all_metrics()
|
96 |
+
detectability_calculator.normalize_metrics()
|
97 |
+
detectability = detectability_calculator.get_combined_detectabilities()
|
98 |
+
detectability_list = list(detectability.values())
|
99 |
+
|
100 |
+
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(user_prompt, reparaphrased_sentences)
|
101 |
+
euclidean_dist_calculator.calculate_all_metrics()
|
102 |
+
euclidean_dist_calculator.normalize_metrics()
|
103 |
+
euclidean_dist = euclidean_dist_calculator.get_normalized_metrics()
|
104 |
+
euclidean_dist_list = list(euclidean_dist.values())
|
105 |
+
|
106 |
+
three_D_plot = gen_three_D_plot(detectability_list, distortion_list, euclidean_dist_list)
|
107 |
+
|
108 |
+
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 + reparaphrased_sentences_list + [three_D_plot]
|
109 |
+
|
110 |
+
|
111 |
+
# Gradio Interface
|
112 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
113 |
+
gr.Markdown("# **AIISC Watermarking Model**")
|
114 |
+
|
115 |
+
with gr.Row():
|
116 |
+
user_input = gr.Textbox(label="User Prompt")
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
submit_button = gr.Button("Submit")
|
120 |
+
clear_button = gr.Button("Clear")
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
highlighted_user_prompt = gr.HTML()
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Tabs():
|
127 |
+
with gr.TabItem("Paraphrased Sentences"):
|
128 |
+
highlighted_accepted_sentences = gr.HTML()
|
129 |
+
with gr.TabItem("Discarded Sentences"):
|
130 |
+
highlighted_discarded_sentences = gr.HTML()
|
131 |
+
|
132 |
+
with gr.Row():
|
133 |
+
gr.Markdown("### Where to Watermark?") # Label for masked sentences trees
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Tabs():
|
136 |
+
tree1_tabs = [gr.Plot() for _ in range(10)] # Adjust this range according to the number of trees
|
137 |
+
for i, tree1 in enumerate(tree1_tabs):
|
138 |
+
with gr.TabItem(f"Sentence {i + 1}"):
|
139 |
+
pass # Placeholder for each tree plot
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
gr.Markdown("### How to Watermark?") # Label for sampled sentences trees
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Tabs():
|
145 |
+
tree2_tabs = [gr.Plot() for _ in range(10)] # Adjust this range according to the number of trees
|
146 |
+
for i, tree2 in enumerate(tree2_tabs):
|
147 |
+
with gr.TabItem(f"Sentence {i + 1}"):
|
148 |
+
pass # Placeholder for each tree plot
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
gr.Markdown("### Re-paraphrased Sentences") # Label for re-paraphrased sentences
|
152 |
+
|
153 |
+
with gr.Row():
|
154 |
+
with gr.Tabs():
|
155 |
+
reparaphrased_sentences_tabs = [gr.HTML() for _ in range(120)] # 120 tabs for 120 batches of sentences
|
156 |
+
for i, reparaphrased_sent_html in enumerate(reparaphrased_sentences_tabs):
|
157 |
+
with gr.TabItem(f"Sentence {i + 1}"):
|
158 |
+
pass # Placeholder for each batch
|
159 |
+
|
160 |
+
with gr.Row():
|
161 |
+
gr.Markdown("### 3D Plot for Sweet Spot")
|
162 |
+
with gr.Row():
|
163 |
+
three_D_plot = gr.Plot()
|
164 |
+
|
165 |
+
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot])
|
166 |
+
clear_button.click(lambda: "", inputs=None, outputs=user_input)
|
167 |
+
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot])
|
168 |
+
|
169 |
+
demo.launch(share=True)
|
detectability.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries
|
2 |
+
import nltk
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
from transformers import BertModel, BertTokenizer
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
10 |
+
|
11 |
+
# Download NLTK data if not already present
|
12 |
+
nltk.download('punkt', quiet=True)
|
13 |
+
|
14 |
+
class SentenceDetectabilityCalculator:
|
15 |
+
"""
|
16 |
+
A class to calculate and analyze detectability metrics between an original sentence and paraphrased sentences.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, original_sentence, paraphrased_sentences):
|
20 |
+
"""
|
21 |
+
Initialize the calculator with the original sentence and a list of paraphrased sentences.
|
22 |
+
"""
|
23 |
+
self.original_sentence = original_sentence
|
24 |
+
self.paraphrased_sentences = paraphrased_sentences
|
25 |
+
self.metrics = {
|
26 |
+
'BLEU Score': {},
|
27 |
+
'Cosine Similarity': {},
|
28 |
+
'STS Score': {}
|
29 |
+
}
|
30 |
+
self.normalized_metrics = {
|
31 |
+
'BLEU Score': {},
|
32 |
+
'Cosine Similarity': {},
|
33 |
+
'STS Score': {}
|
34 |
+
}
|
35 |
+
self.combined_detectabilities = {}
|
36 |
+
|
37 |
+
# Load pre-trained models
|
38 |
+
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
|
39 |
+
self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
40 |
+
self.sts_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
|
41 |
+
|
42 |
+
# Calculate original embeddings
|
43 |
+
self.original_embedding = self._get_sentence_embedding(self.original_sentence)
|
44 |
+
self.sts_original_embedding = self.sts_model.encode(self.original_sentence)
|
45 |
+
|
46 |
+
def calculate_all_metrics(self):
|
47 |
+
"""
|
48 |
+
Calculate all detectability metrics for each paraphrased sentence.
|
49 |
+
"""
|
50 |
+
for idx, paraphrased_sentence in enumerate(self.paraphrased_sentences):
|
51 |
+
key = f"Sentence_{idx + 1}"
|
52 |
+
self.metrics['BLEU Score'][key] = self._calculate_bleu(self.original_sentence, paraphrased_sentence)
|
53 |
+
paraphrase_embedding = self._get_sentence_embedding(paraphrased_sentence)
|
54 |
+
self.metrics['Cosine Similarity'][key] = cosine_similarity([self.original_embedding], [paraphrase_embedding])[0][0]
|
55 |
+
sts_paraphrase_embedding = self.sts_model.encode(paraphrased_sentence)
|
56 |
+
self.metrics['STS Score'][key] = cosine_similarity([self.sts_original_embedding], [sts_paraphrase_embedding])[0][0]
|
57 |
+
|
58 |
+
def normalize_metrics(self):
|
59 |
+
"""
|
60 |
+
Normalize all metrics to be between 0 and 1.
|
61 |
+
"""
|
62 |
+
for metric_name, metric_dict in self.metrics.items():
|
63 |
+
self.normalized_metrics[metric_name] = self._normalize_dict(metric_dict)
|
64 |
+
|
65 |
+
def calculate_combined_detectability(self):
|
66 |
+
"""
|
67 |
+
Calculate the combined detectability using the root mean square of the normalized metrics.
|
68 |
+
"""
|
69 |
+
for key in self.normalized_metrics['BLEU Score'].keys():
|
70 |
+
rms = np.sqrt(sum(
|
71 |
+
self.normalized_metrics[metric][key] ** 2 for metric in self.normalized_metrics
|
72 |
+
) / len(self.normalized_metrics))
|
73 |
+
self.combined_detectabilities[key] = rms
|
74 |
+
|
75 |
+
def plot_metrics(self):
|
76 |
+
"""
|
77 |
+
Plot each normalized metric and the combined detectability in separate graphs.
|
78 |
+
"""
|
79 |
+
keys = list(self.normalized_metrics['BLEU Score'].keys())
|
80 |
+
indices = np.arange(len(keys))
|
81 |
+
|
82 |
+
# Prepare data for plotting
|
83 |
+
metrics = {name: [self.normalized_metrics[name][key] for key in keys] for name in self.normalized_metrics}
|
84 |
+
|
85 |
+
# Plot each metric separately
|
86 |
+
for metric_name, values in metrics.items():
|
87 |
+
plt.figure(figsize=(12, 6))
|
88 |
+
plt.plot(indices, values, marker='o', color=np.random.rand(3,))
|
89 |
+
plt.xlabel('Sentence Index')
|
90 |
+
plt.ylabel('Normalized Value (0-1)')
|
91 |
+
plt.title(f'Normalized {metric_name}')
|
92 |
+
plt.grid(True)
|
93 |
+
plt.tight_layout()
|
94 |
+
plt.show()
|
95 |
+
|
96 |
+
# Private methods for metric calculations
|
97 |
+
def _calculate_bleu(self, reference, candidate):
|
98 |
+
"""
|
99 |
+
Calculate the BLEU score between the original and paraphrased sentence using smoothing.
|
100 |
+
"""
|
101 |
+
reference_tokens = nltk.word_tokenize(reference)
|
102 |
+
candidate_tokens = nltk.word_tokenize(candidate)
|
103 |
+
smoothing = SmoothingFunction().method1
|
104 |
+
return sentence_bleu([reference_tokens], candidate_tokens, smoothing_function=smoothing)
|
105 |
+
|
106 |
+
def _get_sentence_embedding(self, sentence):
|
107 |
+
"""
|
108 |
+
Get sentence embedding using BERT.
|
109 |
+
"""
|
110 |
+
tokens = self.bert_tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
111 |
+
with torch.no_grad():
|
112 |
+
outputs = self.bert_model(**tokens)
|
113 |
+
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
|
114 |
+
|
115 |
+
def _normalize_dict(self, metric_dict):
|
116 |
+
"""
|
117 |
+
Normalize the values in a dictionary to be between 0 and 1.
|
118 |
+
"""
|
119 |
+
values = np.array(list(metric_dict.values()))
|
120 |
+
min_val = values.min()
|
121 |
+
max_val = values.max()
|
122 |
+
# Avoid division by zero if all values are the same
|
123 |
+
return dict(zip(metric_dict.keys(), np.zeros_like(values) if max_val - min_val == 0 else (values - min_val) / (max_val - min_val)))
|
124 |
+
|
125 |
+
# Getter methods
|
126 |
+
def get_normalized_metrics(self):
|
127 |
+
"""
|
128 |
+
Get all normalized metrics as a dictionary.
|
129 |
+
"""
|
130 |
+
return self.normalized_metrics
|
131 |
+
|
132 |
+
def get_combined_detectabilities(self):
|
133 |
+
"""
|
134 |
+
Get the dictionary of combined detectability values.
|
135 |
+
"""
|
136 |
+
return self.combined_detectabilities
|
137 |
+
|
138 |
+
|
139 |
+
# Example usage
|
140 |
+
if __name__ == "__main__":
|
141 |
+
# Original sentence
|
142 |
+
original_sentence = "The quick brown fox jumps over the lazy dog"
|
143 |
+
|
144 |
+
# Paraphrased sentences
|
145 |
+
paraphrased_sentences = [
|
146 |
+
# Original 1: "A swift auburn fox leaps across a sleepy canine."
|
147 |
+
"The swift auburn fox leaps across a sleepy canine.",
|
148 |
+
"A quick auburn fox leaps across a sleepy canine.",
|
149 |
+
"A swift ginger fox leaps across a sleepy canine.",
|
150 |
+
"A swift auburn fox bounds across a sleepy canine.",
|
151 |
+
"A swift auburn fox leaps across a tired canine.",
|
152 |
+
"Three swift auburn foxes leap across a sleepy canine.",
|
153 |
+
"The vulpine specimen rapidly traverses over a dormant dog.",
|
154 |
+
"Like lightning, the russet hunter soars over the drowsy guardian.",
|
155 |
+
"Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
|
156 |
+
"One rapid Vulpes vulpes traverses the path of a quiescent canine.",
|
157 |
+
"A swift auburn predator navigates across a lethargic pet.",
|
158 |
+
"Subject A (fox) demonstrates velocity over Subject B (dog).",
|
159 |
+
|
160 |
+
# Original 2: "The agile russet fox bounds over an idle hound."
|
161 |
+
"Some agile russet foxes bound over an idle hound.",
|
162 |
+
"The nimble russet fox bounds over an idle hound.",
|
163 |
+
"The agile brown fox bounds over an idle hound.",
|
164 |
+
"The agile russet fox jumps over an idle hound.",
|
165 |
+
"The agile russet fox bounds over a lazy hound.",
|
166 |
+
"Two agile russet foxes bound over an idle hound.",
|
167 |
+
"A dexterous vulpine surpasses a stationary canine.",
|
168 |
+
"Quick as thought, the copper warrior sails over the guardian.",
|
169 |
+
"Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
|
170 |
+
"A dexterous V. vulpes exceeds the plane of an inactive canine.",
|
171 |
+
"An agile russet hunter maneuvers above a resting hound.",
|
172 |
+
"Test subject F-1 achieves displacement superior to subject D-1.",
|
173 |
+
|
174 |
+
# Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
|
175 |
+
"The nimble mahogany vulpine vaults above a drowsy dog.",
|
176 |
+
"A swift mahogany vulpine vaults above a drowsy dog.",
|
177 |
+
"A nimble reddish vulpine vaults above a drowsy dog.",
|
178 |
+
"A nimble mahogany fox vaults above a drowsy dog.",
|
179 |
+
"A nimble mahogany vulpine leaps above a drowsy dog.",
|
180 |
+
"Four nimble mahogany vulpines vault above a drowsy dog.",
|
181 |
+
"An agile specimen of reddish fur surpasses a somnolent canine.",
|
182 |
+
"Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
|
183 |
+
"Tha quick brown beastie jumps o'er the tired pup, aye.",
|
184 |
+
"Single V. vulpes demonstrates vertical traverse over C. familiaris.",
|
185 |
+
"A nimble rust-colored predator crosses above a drowsy pet.",
|
186 |
+
"Observed: Subject Red executes vertical motion over Subject Gray.",
|
187 |
+
|
188 |
+
# Original 4: "The speedy copper-colored fox hops over the lethargic pup."
|
189 |
+
"A speedy copper-colored fox hops over the lethargic pup.",
|
190 |
+
"The quick copper-colored fox hops over the lethargic pup.",
|
191 |
+
"The speedy bronze fox hops over the lethargic pup.",
|
192 |
+
"The speedy copper-colored fox jumps over the lethargic pup.",
|
193 |
+
"The speedy copper-colored fox hops over the tired pup.",
|
194 |
+
"Multiple speedy copper-colored foxes hop over the lethargic pup.",
|
195 |
+
"A rapid vulpine of bronze hue traverses an inactive young canine.",
|
196 |
+
"Swift as a dart, the metallic hunter bounds over the lazy puppy.",
|
197 |
+
"Tha fast copper beastie leaps o'er the sleepy wee dog.",
|
198 |
+
"1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
|
199 |
+
"A fleet copper-toned predator moves past a sluggish young dog.",
|
200 |
+
"Field note: Adult fox subject exceeds puppy subject vertically.",
|
201 |
+
|
202 |
+
# Original 5: "A rapid tawny fox springs over a sluggish dog."
|
203 |
+
"The rapid tawny fox springs over a sluggish dog.",
|
204 |
+
"A quick tawny fox springs over a sluggish dog.",
|
205 |
+
"A rapid golden fox springs over a sluggish dog.",
|
206 |
+
"A rapid tawny fox jumps over a sluggish dog.",
|
207 |
+
"A rapid tawny fox springs over a lazy dog.",
|
208 |
+
"Six rapid tawny foxes spring over a sluggish dog.",
|
209 |
+
"An expeditious yellowish vulpine surpasses a torpid canine.",
|
210 |
+
"Fast as a bullet, the golden hunter vaults over the idle guard.",
|
211 |
+
"Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
|
212 |
+
"One V. vulpes displays rapid transit over one inactive C. familiaris.",
|
213 |
+
"A speedy yellow-brown predator bypasses a motionless dog.",
|
214 |
+
"Log entry: Vulpine subject achieves swift vertical displacement.",
|
215 |
+
|
216 |
+
# Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
|
217 |
+
"A fleet-footed chestnut fox soars above an indolent canine.",
|
218 |
+
"The swift chestnut fox soars above an indolent canine.",
|
219 |
+
"The fleet-footed brown fox soars above an indolent canine.",
|
220 |
+
"The fleet-footed chestnut fox leaps above an indolent canine.",
|
221 |
+
"The fleet-footed chestnut fox soars above a lazy canine.",
|
222 |
+
"Several fleet-footed chestnut foxes soar above an indolent canine.",
|
223 |
+
"A rapid brown vulpine specimen traverses a lethargic domestic dog.",
|
224 |
+
"Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
|
225 |
+
"Tha quick brown beastie sails o'er the sleepy hound, ken.",
|
226 |
+
"Single agile V. vulpes achieves elevation above stationary canine.",
|
227 |
+
"A nimble brown predator glides over an unmoving domestic animal.",
|
228 |
+
"Research note: Brown subject displays superior vertical mobility.",
|
229 |
+
|
230 |
+
# Original 7: "A fast ginger fox hurdles past a slothful dog."
|
231 |
+
"The fast ginger fox hurdles past a slothful dog.",
|
232 |
+
"A quick ginger fox hurdles past a slothful dog.",
|
233 |
+
"A fast red fox hurdles past a slothful dog.",
|
234 |
+
"A fast ginger fox jumps past a slothful dog.",
|
235 |
+
"A fast ginger fox hurdles past a lazy dog.",
|
236 |
+
"Five fast ginger foxes hurdle past a slothful dog.",
|
237 |
+
"A rapid orange vulpine bypasses a lethargic canine.",
|
238 |
+
"Quick as lightning, the flame-colored hunter races past the lazy guard.",
|
239 |
+
"Tha swift ginger beastie leaps past the tired doggy, ye see.",
|
240 |
+
"1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
|
241 |
+
"A speedy red-orange predator overtakes a motionless dog.",
|
242 |
+
"Data point: Orange subject demonstrates rapid transit past Gray subject.",
|
243 |
+
|
244 |
+
# Original 8: "The spry rusty-colored fox jumps across a dozing hound."
|
245 |
+
"A spry rusty-colored fox jumps across a dozing hound.",
|
246 |
+
"The agile rusty-colored fox jumps across a dozing hound.",
|
247 |
+
"The spry reddish fox jumps across a dozing hound.",
|
248 |
+
"The spry rusty-colored fox leaps across a dozing hound.",
|
249 |
+
"The spry rusty-colored fox jumps across a sleeping hound.",
|
250 |
+
"Multiple spry rusty-colored foxes jump across a dozing hound.",
|
251 |
+
"An agile rust-toned vulpine traverses a somnolent canine.",
|
252 |
+
"Nimble as thought, the copper hunter bounds over the resting guard.",
|
253 |
+
"Tha lively rust-colored beastie hops o'er the snoozin' hound.",
|
254 |
+
"Single dexterous V. vulpes crosses path of dormant C. familiaris.",
|
255 |
+
"A lithe rust-tinted predator moves past a slumbering dog.",
|
256 |
+
"Observation: Russet subject exhibits agility over dormant subject.",
|
257 |
+
|
258 |
+
# Original 9: "A quick tan fox leaps over an inactive dog."
|
259 |
+
"The quick tan fox leaps over an inactive dog.",
|
260 |
+
"A swift tan fox leaps over an inactive dog.",
|
261 |
+
"A quick beige fox leaps over an inactive dog.",
|
262 |
+
"A quick tan fox jumps over an inactive dog.",
|
263 |
+
"A quick tan fox leaps over a motionless dog.",
|
264 |
+
"Seven quick tan foxes leap over an inactive dog.",
|
265 |
+
"A rapid light-brown vulpine surpasses a stationary canine.",
|
266 |
+
"Fast as wind, the sand-colored hunter soars over the still guard.",
|
267 |
+
"Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
|
268 |
+
"One agile fawn V. vulpes traverses one immobile C. familiaris.",
|
269 |
+
"A fleet tan-colored predator bypasses an unmoving dog.",
|
270 |
+
"Field report: Tan subject demonstrates movement over static subject.",
|
271 |
+
|
272 |
+
# Original 10: "The brisk auburn vulpine bounces over a listless canine."
|
273 |
+
"Some brisk auburn vulpines bounce over a listless canine.",
|
274 |
+
"The quick auburn vulpine bounces over a listless canine.",
|
275 |
+
"The brisk russet vulpine bounces over a listless canine.",
|
276 |
+
"The brisk auburn fox bounces over a listless canine.",
|
277 |
+
"The brisk auburn vulpine jumps over a listless canine.",
|
278 |
+
"Five brisk auburn vulpines bounce over a listless canine.",
|
279 |
+
"The expeditious specimen supersedes a quiescent Canis lupus.",
|
280 |
+
"Swift as wind, the russet hunter vaults over the idle guardian.",
|
281 |
+
"Tha quick ginger beastie hops o'er the lazy mutt, aye.",
|
282 |
+
"One V. vulpes achieves displacement over inactive C. familiaris.",
|
283 |
+
"A high-velocity auburn predator traverses an immobile animal.",
|
284 |
+
"Final observation: Red subject shows mobility over Gray subject."
|
285 |
+
]
|
286 |
+
|
287 |
+
# Create the calculator instance
|
288 |
+
calculator = SentenceDetectabilityCalculator(original_sentence, paraphrased_sentences)
|
289 |
+
|
290 |
+
# Calculate metrics
|
291 |
+
calculator.calculate_all_metrics()
|
292 |
+
calculator.normalize_metrics()
|
293 |
+
calculator.calculate_combined_detectability()
|
294 |
+
|
295 |
+
# Plot metrics
|
296 |
+
calculator.plot_metrics()
|
297 |
+
|
298 |
+
# Get results
|
299 |
+
normalized_metrics = calculator.get_normalized_metrics()
|
300 |
+
combined_detectabilities = calculator.get_combined_detectabilities()
|
301 |
+
|
302 |
+
print("Normalized Metrics:", normalized_metrics)
|
303 |
+
print("Combined Detectabilities:", combined_detectabilities)
|
distortion.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries
|
2 |
+
import nltk
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from scipy.special import rel_entr
|
7 |
+
from collections import Counter
|
8 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
9 |
+
|
10 |
+
# Download NLTK data if not already present
|
11 |
+
nltk.download('punkt', quiet=True)
|
12 |
+
|
13 |
+
class SentenceDistortionCalculator:
|
14 |
+
"""
|
15 |
+
A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, original_sentence, modified_sentences):
|
19 |
+
self.original_sentence = original_sentence
|
20 |
+
self.modified_sentences = modified_sentences
|
21 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
22 |
+
self.model = GPT2LMHeadModel.from_pretrained("gpt2").eval() # Set model to evaluation mode
|
23 |
+
|
24 |
+
# Raw metric dictionaries
|
25 |
+
self.metrics = {
|
26 |
+
'levenshtein': {},
|
27 |
+
'word_level_changes': {},
|
28 |
+
'kl_divergences': {},
|
29 |
+
'perplexities': {},
|
30 |
+
}
|
31 |
+
|
32 |
+
# Combined distortion dictionary
|
33 |
+
self.combined_distortions = {}
|
34 |
+
|
35 |
+
def calculate_all_metrics(self):
|
36 |
+
"""Calculate all distortion metrics for each modified sentence."""
|
37 |
+
for idx, modified_sentence in enumerate(self.modified_sentences):
|
38 |
+
key = f"Sentence_{idx + 1}"
|
39 |
+
self.metrics['levenshtein'][key] = self._calculate_levenshtein_distance(modified_sentence)
|
40 |
+
self.metrics['word_level_changes'][key] = self._calculate_word_level_change(modified_sentence)
|
41 |
+
self.metrics['kl_divergences'][key] = self._calculate_kl_divergence(modified_sentence)
|
42 |
+
self.metrics['perplexities'][key] = self._calculate_perplexity(modified_sentence)
|
43 |
+
|
44 |
+
def normalize_metrics(self):
|
45 |
+
"""Normalize all metrics to be between 0 and 1."""
|
46 |
+
for metric in self.metrics:
|
47 |
+
self.metrics[metric] = self._normalize_dict(self.metrics[metric])
|
48 |
+
|
49 |
+
def calculate_combined_distortion(self):
|
50 |
+
"""Calculate the combined distortion using the root mean square of the normalized metrics."""
|
51 |
+
for key in self.metrics['levenshtein']:
|
52 |
+
rms = np.sqrt(sum(self.metrics[metric][key] ** 2 for metric in self.metrics) / len(self.metrics))
|
53 |
+
self.combined_distortions[key] = rms
|
54 |
+
|
55 |
+
def plot_metrics(self):
|
56 |
+
"""Plot each normalized metric and the combined distortion in separate graphs."""
|
57 |
+
keys = list(self.metrics['levenshtein'].keys())
|
58 |
+
indices = np.arange(len(keys))
|
59 |
+
|
60 |
+
for metric_name, values in self.metrics.items():
|
61 |
+
plt.figure(figsize=(12, 6))
|
62 |
+
plt.plot(indices, list(values.values()), marker='o', label=metric_name)
|
63 |
+
plt.xlabel('Sentence Index')
|
64 |
+
plt.ylabel('Normalized Value (0-1)')
|
65 |
+
plt.title(f'Normalized {metric_name.replace("_", " ").title()}')
|
66 |
+
plt.grid(True)
|
67 |
+
plt.legend()
|
68 |
+
plt.tight_layout()
|
69 |
+
plt.show()
|
70 |
+
|
71 |
+
# Private methods for metric calculations
|
72 |
+
def _calculate_levenshtein_distance(self, modified_sentence):
|
73 |
+
"""Calculate the Levenshtein Distance between the original and modified sentence."""
|
74 |
+
return nltk.edit_distance(self.original_sentence, modified_sentence)
|
75 |
+
|
76 |
+
def _calculate_word_level_change(self, modified_sentence):
|
77 |
+
"""Calculate the proportion of word-level changes between the original and modified sentence."""
|
78 |
+
original_words = self.original_sentence.split()
|
79 |
+
modified_words = modified_sentence.split()
|
80 |
+
total_words = max(len(original_words), len(modified_words))
|
81 |
+
changed_words = sum(o != m for o, m in zip(original_words, modified_words)) + abs(len(original_words) - len(modified_words))
|
82 |
+
return changed_words / total_words if total_words > 0 else 0
|
83 |
+
|
84 |
+
def _calculate_kl_divergence(self, modified_sentence):
|
85 |
+
"""Calculate the KL Divergence between the word distributions of the original and modified sentence."""
|
86 |
+
original_counts = Counter(self.original_sentence.lower().split())
|
87 |
+
modified_counts = Counter(modified_sentence.lower().split())
|
88 |
+
all_words = set(original_counts.keys()).union(modified_counts.keys())
|
89 |
+
|
90 |
+
original_probs = np.array([original_counts[word] for word in all_words], dtype=float)
|
91 |
+
modified_probs = np.array([modified_counts[word] for word in all_words], dtype=float)
|
92 |
+
|
93 |
+
original_probs /= original_probs.sum() + 1e-10 # Avoid division by zero
|
94 |
+
modified_probs /= modified_probs.sum() + 1e-10
|
95 |
+
|
96 |
+
return np.sum(rel_entr(original_probs, modified_probs))
|
97 |
+
|
98 |
+
def _calculate_perplexity(self, sentence):
|
99 |
+
"""Calculate the perplexity of a sentence using GPT-2."""
|
100 |
+
encodings = self.tokenizer(sentence, return_tensors='pt')
|
101 |
+
stride = self.model.config.n_positions
|
102 |
+
log_likelihoods = []
|
103 |
+
|
104 |
+
for i in range(0, encodings.input_ids.size(1), stride):
|
105 |
+
input_ids = encodings.input_ids[:, i:i + stride]
|
106 |
+
with torch.no_grad():
|
107 |
+
outputs = self.model(input_ids, labels=input_ids)
|
108 |
+
log_likelihoods.append(outputs.loss.item())
|
109 |
+
|
110 |
+
avg_log_likelihood = np.mean(log_likelihoods)
|
111 |
+
return torch.exp(torch.tensor(avg_log_likelihood)).item()
|
112 |
+
|
113 |
+
def _normalize_dict(self, metric_dict):
|
114 |
+
"""Normalize the values in a dictionary to be between 0 and 1."""
|
115 |
+
values = np.array(list(metric_dict.values()))
|
116 |
+
min_val, max_val = values.min(), values.max()
|
117 |
+
normalized_values = (values - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(values)
|
118 |
+
return dict(zip(metric_dict.keys(), normalized_values))
|
119 |
+
|
120 |
+
def get_normalized_metrics(self):
|
121 |
+
"""Get all normalized metrics as a dictionary."""
|
122 |
+
return {metric: self._normalize_dict(values) for metric, values in self.metrics.items()}
|
123 |
+
|
124 |
+
def get_combined_distortions(self):
|
125 |
+
"""Get the dictionary of combined distortion values."""
|
126 |
+
return self.combined_distortions
|
entailment.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
|
3 |
+
def analyze_entailment(original_sentence, paraphrased_sentences, threshold):
|
4 |
+
# Load the entailment model once
|
5 |
+
entailment_pipe = pipeline("text-classification", model="ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
|
6 |
+
|
7 |
+
all_sentences = {}
|
8 |
+
selected_sentences = {}
|
9 |
+
discarded_sentences = {}
|
10 |
+
|
11 |
+
# Prepare input for entailment checks
|
12 |
+
inputs = [f"{original_sentence} [SEP] {paraphrase}" for paraphrase in paraphrased_sentences]
|
13 |
+
|
14 |
+
# Perform entailment checks for all paraphrased sentences in one go
|
15 |
+
entailment_results = entailment_pipe(inputs, return_all_scores=True)
|
16 |
+
|
17 |
+
# Iterate over results
|
18 |
+
for paraphrased_sentence, results in zip(paraphrased_sentences, entailment_results):
|
19 |
+
# Extract the entailment score for each paraphrased sentence
|
20 |
+
entailment_score = next((result['score'] for result in results if result['label'] == 'entailment'), 0)
|
21 |
+
|
22 |
+
all_sentences[paraphrased_sentence] = entailment_score
|
23 |
+
|
24 |
+
# Store sentences based on the threshold
|
25 |
+
if entailment_score >= threshold:
|
26 |
+
selected_sentences[paraphrased_sentence] = entailment_score
|
27 |
+
else:
|
28 |
+
discarded_sentences[paraphrased_sentence] = entailment_score
|
29 |
+
|
30 |
+
return all_sentences, selected_sentences, discarded_sentences
|
31 |
+
|
32 |
+
# Example usage
|
33 |
+
# print(analyze_entailment("I love you", ["I adore you", "I hate you"], 0.7))
|
euclidean_distance.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from sklearn.metrics.pairwise import euclidean_distances
|
6 |
+
|
7 |
+
class SentenceEuclideanDistanceCalculator:
|
8 |
+
"""
|
9 |
+
A class to calculate and analyze Euclidean distance between an original sentence and paraphrased sentences.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, original_sentence, paraphrased_sentences):
|
13 |
+
"""
|
14 |
+
Initialize the calculator with the original sentence and a list of paraphrased sentences.
|
15 |
+
"""
|
16 |
+
self.original_sentence = original_sentence
|
17 |
+
self.paraphrased_sentences = paraphrased_sentences
|
18 |
+
|
19 |
+
# Load SentenceTransformer model for embedding calculation
|
20 |
+
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
21 |
+
|
22 |
+
# Precompute the original sentence embedding
|
23 |
+
self.original_embedding = self.model.encode(original_sentence, convert_to_tensor=True)
|
24 |
+
|
25 |
+
# Calculate Euclidean distances and normalize them
|
26 |
+
self.euclidean_distances = self._calculate_all_metrics()
|
27 |
+
self.normalized_euclidean = self._normalize_dict(self.euclidean_distances)
|
28 |
+
|
29 |
+
def _calculate_all_metrics(self):
|
30 |
+
"""
|
31 |
+
Calculate Euclidean distance between the original and each paraphrased sentence.
|
32 |
+
"""
|
33 |
+
distances = {}
|
34 |
+
paraphrase_embeddings = self.model.encode(self.paraphrased_sentences, convert_to_tensor=True)
|
35 |
+
|
36 |
+
for idx, paraphrase_embedding in enumerate(paraphrase_embeddings):
|
37 |
+
key = f"Sentence_{idx + 1}"
|
38 |
+
distances[key] = euclidean_distances([self.original_embedding], [paraphrase_embedding])[0][0]
|
39 |
+
|
40 |
+
return distances
|
41 |
+
|
42 |
+
def _normalize_dict(self, metric_dict):
|
43 |
+
"""
|
44 |
+
Normalize the values in a dictionary to be between 0 and 1.
|
45 |
+
"""
|
46 |
+
values = np.array(list(metric_dict.values()))
|
47 |
+
min_val, max_val = values.min(), values.max()
|
48 |
+
|
49 |
+
# Normalize values
|
50 |
+
normalized_values = (values - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(values)
|
51 |
+
return dict(zip(metric_dict.keys(), normalized_values))
|
52 |
+
|
53 |
+
def plot_metrics(self):
|
54 |
+
"""
|
55 |
+
Plot the normalized Euclidean distances in a graph.
|
56 |
+
"""
|
57 |
+
keys = list(self.normalized_euclidean.keys())
|
58 |
+
indices = np.arange(len(keys))
|
59 |
+
|
60 |
+
plt.figure(figsize=(12, 6))
|
61 |
+
plt.plot(indices, [self.normalized_euclidean[key] for key in keys], marker='o', color=np.random.rand(3,))
|
62 |
+
plt.xlabel('Sentence Index')
|
63 |
+
plt.ylabel('Normalized Euclidean Distance (0-1)')
|
64 |
+
plt.title('Normalized Euclidean Distance')
|
65 |
+
plt.grid(True)
|
66 |
+
plt.tight_layout()
|
67 |
+
plt.show()
|
68 |
+
|
69 |
+
# Getter methods
|
70 |
+
def get_normalized_metrics(self):
|
71 |
+
"""
|
72 |
+
Get the normalized Euclidean distances as a dictionary.
|
73 |
+
"""
|
74 |
+
return self.normalized_euclidean
|
gpt_mask_filling.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
openai.api_key = os.getenv("API_KEY")
|
8 |
+
|
9 |
+
|
10 |
+
#Takes in a sentence and returns a list of dicts consisiting of key-value pairs of masked words and lists of the possible replacements
|
11 |
+
def predict_masked_words(sentence, n_suggestions=5):
|
12 |
+
|
13 |
+
prompt = (
|
14 |
+
f"Given a sentence with masked words, masked word can be one or more than one, indicated by [MASK], generate {n_suggestions} possible words to fill each mask. "
|
15 |
+
"Return the results as a list of dictionaries, where each dictionary key is a masked word and its value is a list of 5 potential words to fill that mask.\n\n"
|
16 |
+
"Example input: \"The [MASK] fox [MASK] over the [MASK] dog.\"\n\n"
|
17 |
+
"Example output:\n"
|
18 |
+
"[\n"
|
19 |
+
" {\n"
|
20 |
+
" \"[MASK]1\": [\"quick\", \"sly\", \"red\", \"clever\", \"sneaky\"]\n"
|
21 |
+
" },\n"
|
22 |
+
" {\n"
|
23 |
+
" \"[MASK]2\": [\"jumped\", \"leaped\", \"hopped\", \"sprang\", \"bounded\"]\n"
|
24 |
+
" },\n"
|
25 |
+
" {\n"
|
26 |
+
" \"[MASK]3\": [\"lazy\", \"sleeping\", \"brown\", \"tired\", \"old\"]\n"
|
27 |
+
" }\n"
|
28 |
+
"]\n\n"
|
29 |
+
"Example input: \"The [MASK] [MASK] ran swiftly across the [MASK] field.\"\n\n"
|
30 |
+
"Example output:\n"
|
31 |
+
"[\n"
|
32 |
+
" {\n"
|
33 |
+
" \"[MASK]1\": [\"tall\", \"fierce\", \"young\", \"old\", \"beautiful\"]\n"
|
34 |
+
" },\n"
|
35 |
+
" {\n"
|
36 |
+
" \"[MASK]2\": [\"lion\", \"tiger\", \"horse\", \"cheetah\", \"deer\"]\n"
|
37 |
+
" },\n"
|
38 |
+
" {\n"
|
39 |
+
" \"[MASK]3\": [\"green\", \"wide\", \"sunny\", \"open\", \"empty\"]\n"
|
40 |
+
" }\n"
|
41 |
+
"]\n\n"
|
42 |
+
"Example input: \"It was a [MASK] day when the train arrived at the station.\"\n\n"
|
43 |
+
"Example output:\n"
|
44 |
+
"[\n"
|
45 |
+
" {\n"
|
46 |
+
" \"[MASK]1\": [\"sunny\", \"rainy\", \"cloudy\", \"foggy\", \"stormy\"]\n"
|
47 |
+
" },\n"
|
48 |
+
"]\n\n"
|
49 |
+
"Now, please process the following sentence:\n"
|
50 |
+
f"{sentence}"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
response = openai.ChatCompletion.create(
|
55 |
+
model="gpt-3.5-turbo",
|
56 |
+
messages=[
|
57 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
58 |
+
{"role": "user", "content": prompt}
|
59 |
+
],
|
60 |
+
max_tokens=100,
|
61 |
+
n=1,
|
62 |
+
stop=None,
|
63 |
+
temperature=0.7
|
64 |
+
)
|
65 |
+
|
66 |
+
print(response['choices'][0]['message']['content'])
|
67 |
+
|
68 |
+
|
69 |
+
# sentence = "Evacuations and storm [MASK] began on Sunday night as forecasters projected that Hurricane Dorian would hit into Florida’s west coast on Wednesday as a major hurricane packing life-threatening winds and storm surge."
|
70 |
+
# predict_masked_words(sentence, n_suggestions=5)
|
highlighter.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def highlight_common_words(common_words, sentences, title):
|
4 |
+
color_map = {}
|
5 |
+
highlighted_html = []
|
6 |
+
|
7 |
+
for idx, sentence in enumerate(sentences, start=1):
|
8 |
+
highlighted_sentence = f"{idx}. {sentence}"
|
9 |
+
|
10 |
+
for index, word in common_words:
|
11 |
+
if word not in color_map:
|
12 |
+
# Assign color using HSL for better visual distinction
|
13 |
+
color_map[word] = f'hsl({(len(color_map) % 6) * 60}, 70%, 80%)'
|
14 |
+
|
15 |
+
# Create a regex pattern for the word
|
16 |
+
escaped_word = re.escape(word)
|
17 |
+
pattern = rf'\b{escaped_word}\b'
|
18 |
+
color = color_map[word]
|
19 |
+
|
20 |
+
# Use a lambda function for word highlighting
|
21 |
+
highlighted_sentence = re.sub(
|
22 |
+
pattern,
|
23 |
+
lambda m: (f'<span style="background-color: {color}; font-weight: bold;'
|
24 |
+
' padding: 2px 4px; border-radius: 2px; position: relative;">'
|
25 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
26 |
+
' padding: 2px 5px; margin-right: 5px;">{index}</span>'
|
27 |
+
f'{m.group(0)}'
|
28 |
+
'</span>'),
|
29 |
+
highlighted_sentence,
|
30 |
+
flags=re.IGNORECASE
|
31 |
+
)
|
32 |
+
|
33 |
+
highlighted_html.append(highlighted_sentence)
|
34 |
+
|
35 |
+
# Construct the final HTML output
|
36 |
+
return generate_html(title, highlighted_html)
|
37 |
+
|
38 |
+
def highlight_common_words_dict(common_words, sentences, title):
|
39 |
+
color_map = {}
|
40 |
+
highlighted_html = []
|
41 |
+
|
42 |
+
for idx, (sentence, score) in enumerate(sentences.items(), start=1):
|
43 |
+
highlighted_sentence = f"{idx}. {sentence}"
|
44 |
+
|
45 |
+
for index, word in common_words:
|
46 |
+
if word not in color_map:
|
47 |
+
color_map[word] = f'hsl({(len(color_map) % 6) * 60}, 70%, 80%)'
|
48 |
+
escaped_word = re.escape(word)
|
49 |
+
pattern = rf'\b{escaped_word}\b'
|
50 |
+
color = color_map[word]
|
51 |
+
|
52 |
+
highlighted_sentence = re.sub(
|
53 |
+
pattern,
|
54 |
+
lambda m: (f'<span style="background-color: {color}; font-weight: bold;'
|
55 |
+
' padding: 1px 2px; border-radius: 2px; position: relative;">'
|
56 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
57 |
+
' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{index}</span>'
|
58 |
+
f'{m.group(0)}'
|
59 |
+
'</span>'),
|
60 |
+
highlighted_sentence,
|
61 |
+
flags=re.IGNORECASE
|
62 |
+
)
|
63 |
+
|
64 |
+
highlighted_html.append(
|
65 |
+
f'<div style="margin-bottom: 5px;">'
|
66 |
+
f'{highlighted_sentence}'
|
67 |
+
f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px;'
|
68 |
+
' background-color: white; font-size: 0.9em;">Entailment Score: {score}</div></div>'
|
69 |
+
)
|
70 |
+
|
71 |
+
return generate_html(title, highlighted_html)
|
72 |
+
|
73 |
+
def generate_html(title, highlighted_html):
|
74 |
+
final_html = "<br><br>".join(highlighted_html)
|
75 |
+
return f'''
|
76 |
+
<div style="border: solid 1px #ccc; padding: 16px; background-color: #FFFFFF; color: #374151;
|
77 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
|
78 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #111827;">{title}</h3>
|
79 |
+
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
80 |
+
</div>
|
81 |
+
'''
|
82 |
+
|
83 |
+
def reparaphrased_sentences_html(sentences):
|
84 |
+
formatted_sentences = [f"{idx + 1}. {sentence}" for idx, sentence in enumerate(sentences)]
|
85 |
+
final_html = "<br><br>".join(formatted_sentences)
|
86 |
+
|
87 |
+
return f'''
|
88 |
+
<div style="border: solid 1px #ccc; padding: 16px; background-color: #FFFFFF; color: #374151;
|
89 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
|
90 |
+
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
91 |
+
</div>
|
92 |
+
'''
|
lcs.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from nltk.corpus import stopwords
|
3 |
+
|
4 |
+
def find_common_subsequences(sentence, str_list):
|
5 |
+
# Load stop words
|
6 |
+
stop_words = set(stopwords.words('english'))
|
7 |
+
|
8 |
+
# Preprocess the input sentence and list of strings
|
9 |
+
sentence = sentence.lower()
|
10 |
+
cleaned_str_list = [s.lower() for s in str_list]
|
11 |
+
|
12 |
+
def clean_text(text):
|
13 |
+
"""Remove stop words and special characters from a given text."""
|
14 |
+
text = re.sub(r'[^\w\s]', '', text)
|
15 |
+
return " ".join(word for word in text.split() if word not in stop_words)
|
16 |
+
|
17 |
+
cleaned_sentence = clean_text(sentence)
|
18 |
+
cleaned_str_list = [clean_text(s) for s in cleaned_str_list]
|
19 |
+
|
20 |
+
words = cleaned_sentence.split()
|
21 |
+
common_grams = []
|
22 |
+
added_phrases = set()
|
23 |
+
|
24 |
+
for n in range(5, 0, -1): # Check n-grams from size 5 to 1
|
25 |
+
for i in range(len(words) - n + 1):
|
26 |
+
subseq = " ".join(words[i:i + n])
|
27 |
+
if is_present(subseq, cleaned_str_list) and subseq not in added_phrases:
|
28 |
+
common_grams.append((i, subseq))
|
29 |
+
added_phrases.add(subseq)
|
30 |
+
|
31 |
+
# Sort by the first appearance in the original sentence and create indexed common grams
|
32 |
+
common_grams.sort(key=lambda x: x[0])
|
33 |
+
return [(index + 1, subseq) for index, (_, subseq) in enumerate(common_grams)]
|
34 |
+
|
35 |
+
def is_present(subseq, str_list):
|
36 |
+
"""Check if a subsequence is present in all strings in the list."""
|
37 |
+
subseq_regex = re.compile(r'\b' + re.escape(subseq) + r'\b')
|
38 |
+
return all(subseq_regex.search(s) for s in str_list)
|
39 |
+
|
40 |
+
def find_common_gram_positions(str_list, common_grams):
|
41 |
+
"""Find positions of common grams in each string from str_list."""
|
42 |
+
positions = []
|
43 |
+
|
44 |
+
for sentence in str_list:
|
45 |
+
words = re.sub(r'[^\w\s]', '', sentence).lower().split()
|
46 |
+
word_positions = {word: [] for word in words}
|
47 |
+
|
48 |
+
for idx, word in enumerate(words):
|
49 |
+
word_positions[word].append(idx + 1) # Store 1-based index positions
|
50 |
+
|
51 |
+
sentence_positions = []
|
52 |
+
for _, gram in common_grams:
|
53 |
+
gram_words = re.sub(r'[^\w\s]', '', gram).lower().split()
|
54 |
+
|
55 |
+
if all(word in word_positions for word in gram_words):
|
56 |
+
start_idx = word_positions[gram_words[0]][0]
|
57 |
+
sentence_positions.append(start_idx)
|
58 |
+
else:
|
59 |
+
sentence_positions.append(-1) # Common gram not found
|
60 |
+
|
61 |
+
positions.append(sentence_positions)
|
62 |
+
|
63 |
+
return positions
|
masking_methods.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
|
3 |
+
import random
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import nltk
|
6 |
+
from vocabulary_split import split_vocabulary, filter_logits
|
7 |
+
|
8 |
+
# Load tokenizer and model for masked language model
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
10 |
+
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
11 |
+
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
12 |
+
|
13 |
+
# Get permissible vocabulary
|
14 |
+
permissible, _ = split_vocabulary(seed=42)
|
15 |
+
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
|
16 |
+
|
17 |
+
# Initialize stop words and ensure NLTK resources are downloaded
|
18 |
+
stop_words = set(stopwords.words('english'))
|
19 |
+
nltk.download('averaged_perceptron_tagger', quiet=True)
|
20 |
+
nltk.download('maxent_ne_chunker', quiet=True)
|
21 |
+
nltk.download('words', quiet=True)
|
22 |
+
|
23 |
+
def get_logits_for_mask(sentence):
|
24 |
+
inputs = tokenizer(sentence, return_tensors="pt")
|
25 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
outputs = model(**inputs)
|
29 |
+
|
30 |
+
logits = outputs.logits
|
31 |
+
return logits[0, mask_token_index, :].squeeze()
|
32 |
+
|
33 |
+
def mask_word(sentence, word):
|
34 |
+
masked_sentence = sentence.replace(word, '[MASK]', 1)
|
35 |
+
logits = get_logits_for_mask(masked_sentence)
|
36 |
+
filtered_logits = filter_logits(logits, permissible_indices)
|
37 |
+
words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
|
38 |
+
return masked_sentence, filtered_logits.tolist(), words
|
39 |
+
|
40 |
+
def mask_non_stopword(sentence, pseudo_random=False):
|
41 |
+
non_stop_words = [word for word in sentence.split() if word.lower() not in stop_words]
|
42 |
+
if not non_stop_words:
|
43 |
+
return sentence, None, None
|
44 |
+
|
45 |
+
if pseudo_random:
|
46 |
+
random.seed(10) # Fixed seed for pseudo-randomness
|
47 |
+
word_to_mask = random.choice(non_stop_words)
|
48 |
+
return mask_word(sentence, word_to_mask)
|
49 |
+
|
50 |
+
def mask_between_lcs(sentence, lcs_points):
|
51 |
+
words = sentence.split()
|
52 |
+
masked_indices = []
|
53 |
+
|
54 |
+
# Mask first word before the first LCS point
|
55 |
+
if lcs_points and lcs_points[0] > 0:
|
56 |
+
idx = random.randint(0, lcs_points[0] - 1)
|
57 |
+
words[idx] = '[MASK]'
|
58 |
+
masked_indices.append(idx)
|
59 |
+
|
60 |
+
# Mask between LCS points
|
61 |
+
for i in range(len(lcs_points) - 1):
|
62 |
+
start, end = lcs_points[i], lcs_points[i + 1]
|
63 |
+
if end - start > 1:
|
64 |
+
mask_index = random.randint(start + 1, end - 1)
|
65 |
+
words[mask_index] = '[MASK]'
|
66 |
+
masked_indices.append(mask_index)
|
67 |
+
|
68 |
+
# Mask last word after the last LCS point
|
69 |
+
if lcs_points and lcs_points[-1] < len(words) - 1:
|
70 |
+
idx = random.randint(lcs_points[-1] + 1, len(words) - 1)
|
71 |
+
words[idx] = '[MASK]'
|
72 |
+
masked_indices.append(idx)
|
73 |
+
|
74 |
+
masked_sentence = ' '.join(words)
|
75 |
+
logits = get_logits_for_mask(masked_sentence)
|
76 |
+
|
77 |
+
logits_list, top_words_list = [], []
|
78 |
+
for idx in masked_indices:
|
79 |
+
filtered_logits = filter_logits(logits[idx], permissible_indices)
|
80 |
+
logits_list.append(filtered_logits.tolist())
|
81 |
+
top_words = [tokenizer.decode([i]) for i in filtered_logits.topk(5).indices.tolist()]
|
82 |
+
top_words_list.append(top_words)
|
83 |
+
|
84 |
+
return masked_sentence, logits_list, top_words_list
|
85 |
+
|
86 |
+
def high_entropy_words(sentence, non_melting_points):
|
87 |
+
non_melting_words = {word.lower() for _, point in non_melting_points for word in point.split()}
|
88 |
+
candidate_words = [word for word in sentence.split() if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
89 |
+
|
90 |
+
if not candidate_words:
|
91 |
+
return sentence, None, None
|
92 |
+
|
93 |
+
max_entropy, max_entropy_word, max_logits = -float('inf'), None, None
|
94 |
+
for word in candidate_words:
|
95 |
+
masked_sentence = sentence.replace(word, '[MASK]', 1)
|
96 |
+
logits = get_logits_for_mask(masked_sentence)
|
97 |
+
filtered_logits = filter_logits(logits, permissible_indices)
|
98 |
+
|
99 |
+
# Calculate entropy
|
100 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
101 |
+
top_5_probs = probs.topk(5).values
|
102 |
+
entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)) # Avoid log(0)
|
103 |
+
|
104 |
+
if entropy > max_entropy:
|
105 |
+
max_entropy, max_entropy_word, max_logits = entropy, word, filtered_logits
|
106 |
+
|
107 |
+
if max_entropy_word is None:
|
108 |
+
return sentence, None, None
|
109 |
+
|
110 |
+
masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1)
|
111 |
+
words = [tokenizer.decode([i]) for i in max_logits.argsort()[-5:]]
|
112 |
+
return masked_sentence, max_logits.tolist(), words
|
113 |
+
|
114 |
+
def mask_by_pos(sentence, pos_to_mask=['NOUN', 'VERB', 'ADJ']):
|
115 |
+
words = nltk.word_tokenize(sentence)
|
116 |
+
pos_tags = nltk.pos_tag(words)
|
117 |
+
|
118 |
+
maskable_words = [word for word, pos in pos_tags if pos[:2] in pos_to_mask]
|
119 |
+
if not maskable_words:
|
120 |
+
return sentence, None, None
|
121 |
+
|
122 |
+
word_to_mask = random.choice(maskable_words)
|
123 |
+
return mask_word(sentence, word_to_mask)
|
124 |
+
|
125 |
+
def mask_named_entity(sentence):
|
126 |
+
words = nltk.word_tokenize(sentence)
|
127 |
+
pos_tags = nltk.pos_tag(words)
|
128 |
+
named_entities = nltk.ne_chunk(pos_tags)
|
129 |
+
|
130 |
+
maskable_words = [word for word, tag in named_entities.leaves() if isinstance(tag, nltk.Tree)]
|
131 |
+
if not maskable_words:
|
132 |
+
return sentence, None, None
|
133 |
+
|
134 |
+
word_to_mask = random.choice(maskable_words)
|
135 |
+
return mask_word(sentence, word_to_mask)
|
136 |
+
|
137 |
+
|
masking_methods_trial.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
3 |
+
from transformers import pipeline
|
4 |
+
import random
|
5 |
+
from nltk.corpus import stopwords
|
6 |
+
import nltk
|
7 |
+
nltk.download('stopwords')
|
8 |
+
import math
|
9 |
+
from vocabulary_split import split_vocabulary, filter_logits
|
10 |
+
import abc
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
# Load tokenizer and model for masked language model
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
15 |
+
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
16 |
+
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
17 |
+
|
18 |
+
# Get permissible vocabulary
|
19 |
+
permissible, _ = split_vocabulary(seed=42)
|
20 |
+
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
|
21 |
+
|
22 |
+
def get_logits_for_mask(model, tokenizer, sentence):
|
23 |
+
inputs = tokenizer(sentence, return_tensors="pt")
|
24 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
25 |
+
|
26 |
+
with torch.no_grad():
|
27 |
+
outputs = model(**inputs)
|
28 |
+
|
29 |
+
logits = outputs.logits
|
30 |
+
mask_token_logits = logits[0, mask_token_index, :]
|
31 |
+
return mask_token_logits.squeeze()
|
32 |
+
|
33 |
+
# Abstract Masking Strategy
|
34 |
+
class MaskingStrategy(abc.ABC):
|
35 |
+
@abc.abstractmethod
|
36 |
+
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
|
37 |
+
"""
|
38 |
+
Given a list of words, return the indices of words to mask.
|
39 |
+
"""
|
40 |
+
pass
|
41 |
+
|
42 |
+
# Specific Masking Strategies
|
43 |
+
class RandomNonStopwordMasking(MaskingStrategy):
|
44 |
+
def __init__(self, num_masks: int = 1):
|
45 |
+
self.num_masks = num_masks
|
46 |
+
self.stop_words = set(stopwords.words('english'))
|
47 |
+
|
48 |
+
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
|
49 |
+
non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
50 |
+
if not non_stop_indices:
|
51 |
+
return []
|
52 |
+
num_masks = min(self.num_masks, len(non_stop_indices))
|
53 |
+
return random.sample(non_stop_indices, num_masks)
|
54 |
+
|
55 |
+
class HighEntropyMasking(MaskingStrategy):
|
56 |
+
def __init__(self, num_masks: int = 1):
|
57 |
+
self.num_masks = num_masks
|
58 |
+
|
59 |
+
def select_words_to_mask(self, words: List[str], sentence: str, model, tokenizer, permissible_indices) -> List[int]:
|
60 |
+
candidate_indices = [i for i, word in enumerate(words) if word.lower() not in set(stopwords.words('english'))]
|
61 |
+
if not candidate_indices:
|
62 |
+
return []
|
63 |
+
|
64 |
+
entropy_scores = {}
|
65 |
+
for idx in candidate_indices:
|
66 |
+
masked_sentence = ' '.join(words[:idx] + ['[MASK]'] + words[idx+1:])
|
67 |
+
logits = get_logits_for_mask(model, tokenizer, masked_sentence)
|
68 |
+
filtered_logits = filter_logits(logits, permissible_indices)
|
69 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
70 |
+
top_5_probs = probs.topk(5).values
|
71 |
+
entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)).item()
|
72 |
+
entropy_scores[idx] = entropy
|
73 |
+
|
74 |
+
# Select top N indices with highest entropy
|
75 |
+
sorted_indices = sorted(entropy_scores, key=entropy_scores.get, reverse=True)
|
76 |
+
return sorted_indices[:self.num_masks]
|
77 |
+
|
78 |
+
class PseudoRandomNonStopwordMasking(MaskingStrategy):
|
79 |
+
def __init__(self, num_masks: int = 1, seed: int = 10):
|
80 |
+
self.num_masks = num_masks
|
81 |
+
self.seed = seed
|
82 |
+
self.stop_words = set(stopwords.words('english'))
|
83 |
+
|
84 |
+
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
|
85 |
+
non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
86 |
+
if not non_stop_indices:
|
87 |
+
return []
|
88 |
+
random.seed(self.seed)
|
89 |
+
num_masks = min(self.num_masks, len(non_stop_indices))
|
90 |
+
return random.sample(non_stop_indices, num_masks)
|
91 |
+
|
92 |
+
class CompositeMaskingStrategy(MaskingStrategy):
|
93 |
+
def __init__(self, strategies: List[MaskingStrategy]):
|
94 |
+
self.strategies = strategies
|
95 |
+
|
96 |
+
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
|
97 |
+
selected_indices = []
|
98 |
+
for strategy in self.strategies:
|
99 |
+
if isinstance(strategy, HighEntropyMasking):
|
100 |
+
selected = strategy.select_words_to_mask(words, **kwargs)
|
101 |
+
else:
|
102 |
+
selected = strategy.select_words_to_mask(words)
|
103 |
+
selected_indices.extend(selected)
|
104 |
+
return list(set(selected_indices)) # Remove duplicates
|
105 |
+
|
106 |
+
# Refactored mask_between_lcs function
|
107 |
+
def mask_between_lcs(sentence, lcs_points, masking_strategy: MaskingStrategy, model, tokenizer, permissible_indices):
|
108 |
+
words = sentence.split()
|
109 |
+
masked_indices = []
|
110 |
+
|
111 |
+
segments = []
|
112 |
+
|
113 |
+
# Define segments based on LCS points
|
114 |
+
previous = 0
|
115 |
+
for point in lcs_points:
|
116 |
+
if point > previous:
|
117 |
+
segments.append((previous, point))
|
118 |
+
previous = point + 1
|
119 |
+
if previous < len(words):
|
120 |
+
segments.append((previous, len(words)))
|
121 |
+
|
122 |
+
# Collect all indices to mask from each segment
|
123 |
+
for start, end in segments:
|
124 |
+
segment_words = words[start:end]
|
125 |
+
if isinstance(masking_strategy, HighEntropyMasking):
|
126 |
+
selected = masking_strategy.select_words_to_mask(segment_words, sentence, model, tokenizer, permissible_indices)
|
127 |
+
else:
|
128 |
+
selected = masking_strategy.select_words_to_mask(segment_words)
|
129 |
+
|
130 |
+
# Adjust indices relative to the whole sentence
|
131 |
+
for idx in selected:
|
132 |
+
masked_idx = start + idx
|
133 |
+
if masked_idx not in masked_indices:
|
134 |
+
masked_indices.append(masked_idx)
|
135 |
+
|
136 |
+
# Apply masking
|
137 |
+
for idx in masked_indices:
|
138 |
+
words[idx] = '[MASK]'
|
139 |
+
|
140 |
+
masked_sentence = ' '.join(words)
|
141 |
+
logits = get_logits_for_mask(model, tokenizer, masked_sentence)
|
142 |
+
|
143 |
+
# Process each masked token
|
144 |
+
top_words_list = []
|
145 |
+
logits_list = []
|
146 |
+
for i, idx in enumerate(masked_indices):
|
147 |
+
logits_i = logits[i]
|
148 |
+
if logits_i.dim() > 1:
|
149 |
+
logits_i = logits_i.squeeze()
|
150 |
+
filtered_logits_i = filter_logits(logits_i, permissible_indices)
|
151 |
+
logits_list.append(filtered_logits_i.tolist())
|
152 |
+
top_5_indices = filtered_logits_i.topk(5).indices.tolist()
|
153 |
+
top_words = [tokenizer.decode([i]) for i in top_5_indices]
|
154 |
+
top_words_list.append(top_words)
|
155 |
+
|
156 |
+
return masked_sentence, logits_list, top_words_list
|
157 |
+
|
158 |
+
# Example Usage
|
159 |
+
if __name__ == "__main__":
|
160 |
+
# Example sentence and LCS points
|
161 |
+
sentence = "This is a sample sentence with some LCS points"
|
162 |
+
lcs_points = [2, 5, 8] # Indices of LCS points
|
163 |
+
|
164 |
+
# Initialize masking strategies
|
165 |
+
random_non_stopword_strategy = RandomNonStopwordMasking(num_masks=1)
|
166 |
+
high_entropy_strategy = HighEntropyMasking(num_masks=1)
|
167 |
+
pseudo_random_strategy = PseudoRandomNonStopwordMasking(num_masks=1, seed=10)
|
168 |
+
composite_strategy = CompositeMaskingStrategy([
|
169 |
+
RandomNonStopwordMasking(num_masks=1),
|
170 |
+
HighEntropyMasking(num_masks=1)
|
171 |
+
])
|
172 |
+
|
173 |
+
# Choose a strategy
|
174 |
+
chosen_strategy = composite_strategy # You can choose any initialized strategy
|
175 |
+
|
176 |
+
# Apply masking
|
177 |
+
masked_sentence, logits_list, top_words_list = mask_between_lcs(
|
178 |
+
sentence,
|
179 |
+
lcs_points,
|
180 |
+
masking_strategy=chosen_strategy,
|
181 |
+
model=model,
|
182 |
+
tokenizer=tokenizer,
|
183 |
+
permissible_indices=permissible_indices
|
184 |
+
)
|
185 |
+
|
186 |
+
print("Masked Sentence:", masked_sentence)
|
187 |
+
for idx, top_words in enumerate(top_words_list):
|
188 |
+
print(f"Top words for mask {idx+1}:", top_words)
|
paraphraser.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Load environment variables
|
6 |
+
load_dotenv()
|
7 |
+
key = os.getenv("OPENAI_API_KEY")
|
8 |
+
|
9 |
+
# Initialize the OpenAI client
|
10 |
+
client = OpenAI(api_key=key)
|
11 |
+
|
12 |
+
def generate_paraphrase(sentences, model="gpt-4", num_paraphrases=10, max_tokens=150, temperature=0.7):
|
13 |
+
"""Generate paraphrased sentences using the OpenAI GPT-4 model."""
|
14 |
+
|
15 |
+
# Ensure sentences is a list
|
16 |
+
if isinstance(sentences, str):
|
17 |
+
sentences = [sentences]
|
18 |
+
|
19 |
+
paraphrased_sentences_list = []
|
20 |
+
|
21 |
+
for sentence in sentences:
|
22 |
+
full_prompt = f"Paraphrase the following text: '{sentence}'"
|
23 |
+
|
24 |
+
try:
|
25 |
+
chat_completion = client.chat.completions.create(
|
26 |
+
messages=[{"role": "user", "content": full_prompt}],
|
27 |
+
model=model,
|
28 |
+
max_tokens=max_tokens,
|
29 |
+
temperature=temperature,
|
30 |
+
n=num_paraphrases # Number of paraphrased sentences to generate
|
31 |
+
)
|
32 |
+
# Extract paraphrased sentences
|
33 |
+
paraphrased_sentences = [choice.message.content.strip() for choice in chat_completion.choices]
|
34 |
+
paraphrased_sentences_list.extend(paraphrased_sentences)
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error paraphrasing sentence '{sentence}': {e}")
|
37 |
+
|
38 |
+
return paraphrased_sentences_list
|
39 |
+
|
40 |
+
# Example usage
|
41 |
+
result = generate_paraphrase(
|
42 |
+
"Mayor Eric Adams did not attend the first candidate forum for the New York City mayoral race, but his record — and the criminal charges he faces — received plenty of attention on Saturday from the Democrats who are running to unseat him."
|
43 |
+
)
|
44 |
+
|
45 |
+
print(f"Number of paraphrases generated: {len(result)}")
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ipywidgets
|
2 |
+
transformers
|
3 |
+
plotly
|
4 |
+
requests
|
5 |
+
Pillow
|
6 |
+
numpy
|
7 |
+
matplotlib
|
8 |
+
tqdm
|
9 |
+
scipy
|
10 |
+
torch
|
11 |
+
seaborn
|
12 |
+
termcolor
|
13 |
+
nltk
|
14 |
+
tenacity
|
15 |
+
pandas
|
16 |
+
graphviz==0.20.3
|
17 |
+
gradio==4.29.0
|
18 |
+
openai
|
19 |
+
python-dotenv
|
20 |
+
scikit-learn
|
21 |
+
sentence-transformers
|
sampling_methods.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
from vocabulary_split import split_vocabulary, filter_logits
|
4 |
+
from masking_methods import tokenizer
|
5 |
+
|
6 |
+
# Get permissible vocabulary
|
7 |
+
permissible, _ = split_vocabulary(seed=42)
|
8 |
+
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
|
9 |
+
|
10 |
+
def sample_word(sentence, words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
11 |
+
# Convert logits to a tensor and filter based on permissible indices
|
12 |
+
filtered_logits = filter_logits(torch.tensor(logits), permissible_indices)
|
13 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
14 |
+
|
15 |
+
# Select sampling technique
|
16 |
+
if sampling_technique == 'inverse_transform':
|
17 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
18 |
+
random_prob = random.random()
|
19 |
+
sampled_index = torch.searchsorted(cumulative_probs, random_prob)
|
20 |
+
elif sampling_technique == 'exponential_minimum':
|
21 |
+
exp_probs = torch.exp(-torch.log(probs))
|
22 |
+
sampled_index = torch.argmax(random.rand_like(exp_probs) * exp_probs)
|
23 |
+
elif sampling_technique == 'temperature':
|
24 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
25 |
+
elif sampling_technique == 'greedy':
|
26 |
+
sampled_index = torch.argmax(filtered_logits).item()
|
27 |
+
else:
|
28 |
+
raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
29 |
+
|
30 |
+
sampled_word = tokenizer.decode([sampled_index])
|
31 |
+
|
32 |
+
# Replace [MASK] with the sampled word
|
33 |
+
filled_sentence = sentence.replace('[MASK]', sampled_word)
|
34 |
+
|
35 |
+
return filled_sentence
|
scores.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from nltk.translate.bleu_score import sentence_bleu
|
4 |
+
from transformers import BertTokenizer, BertModel
|
5 |
+
|
6 |
+
# Function to Calculate the BLEU score
|
7 |
+
def calculate_bleu(reference, candidate):
|
8 |
+
return sentence_bleu([reference], candidate)
|
9 |
+
|
10 |
+
# Function to calculate BERT score
|
11 |
+
def calculate_bert(reference, candidate):
|
12 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
13 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
14 |
+
|
15 |
+
reference_tokens = tokenizer.tokenize(reference)
|
16 |
+
candidate_tokens = tokenizer.tokenize(candidate)
|
17 |
+
|
18 |
+
reference_ids = tokenizer.encode(reference, add_special_tokens=True, max_length=512, truncation=True, return_tensors="pt")
|
19 |
+
candidate_ids = tokenizer.encode(candidate, add_special_tokens=True, max_length=512, truncation=True, return_tensors="pt")
|
20 |
+
|
21 |
+
with torch.no_grad():
|
22 |
+
reference_outputs = model(reference_ids)
|
23 |
+
candidate_outputs = model(candidate_ids)
|
24 |
+
|
25 |
+
reference_embeddings = reference_outputs[0][:, 0, :].numpy()
|
26 |
+
candidate_embeddings = candidate_outputs[0][:, 0, :].numpy()
|
27 |
+
|
28 |
+
cosine_similarity = np.dot(reference_embeddings, candidate_embeddings.T) / (np.linalg.norm(reference_embeddings) * np.linalg.norm(candidate_embeddings))
|
29 |
+
return np.mean(cosine_similarity)
|
30 |
+
|
31 |
+
# Function to calculate minimum edit distance
|
32 |
+
def min_edit_distance(reference, candidate):
|
33 |
+
m = len(reference)
|
34 |
+
n = len(candidate)
|
35 |
+
|
36 |
+
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
37 |
+
|
38 |
+
for i in range(m + 1):
|
39 |
+
for j in range(n + 1):
|
40 |
+
if i == 0:
|
41 |
+
dp[i][j] = j
|
42 |
+
elif j == 0:
|
43 |
+
dp[i][j] = i
|
44 |
+
elif reference[i - 1] == candidate[j - 1]:
|
45 |
+
dp[i][j] = dp[i - 1][j - 1]
|
46 |
+
else:
|
47 |
+
dp[i][j] = 1 + min(dp[i][j - 1], # Insert
|
48 |
+
dp[i - 1][j], # Remove
|
49 |
+
dp[i - 1][j - 1]) # Replace
|
50 |
+
|
51 |
+
return dp[m][n]
|
threeD_plot.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
from scipy.interpolate import griddata
|
4 |
+
|
5 |
+
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
|
6 |
+
# Convert input lists to NumPy arrays
|
7 |
+
detectability = np.array(detectability_val)
|
8 |
+
distortion = np.array(distortion_val)
|
9 |
+
euclidean = np.array(euclidean_val)
|
10 |
+
|
11 |
+
# Normalize the values to range [0, 1]
|
12 |
+
def normalize(data):
|
13 |
+
min_val, max_val = np.min(data), np.max(data)
|
14 |
+
return (data - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(data)
|
15 |
+
|
16 |
+
norm_detectability = normalize(detectability)
|
17 |
+
norm_distortion = normalize(distortion)
|
18 |
+
norm_euclidean = normalize(euclidean)
|
19 |
+
|
20 |
+
# Composite score: maximize detectability, minimize distortion and Euclidean distance
|
21 |
+
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
|
22 |
+
|
23 |
+
# Sweet spot values
|
24 |
+
sweet_spot_index = np.argmax(composite_score)
|
25 |
+
sweet_spot = (detectability[sweet_spot_index], distortion[sweet_spot_index], euclidean[sweet_spot_index])
|
26 |
+
|
27 |
+
# Create a meshgrid for interpolation
|
28 |
+
x_grid, y_grid = np.meshgrid(
|
29 |
+
np.linspace(np.min(detectability), np.max(detectability), 30),
|
30 |
+
np.linspace(np.min(distortion), np.max(distortion), 30)
|
31 |
+
)
|
32 |
+
|
33 |
+
# Interpolate z values (Euclidean distances) to fit the grid
|
34 |
+
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='linear')
|
35 |
+
|
36 |
+
if z_grid is None:
|
37 |
+
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
|
38 |
+
|
39 |
+
# Create the 3D contour plot with the Plasma color scale
|
40 |
+
fig = go.Figure(data=go.Surface(
|
41 |
+
z=z_grid,
|
42 |
+
x=x_grid,
|
43 |
+
y=y_grid,
|
44 |
+
contours={"z": {"show": True, "start": np.min(euclidean), "end": np.max(euclidean), "size": 0.1, "usecolormap": True}},
|
45 |
+
colorscale='Plasma'
|
46 |
+
))
|
47 |
+
|
48 |
+
# Add a marker for the sweet spot
|
49 |
+
fig.add_trace(go.Scatter3d(
|
50 |
+
x=[sweet_spot[0]],
|
51 |
+
y=[sweet_spot[1]],
|
52 |
+
z=[sweet_spot[2]],
|
53 |
+
mode='markers+text',
|
54 |
+
marker=dict(size=10, color='red', symbol='circle'),
|
55 |
+
text=["Sweet Spot"],
|
56 |
+
textposition="top center"
|
57 |
+
))
|
58 |
+
|
59 |
+
# Set axis labels
|
60 |
+
fig.update_layout(
|
61 |
+
scene=dict(
|
62 |
+
xaxis_title='Detectability Score',
|
63 |
+
yaxis_title='Distortion Score',
|
64 |
+
zaxis_title='Euclidean Distance'
|
65 |
+
),
|
66 |
+
margin=dict(l=0, r=0, b=0, t=0)
|
67 |
+
)
|
68 |
+
|
69 |
+
return fig
|
tree.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
import textwrap
|
3 |
+
import re
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
def apply_lcs_numbering(sentence, common_grams):
|
7 |
+
"""Apply LCS numbering based on common grams."""
|
8 |
+
for idx, lcs in common_grams:
|
9 |
+
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
|
10 |
+
return sentence
|
11 |
+
|
12 |
+
def highlight_words(sentence, color_map):
|
13 |
+
"""Highlight specified words in a sentence with corresponding colors."""
|
14 |
+
for word, color in color_map.items():
|
15 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
16 |
+
return sentence
|
17 |
+
|
18 |
+
def clean_and_wrap_nodes(nodes, highlight_info):
|
19 |
+
"""Clean nodes by removing labels and wrap text for display."""
|
20 |
+
global_color_map = dict(highlight_info)
|
21 |
+
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
22 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
23 |
+
return ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes]
|
24 |
+
|
25 |
+
def get_levels_and_edges(nodes):
|
26 |
+
"""Determine levels and create edges dynamically."""
|
27 |
+
levels = {}
|
28 |
+
edges = []
|
29 |
+
for i, node in enumerate(nodes):
|
30 |
+
level = int(node.split()[-1][1])
|
31 |
+
levels[i] = level
|
32 |
+
|
33 |
+
# Create edges from level 0 to level 1 nodes
|
34 |
+
root_node = next(i for i, level in levels.items() if level == 0)
|
35 |
+
edges.extend((root_node, i) for i, level in levels.items() if level == 1)
|
36 |
+
|
37 |
+
return levels, edges
|
38 |
+
|
39 |
+
def calculate_positions(levels):
|
40 |
+
"""Calculate x, y positions for each node based on levels."""
|
41 |
+
positions = {}
|
42 |
+
level_heights = defaultdict(int)
|
43 |
+
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
|
44 |
+
|
45 |
+
for node, level in levels.items():
|
46 |
+
level_heights[level] += 1
|
47 |
+
x_gap = 2
|
48 |
+
l1_y_gap = 10
|
49 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
50 |
+
y_offsets[level] += 1
|
51 |
+
|
52 |
+
return positions
|
53 |
+
|
54 |
+
def color_highlighted_words(node, color_map):
|
55 |
+
"""Highlight words in a wrapped node string."""
|
56 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
57 |
+
colored_parts = [
|
58 |
+
f"<span style='color: {color_map.get(match.group(1), 'black')};'>{match.group(1)}</span>"
|
59 |
+
if (match := re.match(r'\{\{(.*?)\}\}', part))
|
60 |
+
else part
|
61 |
+
for part in parts
|
62 |
+
]
|
63 |
+
return ''.join(colored_parts)
|
64 |
+
|
65 |
+
def generate_subplot(paraphrased_sentence, scheme_sentences, highlight_info, common_grams, subplot_number):
|
66 |
+
"""Generate a subplot based on the input sentences and highlight info."""
|
67 |
+
# Combine nodes into one list with appropriate labels
|
68 |
+
nodes = [paraphrased_sentence + ' L0'] + [s + ' L1' for s in scheme_sentences]
|
69 |
+
|
70 |
+
# Apply LCS numbering and clean/wrap nodes
|
71 |
+
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
|
72 |
+
wrapped_nodes = clean_and_wrap_nodes(nodes, highlight_info)
|
73 |
+
|
74 |
+
# Get levels and edges
|
75 |
+
levels, edges = get_levels_and_edges(nodes)
|
76 |
+
positions = calculate_positions(levels)
|
77 |
+
|
78 |
+
# Create figure
|
79 |
+
fig = go.Figure()
|
80 |
+
|
81 |
+
# Add nodes and edges to the figure
|
82 |
+
for i, node in enumerate(wrapped_nodes):
|
83 |
+
colored_node = color_highlighted_words(node, dict(highlight_info))
|
84 |
+
x, y = positions[i]
|
85 |
+
|
86 |
+
fig.add_trace(go.Scatter(
|
87 |
+
x=[-x], # Reflect the x coordinate
|
88 |
+
y=[y],
|
89 |
+
mode='markers',
|
90 |
+
marker=dict(size=10, color='blue'),
|
91 |
+
hoverinfo='none'
|
92 |
+
))
|
93 |
+
fig.add_annotation(
|
94 |
+
x=-x, # Reflect the x coordinate
|
95 |
+
y=y,
|
96 |
+
text=colored_node,
|
97 |
+
showarrow=False,
|
98 |
+
xshift=15,
|
99 |
+
align="center",
|
100 |
+
font=dict(size=12),
|
101 |
+
bordercolor='black',
|
102 |
+
borderwidth=1,
|
103 |
+
borderpad=2,
|
104 |
+
bgcolor='white',
|
105 |
+
width=300,
|
106 |
+
height=120
|
107 |
+
)
|
108 |
+
|
109 |
+
# Add edges and edge annotations
|
110 |
+
edge_texts = [
|
111 |
+
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking",
|
112 |
+
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
|
113 |
+
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling",
|
114 |
+
"Exponential Minimum Sampling", "Inverse Transform Sampling",
|
115 |
+
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
|
116 |
+
"Inverse Transform Sampling"
|
117 |
+
]
|
118 |
+
|
119 |
+
for i, edge in enumerate(edges):
|
120 |
+
x0, y0 = positions[edge[0]]
|
121 |
+
x1, y1 = positions[edge[1]]
|
122 |
+
fig.add_trace(go.Scatter(
|
123 |
+
x=[-x0, -x1], # Reflect the x coordinates
|
124 |
+
y=[y0, y1],
|
125 |
+
mode='lines',
|
126 |
+
line=dict(color='black', width=1)
|
127 |
+
))
|
128 |
+
|
129 |
+
# Add text annotation above the edge
|
130 |
+
mid_x = (-x0 + -x1) / 2
|
131 |
+
mid_y = (y0 + y1) / 2
|
132 |
+
fig.add_annotation(
|
133 |
+
x=mid_x,
|
134 |
+
y=mid_y + 0.8, # Adjust y position to shift text upwards
|
135 |
+
text=edge_texts[i], # Use the text specific to this edge
|
136 |
+
showarrow=False,
|
137 |
+
font=dict(size=12),
|
138 |
+
align="center"
|
139 |
+
)
|
140 |
+
|
141 |
+
fig.update_layout(
|
142 |
+
showlegend=False,
|
143 |
+
margin=dict(t=20, b=20, l=20, r=20),
|
144 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
145 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
146 |
+
width=1435,
|
147 |
+
height=1000
|
148 |
+
)
|
149 |
+
|
150 |
+
return fig
|
151 |
+
|
152 |
+
def generate_subplot1(paraphrased_sentence, scheme_sentences, highlight_info, common_grams):
|
153 |
+
return generate_subplot(paraphrased_sentence, scheme_sentences, highlight_info, common_grams, subplot_number=1)
|
154 |
+
|
155 |
+
def generate_subplot2(scheme_sentences, sampled_sentence, highlight_info, common_grams):
|
156 |
+
nodes = scheme_sentences + [s + ' L1' for s in sampled_sentence]
|
157 |
+
for i in range(len(scheme_sentences)):
|
158 |
+
nodes[i] += ' L0' # Reassign levels
|
159 |
+
|
160 |
+
# Apply LCS numbering and clean/wrap nodes
|
161 |
+
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
|
162 |
+
wrapped_nodes = clean_and_wrap_nodes(nodes, highlight_info)
|
163 |
+
|
164 |
+
# Get levels and edges
|
165 |
+
levels, edges = get_levels_and_edges(nodes)
|
166 |
+
positions = calculate_positions(levels)
|
167 |
+
|
168 |
+
# Create figure
|
169 |
+
fig2 = go.Figure()
|
170 |
+
|
171 |
+
# Add nodes and edges to the figure
|
172 |
+
for i, node in enumerate(wrapped_nodes):
|
173 |
+
colored_node = color_highlighted_words(node, dict(highlight_info))
|
174 |
+
x, y = positions[i]
|
175 |
+
|
176 |
+
fig2.add_trace(go.Scatter(
|
177 |
+
x=[-x], # Reflect the x coordinate
|
178 |
+
y=[y],
|
179 |
+
mode='markers',
|
180 |
+
marker=dict(size=10, color='blue'),
|
181 |
+
hoverinfo='none'
|
182 |
+
))
|
183 |
+
fig2.add_annotation(
|
184 |
+
x=-x, # Reflect the x coordinate
|
185 |
+
y=y,
|
186 |
+
text=colored_node,
|
187 |
+
showarrow=False,
|
188 |
+
xshift=15,
|
189 |
+
align="center",
|
190 |
+
font=dict(size=12),
|
191 |
+
bordercolor='black',
|
192 |
+
borderwidth=1,
|
193 |
+
borderpad=2,
|
194 |
+
bgcolor='white',
|
195 |
+
width=450,
|
196 |
+
height=65
|
197 |
+
)
|
198 |
+
|
199 |
+
# Add edges and text above each edge
|
200 |
+
edge_texts = [
|
201 |
+
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking",
|
202 |
+
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
|
203 |
+
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling",
|
204 |
+
"Exponential Minimum Sampling", "Inverse Transform Sampling",
|
205 |
+
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
|
206 |
+
"Inverse Transform Sampling"
|
207 |
+
]
|
208 |
+
|
209 |
+
for i, edge in enumerate(edges):
|
210 |
+
x0, y0 = positions[edge[0]]
|
211 |
+
x1, y1 = positions[edge[1]]
|
212 |
+
fig2.add_trace(go.Scatter(
|
213 |
+
x=[-x0, -x1], # Reflect the x coordinates
|
214 |
+
y=[y0, y1],
|
215 |
+
mode='lines',
|
216 |
+
line=dict(color='black', width=1)
|
217 |
+
))
|
218 |
+
|
219 |
+
# Add text annotation above the edge
|
220 |
+
mid_x = (-x0 + -x1) / 2
|
221 |
+
mid_y = (y0 + y1) / 2
|
222 |
+
fig2.add_annotation(
|
223 |
+
x=mid_x,
|
224 |
+
y=mid_y + 0.8, # Adjust y position to shift text upwards
|
225 |
+
text=edge_texts[i], # Use the text specific to this edge
|
226 |
+
showarrow=False,
|
227 |
+
font=dict(size=12),
|
228 |
+
align="center"
|
229 |
+
)
|
230 |
+
|
231 |
+
fig2.update_layout(
|
232 |
+
showlegend=False,
|
233 |
+
margin=dict(t=20, b=20, l=20, r=20),
|
234 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
235 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
236 |
+
width=1435,
|
237 |
+
height=1000
|
238 |
+
)
|
239 |
+
|
240 |
+
return fig2
|
vocabulary_split.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
4 |
+
|
5 |
+
# Load tokenizer and model once
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
7 |
+
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
|
8 |
+
|
9 |
+
def split_vocabulary(seed=42):
|
10 |
+
"""Split the vocabulary into permissible and non-permissible buckets."""
|
11 |
+
# Get the full vocabulary
|
12 |
+
vocab = list(tokenizer.get_vocab().items())
|
13 |
+
|
14 |
+
# Initialize the random number generator
|
15 |
+
random.seed(seed)
|
16 |
+
|
17 |
+
# Split the vocabulary
|
18 |
+
permissible = {}
|
19 |
+
non_permissible = {}
|
20 |
+
|
21 |
+
for word, index in vocab:
|
22 |
+
target_dict = permissible if random.random() < 0.5 else non_permissible
|
23 |
+
target_dict[word] = index
|
24 |
+
|
25 |
+
return permissible, non_permissible
|
26 |
+
|
27 |
+
def get_logits_for_mask(sentence):
|
28 |
+
"""Get the logits for the masked token in the sentence."""
|
29 |
+
inputs = tokenizer(sentence, return_tensors="pt")
|
30 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
31 |
+
|
32 |
+
with torch.no_grad():
|
33 |
+
logits = model(**inputs).logits[0, mask_token_index, :]
|
34 |
+
|
35 |
+
return logits.squeeze()
|
36 |
+
|
37 |
+
def filter_logits(logits, permissible_indices):
|
38 |
+
"""Filter logits based on permissible indices."""
|
39 |
+
filtered_logits = logits.clone()
|
40 |
+
|
41 |
+
# Set logits to -inf for non-permissible indices
|
42 |
+
filtered_logits[~permissible_indices] = float('-inf')
|
43 |
+
|
44 |
+
return filtered_logits
|
45 |
+
|
46 |
+
# Usage example
|
47 |
+
permissible, _ = split_vocabulary(seed=42)
|
48 |
+
|
49 |
+
# Create permissible indices tensor
|
50 |
+
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))], dtype=torch.bool)
|
51 |
+
|
52 |
+
# When sampling:
|
53 |
+
sentence = "The [MASK] is bright today."
|
54 |
+
logits = get_logits_for_mask(sentence)
|
55 |
+
filtered_logits = filter_logits(logits, permissible_indices)
|
56 |
+
|
watermark_detector.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from nltk.corpus import stopwords
|
3 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
4 |
+
from vocabulary_split import split_vocabulary, filter_logits
|
5 |
+
import torch
|
6 |
+
from lcs import find_common_subsequences
|
7 |
+
from paraphraser import generate_paraphrase
|
8 |
+
|
9 |
+
nltk.download('punkt', quiet=True)
|
10 |
+
nltk.download('stopwords', quiet=True)
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
13 |
+
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
14 |
+
|
15 |
+
permissible, _ = split_vocabulary(seed=42)
|
16 |
+
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
|
17 |
+
|
18 |
+
def get_non_melting_points(original_sentence):
|
19 |
+
paraphrased_sentences = generate_paraphrase(original_sentence)
|
20 |
+
common_subsequences = find_common_subsequences(original_sentence, paraphrased_sentences)
|
21 |
+
return common_subsequences
|
22 |
+
|
23 |
+
def get_word_between_points(sentence, start_point, end_point):
|
24 |
+
words = nltk.word_tokenize(sentence)
|
25 |
+
stop_words = set(stopwords.words('english'))
|
26 |
+
start_index = sentence.index(start_point[1])
|
27 |
+
end_index = sentence.index(end_point[1])
|
28 |
+
|
29 |
+
for word in words[start_index+1:end_index]:
|
30 |
+
if word.lower() not in stop_words:
|
31 |
+
return word, words.index(word)
|
32 |
+
return None, None
|
33 |
+
|
34 |
+
def get_logits_for_mask(sentence):
|
35 |
+
inputs = tokenizer(sentence, return_tensors="pt")
|
36 |
+
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
outputs = model(**inputs)
|
40 |
+
|
41 |
+
logits = outputs.logits
|
42 |
+
mask_token_logits = logits[0, mask_token_index, :]
|
43 |
+
return mask_token_logits.squeeze()
|
44 |
+
|
45 |
+
def detect_watermark(sentence):
|
46 |
+
non_melting_points = get_non_melting_points(sentence)
|
47 |
+
|
48 |
+
if len(non_melting_points) < 2:
|
49 |
+
return False, "Not enough non-melting points found."
|
50 |
+
|
51 |
+
word_to_check, index = get_word_between_points(sentence, non_melting_points[0], non_melting_points[1])
|
52 |
+
|
53 |
+
if word_to_check is None:
|
54 |
+
return False, "No suitable word found between non-melting points."
|
55 |
+
|
56 |
+
words = nltk.word_tokenize(sentence)
|
57 |
+
masked_sentence = ' '.join(words[:index] + ['[MASK]'] + words[index+1:])
|
58 |
+
|
59 |
+
logits = get_logits_for_mask(masked_sentence)
|
60 |
+
filtered_logits = filter_logits(logits, permissible_indices)
|
61 |
+
|
62 |
+
top_predictions = filtered_logits.argsort()[-5:]
|
63 |
+
predicted_words = [tokenizer.decode([i]) for i in top_predictions]
|
64 |
+
|
65 |
+
if word_to_check in predicted_words:
|
66 |
+
return True, f"Watermark detected. The word '{word_to_check}' is in the permissible vocabulary."
|
67 |
+
else:
|
68 |
+
return False, f"No watermark detected. The word '{word_to_check}' is not in the permissible vocabulary."
|
69 |
+
|
70 |
+
# Example usage
|
71 |
+
# if __name__ == "__main__":
|
72 |
+
# test_sentence = "The quick brown fox jumps over the lazy dog."
|
73 |
+
# is_watermarked, message = detect_watermark(test_sentence)
|
74 |
+
# print(f"Is the sentence watermarked? {is_watermarked}")
|
75 |
+
# print(f"Detection message: {message}")
|