ir_chinese_medqa / colbert /indexing /codecs /residual_embeddings.py
欧卫
'add_app_files'
58627fa
raw
history blame
No virus
3.19 kB
import os
import torch
import ujson
import tqdm
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.utils.utils import print_message
class ResidualEmbeddings:
Strided = ResidualEmbeddingsStrided
def __init__(self, codes, residuals):
"""
Supply the already compressed residuals.
"""
# assert isinstance(residuals, bitarray), type(residuals)
assert codes.size(0) == residuals.size(0), (codes.size(), residuals.size())
assert codes.dim() == 1 and residuals.dim() == 2, (codes.size(), residuals.size())
assert residuals.dtype == torch.uint8
self.codes = codes.to(torch.int32) # (num_embeddings,) int32
self.residuals = residuals # (num_embeddings, compressed_dim) uint8
@classmethod
def load_chunks(cls, index_path, chunk_idxs, num_embeddings):
num_embeddings += 512 # pad for access with strides
dim, nbits = get_dim_and_nbits(index_path)
codes = torch.empty(num_embeddings, dtype=torch.int32)
residuals = torch.empty(num_embeddings, dim // 8 * nbits, dtype=torch.uint8)
codes_offset = 0
print_message("#> Loading codes and residuals...")
for chunk_idx in tqdm.tqdm(chunk_idxs):
chunk = cls.load(index_path, chunk_idx)
codes_endpos = codes_offset + chunk.codes.size(0)
# Copy the values over to the allocated space
codes[codes_offset:codes_endpos] = chunk.codes
residuals[codes_offset:codes_endpos] = chunk.residuals
codes_offset = codes_endpos
# codes, residuals = codes.cuda(), residuals.cuda() # FIXME: REMOVE THIS LINE!
return cls(codes, residuals)
@classmethod
def load(cls, index_path, chunk_idx):
codes = cls.load_codes(index_path, chunk_idx)
residuals = cls.load_residuals(index_path, chunk_idx)
return cls(codes, residuals)
@classmethod
def load_codes(self, index_path, chunk_idx):
codes_path = os.path.join(index_path, f'{chunk_idx}.codes.pt')
return torch.load(codes_path, map_location='cpu')
@classmethod
def load_residuals(self, index_path, chunk_idx):
residuals_path = os.path.join(index_path, f'{chunk_idx}.residuals.pt') # f'{chunk_idx}.residuals.bn'
# return _load_bitarray(residuals_path)
return torch.load(residuals_path, map_location='cpu')
def save(self, path_prefix):
codes_path = f'{path_prefix}.codes.pt'
residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn'
torch.save(self.codes, codes_path)
torch.save(self.residuals, residuals_path)
# _save_bitarray(self.residuals, residuals_path)
def __len__(self):
return self.codes.size(0)
def get_dim_and_nbits(index_path):
# TODO: Ideally load this using ColBERTConfig.load_from_index!
with open(os.path.join(index_path, 'metadata.json')) as f:
metadata = ujson.load(f)['config']
dim = metadata['dim']
nbits = metadata['nbits']
assert (dim * nbits) % 8 == 0, (dim, nbits, dim * nbits)
return dim, nbits