File size: 4,629 Bytes
411678e
31b6e92
411678e
 
 
3b52176
 
d98d50d
 
 
446f9c9
 
 
 
 
 
d98d50d
3b52176
e741287
3b52176
446f9c9
e694dea
d50bad2
3b52176
e741287
3b52176
a957eeb
 
 
741aa8b
a957eeb
741aa8b
a957eeb
 
 
16975a3
d98d50d
a957eeb
446f9c9
a957eeb
446f9c9
d98d50d
741aa8b
a957eeb
 
446f9c9
 
 
 
 
a957eeb
446f9c9
 
 
 
a957eeb
446f9c9
 
 
a957eeb
446f9c9
 
 
a957eeb
446f9c9
 
a957eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741aa8b
 
a957eeb
 
741aa8b
82bf281
a957eeb
953c510
8eb51fc
ee6d004
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from functions import *

st.set_page_config(page_title="Earnings Semantic Search", page_icon="πŸ”Ž")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with SBert")

def gen_sentiment(text):
    '''Generate sentiment of given text'''
    return sent_pipe(text)[0]['label']

bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
               'e5-base':'intfloat/e5-base',
              'instructor-base': 'hkunlp/instructor-base',
              'mpnet-base-dot-v1':'multi-qa-mpnet-base-dot-v1',
              'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}

search_input = st.text_input(
        label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
        
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
        
top_k = 2

window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)

try:

    if search_input:
        
        if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
        
            ## Save to a dataframe for ease of visualization
            sen_df = st.session_state['sen_df']
                
            passages = chunk_long_text(st.session_state['earnings_passages'],150,window_size=window_size)
            
            with st.spinner(
                text=f"Loading {bi_enc_dict[sbert_model_name]} encoder model..."
            ):
                sbert = load_sbert(bi_enc_dict[sbert_model_name])
                
            
            ##### Sematic Search #####
            # Encode the query using the bi-encoder and find potentially relevant passages
            # corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
            # question_embedding = sbert.encode(search_input, convert_to_tensor=True)
            # question_embedding = question_embedding.cpu()
            # hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
            # hits = hits[0]  # Get the hits for the first query
            
            # ##### Re-Ranking #####
            # # Now, score all retrieved passages with the cross_encoder
            # cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
            # cross_scores = cross_encoder.predict(cross_inp)
            
            # # Sort results by the cross-encoder scores
            # for idx in range(len(cross_scores)):
            #     hits[idx]['cross-score'] = cross_scores[idx]
            
            # # Output of top-3 hits from re-ranker
            # hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
            embedding_model = bi_enc_dict[sbert_model_name]
            
            hits = embed_text(search_input,passages,embedding_model)

            score='cross-score'
            df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
            df['Score'] = round(df['Score'],2)
            df['Sentiment'] = df.Text.apply(gen_sentiment)
            
            def gen_annotated_text(df):
                '''Generate annotated text'''
                
                tag_list=[]
                for row in df.itertuples():
                    label = row[3]
                    text = row[2]
                    if label == 'Positive':
                        tag_list.append((text,label,'#8fce00'))
                    elif label == 'Negative':
                        tag_list.append((text,label,'#f44336'))
                    else:
                        tag_list.append((text,label,'#000000'))
                    
                return tag_list  
            
            text_annotations = gen_annotated_text(df)
    
            first, second = text_annotations[0], text_annotations[1]
            
            
            with st.expander(label='Best Search Query Result', expanded=True):
                annotated_text(first)
                
            with st.expander(label='Alternative Search Query Result'):
                annotated_text(second)
                
        else:
            
            st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
            
    else:
    
        st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')  
        
except RuntimeError:
  
    st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')