nickmuchi commited on
Commit
741aa8b
β€’
1 Parent(s): ecca100

Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py

Browse files
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py CHANGED
@@ -18,57 +18,64 @@ if "sen_df" not in st.session_state:
18
  if "earnings_passages" not in st.session_state:
19
  st.session_state["earnings_passages"] = ''
20
 
21
- if any(st.session_state["sen_df"]) or st.session_state["earnings_passages"]:
22
 
23
- ## Save to a dataframe for ease of visualization
24
- sen_df = st.session_state['sen_df']
25
-
26
- passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
27
-
28
- ##### Sematic Search #####
29
- # Encode the query using the bi-encoder and find potentially relevant passages
30
- corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
31
- question_embedding = sbert.encode(search_input, convert_to_tensor=True)
32
- question_embedding = question_embedding.cpu()
33
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
34
- hits = hits[0] # Get the hits for the first query
35
-
36
- ##### Re-Ranking #####
37
- # Now, score all retrieved passages with the cross_encoder
38
- cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
39
- cross_scores = cross_encoder.predict(cross_inp)
40
-
41
- # Sort results by the cross-encoder scores
42
- for idx in range(len(cross_scores)):
43
- hits[idx]['cross-score'] = cross_scores[idx]
44
-
45
- # Output of top-3 hits from re-ranker
46
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
47
 
48
- score='cross-score'
49
- df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
50
- df['Score'] = round(df['Score'],2)
51
-
52
- print(f'Test: {df}')
53
-
54
- def gen_annotated_text(para):
55
- tag_list = []
56
- for i in sent_tokenize(para):
57
- label = sen_df.loc[sen_df['text']==i, 'label'].values[0]
58
- if label == 'Negative':
59
- tag_list.append((i,label,'#faa'))
60
- elif label == 'Positive':
61
- tag_list.append((i,label,'#afa'))
62
- else:
63
- tag_list.append((i,label,'#fea'))
64
- return tag_list
65
-
66
- text_to_annotate = [gen_annotated_text(para) for para in df.Text.tolist()]
67
-
68
- for i in text_to_annotate:
69
- annotated_text(i)
70
 
71
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
74
-
 
18
  if "earnings_passages" not in st.session_state:
19
  st.session_state["earnings_passages"] = ''
20
 
21
+ if search_input is not None:
22
 
23
+ if any(st.session_state["sen_df"]) or st.session_state["earnings_passages"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ ## Save to a dataframe for ease of visualization
26
+ sen_df = st.session_state['sen_df']
27
+
28
+ passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
29
+
30
+ ##### Sematic Search #####
31
+ # Encode the query using the bi-encoder and find potentially relevant passages
32
+ corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
33
+ question_embedding = sbert.encode(search_input, convert_to_tensor=True)
34
+ question_embedding = question_embedding.cpu()
35
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
36
+ hits = hits[0] # Get the hits for the first query
 
 
 
 
 
 
 
 
 
 
37
 
38
+ ##### Re-Ranking #####
39
+ # Now, score all retrieved passages with the cross_encoder
40
+ cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
41
+ cross_scores = cross_encoder.predict(cross_inp)
42
+
43
+ # Sort results by the cross-encoder scores
44
+ for idx in range(len(cross_scores)):
45
+ hits[idx]['cross-score'] = cross_scores[idx]
46
+
47
+ # Output of top-3 hits from re-ranker
48
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
49
+
50
+ score='cross-score'
51
+ df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
52
+ df['Score'] = round(df['Score'],2)
53
+
54
+ print(f'Test: {df}')
55
+
56
+ def gen_annotated_text(para):
57
+ tag_list = []
58
+ for i in sent_tokenize(para):
59
+ label = sen_df.loc[sen_df['text']==i, 'label'].values[0]
60
+ if label == 'Negative':
61
+ tag_list.append((i,label,'#faa'))
62
+ elif label == 'Positive':
63
+ tag_list.append((i,label,'#afa'))
64
+ else:
65
+ tag_list.append((i,label,'#fea'))
66
+ return tag_list
67
+
68
+ text_to_annotate = [gen_annotated_text(para) for para in df.Text.tolist()]
69
+
70
+ first,second = text_to_annotate[0],text_to_annotate[-1]
71
+
72
+ with st.container():
73
+ annotate_text(*first)
74
+
75
+ with st.container():
76
+ annotate_text(*second)
77
+
78
+ else:
79
+
80
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
81