awinml commited on
Commit
e514fa8
1 Parent(s): 40eb760

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -59,11 +59,17 @@ def save_key(api_key):
59
  return api_key
60
 
61
 
62
- def query_pinecone(query, top_k, model, index):
63
  # generate embeddings for the query
64
  xq = model.encode([query]).tolist()
65
  # search pinecone index for context passage with the answer
66
  xc = index.query(xq, top_k=top_k, include_metadata=True)
 
 
 
 
 
 
67
  return xc
68
 
69
 
@@ -127,19 +133,19 @@ st.title("Abstractive Question Answering - APPL")
127
 
128
  query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
129
 
130
- num_results = int(st.number_input("Number of Results to query", 1, 5, value=2))
131
 
132
 
133
  # Choose encoder model
134
 
135
- encoder_models_choice = ["MPNET", "SGPT"]
136
 
137
  encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
138
 
139
 
140
  # Choose decoder model
141
 
142
- decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (text_davinci)", "T5", "FLAN-T5"]
143
 
144
  decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
145
 
@@ -163,23 +169,33 @@ elif encoder_model == "SGPT":
163
  retriever_model = get_sgpt_embedding_model()
164
 
165
 
166
- query_results = query_pinecone(query_text, num_results, retriever_model, pinecone_index)
167
-
168
  window = int(st.number_input("Sentence Window Size", 1, 3, value=1))
169
 
 
 
 
 
 
 
170
  data = get_data()
171
 
172
- # context_list = format_query(query_results)
173
- context_list = sentence_id_combine(data, query_results, lag=window)
 
 
 
 
 
 
174
 
175
 
176
  st.subheader("Answer:")
177
 
178
 
179
- if decoder_model == "GPT3 (text_davinci)":
180
  openai_key = st.text_input(
181
  "Enter OpenAI key",
182
- value="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",
183
  type="password",
184
  )
185
  api_key = save_key(openai_key)
@@ -193,7 +209,7 @@ if decoder_model == "GPT3 (text_davinci)":
193
  elif decoder_model == "GPT3 (QA_davinci)":
194
  openai_key = st.text_input(
195
  "Enter OpenAI key",
196
- value="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",
197
  type="password",
198
  )
199
  api_key = save_key(openai_key)
 
59
  return api_key
60
 
61
 
62
+ def query_pinecone(query, top_k, model, index, threshold=0.5):
63
  # generate embeddings for the query
64
  xq = model.encode([query]).tolist()
65
  # search pinecone index for context passage with the answer
66
  xc = index.query(xq, top_k=top_k, include_metadata=True)
67
+ # filter the context passages based on the score threshold
68
+ filtered_matches = []
69
+ for match in xc["matches"]:
70
+ if match["score"] >= threshold:
71
+ filtered_matches.append(match)
72
+ xc["matches"] = filtered_matches
73
  return xc
74
 
75
 
 
133
 
134
  query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
135
 
136
+ num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
137
 
138
 
139
  # Choose encoder model
140
 
141
+ encoder_models_choice = ["SGPT", "MPNET"]
142
 
143
  encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
144
 
145
 
146
  # Choose decoder model
147
 
148
+ decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (summary_davinci)", "T5", "FLAN-T5"]
149
 
150
  decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
151
 
 
169
  retriever_model = get_sgpt_embedding_model()
170
 
171
 
 
 
172
  window = int(st.number_input("Sentence Window Size", 1, 3, value=1))
173
 
174
+ threshold = float(
175
+ st.number_input(
176
+ label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.55
177
+ )
178
+ )
179
+
180
  data = get_data()
181
 
182
+ query_results = query_pinecone(
183
+ query_text, num_results, retriever_model, pinecone_index, threshold
184
+ )
185
+
186
+ if threshold <= 0.65:
187
+ context_list = sentence_id_combine(data, query_results, lag=window)
188
+ else:
189
+ context_list = format_query(query_results)
190
 
191
 
192
  st.subheader("Answer:")
193
 
194
 
195
+ if decoder_model == "GPT3 (summary_davinci)":
196
  openai_key = st.text_input(
197
  "Enter OpenAI key",
198
+ value="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",
199
  type="password",
200
  )
201
  api_key = save_key(openai_key)
 
209
  elif decoder_model == "GPT3 (QA_davinci)":
210
  openai_key = st.text_input(
211
  "Enter OpenAI key",
212
+ value="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",
213
  type="password",
214
  )
215
  api_key = save_key(openai_key)