Demo_space_2 / question_answering.py
Ganesh43's picture
Create question_answering.py
229899d verified
raw
history blame
No virus
904 Bytes
import torch
from transformers import BertTokenizer, BertForQuestionAnswering
# Load the pre-trained model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
def answer_query(question, context):
# Preprocess the question and context using the tokenizer
inputs = tokenizer(question, context, return_tensors="pt")
# Use the model for question answering
with torch.no_grad():
outputs = model(**inputs)
# Get start and end logits directly from model outputs
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# Find the most likely answer span
answer_start = torch.argmax(start_logits)
answer_end = torch.argmax(end_logits) + 1
# Extract the answer from the context
answer = tokenizer.convert_tokens_to_string(context)[answer_start:answer_end]
return answer