File size: 3,119 Bytes
c657ec0
16abd01
 
c657ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc19b61
c657ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16abd01
c657ec0
 
 
16abd01
 
d990b6f
16abd01
42a39da
 
c657ec0
16abd01
 
2e01c8a
d990b6f
16abd01
42a39da
d990b6f
c657ec0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
import gradio as gr
import numpy as np

log = logging.getLogger('filter methods')
logging.basicConfig(level=logging.INFO)


def filter_docs_by_meta(docs, filter_dict):
    """
    Filter documents by multiple parameters
    Parameters:
        docs : List[langchain.schema.Document]
        filter_dict :  Dict[str, Any]

    Returns: List of filtered documents

    Examples:
        docs = [langchain.schema.Document(metadata={'a': 1, 'b': 2}, text='text1')
                langchain.schema.Document(metadata={'a': 1, 'b': 3}, text='text2')]
        filter_dict = {'a': 1}
        filter_docs_by_meta(docs, filter_dict)
        [langchain.schema.Document(metadata={'a': 1, 'b': 2}, text='text1')]

        docs = [langchain.schema.Document(metadata={'a': 1, 'b': 2}, text='text1')
                langchain.schema.Document(metadata={'a': 1, 'b': 3}, text='text2')]
        filter_dict = {'a': 1, 'b': 2}
        filter_docs_by_meta(docs, filter_dict)
        [langchain.schema.Document(metadata={'a': 1, 'b': 2}, text='text1')]

    """
    filtered_docs = []
    for doc in docs:
        append = True
        for key, value in filter_dict.items():
            if doc.metadata[key] != int(value):
                append = False
                break
        if append:
            filtered_docs.append(doc)
    return filtered_docs


def search_with_filter(vector_store, query, filter_dict, target_k=5, init_k=100, step=50):
    """
    Expand search with filter until reaching at least a pre-determined number of documents.
    ----------
    Parameters
        vector_store : langchain.vectorstores.FAISS
            The FAISS vector store.
        query : str
            The query to search for.
        filter_dict :  Dict[str, Any]
            The parameters to filer for
        target_k : int
            The minimum number of documents desired after filtering
        init_k : int
            The top-k documents to extract for the initial search.
        step : int
            The size of the step when enlarging the search.

    Returns: List of at least target_k Documents for post-processing.

    """
    context = filter_docs_by_meta(vector_store.similarity_search(query, k=init_k), filter_dict)
    len_docs_begin = len(context)
    if len_docs_begin >= target_k:
        log.info(f'Initial search contains {len_docs_begin} documents. Expansion not required. ')
        return context
    MAX_K = 50000  # This is more than the number of actual documents.
    for top_k_docs in np.arange(init_k, MAX_K, step):
        log.info(f'Context contains {len(context)} documents')
        log.info(f'Expanding search with k={top_k_docs}')
        context = filter_docs_by_meta(vector_store.similarity_search(query, k=int(top_k_docs)), filter_dict)
        if len(context) >= target_k:
            log.info(f'Success. Context contains {len(context)} documents matching the filtering criteria')
            return context
    log.info(f'Failed to reach target number of documents,'
             f' context contains {len(context)} documents matching the filtering criteria')
    return context