Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,10 +7,6 @@ Original file is located at
|
|
7 |
https://colab.research.google.com/drive/1pFGR4uvXMMWVJFQeFmn--arumSxqa5Yy
|
8 |
"""
|
9 |
|
10 |
-
|
11 |
-
import gradio as gr
|
12 |
-
|
13 |
-
# import streamlit as st
|
14 |
from transformers import AutoTokenizer
|
15 |
from transformers import AutoModelForSeq2SeqLM
|
16 |
import plotly.graph_objects as go
|
@@ -35,7 +31,7 @@ import scipy.stats
|
|
35 |
import torch
|
36 |
from transformers import GPT2LMHeadModel
|
37 |
import seaborn as sns
|
38 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
39 |
# from colorama import Fore, Style
|
40 |
# import openai
|
41 |
import random
|
@@ -44,8 +40,11 @@ from termcolor import colored
|
|
44 |
import nltk
|
45 |
from nltk.translate.bleu_score import sentence_bleu
|
46 |
from transformers import BertTokenizer, BertModel
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
import nltk
|
49 |
nltk.download('stopwords')
|
50 |
|
51 |
# Function to Initialize the Model
|
@@ -301,12 +300,6 @@ def generate_paraphrase(question):
|
|
301 |
|
302 |
question = "Following the declaration of the State of Israel in 1948, neighboring Arab states invaded. The war ended with Israel controlling a significant portion of the territory. Many Palestinians became refugees."
|
303 |
|
304 |
-
import nltk
|
305 |
-
nltk.download('punkt')
|
306 |
-
import re
|
307 |
-
from nltk.corpus import stopwords
|
308 |
-
from nltk.tokenize import word_tokenize
|
309 |
-
|
310 |
import re
|
311 |
from nltk.corpus import stopwords
|
312 |
|
@@ -373,52 +366,25 @@ def find_common_subsequences(sentence, str_list):
|
|
373 |
|
374 |
return common_grams
|
375 |
|
376 |
-
question = '''the colorado republican party sent a mass email last week with the subject line "god hates pride"'''
|
377 |
-
res = generate_paraphrase(question)
|
378 |
-
|
379 |
-
res
|
380 |
-
|
381 |
-
common_grams = find_common_subsequences(question, res[0:3])
|
382 |
-
common_grams
|
383 |
-
|
384 |
-
common_gram_words = [word for gram in common_grams for word in gram.split()]
|
385 |
-
common_gram_words
|
386 |
-
|
387 |
def llm_output(prompt):
|
388 |
-
|
389 |
-
# gen_text = sequences[0]["generated_text"]
|
390 |
-
# sentences = gen_text.split('.')
|
391 |
-
# # first_sentence = get_first_sentence(gen_text[len(prompt):])
|
392 |
-
# return gen_text,sentences[-3]
|
393 |
-
return prompt,prompt
|
394 |
-
|
395 |
-
import re
|
396 |
-
import html
|
397 |
|
398 |
def highlight_phrases_with_colors(sentences, phrases):
|
399 |
-
color_map = {}
|
400 |
-
color_index = 0
|
401 |
-
|
402 |
-
# Generate HTML for highlighting each sentence
|
403 |
highlighted_html = []
|
404 |
idx = 1
|
405 |
for sentence in sentences:
|
406 |
sentence_with_idx = f"{idx}. {sentence}"
|
407 |
idx += 1
|
408 |
-
highlighted_sentence =
|
409 |
phrase_count = 0
|
410 |
-
|
411 |
-
# Split sentence into words to apply numbering
|
412 |
words = re.findall(r'\b\w+\b', sentence)
|
413 |
-
word_index = 1
|
414 |
-
|
415 |
-
# Highlight each phrase with a unique color and number
|
416 |
for phrase in phrases:
|
417 |
if phrase not in color_map:
|
418 |
-
# Assign a new color if the phrase hasn't been encountered before
|
419 |
color_map[phrase] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
|
420 |
color_index += 1
|
421 |
-
|
422 |
escaped_phrase = re.escape(phrase)
|
423 |
pattern = rf'\b{escaped_phrase}\b'
|
424 |
highlighted_sentence, num_replacements = re.subn(
|
@@ -436,34 +402,68 @@ def highlight_phrases_with_colors(sentences, phrases):
|
|
436 |
)
|
437 |
if num_replacements > 0:
|
438 |
phrase_count += 1
|
439 |
-
word_index += 1
|
440 |
-
|
441 |
highlighted_html.append(highlighted_sentence)
|
442 |
-
|
443 |
-
# Join sentences with line breaks
|
444 |
final_html = "<br><br>".join(highlighted_html)
|
445 |
-
|
446 |
-
# Wrap in a container div for styling
|
447 |
return f'''
|
448 |
-
<div style="border: solid 1px #; padding: 16px; background-color: #FFFFFF; color: #374151; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius:
|
449 |
-
<h3 style="margin-top: 0; font-size:
|
450 |
-
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius:
|
451 |
</div>
|
452 |
'''
|
453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
def model(prompt):
|
455 |
-
generated,sentence = llm_output(prompt)
|
456 |
res = generate_paraphrase(sentence)
|
457 |
-
common_subs = longest_common_subss(sentence,res)
|
458 |
-
|
459 |
-
common_grams = find_common_subsequences(sentence,res)
|
460 |
-
# common_gram_words = [word for gram in common_grams for word in gram.split()]
|
461 |
for i in range(len(common_subs)):
|
462 |
common_subs[i]["Paraphrased Sentence"] = res[i]
|
463 |
-
result = highlight_phrases_with_colors(res,common_grams)
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
|
|
467 |
|
468 |
with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
|
469 |
gr.Markdown("# Paraphrases the Text and Highlights the Non-melting Points")
|
@@ -485,13 +485,15 @@ with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
|
|
485 |
html_output = gr.HTML()
|
486 |
|
487 |
with gr.Row():
|
|
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
clear_button.click(lambda: "", inputs=None, outputs=[ai_output, selected_sentence, html_output])
|
492 |
-
|
493 |
-
# Launch the demo
|
494 |
-
demo.launch()
|
495 |
|
|
|
|
|
|
|
496 |
|
|
|
|
|
497 |
|
|
|
7 |
https://colab.research.google.com/drive/1pFGR4uvXMMWVJFQeFmn--arumSxqa5Yy
|
8 |
"""
|
9 |
|
|
|
|
|
|
|
|
|
10 |
from transformers import AutoTokenizer
|
11 |
from transformers import AutoModelForSeq2SeqLM
|
12 |
import plotly.graph_objects as go
|
|
|
31 |
import torch
|
32 |
from transformers import GPT2LMHeadModel
|
33 |
import seaborn as sns
|
34 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
|
35 |
# from colorama import Fore, Style
|
36 |
# import openai
|
37 |
import random
|
|
|
40 |
import nltk
|
41 |
from nltk.translate.bleu_score import sentence_bleu
|
42 |
from transformers import BertTokenizer, BertModel
|
43 |
+
import graphviz
|
44 |
+
import gradio as gr
|
45 |
+
|
46 |
+
|
47 |
|
|
|
48 |
nltk.download('stopwords')
|
49 |
|
50 |
# Function to Initialize the Model
|
|
|
300 |
|
301 |
question = "Following the declaration of the State of Israel in 1948, neighboring Arab states invaded. The war ended with Israel controlling a significant portion of the territory. Many Palestinians became refugees."
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
import re
|
304 |
from nltk.corpus import stopwords
|
305 |
|
|
|
366 |
|
367 |
return common_grams
|
368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
def llm_output(prompt):
|
370 |
+
return prompt, prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
def highlight_phrases_with_colors(sentences, phrases):
|
373 |
+
color_map = {}
|
374 |
+
color_index = 0
|
|
|
|
|
375 |
highlighted_html = []
|
376 |
idx = 1
|
377 |
for sentence in sentences:
|
378 |
sentence_with_idx = f"{idx}. {sentence}"
|
379 |
idx += 1
|
380 |
+
highlighted_sentence = sentence_with_idx
|
381 |
phrase_count = 0
|
|
|
|
|
382 |
words = re.findall(r'\b\w+\b', sentence)
|
383 |
+
word_index = 1
|
|
|
|
|
384 |
for phrase in phrases:
|
385 |
if phrase not in color_map:
|
|
|
386 |
color_map[phrase] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
|
387 |
color_index += 1
|
|
|
388 |
escaped_phrase = re.escape(phrase)
|
389 |
pattern = rf'\b{escaped_phrase}\b'
|
390 |
highlighted_sentence, num_replacements = re.subn(
|
|
|
402 |
)
|
403 |
if num_replacements > 0:
|
404 |
phrase_count += 1
|
405 |
+
word_index += 1
|
|
|
406 |
highlighted_html.append(highlighted_sentence)
|
|
|
|
|
407 |
final_html = "<br><br>".join(highlighted_html)
|
|
|
|
|
408 |
return f'''
|
409 |
+
<div style="border: solid 1px #; padding: 16px; background-color: #FFFFFF; color: #374151; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 2px;">
|
410 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #111827;">Paraphrased And Highlighted Text</h3>
|
411 |
+
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 2px;">{final_html}</div>
|
412 |
</div>
|
413 |
'''
|
414 |
|
415 |
+
# Masking Model
|
416 |
+
def mask_non_stopword(sentence):
|
417 |
+
stop_words = set(stopwords.words('english'))
|
418 |
+
words = sentence.split()
|
419 |
+
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
420 |
+
if not non_stop_words:
|
421 |
+
return sentence
|
422 |
+
word_to_mask = random.choice(non_stop_words)
|
423 |
+
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
424 |
+
return masked_sentence
|
425 |
+
|
426 |
+
# Load tokenizer and model for masked language model
|
427 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
428 |
+
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
429 |
+
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
430 |
+
|
431 |
+
def mask(sentence):
|
432 |
+
predictions = fill_mask(sentence)
|
433 |
+
masked_sentences = [predictions[i]['sequence'] for i in range(len(predictions))]
|
434 |
+
return masked_sentences
|
435 |
+
|
436 |
+
# Function to generate the tree and return the Graphviz source
|
437 |
+
def generate_tree(original_sentence: str) -> str:
|
438 |
+
paraphrased_sentences = generate_paraphrase(original_sentence)
|
439 |
+
first_paraphrased_sentence = paraphrased_sentences[0]
|
440 |
+
masked_sentence = mask_non_stopword(first_paraphrased_sentence)
|
441 |
+
masked_versions = mask(masked_sentence)
|
442 |
+
dot = graphviz.Digraph()
|
443 |
+
dot.attr(rankdir='LR', size='8,10!', dpi='72')
|
444 |
+
dot.node("Original", original_sentence)
|
445 |
+
dot.node("Paraphrased", first_paraphrased_sentence)
|
446 |
+
dot.edge("Original", "Paraphrased")
|
447 |
+
for i, masked in enumerate(masked_versions):
|
448 |
+
node_id = f"Masked_{i}"
|
449 |
+
dot.node(node_id, masked)
|
450 |
+
dot.edge("Paraphrased", node_id)
|
451 |
+
return masked_sentence, dot.source
|
452 |
+
|
453 |
+
# Function for the Gradio interface
|
454 |
def model(prompt):
|
455 |
+
generated, sentence = llm_output(prompt)
|
456 |
res = generate_paraphrase(sentence)
|
457 |
+
common_subs = longest_common_subss(sentence, res)
|
458 |
+
common_grams = find_common_subsequences(sentence, res)
|
|
|
|
|
459 |
for i in range(len(common_subs)):
|
460 |
common_subs[i]["Paraphrased Sentence"] = res[i]
|
461 |
+
result = highlight_phrases_with_colors(res, common_grams)
|
462 |
+
masked_sentence, tree_source = generate_tree(sentence)
|
463 |
+
graph = graphviz.Source(tree_source)
|
464 |
+
svg_content = graph.pipe(format='svg').decode('utf-8')
|
465 |
+
# tree = f'<div style="width: 100%; overflow-x: auto;">{svg_content}</div>'
|
466 |
+
return generated, generated, result, masked_sentence, svg_content
|
467 |
|
468 |
with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
|
469 |
gr.Markdown("# Paraphrases the Text and Highlights the Non-melting Points")
|
|
|
485 |
html_output = gr.HTML()
|
486 |
|
487 |
with gr.Row():
|
488 |
+
masked_sentence = gr.Textbox(label="Masked Sentence")
|
489 |
|
490 |
+
with gr.Row():
|
491 |
+
tree = gr.HTML(label="Tree")
|
|
|
|
|
|
|
|
|
492 |
|
493 |
+
submit_button.click(model, inputs=user_input, outputs=[ai_output, selected_sentence, html_output, masked_sentence, tree])
|
494 |
+
clear_button.click(lambda: "", inputs=None, outputs=user_input)
|
495 |
+
clear_button.click(lambda: "", inputs=None, outputs=[ai_output, selected_sentence, html_output, masked_sentence, tree])
|
496 |
|
497 |
+
# Launch the demo
|
498 |
+
demo.launch(share=True)
|
499 |
|