nickmuchi commited on
Commit
ee6d004
Β·
1 Parent(s): 754ce49

Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py

Browse files
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py CHANGED
@@ -18,49 +18,55 @@ if "sen_df" not in st.session_state:
18
  if "earnings_passages" not in st.session_state:
19
  st.session_state["earnings_passages"] = ''
20
 
21
- ## Save to a dataframe for ease of visualization
22
- sen_df = st.session_state['sen_def']
23
-
24
- passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
25
-
26
- ##### Sematic Search #####
27
- # Encode the query using the bi-encoder and find potentially relevant passages
28
- corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
29
- question_embedding = sbert.encode(search_input, convert_to_tensor=True)
30
- question_embedding = question_embedding.cpu()
31
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
32
- hits = hits[0] # Get the hits for the first query
33
-
34
- ##### Re-Ranking #####
35
- # Now, score all retrieved passages with the cross_encoder
36
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
37
- cross_scores = cross_encoder.predict(cross_inp)
38
-
39
- # Sort results by the cross-encoder scores
40
- for idx in range(len(cross_scores)):
41
- hits[idx]['cross-score'] = cross_scores[idx]
42
-
43
- # Output of top-3 hits from re-ranker
44
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
45
 
46
- score='cross-score'
47
- df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:2]],columns=['Score','Text'])
48
- df['Score'] = round(df['Score'],2)
49
-
50
- def gen_annotated_text(para):
51
- tag_list = []
52
- for i in sent_tokenize(para):
53
- label = sen_df.loc[sen_df['text']==i, 'label'].values[0]
54
- if label == 'Negative':
55
- tag_list.append((i,label,'#faa'))
56
- elif label == 'Positive':
57
- tag_list.append((i,label,'#afa'))
58
- else:
59
- tag_list.append((i,label,'#fea'))
60
- return tag_list
61
-
62
- text_to_annotate = [gen_annotated_text(para) for para in df.Text.tolist()]
63
-
64
- for i in text_to_annotate:
65
- annotated_text(i)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
18
  if "earnings_passages" not in st.session_state:
19
  st.session_state["earnings_passages"] = ''
20
 
21
+ if st.session_state["sen_df"] or st.session_state["earnings_passages"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ## Save to a dataframe for ease of visualization
24
+ sen_df = st.session_state['sen_def']
25
+
26
+ passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
27
+
28
+ ##### Sematic Search #####
29
+ # Encode the query using the bi-encoder and find potentially relevant passages
30
+ corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
31
+ question_embedding = sbert.encode(search_input, convert_to_tensor=True)
32
+ question_embedding = question_embedding.cpu()
33
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
34
+ hits = hits[0] # Get the hits for the first query
35
+
36
+ ##### Re-Ranking #####
37
+ # Now, score all retrieved passages with the cross_encoder
38
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
39
+ cross_scores = cross_encoder.predict(cross_inp)
40
+
41
+ # Sort results by the cross-encoder scores
42
+ for idx in range(len(cross_scores)):
43
+ hits[idx]['cross-score'] = cross_scores[idx]
44
+
45
+ # Output of top-3 hits from re-ranker
46
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
47
+
48
+ score='cross-score'
49
+ df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:2]],columns=['Score','Text'])
50
+ df['Score'] = round(df['Score'],2)
51
+
52
+ def gen_annotated_text(para):
53
+ tag_list = []
54
+ for i in sent_tokenize(para):
55
+ label = sen_df.loc[sen_df['text']==i, 'label'].values[0]
56
+ if label == 'Negative':
57
+ tag_list.append((i,label,'#faa'))
58
+ elif label == 'Positive':
59
+ tag_list.append((i,label,'#afa'))
60
+ else:
61
+ tag_list.append((i,label,'#fea'))
62
+ return tag_list
63
+
64
+ text_to_annotate = [gen_annotated_text(para) for para in df.Text.tolist()]
65
+
66
+ for i in text_to_annotate:
67
+ annotated_text(i)
68
+
69
+ else:
70
+
71
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
72