Sourab Mangrulkar commited on
Commit
7d055fb
β€’
1 Parent(s): 6f4afc6
Files changed (2) hide show
  1. app.py +17 -2
  2. search_index.bin β†’ embeddings.npy +2 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import re
4
  from sentence_transformers import SentenceTransformer, CrossEncoder
5
  import hnswlib
 
6
  from typing import Iterator
7
 
8
  import gradio as gr
@@ -21,6 +22,7 @@ EMBED_DIM = 1024
21
  K = 10
22
  EF = 100
23
  SEARCH_INDEX = "search_index.bin"
 
24
  DOCUMENT_DATASET = "chunked_data.parquet"
25
  COSINE_THRESHOLD = 0.7
26
 
@@ -119,6 +121,19 @@ def load_hnsw_index(index_file):
119
  return index
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def create_query_embedding(query):
123
  # Encode the query to get its embedding
124
  embedding = biencoder.encode([query], normalize_embeddings=True)[0]
@@ -274,7 +289,7 @@ def check_input_token_length(message: str, chat_history: list[tuple[str, str]],
274
  )
275
 
276
 
277
- search_index = load_hnsw_index(SEARCH_INDEX)
278
  data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
279
  with gr.Blocks(css="style.css") as demo:
280
  gr.Markdown(DESCRIPTION)
@@ -448,4 +463,4 @@ with gr.Blocks(css="style.css") as demo:
448
  api_name=False,
449
  )
450
 
451
- demo.queue(max_size=20).launch(debug=True, share=True)
 
3
  import re
4
  from sentence_transformers import SentenceTransformer, CrossEncoder
5
  import hnswlib
6
+ import numpy as np
7
  from typing import Iterator
8
 
9
  import gradio as gr
 
22
  K = 10
23
  EF = 100
24
  SEARCH_INDEX = "search_index.bin"
25
+ EMBEDDINGS_FILE = "embeddings.npy"
26
  DOCUMENT_DATASET = "chunked_data.parquet"
27
  COSINE_THRESHOLD = 0.7
28
 
 
121
  return index
122
 
123
 
124
+ # create the index for the PEFT docs from numpy embeddings
125
+ # avoid the arch mismatches when creating search index
126
+ def create_hnsw_index(embeddings_file, M=16, efC=100):
127
+ embeddings = np.load(embeddings_file)
128
+ # Create the HNSW index
129
+ num_dim = embeddings.shape[1]
130
+ ids = np.arange(embeddings.shape[0])
131
+ index = hnswlib.Index(space="ip", dim=num_dim)
132
+ index.init_index(max_elements=embeddings.shape[0], ef_construction=efC, M=M)
133
+ index.add_items(embeddings, ids)
134
+ return index
135
+
136
+
137
  def create_query_embedding(query):
138
  # Encode the query to get its embedding
139
  embedding = biencoder.encode([query], normalize_embeddings=True)[0]
 
289
  )
290
 
291
 
292
+ search_index = create_hnsw_index(EMBEDDINGS_FILE) # load_hnsw_index(SEARCH_INDEX)
293
  data_df = pd.read_parquet(DOCUMENT_DATASET).reset_index()
294
  with gr.Blocks(css="style.css") as demo:
295
  gr.Markdown(DESCRIPTION)
 
463
  api_name=False,
464
  )
465
 
466
+ demo.queue(max_size=20).launch(debug=True, share=False)
search_index.bin β†’ embeddings.npy RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:14e38e3cb1c2b2e64977ca2ca5ded4ebff397412e228d6777304626448da8680
3
- size 4911056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d98d063ffe42060493c8e52bb0c8f0b33f57d6316dd0b27651ebdccad212defa
3
+ size 4735104