File size: 4,390 Bytes
12f7a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4976a8c
12f7a48
 
 
 
 
 
4976a8c
12f7a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from pathlib import Path
from typing import List, Optional, Dict, Any

from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from loguru import logger
from vllm import (
    AsyncLLMEngine, AsyncEngineArgs,
    PoolingParams, EmbeddingRequestOutput,
)

from hfendpoints import EndpointConfig, Handler, __version__
from hfendpoints.http import Context, run
from hfendpoints.tasks import Usage
from hfendpoints.tasks.embedding import EmbeddingRequest, EmbeddingResponse


def get_sentence_transformers_config(config: EndpointConfig) -> Optional[Dict[str, Any]]:
    st_config_path = None
    if not config.is_debug:
        st_config_path = (Path(config.repository) / "config_sentence_transformers.json")

    if not st_config_path or not st_config_path.exists():
        try:
            st_config_path = hf_hub_download(config.model_id, filename="config_sentence_transformers.json")
        except EntryNotFoundError:
            logger.info(f"Sentence Transformers config not found on {config.model_id}")
            return None

    with open(st_config_path, "r", encoding="utf-8") as config_f:
        from json import load
        return load(config_f)


class VllmEmbeddingHandler(Handler):
    __slot__ = ("_engine", "_sentence_transformer_config",)

    def __init__(self, config: EndpointConfig):
        self._sentence_transformers_config = get_sentence_transformers_config(config)
        self._engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(
                str(config.repository),
                task="embed",
                device="auto",
                dtype="bfloat16",
                kv_cache_dtype="auto",
                enforce_eager=False,
                enable_prefix_caching=True,
                disable_log_requests=True,
            )
        )

    async def embeds(
            self,
            prompts: str,
            pooling: PoolingParams,
            request_id: str
    ) -> List[EmbeddingRequestOutput]:
        outputs = []
        async for item in self._engine.encode(
                prompts,
                pooling_params=pooling,
                request_id=request_id,
                lora_request=None,
        ):
            outputs.append(EmbeddingRequestOutput.from_base(item))

        return outputs

    async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
        if "dimension" in request.parameters:
            pooling_params = PoolingParams(dimensions=request.parameters["dimension"])
        else:
            pooling_params = None

        if "prompt_name" in request.parameters and self._sentence_transformers_config:
            prompt_name = request.parameters["prompt_name"]
            tokenizer = await self._engine.get_tokenizer()
            prompt = self._sentence_transformers_config.get("prompts", {}).get(prompt_name, None)
            num_prompt_tokens = len(tokenizer.tokenize(prompt)) if prompt else 0
        else:
            prompt = None
            num_prompt_tokens = 0

        if request.is_batched:
            embeddings = []
            num_tokens = 0
            for idx, document in enumerate(request.inputs):
                input = f"{prompt}{document}" if prompt else document

                output = await self.embeds(input, pooling_params, f"{ctx.request_id}-{idx}")
                num_tokens += len(output[0].prompt_token_ids)
                embeddings += [output[0].outputs.embedding]
        else:
            input = f"{prompt} {request.inputs}" if prompt else request.inputs

            output = await self.embeds(input, pooling_params, ctx.request_id)
            num_tokens = len(output[0].prompt_token_ids)
            embeddings = output[0].outputs.embedding

        return EmbeddingResponse(embeddings, prompt_tokens=num_prompt_tokens, num_tokens=num_tokens)


def entrypoint():
    # Readout the endpoint configuration from the provided environment variable
    config = EndpointConfig.from_env()

    logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}")

    # Allocate handler
    handler = VllmEmbeddingHandler(config)

    # Allocate endpoint
    from hfendpoints.openai.embedding import EmbeddingEndpoint
    endpoint = EmbeddingEndpoint(handler)
    run(endpoint, config.interface, config.port)


if __name__ == "__main__":
    entrypoint()