LeetTools commited on
Commit
1af3a2c
·
verified ·
1 Parent(s): d493b39

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +50 -7
ask.py CHANGED
@@ -346,19 +346,62 @@ CREATE TABLE {table_name} (
346
 
347
  self.logger.debug(query_result)
348
 
349
- matched_chunks = []
350
- for record in query_result.fetchall():
 
 
351
  result_record = {
352
- "url": record[1],
353
- "chunk": record[2],
354
  }
355
- matched_chunks.append(result_record)
356
 
357
  if settings.hybrid_search:
358
  self.logger.info("Running full-text search ...")
359
- pass
360
 
361
- return matched_chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  def _get_api_client(self) -> OpenAI:
364
  return OpenAI(api_key=self.llm_api_key, base_url=self.llm_base_url)
 
346
 
347
  self.logger.debug(query_result)
348
 
349
+ # use a dict to remove duplicates from vector search and full-text search
350
+ matched_chunks_dict = {}
351
+ for vec_result in query_result.fetchall():
352
+ doc_id = vec_result[0]
353
  result_record = {
354
+ "url": vec_result[1],
355
+ "chunk": vec_result[2],
356
  }
357
+ matched_chunks_dict[doc_id] = result_record
358
 
359
  if settings.hybrid_search:
360
  self.logger.info("Running full-text search ...")
 
361
 
362
+ self.db_con.execute(
363
+ f"""
364
+ PREPARE fts_query AS (
365
+ WITH scored_docs AS (
366
+ SELECT *, fts_main_{table_name}.match_bm25(
367
+ doc_id, ?, fields := 'chunk'
368
+ ) AS score FROM {table_name})
369
+ SELECT doc_id, url, chunk, score
370
+ FROM scored_docs
371
+ WHERE score IS NOT NULL
372
+ ORDER BY score DESC
373
+ LIMIT 10)
374
+ """
375
+ )
376
+ self.db_con.execute("PRAGMA threads=4")
377
+
378
+ # You can run more complex query rewrite methods here
379
+ # usually: stemming, stop words, etc.
380
+ escaped_query = query.replace("'", " ")
381
+ fts_result: duckdb.DuckDBPyRelation = self.db_con.execute(
382
+ f"EXECUTE fts_query('{escaped_query}')"
383
+ )
384
+
385
+ index = 0
386
+ for fts_record in fts_result.fetchall():
387
+ index += 1
388
+ self.logger.debug(f"The full text search record #{index}: {fts_record}")
389
+ doc_id = fts_record[0]
390
+ result_record = {
391
+ "url": fts_record[1],
392
+ "chunk": fts_record[2],
393
+ }
394
+
395
+ # You can configure the score threashold and top-k
396
+ if fts_record[3] > 1:
397
+ matched_chunks_dict[doc_id] = result_record
398
+ else:
399
+ break
400
+
401
+ if index >= 10:
402
+ break
403
+
404
+ return matched_chunks_dict.values()
405
 
406
  def _get_api_client(self) -> OpenAI:
407
  return OpenAI(api_key=self.llm_api_key, base_url=self.llm_base_url)