nickmuchi commited on
Commit
3c8f0ad
1 Parent(s): dbd6797

Update app.py

Browse files

Added e5 embedding model

Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -137,8 +137,15 @@ def bi_encode(bi_enc,passages):
137
 
138
  #Compute the embeddings using the multi-process pool
139
  with st.spinner('Encoding passages into a vector space...'):
140
-
141
- corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
 
 
 
 
 
 
 
142
 
143
  st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
144
 
@@ -178,7 +185,7 @@ def bm25_api(passages):
178
 
179
  return bm25
180
 
181
- bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1","neeva/query2query"]
182
 
183
  def display_df_as_table(model,top_k,score='score'):
184
  # Display the df with text and scores as a table
@@ -204,7 +211,7 @@ top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5
204
 
205
  # This function will search all wikipedia articles for passages that
206
  # answer the query
207
- def search_func(query, top_k=top_k):
208
 
209
  global bi_encoder, cross_encoder
210
 
@@ -229,6 +236,8 @@ def search_func(query, top_k=top_k):
229
  bm25_df = display_df_as_table(bm25_hits,top_k)
230
  st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
231
 
 
 
232
  ##### Sematic Search #####
233
  # Encode the query using the bi-encoder and find potentially relevant passages
234
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
 
137
 
138
  #Compute the embeddings using the multi-process pool
139
  with st.spinner('Encoding passages into a vector space...'):
140
+
141
+ if bi_enc == 'intfloat/e5-base':
142
+
143
+ corpus_embeddings = bi_encoder.encode(['passage: ' + sentence for sentence in passages], convert_to_tensor=True)
144
+
145
+ else:
146
+
147
+ corpus_embeddings = bi_encoder.encode([passages, convert_to_tensor=True)
148
+
149
 
150
  st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
151
 
 
185
 
186
  return bm25
187
 
188
+ bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1",'intfloat/e5-base',"neeva/query2query"]
189
 
190
  def display_df_as_table(model,top_k,score='score'):
191
  # Display the df with text and scores as a table
 
211
 
212
  # This function will search all wikipedia articles for passages that
213
  # answer the query
214
+ def search_func(query, top_k=top_k, bi_encoder_type):
215
 
216
  global bi_encoder, cross_encoder
217
 
 
236
  bm25_df = display_df_as_table(bm25_hits,top_k)
237
  st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
238
 
239
+ if bi_encoder_type == 'intfloat/e5-base':
240
+ query = 'query: ' + query
241
  ##### Sematic Search #####
242
  # Encode the query using the bi-encoder and find potentially relevant passages
243
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)