Spaces:
Running
Running
File size: 3,753 Bytes
89cd5d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# from src import *
import logging
import os
import lancedb
from lancedb.rerankers import ColbertReranker
import sys
logger = logging.getLogger(__name__)
from typing import Optional
from src.vectordb.helpers import set_uri
# db = lancedb.connect("/tmp/db")
def search(query: str, table_name: str, filter_condition: Optional[str] = None,
category: str = "docs", limit: int = 10, reranking: int = 0,
run_local: Optional[bool] = False) -> list | None:
"""
Generalized function to search a database table, with optional filters and reranking.
Args:
- query: str, search query.
- table_name: str, name of the table to search.
- filter_condition: Optional[str], optional SQL-like condition for filtering results.
- category: str, type of category (default is 'docs').
- limit: int, number of results (default is 10).
- reranking: int (0 or 1), if activated, ColbertReranker is used.
- run_local: Optional[bool], whether to run in a local environment.
Returns:
A list of the most relevant documents or listings based on the category.
"""
uri = set_uri(run_local)
try:
db = lancedb.connect(uri)
except Exception as e:
logger.error(f"Error while connecting to DB: {e}")
return None
logger.info(f"Connected to {table_name} DB.")
table = db.open_table(table_name)
search_query = table.search(query).metric('cosine')
if filter_condition:
search_query = search_query.where(filter_condition)
if reranking:
try:
column = 'description' if category == 'listings' else 'text'
reranker = ColbertReranker(column=column)
results = search_query.rerank(reranker=reranker).limit(limit).to_list()
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
logger.error(f"Error while reranking results: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
return None
else:
try:
results = search_query.limit(limit).to_list()
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
logger.error(f"Error while searching: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
return None
logger.info("Found the most relevant documents.")
if category == "docs":
return [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
results]
else:
return [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
"description": r['description']} for r in results]
def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0,
run_local: Optional[bool] = False) -> list | None:
"""
Function to search documents in the Wikivoyage database.
"""
return search(query=query, table_name="wikivoyage_documents", category="docs",
limit=limit, reranking=reranking, run_local=run_local)
def search_wikivoyage_listings(query: str, cities: list, limit: int = 10, reranking: int = 0,
run_local: Optional[bool] = False) -> list | None:
"""
Function to search listings in the Wikivoyage database, post-filtered by cities.
"""
cities_filter = f"city IN {tuple(cities)}"
return search(query=query, table_name="wikivoyage_listings", filter_condition=cities_filter,
category="listings", limit=limit, reranking=reranking, run_local=run_local)
|