WikiAI / app.py
PSM272's picture
Update app.py
1710f02
import wikipedia as wiki
import pprint as pp
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
import torch
import gradio as gr
def greet(name):
#question = 'Why is the sky blue?'
question = name
results = wiki.search(question)
#print("Wikipedia search results for our question:\n")
#pp.pprint(results)
page = wiki.page(results[0])
text = page.content
#print(f"\nThe {results[0]} Wikipedia article contains {len(text)} characters.")
#print(text)
model_name = "deepset/roberta-base-squad2"
#from transformers import AutoModel
#model_name = AutoModelForQuestionAnswering.from_pretrained('./roberta-base-squad2/')
def get_sentence(text, pos):
start = text.rfind('.', 0, pos) + 1
end = text.find('.', pos)
if end == -1:
end = len(text)
return text[start:end].strip()
# a) Get predictions
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
QA_input = {
'question': question,
'context': text
}
res = nlp(QA_input)
# b) Load model & tokenizer
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
#print("{'answer': '"+res['answer']+"', 'text': '")
#print(res['answer'])
#print("', 'text': '")
position = res['start']
#words = sum(map(str.split, text), [])
#sentence = ' '.join(words[position-1:]).split('.')[0] + '.'
#print(get_sentence(text, position)+'.')
tokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
model = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase")
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_response(input_text,num_return_sequences,num_beams):
batch = tokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
translated = model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text
num_beams = 20
num_return_sequences = 1
context = get_sentence(text, position)+'.'
#print(get_response(context,num_return_sequences,num_beams)[0])
#print("'}")
return "{'answer': '"+res['answer']+"', 'text': '"+get_response(context,num_return_sequences,num_beams)[0]+"'}"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()