Spaces:
Running
Running
import nltk | |
nltk.download('stopwords') | |
# from transformers import AutoTokenizer | |
# from transformers import AutoModelForSeq2SeqLM | |
import plotly.graph_objs as go | |
from transformers import pipeline | |
import random | |
import gradio as gr | |
from tree import generate_subplot1, generate_subplot2 | |
from paraphraser import generate_paraphrase | |
from lcs import find_common_subsequences, find_common_gram_positions | |
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html | |
from entailment import analyze_entailment | |
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words | |
from sampling_methods import sample_word | |
from detectability import SentenceDetectabilityCalculator | |
from distortion import SentenceDistortionCalculator | |
from euclidean_distance import SentenceEuclideanDistanceCalculator | |
from threeD_plot import gen_three_D_plot | |
# Function for the Gradio interface | |
def model(prompt): | |
user_prompt = prompt | |
paraphrased_sentences = generate_paraphrase(user_prompt) | |
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7) | |
print(analyze_entailment(user_prompt, paraphrased_sentences, 0.7)) | |
common_grams = find_common_subsequences(user_prompt, selected_sentences) | |
subsequences = [subseq for _, subseq in common_grams] | |
common_grams_position = find_common_gram_positions(selected_sentences, subsequences) | |
masked_sentences = [] | |
masked_words = [] | |
masked_logits = [] | |
for sentence in paraphrased_sentences: | |
masked_sent, logits, words = mask_non_stopword(sentence) | |
masked_sentences.append(masked_sent) | |
masked_words.append(words) | |
masked_logits.append(logits) | |
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence) | |
masked_sentences.append(masked_sent) | |
masked_words.append(words) | |
masked_logits.append(logits) | |
masked_sent, logits, words = high_entropy_words(sentence, common_grams) | |
masked_sentences.append(masked_sent) | |
masked_words.append(words) | |
masked_logits.append(logits) | |
sampled_sentences = [] | |
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits): | |
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0)) | |
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0)) | |
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0)) | |
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0)) | |
colors = ["red", "blue", "brown", "green"] | |
def select_color(): | |
return random.choice(colors) | |
highlight_info = [(word, select_color()) for _, word in common_grams] | |
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt") | |
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences") | |
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences") | |
trees1 = [] | |
trees2 = [] | |
masked_index = 0 | |
sampled_index = 0 | |
for i, sentence in enumerate(paraphrased_sentences): | |
next_masked_sentences = masked_sentences[masked_index:masked_index + 3] | |
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12] | |
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams) | |
trees1.append(tree1) | |
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams) | |
trees2.append(tree2) | |
masked_index += 3 | |
sampled_index += 12 | |
reparaphrased_sentences = generate_paraphrase(sampled_sentences) | |
len_reparaphrased_sentences = len(reparaphrased_sentences) | |
reparaphrased_sentences_list = [] | |
# Process the sentences in batches of 10 | |
for i in range(0, len_reparaphrased_sentences, 10): | |
# Get the current batch of 10 sentences | |
batch = reparaphrased_sentences[i:i + 10] | |
# Check if the batch has exactly 10 sentences | |
if len(batch) == 10: | |
# Call the display_sentences function and store the result in the list | |
html_block = reparaphrased_sentences_html(batch) | |
reparaphrased_sentences_list.append(html_block) | |
distortion_list = [] | |
detectability_list = [] | |
euclidean_dist_list = [] | |
distortion_calculator = SentenceDistortionCalculator(user_prompt, reparaphrased_sentences) | |
distortion_calculator.calculate_all_metrics() | |
distortion_calculator.normalize_metrics() | |
distortion_calculator.calculate_combined_distortion() | |
distortion = distortion_calculator.get_combined_distortions() | |
for each in distortion.items(): | |
distortion_list.append(each[1]) | |
detectability_calculator = SentenceDetectabilityCalculator(user_prompt, reparaphrased_sentences) | |
detectability_calculator.calculate_all_metrics() | |
detectability_calculator.normalize_metrics() | |
detectability_calculator.calculate_combined_detectability() | |
detectability = detectability_calculator.get_combined_detectabilities() | |
for each in detectability.items(): | |
detectability_list.append(each[1]) | |
euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(user_prompt, reparaphrased_sentences) | |
euclidean_dist_calculator.calculate_all_metrics() | |
euclidean_dist_calculator.normalize_metrics() | |
euclidean_dist_calculator.get_normalized_metrics() | |
euclidean_dist = detectability_calculator.get_combined_detectabilities() | |
for each in euclidean_dist.items(): | |
euclidean_dist_list.append(each[1]) | |
three_D_plot = gen_three_D_plot(detectability_list, distortion_list, euclidean_dist_list) | |
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 + reparaphrased_sentences_list + [three_D_plot] | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown("# **AIISC Watermarking Model**") | |
with gr.Row(): | |
user_input = gr.Textbox(label="User Prompt") | |
with gr.Row(): | |
submit_button = gr.Button("Submit") | |
clear_button = gr.Button("Clear") | |
with gr.Row(): | |
highlighted_user_prompt = gr.HTML() | |
with gr.Row(): | |
with gr.Tabs(): | |
with gr.TabItem("Paraphrased Sentences"): | |
highlighted_accepted_sentences = gr.HTML() | |
with gr.TabItem("Discarded Sentences"): | |
highlighted_discarded_sentences = gr.HTML() | |
# Adding labels before the tree plots | |
with gr.Row(): | |
gr.Markdown("### Where to Watermark?") # Label for masked sentences trees | |
with gr.Row(): | |
with gr.Tabs(): | |
tree1_tabs = [] | |
for i in range(10): # Adjust this range according to the number of trees | |
with gr.TabItem(f"Sentence {i+1}"): | |
tree1 = gr.Plot() | |
tree1_tabs.append(tree1) | |
with gr.Row(): | |
gr.Markdown("### How to Watermark?") # Label for sampled sentences trees | |
with gr.Row(): | |
with gr.Tabs(): | |
tree2_tabs = [] | |
for i in range(10): # Adjust this range according to the number of trees | |
with gr.TabItem(f"Sentence {i+1}"): | |
tree2 = gr.Plot() | |
tree2_tabs.append(tree2) | |
# Adding the "Re-paraphrased Sentences" section | |
with gr.Row(): | |
gr.Markdown("### Re-paraphrased Sentences") # Label for re-paraphrased sentences | |
# Adding tabs for the re-paraphrased sentences | |
with gr.Row(): | |
with gr.Tabs(): | |
reparaphrased_sentences_tabs = [] | |
for i in range(120): # 120 tabs for 120 batches of sentences | |
with gr.TabItem(f"Sentence {i+1}"): | |
reparaphrased_sent_html = gr.HTML() # Placeholder for each batch | |
reparaphrased_sentences_tabs.append(reparaphrased_sent_html) | |
with gr.Row(): | |
gr.Markdown("### 3D Plot for Sweet Spot") | |
with gr.Row(): | |
three_D_plot = gr.Plot() | |
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]) | |
clear_button.click(lambda: "", inputs=None, outputs=user_input) | |
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]) | |
demo.launch(share=True) | |