awinml commited on
Commit
1a08523
1 Parent(s): 8b61059

Upload 8 files

Browse files
Files changed (4) hide show
  1. app.py +287 -129
  2. utils/models.py +8 -6
  3. utils/prompts.py +48 -1
  4. utils/retriever.py +52 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import openai
2
  import streamlit_scrollable_textbox as stx
3
 
@@ -25,7 +27,7 @@ from utils.models import (
25
  get_spacy_model,
26
  get_splade_sparse_embedding_model,
27
  get_t5_model,
28
- gpt_model,
29
  save_key,
30
  )
31
  from utils.prompts import (
@@ -36,8 +38,10 @@ from utils.prompts import (
36
  generate_flant5_prompt_summ_chunk_context_single,
37
  generate_gpt_j_two_shot_prompt_1,
38
  generate_gpt_j_two_shot_prompt_2,
39
- generate_gpt_prompt,
40
- generate_gpt_prompt_2,
 
 
41
  get_context_list_prompt,
42
  )
43
  from utils.retriever import (
@@ -46,6 +50,7 @@ from utils.retriever import (
46
  query_pinecone_sparse,
47
  sentence_id_combine,
48
  text_lookup,
 
49
  )
50
  from utils.transcript_retrieval import retrieve_transcript
51
  from utils.vector_index import (
@@ -66,59 +71,29 @@ col1, col2 = st.columns([3, 3], gap="medium")
66
 
67
  with st.sidebar:
68
  ner_choice = st.selectbox("Select NER Model", ["Spacy", "Alpaca"])
 
 
 
69
 
70
  if ner_choice == "Spacy":
71
  ner_model = get_spacy_model()
72
 
73
  with col1:
74
  st.subheader("Question")
75
- query_text = st.text_area(
76
- "Input Query",
77
- value="What was discussed regarding Wearables revenue performance?",
78
- )
79
-
80
- if ner_choice == "Alpaca":
81
- ner_prompt = generate_alpaca_ner_prompt(query_text)
82
- entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
83
- company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
84
- entity_text
85
- )
86
- else:
87
- company_ent = extract_ticker_spacy(query_text, ner_model)
88
- quarter_ent, year_ent = extract_quarter_year(query_text)
89
-
90
- ticker_index, quarter_index, year_index = clean_entities(
91
- company_ent, quarter_ent, year_ent
92
- )
93
-
94
- with col1:
95
- years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
96
-
97
- with col1:
98
- # Hardcoding the defaults for a question without metadata
99
- if (
100
- query_text
101
- == "What was discussed regarding Wearables revenue performance?"
102
- ):
103
- year = st.selectbox("Year", years_choice)
104
- else:
105
- year = st.selectbox("Year", years_choice, index=year_index)
106
-
107
- with col1:
108
- # Hardcoding the defaults for a question without metadata
109
- if (
110
- query_text
111
- == "What was discussed regarding Wearables revenue performance?"
112
- ):
113
- quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"])
114
  else:
115
- quarter = st.selectbox(
116
- "Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index
 
117
  )
118
 
119
- with col1:
120
- participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
121
 
 
 
122
  ticker_choice = [
123
  "AAPL",
124
  "CSCO",
@@ -132,23 +107,87 @@ ticker_choice = [
132
  "AMD",
133
  ]
134
 
135
- with col1:
136
- # Hardcoding the defaults for a question without metadata
137
- if (
138
- query_text
139
- == "What was discussed regarding Wearables revenue performance?"
140
- ):
141
- ticker = st.selectbox("Company", ticker_choice)
 
142
  else:
143
- ticker = st.selectbox("Company", ticker_choice, ticker_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  with st.sidebar:
146
  st.subheader("Select Options:")
147
 
148
- with st.sidebar:
149
- num_results = int(
150
- st.number_input("Number of Results to query", 1, 15, value=5)
151
- )
 
 
 
 
152
 
153
 
154
  # Choose encoder model
@@ -160,8 +199,11 @@ with st.sidebar:
160
 
161
  # Choose decoder model
162
 
163
- decoder_models_choice = ["GPT3 - (text-davinci-003)", "T5", "FLAN-T5", "GPT-J"]
164
-
 
 
 
165
  with st.sidebar:
166
  decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
167
 
@@ -198,66 +240,140 @@ elif encoder_model == "Hybrid MPNET - SPLADE":
198
  ) = get_splade_sparse_embedding_model()
199
 
200
  with st.sidebar:
201
- window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
202
-
203
- with st.sidebar:
204
- threshold = float(
205
- st.number_input(
206
- label="Similarity Score Threshold",
207
- step=0.05,
208
- format="%.2f",
209
- value=0.25,
 
 
 
 
 
 
 
 
 
 
 
 
210
  )
211
- )
212
 
213
  data = get_data()
214
 
215
- if encoder_model == "Hybrid SGPT - SPLADE":
216
- dense_query_embedding = create_dense_embeddings(
217
- query_text, retriever_model
218
- )
219
- sparse_query_embedding = create_sparse_embeddings(
220
- query_text, sparse_retriever_model, sparse_retriever_tokenizer
221
- )
222
- dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
223
- dense_query_embedding, sparse_query_embedding, 0
224
- )
225
- query_results = query_pinecone_sparse(
226
- dense_query_embedding,
227
- sparse_query_embedding,
228
- num_results,
229
- pinecone_index,
230
- year,
231
- quarter,
232
- ticker,
233
- participant_type,
234
- threshold,
235
- )
 
236
 
237
- else:
238
- dense_query_embedding = create_dense_embeddings(
239
- query_text, retriever_model
240
- )
241
- query_results = query_pinecone(
242
- dense_query_embedding,
243
- num_results,
244
- pinecone_index,
245
- year,
246
- quarter,
247
- ticker,
248
- participant_type,
249
- threshold,
250
- )
251
 
 
 
 
 
252
 
253
- if threshold <= 0.90:
254
- context_list = sentence_id_combine(data, query_results, lag=window)
255
  else:
256
- context_list = format_query(query_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- if decoder_model == "GPT3 - (text-davinci-003)":
260
- prompt = generate_gpt_prompt(query_text, context_list)
261
  with col2:
262
  with st.form("my_form"):
263
  edited_prompt = st.text_area(
@@ -273,9 +389,20 @@ if decoder_model == "GPT3 - (text-davinci-003)":
273
  if submitted:
274
  api_key = save_key(openai_key)
275
  openai.api_key = api_key
276
- generated_text = gpt_model(edited_prompt)
277
  st.subheader("Answer:")
278
- st.write(generated_text)
 
 
 
 
 
 
 
 
 
 
 
279
 
280
 
281
  elif decoder_model == "T5":
@@ -384,22 +511,53 @@ if decoder_model == "GPT-J":
384
  )
385
  submitted = st.form_submit_button("Submit")
386
 
 
387
 
388
- with col1:
389
- with st.expander("See Retrieved Text"):
390
- st.subheader("Retrieved Text:")
391
- for context_text in context_list:
392
- context_text = f"""{context_text}"""
393
- st.write(
394
- f"<ul><li><p>{context_text}</p></li></ul>",
395
- unsafe_allow_html=True,
396
- )
397
-
398
- file_text = retrieve_transcript(data, year, quarter, ticker)
399
 
400
- with col1:
401
- with st.expander("See Transcript"):
402
- st.subheader("Earnings Call Transcript:")
403
- stx.scrollableTextbox(
404
- file_text, height=700, border=False, fontFamily="Helvetica"
405
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
  import openai
4
  import streamlit_scrollable_textbox as stx
5
 
 
27
  get_spacy_model,
28
  get_splade_sparse_embedding_model,
29
  get_t5_model,
30
+ gpt_turbo_model,
31
  save_key,
32
  )
33
  from utils.prompts import (
 
38
  generate_flant5_prompt_summ_chunk_context_single,
39
  generate_gpt_j_two_shot_prompt_1,
40
  generate_gpt_j_two_shot_prompt_2,
41
+ generate_gpt_prompt_alpaca,
42
+ generate_gpt_prompt_alpaca_multi_doc,
43
+ generate_gpt_prompt_original,
44
+ generate_multi_doc_context,
45
  get_context_list_prompt,
46
  )
47
  from utils.retriever import (
 
50
  query_pinecone_sparse,
51
  sentence_id_combine,
52
  text_lookup,
53
+ year_quarter_range,
54
  )
55
  from utils.transcript_retrieval import retrieve_transcript
56
  from utils.vector_index import (
 
71
 
72
  with st.sidebar:
73
  ner_choice = st.selectbox("Select NER Model", ["Spacy", "Alpaca"])
74
+ document_type = st.selectbox(
75
+ "Select Query Type", ["Single-Document", "Multi-Document"]
76
+ )
77
 
78
  if ner_choice == "Spacy":
79
  ner_model = get_spacy_model()
80
 
81
  with col1:
82
  st.subheader("Question")
83
+ if document_type == "Single-Document":
84
+ query_text = st.text_area(
85
+ "Input Query",
86
+ value="What was discussed regarding Wearables revenue performance?",
87
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  else:
89
+ query_text = st.text_area(
90
+ "Input Query",
91
+ value="How has Apple's revenue from Wearables performed over the past 2 years?",
92
  )
93
 
 
 
94
 
95
+ years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
96
+ quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
97
  ticker_choice = [
98
  "AAPL",
99
  "CSCO",
 
107
  "AMD",
108
  ]
109
 
110
+
111
+ if document_type == "Single-Document":
112
+ if ner_choice == "Alpaca":
113
+ ner_prompt = generate_alpaca_ner_prompt(query_text)
114
+ entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
115
+ company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
116
+ entity_text
117
+ )
118
  else:
119
+ company_ent = extract_ticker_spacy(query_text, ner_model)
120
+ quarter_ent, year_ent = extract_quarter_year(query_text)
121
+
122
+ ticker_index, quarter_index, year_index = clean_entities(
123
+ company_ent, quarter_ent, year_ent
124
+ )
125
+
126
+ with col1:
127
+ # Hardcoding the defaults for a question without metadata
128
+ if (
129
+ query_text
130
+ == "What was discussed regarding Wearables revenue performance?"
131
+ ):
132
+ year = st.selectbox("Year", years_choice)
133
+ quarter = st.selectbox("Quarter", quarters_choice)
134
+ ticker = st.selectbox("Company", ticker_choice)
135
+ else:
136
+ year = st.selectbox("Year", years_choice, index=year_index)
137
+ quarter = st.selectbox(
138
+ "Quarter", quarters_choice, index=quarter_index
139
+ )
140
+ ticker = st.selectbox("Company", ticker_choice, ticker_index)
141
+
142
+ participant_type = st.selectbox(
143
+ "Speaker", ["Company Speaker", "Analyst"]
144
+ )
145
+
146
+ else:
147
+ # Multi-Document Case
148
+
149
+ with col1:
150
+ # Hardcoding the defaults for a question without metadata
151
+ if (
152
+ query_text
153
+ == "How has Apple's revenue from Wearables performed over the past 2 years?"
154
+ ):
155
+ start_year = st.selectbox("Start Year", years_choice, index=2)
156
+ start_quarter = st.selectbox(
157
+ "Start Quarter", quarters_choice, index=0
158
+ )
159
+
160
+ end_year = st.selectbox("End Year", years_choice, index=0)
161
+ end_quarter = st.selectbox("End Quarter", quarters_choice, index=0)
162
+
163
+ ticker = st.selectbox("Company", ticker_choice, index=0)
164
+ else:
165
+ start_year = st.selectbox("Start Year", years_choice, index=2)
166
+ start_quarter = st.selectbox(
167
+ "Start Quarter", quarters_choice, index=0
168
+ )
169
+
170
+ end_year = st.selectbox("End Year", years_choice, index=0)
171
+ end_quarter = st.selectbox("End Quarter", quarters_choice, index=0)
172
+
173
+ ticker = st.selectbox("Company", ticker_choice, index=0)
174
+
175
+ participant_type = st.selectbox(
176
+ "Speaker", ["Company Speaker", "Analyst"]
177
+ )
178
+
179
 
180
  with st.sidebar:
181
  st.subheader("Select Options:")
182
 
183
+ if document_type == "Single-Document":
184
+ num_results = int(
185
+ st.number_input("Number of Results to query", 1, 15, value=5)
186
+ )
187
+ else:
188
+ num_results = int(
189
+ st.number_input("Number of Results to query", 1, 15, value=2)
190
+ )
191
 
192
 
193
  # Choose encoder model
 
199
 
200
  # Choose decoder model
201
 
202
+ # Restricting multi-document to only GPT-3
203
+ if document_type == "Single-Document":
204
+ decoder_models_choice = ["GPT-3.5 Turbo", "T5", "FLAN-T5", "GPT-J"]
205
+ else:
206
+ decoder_models_choice = ["GPT-3.5 Turbo"]
207
  with st.sidebar:
208
  decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
209
 
 
240
  ) = get_splade_sparse_embedding_model()
241
 
242
  with st.sidebar:
243
+ if document_type == "Single-Document":
244
+ window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
245
+
246
+ threshold = float(
247
+ st.number_input(
248
+ label="Similarity Score Threshold",
249
+ step=0.05,
250
+ format="%.2f",
251
+ value=0.25,
252
+ )
253
+ )
254
+ else:
255
+ window = int(st.number_input("Sentence Window Size", 0, 10, value=0))
256
+
257
+ threshold = float(
258
+ st.number_input(
259
+ label="Similarity Score Threshold",
260
+ step=0.05,
261
+ format="%.2f",
262
+ value=0.6,
263
+ )
264
  )
 
265
 
266
  data = get_data()
267
 
268
+ if document_type == "Single-Document":
269
+ if encoder_model == "Hybrid SGPT - SPLADE":
270
+ dense_query_embedding = create_dense_embeddings(
271
+ query_text, retriever_model
272
+ )
273
+ sparse_query_embedding = create_sparse_embeddings(
274
+ query_text, sparse_retriever_model, sparse_retriever_tokenizer
275
+ )
276
+ dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
277
+ dense_query_embedding, sparse_query_embedding, 0
278
+ )
279
+ query_results = query_pinecone_sparse(
280
+ dense_query_embedding,
281
+ sparse_query_embedding,
282
+ num_results,
283
+ pinecone_index,
284
+ year,
285
+ quarter,
286
+ ticker,
287
+ participant_type,
288
+ threshold,
289
+ )
290
 
291
+ else:
292
+ dense_query_embedding = create_dense_embeddings(
293
+ query_text, retriever_model
294
+ )
295
+ query_results = query_pinecone(
296
+ dense_query_embedding,
297
+ num_results,
298
+ pinecone_index,
299
+ year,
300
+ quarter,
301
+ ticker,
302
+ participant_type,
303
+ threshold,
304
+ )
305
 
306
+ if threshold <= 0.90:
307
+ context_list = sentence_id_combine(data, query_results, lag=window)
308
+ else:
309
+ context_list = format_query(query_results)
310
 
 
 
311
  else:
312
+ # Multi-Document Retreival
313
+ if encoder_model == "Hybrid SGPT - SPLADE":
314
+ dense_query_embedding = create_dense_embeddings(
315
+ query_text, retriever_model
316
+ )
317
+ sparse_query_embedding = create_sparse_embeddings(
318
+ query_text, sparse_retriever_model, sparse_retriever_tokenizer
319
+ )
320
+ dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
321
+ dense_query_embedding, sparse_query_embedding, 0
322
+ )
323
+ year_quarter_list = year_quarter_range(
324
+ start_quarter, start_year, end_quarter, end_year
325
+ )
326
+
327
+ context_group = []
328
+ for year, quarter in year_quarter_list:
329
+ query_results = query_pinecone_sparse(
330
+ dense_query_embedding,
331
+ sparse_query_embedding,
332
+ num_results,
333
+ pinecone_index,
334
+ year,
335
+ quarter,
336
+ ticker,
337
+ participant_type,
338
+ threshold,
339
+ )
340
+ results_list = sentence_id_combine(data, query_results, lag=window)
341
+ context_group.append((results_list, year, quarter))
342
+
343
+ else:
344
+ dense_query_embedding = create_dense_embeddings(
345
+ query_text, retriever_model
346
+ )
347
+ year_quarter_list = year_quarter_range(
348
+ start_quarter, start_year, end_quarter, end_year
349
+ )
350
+
351
+ context_group = []
352
+ for year, quarter in year_quarter_list:
353
+ query_results = query_pinecone(
354
+ dense_query_embedding,
355
+ num_results,
356
+ pinecone_index,
357
+ year,
358
+ quarter,
359
+ ticker,
360
+ participant_type,
361
+ threshold,
362
+ )
363
+ results_list = sentence_id_combine(data, query_results, lag=window)
364
+ context_group.append((results_list, year, quarter))
365
 
366
+ multi_doc_context = generate_multi_doc_context(context_group)
367
+
368
+
369
+ if decoder_model == "GPT-3.5 Turbo":
370
+ if document_type == "Single-Document":
371
+ prompt = generate_gpt_prompt_alpaca(query_text, context_list)
372
+ else:
373
+ prompt = generate_gpt_prompt_alpaca_multi_doc(
374
+ query_text, context_group
375
+ )
376
 
 
 
377
  with col2:
378
  with st.form("my_form"):
379
  edited_prompt = st.text_area(
 
389
  if submitted:
390
  api_key = save_key(openai_key)
391
  openai.api_key = api_key
392
+ generated_text = gpt_turbo_model(edited_prompt)
393
  st.subheader("Answer:")
394
+ regex_pattern_sentences = (
395
+ "(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s"
396
+ )
397
+ generated_text_list = re.split(
398
+ regex_pattern_sentences, generated_text
399
+ )
400
+ for answer_text in generated_text_list:
401
+ answer_text = f"""{answer_text}"""
402
+ st.write(
403
+ f"<ul><li><p>{answer_text}</p></li></ul>",
404
+ unsafe_allow_html=True,
405
+ )
406
 
407
 
408
  elif decoder_model == "T5":
 
511
  )
512
  submitted = st.form_submit_button("Submit")
513
 
514
+ tab1, tab2 = st.tabs(["Retrived Text", "Retrieved Documents"])
515
 
 
 
 
 
 
 
 
 
 
 
 
516
 
517
+ with tab1:
518
+ if document_type == "Single-Document":
519
+ with st.expander("See Retrieved Text"):
520
+ st.subheader("Retrieved Text:")
521
+ for context_text in context_list:
522
+ context_text = f"""{context_text}"""
523
+ st.write(
524
+ f"<ul><li><p>{context_text}</p></li></ul>",
525
+ unsafe_allow_html=True,
526
+ )
527
+ else:
528
+ with st.expander("See Retrieved Text"):
529
+ st.subheader("Retrieved Text:")
530
+ sections = [
531
+ s.strip()
532
+ for s in multi_doc_context.split("Document: ")
533
+ if s.strip()
534
+ ]
535
+
536
+ # Add "Document: " back to the beginning of each section
537
+ context_list = [
538
+ "Document: " + s[0:7] + "\n" + s[7:] for s in sections
539
+ ]
540
+ for context_text in context_list:
541
+ context_text = f"""{context_text}"""
542
+ st.write(
543
+ f"<ul><li><p>{context_text}</p></li></ul>",
544
+ unsafe_allow_html=True,
545
+ )
546
+
547
+
548
+ with tab2:
549
+ if document_type == "Single-Document":
550
+ file_text = retrieve_transcript(data, year, quarter, ticker)
551
+ with st.expander("See Transcript"):
552
+ st.subheader("Earnings Call Transcript:")
553
+ stx.scrollableTextbox(
554
+ file_text, height=700, border=False, fontFamily="Helvetica"
555
+ )
556
+ else:
557
+ for year, quarter in year_quarter_list:
558
+ file_text = retrieve_transcript(data, year, quarter, ticker)
559
+ with st.expander(f"See Transcript - {quarter} {year}"):
560
+ st.subheader("Earnings Call Transcript - {quarter} {year}:")
561
+ stx.scrollableTextbox(
562
+ file_text, height=700, border=False, fontFamily="Helvetica"
563
+ )
utils/models.py CHANGED
@@ -103,14 +103,16 @@ def save_key(api_key):
103
  # Text Generation
104
 
105
 
106
- def gpt_model(prompt):
107
- response = openai.Completion.create(
108
- model="text-davinci-003",
109
- prompt=prompt,
110
- temperature=0,
 
 
111
  max_tokens=1024,
112
  )
113
- return response.choices[0].text
114
 
115
 
116
  def generate_text_flan_t5(model, tokenizer, input_text):
 
103
  # Text Generation
104
 
105
 
106
+ def gpt_turbo_model(prompt):
107
+ response = openai.ChatCompletion.create(
108
+ model="gpt-3.5-turbo",
109
+ messages=[
110
+ {"role": "user", "content": prompt},
111
+ ],
112
+ temperature=0.01,
113
  max_tokens=1024,
114
  )
115
+ return response["choices"][0]["message"]["content"]
116
 
117
 
118
  def generate_text_flan_t5(model, tokenizer, input_text):
utils/prompts.py CHANGED
@@ -1,4 +1,51 @@
1
- def generate_gpt_prompt(query_text, context_list):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  context = " ".join(context_list)
3
  prompt = f"""Answer the question in 6 long detailed points as accurately as possible using the provided context. Include as many key details as possible.
4
  Context: {context}
 
1
+ def generate_multi_doc_context(context_group):
2
+ multi_doc_context = ""
3
+ for context_text_list, year, quarter in context_group:
4
+ print((context_text_list, year, quarter))
5
+ if context_text_list == []:
6
+ break
7
+ else:
8
+ multi_doc_context = (
9
+ multi_doc_context
10
+ + "\n"
11
+ + f"Document: {quarter} {year}"
12
+ + "\n"
13
+ + " ".join(context_text_list)
14
+ )
15
+ return multi_doc_context
16
+
17
+
18
+ def generate_gpt_prompt_alpaca(query_text, context_list):
19
+ context = " ".join(context_list)
20
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to write a response that that appropriately completes the request:
21
+ ### Instruction:
22
+ - Write a detailed paragraph consisting of exactly five complete sentences that answer the question based on the provided context.
23
+ - Focus on addressing the specific question posed, providing as much relevant information and detail as possible.
24
+ - Only use details from the provided context that directly address the question; do not include any additional information that is not explicitly stated.
25
+ - Aim to provide a clear and concise summary that fully addresses the question.
26
+
27
+ Question: {query_text}
28
+ Context: {context}
29
+ ### Response:"""
30
+ return prompt
31
+
32
+
33
+ def generate_gpt_prompt_alpaca_multi_doc(query_text, context_group):
34
+ multi_doc_context = generate_multi_doc_context(context_group)
35
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to write a response that that appropriately completes the request:
36
+ ### Instruction:
37
+ - Write a detailed paragraph consisting of exactly five complete sentences that answer the question based on the provided context.
38
+ - Focus on addressing the specific question posed, providing as much relevant information and detail as possible.
39
+ - Only use details from the provided context that directly address the question; do not include any additional information that is not explicitly stated.
40
+ - Aim to provide a clear and concise summary that fully addresses the question.
41
+
42
+ Question: {query_text}
43
+ Context: {multi_doc_context}
44
+ ### Response:"""
45
+ return prompt
46
+
47
+
48
+ def generate_gpt_prompt_original(query_text, context_list):
49
  context = " ".join(context_list)
50
  prompt = f"""Answer the question in 6 long detailed points as accurately as possible using the provided context. Include as many key details as possible.
51
  Context: {context}
utils/retriever.py CHANGED
@@ -195,3 +195,55 @@ def sentence_id_combine(data, query_results, lag=1):
195
  def text_lookup(data, sentence_ids):
196
  context = ". ".join(data.iloc[sentence_ids].to_list())
197
  return context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def text_lookup(data, sentence_ids):
196
  context = ". ".join(data.iloc[sentence_ids].to_list())
197
  return context
198
+
199
+
200
+ def year_quarter_range(start_quarter, start_year, end_quarter, end_year):
201
+ """Creates a list of all (year, quarter) pairs that lie in the range including the start and end quarters."""
202
+ start_year = int(start_year)
203
+ end_year = int(end_year)
204
+
205
+ quarters = (
206
+ [("Q1", "Q2", "Q3", "Q4")] * (end_year - start_year)
207
+ + [("Q1", "Q2", "Q3" if end_quarter == "Q4" else "Q4")]
208
+ * (end_quarter == "Q4")
209
+ + [
210
+ (
211
+ "Q1"
212
+ if start_quarter == "Q1"
213
+ else "Q2"
214
+ if start_quarter == "Q2"
215
+ else "Q3"
216
+ if start_quarter == "Q3"
217
+ else "Q4",
218
+ )
219
+ * (end_year - start_year)
220
+ ]
221
+ )
222
+ years = list(range(start_year, end_year + 1))
223
+ list_year_quarter = [
224
+ (y, q) for y in years for q in quarters[years.index(y)]
225
+ ]
226
+ # Remove duplicate pairs
227
+ seen = set()
228
+ list_year_quarter_cleaned = []
229
+ for tup in list_year_quarter:
230
+ if tup not in seen:
231
+ seen.add(tup)
232
+ list_year_quarter_cleaned.append(tup)
233
+ return list_year_quarter_cleaned
234
+
235
+
236
+ def multi_document_query(
237
+ dense_query_embedding,
238
+ sparse_query_embedding,
239
+ num_results,
240
+ pinecone_index,
241
+ start_quarter,
242
+ start_year,
243
+ end_quarter,
244
+ end_year,
245
+ ticker,
246
+ participant_type,
247
+ threshold,
248
+ ):
249
+ pass