recipe-improver / app.py
aidan-o-brien's picture
added default index
124247d
raw history blame
No virus
4.41 kB
from distutils.command.upload import upload
import pandas as pd
import streamlit as st
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
from transformers import pipeline
@st.cache
def load_data(file):
df = pd.read_csv(file, encoding='utf-8', nrows=50)
return df
def load_pipeline(model_cp, tokenizer_cp):
return pipeline("question-answering", model=model_cp, tokenizer=tokenizer_cp)
def choose_model():
with st.sidebar:
st.write("# Model Selection")
model_cp = st.selectbox('Select model for inference',
('deepset/roberta-base-squad2',
'aidan-o-brien/recipe-improver'))
# If not my model > model_cp = tokenizer_cp, else > albert tokenizer
if model_cp == "aidan-o-brien/recipe-improver":
return model_cp, "albert-base-v2"
else:
return model_cp, model_cp
def choose_postprocessing():
with st.sidebar:
st.write('# Postprocessing')
postprocessing = st.selectbox('Select postprocessing method',
('no postprocessing',
'remove substrings',
'iteration'))
return postprocessing
def remove_substring(resp):
"""Function filters postprocessed model output to remove answers that
are substrings of another answer."""
answers = [answer['answer'] for answer in resp]
answers.sort(key=lambda s: len(s), reverse=True)
filtered_answers = []
for s in answers:
if not any([s in o for o in filtered_answers]):
filtered_answers.append(s)
filtered_resp = list(filter(lambda r: r['answer'] in filtered_answers, resp))
return filtered_resp
def iterate(resp, context):
# Remove substring answers
filtered_resp = remove_substring(resp)
# Remove best answer from original context
best_answer = filtered_resp[0]['answer']
new_context = context.replace(best_answer, "")
return new_context, filtered_resp
def remove_substring_iter(pipeline, question, context):
# Create answers placeholder
ret_resp = []
# Loop through five times, removing best and re-running
for _ in range(5):
resp = pipeline(question, context, top_k=5,
handle_impossible_answer=True)
# Update context
context, filtered_resp = iterate(resp, context)
# If best score not above threshold, quit
if filtered_resp[0]['score'] < 1e-3:
break
ret_resp.append(filtered_resp[0])
return ret_resp
# Page config
title = "Recipe Improver"
icon = "🍣"
st.set_page_config(page_title=title, page_icon=icon)
st.title(title)
# Choose model and postprocessing procedure
model_cp, tokenizer_cp = choose_model()
postprocessing = choose_postprocessing()
# Load model and tokenizer
question_answer = load_pipeline(model_cp, tokenizer_cp)
st.write("Model and tokenizer successfully loaded.")
# Upload csv - format with expander for aesthetics
with st.expander("Upload csv file"):
uploaded_file = st.file_uploader("Choose a csv file", type="csv", key='file_uploader')
# If file is uploaded, run inference - QA pipelines can't run batch
if uploaded_file is not None:
df = load_data(uploaded_file)
# Ask user for index to try out
user_idx = st.text_input(f'Enter index (max {df.shape[0] - 1}', "0")
# Only run rest of script if user provides index
if user_idx is not None:
first_example = df['review'][int(user_idx)]
question = "How was the recipe improved?"
if postprocessing == "no postprocessing":
resp = question_answer(question=question,
context=first_example,
top_k=5)
elif postprocessing == "remove substrings":
resp = question_answer(question=question,
context=first_example,
top_k=5)
resp = remove_substring(resp)
elif postprocessing == "iteration":
resp = remove_substring_iter(question_answer,
question,
first_example)
# Present results
st.markdown(f"""
# Results
The review provided was:
{first_example}
The question asked was:
{question}
The answers were:
{resp}
""")