Spaces:
Runtime error
Runtime error
File size: 4,408 Bytes
1e434e8 8ddfa6e 1e434e8 961346c 8ddfa6e 1f6bd67 c28ffd1 1f6bd67 23e04e2 c78eaed 23e04e2 a1d41c3 23e04e2 1e434e8 23e04e2 1f6bd67 23e04e2 961346c 1e434e8 15ccfd9 8fbe78b 8fa0f4a 124247d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from distutils.command.upload import upload
import pandas as pd
import streamlit as st
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
from transformers import pipeline
@st.cache
def load_data(file):
df = pd.read_csv(file, encoding='utf-8', nrows=50)
return df
def load_pipeline(model_cp, tokenizer_cp):
return pipeline("question-answering", model=model_cp, tokenizer=tokenizer_cp)
def choose_model():
with st.sidebar:
st.write("# Model Selection")
model_cp = st.selectbox('Select model for inference',
('deepset/roberta-base-squad2',
'aidan-o-brien/recipe-improver'))
# If not my model > model_cp = tokenizer_cp, else > albert tokenizer
if model_cp == "aidan-o-brien/recipe-improver":
return model_cp, "albert-base-v2"
else:
return model_cp, model_cp
def choose_postprocessing():
with st.sidebar:
st.write('# Postprocessing')
postprocessing = st.selectbox('Select postprocessing method',
('no postprocessing',
'remove substrings',
'iteration'))
return postprocessing
def remove_substring(resp):
"""Function filters postprocessed model output to remove answers that
are substrings of another answer."""
answers = [answer['answer'] for answer in resp]
answers.sort(key=lambda s: len(s), reverse=True)
filtered_answers = []
for s in answers:
if not any([s in o for o in filtered_answers]):
filtered_answers.append(s)
filtered_resp = list(filter(lambda r: r['answer'] in filtered_answers, resp))
return filtered_resp
def iterate(resp, context):
# Remove substring answers
filtered_resp = remove_substring(resp)
# Remove best answer from original context
best_answer = filtered_resp[0]['answer']
new_context = context.replace(best_answer, "")
return new_context, filtered_resp
def remove_substring_iter(pipeline, question, context):
# Create answers placeholder
ret_resp = []
# Loop through five times, removing best and re-running
for _ in range(5):
resp = pipeline(question, context, top_k=5,
handle_impossible_answer=True)
# Update context
context, filtered_resp = iterate(resp, context)
# If best score not above threshold, quit
if filtered_resp[0]['score'] < 1e-2:
break
ret_resp.append(filtered_resp[0])
return ret_resp
# Page config
title = "Recipe Improver"
icon = "🍣"
st.set_page_config(page_title=title, page_icon=icon)
st.title(title)
# Choose model and postprocessing procedure
model_cp, tokenizer_cp = choose_model()
postprocessing = choose_postprocessing()
# Load model and tokenizer
question_answer = load_pipeline(model_cp, tokenizer_cp)
st.write("Model and tokenizer successfully loaded.")
# Upload csv - format with expander for aesthetics
with st.expander("Upload csv file"):
uploaded_file = st.file_uploader("Choose a csv file", type="csv", key='file_uploader')
# If file is uploaded, run inference - QA pipelines can't run batch
if uploaded_file is not None:
df = load_data(uploaded_file)
# Ask user for index to try out
user_idx = st.text_input(f'Enter index (max {df.shape[0] - 1}', "0")
# Only run rest of script if user provides index
if user_idx is not None:
first_example = df['review'][int(user_idx)]
question = "How was the recipe improved?"
if postprocessing == "no postprocessing":
resp = question_answer(question=question,
context=first_example,
top_k=5)
elif postprocessing == "remove substrings":
resp = question_answer(question=question,
context=first_example,
top_k=5)
resp = remove_substring(resp)
elif postprocessing == "iteration":
resp = remove_substring_iter(question_answer,
question,
first_example)
# Present results
st.markdown(f"""
# Results
The review provided was:
{first_example}
The question asked was:
{question}
The answers were:
{resp}
""") |