File size: 1,920 Bytes
8f175aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import tqdm

def maxsim(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs):
    ncentroids, nquery_vectors = centroid_scores.shape
    centroid_scores = centroid_scores.flatten()
    scores = []

    for i in tqdm.tqdm(range(len(pids)), desc='Calculating maxsim over centroids...'):
        seen_codes = set()
        per_doc_scores = torch.full((nquery_vectors,), -9999, dtype=torch.float32)

        pid = pids[i]
        for j in range(doclens[pid]):
            code = codes[offsets[pid] + j]
            assert code < ncentroids
            if idx[code] and code not in seen_codes:
                for k in range(nquery_vectors):
                    per_doc_scores[k] = torch.max(
                        per_doc_scores[k], 
                        centroid_scores[code * nquery_vectors + k]
                    )
                seen_codes.add(code)

        score = torch.sum(per_doc_scores[:nquery_vectors]).item()
        scores += [(score, pid)]

    # Sort and return scores
    global_scores = sorted(scores, key=lambda x: x[0], reverse=True)
    filtered_pids = [pid for _, pid in global_scores[:nfiltered_docs]]
    filtered_pids = torch.tensor(filtered_pids, dtype=torch.int32)

    return filtered_pids


def filter_pids(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs):
    filtered_pids = maxsim(
        pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs
    )

    print('Stage 2 filtering:', pids.shape, '->', filtered_pids.shape) # (all_docs) -> (n_docs/4)

    nfinal_filtered_docs = int(nfiltered_docs / 4)
    ones = [True] * centroid_scores.size(0)

    final_filtered_pids = maxsim(
        filtered_pids, centroid_scores, codes, doclens, offsets, ones, nfinal_filtered_docs
    )

    print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4)

    return final_filtered_pids