File size: 4,408 Bytes
1e434e8
 
 
 
 
8ddfa6e
1e434e8
 
 
 
 
 
 
 
961346c
 
8ddfa6e
 
1f6bd67
c28ffd1
1f6bd67
 
 
 
 
 
 
 
 
 
 
23e04e2
 
 
 
 
 
 
 
 
 
 
c78eaed
 
 
 
 
 
 
 
 
 
 
 
 
 
23e04e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1d41c3
23e04e2
 
 
 
 
1e434e8
 
 
 
 
 
 
23e04e2
1f6bd67
23e04e2
 
 
 
961346c
1e434e8
 
 
15ccfd9
 
 
 
 
8fbe78b
8fa0f4a
 
 
124247d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from distutils.command.upload import upload
import pandas as pd
import streamlit as st

from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
from transformers import pipeline


@st.cache
def load_data(file):
    df = pd.read_csv(file, encoding='utf-8', nrows=50)
    return df


def load_pipeline(model_cp, tokenizer_cp):
    return pipeline("question-answering", model=model_cp, tokenizer=tokenizer_cp)


def choose_model():
    with st.sidebar:
        st.write("# Model Selection")
        model_cp = st.selectbox('Select model for inference',
                            ('deepset/roberta-base-squad2',
                            'aidan-o-brien/recipe-improver'))
        # If not my model > model_cp = tokenizer_cp, else > albert tokenizer
        if model_cp == "aidan-o-brien/recipe-improver":
            return model_cp, "albert-base-v2"
        else:
            return model_cp, model_cp


def choose_postprocessing():
    with st.sidebar:
        st.write('# Postprocessing')
        postprocessing = st.selectbox('Select postprocessing method',
            ('no postprocessing',
             'remove substrings',
             'iteration'))
        return postprocessing


def remove_substring(resp):
  """Function filters postprocessed model output to remove answers that
  are substrings of another answer."""
  answers = [answer['answer'] for answer in resp]

  answers.sort(key=lambda s: len(s), reverse=True)
  filtered_answers = []
  for s in answers:
    if not any([s in o for o in filtered_answers]):
      filtered_answers.append(s)

  filtered_resp = list(filter(lambda r: r['answer'] in filtered_answers, resp))
  return filtered_resp


def iterate(resp, context):
    # Remove substring answers
    filtered_resp = remove_substring(resp)
    # Remove best answer from original context
    best_answer = filtered_resp[0]['answer']
    new_context = context.replace(best_answer, "")
    return new_context, filtered_resp


def remove_substring_iter(pipeline, question, context):
    # Create answers placeholder
    ret_resp = []

    # Loop through five times, removing best and re-running
    for _ in range(5):
        resp = pipeline(question, context, top_k=5,
                        handle_impossible_answer=True)
        # Update context
        context, filtered_resp = iterate(resp, context)
        # If best score not above threshold, quit
        if filtered_resp[0]['score'] < 1e-2:
            break
        ret_resp.append(filtered_resp[0])
    return ret_resp


# Page config
title = "Recipe Improver"
icon = "🍣"
st.set_page_config(page_title=title, page_icon=icon)
st.title(title)


# Choose model and postprocessing procedure
model_cp, tokenizer_cp = choose_model()
postprocessing = choose_postprocessing()


# Load model and tokenizer
question_answer = load_pipeline(model_cp, tokenizer_cp)
st.write("Model and tokenizer successfully loaded.")


# Upload csv - format with expander for aesthetics
with st.expander("Upload csv file"):
    uploaded_file = st.file_uploader("Choose a csv file", type="csv", key='file_uploader')


# If file is uploaded, run inference - QA pipelines can't run batch
if uploaded_file is not None:
    df = load_data(uploaded_file)

    # Ask user for index to try out
    user_idx = st.text_input(f'Enter index (max {df.shape[0] - 1}', "0")

    # Only run rest of script if user provides index
    if user_idx is not None:
        first_example = df['review'][int(user_idx)]
        question = "How was the recipe improved?"
        if postprocessing == "no postprocessing":
            resp = question_answer(question=question,
                                context=first_example,
                                top_k=5)
        elif postprocessing == "remove substrings":
            resp = question_answer(question=question,
                                context=first_example,
                                top_k=5)
            resp = remove_substring(resp)
        elif postprocessing == "iteration":
            resp = remove_substring_iter(question_answer,
                                        question,
                                        first_example)

        # Present results
        st.markdown(f"""
        # Results

        The review provided was:

        {first_example}

        The question asked was:

        {question}

        The answers were:

        {resp}
        """)