msures3's picture
Create app.py
22ba023 verified
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch
import gradio as gr
model = AutoModelForQuestionAnswering.from_pretrained("msures3/distilbert-base-squad")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def generate_response(context, question):
inputs = tokenizer(question, context, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
predicted_answers = tokenizer.decode(predict_answer_tokens)
return predicted_answers
inputs = [
gr.Textbox(label="Enter Context"),
gr.Textbox(label="Enter Question")
]
outputs = gr.Textbox(label="Predicted Answer")
app = gr.Interface(
fn=generate_response,
inputs=inputs,
outputs=outputs,
title="Context Based Question Answering",
description="Enter a context and a question to get the predicted answer.",
examples=[
["This is a sample context. The quick brown fox jumps over the lazy dog.", "What animal jumps over the dog?"],
["The capital of France is Paris. It is a beautiful city with many attractions.", "What is the capital of France?"]
]
)
app.launch()