Spaces:
Runtime error
Runtime error
from distutils.command.upload import upload | |
import pandas as pd | |
import streamlit as st | |
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering | |
from transformers import pipeline | |
def load_data(file): | |
df = pd.read_csv(file, encoding='utf-8', nrows=50) | |
return df | |
def load_pipeline(model_cp, tokenizer_cp): | |
return pipeline("question-answering", model=model_cp, tokenizer=tokenizer_cp) | |
def choose_model(): | |
with st.sidebar: | |
st.write("# Model Selection") | |
model_cp = st.selectbox('Select model for inference', | |
('deepset/roberta-base-squad2', | |
'aidan-o-brien/recipe-improver')) | |
# If not my model > model_cp = tokenizer_cp, else > albert tokenizer | |
if model_cp == "aidan-o-brien/recipe-improver": | |
return model_cp, "albert-base-v2" | |
else: | |
return model_cp, model_cp | |
def choose_postprocessing(): | |
with st.sidebar: | |
st.write('# Postprocessing') | |
postprocessing = st.selectbox('Select postprocessing method', | |
('no postprocessing', | |
'remove substrings', | |
'iteration')) | |
return postprocessing | |
def remove_substring(resp): | |
"""Function filters postprocessed model output to remove answers that | |
are substrings of another answer.""" | |
answers = [answer['answer'] for answer in resp] | |
answers.sort(key=lambda s: len(s), reverse=True) | |
filtered_answers = [] | |
for s in answers: | |
if not any([s in o for o in filtered_answers]): | |
filtered_answers.append(s) | |
filtered_resp = list(filter(lambda r: r['answer'] in filtered_answers, resp)) | |
return filtered_resp | |
def iterate(resp, context): | |
# Remove substring answers | |
filtered_resp = remove_substring(resp) | |
# Remove best answer from original context | |
best_answer = filtered_resp[0]['answer'] | |
new_context = context.replace(best_answer, "") | |
return new_context, filtered_resp | |
def remove_substring_iter(pipeline, question, context): | |
# Create answers placeholder | |
ret_resp = [] | |
# Loop through five times, removing best and re-running | |
for _ in range(5): | |
resp = pipeline(question, context, top_k=5, | |
handle_impossible_answer=True) | |
# Update context | |
context, filtered_resp = iterate(resp, context) | |
# If best score not above threshold, quit | |
if filtered_resp[0]['score'] < 1e-2: | |
break | |
ret_resp.append(filtered_resp[0]) | |
return ret_resp | |
# Page config | |
title = "Recipe Improver" | |
icon = "🍣" | |
st.set_page_config(page_title=title, page_icon=icon) | |
st.title(title) | |
# Choose model and postprocessing procedure | |
model_cp, tokenizer_cp = choose_model() | |
postprocessing = choose_postprocessing() | |
# Load model and tokenizer | |
question_answer = load_pipeline(model_cp, tokenizer_cp) | |
st.write("Model and tokenizer successfully loaded.") | |
# Upload csv - format with expander for aesthetics | |
with st.expander("Upload csv file"): | |
uploaded_file = st.file_uploader("Choose a csv file", type="csv", key='file_uploader') | |
# If file is uploaded, run inference - QA pipelines can't run batch | |
if uploaded_file is not None: | |
df = load_data(uploaded_file) | |
# Ask user for index to try out | |
user_idx = st.text_input(f'Enter index (max {df.shape[0] - 1}', "0") | |
# Only run rest of script if user provides index | |
if user_idx is not None: | |
first_example = df['review'][int(user_idx)] | |
question = "How was the recipe improved?" | |
if postprocessing == "no postprocessing": | |
resp = question_answer(question=question, | |
context=first_example, | |
top_k=5) | |
elif postprocessing == "remove substrings": | |
resp = question_answer(question=question, | |
context=first_example, | |
top_k=5) | |
resp = remove_substring(resp) | |
elif postprocessing == "iteration": | |
resp = remove_substring_iter(question_answer, | |
question, | |
first_example) | |
# Present results | |
st.markdown(f""" | |
# Results | |
The review provided was: | |
{first_example} | |
The question asked was: | |
{question} | |
The answers were: | |
{resp} | |
""") |