Spaces:
Sleeping
Sleeping
import streamlit as st | |
import inference | |
from app_utils import get_default_texts, display_output | |
summarizer_path = 'google/pegasus-large' | |
qa_path = 'qa_model' | |
entity_recognition_path = 'entity_rec' | |
if 'sample_index' not in st.session_state: | |
st.session_state['sample_index'] = 0 | |
if 'which_button' not in st.session_state: | |
st.session_state['which_button'] = 'sample_button' | |
st.title('NLP Demo') | |
with st.sidebar: | |
st.header("Select your choices") | |
ops_to_perform = st.multiselect('Select operation to perform :', ['Question Answering', 'Entity Recognition', 'Text Summarization'], | |
default=['Question Answering']) | |
chosen_dataset = st.selectbox("Choose one of the datasest to get samples :", ['squad-qa', 'bbc-xsum-summarization', 'conll-ner']) | |
samples_dict = get_default_texts(chosen_dataset) | |
tot_index = len(samples_dict) | |
st.write('**Select from sample images**') | |
st.write("Select one from these available samples: ") | |
current_index = st.session_state['sample_index'] | |
prev_button, next_button = st.columns(2) | |
with prev_button: | |
prev = st.button('prev_text') | |
with next_button: | |
next = st.button('next_text') | |
if prev: | |
current_index = (current_index - 1) % tot_index | |
if next: | |
current_index = (current_index + 1) % tot_index | |
st.session_state['sample_index'] = current_index | |
sample_text = samples_dict[current_index] | |
input_text = st.text_area("Input text to perform selected operations on : ", sample_text) | |
question = None | |
if "Question Answering" in ops_to_perform: | |
question = st.text_input("Enter a valid question here :") | |
predict_clicked = st.button("Submit for predictions") | |
if predict_clicked: | |
which_button = st.session_state['which_button'] | |
if which_button == 'sample_button': | |
all_outputs = inference.get_predictions(input_text, ops_to_perform, question) | |
st.session_state['prev_outputs'] = all_outputs | |
display_output(all_outputs) | |
else: | |
if 'prev_outputs' in st.session_state: | |
all_outputs = st.session_state['prev_outputs'] | |
display_output(all_outputs) | |