Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
import ntlk_utils #nltk are download in different file
|
| 3 |
+
from nltk.corpus import wordnet as wn
|
| 4 |
+
from nltk.tokenize import sent_tokenize
|
| 5 |
+
from nltk.corpus import stopwords
|
| 6 |
+
from time import sleep
|
| 7 |
+
|
| 8 |
+
from flashtext import KeywordProcessor
|
| 9 |
+
from pprint import pprint
|
| 10 |
+
import random
|
| 11 |
+
import pke
|
| 12 |
+
import traceback
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import requests
|
| 16 |
+
import string
|
| 17 |
+
import re
|
| 18 |
+
import string
|
| 19 |
+
import itertools
|
| 20 |
+
|
| 21 |
+
import streamlit as st
|
| 22 |
+
from transformers import T5ForConditionalGeneration,T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from transformers import pipeline
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import random
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
def set_seed(seed: int):
|
| 31 |
+
random.seed(seed)
|
| 32 |
+
np.random.seed(seed)
|
| 33 |
+
torch.manual_seed(seed)
|
| 34 |
+
torch.cuda.manual_seed_all(seed)
|
| 35 |
+
|
| 36 |
+
set_seed(42)
|
| 37 |
+
|
| 38 |
+
summary_model = T5ForConditionalGeneration.from_pretrained('t5-base')
|
| 39 |
+
summary_tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
| 40 |
+
|
| 41 |
+
question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
|
| 42 |
+
question_tokenizer = T5Tokenizer.from_pretrained('ramsrigouthamg/t5_squad_v1')
|
| 43 |
+
|
| 44 |
+
#summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
st.header(" Question Creation ")
|
| 49 |
+
st.subheader(" Enter the text and click on generate question. Questions will be created automatically.")
|
| 50 |
+
text = st.text_area("Input the text to get questions",placeholder="Enter the text", height=200)
|
| 51 |
+
button = st.button("Generate Question")
|
| 52 |
+
|
| 53 |
+
def postprocesstext (content):
|
| 54 |
+
final=""
|
| 55 |
+
for sent in sent_tokenize(content):
|
| 56 |
+
sent = sent.capitalize()
|
| 57 |
+
final = final +" "+sent
|
| 58 |
+
return final
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def summarizer(text,model,tokenizer):
|
| 62 |
+
text = text.strip().replace("\n"," ")
|
| 63 |
+
text = "summarize: "+text
|
| 64 |
+
print (text)
|
| 65 |
+
max_len = 512
|
| 66 |
+
encoding = tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt").to(device)
|
| 67 |
+
|
| 68 |
+
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
| 69 |
+
|
| 70 |
+
outs = model.generate(input_ids=input_ids,
|
| 71 |
+
attention_mask=attention_mask,
|
| 72 |
+
early_stopping=True,
|
| 73 |
+
num_beams=3,
|
| 74 |
+
num_return_sequences=1,
|
| 75 |
+
no_repeat_ngram_size=2,
|
| 76 |
+
min_length = 75,
|
| 77 |
+
max_length=300)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
|
| 81 |
+
summary = dec[0]
|
| 82 |
+
summary = postprocesstext(summary)
|
| 83 |
+
summary= summary.strip()
|
| 84 |
+
print( "done from summarizer")
|
| 85 |
+
return summary
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_nouns_multipartite(content):
|
| 89 |
+
out=[]
|
| 90 |
+
try:
|
| 91 |
+
extractor = pke.unsupervised.MultipartiteRank()
|
| 92 |
+
extractor.load_document(input=content,language='en')
|
| 93 |
+
# not contain punctuation marks or stopwords as candidates.
|
| 94 |
+
pos = {'PROPN','NOUN'}
|
| 95 |
+
#pos = {'PROPN','NOUN'}
|
| 96 |
+
stoplist = list(string.punctuation)
|
| 97 |
+
stoplist += ['-lrb-', '-rrb-', '-lcb-', '-rcb-', '-lsb-', '-rsb-']
|
| 98 |
+
stoplist += stopwords.words('english')
|
| 99 |
+
# extractor.candidate_selection(pos=pos, stoplist=stoplist)
|
| 100 |
+
extractor.candidate_selection(pos=pos)
|
| 101 |
+
# 4. build the Multipartite graph and rank candidates using random walk,
|
| 102 |
+
# alpha controls the weight adjustment mechanism, see TopicRank for
|
| 103 |
+
# threshold/method parameters.
|
| 104 |
+
extractor.candidate_weighting(alpha=1.1,
|
| 105 |
+
threshold=0.75,
|
| 106 |
+
method='average')
|
| 107 |
+
keyphrases = extractor.get_n_best(n=15)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
for val in keyphrases:
|
| 111 |
+
out.append(val[0])
|
| 112 |
+
except:
|
| 113 |
+
out = []
|
| 114 |
+
traceback.print_exc()
|
| 115 |
+
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
def get_keywords(originaltext,summarytext):
|
| 119 |
+
keywords = get_nouns_multipartite(originaltext)
|
| 120 |
+
print ("keywords unsummarized: ",keywords)
|
| 121 |
+
keyword_processor = KeywordProcessor()
|
| 122 |
+
for keyword in keywords:
|
| 123 |
+
keyword_processor.add_keyword(keyword)
|
| 124 |
+
|
| 125 |
+
keywords_found = keyword_processor.extract_keywords(summarytext)
|
| 126 |
+
keywords_found = list(set(keywords_found))
|
| 127 |
+
print ("keywords_found in summarized: ",keywords_found)
|
| 128 |
+
|
| 129 |
+
important_keywords =[]
|
| 130 |
+
for keyword in keywords:
|
| 131 |
+
if keyword in keywords_found:
|
| 132 |
+
important_keywords.append(keyword)
|
| 133 |
+
|
| 134 |
+
return important_keywords[:4]
|
| 135 |
+
|
| 136 |
+
def get_question(context,answer,model,tokenizer):
|
| 137 |
+
text = "context: {} answer: {}".format(context,answer)
|
| 138 |
+
encoding = tokenizer.encode_plus(text,max_length=384, pad_to_max_length=False,truncation=True, return_tensors="pt").to(device)
|
| 139 |
+
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
| 140 |
+
|
| 141 |
+
outs = model.generate(input_ids=input_ids,
|
| 142 |
+
attention_mask=attention_mask,
|
| 143 |
+
early_stopping=True,
|
| 144 |
+
num_beams=5,
|
| 145 |
+
num_return_sequences=1,
|
| 146 |
+
no_repeat_ngram_size=2,
|
| 147 |
+
max_length=72)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Question = dec[0].replace("question:","")
|
| 154 |
+
Question= Question.strip()
|
| 155 |
+
return Question
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if text and button:
|
| 159 |
+
# summarized_text = summarizer(text,summary_model,summary_tokenizer)
|
| 160 |
+
summarized_text = summarizer(text,summary_model,summary_tokenizer)
|
| 161 |
+
puts ("stopping pankaj")
|
| 162 |
+
sleep(0.5)
|
| 163 |
+
puts("summry",summarized_text)
|
| 164 |
+
#summarized_text = summarizer(text, max_length=130, min_length=30, do_sample=False)
|
| 165 |
+
# summarized_text = "Musk tweeted that his electric vehicle-making company tesla will not accept payments in bitcoin because of environmental concerns. He also said that the company was working with developers of dogecoin to improve system transaction efficiency. The world's largest cryptocurrency hit a two-month low, while doge coin rallied by about 20 percent. Musk has in recent months often tweeted in support of crypto, but rarely for bitcoin."
|
| 166 |
+
imp_keywords = get_keywords(text,summarized_text)
|
| 167 |
+
for answer in imp_keywords:
|
| 168 |
+
ques = get_question(summarized_text,answer,question_model,question_tokenizer)
|
| 169 |
+
st.write(ques)
|
| 170 |
+
st.write(answer.capitalize())
|
| 171 |
+
st.write("\n")
|