# from src import * import logging import os import lancedb from lancedb.rerankers import ColbertReranker import sys logger = logging.getLogger(__name__) from typing import Optional from src.vectordb.helpers import set_uri # db = lancedb.connect("/tmp/db") def search(query: str, table_name: str, filter_condition: Optional[str] = None, category: str = "docs", limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False) -> list | None: """ Generalized function to search a database table, with optional filters and reranking. Args: - query: str, search query. - table_name: str, name of the table to search. - filter_condition: Optional[str], optional SQL-like condition for filtering results. - category: str, type of category (default is 'docs'). - limit: int, number of results (default is 10). - reranking: int (0 or 1), if activated, ColbertReranker is used. - run_local: Optional[bool], whether to run in a local environment. Returns: A list of the most relevant documents or listings based on the category. """ uri = set_uri(run_local) try: db = lancedb.connect(uri) except Exception as e: logger.error(f"Error while connecting to DB: {e}") return None logger.info(f"Connected to {table_name} DB.") table = db.open_table(table_name) search_query = table.search(query).metric('cosine') if filter_condition: search_query = search_query.where(filter_condition) if reranking: try: column = 'description' if category == 'listings' else 'text' reranker = ColbertReranker(column=column) results = search_query.rerank(reranker=reranker).limit(limit).to_list() except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logger.error(f"Error while reranking results: {e}, {(exc_type, fname, exc_tb.tb_lineno)}") return None else: try: results = search_query.limit(limit).to_list() except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] logger.error(f"Error while searching: {e}, {(exc_type, fname, exc_tb.tb_lineno)}") return None logger.info("Found the most relevant documents.") if category == "docs": return [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in results] else: return [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'], "description": r['description']} for r in results] def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False) -> list | None: """ Function to search documents in the Wikivoyage database. """ return search(query=query, table_name="wikivoyage_documents", category="docs", limit=limit, reranking=reranking, run_local=run_local) def search_wikivoyage_listings(query: str, cities: list, limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False) -> list | None: """ Function to search listings in the Wikivoyage database, post-filtered by cities. """ cities_filter = f"city IN {tuple(cities)}" return search(query=query, table_name="wikivoyage_listings", filter_condition=cities_filter, category="listings", limit=limit, reranking=reranking, run_local=run_local)