File size: 3,753 Bytes
89cd5d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# from src import * 

import logging
import os
import lancedb
from lancedb.rerankers import ColbertReranker

import sys

logger = logging.getLogger(__name__)
from typing import Optional
from src.vectordb.helpers import set_uri


# db = lancedb.connect("/tmp/db")


def search(query: str, table_name: str, filter_condition: Optional[str] = None,
           category: str = "docs", limit: int = 10, reranking: int = 0,
           run_local: Optional[bool] = False) -> list | None:
    """
    Generalized function to search a database table, with optional filters and reranking.

    Args:
        - query: str, search query.
        - table_name: str, name of the table to search.
        - filter_condition: Optional[str], optional SQL-like condition for filtering results.
        - category: str, type of category (default is 'docs').
        - limit: int, number of results (default is 10).
        - reranking: int (0 or 1), if activated, ColbertReranker is used.
        - run_local: Optional[bool], whether to run in a local environment.

    Returns:
        A list of the most relevant documents or listings based on the category.
    """
    uri = set_uri(run_local)

    try:
        db = lancedb.connect(uri)
    except Exception as e:
        logger.error(f"Error while connecting to DB: {e}")
        return None

    logger.info(f"Connected to {table_name} DB.")
    table = db.open_table(table_name)

    search_query = table.search(query).metric('cosine')

    if filter_condition:
        search_query = search_query.where(filter_condition)

    if reranking:
        try:
            column = 'description' if category == 'listings' else 'text'
            reranker = ColbertReranker(column=column)
            results = search_query.rerank(reranker=reranker).limit(limit).to_list()
        except Exception as e:
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            logger.error(f"Error while reranking results: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
            return None
    else:
        try:
            results = search_query.limit(limit).to_list()
        except Exception as e:
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            logger.error(f"Error while searching: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
            return None

    logger.info("Found the most relevant documents.")

    if category == "docs":
        return [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
                results]
    else:
        return [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
                 "description": r['description']} for r in results]


def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0,
                           run_local: Optional[bool] = False) -> list | None:
    """
    Function to search documents in the Wikivoyage database.
    """
    return search(query=query, table_name="wikivoyage_documents", category="docs",
                  limit=limit, reranking=reranking, run_local=run_local)


def search_wikivoyage_listings(query: str, cities: list, limit: int = 10, reranking: int = 0,
                               run_local: Optional[bool] = False) -> list | None:
    """
    Function to search listings in the Wikivoyage database, post-filtered by cities.
    """
    cities_filter = f"city IN {tuple(cities)}"
    return search(query=query, table_name="wikivoyage_listings", filter_condition=cities_filter,
                  category="listings", limit=limit, reranking=reranking, run_local=run_local)