recipe-improver / app.py
aidan-o-brien's picture
changed iteration treshold
a1d41c3
from distutils.command.upload import upload
import pandas as pd
import streamlit as st
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
from transformers import pipeline
@st.cache
def load_data(file):
df = pd.read_csv(file, encoding='utf-8', nrows=50)
return df
def load_pipeline(model_cp, tokenizer_cp):
return pipeline("question-answering", model=model_cp, tokenizer=tokenizer_cp)
def choose_model():
with st.sidebar:
st.write("# Model Selection")
model_cp = st.selectbox('Select model for inference',
('deepset/roberta-base-squad2',
'aidan-o-brien/recipe-improver'))
# If not my model > model_cp = tokenizer_cp, else > albert tokenizer
if model_cp == "aidan-o-brien/recipe-improver":
return model_cp, "albert-base-v2"
else:
return model_cp, model_cp
def choose_postprocessing():
with st.sidebar:
st.write('# Postprocessing')
postprocessing = st.selectbox('Select postprocessing method',
('no postprocessing',
'remove substrings',
'iteration'))
return postprocessing
def remove_substring(resp):
"""Function filters postprocessed model output to remove answers that
are substrings of another answer."""
answers = [answer['answer'] for answer in resp]
answers.sort(key=lambda s: len(s), reverse=True)
filtered_answers = []
for s in answers:
if not any([s in o for o in filtered_answers]):
filtered_answers.append(s)
filtered_resp = list(filter(lambda r: r['answer'] in filtered_answers, resp))
return filtered_resp
def iterate(resp, context):
# Remove substring answers
filtered_resp = remove_substring(resp)
# Remove best answer from original context
best_answer = filtered_resp[0]['answer']
new_context = context.replace(best_answer, "")
return new_context, filtered_resp
def remove_substring_iter(pipeline, question, context):
# Create answers placeholder
ret_resp = []
# Loop through five times, removing best and re-running
for _ in range(5):
resp = pipeline(question, context, top_k=5,
handle_impossible_answer=True)
# Update context
context, filtered_resp = iterate(resp, context)
# If best score not above threshold, quit
if filtered_resp[0]['score'] < 1e-2:
break
ret_resp.append(filtered_resp[0])
return ret_resp
# Page config
title = "Recipe Improver"
icon = "🍣"
st.set_page_config(page_title=title, page_icon=icon)
st.title(title)
# Choose model and postprocessing procedure
model_cp, tokenizer_cp = choose_model()
postprocessing = choose_postprocessing()
# Load model and tokenizer
question_answer = load_pipeline(model_cp, tokenizer_cp)
st.write("Model and tokenizer successfully loaded.")
# Upload csv - format with expander for aesthetics
with st.expander("Upload csv file"):
uploaded_file = st.file_uploader("Choose a csv file", type="csv", key='file_uploader')
# If file is uploaded, run inference - QA pipelines can't run batch
if uploaded_file is not None:
df = load_data(uploaded_file)
# Ask user for index to try out
user_idx = st.text_input(f'Enter index (max {df.shape[0] - 1}', "0")
# Only run rest of script if user provides index
if user_idx is not None:
first_example = df['review'][int(user_idx)]
question = "How was the recipe improved?"
if postprocessing == "no postprocessing":
resp = question_answer(question=question,
context=first_example,
top_k=5)
elif postprocessing == "remove substrings":
resp = question_answer(question=question,
context=first_example,
top_k=5)
resp = remove_substring(resp)
elif postprocessing == "iteration":
resp = remove_substring_iter(question_answer,
question,
first_example)
# Present results
st.markdown(f"""
# Results
The review provided was:
{first_example}
The question asked was:
{question}
The answers were:
{resp}
""")