awinml commited on
Commit
9975133
1 Parent(s): 72a93a0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +73 -37
  2. utils.py +15 -7
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import pinecone
2
  import streamlit as st
 
 
 
3
  import streamlit_scrollable_textbox as stx
4
  import openai
5
  from utils import (
@@ -17,23 +20,32 @@ from utils import (
17
  format_query,
18
  sentence_id_combine,
19
  text_lookup,
20
- gpt3,
 
21
  )
22
 
23
 
24
  st.title("Abstractive Question Answering")
25
 
 
26
  st.write(
27
  "The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
28
  )
29
 
30
- query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
 
 
 
 
31
 
32
- years_choice = ["2020", "2019", "2018", "2017", "2016"]
 
33
 
34
- year = st.selectbox("Year", years_choice)
 
35
 
36
- quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
 
37
 
38
  ticker_choice = [
39
  "AAPL",
@@ -48,23 +60,33 @@ ticker_choice = [
48
  "AMD",
49
  ]
50
 
51
- ticker = st.selectbox("Company", ticker_choice)
 
52
 
53
- num_results = int(st.number_input("Number of Results to query", 1, 5, value=5))
 
 
 
 
54
 
55
 
56
  # Choose encoder model
57
 
58
  encoder_models_choice = ["SGPT", "MPNET"]
59
-
60
- encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
61
 
62
 
63
  # Choose decoder model
64
 
65
- decoder_models_choice = ["FLAN-T5", "T5", "GPT3 - (text-davinci-003)"]
 
 
 
 
66
 
67
- decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
 
68
 
69
 
70
  if encoder_model == "MPNET":
@@ -82,13 +104,15 @@ elif encoder_model == "SGPT":
82
  retriever_model = get_sgpt_embedding_model()
83
 
84
 
85
- window = int(st.number_input("Sentence Window Size", 0, 5, value=3))
 
86
 
87
- threshold = float(
88
- st.number_input(
89
- label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.35
 
 
90
  )
91
- )
92
 
93
  data = get_data()
94
 
@@ -109,22 +133,26 @@ else:
109
  context_list = format_query(query_results)
110
 
111
 
112
- st.subheader("Answer:")
113
-
114
 
115
  if decoder_model == "GPT3 - (text-davinci-003)":
116
- with st.form("my_form"):
117
- openai_key = st.text_input(
118
- "Enter OpenAI key",
119
- value="",
120
- type="password",
121
- )
122
- submitted = st.form_submit_button("Submit")
123
- if submitted:
124
- api_key = save_key(openai_key)
125
- openai.api_key = api_key
126
- generated_text = gpt3(query_text, context_list)
127
- st.write(generated_text)
 
 
 
 
 
128
 
129
  elif decoder_model == "T5":
130
  t5_pipeline = get_t5_model()
@@ -132,7 +160,9 @@ elif decoder_model == "T5":
132
  for context_text in context_list:
133
  output_text.append(t5_pipeline(context_text)[0]["summary_text"])
134
  generated_text = ". ".join(output_text)
135
- st.write(t5_pipeline(generated_text)[0]["summary_text"])
 
 
136
 
137
  elif decoder_model == "FLAN-T5":
138
  flan_t5_pipeline = get_flan_t5_model()
@@ -140,13 +170,19 @@ elif decoder_model == "FLAN-T5":
140
  for context_text in context_list:
141
  output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
142
  generated_text = ". ".join(output_text)
143
- st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
 
 
144
 
145
- with st.expander("See Retrieved Text"):
146
- for context_text in context_list:
147
- st.markdown(f"- {context_text}")
 
148
 
149
  file_text = retrieve_transcript(data, year, quarter, ticker)
150
 
151
- with st.expander("See Transcript"):
152
- stx.scrollableTextbox(file_text, height=700, border=False, fontFamily="Helvetica")
 
 
 
 
1
  import pinecone
2
  import streamlit as st
3
+
4
+ st.set_page_config(layout="wide")
5
+
6
  import streamlit_scrollable_textbox as stx
7
  import openai
8
  from utils import (
 
20
  format_query,
21
  sentence_id_combine,
22
  text_lookup,
23
+ generate_prompt,
24
+ gpt_model,
25
  )
26
 
27
 
28
  st.title("Abstractive Question Answering")
29
 
30
+
31
  st.write(
32
  "The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
33
  )
34
 
35
+ col1, col2 = st.columns([3, 3], gap="medium")
36
+
37
+ with col1:
38
+ st.subheader("Question")
39
+ query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
40
 
41
+ with col1:
42
+ years_choice = ["2020", "2019", "2018", "2017", "2016"]
43
 
44
+ with col1:
45
+ year = st.selectbox("Year", years_choice)
46
 
47
+ with col1:
48
+ quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
49
 
50
  ticker_choice = [
51
  "AAPL",
 
60
  "AMD",
61
  ]
62
 
63
+ with col1:
64
+ ticker = st.selectbox("Company", ticker_choice)
65
 
66
+ with st.sidebar:
67
+ st.subheader("Select Options:")
68
+
69
+ with st.sidebar:
70
+ num_results = int(st.number_input("Number of Results to query", 1, 5, value=5))
71
 
72
 
73
  # Choose encoder model
74
 
75
  encoder_models_choice = ["SGPT", "MPNET"]
76
+ with st.sidebar:
77
+ encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
78
 
79
 
80
  # Choose decoder model
81
 
82
+ decoder_models_choice = [
83
+ "GPT3 - (text-davinci-003)",
84
+ "T5",
85
+ "FLAN-T5",
86
+ ]
87
 
88
+ with st.sidebar:
89
+ decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
90
 
91
 
92
  if encoder_model == "MPNET":
 
104
  retriever_model = get_sgpt_embedding_model()
105
 
106
 
107
+ with st.sidebar:
108
+ window = int(st.number_input("Sentence Window Size", 0, 5, value=3))
109
 
110
+ with st.sidebar:
111
+ threshold = float(
112
+ st.number_input(
113
+ label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.35
114
+ )
115
  )
 
116
 
117
  data = get_data()
118
 
 
133
  context_list = format_query(query_results)
134
 
135
 
136
+ prompt = generate_prompt(query_text, context_list)
 
137
 
138
  if decoder_model == "GPT3 - (text-davinci-003)":
139
+ with col2:
140
+ with st.form("my_form"):
141
+ edited_prompt = st.text_area(label="Model Prompt", value=prompt, height=270)
142
+
143
+ openai_key = st.text_input(
144
+ "Enter OpenAI key",
145
+ value="",
146
+ type="password",
147
+ )
148
+ submitted = st.form_submit_button("Submit")
149
+ if submitted:
150
+ api_key = save_key(openai_key)
151
+ openai.api_key = api_key
152
+ generated_text = gpt_model(edited_prompt)
153
+ with col2:
154
+ st.subheader("Answer:")
155
+ st.write(generated_text)
156
 
157
  elif decoder_model == "T5":
158
  t5_pipeline = get_t5_model()
 
160
  for context_text in context_list:
161
  output_text.append(t5_pipeline(context_text)[0]["summary_text"])
162
  generated_text = ". ".join(output_text)
163
+ with col2:
164
+ st.subheader("Answer:")
165
+ st.write(t5_pipeline(generated_text)[0]["summary_text"])
166
 
167
  elif decoder_model == "FLAN-T5":
168
  flan_t5_pipeline = get_flan_t5_model()
 
170
  for context_text in context_list:
171
  output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
172
  generated_text = ". ".join(output_text)
173
+ with col2:
174
+ st.subheader("Answer:")
175
+ st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
176
 
177
+ with col1:
178
+ with st.expander("See Retrieved Text"):
179
+ for context_text in context_list:
180
+ st.markdown(f"- {context_text}")
181
 
182
  file_text = retrieve_transcript(data, year, quarter, ticker)
183
 
184
+ with col1:
185
+ with st.expander("See Transcript"):
186
+ stx.scrollableTextbox(
187
+ file_text, height=700, border=False, fontFamily="Helvetica"
188
+ )
utils.py CHANGED
@@ -113,15 +113,23 @@ def text_lookup(data, sentence_ids):
113
  return context
114
 
115
 
116
- def gpt3(query, result):
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  response = openai.Completion.create(
118
  model="text-davinci-003",
119
- prompt=f"""Context information is below. \n"
120
- "---------------------\n"
121
- "{result}"
122
- "\n---------------------\n"
123
- "Given the context information and prior knowledge, answer this question: {query}. \n"
124
- "Try to include as many key details as possible and format the answer in points. \n" """,
125
  temperature=0.1,
126
  max_tokens=512,
127
  top_p=1.0,
 
113
  return context
114
 
115
 
116
+ def generate_prompt(query_text, context_list):
117
+ #context = " ".join(context_list)
118
+ prompt = f"""
119
+ Context information is below:
120
+ ---------------------
121
+ {context_list}
122
+ ---------------------
123
+ Given the context information and prior knowledge, answer this question:
124
+ {query_text}
125
+ Try to include as many key details as possible and format the answer in points."""
126
+ return prompt
127
+
128
+
129
+ def gpt_model(prompt):
130
  response = openai.Completion.create(
131
  model="text-davinci-003",
132
+ prompt=prompt,
 
 
 
 
 
133
  temperature=0.1,
134
  max_tokens=512,
135
  top_p=1.0,