File size: 3,357 Bytes
411678e
31b6e92
411678e
ce57a20
411678e
ce57a20
3b52176
d98d50d
 
 
446f9c9
e232116
16f2ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
e232116
446f9c9
 
 
d98d50d
3b52176
e741287
3b52176
446f9c9
e694dea
9f7471f
 
3b52176
a957eeb
 
 
741aa8b
a957eeb
741aa8b
a957eeb
 
e232116
 
09e96c9
 
e232116
a957eeb
09e96c9
a957eeb
ba86e7b
cffcba4
446f9c9
cffcba4
e232116
c712f91
e232116
 
 
 
 
a957eeb
e232116
a957eeb
09c79f1
e232116
a957eeb
e232116
a957eeb
e232116
 
a957eeb
e232116
 
 
a957eeb
 
 
 
741aa8b
 
a957eeb
 
741aa8b
82bf281
a957eeb
953c510
8eb51fc
ee6d004
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
from functions import *

st.set_page_config(page_title="Earnings Question/Answering", page_icon="πŸ”Ž")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with LangChain, OpenAI & SBert")

def gen_sentiment(text):
    '''Generate sentiment of given text'''
    return sent_pipe(text)[0]['label']

def gen_annotated_text(df):
    '''Generate annotated text'''
    
    tag_list=[]
    for row in df.itertuples():
        label = row[2]
        text = row[1]
        if label == 'Positive':
            tag_list.append((text,label,'#8fce00'))
        elif label == 'Negative':
            tag_list.append((text,label,'#f44336'))
        else:
            tag_list.append((text,label,'#000000'))
        
    return tag_list

bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
              'instructor-base': 'hkunlp/instructor-base',
              'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}

search_input = st.text_input(
        label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
        
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
        
chunk_size = st.sidebar.slider("Number of Chars per Chunk of Text",min_value=500,max_value=2000,value=1000)
overlap_size = st.sidebar.slider("Number of Overlap Chars in Search Response",min_value=50,max_value=300,value=50)

try:

    if search_input:
        
        if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
        
            ## Save to a dataframe for ease of visualization
            sen_df = st.session_state['sen_df']

            title = st.session_state['title']

            embedding_model = bi_enc_dict[sbert_model_name]
                            
            with st.spinner(
                text=f"Loading {embedding_model} embedding model and Generating Response..."
            ):
                
                docsearch = process_corpus(st.session_state['earnings_passages'],title, embedding_model)

                result = embed_text(search_input,title,embedding_model,docsearch,chain_type)


            references = [doc.page_content for doc in result['input_documents']]

            answer = result['output_text']

            sentiment_label = gen_sentiment(answer)
                
            ##### Sematic Search #####
            
            df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]})
              
            
            text_annotations = gen_annotated_text(df)[0]            
            
            with st.expander(label='Query Result', expanded=True):
                annotated_text(text_annotations)
                
            with st.expander(label='References from Corpus used to Generate Result'):
                for ref in references:
                    st.write(ref)
                
        else:
            
            st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
            
    else:
    
        st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')  
        
except RuntimeError:
  
    st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')