File size: 4,751 Bytes
4c96de6
 
 
 
 
 
 
 
 
68b896e
4c96de6
 
 
 
 
 
 
 
 
 
 
68b896e
4c96de6
 
 
 
68b896e
 
 
 
 
 
 
 
 
 
 
 
 
 
4c96de6
 
 
 
 
 
 
 
 
 
 
 
fe7831c
4c96de6
 
 
 
fe7831c
4c96de6
 
 
 
 
fe7831c
 
 
68b896e
4c96de6
 
 
 
 
fe7831c
68b896e
 
 
 
 
 
 
 
 
 
4c96de6
68b896e
 
4c96de6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Any, Dict, List

from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
import torch
import logging

logger = logging.getLogger(__name__)

# Hardcoded, I know
MODEL = "fdurant/colbert-xm-for-inference-api"

class EndpointHandler():

    def __init__(self, path=""):
        self._config = ColBERTConfig(
            # Defaults copied from https://github.com/datastax/ragstack-ai/blob/main/libs/colbert/ragstack_colbert/colbert_embedding_model.py
            doc_maxlen=512, # Maximum number of tokens for document chunks. Should equal the chunk_size.
            nbits=2, # The number bits that each dimension encodes to.
            kmeans_niters=4, # Number of iterations for k-means clustering during quantization.
            nranks=-1, # Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs.
            checkpoint=MODEL, # Path to the model checkpoint.
        )
        self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3)

    def __call__(self, data: Any) -> List[Dict[str, Any]]:
        """
            data args:
                inputs (:obj: `str`)
            Return:
                A :obj:`list` : will be serialized and returned.
                When the input is a single query string, the returned list will contain a single dictionary with:
                    - input (:obj: `str`) : The input query.
                    - query_embedding (:obj: `list`) : The query embedding of shape (1, 32, 128).
                When the input is a batch (= list) of chunk strings, the returned list will contain a dictionary for each chunk:
                    - input (:obj: `str`) : The input chunk.
                    - chunk_embedding (:obj: `list`) : The chunk embedding of shape (1, num_tokens, 128)
                    - token_ids (:obj: `list`) : The token ids.
                    - token_list (:obj: `list`) : The token list.
        """
        inputs = data["inputs"]
        texts = []
        if isinstance(inputs, str):
            texts = [inputs]
        elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs):
            texts = inputs
        else:
            raise ValueError("Invalid input data format")
        with torch.inference_mode():
            
            if len(texts) == 1:
                # It's a query
                logger.info(f"Received query of 1 text with {len(texts[0])} characters and {len(texts[0].split())} words")
                embedding = self._checkpoint.queryFromText(
                    queries=texts,
                    full_length_search=False,  # Indicates whether to encode the query for a full-length search.
                )
                logger.info(f"Query embedding shape: {embedding.shape}")
                return [
                    {"input": inputs, "query_embedding": embedding.tolist()[0]}
                ]
            elif len(texts) > 1:
                # It's a batch of chunks
                logger.info(f"Received batch of {len(texts)} chunks")
                for i, text in enumerate(texts):
                    logger.info(f"Chunk {i} has {len(text)} characters and {len(text.split())} words")
                embeddings, token_id_lists = self._checkpoint.docFromText(
                    docs=texts,
                    bsize=self._config.bsize, # Batch size
                    keep_dims=True, # Do NOT flatten the embeddings
                    return_tokens=True, # Return the tokens as well
                )
                logger.info(f"Chunk embeddings shape: {embeddings.shape}")
                token_lists = []
                for text, embedding, token_ids in zip(texts, embeddings, token_id_lists):
                    logger.debug(f"Chunk: {text}")
                    logger.debug(f"Chunk embedding shape: {embedding.shape}")
                    logger.debug(f"Chunk token ids: {token_ids}")
                    token_list = self._checkpoint.doc_tokenizer.tok.convert_ids_to_tokens(token_ids)
                    token_lists.append(token_list)
                    logger.debug(f"Chunk tokens: {token_list}")
#                    reconstructed_text = self._checkpoint.doc_tokenizer.tok.decode(token_count)
#                    logger.debug(f"Reconstructed text with special tokens: {reconstructed_text}")
                return [
                    {"input": _input, "chunk_embedding": embedding.tolist(), "token_ids": token_ids.tolist(), "token_list": token_list}
                    for _input, embedding, token_ids, token_list in zip(texts, embeddings, token_id_lists, token_lists)
                ]
            else:
                raise ValueError("No data to process")