File size: 3,192 Bytes
c49f0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from utils.models import get_bm25_model, preprocess_text
import numpy as np

# BM25 Filtering and Retrieval


def filter_data_docs(data, ticker, quarter, year):
    year_int = int(year)
    data_subset = data[
        (data["Year"] == year_int)
        & (data["Quarter"] == quarter)
        & (data["Ticker"] == ticker)
    ]
    return data_subset


def get_bm25_search_hits(corpus, sparse_scores, top_n=50):
    bm25_search = []
    indices = []
    for idx in sparse_scores:
        if len(bm25_search) <= top_n:
            bm25_search.append(corpus[idx])
            indices.append(idx)
    indices = [int(x) for x in indices]
    return indices


# BM-25 Filtering
def get_indices_bm25(
    data, query, ticker=None, quarter=None, year=None, num_candidates=50
):
    if ticker is None or quarter is None or year is None:
        corpus, bm25 = get_bm25_model(data)
    else:
        filtered_data = filter_data_docs(data, ticker, quarter, year)
        corpus, bm25 = get_bm25_model(filtered_data)
    tokenized_query = preprocess_text(query).split()
    sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1]
    indices_hits = get_bm25_search_hits(corpus, sparse_scores, num_candidates)
    return indices_hits


def query_pinecone(
    dense_vec,
    top_k,
    index,
    year=None,
    quarter=None,
    ticker=None,
    keywords=None,
    indices=None,
    threshold=0.25,
):

    filter_dict = {
        "QA_Flag": {"$eq": "Answer"},
    }
    if year is not None:
        filter_dict["Year"] = int(year)
    if quarter is not None:
        filter_dict["Quarter"] = {"$eq": quarter}
    if ticker is not None:
        filter_dict["Ticker"] = {"$eq": ticker}
    if keywords is not None:
        filter_dict["Keywords"] = {"$in": keywords}
    if indices is not None:
        filter_dict["index"] = {"$in": indices}

    xc = index.query(
        vector=dense_vec,
        top_k=top_k,
        filter=filter_dict,
        include_metadata=True,
    )

    # filter the context passages based on the score threshold
    filtered_matches = []
    for match in xc["matches"]:
        if match["score"] >= threshold:
            filtered_matches.append(match)
    xc["matches"] = filtered_matches
    return xc


def sentence_id_combine(data, query_results, lag=1):
    # Extract sentence IDs from query results
    ids = [
        result["metadata"]["Sentence_id"]
        for result in query_results["matches"]
    ]
    # Generate new IDs by adding a lag value to the original IDs
    new_ids = [id + i for id in ids for i in range(-lag, lag + 1)]
    # Remove duplicates and sort the new IDs
    new_ids = sorted(set(new_ids))
    # Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1
    lookup_ids = [
        new_ids[i : i + (lag * 2 + 1)]
        for i in range(0, len(new_ids), lag * 2 + 1)
    ]
    # Create a list of context sentences by joining the sentences
    #  corresponding to the lookup IDs
    context_list = [
        " ".join(
            data.loc[data["Sentence_id"].isin(lookup_id), "Text"].to_list()
        )
        for lookup_id in lookup_ids
    ]
    context = " ".join(context_list).strip()
    return context