Demo_space_2 / app.py
Ganesh43's picture
Update app.py
4265c8b verified
raw
history blame
No virus
1.7 kB
import torch
import streamlit as st
from transformers import BertTokenizer, BertForQuestionAnswering
# Utilize BertForQuestionAnswering model for direct start/end logits
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
def answer_query(question, context):
# Preprocess using tokenizer
inputs = tokenizer(question, context, return_tensors="pt")
# Use model for question answering
with torch.no_grad():
outputs = model(**inputs)
# Retrieve logits directly
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# Find answer span
answer_start = torch.argmax(start_logits)
answer_end = torch.argmax(end_logits) + 1
# Extract answer from context
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Access original tokens
)[answer_start:answer_end]
return answer
# Streamlit app
st.title("Question Answering App")
# Textbox for user query
user_query = st.text_input("Enter your question:")
# File uploader for context
uploaded_file = st.file_uploader("Upload a context file (txt):")
if uploaded_file is not None:
# Read the uploaded file content
context = uploaded_file.read().decode("utf-8")
else:
# Use default context if no file uploaded
context = "This is a sample context for demonstration purposes. You can upload your own text file for context."
# Answer the query if a question is provided
if user_query:
answer = answer_query(user_query, context)
st.write(f"Answer: {answer}")
else:
st.write("Please enter a question.")