|
import torch |
|
import tqdm |
|
|
|
def maxsim(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs): |
|
ncentroids, nquery_vectors = centroid_scores.shape |
|
centroid_scores = centroid_scores.flatten() |
|
scores = [] |
|
|
|
for i in tqdm.tqdm(range(len(pids)), desc='Calculating maxsim over centroids...'): |
|
seen_codes = set() |
|
per_doc_scores = torch.full((nquery_vectors,), -9999, dtype=torch.float32) |
|
|
|
pid = pids[i] |
|
for j in range(doclens[pid]): |
|
code = codes[offsets[pid] + j] |
|
assert code < ncentroids |
|
if idx[code] and code not in seen_codes: |
|
for k in range(nquery_vectors): |
|
per_doc_scores[k] = torch.max( |
|
per_doc_scores[k], |
|
centroid_scores[code * nquery_vectors + k] |
|
) |
|
seen_codes.add(code) |
|
|
|
score = torch.sum(per_doc_scores[:nquery_vectors]).item() |
|
scores += [(score, pid)] |
|
|
|
|
|
global_scores = sorted(scores, key=lambda x: x[0], reverse=True) |
|
filtered_pids = [pid for _, pid in global_scores[:nfiltered_docs]] |
|
filtered_pids = torch.tensor(filtered_pids, dtype=torch.int32) |
|
|
|
return filtered_pids |
|
|
|
|
|
def filter_pids(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs): |
|
filtered_pids = maxsim( |
|
pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs |
|
) |
|
|
|
print('Stage 2 filtering:', pids.shape, '->', filtered_pids.shape) |
|
|
|
nfinal_filtered_docs = int(nfiltered_docs / 4) |
|
ones = [True] * centroid_scores.size(0) |
|
|
|
final_filtered_pids = maxsim( |
|
filtered_pids, centroid_scores, codes, doclens, offsets, ones, nfinal_filtered_docs |
|
) |
|
|
|
print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) |
|
|
|
return final_filtered_pids |
|
|
|
|
|
def decompress_residuals(pids, doclens, offsets, bucket_weights, reversed_bit_map, |
|
bucket_weight_combinations, binary_residuals, codes, |
|
centroids, dim, nbits): |
|
npacked_vals_per_byte = 8 // nbits |
|
packed_dim = dim // npacked_vals_per_byte |
|
cumulative_lengths = [0 for _ in range(len(pids)+1)] |
|
noutputs = 0 |
|
for i in range(len(pids)): |
|
noutputs += doclens[pids[i]] |
|
cumulative_lengths[i + 1] = cumulative_lengths[i] + doclens[pids[i]] |
|
|
|
output = [] |
|
|
|
binary_residuals = binary_residuals.flatten() |
|
centroids = centroids.flatten() |
|
|
|
|
|
for i in range(len(pids)): |
|
pid = pids[i] |
|
|
|
|
|
offset = offsets[pid] |
|
|
|
|
|
for j in range(doclens[pid]): |
|
code = codes[offset + j] |
|
|
|
|
|
for k in range(packed_dim): |
|
x = binary_residuals[(offset + j) * packed_dim + k] |
|
x = reversed_bit_map[x] |
|
|
|
|
|
|
|
for l in range(npacked_vals_per_byte): |
|
output_dim_idx = k * npacked_vals_per_byte + l |
|
bucket_weight_idx = bucket_weight_combinations[x * npacked_vals_per_byte + l] |
|
output[(cumulative_lengths[i] + j) * dim + output_dim_idx] = \ |
|
bucket_weights[bucket_weight_idx] + centroids[code * dim + output_dim_idx] |
|
|
|
return output |
|
|