Spaces:
Runtime error
Runtime error
""" | |
EVENTUALLY: Tune the batch sizes selected here for a good balance of speed and generality. | |
""" | |
import os | |
import torch | |
import numpy as np | |
from itertools import product | |
from colbert.infra.config import ColBERTConfig | |
from colbert.indexing.codecs.residual_embeddings import ResidualEmbeddings | |
from colbert.utils.utils import print_message | |
import pathlib | |
from torch.utils.cpp_extension import load | |
class ResidualCodec: | |
Embeddings = ResidualEmbeddings | |
def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None): | |
self.use_gpu = config.total_visible_gpus > 0 | |
ResidualCodec.try_load_torch_extensions(self.use_gpu) | |
if self.use_gpu > 0: | |
self.centroids = centroids.cuda().half() | |
else: | |
self.centroids = centroids.float() | |
self.dim, self.nbits = config.dim, config.nbits | |
self.avg_residual = avg_residual | |
if torch.is_tensor(self.avg_residual): | |
if self.use_gpu: | |
self.avg_residual = self.avg_residual.cuda().half() | |
if torch.is_tensor(bucket_cutoffs): | |
if self.use_gpu: | |
bucket_cutoffs = bucket_cutoffs.cuda() | |
bucket_weights = bucket_weights.half().cuda() | |
self.bucket_cutoffs = bucket_cutoffs | |
self.bucket_weights = bucket_weights | |
if not self.use_gpu and self.bucket_weights is not None: | |
self.bucket_weights = self.bucket_weights.to(torch.float32) | |
self.arange_bits = torch.arange(0, self.nbits, device='cuda' if self.use_gpu else 'cpu', dtype=torch.uint8) | |
self.rank = config.rank | |
# We reverse the residual bits because arange_bits as | |
# currently constructed produces results with the reverse | |
# of the expected endianness | |
self.reversed_bit_map = [] | |
mask = (1 << self.nbits) - 1 | |
for i in range(256): | |
# The reversed byte | |
z = 0 | |
for j in range(8, 0, -self.nbits): | |
# Extract a subsequence of length n bits | |
x = (i >> (j - self.nbits)) & mask | |
# Reverse the endianness of each bit subsequence (e.g. 10 -> 01) | |
y = 0 | |
for k in range(self.nbits - 1, -1, -1): | |
y += ((x >> (self.nbits - k - 1)) & 1) * (2 ** k) | |
# Set the corresponding bits in the output byte | |
z |= y | |
if j > self.nbits: | |
z <<= self.nbits | |
self.reversed_bit_map.append(z) | |
self.reversed_bit_map = torch.tensor(self.reversed_bit_map).to(torch.uint8) | |
# A table of all possible lookup orders into bucket_weights | |
# given n bits per lookup | |
keys_per_byte = 8 // self.nbits | |
if self.bucket_weights is not None: | |
self.decompression_lookup_table = ( | |
torch.tensor( | |
list( | |
product( | |
list(range(len(self.bucket_weights))), | |
repeat=keys_per_byte | |
) | |
) | |
) | |
.to(torch.uint8) | |
) | |
else: | |
self.decompression_lookup_table = None | |
if self.use_gpu: | |
self.reversed_bit_map = self.reversed_bit_map.cuda() | |
if self.decompression_lookup_table is not None: | |
self.decompression_lookup_table = self.decompression_lookup_table.cuda() | |
def try_load_torch_extensions(cls, use_gpu): | |
if hasattr(cls, "loaded_extensions") or not use_gpu: | |
return | |
print_message(f"Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") | |
decompress_residuals_cpp = load( | |
name="decompress_residuals_cpp", | |
sources=[ | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cpp" | |
), | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "decompress_residuals.cu" | |
), | |
], | |
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", | |
) | |
cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp | |
print_message(f"Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...") | |
packbits_cpp = load( | |
name="packbits_cpp", | |
sources=[ | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "packbits.cpp" | |
), | |
os.path.join( | |
pathlib.Path(__file__).parent.resolve(), "packbits.cu" | |
), | |
], | |
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", | |
) | |
cls.packbits = packbits_cpp.packbits_cpp | |
cls.loaded_extensions = True | |
def load(cls, index_path): | |
config = ColBERTConfig.load_from_index(index_path) | |
centroids_path = os.path.join(index_path, 'centroids.pt') | |
avgresidual_path = os.path.join(index_path, 'avg_residual.pt') | |
buckets_path = os.path.join(index_path, 'buckets.pt') | |
centroids = torch.load(centroids_path, map_location='cpu') | |
avg_residual = torch.load(avgresidual_path, map_location='cpu') | |
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu') | |
if avg_residual.dim() == 0: | |
avg_residual = avg_residual.item() | |
return cls(config=config, centroids=centroids, avg_residual=avg_residual, bucket_cutoffs=bucket_cutoffs, bucket_weights=bucket_weights) | |
def save(self, index_path): | |
assert self.avg_residual is not None | |
assert torch.is_tensor(self.bucket_cutoffs), self.bucket_cutoffs | |
assert torch.is_tensor(self.bucket_weights), self.bucket_weights | |
centroids_path = os.path.join(index_path, 'centroids.pt') | |
avgresidual_path = os.path.join(index_path, 'avg_residual.pt') | |
buckets_path = os.path.join(index_path, 'buckets.pt') | |
torch.save(self.centroids.half(), centroids_path) | |
torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path) | |
if torch.is_tensor(self.avg_residual): | |
torch.save(self.avg_residual, avgresidual_path) | |
else: | |
torch.save(torch.tensor([self.avg_residual]), avgresidual_path) | |
def compress(self, embs): | |
codes, residuals = [], [] | |
for batch in embs.split(1 << 18): | |
if self.use_gpu: | |
batch = batch.cuda().half() | |
codes_ = self.compress_into_codes(batch, out_device=batch.device) | |
centroids_ = self.lookup_centroids(codes_, out_device=batch.device) | |
residuals_ = (batch - centroids_) | |
codes.append(codes_.cpu()) | |
residuals.append(self.binarize(residuals_).cpu()) | |
codes = torch.cat(codes) | |
residuals = torch.cat(residuals) | |
return ResidualCodec.Embeddings(codes, residuals) | |
def binarize(self, residuals): | |
residuals = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) | |
residuals = residuals.unsqueeze(-1).expand(*residuals.size(), self.nbits) # add a new nbits-wide dim | |
residuals = residuals >> self.arange_bits # divide by 2^bit for each bit position | |
residuals = residuals & 1 # apply mod 2 to binarize | |
assert self.dim % 8 == 0 | |
assert self.dim % (self.nbits * 8) == 0, (self.dim, self.nbits) | |
if self.use_gpu: | |
residuals_packed = ResidualCodec.packbits(residuals.contiguous().flatten()) | |
else: | |
residuals_packed = np.packbits(np.asarray(residuals.contiguous().flatten())) | |
residuals_packed = torch.as_tensor(residuals_packed, dtype=torch.uint8) | |
residuals_packed = residuals_packed.reshape(residuals.size(0), self.dim // 8 * self.nbits) | |
return residuals_packed | |
def compress_into_codes(self, embs, out_device): | |
""" | |
EVENTUALLY: Fusing the kernels or otherwise avoiding materalizing the entire matrix before max(dim=0) | |
seems like it would help here a lot. | |
""" | |
codes = [] | |
bsize = (1 << 29) // self.centroids.size(0) | |
for batch in embs.split(bsize): | |
if self.use_gpu: | |
indices = (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device) | |
else: | |
indices = (self.centroids @ batch.T.cpu().float()).max(dim=0).indices.to(device=out_device) | |
codes.append(indices) | |
return torch.cat(codes) | |
def lookup_centroids(self, codes, out_device): | |
""" | |
Handles multi-dimensional codes too. | |
EVENTUALLY: The .split() below should happen on a flat view. | |
""" | |
centroids = [] | |
for batch in codes.split(1 << 20): | |
if self.use_gpu: | |
centroids.append(self.centroids[batch.cuda().long()].to(device=out_device)) | |
else: | |
centroids.append(self.centroids[batch.long()].to(device=out_device)) | |
return torch.cat(centroids) | |
#@profile | |
def decompress(self, compressed_embs: Embeddings): | |
""" | |
We batch below even if the target device is CUDA to avoid large temporary buffers causing OOM. | |
""" | |
codes, residuals = compressed_embs.codes, compressed_embs.residuals | |
D = [] | |
for codes_, residuals_ in zip(codes.split(1 << 15), residuals.split(1 << 15)): | |
if self.use_gpu: | |
codes_, residuals_ = codes_.cuda(), residuals_.cuda() | |
centroids_ = ResidualCodec.decompress_residuals( | |
residuals_, | |
self.bucket_weights, | |
self.reversed_bit_map, | |
self.decompression_lookup_table, | |
codes_, | |
self.centroids, | |
self.dim, | |
self.nbits, | |
).cuda() | |
else: | |
# TODO: Remove dead code | |
centroids_ = self.lookup_centroids(codes_, out_device='cpu') | |
residuals_ = self.reversed_bit_map[residuals_.long()] | |
residuals_ = self.decompression_lookup_table[residuals_.long()] | |
residuals_ = residuals_.reshape(residuals_.shape[0], -1) | |
residuals_ = self.bucket_weights[residuals_.long()] | |
centroids_.add_(residuals_) | |
if self.use_gpu: | |
D_ = torch.nn.functional.normalize(centroids_, p=2, dim=-1).half() | |
else: | |
D_ = torch.nn.functional.normalize(centroids_.to(torch.float32), p=2, dim=-1) | |
D.append(D_) | |
return torch.cat(D) | |