from distutils.command.upload import upload import pandas as pd import streamlit as st from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering from transformers import pipeline @st.cache def load_data(file): df = pd.read_csv(file, encoding='utf-8', nrows=50) return df def load_pipeline(model_cp, tokenizer_cp): return pipeline("question-answering", model=model_cp, tokenizer=tokenizer_cp) def choose_model(): with st.sidebar: st.write("# Model Selection") model_cp = st.selectbox('Select model for inference', ('deepset/roberta-base-squad2', 'aidan-o-brien/recipe-improver')) # If not my model > model_cp = tokenizer_cp, else > albert tokenizer if model_cp == "aidan-o-brien/recipe-improver": return model_cp, "albert-base-v2" else: return model_cp, model_cp def choose_postprocessing(): with st.sidebar: st.write('# Postprocessing') postprocessing = st.selectbox('Select postprocessing method', ('no postprocessing', 'remove substrings', 'iteration')) return postprocessing def remove_substring(resp): """Function filters postprocessed model output to remove answers that are substrings of another answer.""" answers = [answer['answer'] for answer in resp] answers.sort(key=lambda s: len(s), reverse=True) filtered_answers = [] for s in answers: if not any([s in o for o in filtered_answers]): filtered_answers.append(s) filtered_resp = list(filter(lambda r: r['answer'] in filtered_answers, resp)) return filtered_resp def iterate(resp, context): # Remove substring answers filtered_resp = remove_substring(resp) # Remove best answer from original context best_answer = filtered_resp[0]['answer'] new_context = context.replace(best_answer, "") return new_context, filtered_resp def remove_substring_iter(pipeline, question, context): # Create answers placeholder ret_resp = [] # Loop through five times, removing best and re-running for _ in range(5): resp = pipeline(question, context, top_k=5, handle_impossible_answer=True) # Update context context, filtered_resp = iterate(resp, context) # If best score not above threshold, quit if filtered_resp[0]['score'] < 1e-2: break ret_resp.append(filtered_resp[0]) return ret_resp # Page config title = "Recipe Improver" icon = "🍣" st.set_page_config(page_title=title, page_icon=icon) st.title(title) # Choose model and postprocessing procedure model_cp, tokenizer_cp = choose_model() postprocessing = choose_postprocessing() # Load model and tokenizer question_answer = load_pipeline(model_cp, tokenizer_cp) st.write("Model and tokenizer successfully loaded.") # Upload csv - format with expander for aesthetics with st.expander("Upload csv file"): uploaded_file = st.file_uploader("Choose a csv file", type="csv", key='file_uploader') # If file is uploaded, run inference - QA pipelines can't run batch if uploaded_file is not None: df = load_data(uploaded_file) # Ask user for index to try out user_idx = st.text_input(f'Enter index (max {df.shape[0] - 1}', "0") # Only run rest of script if user provides index if user_idx is not None: first_example = df['review'][int(user_idx)] question = "How was the recipe improved?" if postprocessing == "no postprocessing": resp = question_answer(question=question, context=first_example, top_k=5) elif postprocessing == "remove substrings": resp = question_answer(question=question, context=first_example, top_k=5) resp = remove_substring(resp) elif postprocessing == "iteration": resp = remove_substring_iter(question_answer, question, first_example) # Present results st.markdown(f""" # Results The review provided was: {first_example} The question asked was: {question} The answers were: {resp} """)