fdurant's picture
Add handler.py, start_emulator.sh and test scripts
4c96de6
raw
history blame
No virus
3.02 kB
from typing import Any, Dict, List
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
import torch
import logging
logger = logging.getLogger(__name__)
MODEL = "fdurant/colbert-xm-for-inference-api"
class EndpointHandler():
def __init__(self, path=""):
self._config = ColBERTConfig(
# Defaults copied from https://github.com/datastax/ragstack-ai/blob/main/libs/colbert/ragstack_colbert/colbert_embedding_model.py
doc_maxlen=512, # Maximum number of tokens for document chunks. Should equal the chunk_size.
nbits=2, # The number bits that each dimension encodes to.
kmeans_niters=4, # Number of iterations for k-means clustering during quantization.
nranks=-1, # Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs.
checkpoint=MODEL,
)
self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3)
def __call__(self, data: Any) -> List[Dict[str, Any]]:
inputs = data["inputs"]
texts = []
if isinstance(inputs, str):
texts = [inputs]
elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs):
texts = inputs
else:
raise ValueError("Invalid input data format")
with torch.inference_mode():
if len(texts) == 1:
# It's a query
logger.info(f"Query: {texts}")
embedding = self._checkpoint.queryFromText(
queries=texts,
full_length_search=False, # Indicates whether to encode the query for a full-length search.
)
logger.info(f"Query embedding shape: {embedding.shape}")
return [
{"input": inputs, "query_embedding": embedding.tolist()[0]}
]
elif len(texts) > 1:
# It's a batch of chunks
logger.info(f"Batch of chunks: {texts}")
embeddings, token_counts = self._checkpoint.docFromText(
docs=texts,
bsize=self._config.bsize, # Batch size
keep_dims=True, # Do NOT flatten the embeddings
return_tokens=True, # Return the tokens as well
)
for text, embedding, token_count in zip(texts, embeddings, token_counts):
logger.info(f"Chunk: {text}")
logger.info(f"Chunk embedding shape: {embedding.shape}")
logger.info(f"Chunk count: {token_count}")
return [
{"input": _input, "chunk_embedding": embedding.tolist(), "token_count": token_count.tolist()}
for _input, embedding, token_count in zip(texts, embeddings, token_counts)
]
else:
raise ValueError("No data to process")