import transformers import streamlit as st from annotated_text import annotated_text @st.cache(allow_output_mutation=True, show_spinner=False) def get_pipe(): tokenizer = transformers.AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") model = transformers.AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") pipe = transformers.pipeline("question-answering", model=model, tokenizer=tokenizer) return pipe def parse_context(context, prediction): parsed_context = [] parsed_context.append(context[:prediction["start"]]) parsed_context.append((prediction["answer"], "ANSWER", "#afa")) parsed_context.append(context[prediction["end"]:]) return parsed_context st.set_page_config(page_title="Question Answering") st.title("Question Answering") st.write("Enter context and a question and press 'Predict' to extract the answer from the context.") default_context = "My name is Wolfgang and I live in Berlin." default_question = "What is my name?" context = st.text_area("Enter context here:", value=default_context) question = st.text_input("Enter question here:", value=default_question) submit = st.button('Predict') with st.spinner("Loading model..."): pipe = get_pipe() if (submit and len(context.strip()) > 0 and len(question.strip()) > 0) or \ (len(context.strip()) > 0 and len(question.strip()) > 0): prediction = pipe(question, context) parsed_context = parse_context(context, prediction) st.header("Prediction:") annotated_text(*parsed_context) st.header('Raw values:') st.json(prediction)