Earnings-Call-Analysis-Whisperer / pages /3_Earnings_Semantic_Search_πŸ”Ž_.py
nickmuchi's picture
Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py
e232116
raw
history blame
3.6 kB
import streamlit as st
from functions import *
st.set_page_config(page_title="Earnings Semantic Search", page_icon="πŸ”Ž")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with SBert")
def gen_sentiment(text):
'''Generate sentiment of given text'''
return sent_pipe(text)[0]['label']
def gen_annotated_text(df):
'''Generate annotated text'''
tag_list=[]
for row in df.itertuples():
label = row[3]
text = row[2]
if label == 'Positive':
tag_list.append((text,label,'#8fce00'))
elif label == 'Negative':
tag_list.append((text,label,'#f44336'))
else:
tag_list.append((text,label,'#000000'))
return tag_list
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
'instructor-base': 'hkunlp/instructor-base',
'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}
search_input = st.text_input(
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
chunk_size = st.sidebar.slider("Number of Words per Chunk of Text",min_value=100,max_value=250,value=200)
overlap_size = st.sidebar.slider("Number of Overlap Words in Search Response",min_value=30,max_value=100,value=50)
chain_type = st.sidebar.radio("Langchain Chain Type",options = ['Normal','Refined'])
try:
if search_input:
if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
## Save to a dataframe for ease of visualization
sen_df = st.session_state['sen_df']
title = st.session_state['title']
with st.spinner(
text=f"Loading {bi_enc_dict[sbert_model_name]} embedding model and Generating Response..."
):
result = embed_text(search_input,st.session_state['earnings_passages'],title,
bi_enc_dict[sbert_model_name],
emb_tokenizer,chain_type=chain_type)
references = [doc.page_content for doc in result['input_documents']]
answer = result['output_text']
sentiment_label = gen_sentiment(answer)
##### Sematic Search #####
df = pd.DataFrame([(num,res,lab) for num, res, lab in zip(1,answer,sentiment_label)],columns=['Index','Text','Sentiment'])
text_annotations = gen_annotated_text(df)[0]
with st.expander(label='Query Result', expanded=True):
annotated_text(text_annotations)
with st.expander(label='References from Corpus used to Generate Result'):
for ref in references:
st.write(ref)
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
except RuntimeError:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')