File size: 3,341 Bytes
411678e
31b6e92
411678e
 
 
3b52176
 
d98d50d
 
 
446f9c9
e232116
16f2ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
e232116
446f9c9
 
 
d98d50d
3b52176
e741287
3b52176
446f9c9
e694dea
e232116
 
 
3b52176
a957eeb
 
 
741aa8b
a957eeb
741aa8b
a957eeb
 
e232116
 
 
a957eeb
e232116
a957eeb
446f9c9
e232116
2d0cd10
e232116
c712f91
e232116
 
 
 
 
a957eeb
e232116
a957eeb
09c79f1
e232116
a957eeb
e232116
a957eeb
e232116
 
a957eeb
e232116
 
 
a957eeb
 
 
 
741aa8b
 
a957eeb
 
741aa8b
82bf281
a957eeb
953c510
8eb51fc
ee6d004
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
from functions import *

st.set_page_config(page_title="Earnings Semantic Search", page_icon="πŸ”Ž")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with SBert")

def gen_sentiment(text):
    '''Generate sentiment of given text'''
    return sent_pipe(text)[0]['label']

def gen_annotated_text(df):
    '''Generate annotated text'''
    
    tag_list=[]
    for row in df.itertuples():
        label = row[2]
        text = row[1]
        if label == 'Positive':
            tag_list.append((text,label,'#8fce00'))
        elif label == 'Negative':
            tag_list.append((text,label,'#f44336'))
        else:
            tag_list.append((text,label,'#000000'))
        
    return tag_list

bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
              'instructor-base': 'hkunlp/instructor-base',
              'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}

search_input = st.text_input(
        label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
        
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
        
chunk_size = st.sidebar.slider("Number of Words per Chunk of Text",min_value=100,max_value=250,value=200)
overlap_size = st.sidebar.slider("Number of Overlap Words in Search Response",min_value=30,max_value=100,value=50)
chain_type = st.sidebar.radio("Langchain Chain Type",options = ['Normal','Refined'])

try:

    if search_input:
        
        if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
        
            ## Save to a dataframe for ease of visualization
            sen_df = st.session_state['sen_df']

            title = st.session_state['title']
                            
            with st.spinner(
                text=f"Loading {bi_enc_dict[sbert_model_name]} embedding model and Generating Response..."
            ):

                result = embed_text(search_input,st.session_state['earnings_passages'],title,
                                    bi_enc_dict[sbert_model_name],emb_tokenizer,chain_type)


            references = [doc.page_content for doc in result['input_documents']]

            answer = result['output_text']

            sentiment_label = gen_sentiment(answer)
                
            ##### Sematic Search #####
            
            df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]})
              
            
            text_annotations = gen_annotated_text(df)[0]            
            
            with st.expander(label='Query Result', expanded=True):
                annotated_text(text_annotations)
                
            with st.expander(label='References from Corpus used to Generate Result'):
                for ref in references:
                    st.write(ref)
                
        else:
            
            st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
            
    else:
    
        st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')  
        
except RuntimeError:
  
    st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')