jgyasu commited on
Commit
be52172
1 Parent(s): 5995f2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -68
app.py CHANGED
@@ -7,10 +7,6 @@ Original file is located at
7
  https://colab.research.google.com/drive/1pFGR4uvXMMWVJFQeFmn--arumSxqa5Yy
8
  """
9
 
10
-
11
- import gradio as gr
12
-
13
- # import streamlit as st
14
  from transformers import AutoTokenizer
15
  from transformers import AutoModelForSeq2SeqLM
16
  import plotly.graph_objects as go
@@ -35,7 +31,7 @@ import scipy.stats
35
  import torch
36
  from transformers import GPT2LMHeadModel
37
  import seaborn as sns
38
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
39
  # from colorama import Fore, Style
40
  # import openai
41
  import random
@@ -44,8 +40,11 @@ from termcolor import colored
44
  import nltk
45
  from nltk.translate.bleu_score import sentence_bleu
46
  from transformers import BertTokenizer, BertModel
 
 
 
 
47
 
48
- import nltk
49
  nltk.download('stopwords')
50
 
51
  # Function to Initialize the Model
@@ -301,12 +300,6 @@ def generate_paraphrase(question):
301
 
302
  question = "Following the declaration of the State of Israel in 1948, neighboring Arab states invaded. The war ended with Israel controlling a significant portion of the territory. Many Palestinians became refugees."
303
 
304
- import nltk
305
- nltk.download('punkt')
306
- import re
307
- from nltk.corpus import stopwords
308
- from nltk.tokenize import word_tokenize
309
-
310
  import re
311
  from nltk.corpus import stopwords
312
 
@@ -373,52 +366,25 @@ def find_common_subsequences(sentence, str_list):
373
 
374
  return common_grams
375
 
376
- question = '''the colorado republican party sent a mass email last week with the subject line "god hates pride"'''
377
- res = generate_paraphrase(question)
378
-
379
- res
380
-
381
- common_grams = find_common_subsequences(question, res[0:3])
382
- common_grams
383
-
384
- common_gram_words = [word for gram in common_grams for word in gram.split()]
385
- common_gram_words
386
-
387
  def llm_output(prompt):
388
- # sequences = text_generator(prompt)
389
- # gen_text = sequences[0]["generated_text"]
390
- # sentences = gen_text.split('.')
391
- # # first_sentence = get_first_sentence(gen_text[len(prompt):])
392
- # return gen_text,sentences[-3]
393
- return prompt,prompt
394
-
395
- import re
396
- import html
397
 
398
  def highlight_phrases_with_colors(sentences, phrases):
399
- color_map = {} # Dictionary to store color assignments for each phrase
400
- color_index = 0 # Index to assign colors sequentially
401
-
402
- # Generate HTML for highlighting each sentence
403
  highlighted_html = []
404
  idx = 1
405
  for sentence in sentences:
406
  sentence_with_idx = f"{idx}. {sentence}"
407
  idx += 1
408
- highlighted_sentence = html.escape(sentence_with_idx)
409
  phrase_count = 0
410
-
411
- # Split sentence into words to apply numbering
412
  words = re.findall(r'\b\w+\b', sentence)
413
- word_index = 1 # Index to track words
414
-
415
- # Highlight each phrase with a unique color and number
416
  for phrase in phrases:
417
  if phrase not in color_map:
418
- # Assign a new color if the phrase hasn't been encountered before
419
  color_map[phrase] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
420
  color_index += 1
421
-
422
  escaped_phrase = re.escape(phrase)
423
  pattern = rf'\b{escaped_phrase}\b'
424
  highlighted_sentence, num_replacements = re.subn(
@@ -436,34 +402,68 @@ def highlight_phrases_with_colors(sentences, phrases):
436
  )
437
  if num_replacements > 0:
438
  phrase_count += 1
439
- word_index += 1 # Increment word index after each replacement
440
-
441
  highlighted_html.append(highlighted_sentence)
442
-
443
- # Join sentences with line breaks
444
  final_html = "<br><br>".join(highlighted_html)
445
-
446
- # Wrap in a container div for styling
447
  return f'''
448
- <div style="border: solid 1px #; padding: 16px; background-color: #FFFFFF; color: #374151; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 12px;">
449
- <h3 style="margin-top: 0; font-size: 1.25em; color: #111827;">Paraphrased And Highlighted Text</h3>
450
- <div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 12px;">{final_html}</div>
451
  </div>
452
  '''
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  def model(prompt):
455
- generated,sentence = llm_output(prompt)
456
  res = generate_paraphrase(sentence)
457
- common_subs = longest_common_subss(sentence,res)
458
- # non_melting = non_melting_points(sentence, res)
459
- common_grams = find_common_subsequences(sentence,res)
460
- # common_gram_words = [word for gram in common_grams for word in gram.split()]
461
  for i in range(len(common_subs)):
462
  common_subs[i]["Paraphrased Sentence"] = res[i]
463
- result = highlight_phrases_with_colors(res,common_grams)
464
- return generated, result
465
-
466
- # model(question)
 
 
467
 
468
  with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
469
  gr.Markdown("# Paraphrases the Text and Highlights the Non-melting Points")
@@ -485,13 +485,15 @@ with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
485
  html_output = gr.HTML()
486
 
487
  with gr.Row():
 
488
 
489
- submit_button.click(model, inputs=user_input, outputs=[ai_output, html_output])
490
- clear_button.click(lambda: "", inputs=None, outputs=user_input)
491
- clear_button.click(lambda: "", inputs=None, outputs=[ai_output, selected_sentence, html_output])
492
-
493
- # Launch the demo
494
- demo.launch()
495
 
 
 
 
496
 
 
 
497
 
 
7
  https://colab.research.google.com/drive/1pFGR4uvXMMWVJFQeFmn--arumSxqa5Yy
8
  """
9
 
 
 
 
 
10
  from transformers import AutoTokenizer
11
  from transformers import AutoModelForSeq2SeqLM
12
  import plotly.graph_objects as go
 
31
  import torch
32
  from transformers import GPT2LMHeadModel
33
  import seaborn as sns
34
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
35
  # from colorama import Fore, Style
36
  # import openai
37
  import random
 
40
  import nltk
41
  from nltk.translate.bleu_score import sentence_bleu
42
  from transformers import BertTokenizer, BertModel
43
+ import graphviz
44
+ import gradio as gr
45
+
46
+
47
 
 
48
  nltk.download('stopwords')
49
 
50
  # Function to Initialize the Model
 
300
 
301
  question = "Following the declaration of the State of Israel in 1948, neighboring Arab states invaded. The war ended with Israel controlling a significant portion of the territory. Many Palestinians became refugees."
302
 
 
 
 
 
 
 
303
  import re
304
  from nltk.corpus import stopwords
305
 
 
366
 
367
  return common_grams
368
 
 
 
 
 
 
 
 
 
 
 
 
369
  def llm_output(prompt):
370
+ return prompt, prompt
 
 
 
 
 
 
 
 
371
 
372
  def highlight_phrases_with_colors(sentences, phrases):
373
+ color_map = {}
374
+ color_index = 0
 
 
375
  highlighted_html = []
376
  idx = 1
377
  for sentence in sentences:
378
  sentence_with_idx = f"{idx}. {sentence}"
379
  idx += 1
380
+ highlighted_sentence = sentence_with_idx
381
  phrase_count = 0
 
 
382
  words = re.findall(r'\b\w+\b', sentence)
383
+ word_index = 1
 
 
384
  for phrase in phrases:
385
  if phrase not in color_map:
 
386
  color_map[phrase] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
387
  color_index += 1
 
388
  escaped_phrase = re.escape(phrase)
389
  pattern = rf'\b{escaped_phrase}\b'
390
  highlighted_sentence, num_replacements = re.subn(
 
402
  )
403
  if num_replacements > 0:
404
  phrase_count += 1
405
+ word_index += 1
 
406
  highlighted_html.append(highlighted_sentence)
 
 
407
  final_html = "<br><br>".join(highlighted_html)
 
 
408
  return f'''
409
+ <div style="border: solid 1px #; padding: 16px; background-color: #FFFFFF; color: #374151; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 2px;">
410
+ <h3 style="margin-top: 0; font-size: 1em; color: #111827;">Paraphrased And Highlighted Text</h3>
411
+ <div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 2px;">{final_html}</div>
412
  </div>
413
  '''
414
 
415
+ # Masking Model
416
+ def mask_non_stopword(sentence):
417
+ stop_words = set(stopwords.words('english'))
418
+ words = sentence.split()
419
+ non_stop_words = [word for word in words if word.lower() not in stop_words]
420
+ if not non_stop_words:
421
+ return sentence
422
+ word_to_mask = random.choice(non_stop_words)
423
+ masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
424
+ return masked_sentence
425
+
426
+ # Load tokenizer and model for masked language model
427
+ tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
428
+ model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
429
+ fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
430
+
431
+ def mask(sentence):
432
+ predictions = fill_mask(sentence)
433
+ masked_sentences = [predictions[i]['sequence'] for i in range(len(predictions))]
434
+ return masked_sentences
435
+
436
+ # Function to generate the tree and return the Graphviz source
437
+ def generate_tree(original_sentence: str) -> str:
438
+ paraphrased_sentences = generate_paraphrase(original_sentence)
439
+ first_paraphrased_sentence = paraphrased_sentences[0]
440
+ masked_sentence = mask_non_stopword(first_paraphrased_sentence)
441
+ masked_versions = mask(masked_sentence)
442
+ dot = graphviz.Digraph()
443
+ dot.attr(rankdir='LR', size='8,10!', dpi='72')
444
+ dot.node("Original", original_sentence)
445
+ dot.node("Paraphrased", first_paraphrased_sentence)
446
+ dot.edge("Original", "Paraphrased")
447
+ for i, masked in enumerate(masked_versions):
448
+ node_id = f"Masked_{i}"
449
+ dot.node(node_id, masked)
450
+ dot.edge("Paraphrased", node_id)
451
+ return masked_sentence, dot.source
452
+
453
+ # Function for the Gradio interface
454
  def model(prompt):
455
+ generated, sentence = llm_output(prompt)
456
  res = generate_paraphrase(sentence)
457
+ common_subs = longest_common_subss(sentence, res)
458
+ common_grams = find_common_subsequences(sentence, res)
 
 
459
  for i in range(len(common_subs)):
460
  common_subs[i]["Paraphrased Sentence"] = res[i]
461
+ result = highlight_phrases_with_colors(res, common_grams)
462
+ masked_sentence, tree_source = generate_tree(sentence)
463
+ graph = graphviz.Source(tree_source)
464
+ svg_content = graph.pipe(format='svg').decode('utf-8')
465
+ # tree = f'<div style="width: 100%; overflow-x: auto;">{svg_content}</div>'
466
+ return generated, generated, result, masked_sentence, svg_content
467
 
468
  with gr.Blocks(theme = gr.themes.Monochrome()) as demo:
469
  gr.Markdown("# Paraphrases the Text and Highlights the Non-melting Points")
 
485
  html_output = gr.HTML()
486
 
487
  with gr.Row():
488
+ masked_sentence = gr.Textbox(label="Masked Sentence")
489
 
490
+ with gr.Row():
491
+ tree = gr.HTML(label="Tree")
 
 
 
 
492
 
493
+ submit_button.click(model, inputs=user_input, outputs=[ai_output, selected_sentence, html_output, masked_sentence, tree])
494
+ clear_button.click(lambda: "", inputs=None, outputs=user_input)
495
+ clear_button.click(lambda: "", inputs=None, outputs=[ai_output, selected_sentence, html_output, masked_sentence, tree])
496
 
497
+ # Launch the demo
498
+ demo.launch(share=True)
499