nickmuchi commited on
Commit
e429024
1 Parent(s): 446f9c9

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +56 -0
functions.py CHANGED
@@ -52,6 +52,62 @@ def load_sbert(model_name):
52
 
53
  return sbert
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @st.experimental_singleton(suppress_st_warning=True)
56
  def get_spacy():
57
  nlp = en_core_web_lg.load()
 
52
 
53
  return sbert
54
 
55
+ @st.experimental_memo(suppress_st_warning=True)
56
+ def embed_text(query,corpus,embedding_model):
57
+
58
+ '''Embed text and generate semantic search scores'''
59
+
60
+ #If model is e5 then apply prefixes to query and passage
61
+ if embedding_model == 'intfloat/e5-base':
62
+ search_input = 'query: '+ query
63
+ passages_emb = ['passage: ' + sentence for sentence in corpus]
64
+
65
+ elif embedding_model == 'hkunlp/instructor-base':
66
+ search_input = [['Represent the Financial question for retrieving supporting documents; Input: ', query, 0]]
67
+ passages_emb = [['Represent the Financial document for retrieval; Input: ',sentence,0] for sentence in corpus]
68
+
69
+ else:
70
+ search_input = query
71
+ passages_emb = corpus
72
+
73
+
74
+ #Embed corpus and question
75
+ corpus_embedding = sbert.encode(passages_emb, convert_to_tensor=True)
76
+ question_embedding = sbert.encode(search_input, convert_to_tensor=True)
77
+ question_embedding = question_embedding.cpu()
78
+ corpus_embedding = corpus_embedding.cpu()
79
+
80
+ # #Calculate similarity scores and rank
81
+ hits = util.semantic_search(question_embedding, corpus_embedding, top_k=2)
82
+ hits = hits[0] # Get the hits for the first query
83
+
84
+ # ##### Re-Ranking #####
85
+ # Now, score all retrieved passages with the cross_encoder
86
+ cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
87
+
88
+ if embedding_model == 'hkunlp/instructor-base':
89
+ result = []
90
+
91
+ for sublist in cross_inp:
92
+ question = sublist[0][0][1]
93
+ document = sublist[1][1]
94
+ result.append([question, document])
95
+
96
+ cross_inp = result
97
+
98
+ cross_scores = cross_encoder.predict(cross_inp)
99
+
100
+ # Sort results by the cross-encoder scores
101
+ for idx in range(len(cross_scores)):
102
+ hits[idx]['cross-score'] = cross_scores[idx]
103
+
104
+ # Output of top-3 hits from re-ranker
105
+ # st.markdown("\n-------------------------\n")
106
+ # st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
107
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
108
+
109
+ return hits
110
+
111
  @st.experimental_singleton(suppress_st_warning=True)
112
  def get_spacy():
113
  nlp = en_core_web_lg.load()