Nighter's picture
Update app.py
680c7ae
raw
history blame contribute delete
No virus
7.22 kB
import gradio as gr
from transformers import pipeline
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import text_to_word_sequence
import pickle
import re
from tensorflow.keras.models import load_model
# Load long model
with open('lstm-qa-long-answers-model/tokenizer.pickle', 'rb') as handle:
tokenizer = pickle.load(handle)
long_answer_model = load_model('lstm-qa-long-answers-model/model.h5')
def clean_text(text):
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'\[\d+\]', '', text)
text = re.sub(r'[^a-zA-Z0-9\s().,]', '', text)
return text
def remove_parentheses(text):
pattern = r'\([^)]*\)'
return re.sub(pattern, '', text)
def predict_correct_answer(question, answer1, answer2):
answers = [answer1, answer2]
correct_answer = None
best_score = 0
for answer in answers:
clean_answer = clean_text(answer)
question_seq = tokenizer.texts_to_sequences([question])
answer_seq = tokenizer.texts_to_sequences([clean_answer])
padded_question = pad_sequences(question_seq, padding='post')
padded_answer = pad_sequences(answer_seq, maxlen=300, padding='post', truncating='post')
score = long_answer_model.predict([padded_answer, padded_question])[0][0]
if score > best_score:
best_score = score
correct_answer = clean_answer
return correct_answer, best_score
# def split_into_sentences(text):
# sentences = re.split(r'\.\s*', text)
# return [s.strip() for s in sentences if s]
# def predict_answer(context, question):
# sentences = split_into_sentences(context)
# best_sentence = None
# best_score = 0
# for sentence in sentences:
# clean_sentence = clean_text(sentence)
# question_seq = tokenizer.texts_to_sequences([question])
# sentence_seq = tokenizer.texts_to_sequences([clean_sentence])
# max_sentence_length = 300
# padded_question = pad_sequences(question_seq, padding='post')
# padded_sentence = pad_sequences(sentence_seq, maxlen=max_sentence_length, padding='post', truncating='post')
# score = long_answer_model.predict([padded_sentence, padded_question])[0]
# if score > best_score:
# best_score = score
# best_sentence = clean_sentence
# return best_score, best_sentence
# Load short model
distilbert_base_uncased = pipeline(model="Nighter/QA_wiki_data_short_answer", from_tf=True)
bert_base_uncased = pipeline(model="Nighter/QA_bert_base_uncased_wiki_data_short_answer", from_tf=True)
roberta_base = pipeline(model="Nighter/QA_wiki_data_roberta_base_short_answer", from_tf=True)
longformer_base = pipeline(model="aware-ai/longformer-squadv2")
# Function to answer on all models
def answer_questions(context, question):
# long_score, long_answer = predict_answer(context, question)
distilbert_base_uncased_result = distilbert_base_uncased(question=question, context=remove_parentheses(context))
bert_base_uncased_result = bert_base_uncased(question=question, context=remove_parentheses(context))
roberta_base_result = roberta_base(question=question, context=remove_parentheses(context))
longformer_base_result = longformer_base(question=question, context=remove_parentheses(context))
return distilbert_base_uncased_result['answer'], distilbert_base_uncased_result['score'], bert_base_uncased_result['answer'], bert_base_uncased_result['score'], roberta_base_result['answer'], longformer_base_result['score'], longformer_base_result['answer'], roberta_base_result['score'] #, long_answer, long_score
# App Interface
with gr.Blocks() as app:
gr.Markdown("<center> <h1>Question Answering with Short and Long Answer Models </h1> </center><hr>")
with gr.Tab("QA Short Answer"):
with gr.Row():
with gr.Column():
context_input = gr.Textbox(lines=8, label="Context", placeholder="Input Context here...")
question_input = gr.Textbox(lines=3, label="Question", placeholder="Input Question here...")
submit_btn = gr.Button("Submit")
gr.ClearButton([context_input,question_input])
with gr.Column():
with gr.Row():
with gr.Column(scale=6):
distilbert_base_uncased_output = gr.Textbox(lines=2, label="Distil BERT Base Uncased")
with gr.Column(scale=2):
distilbert_base_uncased_score = gr.Number(label="Distil BERT Base Uncased Score")
with gr.Row():
with gr.Column(scale=6):
bert_base_uncased_output = gr.Textbox(lines=2, label="BERT Base Uncased")
with gr.Column(scale=2):
bert_base_uncased_score = gr.Number(label="BERT Base Uncased Score")
with gr.Row():
with gr.Column(scale=6):
roberta_base_output = gr.Textbox(lines=2, label="RoBERTa Base")
with gr.Column(scale=2):
roberta_base_score = gr.Number(label="RoBERTa Base Score")
with gr.Row():
with gr.Column(scale=6):
longformer_base_output = gr.Textbox(lines=2, label="Longformer Base")
with gr.Column(scale=2):
longformer_base_score = gr.Number(label="Longformer Base Score")
submit_btn.click(fn=answer_questions, inputs=[context_input, question_input], outputs=[distilbert_base_uncased_output, distilbert_base_uncased_score, bert_base_uncased_output, bert_base_uncased_score, roberta_base_output, roberta_base_score, longformer_base_output, longformer_base_score])
examples='examples'
gr.Examples(examples,[context_input, question_input],[distilbert_base_uncased_output, distilbert_base_uncased_score, bert_base_uncased_output, bert_base_uncased_score, roberta_base_output, roberta_base_score, longformer_base_output, longformer_base_score],answer_questions)
with gr.Tab("Long Answer Prediction"):
with gr.Row():
with gr.Column():
long_question_input = gr.Textbox(lines=3,label="Question", placeholder="Enter the question")
answer1_input = gr.Textbox(lines=3,label="Answer 1", placeholder="Enter answer 1")
answer2_input = gr.Textbox(lines=3,label="Answer 2", placeholder="Enter answer 2")
submit_btn_long = gr.Button("Submit")
gr.ClearButton([long_question_input, answer1_input, answer2_input])
with gr.Column():
correct_answer_output = gr.Textbox(lines=3,label="Correct Answer")
score_output = gr.Number(label="Score")
submit_btn_long.click(fn=predict_correct_answer, inputs=[long_question_input, answer1_input, answer2_input],
outputs=[correct_answer_output, score_output])
long_examples = 'long_examples'
gr.Examples(long_examples,[long_question_input, answer1_input, answer2_input])
if __name__ == "__main__":
app.launch()