nickmuchi commited on
Commit
3b52176
β€’
1 Parent(s): c2ce126

Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py

Browse files
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py CHANGED
@@ -3,4 +3,61 @@ from functions import *
3
 
4
  st.set_page_config(page_title="Earnings Semantic Search", page_icon="πŸ”Ž")
5
  st.sidebar.header("Semantic Search")
6
- st.markdown("## Earnings Semantic Search with SBert")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  st.set_page_config(page_title="Earnings Semantic Search", page_icon="πŸ”Ž")
5
  st.sidebar.header("Semantic Search")
6
+ st.markdown("## Earnings Semantic Search with SBert")
7
+
8
+ search_input = st.text_input(
9
+ label='Enter Your Search Query, e.g "What challenges did the business face?"', key='search')
10
+
11
+ top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)
12
+
13
+ window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=5,value=3)
14
+
15
+ earnings_sentiment, earnings_sentences = sentiment_pipe(earnings_passages)
16
+
17
+ with st.expander("See Transcribed Earnings Text"):
18
+ st.write(f"Number of Sentences: {len(earnings_sentences)}")
19
+
20
+ st.write(earnings_passages)
21
+
22
+
23
+ ## Save to a dataframe for ease of visualization
24
+ sen_df = pd.DataFrame(earnings_sentiment)
25
+ sen_df['text'] = earnings_sentences
26
+ grouped = pd.DataFrame(sen_df['label'].value_counts()).reset_index()
27
+ grouped.columns = ['sentiment','count']
28
+
29
+
30
+ passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
31
+
32
+ ##### Sematic Search #####
33
+ # Encode the query using the bi-encoder and find potentially relevant passages
34
+ corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
35
+ question_embedding = sbert.encode(search_input, convert_to_tensor=True)
36
+ question_embedding = question_embedding.cpu()
37
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
38
+ hits = hits[0] # Get the hits for the first query
39
+
40
+ ##### Re-Ranking #####
41
+ # Now, score all retrieved passages with the cross_encoder
42
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
43
+ cross_scores = cross_encoder.predict(cross_inp)
44
+
45
+ # Sort results by the cross-encoder scores
46
+ for idx in range(len(cross_scores)):
47
+ hits[idx]['cross-score'] = cross_scores[idx]
48
+
49
+ # Output of top-3 hits from bi-encoder
50
+ st.markdown("\n-------------------------\n")
51
+ st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
52
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
53
+
54
+ cross_df = display_df_as_table(hits,top_k)
55
+ st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
56
+
57
+ # Output of top-3 hits from re-ranker
58
+ st.markdown("\n-------------------------\n")
59
+ st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
60
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
61
+
62
+ rerank_df = display_df_as_table(hits,top_k,'cross-score')
63
+ st.write(rerank_df.to_html(index=False), unsafe_allow_html=True