from typing import Any, Dict, List from colbert.infra import ColBERTConfig from colbert.modeling.checkpoint import Checkpoint import torch import logging logger = logging.getLogger(__name__) # Hardcoded, I know MODEL = "fdurant/colbert-xm-for-inference-api" class EndpointHandler(): def __init__(self, path=""): self._config = ColBERTConfig( # Defaults copied from https://github.com/datastax/ragstack-ai/blob/main/libs/colbert/ragstack_colbert/colbert_embedding_model.py doc_maxlen=512, # Maximum number of tokens for document chunks. Should equal the chunk_size. nbits=2, # The number bits that each dimension encodes to. kmeans_niters=4, # Number of iterations for k-means clustering during quantization. nranks=-1, # Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs. checkpoint=MODEL, # Path to the model checkpoint. ) self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3) def __call__(self, data: Any) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) Return: A :obj:`list` : will be serialized and returned. When the input is a single query string, the returned list will contain a single dictionary with: - input (:obj: `str`) : The input query. - query_embedding (:obj: `list`) : The query embedding of shape (1, 32, 128). When the input is a batch (= list) of chunk strings, the returned list will contain a dictionary for each chunk: - input (:obj: `str`) : The input chunk. - chunk_embedding (:obj: `list`) : The chunk embedding of shape (1, num_tokens, 128) - token_ids (:obj: `list`) : The token ids. - token_list (:obj: `list`) : The token list. """ inputs = data["inputs"] texts = [] if isinstance(inputs, str): texts = [inputs] elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs): texts = inputs else: raise ValueError("Invalid input data format") with torch.inference_mode(): if len(texts) == 1: # It's a query logger.info(f"Received query of 1 text with {len(texts[0])} characters and {len(texts[0].split())} words") embedding = self._checkpoint.queryFromText( queries=texts, full_length_search=False, # Indicates whether to encode the query for a full-length search. ) logger.info(f"Query embedding shape: {embedding.shape}") return [ {"input": inputs, "query_embedding": embedding.tolist()[0]} ] elif len(texts) > 1: # It's a batch of chunks logger.info(f"Received batch of {len(texts)} chunks") for i, text in enumerate(texts): logger.info(f"Chunk {i} has {len(text)} characters and {len(text.split())} words") embeddings, token_id_lists = self._checkpoint.docFromText( docs=texts, bsize=self._config.bsize, # Batch size keep_dims=True, # Do NOT flatten the embeddings return_tokens=True, # Return the tokens as well ) logger.info(f"Chunk embeddings shape: {embeddings.shape}") token_lists = [] for text, embedding, token_ids in zip(texts, embeddings, token_id_lists): logger.debug(f"Chunk: {text}") logger.debug(f"Chunk embedding shape: {embedding.shape}") logger.debug(f"Chunk token ids: {token_ids}") token_list = self._checkpoint.doc_tokenizer.tok.convert_ids_to_tokens(token_ids) token_lists.append(token_list) logger.debug(f"Chunk tokens: {token_list}") # reconstructed_text = self._checkpoint.doc_tokenizer.tok.decode(token_count) # logger.debug(f"Reconstructed text with special tokens: {reconstructed_text}") return [ {"input": _input, "chunk_embedding": embedding.tolist(), "token_ids": token_ids.tolist(), "token_list": token_list} for _input, embedding, token_ids, token_list in zip(texts, embeddings, token_id_lists, token_lists) ] else: raise ValueError("No data to process")