z00mP commited on
Commit
d35069c
1 Parent(s): aa1f83e

add different chunk tables and emb models

Browse files
Files changed (2) hide show
  1. app.py +5 -2
  2. backend/semantic_search.py +5 -3
app.py CHANGED
@@ -44,7 +44,9 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, eross_encode
44
  # Retrieve documents relevant to query
45
  document_start = perf_counter()
46
 
47
- documents = retrieve(query, TOP_K)
 
 
48
 
49
  document_time = perf_counter() - document_start
50
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
@@ -118,10 +120,11 @@ with gr.Blocks() as demo:
118
  )
119
  eross_encoder = gr.Radio(
120
  choices=[
 
121
  "BAAI/bge-reranker-large",
122
  "cross-encoder/ms-marco-MiniLM-L-6-v2",
123
  ],
124
- value="cross-encoder/ms-marco-MiniLM-L-6-v2",
125
  label='Cross-encoder model'
126
  )
127
  top_k_param = gr.Radio(
 
44
  # Retrieve documents relevant to query
45
  document_start = perf_counter()
46
 
47
+ #documents = retrieve(query, TOP_K)
48
+ documents = retrieve(query, top_k_param, chunk_table, embedding_model)
49
+
50
 
51
  document_time = perf_counter() - document_start
52
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
 
120
  )
121
  eross_encoder = gr.Radio(
122
  choices=[
123
+ "None"
124
  "BAAI/bge-reranker-large",
125
  "cross-encoder/ms-marco-MiniLM-L-6-v2",
126
  ],
127
+ value="None",
128
  label='Cross-encoder model'
129
  )
130
  top_k_param = gr.Radio(
backend/semantic_search.py CHANGED
@@ -6,15 +6,17 @@ from sentence_transformers import SentenceTransformer
6
 
7
  db = lancedb.connect(".lancedb")
8
 
9
- TABLE = db.open_table(os.getenv("TABLE_NAME"))
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
 
14
- retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
15
 
16
 
17
- def retrieve(query, k):
 
 
18
  query_vec = retriever.encode(query)
19
  try:
20
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
 
6
 
7
  db = lancedb.connect(".lancedb")
8
 
9
+ #TABLE = db.open_table(os.getenv("TABLE_NAME"))
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
 
14
+ #retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
15
 
16
 
17
+ def retrieve(query, k, table_name, emb_name):
18
+ TABLE = db.open_table(table_name)
19
+ retriever = SentenceTransformer(emb_name)
20
  query_vec = retriever.encode(query)
21
  try:
22
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()