awinml commited on
Commit
0175cb6
1 Parent(s): 76c87df

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +44 -13
  2. utils.py +44 -3
app.py CHANGED
@@ -12,15 +12,18 @@ from utils import (
12
  create_sparse_embeddings,
13
  extract_entities,
14
  format_query,
 
 
 
15
  generate_flant5_prompt_instruct_chunk_context,
16
- generate_flant5_prompt_instruct_complete_context,
17
  generate_flant5_prompt_instruct_chunk_context_single,
18
- generate_flant5_prompt_summ_chunk_context_single,
19
  generate_flant5_prompt_summ_chunk_context,
20
- generate_text_flan_t5,
21
- generate_gpt_prompt,
22
  generate_gpt_j_two_shot_prompt_1,
23
  generate_gpt_j_two_shot_prompt_2,
 
 
24
  get_context_list_prompt,
25
  get_data,
26
  get_flan_t5_model,
@@ -49,7 +52,13 @@ st.write(
49
  col1, col2 = st.columns([3, 3], gap="medium")
50
 
51
 
52
- spacy_model = get_spacy_model()
 
 
 
 
 
 
53
 
54
  with col1:
55
  st.subheader("Question")
@@ -58,7 +67,12 @@ with col1:
58
  value="What was discussed regarding Wearables revenue performance?",
59
  )
60
 
61
- company_ent, quarter_ent, year_ent = extract_entities(query_text, spacy_model)
 
 
 
 
 
62
  ticker_index, quarter_index, year_index = clean_entities(
63
  company_ent, quarter_ent, year_ent
64
  )
@@ -251,7 +265,9 @@ if decoder_model == "GPT3 - (text-davinci-003)":
251
 
252
 
253
  elif decoder_model == "T5":
254
- prompt = generate_flant5_prompt_instruct_complete_context(query_text, context_list)
 
 
255
  t5_pipeline = get_t5_model()
256
  output_text = []
257
  with col2:
@@ -275,7 +291,8 @@ elif decoder_model == "FLAN-T5":
275
  output_text = []
276
  with col2:
277
  prompt_type = st.selectbox(
278
- "Select prompt type", ["Complete Text QA", "Chunkwise QA", "Chunkwise Summarize"]
 
279
  )
280
  if prompt_type == "Complete Text QA":
281
  prompt = generate_flant5_prompt_instruct_complete_context(
@@ -300,23 +317,37 @@ elif decoder_model == "FLAN-T5":
300
  submitted = st.form_submit_button("Submit")
301
  if submitted:
302
  if prompt_type == "Complete Text QA":
303
- output_text_string = generate_text_flan_t5(flan_t5_model, flan_t5_tokenizer, prompt)
 
 
304
  st.subheader("Answer:")
305
  st.write(output_text_string)
306
  elif prompt_type == "Chunkwise QA":
307
  for context_text in context_list:
308
- model_input = generate_flant5_prompt_instruct_chunk_context_single(query_text, context_text)
 
 
309
  output_text.append(
310
- generate_text_flan_t5(flan_t5_model, flan_t5_tokenizer, model_input))
 
 
 
311
  st.subheader("Answer:")
312
  for text in output_text:
313
  if "(iii)" not in text:
314
  st.markdown(f"- {text}")
315
  elif prompt_type == "Chunkwise Summarize":
316
  for context_text in context_list:
317
- model_input = generate_flant5_prompt_summ_chunk_context_single(query_text, context_text)
 
 
 
 
318
  output_text.append(
319
- generate_text_flan_t5(flan_t5_model, flan_t5_tokenizer, model_input))
 
 
 
320
  st.subheader("Answer:")
321
  for text in output_text:
322
  if "(iii)" not in text:
 
12
  create_sparse_embeddings,
13
  extract_entities,
14
  format_query,
15
+ get_flan_alpaca_xl_model,
16
+ generate_entities_flan_alpaca,
17
+ format_entities_flan_alpaca,
18
  generate_flant5_prompt_instruct_chunk_context,
 
19
  generate_flant5_prompt_instruct_chunk_context_single,
20
+ generate_flant5_prompt_instruct_complete_context,
21
  generate_flant5_prompt_summ_chunk_context,
22
+ generate_flant5_prompt_summ_chunk_context_single,
 
23
  generate_gpt_j_two_shot_prompt_1,
24
  generate_gpt_j_two_shot_prompt_2,
25
+ generate_gpt_prompt,
26
+ generate_text_flan_t5,
27
  get_context_list_prompt,
28
  get_data,
29
  get_flan_t5_model,
 
52
  col1, col2 = st.columns([3, 3], gap="medium")
53
 
54
 
55
+ with st.sidebar:
56
+ ner_choice = st.selectbox("Select NER Model", ["Alpaca", "Spacy"])
57
+
58
+ if ner_choice == "Alpaca":
59
+ ner_model = get_flan_alpaca_xl_model()
60
+ else:
61
+ ner_model = get_spacy_model()
62
 
63
  with col1:
64
  st.subheader("Question")
 
67
  value="What was discussed regarding Wearables revenue performance?",
68
  )
69
 
70
+ if ner_choice == "Alpaca":
71
+ entity_text = generate_entities_flan_alpaca(ner_model)
72
+ company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text)
73
+ else:
74
+ company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model)
75
+
76
  ticker_index, quarter_index, year_index = clean_entities(
77
  company_ent, quarter_ent, year_ent
78
  )
 
265
 
266
 
267
  elif decoder_model == "T5":
268
+ prompt = generate_flant5_prompt_instruct_complete_context(
269
+ query_text, context_list
270
+ )
271
  t5_pipeline = get_t5_model()
272
  output_text = []
273
  with col2:
 
291
  output_text = []
292
  with col2:
293
  prompt_type = st.selectbox(
294
+ "Select prompt type",
295
+ ["Complete Text QA", "Chunkwise QA", "Chunkwise Summarize"],
296
  )
297
  if prompt_type == "Complete Text QA":
298
  prompt = generate_flant5_prompt_instruct_complete_context(
 
317
  submitted = st.form_submit_button("Submit")
318
  if submitted:
319
  if prompt_type == "Complete Text QA":
320
+ output_text_string = generate_text_flan_t5(
321
+ flan_t5_model, flan_t5_tokenizer, prompt
322
+ )
323
  st.subheader("Answer:")
324
  st.write(output_text_string)
325
  elif prompt_type == "Chunkwise QA":
326
  for context_text in context_list:
327
+ model_input = generate_flant5_prompt_instruct_chunk_context_single(
328
+ query_text, context_text
329
+ )
330
  output_text.append(
331
+ generate_text_flan_t5(
332
+ flan_t5_model, flan_t5_tokenizer, model_input
333
+ )
334
+ )
335
  st.subheader("Answer:")
336
  for text in output_text:
337
  if "(iii)" not in text:
338
  st.markdown(f"- {text}")
339
  elif prompt_type == "Chunkwise Summarize":
340
  for context_text in context_list:
341
+ model_input = (
342
+ generate_flant5_prompt_summ_chunk_context_single(
343
+ query_text, context_text
344
+ )
345
+ )
346
  output_text.append(
347
+ generate_text_flan_t5(
348
+ flan_t5_model, flan_t5_tokenizer, model_input
349
+ )
350
+ )
351
  st.subheader("Answer:")
352
  for text in output_text:
353
  if "(iii)" not in text:
utils.py CHANGED
@@ -2,7 +2,6 @@ import re
2
 
3
  import openai
4
  import pandas as pd
5
- import pinecone
6
  import spacy
7
  import streamlit_scrollable_textbox as stx
8
  import torch
@@ -12,11 +11,12 @@ from transformers import (
12
  AutoModelForMaskedLM,
13
  AutoModelForSeq2SeqLM,
14
  AutoTokenizer,
 
 
15
  pipeline,
16
  )
17
- from transformers import T5Tokenizer, T5ForConditionalGeneration
18
-
19
 
 
20
  import streamlit as st
21
 
22
 
@@ -34,6 +34,11 @@ def get_spacy_model():
34
  return spacy.load("en_core_web_sm")
35
 
36
 
 
 
 
 
 
37
  # Initialize models from HuggingFace
38
 
39
 
@@ -469,6 +474,42 @@ Answer:?"""
469
 
470
  # Entity Extraction
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  def extract_quarter_year(string):
474
  # Extract year from string
 
2
 
3
  import openai
4
  import pandas as pd
 
5
  import spacy
6
  import streamlit_scrollable_textbox as stx
7
  import torch
 
11
  AutoModelForMaskedLM,
12
  AutoModelForSeq2SeqLM,
13
  AutoTokenizer,
14
+ T5ForConditionalGeneration,
15
+ T5Tokenizer,
16
  pipeline,
17
  )
 
 
18
 
19
+ import pinecone
20
  import streamlit as st
21
 
22
 
 
34
  return spacy.load("en_core_web_sm")
35
 
36
 
37
+ @st.experimental_singleton
38
+ def get_flan_alpaca_xl_model():
39
+ return pipeline(model="declare-lab/flan-alpaca-xl")
40
+
41
+
42
  # Initialize models from HuggingFace
43
 
44
 
 
474
 
475
  # Entity Extraction
476
 
477
+ def generate_entities_flan_alpaca(model):
478
+ output = model(prompt, max_length=512, temperature=0.1)
479
+ generated_text = output[0]["generated_text"]
480
+ return generated_text
481
+
482
+
483
+ def format_entities_flan_alpaca(model_output):
484
+ """
485
+ Extracts the text for each entity from the output generated by the
486
+ Flan-Alpaca model.
487
+ """
488
+ try:
489
+ company_string, quarter_string, year_string = values.split(", ")
490
+ except:
491
+ company = None
492
+ quarter = None
493
+ year = None
494
+ try:
495
+ company = company_string.split(" - ")[1].lower()
496
+ company = None if company.lower() == 'none' else company
497
+ except:
498
+ company = None
499
+ try:
500
+ quarter = quarter_string.split(" - ")[1]
501
+ quarter = None if quarter.lower() == 'none' else quarter
502
+
503
+ except:
504
+ quarter = None
505
+ try:
506
+ year = year_string.split(" - ")[1]
507
+ year = None if year.lower() == 'none' else year
508
+
509
+ except:
510
+ year = None
511
+ return company, quarter, year
512
+
513
 
514
  def extract_quarter_year(string):
515
  # Extract year from string