omm / qa.py
Nikhil0987's picture
this
5bfab8c
raw
history blame
1.11 kB
from transformers import pipeline
import streamlit as st
# def que():
# question = st.text_input("ASk me a question")
# oracle = pipeline(task= "question-answering",model="deepset/roberta-base-squad2")
# oracle(question="Where do I live?", context="My name is Wolfgang and I live in Berlin")
def question_answering(question, context):
"""Answers a question given a context."""
# Load the question answering model.
qa_model = pipeline("question-answering")
# Prepare the inputs for the model.
inputs = {
"question": question,
"context": context,
}
# Get the answer from the model.
output = qa_model(**inputs)
answer = output["answer_start"]
# Return the answer.
return context[answer : answer + output["answer_length"]]
if __name__ == "__main__":
# Get the question and context.
question = "What is the capital of France?"
context = "The capital of France is Paris."
# Get the answer.
answer = question_answering(question, context)
# Print the answer.
print(answer)