import torch import tqdm def maxsim(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs): ncentroids, nquery_vectors = centroid_scores.shape centroid_scores = centroid_scores.flatten() scores = [] for i in tqdm.tqdm(range(len(pids)), desc='Calculating maxsim over centroids...'): seen_codes = set() per_doc_scores = torch.full((nquery_vectors,), -9999, dtype=torch.float32) pid = pids[i] for j in range(doclens[pid]): code = codes[offsets[pid] + j] assert code < ncentroids if idx[code] and code not in seen_codes: for k in range(nquery_vectors): per_doc_scores[k] = torch.max( per_doc_scores[k], centroid_scores[code * nquery_vectors + k] ) seen_codes.add(code) score = torch.sum(per_doc_scores[:nquery_vectors]).item() scores += [(score, pid)] # Sort and return scores global_scores = sorted(scores, key=lambda x: x[0], reverse=True) filtered_pids = [pid for _, pid in global_scores[:nfiltered_docs]] filtered_pids = torch.tensor(filtered_pids, dtype=torch.int32) return filtered_pids def filter_pids(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs): filtered_pids = maxsim( pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs ) print('Stage 2 filtering:', pids.shape, '->', filtered_pids.shape) # (all_docs) -> (n_docs/4) nfinal_filtered_docs = int(nfiltered_docs / 4) ones = [True] * centroid_scores.size(0) final_filtered_pids = maxsim( filtered_pids, centroid_scores, codes, doclens, offsets, ones, nfinal_filtered_docs ) print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4) return final_filtered_pids def decompress_residuals(pids, doclens, offsets, bucket_weights, reversed_bit_map, bucket_weight_combinations, binary_residuals, codes, centroids, dim, nbits): npacked_vals_per_byte = 8 // nbits packed_dim = dim // npacked_vals_per_byte cumulative_lengths = [0 for _ in range(len(pids)+1)] noutputs = 0 for i in range(len(pids)): noutputs += doclens[pids[i]] cumulative_lengths[i + 1] = cumulative_lengths[i] + doclens[pids[i]] output = [] binary_residuals = binary_residuals.flatten() centroids = centroids.flatten() # Iterate over all documents for i in range(len(pids)): pid = pids[i] # Offset into packed list of token vectors for the given document offset = offsets[pid] # For each document, iterate over all token vectors for j in range(doclens[pid]): code = codes[offset + j] # For each token vector, iterate over the packed (8-bit) residual values for k in range(packed_dim): x = binary_residuals[(offset + j) * packed_dim + k] x = reversed_bit_map[x] # For each packed residual value, iterate over the bucket weight indices. # If we use n-bit compression, that means there will be (8 / n) indices per packed value. for l in range(npacked_vals_per_byte): output_dim_idx = k * npacked_vals_per_byte + l bucket_weight_idx = bucket_weight_combinations[x * npacked_vals_per_byte + l] output[(cumulative_lengths[i] + j) * dim + output_dim_idx] = \ bucket_weights[bucket_weight_idx] + centroids[code * dim + output_dim_idx] return output