File size: 3,632 Bytes
411678e 31b6e92 411678e 3b52176 d98d50d 446f9c9 e232116 446f9c9 d98d50d 3b52176 e741287 3b52176 446f9c9 e694dea e232116 3b52176 a957eeb 741aa8b a957eeb 741aa8b a957eeb e232116 a957eeb e232116 a957eeb 446f9c9 e232116 c712f91 e232116 a957eeb e232116 a957eeb e232116 a957eeb e232116 a957eeb e232116 a957eeb e232116 a957eeb 741aa8b a957eeb 741aa8b 82bf281 a957eeb 953c510 8eb51fc ee6d004 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import streamlit as st
from functions import *
st.set_page_config(page_title="Earnings Semantic Search", page_icon="π")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with SBert")
def gen_sentiment(text):
'''Generate sentiment of given text'''
return sent_pipe(text)[0]['label']
def gen_annotated_text(df):
'''Generate annotated text'''
tag_list=[]
for row in df.itertuples():
label = row[3]
text = row[2]
if label == 'Positive':
tag_list.append((text,label,'#8fce00'))
elif label == 'Negative':
tag_list.append((text,label,'#f44336'))
else:
tag_list.append((text,label,'#000000'))
return tag_list
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
'instructor-base': 'hkunlp/instructor-base',
'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}
search_input = st.text_input(
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
chunk_size = st.sidebar.slider("Number of Words per Chunk of Text",min_value=100,max_value=250,value=200)
overlap_size = st.sidebar.slider("Number of Overlap Words in Search Response",min_value=30,max_value=100,value=50)
chain_type = st.sidebar.radio("Langchain Chain Type",options = ['Normal','Refined'])
try:
if search_input:
if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
## Save to a dataframe for ease of visualization
sen_df = st.session_state['sen_df']
title = st.session_state['title']
with st.spinner(
text=f"Loading {bi_enc_dict[sbert_model_name]} embedding model and Generating Response..."
):
result = embed_text(search_input,st.session_state['earnings_passages'],title,
bi_enc_dict[sbert_model_name],
emb_tokenizer,chain_type=chain_type)
print(result)
references = [doc.page_content for doc in result['input_documents']]
answer = result['output_text']
sentiment_label = gen_sentiment(answer)
##### Sematic Search #####
df = pd.DataFrame([(num,res,lab) for num, res, lab in zip(1,answer,sentiment_label)],columns=['Index','Text','Sentiment'])
text_annotations = gen_annotated_text(df)[0]
with st.expander(label='Query Result', expanded=True):
annotated_text(text_annotations)
with st.expander(label='References from Corpus used to Generate Result'):
for ref in references:
st.write(ref)
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
except RuntimeError:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
|