WikiAI / app.py
PSM272's picture
Create app.py
6972a16
raw
history blame
2.57 kB
import wikipedia as wiki
import pprint as pp
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
import torch
import gradio as gr
def greet(name):
#question = 'Why is the sky blue?'
question = name
results = wiki.search(question)
#print("Wikipedia search results for our question:\n")
#pp.pprint(results)
page = wiki.page(results[0])
text = page.content
#print(f"\nThe {results[0]} Wikipedia article contains {len(text)} characters.")
#print(text)
model_name = "deepset/roberta-base-squad2"
#from transformers import AutoModel
#model_name = AutoModelForQuestionAnswering.from_pretrained('./roberta-base-squad2/')
def get_sentence(text, pos):
start = text.rfind('.', 0, pos) + 1
end = text.find('.', pos)
if end == -1:
end = len(text)
return text[start:end].strip()
# a) Get predictions
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
QA_input = {
'question': question,
'context': text
}
res = nlp(QA_input)
# b) Load model & tokenizer
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("{'answer': '"+res['answer']+"', 'text': '")
#print(res['answer'])
#print("', 'text': '")
position = res['start']
#words = sum(map(str.split, text), [])
#sentence = ' '.join(words[position-1:]).split('.')[0] + '.'
print(get_sentence(text, position)+'.')
tokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
model = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase")
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_response(input_text,num_return_sequences,num_beams):
batch = tokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
translated = model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text
num_beams = 20
num_return_sequences = 1
context = get_sentence(text, position)+'.'
print(get_response(context,num_return_sequences,num_beams)[0])
print("'}")
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()