""" EVENTUALLY: Tune the batch sizes selected here for a good balance of speed and generality. """ import os import torch import numpy as np from itertools import product from colbert.infra.config import ColBERTConfig from colbert.indexing.codecs.residual_embeddings import ResidualEmbeddings from colbert.utils.utils import print_message import pathlib from torch.utils.cpp_extension import load class ResidualCodec: Embeddings = ResidualEmbeddings def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None): self.use_gpu = config.total_visible_gpus > 0 ResidualCodec.try_load_torch_extensions(self.use_gpu) if self.use_gpu > 0: self.centroids = centroids.cuda().half() else: self.centroids = centroids.float() self.dim, self.nbits = config.dim, config.nbits self.avg_residual = avg_residual if torch.is_tensor(self.avg_residual): if self.use_gpu: self.avg_residual = self.avg_residual.cuda().half() if torch.is_tensor(bucket_cutoffs): if self.use_gpu: bucket_cutoffs = bucket_cutoffs.cuda() bucket_weights = bucket_weights.half().cuda() self.bucket_cutoffs = bucket_cutoffs self.bucket_weights = bucket_weights if not self.use_gpu and self.bucket_weights is not None: self.bucket_weights = self.bucket_weights.to(torch.float32) self.arange_bits = torch.arange(0, self.nbits, device='cuda' if self.use_gpu else 'cpu', dtype=torch.uint8) self.rank = config.rank # We reverse the residual bits because arange_bits as # currently constructed produces results with the reverse # of the expected endianness self.reversed_bit_map = [] mask = (1 << self.nbits) - 1 for i in range(256): # The reversed byte z = 0 for j in range(8, 0, -self.nbits): # Extract a subsequence of length n bits x = (i >> (j - self.nbits)) & mask # Reverse the endianness of each bit subsequence (e.g. 10 -> 01) y = 0 for k in range(self.nbits - 1, -1, -1): y += ((x >> (self.nbits - k - 1)) & 1) * (2 ** k) # Set the corresponding bits in the output byte z |= y if j > self.nbits: z <<= self.nbits self.reversed_bit_map.append(z) self.reversed_bit_map = torch.tensor(self.reversed_bit_map).to(torch.uint8) # A table of all possible lookup orders into bucket_weights # given n bits per lookup keys_per_byte = 8 // self.nbits if self.bucket_weights is not None: self.decompression_lookup_table = ( torch.tensor( list( product( list(range(len(self.bucket_weights))), repeat=keys_per_byte ) ) ) .to(torch.uint8) ) else: self.decompression_lookup_table = None if self.use_gpu: self.reversed_bit_map = self.reversed_bit_map.cuda() if self.decompression_lookup_table is not None: self.decompression_lookup_table = self.decompression_lookup_table.cuda() @classmethod def try_load_torch_extensions(cls, use_gpu): if hasattr(cls, "loaded_extensions") or not use_gpu: return print_message(f"Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") decompress_residuals_cpp = load( name="decompress_residuals_cpp", sources=[ os.path.join( pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cpp" ), os.path.join( pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cu" ), ], verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", ) cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp print_message(f"Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") packbits_cpp = load( name="packbits_cpp", sources=[ os.path.join( pathlib.Path(__file__).parent.resolve(), "packbits.cpp" ), os.path.join( pathlib.Path(__file__).parent.resolve(), "packbits.cu" ), ], verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", ) cls.packbits = packbits_cpp.packbits_cpp cls.loaded_extensions = True @classmethod def load(cls, index_path): config = ColBERTConfig.load_from_index(index_path) centroids_path = os.path.join(index_path, 'centroids.pt') avgresidual_path = os.path.join(index_path, 'avg_residual.pt') buckets_path = os.path.join(index_path, 'buckets.pt') centroids = torch.load(centroids_path, map_location='cpu') avg_residual = torch.load(avgresidual_path, map_location='cpu') bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu') if avg_residual.dim() == 0: avg_residual = avg_residual.item() return cls(config=config, centroids=centroids, avg_residual=avg_residual, bucket_cutoffs=bucket_cutoffs, bucket_weights=bucket_weights) def save(self, index_path): assert self.avg_residual is not None assert torch.is_tensor(self.bucket_cutoffs), self.bucket_cutoffs assert torch.is_tensor(self.bucket_weights), self.bucket_weights centroids_path = os.path.join(index_path, 'centroids.pt') avgresidual_path = os.path.join(index_path, 'avg_residual.pt') buckets_path = os.path.join(index_path, 'buckets.pt') torch.save(self.centroids.half(), centroids_path) torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path) if torch.is_tensor(self.avg_residual): torch.save(self.avg_residual, avgresidual_path) else: torch.save(torch.tensor([self.avg_residual]), avgresidual_path) def compress(self, embs): codes, residuals = [], [] for batch in embs.split(1 << 18): if self.use_gpu: batch = batch.cuda().half() codes_ = self.compress_into_codes(batch, out_device=batch.device) centroids_ = self.lookup_centroids(codes_, out_device=batch.device) residuals_ = (batch - centroids_) codes.append(codes_.cpu()) residuals.append(self.binarize(residuals_).cpu()) codes = torch.cat(codes) residuals = torch.cat(residuals) return ResidualCodec.Embeddings(codes, residuals) def binarize(self, residuals): residuals = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) residuals = residuals.unsqueeze(-1).expand(*residuals.size(), self.nbits) # add a new nbits-wide dim residuals = residuals >> self.arange_bits # divide by 2^bit for each bit position residuals = residuals & 1 # apply mod 2 to binarize assert self.dim % 8 == 0 assert self.dim % (self.nbits * 8) == 0, (self.dim, self.nbits) if self.use_gpu: residuals_packed = ResidualCodec.packbits(residuals.contiguous().flatten()) else: residuals_packed = np.packbits(np.asarray(residuals.contiguous().flatten())) residuals_packed = torch.as_tensor(residuals_packed, dtype=torch.uint8) residuals_packed = residuals_packed.reshape(residuals.size(0), self.dim // 8 * self.nbits) return residuals_packed def compress_into_codes(self, embs, out_device): """ EVENTUALLY: Fusing the kernels or otherwise avoiding materalizing the entire matrix before max(dim=0) seems like it would help here a lot. """ codes = [] bsize = (1 << 29) // self.centroids.size(0) for batch in embs.split(bsize): if self.use_gpu: indices = (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device) else: indices = (self.centroids @ batch.T.cpu().float()).max(dim=0).indices.to(device=out_device) codes.append(indices) return torch.cat(codes) def lookup_centroids(self, codes, out_device): """ Handles multi-dimensional codes too. EVENTUALLY: The .split() below should happen on a flat view. """ centroids = [] for batch in codes.split(1 << 20): if self.use_gpu: centroids.append(self.centroids[batch.cuda().long()].to(device=out_device)) else: centroids.append(self.centroids[batch.long()].to(device=out_device)) return torch.cat(centroids) #@profile def decompress(self, compressed_embs: Embeddings): """ We batch below even if the target device is CUDA to avoid large temporary buffers causing OOM. """ codes, residuals = compressed_embs.codes, compressed_embs.residuals D = [] for codes_, residuals_ in zip(codes.split(1 << 15), residuals.split(1 << 15)): if self.use_gpu: codes_, residuals_ = codes_.cuda(), residuals_.cuda() centroids_ = ResidualCodec.decompress_residuals( residuals_, self.bucket_weights, self.reversed_bit_map, self.decompression_lookup_table, codes_, self.centroids, self.dim, self.nbits, ).cuda() else: # TODO: Remove dead code centroids_ = self.lookup_centroids(codes_, out_device='cpu') residuals_ = self.reversed_bit_map[residuals_.long()] residuals_ = self.decompression_lookup_table[residuals_.long()] residuals_ = residuals_.reshape(residuals_.shape[0], -1) residuals_ = self.bucket_weights[residuals_.long()] centroids_.add_(residuals_) if self.use_gpu: D_ = torch.nn.functional.normalize(centroids_, p=2, dim=-1).half() else: D_ = torch.nn.functional.normalize(centroids_.to(torch.float32), p=2, dim=-1) D.append(D_) return torch.cat(D)