fdurant's picture
fix: better logging
fe7831c
raw
history blame contribute delete
No virus
4.75 kB
from typing import Any, Dict, List
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
import torch
import logging
logger = logging.getLogger(__name__)
# Hardcoded, I know
MODEL = "fdurant/colbert-xm-for-inference-api"
class EndpointHandler():
def __init__(self, path=""):
self._config = ColBERTConfig(
# Defaults copied from https://github.com/datastax/ragstack-ai/blob/main/libs/colbert/ragstack_colbert/colbert_embedding_model.py
doc_maxlen=512, # Maximum number of tokens for document chunks. Should equal the chunk_size.
nbits=2, # The number bits that each dimension encodes to.
kmeans_niters=4, # Number of iterations for k-means clustering during quantization.
nranks=-1, # Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs.
checkpoint=MODEL, # Path to the model checkpoint.
)
self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3)
def __call__(self, data: Any) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` : will be serialized and returned.
When the input is a single query string, the returned list will contain a single dictionary with:
- input (:obj: `str`) : The input query.
- query_embedding (:obj: `list`) : The query embedding of shape (1, 32, 128).
When the input is a batch (= list) of chunk strings, the returned list will contain a dictionary for each chunk:
- input (:obj: `str`) : The input chunk.
- chunk_embedding (:obj: `list`) : The chunk embedding of shape (1, num_tokens, 128)
- token_ids (:obj: `list`) : The token ids.
- token_list (:obj: `list`) : The token list.
"""
inputs = data["inputs"]
texts = []
if isinstance(inputs, str):
texts = [inputs]
elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs):
texts = inputs
else:
raise ValueError("Invalid input data format")
with torch.inference_mode():
if len(texts) == 1:
# It's a query
logger.info(f"Received query of 1 text with {len(texts[0])} characters and {len(texts[0].split())} words")
embedding = self._checkpoint.queryFromText(
queries=texts,
full_length_search=False, # Indicates whether to encode the query for a full-length search.
)
logger.info(f"Query embedding shape: {embedding.shape}")
return [
{"input": inputs, "query_embedding": embedding.tolist()[0]}
]
elif len(texts) > 1:
# It's a batch of chunks
logger.info(f"Received batch of {len(texts)} chunks")
for i, text in enumerate(texts):
logger.info(f"Chunk {i} has {len(text)} characters and {len(text.split())} words")
embeddings, token_id_lists = self._checkpoint.docFromText(
docs=texts,
bsize=self._config.bsize, # Batch size
keep_dims=True, # Do NOT flatten the embeddings
return_tokens=True, # Return the tokens as well
)
logger.info(f"Chunk embeddings shape: {embeddings.shape}")
token_lists = []
for text, embedding, token_ids in zip(texts, embeddings, token_id_lists):
logger.debug(f"Chunk: {text}")
logger.debug(f"Chunk embedding shape: {embedding.shape}")
logger.debug(f"Chunk token ids: {token_ids}")
token_list = self._checkpoint.doc_tokenizer.tok.convert_ids_to_tokens(token_ids)
token_lists.append(token_list)
logger.debug(f"Chunk tokens: {token_list}")
# reconstructed_text = self._checkpoint.doc_tokenizer.tok.decode(token_count)
# logger.debug(f"Reconstructed text with special tokens: {reconstructed_text}")
return [
{"input": _input, "chunk_embedding": embedding.tolist(), "token_ids": token_ids.tolist(), "token_list": token_list}
for _input, embedding, token_ids, token_list in zip(texts, embeddings, token_id_lists, token_lists)
]
else:
raise ValueError("No data to process")