|
import openai |
|
import streamlit_scrollable_textbox as stx |
|
|
|
import pinecone |
|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
from utils import ( |
|
clean_entities, |
|
create_dense_embeddings, |
|
create_sparse_embeddings, |
|
extract_entities, |
|
format_query, |
|
get_flan_alpaca_xl_model, |
|
generate_entities_flan_alpaca, |
|
format_entities_flan_alpaca, |
|
generate_flant5_prompt_instruct_chunk_context, |
|
generate_flant5_prompt_instruct_chunk_context_single, |
|
generate_flant5_prompt_instruct_complete_context, |
|
generate_flant5_prompt_summ_chunk_context, |
|
generate_flant5_prompt_summ_chunk_context_single, |
|
generate_gpt_j_two_shot_prompt_1, |
|
generate_gpt_j_two_shot_prompt_2, |
|
generate_gpt_prompt, |
|
generate_text_flan_t5, |
|
get_context_list_prompt, |
|
get_data, |
|
get_flan_t5_model, |
|
get_mpnet_embedding_model, |
|
get_sgpt_embedding_model, |
|
get_spacy_model, |
|
get_splade_sparse_embedding_model, |
|
get_t5_model, |
|
gpt_model, |
|
hybrid_score_norm, |
|
query_pinecone, |
|
query_pinecone_sparse, |
|
retrieve_transcript, |
|
save_key, |
|
sentence_id_combine, |
|
text_lookup, |
|
) |
|
|
|
st.title("Abstractive Question Answering") |
|
|
|
|
|
st.write( |
|
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." |
|
) |
|
|
|
col1, col2 = st.columns([3, 3], gap="medium") |
|
|
|
|
|
with st.sidebar: |
|
ner_choice = st.selectbox("Select NER Model", ["Alpaca", "Spacy"]) |
|
|
|
if ner_choice == "Alpaca": |
|
ner_model = get_flan_alpaca_xl_model() |
|
else: |
|
ner_model = get_spacy_model() |
|
|
|
with col1: |
|
st.subheader("Question") |
|
query_text = st.text_area( |
|
"Input Query", |
|
value="What was discussed regarding Wearables revenue performance?", |
|
) |
|
|
|
if ner_choice == "Alpaca": |
|
entity_text = generate_entities_flan_alpaca(ner_model) |
|
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text) |
|
else: |
|
company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model) |
|
|
|
ticker_index, quarter_index, year_index = clean_entities( |
|
company_ent, quarter_ent, year_ent |
|
) |
|
|
|
with col1: |
|
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"] |
|
|
|
with col1: |
|
|
|
if ( |
|
query_text |
|
== "What was discussed regarding Wearables revenue performance?" |
|
): |
|
year = st.selectbox("Year", years_choice) |
|
else: |
|
year = st.selectbox("Year", years_choice, index=year_index) |
|
|
|
with col1: |
|
|
|
if ( |
|
query_text |
|
== "What was discussed regarding Wearables revenue performance?" |
|
): |
|
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"]) |
|
else: |
|
quarter = st.selectbox( |
|
"Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index |
|
) |
|
|
|
with col1: |
|
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"]) |
|
|
|
ticker_choice = [ |
|
"AAPL", |
|
"CSCO", |
|
"MSFT", |
|
"ASML", |
|
"NVDA", |
|
"GOOGL", |
|
"MU", |
|
"INTC", |
|
"AMZN", |
|
"AMD", |
|
] |
|
|
|
with col1: |
|
|
|
if ( |
|
query_text |
|
== "What was discussed regarding Wearables revenue performance?" |
|
): |
|
ticker = st.selectbox("Company", ticker_choice) |
|
else: |
|
ticker = st.selectbox("Company", ticker_choice, ticker_index) |
|
|
|
with st.sidebar: |
|
st.subheader("Select Options:") |
|
|
|
with st.sidebar: |
|
num_results = int( |
|
st.number_input("Number of Results to query", 1, 15, value=5) |
|
) |
|
|
|
|
|
|
|
|
|
encoder_models_choice = ["MPNET", "SGPT", "Hybrid MPNET - SPLADE"] |
|
with st.sidebar: |
|
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) |
|
|
|
|
|
|
|
|
|
decoder_models_choice = ["GPT3 - (text-davinci-003)", "T5", "FLAN-T5", "GPT-J"] |
|
|
|
with st.sidebar: |
|
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) |
|
|
|
|
|
if encoder_model == "MPNET": |
|
|
|
pinecone.init( |
|
api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp" |
|
) |
|
pinecone_index_name = "week2-all-mpnet-base" |
|
pinecone_index = pinecone.Index(pinecone_index_name) |
|
retriever_model = get_mpnet_embedding_model() |
|
|
|
elif encoder_model == "SGPT": |
|
|
|
pinecone.init( |
|
api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp" |
|
) |
|
pinecone_index_name = "week2-sgpt-125m" |
|
pinecone_index = pinecone.Index(pinecone_index_name) |
|
retriever_model = get_sgpt_embedding_model() |
|
|
|
elif encoder_model == "Hybrid MPNET - SPLADE": |
|
pinecone.init( |
|
api_key=st.secrets["pinecone_hybrid_splade_mpnet"], |
|
environment="us-central1-gcp", |
|
) |
|
pinecone_index_name = "splade-mpnet" |
|
pinecone_index = pinecone.Index(pinecone_index_name) |
|
retriever_model = get_mpnet_embedding_model() |
|
( |
|
sparse_retriever_model, |
|
sparse_retriever_tokenizer, |
|
) = get_splade_sparse_embedding_model() |
|
|
|
with st.sidebar: |
|
window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) |
|
|
|
with st.sidebar: |
|
threshold = float( |
|
st.number_input( |
|
label="Similarity Score Threshold", |
|
step=0.05, |
|
format="%.2f", |
|
value=0.25, |
|
) |
|
) |
|
|
|
data = get_data() |
|
|
|
if encoder_model == "Hybrid SGPT - SPLADE": |
|
dense_query_embedding = create_dense_embeddings( |
|
query_text, retriever_model |
|
) |
|
sparse_query_embedding = create_sparse_embeddings( |
|
query_text, sparse_retriever_model, sparse_retriever_tokenizer |
|
) |
|
dense_query_embedding, sparse_query_embedding = hybrid_score_norm( |
|
dense_query_embedding, sparse_query_embedding, 0 |
|
) |
|
query_results = query_pinecone_sparse( |
|
dense_query_embedding, |
|
sparse_query_embedding, |
|
num_results, |
|
pinecone_index, |
|
year, |
|
quarter, |
|
ticker, |
|
participant_type, |
|
threshold, |
|
) |
|
|
|
else: |
|
dense_query_embedding = create_dense_embeddings( |
|
query_text, retriever_model |
|
) |
|
query_results = query_pinecone( |
|
dense_query_embedding, |
|
num_results, |
|
pinecone_index, |
|
year, |
|
quarter, |
|
ticker, |
|
participant_type, |
|
threshold, |
|
) |
|
|
|
|
|
if threshold <= 0.90: |
|
context_list = sentence_id_combine(data, query_results, lag=window) |
|
else: |
|
context_list = format_query(query_results) |
|
|
|
|
|
if decoder_model == "GPT3 - (text-davinci-003)": |
|
prompt = generate_gpt_prompt(query_text, context_list) |
|
with col2: |
|
with st.form("my_form"): |
|
edited_prompt = st.text_area( |
|
label="Model Prompt", value=prompt, height=270 |
|
) |
|
|
|
openai_key = st.text_input( |
|
"Enter OpenAI key", |
|
value="", |
|
type="password", |
|
) |
|
submitted = st.form_submit_button("Submit") |
|
if submitted: |
|
api_key = save_key(openai_key) |
|
openai.api_key = api_key |
|
generated_text = gpt_model(edited_prompt) |
|
st.subheader("Answer:") |
|
st.write(generated_text) |
|
|
|
|
|
elif decoder_model == "T5": |
|
prompt = generate_flant5_prompt_instruct_complete_context( |
|
query_text, context_list |
|
) |
|
t5_pipeline = get_t5_model() |
|
output_text = [] |
|
with col2: |
|
with st.form("my_form"): |
|
edited_prompt = st.text_area( |
|
label="Model Prompt", value=prompt, height=270 |
|
) |
|
context_list = get_context_list_prompt(edited_prompt) |
|
submitted = st.form_submit_button("Submit") |
|
if submitted: |
|
for context_text in context_list: |
|
output_text.append( |
|
t5_pipeline(context_text)[0]["summary_text"] |
|
) |
|
st.subheader("Answer:") |
|
for text in output_text: |
|
st.markdown(f"- {text}") |
|
|
|
elif decoder_model == "FLAN-T5": |
|
flan_t5_model, flan_t5_tokenizer = get_flan_t5_model() |
|
output_text = [] |
|
with col2: |
|
prompt_type = st.selectbox( |
|
"Select prompt type", |
|
["Complete Text QA", "Chunkwise QA", "Chunkwise Summarize"], |
|
) |
|
if prompt_type == "Complete Text QA": |
|
prompt = generate_flant5_prompt_instruct_complete_context( |
|
query_text, context_list |
|
) |
|
elif prompt_type == "Chunkwise QA": |
|
st.write("The following prompt is not editable.") |
|
prompt = generate_flant5_prompt_instruct_chunk_context( |
|
query_text, context_list |
|
) |
|
elif prompt_type == "Chunkwise Summarize": |
|
st.write("The following prompt is not editable.") |
|
prompt = generate_flant5_prompt_summ_chunk_context( |
|
query_text, context_list |
|
) |
|
else: |
|
prompt = "" |
|
with st.form("my_form"): |
|
edited_prompt = st.text_area( |
|
label="Model Prompt", value=prompt, height=270 |
|
) |
|
submitted = st.form_submit_button("Submit") |
|
if submitted: |
|
if prompt_type == "Complete Text QA": |
|
output_text_string = generate_text_flan_t5( |
|
flan_t5_model, flan_t5_tokenizer, prompt |
|
) |
|
st.subheader("Answer:") |
|
st.write(output_text_string) |
|
elif prompt_type == "Chunkwise QA": |
|
for context_text in context_list: |
|
model_input = generate_flant5_prompt_instruct_chunk_context_single( |
|
query_text, context_text |
|
) |
|
output_text.append( |
|
generate_text_flan_t5( |
|
flan_t5_model, flan_t5_tokenizer, model_input |
|
) |
|
) |
|
st.subheader("Answer:") |
|
for text in output_text: |
|
if "(iii)" not in text: |
|
st.markdown(f"- {text}") |
|
elif prompt_type == "Chunkwise Summarize": |
|
for context_text in context_list: |
|
model_input = ( |
|
generate_flant5_prompt_summ_chunk_context_single( |
|
query_text, context_text |
|
) |
|
) |
|
output_text.append( |
|
generate_text_flan_t5( |
|
flan_t5_model, flan_t5_tokenizer, model_input |
|
) |
|
) |
|
st.subheader("Answer:") |
|
for text in output_text: |
|
if "(iii)" not in text: |
|
st.markdown(f"- {text}") |
|
|
|
if decoder_model == "GPT-J": |
|
if ticker in ["AAPL", "AMD"]: |
|
prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list) |
|
elif ticker in ["NVDA", "INTC", "AMZN"]: |
|
prompt = generate_gpt_j_two_shot_prompt_2(query_text, context_list) |
|
else: |
|
prompt = generate_gpt_j_two_shot_prompt_1(query_text, context_list) |
|
with col2: |
|
with st.form("my_form"): |
|
edited_prompt = st.text_area( |
|
label="Model Prompt", value=prompt, height=270 |
|
) |
|
st.write( |
|
"The app currently just shows the prompt. The app does not load the model due to memory limitations." |
|
) |
|
submitted = st.form_submit_button("Submit") |
|
|
|
|
|
with col1: |
|
with st.expander("See Retrieved Text"): |
|
for context_text in context_list: |
|
st.markdown(f"- {context_text}") |
|
|
|
file_text = retrieve_transcript(data, year, quarter, ticker) |
|
|
|
with col1: |
|
with st.expander("See Transcript"): |
|
stx.scrollableTextbox( |
|
file_text, height=700, border=False, fontFamily="Helvetica" |
|
) |
|
|