sofianhw commited on
Commit
8f99309
1 Parent(s): b7ac138

update some code to comply with 0.5.1

Browse files
Files changed (6) hide show
  1. api_server.py +39 -8
  2. protocol.py +112 -21
  3. serving_chat.py +518 -158
  4. serving_completion.py +378 -397
  5. serving_embedding.py +144 -0
  6. serving_engine.py +13 -4
api_server.py CHANGED
@@ -15,20 +15,27 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
15
  from prometheus_client import make_asgi_app
16
  from starlette.routing import Mount
17
 
18
- import vllm
19
  import vllm.envs as envs
20
  from vllm.engine.arg_utils import AsyncEngineArgs
21
  from vllm.engine.async_llm_engine import AsyncLLMEngine
22
  from vllm.entrypoints.openai.cli_args import make_arg_parser
 
 
23
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
24
  ChatCompletionResponse,
25
  CompletionRequest,
26
- EmbeddingRequest, ErrorResponse)
 
 
 
 
 
27
  from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
28
  from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
29
  from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
30
  from vllm.logger import init_logger
31
  from vllm.usage.usage_lib import UsageContext
 
32
 
33
  TIMEOUT_KEEP_ALIVE = 5 # seconds
34
 
@@ -59,6 +66,7 @@ async def lifespan(app: fastapi.FastAPI):
59
 
60
  app = fastapi.FastAPI(lifespan=lifespan)
61
 
 
62
  def parse_args():
63
  parser = make_arg_parser()
64
  return parser.parse_args()
@@ -84,7 +92,29 @@ async def health() -> Response:
84
  return Response(status_code=200)
85
 
86
 
87
- @app.get("/api/v1/models")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  async def show_available_models():
89
  models = await openai_serving_chat.show_available_models()
90
  return JSONResponse(content=models.model_dump())
@@ -92,11 +122,11 @@ async def show_available_models():
92
 
93
  @app.get("/version")
94
  async def show_version():
95
- ver = {"version": vllm.__version__}
96
  return JSONResponse(content=ver)
97
 
98
 
99
- @app.post("/api/v1/chat/completions")
100
  async def create_chat_completion(request: ChatCompletionRequest,
101
  raw_request: Request):
102
  generator = await openai_serving_chat.create_chat_completion(
@@ -112,7 +142,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
112
  return JSONResponse(content=generator.model_dump())
113
 
114
 
115
- @app.post("/api/v1/completions")
116
  async def create_completion(request: CompletionRequest, raw_request: Request):
117
  generator = await openai_serving_completion.create_completion(
118
  request, raw_request)
@@ -126,7 +156,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
126
  return JSONResponse(content=generator.model_dump())
127
 
128
 
129
- @app.post("/api/v1/embeddings")
130
  async def create_embedding(request: EmbeddingRequest, raw_request: Request):
131
  generator = await openai_serving_embedding.create_embedding(
132
  request, raw_request)
@@ -173,7 +203,7 @@ if __name__ == "__main__":
173
  raise ValueError(f"Invalid middleware {middleware}. "
174
  f"Must be a function or a class.")
175
 
176
- logger.info("vLLM API server version %s", vllm.__version__)
177
  logger.info("args: %s", args)
178
 
179
  if args.served_model_name is not None:
@@ -182,6 +212,7 @@ if __name__ == "__main__":
182
  served_model_names = [args.model]
183
 
184
  engine_args = AsyncEngineArgs.from_cli_args(args)
 
185
  engine = AsyncLLMEngine.from_engine_args(
186
  engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
187
 
 
15
  from prometheus_client import make_asgi_app
16
  from starlette.routing import Mount
17
 
 
18
  import vllm.envs as envs
19
  from vllm.engine.arg_utils import AsyncEngineArgs
20
  from vllm.engine.async_llm_engine import AsyncLLMEngine
21
  from vllm.entrypoints.openai.cli_args import make_arg_parser
22
+ # yapf conflicts with isort for this block
23
+ # yapf: disable
24
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
25
  ChatCompletionResponse,
26
  CompletionRequest,
27
+ DetokenizeRequest,
28
+ DetokenizeResponse,
29
+ EmbeddingRequest, ErrorResponse,
30
+ TokenizeRequest,
31
+ TokenizeResponse)
32
+ # yapf: enable
33
  from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
34
  from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
35
  from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
36
  from vllm.logger import init_logger
37
  from vllm.usage.usage_lib import UsageContext
38
+ from vllm.version import __version__ as VLLM_VERSION
39
 
40
  TIMEOUT_KEEP_ALIVE = 5 # seconds
41
 
 
66
 
67
  app = fastapi.FastAPI(lifespan=lifespan)
68
 
69
+
70
  def parse_args():
71
  parser = make_arg_parser()
72
  return parser.parse_args()
 
92
  return Response(status_code=200)
93
 
94
 
95
+ @app.post("/tokenize")
96
+ async def tokenize(request: TokenizeRequest):
97
+ generator = await openai_serving_completion.create_tokenize(request)
98
+ if isinstance(generator, ErrorResponse):
99
+ return JSONResponse(content=generator.model_dump(),
100
+ status_code=generator.code)
101
+ else:
102
+ assert isinstance(generator, TokenizeResponse)
103
+ return JSONResponse(content=generator.model_dump())
104
+
105
+
106
+ @app.post("/detokenize")
107
+ async def detokenize(request: DetokenizeRequest):
108
+ generator = await openai_serving_completion.create_detokenize(request)
109
+ if isinstance(generator, ErrorResponse):
110
+ return JSONResponse(content=generator.model_dump(),
111
+ status_code=generator.code)
112
+ else:
113
+ assert isinstance(generator, DetokenizeResponse)
114
+ return JSONResponse(content=generator.model_dump())
115
+
116
+
117
+ @app.get("/v1/models")
118
  async def show_available_models():
119
  models = await openai_serving_chat.show_available_models()
120
  return JSONResponse(content=models.model_dump())
 
122
 
123
  @app.get("/version")
124
  async def show_version():
125
+ ver = {"version": VLLM_VERSION}
126
  return JSONResponse(content=ver)
127
 
128
 
129
+ @app.post("/v1/chat/completions")
130
  async def create_chat_completion(request: ChatCompletionRequest,
131
  raw_request: Request):
132
  generator = await openai_serving_chat.create_chat_completion(
 
142
  return JSONResponse(content=generator.model_dump())
143
 
144
 
145
+ @app.post("/v1/completions")
146
  async def create_completion(request: CompletionRequest, raw_request: Request):
147
  generator = await openai_serving_completion.create_completion(
148
  request, raw_request)
 
156
  return JSONResponse(content=generator.model_dump())
157
 
158
 
159
+ @app.post("/v1/embeddings")
160
  async def create_embedding(request: EmbeddingRequest, raw_request: Request):
161
  generator = await openai_serving_embedding.create_embedding(
162
  request, raw_request)
 
203
  raise ValueError(f"Invalid middleware {middleware}. "
204
  f"Must be a function or a class.")
205
 
206
+ logger.info("vLLM API server version %s", VLLM_VERSION)
207
  logger.info("args: %s", args)
208
 
209
  if args.served_model_name is not None:
 
212
  served_model_names = [args.model]
213
 
214
  engine_args = AsyncEngineArgs.from_cli_args(args)
215
+
216
  engine = AsyncLLMEngine.from_engine_args(
217
  engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
218
 
protocol.py CHANGED
@@ -102,6 +102,11 @@ class ResponseFormat(OpenAIBaseModel):
102
  type: Literal["text", "json_object"]
103
 
104
 
 
 
 
 
 
105
  class FunctionDefinition(OpenAIBaseModel):
106
  name: str
107
  description: Optional[str] = None
@@ -140,6 +145,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
140
  le=torch.iinfo(torch.long).max)
141
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
142
  stream: Optional[bool] = False
 
143
  temperature: Optional[float] = 0.7
144
  top_p: Optional[float] = 1.0
145
  tools: Optional[List[ChatCompletionToolsParam]] = None
@@ -185,6 +191,27 @@ class ChatCompletionRequest(OpenAIBaseModel):
185
  "special tokens so this should be set to False (as is the "
186
  "default)."),
187
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  include_stop_str_in_output: Optional[bool] = Field(
189
  default=False,
190
  description=(
@@ -229,15 +256,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
229
 
230
  logits_processors = None
231
  if self.logit_bias:
 
 
 
 
 
 
 
 
 
 
232
 
233
  def logit_bias_logits_processor(
234
  token_ids: List[int],
235
  logits: torch.Tensor) -> torch.Tensor:
236
- assert self.logit_bias is not None
237
- for token_id, bias in self.logit_bias.items():
238
- # Clamp the bias between -100 and 100 per OpenAI API spec
239
- bias = min(100, max(-100, bias))
240
- logits[int(token_id)] += bias
241
  return logits
242
 
243
  logits_processors = [logit_bias_logits_processor]
@@ -269,6 +303,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
269
  logits_processors=logits_processors,
270
  )
271
 
 
 
 
 
 
 
 
 
 
272
  @model_validator(mode="before")
273
  @classmethod
274
  def check_guided_decoding_count(cls, data):
@@ -308,9 +351,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
308
  raise ValueError(
309
  "when using `top_logprobs`, `logprobs` must be set to true."
310
  )
311
- elif not 0 <= data["top_logprobs"] <= 20:
312
  raise ValueError(
313
- "`top_logprobs` must be a value in the interval [0, 20].")
314
  return data
315
 
316
 
@@ -332,6 +375,7 @@ class CompletionRequest(OpenAIBaseModel):
332
  le=torch.iinfo(torch.long).max)
333
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
334
  stream: Optional[bool] = False
 
335
  suffix: Optional[str] = None
336
  temperature: Optional[float] = 1.0
337
  top_p: Optional[float] = 1.0
@@ -404,15 +448,22 @@ class CompletionRequest(OpenAIBaseModel):
404
 
405
  logits_processors = None
406
  if self.logit_bias:
 
 
 
 
 
 
 
 
 
 
407
 
408
  def logit_bias_logits_processor(
409
  token_ids: List[int],
410
  logits: torch.Tensor) -> torch.Tensor:
411
- assert self.logit_bias is not None
412
- for token_id, bias in self.logit_bias.items():
413
- # Clamp the bias between -100 and 100 per OpenAI API spec
414
- bias = min(100, max(-100, bias))
415
- logits[int(token_id)] += bias
416
  return logits
417
 
418
  logits_processors = [logit_bias_logits_processor]
@@ -463,9 +514,16 @@ class CompletionRequest(OpenAIBaseModel):
463
  @classmethod
464
  def check_logprobs(cls, data):
465
  if "logprobs" in data and data[
466
- "logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
467
- raise ValueError(("if passed, `logprobs` must be a value",
468
- " in the interval [0, 5]."))
 
 
 
 
 
 
 
469
  return data
470
 
471
 
@@ -491,7 +549,8 @@ class CompletionLogProbs(OpenAIBaseModel):
491
  text_offset: List[int] = Field(default_factory=list)
492
  token_logprobs: List[Optional[float]] = Field(default_factory=list)
493
  tokens: List[str] = Field(default_factory=list)
494
- top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
 
495
 
496
 
497
  class CompletionResponseChoice(OpenAIBaseModel):
@@ -543,7 +602,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
543
  class EmbeddingResponseData(BaseModel):
544
  index: int
545
  object: str = "embedding"
546
- embedding: List[float]
547
 
548
 
549
  class EmbeddingResponse(BaseModel):
@@ -590,7 +649,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
590
  index: int
591
  message: ChatMessage
592
  logprobs: Optional[ChatCompletionLogProbs] = None
593
- finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
594
  stop_reason: Optional[Union[int, str]] = None
595
 
596
 
@@ -613,7 +672,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
613
  index: int
614
  delta: DeltaMessage
615
  logprobs: Optional[ChatCompletionLogProbs] = None
616
- finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
617
  stop_reason: Optional[Union[int, str]] = None
618
 
619
 
@@ -649,6 +708,17 @@ class BatchRequestInput(OpenAIBaseModel):
649
  body: Union[ChatCompletionRequest, ]
650
 
651
 
 
 
 
 
 
 
 
 
 
 
 
652
  class BatchRequestOutput(OpenAIBaseModel):
653
  """
654
  The per-line object of the batch output and error files
@@ -660,8 +730,29 @@ class BatchRequestOutput(OpenAIBaseModel):
660
  # inputs.
661
  custom_id: str
662
 
663
- response: Optional[ChatCompletionResponse]
664
 
665
  # For requests that failed with a non-HTTP error, this will contain more
666
  # information on the cause of the failure.
667
- error: Optional[Any]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  type: Literal["text", "json_object"]
103
 
104
 
105
+ class StreamOptions(OpenAIBaseModel):
106
+ include_usage: Optional[bool] = True
107
+ continuous_usage_stats: Optional[bool] = True
108
+
109
+
110
  class FunctionDefinition(OpenAIBaseModel):
111
  name: str
112
  description: Optional[str] = None
 
145
  le=torch.iinfo(torch.long).max)
146
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
147
  stream: Optional[bool] = False
148
+ stream_options: Optional[StreamOptions] = None
149
  temperature: Optional[float] = 0.7
150
  top_p: Optional[float] = 1.0
151
  tools: Optional[List[ChatCompletionToolsParam]] = None
 
191
  "special tokens so this should be set to False (as is the "
192
  "default)."),
193
  )
194
+ documents: Optional[List[Dict[str, str]]] = Field(
195
+ default=None,
196
+ description=
197
+ ("A list of dicts representing documents that will be accessible to "
198
+ "the model if it is performing RAG (retrieval-augmented generation)."
199
+ " If the template does not support RAG, this argument will have no "
200
+ "effect. We recommend that each document should be a dict containing "
201
+ "\"title\" and \"text\" keys."),
202
+ )
203
+ chat_template: Optional[str] = Field(
204
+ default=None,
205
+ description=(
206
+ "A Jinja template to use for this conversion. "
207
+ "If this is not passed, the model's default chat template will be "
208
+ "used instead."),
209
+ )
210
+ chat_template_kwargs: Optional[Dict[str, Any]] = Field(
211
+ default=None,
212
+ description=("Additional kwargs to pass to the template renderer. "
213
+ "Will be accessible by the chat template."),
214
+ )
215
  include_stop_str_in_output: Optional[bool] = Field(
216
  default=False,
217
  description=(
 
256
 
257
  logits_processors = None
258
  if self.logit_bias:
259
+ logit_bias: Dict[int, float] = {}
260
+ try:
261
+ for token_id, bias in self.logit_bias.items():
262
+ # Convert token_id to integer before we add to LLMEngine
263
+ # Clamp the bias between -100 and 100 per OpenAI API spec
264
+ logit_bias[int(token_id)] = min(100, max(-100, bias))
265
+ except ValueError as exc:
266
+ raise ValueError(f"Found token_id `{token_id}` in logit_bias "
267
+ f"but token_id must be an integer or string "
268
+ f"representing an integer") from exc
269
 
270
  def logit_bias_logits_processor(
271
  token_ids: List[int],
272
  logits: torch.Tensor) -> torch.Tensor:
273
+ for token_id, bias in logit_bias.items():
274
+ logits[token_id] += bias
 
 
 
275
  return logits
276
 
277
  logits_processors = [logit_bias_logits_processor]
 
303
  logits_processors=logits_processors,
304
  )
305
 
306
+ @model_validator(mode='before')
307
+ @classmethod
308
+ def validate_stream_options(cls, values):
309
+ if (values.get('stream_options') is not None
310
+ and not values.get('stream')):
311
+ raise ValueError(
312
+ "stream_options can only be set if stream is true")
313
+ return values
314
+
315
  @model_validator(mode="before")
316
  @classmethod
317
  def check_guided_decoding_count(cls, data):
 
351
  raise ValueError(
352
  "when using `top_logprobs`, `logprobs` must be set to true."
353
  )
354
+ elif data["top_logprobs"] < 0:
355
  raise ValueError(
356
+ "`top_logprobs` must be a value a positive value.")
357
  return data
358
 
359
 
 
375
  le=torch.iinfo(torch.long).max)
376
  stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
377
  stream: Optional[bool] = False
378
+ stream_options: Optional[StreamOptions] = None
379
  suffix: Optional[str] = None
380
  temperature: Optional[float] = 1.0
381
  top_p: Optional[float] = 1.0
 
448
 
449
  logits_processors = None
450
  if self.logit_bias:
451
+ logit_bias: Dict[int, float] = {}
452
+ try:
453
+ for token_id, bias in self.logit_bias.items():
454
+ # Convert token_id to integer
455
+ # Clamp the bias between -100 and 100 per OpenAI API spec
456
+ logit_bias[int(token_id)] = min(100, max(-100, bias))
457
+ except ValueError as exc:
458
+ raise ValueError(f"Found token_id `{token_id}` in logit_bias "
459
+ f"but token_id must be an integer or string "
460
+ f"representing an integer") from exc
461
 
462
  def logit_bias_logits_processor(
463
  token_ids: List[int],
464
  logits: torch.Tensor) -> torch.Tensor:
465
+ for token_id, bias in logit_bias.items():
466
+ logits[token_id] += bias
 
 
 
467
  return logits
468
 
469
  logits_processors = [logit_bias_logits_processor]
 
514
  @classmethod
515
  def check_logprobs(cls, data):
516
  if "logprobs" in data and data[
517
+ "logprobs"] is not None and not data["logprobs"] >= 0:
518
+ raise ValueError("if passed, `logprobs` must be a positive value.")
519
+ return data
520
+
521
+ @model_validator(mode="before")
522
+ @classmethod
523
+ def validate_stream_options(cls, data):
524
+ if data.get("stream_options") and not data.get("stream"):
525
+ raise ValueError(
526
+ "Stream options can only be defined when stream is True.")
527
  return data
528
 
529
 
 
549
  text_offset: List[int] = Field(default_factory=list)
550
  token_logprobs: List[Optional[float]] = Field(default_factory=list)
551
  tokens: List[str] = Field(default_factory=list)
552
+ top_logprobs: List[Optional[Dict[str,
553
+ float]]] = Field(default_factory=list)
554
 
555
 
556
  class CompletionResponseChoice(OpenAIBaseModel):
 
602
  class EmbeddingResponseData(BaseModel):
603
  index: int
604
  object: str = "embedding"
605
+ embedding: Union[List[float], str]
606
 
607
 
608
  class EmbeddingResponse(BaseModel):
 
649
  index: int
650
  message: ChatMessage
651
  logprobs: Optional[ChatCompletionLogProbs] = None
652
+ finish_reason: Optional[str] = None
653
  stop_reason: Optional[Union[int, str]] = None
654
 
655
 
 
672
  index: int
673
  delta: DeltaMessage
674
  logprobs: Optional[ChatCompletionLogProbs] = None
675
+ finish_reason: Optional[str] = None
676
  stop_reason: Optional[Union[int, str]] = None
677
 
678
 
 
708
  body: Union[ChatCompletionRequest, ]
709
 
710
 
711
+ class BatchResponseData(OpenAIBaseModel):
712
+ # HTTP status code of the response.
713
+ status_code: int = 200
714
+
715
+ # An unique identifier for the API request.
716
+ request_id: str
717
+
718
+ # The body of the response.
719
+ body: Union[ChatCompletionResponse, ]
720
+
721
+
722
  class BatchRequestOutput(OpenAIBaseModel):
723
  """
724
  The per-line object of the batch output and error files
 
730
  # inputs.
731
  custom_id: str
732
 
733
+ response: Optional[BatchResponseData]
734
 
735
  # For requests that failed with a non-HTTP error, this will contain more
736
  # information on the cause of the failure.
737
+ error: Optional[Any]
738
+
739
+
740
+ class TokenizeRequest(OpenAIBaseModel):
741
+ model: str
742
+ prompt: str
743
+ add_special_tokens: bool = Field(default=True)
744
+
745
+
746
+ class TokenizeResponse(OpenAIBaseModel):
747
+ tokens: List[int]
748
+ count: int
749
+ max_model_len: int
750
+
751
+
752
+ class DetokenizeRequest(OpenAIBaseModel):
753
+ model: str
754
+ tokens: List[int]
755
+
756
+
757
+ class DetokenizeResponse(OpenAIBaseModel):
758
+ prompt: str
serving_chat.py CHANGED
@@ -1,224 +1,558 @@
1
- import time
2
  import codecs
 
 
 
 
 
 
 
 
3
  from fastapi import Request
4
- from typing import AsyncGenerator, AsyncIterator, Union
5
- from vllm.logger import init_logger
6
- from vllm.utils import random_uuid
 
7
  from vllm.engine.async_llm_engine import AsyncLLMEngine
8
- from protocol import (
 
 
 
9
  ChatCompletionRequest, ChatCompletionResponse,
10
  ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
11
  ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
12
- UsageInfo)
 
 
 
 
 
 
 
 
13
  from vllm.outputs import RequestOutput
14
- from serving_engine import OpenAIServing
 
 
 
15
 
16
  logger = init_logger(__name__)
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class OpenAIServingChat(OpenAIServing):
20
 
21
  def __init__(self,
22
  engine: AsyncLLMEngine,
23
- served_model: str,
 
24
  response_role: str,
25
- chat_template=None):
26
- super().__init__(engine=engine, served_model=served_model)
 
 
 
 
 
27
  self.response_role = response_role
28
  self._load_chat_template(chat_template)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  async def create_chat_completion(
31
- self, request: ChatCompletionRequest, raw_request: Request
 
 
32
  ) -> Union[ErrorResponse, AsyncGenerator[str, None],
33
  ChatCompletionResponse]:
34
  """Completion API similar to OpenAI's API.
35
 
36
- See https://platform.openai.com/docs/api-reference/chat/create
37
- for the API specification. This API mimics the OpenAI ChatCompletion API.
 
38
 
39
- NOTE: Currently we do not support the following features:
40
  - function_call (Users should implement this by themselves)
41
- - logit_bias (to be supported by vLLM engine)
42
  """
43
  error_check_ret = await self._check_model(request)
44
  if error_check_ret is not None:
45
  return error_check_ret
46
 
47
- if request.logit_bias is not None and len(request.logit_bias) > 0:
48
- # TODO: support logit_bias in vLLM engine.
49
- return self.create_error_response(
50
- "logit_bias is not currently supported")
51
-
52
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  prompt = self.tokenizer.apply_chat_template(
54
- conversation=request.messages,
55
  tokenize=False,
56
- add_generation_prompt=request.add_generation_prompt)
 
 
 
 
 
57
  except Exception as e:
58
- logger.error(
59
- f"Error in applying chat template from request: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
60
  return self.create_error_response(str(e))
61
 
62
  request_id = f"cmpl-{random_uuid()}"
63
  try:
64
- token_ids = self._validate_prompt_and_tokenize(request,
65
- prompt=prompt)
 
 
 
66
  sampling_params = request.to_sampling_params()
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except ValueError as e:
68
  return self.create_error_response(str(e))
69
 
70
- result_generator = self.engine.generate(prompt, sampling_params,
71
- request_id, token_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # Streaming response
73
  if request.stream:
74
  return self.chat_completion_stream_generator(
75
- request, result_generator, request_id)
76
  else:
77
- return await self.chat_completion_full_generator(
78
- request, raw_request, result_generator, request_id)
 
 
 
 
 
79
 
80
  def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
81
  if request.add_generation_prompt:
82
  return self.response_role
83
  else:
84
- return request.messages[-1].role
85
 
86
  async def chat_completion_stream_generator(
87
  self, request: ChatCompletionRequest,
88
- result_generator: AsyncIterator[RequestOutput], request_id: str
89
- ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
90
-
91
- model_name = request.model
92
- created_time = int(time.monotonic())
93
  chunk_object_type = "chat.completion.chunk"
94
-
95
- # Send first response for each request.n (index) with the role
96
- role = self.get_chat_request_role(request)
97
- for i in range(request.n):
98
- choice_data = ChatCompletionResponseStreamChoice(
99
- index=i, delta=DeltaMessage(role=role), finish_reason=None)
100
- chunk = ChatCompletionStreamResponse(id=request_id,
101
- object=chunk_object_type,
102
- created=created_time,
103
- choices=[choice_data],
104
- model=model_name)
105
- data = chunk.model_dump_json(exclude_unset=True)
106
- yield f"data: {data}\n\n"
107
-
108
- # Send response to echo the input portion of the last message
109
- if request.echo:
110
- last_msg_content = ""
111
- if request.messages and isinstance(
112
- request.messages, list) and request.messages[-1].get(
113
- "content") and request.messages[-1].get(
114
- "role") == role:
115
- last_msg_content = request.messages[-1]["content"]
116
- if last_msg_content:
117
- for i in range(request.n):
118
- choice_data = ChatCompletionResponseStreamChoice(
119
- index=i,
120
- delta=DeltaMessage(content=last_msg_content),
121
- finish_reason=None)
122
- chunk = ChatCompletionStreamResponse(
123
- id=request_id,
124
- object=chunk_object_type,
125
- created=created_time,
126
- choices=[choice_data],
127
- model=model_name)
128
- data = chunk.model_dump_json(exclude_unset=True)
129
- yield f"data: {data}\n\n"
130
 
131
  # Send response for each token for each request.n (index)
 
132
  previous_texts = [""] * request.n
133
  previous_num_tokens = [0] * request.n
134
  finish_reason_sent = [False] * request.n
135
- async for res in result_generator:
136
- res: RequestOutput
137
- for output in res.outputs:
138
- i = output.index
139
-
140
- if finish_reason_sent[i]:
141
- continue
142
-
143
- delta_text = output.text[len(previous_texts[i]):]
144
- previous_texts[i] = output.text
145
- previous_num_tokens[i] = len(output.token_ids)
146
-
147
- if output.finish_reason is None:
148
- # Send token-by-token response for each request.n
149
- choice_data = ChatCompletionResponseStreamChoice(
150
- index=i,
151
- delta=DeltaMessage(content=delta_text),
152
- finish_reason=None)
153
- chunk = ChatCompletionStreamResponse(
154
- id=request_id,
155
- object=chunk_object_type,
156
- created=created_time,
157
- choices=[choice_data],
158
- model=model_name)
159
- data = chunk.model_dump_json(exclude_unset=True)
160
- yield f"data: {data}\n\n"
161
- else:
162
- # Send the finish response for each request.n only once
163
- prompt_tokens = len(res.prompt_token_ids)
164
- final_usage = UsageInfo(
165
- prompt_tokens=prompt_tokens,
166
- completion_tokens=previous_num_tokens[i],
167
- total_tokens=prompt_tokens + previous_num_tokens[i],
168
- )
169
- choice_data = ChatCompletionResponseStreamChoice(
170
- index=i,
171
- delta=DeltaMessage(content=delta_text),
172
- finish_reason=output.finish_reason)
173
- chunk = ChatCompletionStreamResponse(
174
- id=request_id,
175
- object=chunk_object_type,
176
- created=created_time,
177
- choices=[choice_data],
178
- model=model_name)
179
- if final_usage is not None:
180
- chunk.usage = final_usage
181
- data = chunk.model_dump_json(exclude_unset=True,
182
- exclude_none=True)
183
- yield f"data: {data}\n\n"
184
- finish_reason_sent[i] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # Send the final done message after all response.n are finished
186
  yield "data: [DONE]\n\n"
187
 
188
  async def chat_completion_full_generator(
189
- self, request: ChatCompletionRequest, raw_request: Request,
190
- result_generator: AsyncIterator[RequestOutput],
191
- request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
 
192
 
193
- model_name = request.model
194
- created_time = int(time.monotonic())
195
- final_res: RequestOutput = None
196
 
197
  async for res in result_generator:
198
- if await raw_request.is_disconnected():
199
  # Abort the request if the client disconnects.
200
  await self.engine.abort(request_id)
201
  return self.create_error_response("Client disconnected")
202
  final_res = res
203
  assert final_res is not None
204
 
205
- choices = []
 
206
  role = self.get_chat_request_role(request)
207
  for output in final_res.outputs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  choice_data = ChatCompletionResponseChoice(
209
  index=output.index,
210
- message=ChatMessage(role=role, content=output.text),
 
211
  finish_reason=output.finish_reason,
212
- )
213
  choices.append(choice_data)
214
 
215
  if request.echo:
216
  last_msg_content = ""
217
- if request.messages and isinstance(
218
- request.messages, list) and request.messages[-1].get(
219
- "content") and request.messages[-1].get(
220
- "role") == role:
221
- last_msg_content = request.messages[-1]["content"]
222
 
223
  for choice in choices:
224
  full_message = last_msg_content + choice.message.content
@@ -242,24 +576,50 @@ class OpenAIServingChat(OpenAIServing):
242
 
243
  return response
244
 
245
- def _load_chat_template(self, chat_template):
246
- if chat_template is not None:
247
- try:
248
- with open(chat_template, "r") as f:
249
- self.tokenizer.chat_template = f.read()
250
- except OSError:
251
- # If opening a file fails, set chat template to be args to
252
- # ensure we decode so our escape are interpreted correctly
253
- self.tokenizer.chat_template = codecs.decode(
254
- chat_template, "unicode_escape")
 
 
 
 
255
 
256
- logger.info(
257
- f"Using supplied chat template:\n{self.tokenizer.chat_template}"
258
- )
259
- elif self.tokenizer.chat_template is not None:
260
- logger.info(
261
- f"Using default chat template:\n{self.tokenizer.chat_template}"
262
- )
263
- else:
264
- logger.warning(
265
- "No chat template provided. Chat API will not work.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import codecs
2
+ import time
3
+ from dataclasses import dataclass, field
4
+ from functools import cached_property
5
+ from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
6
+ List, Optional)
7
+ from typing import Sequence as GenericSequence
8
+ from typing import TypedDict, Union, cast, final
9
+
10
  from fastapi import Request
11
+ from openai.types.chat import (ChatCompletionContentPartImageParam,
12
+ ChatCompletionContentPartTextParam)
13
+
14
+ from vllm.config import ModelConfig
15
  from vllm.engine.async_llm_engine import AsyncLLMEngine
16
+ from vllm.entrypoints.openai.protocol import (
17
+ ChatCompletionContentPartParam, ChatCompletionLogProb,
18
+ ChatCompletionLogProbs, ChatCompletionLogProbsContent,
19
+ ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
20
  ChatCompletionRequest, ChatCompletionResponse,
21
  ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
22
  ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
23
+ FunctionCall, ToolCall, UsageInfo)
24
+ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
25
+ OpenAIServing)
26
+ from vllm.inputs import PromptInputs
27
+ from vllm.logger import init_logger
28
+ from vllm.model_executor.guided_decoding import (
29
+ get_guided_decoding_logits_processor)
30
+ from vllm.multimodal import MultiModalDataDict
31
+ from vllm.multimodal.utils import async_get_and_parse_image
32
  from vllm.outputs import RequestOutput
33
+ from vllm.sequence import Logprob
34
+ from vllm.tracing import (contains_trace_headers, extract_trace_headers,
35
+ log_tracing_disabled_warning)
36
+ from vllm.utils import random_uuid
37
 
38
  logger = init_logger(__name__)
39
 
40
 
41
+ @final # So that it should be compatible with Dict[str, str]
42
+ class ConversationMessage(TypedDict):
43
+ role: str
44
+ content: str
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class ChatMessageParseResult:
49
+ messages: List[ConversationMessage]
50
+ mm_futures: List[Awaitable[MultiModalDataDict]] = field(
51
+ default_factory=list)
52
+
53
+
54
  class OpenAIServingChat(OpenAIServing):
55
 
56
  def __init__(self,
57
  engine: AsyncLLMEngine,
58
+ model_config: ModelConfig,
59
+ served_model_names: List[str],
60
  response_role: str,
61
+ lora_modules: Optional[List[LoRAModulePath]] = None,
62
+ chat_template: Optional[str] = None):
63
+ super().__init__(engine=engine,
64
+ model_config=model_config,
65
+ served_model_names=served_model_names,
66
+ lora_modules=lora_modules)
67
+
68
  self.response_role = response_role
69
  self._load_chat_template(chat_template)
70
 
71
+ def _load_chat_template(self, chat_template: Optional[str]):
72
+ tokenizer = self.tokenizer
73
+
74
+ if chat_template is not None:
75
+ try:
76
+ with open(chat_template, "r") as f:
77
+ tokenizer.chat_template = f.read()
78
+ except OSError as e:
79
+ JINJA_CHARS = "{}\n"
80
+ if not any(c in chat_template for c in JINJA_CHARS):
81
+ msg = (f"The supplied chat template ({chat_template}) "
82
+ f"looks like a file path, but it failed to be "
83
+ f"opened. Reason: {e}")
84
+ raise ValueError(msg) from e
85
+
86
+ # If opening a file fails, set chat template to be args to
87
+ # ensure we decode so our escape are interpreted correctly
88
+ tokenizer.chat_template = codecs.decode(
89
+ chat_template, "unicode_escape")
90
+
91
+ logger.info("Using supplied chat template:\n%s",
92
+ tokenizer.chat_template)
93
+ elif tokenizer.chat_template is not None:
94
+ logger.info("Using default chat template:\n%s",
95
+ tokenizer.chat_template)
96
+ else:
97
+ logger.warning(
98
+ "No chat template provided. Chat API will not work.")
99
+
100
+ @cached_property
101
+ def image_token_str(self) -> Optional[str]:
102
+ # TODO: Let user specify how to insert image tokens into prompt
103
+ # (similar to chat template)
104
+ model_type = self.model_config.hf_config.model_type
105
+ if model_type == "phi3_v":
106
+ # Workaround since this token is not defined in the tokenizer
107
+ return "<|image_1|>"
108
+ if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
109
+ "paligemma"):
110
+ # These models do not use image tokens in the prompt
111
+ return None
112
+ if model_type.startswith("llava"):
113
+ return self.tokenizer.decode(
114
+ self.model_config.hf_config.image_token_index)
115
+
116
+ else:
117
+ raise TypeError("Unknown model type: {model_type}")
118
+
119
+ # TODO: Let user specify how to insert image tokens into prompt
120
+ # (similar to chat template)
121
+ def _get_full_image_text_prompt(self, image_token_str: str,
122
+ text_prompt: str) -> str:
123
+ """Combine image and text prompts for vision language model"""
124
+
125
+ # NOTE: For now we assume all model architectures use the same
126
+ # image + text prompt format. This may change in the future.
127
+ return f"{image_token_str}\n{text_prompt}"
128
+
129
+ def _parse_chat_message_content_parts(
130
+ self,
131
+ role: str,
132
+ parts: Iterable[ChatCompletionContentPartParam],
133
+ ) -> ChatMessageParseResult:
134
+ texts: List[str] = []
135
+ mm_futures: List[Awaitable[MultiModalDataDict]] = []
136
+
137
+ for part in parts:
138
+ part_type = part["type"]
139
+ if part_type == "text":
140
+ text = cast(ChatCompletionContentPartTextParam, part)["text"]
141
+ texts.append(text)
142
+ elif part_type == "image_url":
143
+ if len(mm_futures) > 0:
144
+ raise NotImplementedError(
145
+ "Multiple 'image_url' input is currently not supported."
146
+ )
147
+
148
+ image_url = cast(ChatCompletionContentPartImageParam,
149
+ part)["image_url"]
150
+
151
+ if image_url.get("detail", "auto") != "auto":
152
+ logger.warning(
153
+ "'image_url.detail' is currently not supported and "
154
+ "will be ignored.")
155
+
156
+ image_future = async_get_and_parse_image(image_url["url"])
157
+ mm_futures.append(image_future)
158
+ else:
159
+ raise NotImplementedError(f"Unknown part type: {part_type}")
160
+
161
+ text_prompt = "\n".join(texts)
162
+
163
+ if mm_futures:
164
+ image_token_str = self.image_token_str
165
+ if image_token_str is not None:
166
+ if image_token_str in text_prompt:
167
+ logger.warning(
168
+ "Detected image token string in the text prompt. "
169
+ "Skipping prompt formatting.")
170
+ else:
171
+ text_prompt = self._get_full_image_text_prompt(
172
+ image_token_str=image_token_str,
173
+ text_prompt=text_prompt,
174
+ )
175
+
176
+ messages = [ConversationMessage(role=role, content=text_prompt)]
177
+
178
+ return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
179
+
180
+ def _parse_chat_message_content(
181
+ self,
182
+ message: ChatCompletionMessageParam,
183
+ ) -> ChatMessageParseResult:
184
+ role = message["role"]
185
+ content = message.get("content")
186
+
187
+ if content is None:
188
+ return ChatMessageParseResult(messages=[], mm_futures=[])
189
+ if isinstance(content, str):
190
+ messages = [ConversationMessage(role=role, content=content)]
191
+ return ChatMessageParseResult(messages=messages, mm_futures=[])
192
+
193
+ return self._parse_chat_message_content_parts(role, content)
194
+
195
  async def create_chat_completion(
196
+ self,
197
+ request: ChatCompletionRequest,
198
+ raw_request: Optional[Request] = None
199
  ) -> Union[ErrorResponse, AsyncGenerator[str, None],
200
  ChatCompletionResponse]:
201
  """Completion API similar to OpenAI's API.
202
 
203
+ See https://platform.openai.com/docs/api-reference/chat/create
204
+ for the API specification. This API mimics the OpenAI
205
+ ChatCompletion API.
206
 
207
+ NOTE: Currently we do not support the following feature:
208
  - function_call (Users should implement this by themselves)
 
209
  """
210
  error_check_ret = await self._check_model(request)
211
  if error_check_ret is not None:
212
  return error_check_ret
213
 
 
 
 
 
 
214
  try:
215
+ conversation: List[ConversationMessage] = []
216
+ mm_futures: List[Awaitable[MultiModalDataDict]] = []
217
+
218
+ for msg in request.messages:
219
+ chat_parsed_result = self._parse_chat_message_content(msg)
220
+
221
+ conversation.extend(chat_parsed_result.messages)
222
+ mm_futures.extend(chat_parsed_result.mm_futures)
223
+
224
+ tool_dicts = None if request.tools is None else [
225
+ tool.model_dump() for tool in request.tools
226
+ ]
227
+
228
  prompt = self.tokenizer.apply_chat_template(
229
+ conversation=conversation,
230
  tokenize=False,
231
+ add_generation_prompt=request.add_generation_prompt,
232
+ tools=tool_dicts,
233
+ documents=request.documents,
234
+ chat_template=request.chat_template,
235
+ **(request.chat_template_kwargs or {}),
236
+ )
237
  except Exception as e:
238
+ logger.error("Error in applying chat template from request: %s", e)
239
+ return self.create_error_response(str(e))
240
+
241
+ mm_data: Optional[MultiModalDataDict] = None
242
+ try:
243
+ if len(mm_futures):
244
+ # since we support only single mm data currently
245
+ assert len(
246
+ mm_futures
247
+ ) == 1, "Multiple 'image_url' input is currently not supported."
248
+ mm_data = await mm_futures[0]
249
+ except Exception as e:
250
+ logger.error("Error in loading multi-modal data: %s", e)
251
  return self.create_error_response(str(e))
252
 
253
  request_id = f"cmpl-{random_uuid()}"
254
  try:
255
+ # Tokenize/detokenize depending on prompt format (string/token list)
256
+ prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
257
+ request,
258
+ prompt=prompt,
259
+ add_special_tokens=request.add_special_tokens)
260
  sampling_params = request.to_sampling_params()
261
+ lora_request = self._maybe_get_lora(request)
262
+ decoding_config = await self.engine.get_decoding_config()
263
+ guided_decoding_backend = request.guided_decoding_backend \
264
+ or decoding_config.guided_decoding_backend
265
+ guided_decode_logits_processor = (
266
+ await get_guided_decoding_logits_processor(
267
+ guided_decoding_backend, request, await
268
+ self.engine.get_tokenizer()))
269
+ if guided_decode_logits_processor:
270
+ if sampling_params.logits_processors is None:
271
+ sampling_params.logits_processors = []
272
+ sampling_params.logits_processors.append(
273
+ guided_decode_logits_processor)
274
  except ValueError as e:
275
  return self.create_error_response(str(e))
276
 
277
+ inputs: PromptInputs = {
278
+ "prompt": prompt_text,
279
+ "prompt_token_ids": prompt_ids,
280
+ }
281
+ if mm_data:
282
+ inputs["multi_modal_data"] = mm_data
283
+
284
+ is_tracing_enabled = await self.engine.is_tracing_enabled()
285
+ trace_headers = None
286
+ if is_tracing_enabled and raw_request:
287
+ trace_headers = extract_trace_headers(raw_request.headers)
288
+ if not is_tracing_enabled and raw_request and contains_trace_headers(
289
+ raw_request.headers):
290
+ log_tracing_disabled_warning()
291
+
292
+ result_generator = self.engine.generate(
293
+ inputs,
294
+ sampling_params,
295
+ request_id,
296
+ lora_request,
297
+ trace_headers=trace_headers,
298
+ )
299
  # Streaming response
300
  if request.stream:
301
  return self.chat_completion_stream_generator(
302
+ request, result_generator, request_id, conversation)
303
  else:
304
+ try:
305
+ return await self.chat_completion_full_generator(
306
+ request, raw_request, result_generator, request_id,
307
+ conversation)
308
+ except ValueError as e:
309
+ # TODO: Use a vllm-specific Validation Error
310
+ return self.create_error_response(str(e))
311
 
312
  def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
313
  if request.add_generation_prompt:
314
  return self.response_role
315
  else:
316
+ return request.messages[-1]["role"]
317
 
318
  async def chat_completion_stream_generator(
319
  self, request: ChatCompletionRequest,
320
+ result_generator: AsyncIterator[RequestOutput], request_id: str,
321
+ conversation: List[ConversationMessage]
322
+ ) -> AsyncGenerator[str, None]:
323
+ model_name = self.served_model_names[0]
324
+ created_time = int(time.time())
325
  chunk_object_type = "chat.completion.chunk"
326
+ first_iteration = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  # Send response for each token for each request.n (index)
329
+ assert request.n is not None
330
  previous_texts = [""] * request.n
331
  previous_num_tokens = [0] * request.n
332
  finish_reason_sent = [False] * request.n
333
+ try:
334
+ async for res in result_generator:
335
+ # We need to do it here, because if there are exceptions in
336
+ # the result_generator, it needs to be sent as the FIRST
337
+ # response (by the try...catch).
338
+ if first_iteration:
339
+ # Send first response for each request.n (index) with
340
+ # the role
341
+ role = self.get_chat_request_role(request)
342
+ for i in range(request.n):
343
+ choice_data = ChatCompletionResponseStreamChoice(
344
+ index=i,
345
+ delta=DeltaMessage(role=role),
346
+ logprobs=None,
347
+ finish_reason=None)
348
+ chunk = ChatCompletionStreamResponse(
349
+ id=request_id,
350
+ object=chunk_object_type,
351
+ created=created_time,
352
+ choices=[choice_data],
353
+ model=model_name)
354
+ if (request.stream_options
355
+ and request.stream_options.include_usage):
356
+ chunk.usage = None
357
+ data = chunk.model_dump_json(exclude_unset=True)
358
+ yield f"data: {data}\n\n"
359
+
360
+ # Send response to echo the input portion of the
361
+ # last message
362
+ if request.echo:
363
+ last_msg_content = ""
364
+ if conversation and conversation[-1].get(
365
+ "content") and conversation[-1].get(
366
+ "role") == role:
367
+ last_msg_content = conversation[-1]["content"]
368
+
369
+ if last_msg_content:
370
+ for i in range(request.n):
371
+ choice_data = (
372
+ ChatCompletionResponseStreamChoice(
373
+ index=i,
374
+ delta=DeltaMessage(
375
+ content=last_msg_content),
376
+ finish_reason=None))
377
+ chunk = ChatCompletionStreamResponse(
378
+ id=request_id,
379
+ object=chunk_object_type,
380
+ created=created_time,
381
+ choices=[choice_data],
382
+ logprobs=None,
383
+ model=model_name)
384
+ if (request.stream_options and
385
+ request.stream_options.include_usage):
386
+ chunk.usage = None
387
+ data = chunk.model_dump_json(
388
+ exclude_unset=True)
389
+ yield f"data: {data}\n\n"
390
+ first_iteration = False
391
+
392
+ for output in res.outputs:
393
+ i = output.index
394
+
395
+ if finish_reason_sent[i]:
396
+ continue
397
+
398
+ delta_token_ids = output.token_ids[previous_num_tokens[i]:]
399
+ out_logprobs = output.logprobs[
400
+ previous_num_tokens[i]:] if output.logprobs else None
401
+
402
+ if request.logprobs and request.top_logprobs is not None:
403
+ assert out_logprobs is not None, (
404
+ "Did not output logprobs")
405
+ logprobs = self._create_chat_logprobs(
406
+ token_ids=delta_token_ids,
407
+ top_logprobs=out_logprobs,
408
+ num_output_top_logprobs=request.top_logprobs,
409
+ )
410
+ else:
411
+ logprobs = None
412
+
413
+ delta_text = output.text[len(previous_texts[i]):]
414
+ previous_texts[i] = output.text
415
+ previous_num_tokens[i] = len(output.token_ids)
416
+
417
+ if request.tool_choice and type(
418
+ request.tool_choice
419
+ ) is ChatCompletionNamedToolChoiceParam:
420
+ delta_message = DeltaMessage(tool_calls=[
421
+ ToolCall(function=FunctionCall(
422
+ name=request.tool_choice.function.name,
423
+ arguments=delta_text))
424
+ ])
425
+ else:
426
+ delta_message = DeltaMessage(content=delta_text)
427
+
428
+ if output.finish_reason is None:
429
+ # Send token-by-token response for each request.n
430
+
431
+ choice_data = ChatCompletionResponseStreamChoice(
432
+ index=i,
433
+ delta=delta_message,
434
+ logprobs=logprobs,
435
+ finish_reason=None)
436
+ chunk = ChatCompletionStreamResponse(
437
+ id=request_id,
438
+ object=chunk_object_type,
439
+ created=created_time,
440
+ choices=[choice_data],
441
+ model=model_name)
442
+ if (request.stream_options
443
+ and request.stream_options.include_usage):
444
+ chunk.usage = None
445
+ data = chunk.model_dump_json(exclude_unset=True)
446
+ yield f"data: {data}\n\n"
447
+ else:
448
+ # Send the finish response for each request.n only once
449
+ prompt_tokens = len(res.prompt_token_ids)
450
+ choice_data = ChatCompletionResponseStreamChoice(
451
+ index=i,
452
+ delta=delta_message,
453
+ logprobs=logprobs,
454
+ finish_reason=output.finish_reason,
455
+ stop_reason=output.stop_reason)
456
+ chunk = ChatCompletionStreamResponse(
457
+ id=request_id,
458
+ object=chunk_object_type,
459
+ created=created_time,
460
+ choices=[choice_data],
461
+ model=model_name)
462
+ if (request.stream_options
463
+ and request.stream_options.include_usage):
464
+ chunk.usage = None
465
+ data = chunk.model_dump_json(exclude_unset=True)
466
+ yield f"data: {data}\n\n"
467
+ finish_reason_sent[i] = True
468
+
469
+ if (request.stream_options
470
+ and request.stream_options.include_usage):
471
+ final_usage = UsageInfo(
472
+ prompt_tokens=prompt_tokens,
473
+ completion_tokens=previous_num_tokens[i],
474
+ total_tokens=prompt_tokens + previous_num_tokens[i],
475
+ )
476
+
477
+ final_usage_chunk = ChatCompletionStreamResponse(
478
+ id=request_id,
479
+ object=chunk_object_type,
480
+ created=created_time,
481
+ choices=[],
482
+ model=model_name,
483
+ usage=final_usage)
484
+ final_usage_data = (final_usage_chunk.model_dump_json(
485
+ exclude_unset=True, exclude_none=True))
486
+ yield f"data: {final_usage_data}\n\n"
487
+
488
+ except ValueError as e:
489
+ # TODO: Use a vllm-specific Validation Error
490
+ data = self.create_streaming_error_response(str(e))
491
+ yield f"data: {data}\n\n"
492
  # Send the final done message after all response.n are finished
493
  yield "data: [DONE]\n\n"
494
 
495
  async def chat_completion_full_generator(
496
+ self, request: ChatCompletionRequest, raw_request: Optional[Request],
497
+ result_generator: AsyncIterator[RequestOutput], request_id: str,
498
+ conversation: List[ConversationMessage]
499
+ ) -> Union[ErrorResponse, ChatCompletionResponse]:
500
 
501
+ model_name = self.served_model_names[0]
502
+ created_time = int(time.time())
503
+ final_res: Optional[RequestOutput] = None
504
 
505
  async for res in result_generator:
506
+ if raw_request is not None and await raw_request.is_disconnected():
507
  # Abort the request if the client disconnects.
508
  await self.engine.abort(request_id)
509
  return self.create_error_response("Client disconnected")
510
  final_res = res
511
  assert final_res is not None
512
 
513
+ choices: List[ChatCompletionResponseChoice] = []
514
+
515
  role = self.get_chat_request_role(request)
516
  for output in final_res.outputs:
517
+ token_ids = output.token_ids
518
+ out_logprobs = output.logprobs
519
+
520
+ if request.logprobs and request.top_logprobs is not None:
521
+ assert out_logprobs is not None, "Did not output logprobs"
522
+ logprobs = self._create_chat_logprobs(
523
+ token_ids=token_ids,
524
+ top_logprobs=out_logprobs,
525
+ num_output_top_logprobs=request.top_logprobs,
526
+ )
527
+ else:
528
+ logprobs = None
529
+
530
+ if request.tool_choice and type(
531
+ request.tool_choice) is ChatCompletionNamedToolChoiceParam:
532
+ message = ChatMessage(
533
+ role=role,
534
+ content="",
535
+ tool_calls=[
536
+ ToolCall(function=FunctionCall(
537
+ name=request.tool_choice.function.name,
538
+ arguments=output.text))
539
+ ])
540
+ elif not request.tool_choice or request.tool_choice == "none":
541
+ message = ChatMessage(role=role, content=output.text)
542
+
543
  choice_data = ChatCompletionResponseChoice(
544
  index=output.index,
545
+ message=message,
546
+ logprobs=logprobs,
547
  finish_reason=output.finish_reason,
548
+ stop_reason=output.stop_reason)
549
  choices.append(choice_data)
550
 
551
  if request.echo:
552
  last_msg_content = ""
553
+ if conversation and conversation[-1].get(
554
+ "content") and conversation[-1].get("role") == role:
555
+ last_msg_content = conversation[-1]["content"]
 
 
556
 
557
  for choice in choices:
558
  full_message = last_msg_content + choice.message.content
 
576
 
577
  return response
578
 
579
+ def _get_top_logprobs(
580
+ self, logprobs: Dict[int, Logprob],
581
+ top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
582
+ return [
583
+ ChatCompletionLogProb(
584
+ token=self._get_decoded_token(p[1], p[0]),
585
+ logprob=max(p[1].logprob, -9999.0),
586
+ bytes=list(
587
+ self._get_decoded_token(p[1],
588
+ p[0]).encode("utf-8",
589
+ errors="replace")))
590
+ for i, p in enumerate(logprobs.items())
591
+ if top_logprobs and i < top_logprobs
592
+ ]
593
 
594
+ def _create_chat_logprobs(
595
+ self,
596
+ token_ids: GenericSequence[int],
597
+ top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
598
+ num_output_top_logprobs: Optional[int] = None,
599
+ ) -> ChatCompletionLogProbs:
600
+ """Create OpenAI-style logprobs."""
601
+
602
+ logprobs_content = []
603
+
604
+ for i, token_id in enumerate(token_ids):
605
+ step_top_logprobs = top_logprobs[i]
606
+ if step_top_logprobs is None:
607
+ logprobs_content.append(
608
+ ChatCompletionLogProbsContent(
609
+ token=self.tokenizer.decode(token_id),
610
+ bytes=list(
611
+ self.tokenizer.decode(token_id).encode(
612
+ "utf-8", errors="replace"))))
613
+ else:
614
+ logprobs_content.append(
615
+ ChatCompletionLogProbsContent(
616
+ token=step_top_logprobs[token_id].decoded_token,
617
+ logprob=max(step_top_logprobs[token_id].logprob,
618
+ -9999.0),
619
+ bytes=list(
620
+ step_top_logprobs[token_id].decoded_token.encode(
621
+ "utf-8", errors="replace")),
622
+ top_logprobs=self._get_top_logprobs(
623
+ step_top_logprobs, num_output_top_logprobs)))
624
+
625
+ return ChatCompletionLogProbs(content=logprobs_content)
serving_completion.py CHANGED
@@ -1,24 +1,26 @@
1
- import codecs
2
  import time
3
- from dataclasses import dataclass
4
- from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
5
  Optional)
6
  from typing import Sequence as GenericSequence
7
- from typing import TypedDict, Union, cast, final
8
 
9
  from fastapi import Request
10
- from openai.types.chat import ChatCompletionContentPartTextParam
11
 
12
  from vllm.config import ModelConfig
13
  from vllm.engine.async_llm_engine import AsyncLLMEngine
14
- from vllm.entrypoints.openai.protocol import (
15
- ChatCompletionContentPartParam, ChatCompletionLogProb,
16
- ChatCompletionLogProbs, ChatCompletionLogProbsContent,
17
- ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
18
- ChatCompletionRequest, ChatCompletionResponse,
19
- ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
20
- ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
21
- FunctionCall, ToolCall, UsageInfo)
 
 
 
 
 
22
  from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
23
  OpenAIServing)
24
  from vllm.logger import init_logger
@@ -26,417 +28,364 @@ from vllm.model_executor.guided_decoding import (
26
  get_guided_decoding_logits_processor)
27
  from vllm.outputs import RequestOutput
28
  from vllm.sequence import Logprob
29
- from vllm.utils import random_uuid
 
 
30
 
31
  logger = init_logger(__name__)
32
 
33
-
34
- @final # So that it should be compatible with Dict[str, str]
35
- class ConversationMessage(TypedDict):
36
- role: str
37
- content: str
38
-
39
-
40
- @dataclass(frozen=True)
41
- class ChatMessageParseResult:
42
- messages: List[ConversationMessage]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
- class OpenAIServingChat(OpenAIServing):
46
 
47
- def __init__(self,
48
- engine: AsyncLLMEngine,
49
- model_config: ModelConfig,
50
  served_model_names: List[str],
51
- response_role: str,
52
- lora_modules: Optional[List[LoRAModulePath]] = None,
53
- chat_template: Optional[str] = None):
54
  super().__init__(engine=engine,
55
  model_config=model_config,
56
  served_model_names=served_model_names,
57
  lora_modules=lora_modules)
58
 
59
- self.response_role = response_role
60
- self._load_chat_template(chat_template)
61
-
62
- def _load_chat_template(self, chat_template: Optional[str]):
63
- tokenizer = self.tokenizer
64
-
65
- if chat_template is not None:
66
- try:
67
- with open(chat_template, "r") as f:
68
- tokenizer.chat_template = f.read()
69
- except OSError as e:
70
- JINJA_CHARS = "{}\n"
71
- if not any(c in chat_template for c in JINJA_CHARS):
72
- msg = (f"The supplied chat template ({chat_template}) "
73
- f"looks like a file path, but it failed to be "
74
- f"opened. Reason: {e}")
75
- raise ValueError(msg) from e
76
-
77
- # If opening a file fails, set chat template to be args to
78
- # ensure we decode so our escape are interpreted correctly
79
- tokenizer.chat_template = codecs.decode(
80
- chat_template, "unicode_escape")
81
-
82
- logger.info("Using supplied chat template:\n%s",
83
- tokenizer.chat_template)
84
- elif tokenizer.chat_template is not None:
85
- logger.info("Using default chat template:\n%s",
86
- tokenizer.chat_template)
87
- else:
88
- logger.warning(
89
- "No chat template provided. Chat API will not work.")
90
-
91
- def _parse_chat_message_content_parts(
92
- self,
93
- role: str,
94
- parts: Iterable[ChatCompletionContentPartParam],
95
- ) -> ChatMessageParseResult:
96
- texts: List[str] = []
97
-
98
- for _, part in enumerate(parts):
99
- part_type = part["type"]
100
- if part_type == "text":
101
- text = cast(ChatCompletionContentPartTextParam, part)["text"]
102
-
103
- texts.append(text)
104
- else:
105
- raise NotImplementedError(f"Unknown part type: {part_type}")
106
-
107
- messages = [ConversationMessage(role=role, content="\n".join(texts))]
108
-
109
- return ChatMessageParseResult(messages=messages)
110
-
111
- def _parse_chat_message_content(
112
- self,
113
- message: ChatCompletionMessageParam,
114
- ) -> ChatMessageParseResult:
115
- role = message["role"]
116
- content = message.get("content")
117
-
118
- if content is None:
119
- return ChatMessageParseResult(messages=[])
120
- if isinstance(content, str):
121
- messages = [ConversationMessage(role=role, content=content)]
122
- return ChatMessageParseResult(messages=messages)
123
-
124
- return self._parse_chat_message_content_parts(role, content)
125
-
126
- async def create_chat_completion(
127
- self,
128
- request: ChatCompletionRequest,
129
- raw_request: Optional[Request] = None
130
- ) -> Union[ErrorResponse, AsyncGenerator[str, None],
131
- ChatCompletionResponse]:
132
  """Completion API similar to OpenAI's API.
133
 
134
- See https://platform.openai.com/docs/api-reference/chat/create
135
- for the API specification. This API mimics the OpenAI
136
- ChatCompletion API.
137
 
138
  NOTE: Currently we do not support the following feature:
139
- - function_call (Users should implement this by themselves)
 
140
  """
141
  error_check_ret = await self._check_model(request)
142
  if error_check_ret is not None:
143
  return error_check_ret
144
 
145
- try:
146
- conversation: List[ConversationMessage] = []
147
-
148
- for msg in request.messages:
149
- parsed_msg = self._parse_chat_message_content(msg)
150
-
151
- conversation.extend(parsed_msg.messages)
152
-
153
- prompt = self.tokenizer.apply_chat_template(
154
- conversation=conversation,
155
- tokenize=False,
156
- add_generation_prompt=request.add_generation_prompt,
157
- )
158
- except Exception as e:
159
- logger.error("Error in applying chat template from request: %s", e)
160
- return self.create_error_response(str(e))
161
 
 
162
  request_id = f"cmpl-{random_uuid()}"
 
 
 
 
163
  try:
164
- # Tokenize/detokenize depending on prompt format (string/token list)
165
- prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
166
- request,
167
- prompt=prompt,
168
- add_special_tokens=request.add_special_tokens)
169
  sampling_params = request.to_sampling_params()
170
  lora_request = self._maybe_get_lora(request)
171
  decoding_config = await self.engine.get_decoding_config()
172
  guided_decoding_backend = request.guided_decoding_backend \
173
  or decoding_config.guided_decoding_backend
174
- guided_decode_logits_processor = (
175
  await get_guided_decoding_logits_processor(
176
  guided_decoding_backend, request, await
177
  self.engine.get_tokenizer()))
178
- if guided_decode_logits_processor:
179
  if sampling_params.logits_processors is None:
180
  sampling_params.logits_processors = []
181
  sampling_params.logits_processors.append(
182
- guided_decode_logits_processor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  except ValueError as e:
 
184
  return self.create_error_response(str(e))
185
 
186
- result_generator = self.engine.generate(
187
- {
188
- "prompt": prompt_text,
189
- "prompt_token_ids": prompt_ids
190
- },
191
- sampling_params,
192
- request_id,
193
- lora_request,
194
- )
 
195
  # Streaming response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if request.stream:
197
- return self.chat_completion_stream_generator(
198
- request, result_generator, request_id, conversation)
199
- else:
200
- try:
201
- return await self.chat_completion_full_generator(
202
- request, raw_request, result_generator, request_id,
203
- conversation)
204
- except ValueError as e:
205
- # TODO: Use a vllm-specific Validation Error
206
- return self.create_error_response(str(e))
207
-
208
- def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
209
- if request.add_generation_prompt:
210
- return self.response_role
211
- else:
212
- return request.messages[-1]["role"]
213
 
214
- async def chat_completion_stream_generator(
215
- self, request: ChatCompletionRequest,
216
- result_generator: AsyncIterator[RequestOutput], request_id: str,
217
- conversation: List[ConversationMessage]
218
- ) -> AsyncGenerator[str, None]:
219
- model_name = self.served_model_names[0]
220
- created_time = int(time.time())
221
- chunk_object_type = "chat.completion.chunk"
222
- first_iteration = True
223
 
224
- # Send response for each token for each request.n (index)
225
- assert request.n is not None
226
- previous_texts = [""] * request.n
227
- previous_num_tokens = [0] * request.n
228
- finish_reason_sent = [False] * request.n
229
- try:
230
- async for res in result_generator:
231
- # We need to do it here, because if there are exceptions in
232
- # the result_generator, it needs to be sent as the FIRST
233
- # response (by the try...catch).
234
- if first_iteration:
235
- # Send first response for each request.n (index) with
236
- # the role
237
- role = self.get_chat_request_role(request)
238
- for i in range(request.n):
239
- choice_data = ChatCompletionResponseStreamChoice(
240
- index=i,
241
- delta=DeltaMessage(role=role),
242
- logprobs=None,
243
- finish_reason=None)
244
- chunk = ChatCompletionStreamResponse(
245
- id=request_id,
246
- object=chunk_object_type,
247
- created=created_time,
248
- choices=[choice_data],
249
- model=model_name)
250
- data = chunk.model_dump_json(exclude_unset=True)
251
- yield f"data: {data}\n\n"
252
-
253
- # Send response to echo the input portion of the
254
- # last message
255
- if request.echo:
256
- last_msg_content = ""
257
- if conversation and conversation[-1].get(
258
- "content") and conversation[-1].get(
259
- "role") == role:
260
- last_msg_content = conversation[-1]["content"]
261
-
262
- if last_msg_content:
263
- for i in range(request.n):
264
- choice_data = (
265
- ChatCompletionResponseStreamChoice(
266
- index=i,
267
- delta=DeltaMessage(
268
- content=last_msg_content),
269
- finish_reason=None))
270
- chunk = ChatCompletionStreamResponse(
271
- id=request_id,
272
- object=chunk_object_type,
273
- created=created_time,
274
- choices=[choice_data],
275
- logprobs=None,
276
- model=model_name)
277
- data = chunk.model_dump_json(
278
- exclude_unset=True)
279
- yield f"data: {data}\n\n"
280
- first_iteration = False
281
 
282
- for output in res.outputs:
283
- i = output.index
284
 
285
- if finish_reason_sent[i]:
286
- continue
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- delta_token_ids = output.token_ids[previous_num_tokens[i]:]
289
- top_logprobs = output.logprobs[
290
- previous_num_tokens[i]:] if output.logprobs else None
291
 
292
- if request.logprobs:
293
- logprobs = self._create_chat_logprobs(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  token_ids=delta_token_ids,
295
- top_logprobs=top_logprobs,
296
- num_output_top_logprobs=request.top_logprobs,
 
297
  )
298
  else:
299
  logprobs = None
300
 
301
- delta_text = output.text[len(previous_texts[i]):]
302
  previous_texts[i] = output.text
303
  previous_num_tokens[i] = len(output.token_ids)
304
-
305
- if request.tool_choice and type(
306
- request.tool_choice
307
- ) is ChatCompletionNamedToolChoiceParam:
308
- delta_message = DeltaMessage(tool_calls=[
309
- ToolCall(function=FunctionCall(
310
- name=request.tool_choice.function.name,
311
- arguments=delta_text))
 
 
 
 
 
 
 
312
  ])
313
- else:
314
- delta_message = DeltaMessage(content=delta_text)
315
-
316
- if output.finish_reason is None:
317
- # Send token-by-token response for each request.n
318
-
319
- choice_data = ChatCompletionResponseStreamChoice(
320
- index=i,
321
- delta=delta_message,
322
- logprobs=logprobs,
323
- finish_reason=None)
324
- chunk = ChatCompletionStreamResponse(
325
- id=request_id,
326
- object=chunk_object_type,
327
- created=created_time,
328
- choices=[choice_data],
329
- model=model_name)
330
- data = chunk.model_dump_json(exclude_unset=True)
331
- yield f"data: {data}\n\n"
332
- else:
333
- # Send the finish response for each request.n only once
334
- prompt_tokens = len(res.prompt_token_ids)
335
- final_usage = UsageInfo(
336
- prompt_tokens=prompt_tokens,
337
- completion_tokens=previous_num_tokens[i],
338
- total_tokens=prompt_tokens +
339
- previous_num_tokens[i],
340
- )
341
- choice_data = ChatCompletionResponseStreamChoice(
342
- index=i,
343
- delta=delta_message,
344
- logprobs=logprobs,
345
- finish_reason=output.finish_reason,
346
- stop_reason=output.stop_reason)
347
- chunk = ChatCompletionStreamResponse(
348
- id=request_id,
349
- object=chunk_object_type,
350
- created=created_time,
351
- choices=[choice_data],
352
- model=model_name)
353
- if final_usage is not None:
354
- chunk.usage = final_usage
355
- data = chunk.model_dump_json(exclude_unset=True,
356
- exclude_none=True)
357
- yield f"data: {data}\n\n"
358
- finish_reason_sent[i] = True
359
  except ValueError as e:
360
  # TODO: Use a vllm-specific Validation Error
361
  data = self.create_streaming_error_response(str(e))
362
  yield f"data: {data}\n\n"
363
- # Send the final done message after all response.n are finished
364
  yield "data: [DONE]\n\n"
365
 
366
- async def chat_completion_full_generator(
367
- self, request: ChatCompletionRequest, raw_request: Optional[Request],
368
- result_generator: AsyncIterator[RequestOutput], request_id: str,
369
- conversation: List[ConversationMessage]
370
- ) -> Union[ErrorResponse, ChatCompletionResponse]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- model_name = self.served_model_names[0]
373
- created_time = int(time.time())
374
- final_res: Optional[RequestOutput] = None
375
 
376
- async for res in result_generator:
377
- if raw_request is not None and await raw_request.is_disconnected():
378
- # Abort the request if the client disconnects.
379
- await self.engine.abort(request_id)
380
- return self.create_error_response("Client disconnected")
381
- final_res = res
382
- assert final_res is not None
383
-
384
- choices = []
385
-
386
- role = self.get_chat_request_role(request)
387
- for output in final_res.outputs:
388
- token_ids = output.token_ids
389
- top_logprobs = output.logprobs
390
-
391
- if request.logprobs:
392
- logprobs = self._create_chat_logprobs(
393
- token_ids=token_ids,
394
- top_logprobs=top_logprobs,
395
- num_output_top_logprobs=request.top_logprobs,
396
- )
397
- else:
398
- logprobs = None
399
-
400
- if request.tool_choice and type(
401
- request.tool_choice) is ChatCompletionNamedToolChoiceParam:
402
- message = ChatMessage(
403
- role=role,
404
- content="",
405
- tool_calls=[
406
- ToolCall(function=FunctionCall(
407
- name=request.tool_choice.function.name,
408
- arguments=output.text))
409
- ])
410
- elif not request.tool_choice or request.tool_choice == "none":
411
- message = ChatMessage(role=role, content=output.text)
412
-
413
- choice_data = ChatCompletionResponseChoice(
414
- index=output.index,
415
- message=message,
416
- logprobs=logprobs,
417
- finish_reason=output.finish_reason,
418
- stop_reason=output.stop_reason)
419
- choices.append(choice_data)
420
-
421
- if request.echo:
422
- last_msg_content = ""
423
- if conversation and conversation[-1].get(
424
- "content") and conversation[-1].get("role") == role:
425
- last_msg_content = conversation[-1]["content"]
426
-
427
- for choice in choices:
428
- full_message = last_msg_content + choice.message.content
429
- choice.message.content = full_message
430
-
431
- num_prompt_tokens = len(final_res.prompt_token_ids)
432
- num_generated_tokens = sum(
433
- len(output.token_ids) for output in final_res.outputs)
434
  usage = UsageInfo(
435
  prompt_tokens=num_prompt_tokens,
436
  completion_tokens=num_generated_tokens,
437
  total_tokens=num_prompt_tokens + num_generated_tokens,
438
  )
439
- response = ChatCompletionResponse(
 
440
  id=request_id,
441
  created=created_time,
442
  model=model_name,
@@ -444,52 +393,84 @@ class OpenAIServingChat(OpenAIServing):
444
  usage=usage,
445
  )
446
 
447
- return response
448
-
449
- def _get_top_logprobs(
450
- self, logprobs: Dict[int, Logprob],
451
- top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
452
- return [
453
- ChatCompletionLogProb(
454
- token=self._get_decoded_token(p[1], p[0]),
455
- logprob=max(p[1].logprob, -9999.0),
456
- bytes=list(
457
- self._get_decoded_token(p[1],
458
- p[0]).encode("utf-8",
459
- errors="replace")))
460
- for i, p in enumerate(logprobs.items())
461
- if top_logprobs and i < top_logprobs
462
- ]
463
-
464
- def _create_chat_logprobs(
465
  self,
466
  token_ids: GenericSequence[int],
467
  top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
468
- num_output_top_logprobs: Optional[int] = None,
469
- ) -> ChatCompletionLogProbs:
470
- """Create OpenAI-style logprobs."""
 
 
 
 
 
471
 
472
- logprobs_content = []
473
 
474
  for i, token_id in enumerate(token_ids):
475
  step_top_logprobs = top_logprobs[i]
476
  if step_top_logprobs is None:
477
- logprobs_content.append(
478
- ChatCompletionLogProbsContent(
479
- token=self.tokenizer.decode(token_id),
480
- bytes=list(
481
- self.tokenizer.decode(token_id).encode(
482
- "utf-8", errors="replace"))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  else:
484
- logprobs_content.append(
485
- ChatCompletionLogProbsContent(
486
- token=step_top_logprobs[token_id].decoded_token,
487
- logprob=max(step_top_logprobs[token_id].logprob,
488
- -9999.0),
489
- bytes=list(
490
- step_top_logprobs[token_id].decoded_token.encode(
491
- "utf-8", errors="replace")),
492
- top_logprobs=self._get_top_logprobs(
493
- step_top_logprobs, num_output_top_logprobs)))
494
-
495
- return ChatCompletionLogProbs(content=logprobs_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
+ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
 
3
  Optional)
4
  from typing import Sequence as GenericSequence
5
+ from typing import Tuple
6
 
7
  from fastapi import Request
 
8
 
9
  from vllm.config import ModelConfig
10
  from vllm.engine.async_llm_engine import AsyncLLMEngine
11
+ # yapf conflicts with isort for this block
12
+ # yapf: disable
13
+ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
14
+ CompletionRequest,
15
+ CompletionResponse,
16
+ CompletionResponseChoice,
17
+ CompletionResponseStreamChoice,
18
+ CompletionStreamResponse,
19
+ DetokenizeRequest,
20
+ DetokenizeResponse,
21
+ TokenizeRequest,
22
+ TokenizeResponse, UsageInfo)
23
+ # yapf: enable
24
  from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
25
  OpenAIServing)
26
  from vllm.logger import init_logger
 
28
  get_guided_decoding_logits_processor)
29
  from vllm.outputs import RequestOutput
30
  from vllm.sequence import Logprob
31
+ from vllm.tracing import (contains_trace_headers, extract_trace_headers,
32
+ log_tracing_disabled_warning)
33
+ from vllm.utils import merge_async_iterators, random_uuid
34
 
35
  logger = init_logger(__name__)
36
 
37
+ TypeTokenIDs = List[int]
38
+ TypeTopLogProbs = List[Optional[Dict[int, float]]]
39
+ TypeCreateLogProbsFn = Callable[
40
+ [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
41
+
42
+
43
+ def parse_prompt_format(prompt) -> Tuple[bool, list]:
44
+ # get the prompt, openai supports the following
45
+ # "a string, array of strings, array of tokens, or array of token arrays."
46
+ prompt_is_tokens = False
47
+ prompts = [prompt] # case 1: a string
48
+ if isinstance(prompt, list):
49
+ if len(prompt) == 0:
50
+ raise ValueError("please provide at least one prompt")
51
+ elif isinstance(prompt[0], str):
52
+ prompt_is_tokens = False
53
+ prompts = prompt # case 2: array of strings
54
+ elif isinstance(prompt[0], int):
55
+ prompt_is_tokens = True
56
+ prompts = [prompt] # case 3: array of tokens
57
+ elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
58
+ prompt_is_tokens = True
59
+ prompts = prompt # case 4: array of token arrays
60
+ else:
61
+ raise ValueError("prompt must be a string, array of strings, "
62
+ "array of tokens, or array of token arrays")
63
+ return prompt_is_tokens, prompts
64
 
65
 
66
+ class OpenAIServingCompletion(OpenAIServing):
67
 
68
+ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
 
 
69
  served_model_names: List[str],
70
+ lora_modules: Optional[List[LoRAModulePath]]):
 
 
71
  super().__init__(engine=engine,
72
  model_config=model_config,
73
  served_model_names=served_model_names,
74
  lora_modules=lora_modules)
75
 
76
+ async def create_completion(self, request: CompletionRequest,
77
+ raw_request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  """Completion API similar to OpenAI's API.
79
 
80
+ See https://platform.openai.com/docs/api-reference/completions/create
81
+ for the API specification. This API mimics the OpenAI Completion API.
 
82
 
83
  NOTE: Currently we do not support the following feature:
84
+ - suffix (the language models we currently support do not support
85
+ suffix)
86
  """
87
  error_check_ret = await self._check_model(request)
88
  if error_check_ret is not None:
89
  return error_check_ret
90
 
91
+ # Return error for unsupported features.
92
+ if request.suffix is not None:
93
+ return self.create_error_response(
94
+ "suffix is not currently supported")
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ model_name = self.served_model_names[0]
97
  request_id = f"cmpl-{random_uuid()}"
98
+ created_time = int(time.time())
99
+
100
+ # Schedule the request and get the result generator.
101
+ generators: List[AsyncIterator[RequestOutput]] = []
102
  try:
 
 
 
 
 
103
  sampling_params = request.to_sampling_params()
104
  lora_request = self._maybe_get_lora(request)
105
  decoding_config = await self.engine.get_decoding_config()
106
  guided_decoding_backend = request.guided_decoding_backend \
107
  or decoding_config.guided_decoding_backend
108
+ guided_decode_logit_processor = (
109
  await get_guided_decoding_logits_processor(
110
  guided_decoding_backend, request, await
111
  self.engine.get_tokenizer()))
112
+ if guided_decode_logit_processor is not None:
113
  if sampling_params.logits_processors is None:
114
  sampling_params.logits_processors = []
115
  sampling_params.logits_processors.append(
116
+ guided_decode_logit_processor)
117
+ prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
118
+
119
+ for i, prompt in enumerate(prompts):
120
+ if prompt_is_tokens:
121
+ prompt_formats = self._validate_prompt_and_tokenize(
122
+ request,
123
+ prompt_ids=prompt,
124
+ truncate_prompt_tokens=sampling_params.
125
+ truncate_prompt_tokens)
126
+ else:
127
+ prompt_formats = self._validate_prompt_and_tokenize(
128
+ request,
129
+ prompt=prompt,
130
+ truncate_prompt_tokens=sampling_params.
131
+ truncate_prompt_tokens)
132
+ prompt_ids, prompt_text = prompt_formats
133
+
134
+ is_tracing_enabled = await self.engine.is_tracing_enabled()
135
+ trace_headers = None
136
+ if is_tracing_enabled:
137
+ trace_headers = extract_trace_headers(raw_request.headers)
138
+ if not is_tracing_enabled and contains_trace_headers(
139
+ raw_request.headers):
140
+ log_tracing_disabled_warning()
141
+
142
+ generator = self.engine.generate(
143
+ {
144
+ "prompt": prompt_text,
145
+ "prompt_token_ids": prompt_ids
146
+ },
147
+ sampling_params,
148
+ f"{request_id}-{i}",
149
+ lora_request=lora_request,
150
+ trace_headers=trace_headers,
151
+ )
152
+
153
+ generators.append(generator)
154
  except ValueError as e:
155
+ # TODO: Use a vllm-specific Validation Error
156
  return self.create_error_response(str(e))
157
 
158
+ result_generator: AsyncIterator[Tuple[
159
+ int, RequestOutput]] = merge_async_iterators(*generators)
160
+
161
+ # Similar to the OpenAI API, when n != best_of, we do not stream the
162
+ # results. In addition, we do not stream the results when use
163
+ # beam search.
164
+ stream = (request.stream
165
+ and (request.best_of is None or request.n == request.best_of)
166
+ and not request.use_beam_search)
167
+
168
  # Streaming response
169
+ if stream:
170
+ return self.completion_stream_generator(request,
171
+ raw_request,
172
+ result_generator,
173
+ request_id,
174
+ created_time,
175
+ model_name,
176
+ num_prompts=len(prompts))
177
+
178
+ # Non-streaming response
179
+ final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
180
+ try:
181
+ async for i, res in result_generator:
182
+ if await raw_request.is_disconnected():
183
+ # Abort the request if the client disconnects.
184
+ await self.engine.abort(f"{request_id}-{i}")
185
+ return self.create_error_response("Client disconnected")
186
+ final_res_batch[i] = res
187
+ response = self.request_output_to_completion_response(
188
+ final_res_batch, request, request_id, created_time, model_name)
189
+ except ValueError as e:
190
+ # TODO: Use a vllm-specific Validation Error
191
+ return self.create_error_response(str(e))
192
+
193
+ # When user requests streaming but we don't stream, we still need to
194
+ # return a streaming response with a single event.
195
  if request.stream:
196
+ response_json = response.model_dump_json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ async def fake_stream_generator() -> AsyncGenerator[str, None]:
199
+ yield f"data: {response_json}\n\n"
200
+ yield "data: [DONE]\n\n"
 
 
 
 
 
 
201
 
202
+ return fake_stream_generator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ return response
 
205
 
206
+ async def completion_stream_generator(
207
+ self,
208
+ request: CompletionRequest,
209
+ raw_request: Request,
210
+ result_generator: AsyncIterator[Tuple[int, RequestOutput]],
211
+ request_id: str,
212
+ created_time: int,
213
+ model_name: str,
214
+ num_prompts: int,
215
+ ) -> AsyncGenerator[str, None]:
216
+ assert request.n is not None
217
+ previous_texts = [""] * request.n * num_prompts
218
+ previous_num_tokens = [0] * request.n * num_prompts
219
+ has_echoed = [False] * request.n * num_prompts
220
 
221
+ try:
222
+ async for prompt_idx, res in result_generator:
 
223
 
224
+ # Abort the request if the client disconnects.
225
+ if await raw_request.is_disconnected():
226
+ await self.engine.abort(f"{request_id}-{prompt_idx}")
227
+ raise StopAsyncIteration()
228
+
229
+ for output in res.outputs:
230
+ i = output.index + prompt_idx * request.n
231
+ # TODO(simon): optimize the performance by avoiding full
232
+ # text O(n^2) sending.
233
+
234
+ assert request.max_tokens is not None
235
+ if request.echo and request.max_tokens == 0:
236
+ # only return the prompt
237
+ delta_text = res.prompt
238
+ delta_token_ids = res.prompt_token_ids
239
+ out_logprobs = res.prompt_logprobs
240
+ has_echoed[i] = True
241
+ elif (request.echo and request.max_tokens > 0
242
+ and not has_echoed[i]):
243
+ # echo the prompt and first token
244
+ delta_text = res.prompt + output.text
245
+ delta_token_ids = (res.prompt_token_ids +
246
+ output.token_ids)
247
+ out_logprobs = res.prompt_logprobs + (output.logprobs
248
+ or [])
249
+ has_echoed[i] = True
250
+ else:
251
+ # return just the delta
252
+ delta_text = output.text[len(previous_texts[i]):]
253
+ delta_token_ids = output.token_ids[
254
+ previous_num_tokens[i]:]
255
+ out_logprobs = output.logprobs[previous_num_tokens[
256
+ i]:] if output.logprobs else None
257
+
258
+ if request.logprobs is not None:
259
+ assert out_logprobs is not None, (
260
+ "Did not output logprobs")
261
+ logprobs = self._create_completion_logprobs(
262
  token_ids=delta_token_ids,
263
+ top_logprobs=out_logprobs,
264
+ num_output_top_logprobs=request.logprobs,
265
+ initial_text_offset=len(previous_texts[i]),
266
  )
267
  else:
268
  logprobs = None
269
 
 
270
  previous_texts[i] = output.text
271
  previous_num_tokens[i] = len(output.token_ids)
272
+ finish_reason = output.finish_reason
273
+ stop_reason = output.stop_reason
274
+
275
+ chunk = CompletionStreamResponse(
276
+ id=request_id,
277
+ created=created_time,
278
+ model=model_name,
279
+ choices=[
280
+ CompletionResponseStreamChoice(
281
+ index=i,
282
+ text=delta_text,
283
+ logprobs=logprobs,
284
+ finish_reason=finish_reason,
285
+ stop_reason=stop_reason,
286
+ )
287
  ])
288
+ if (request.stream_options
289
+ and request.stream_options.include_usage):
290
+ if (request.stream_options.continuous_usage_stats
291
+ or output.finish_reason is not None):
292
+ prompt_tokens = len(res.prompt_token_ids)
293
+ completion_tokens = len(output.token_ids)
294
+ usage = UsageInfo(
295
+ prompt_tokens=prompt_tokens,
296
+ completion_tokens=completion_tokens,
297
+ total_tokens=prompt_tokens + completion_tokens,
298
+ )
299
+ if request.stream_options.continuous_usage_stats:
300
+ chunk.usage = usage
301
+ else:
302
+ chunk.usage = None
303
+
304
+ response_json = chunk.model_dump_json(exclude_unset=True)
305
+ yield f"data: {response_json}\n\n"
306
+
307
+ if (request.stream_options
308
+ and request.stream_options.include_usage):
309
+ final_usage_chunk = CompletionStreamResponse(
310
+ id=request_id,
311
+ created=created_time,
312
+ model=model_name,
313
+ choices=[],
314
+ usage=usage,
315
+ )
316
+ final_usage_data = (final_usage_chunk.model_dump_json(
317
+ exclude_unset=True, exclude_none=True))
318
+ yield f"data: {final_usage_data}\n\n"
319
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  except ValueError as e:
321
  # TODO: Use a vllm-specific Validation Error
322
  data = self.create_streaming_error_response(str(e))
323
  yield f"data: {data}\n\n"
 
324
  yield "data: [DONE]\n\n"
325
 
326
+ def request_output_to_completion_response(
327
+ self,
328
+ final_res_batch: List[RequestOutput],
329
+ request: CompletionRequest,
330
+ request_id: str,
331
+ created_time: int,
332
+ model_name: str,
333
+ ) -> CompletionResponse:
334
+ choices: List[CompletionResponseChoice] = []
335
+ num_prompt_tokens = 0
336
+ num_generated_tokens = 0
337
+ for final_res in final_res_batch:
338
+ assert final_res is not None
339
+ prompt_token_ids = final_res.prompt_token_ids
340
+ prompt_logprobs = final_res.prompt_logprobs
341
+ prompt_text = final_res.prompt
342
+
343
+ for output in final_res.outputs:
344
+ assert request.max_tokens is not None
345
+ if request.echo and request.max_tokens == 0:
346
+ token_ids = prompt_token_ids
347
+ out_logprobs = prompt_logprobs
348
+ output_text = prompt_text
349
+ elif request.echo and request.max_tokens > 0:
350
+ token_ids = prompt_token_ids + list(output.token_ids)
351
+ out_logprobs = (prompt_logprobs + output.logprobs
352
+ if request.logprobs is not None else None)
353
+ output_text = prompt_text + output.text
354
+ else:
355
+ token_ids = output.token_ids
356
+ out_logprobs = output.logprobs
357
+ output_text = output.text
358
+
359
+ if request.logprobs is not None:
360
+ assert out_logprobs is not None, "Did not output logprobs"
361
+ logprobs = self._create_completion_logprobs(
362
+ token_ids=token_ids,
363
+ top_logprobs=out_logprobs,
364
+ num_output_top_logprobs=request.logprobs,
365
+ )
366
+ else:
367
+ logprobs = None
368
+
369
+ choice_data = CompletionResponseChoice(
370
+ index=len(choices),
371
+ text=output_text,
372
+ logprobs=logprobs,
373
+ finish_reason=output.finish_reason,
374
+ stop_reason=output.stop_reason,
375
+ )
376
+ choices.append(choice_data)
377
 
378
+ num_prompt_tokens += len(prompt_token_ids)
379
+ num_generated_tokens += sum(
380
+ len(output.token_ids) for output in final_res.outputs)
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  usage = UsageInfo(
383
  prompt_tokens=num_prompt_tokens,
384
  completion_tokens=num_generated_tokens,
385
  total_tokens=num_prompt_tokens + num_generated_tokens,
386
  )
387
+
388
+ return CompletionResponse(
389
  id=request_id,
390
  created=created_time,
391
  model=model_name,
 
393
  usage=usage,
394
  )
395
 
396
+ def _create_completion_logprobs(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  self,
398
  token_ids: GenericSequence[int],
399
  top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
400
+ num_output_top_logprobs: int,
401
+ initial_text_offset: int = 0,
402
+ ) -> CompletionLogProbs:
403
+ """Create logprobs for OpenAI Completion API."""
404
+ out_text_offset: List[int] = []
405
+ out_token_logprobs: List[Optional[float]] = []
406
+ out_tokens: List[str] = []
407
+ out_top_logprobs: List[Optional[Dict[str, float]]] = []
408
 
409
+ last_token_len = 0
410
 
411
  for i, token_id in enumerate(token_ids):
412
  step_top_logprobs = top_logprobs[i]
413
  if step_top_logprobs is None:
414
+ token = self.tokenizer.decode(token_id)
415
+ out_tokens.append(token)
416
+ out_token_logprobs.append(None)
417
+ out_top_logprobs.append(None)
418
+ else:
419
+ token = self._get_decoded_token(step_top_logprobs[token_id],
420
+ token_id)
421
+ token_logprob = max(step_top_logprobs[token_id].logprob,
422
+ -9999.0)
423
+ out_tokens.append(token)
424
+ out_token_logprobs.append(token_logprob)
425
+
426
+ # makes sure to add the top num_output_top_logprobs + 1
427
+ # logprobs, as defined in the openai API
428
+ # (cf. https://github.com/openai/openai-openapi/blob/
429
+ # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
430
+ out_top_logprobs.append({
431
+ # Convert float("-inf") to the
432
+ # JSON-serializable float that OpenAI uses
433
+ self._get_decoded_token(top_lp[1], top_lp[0]):
434
+ max(top_lp[1].logprob, -9999.0)
435
+ for i, top_lp in enumerate(step_top_logprobs.items())
436
+ if num_output_top_logprobs >= i
437
+ })
438
+
439
+ if len(out_text_offset) == 0:
440
+ out_text_offset.append(initial_text_offset)
441
  else:
442
+ out_text_offset.append(out_text_offset[-1] + last_token_len)
443
+ last_token_len = len(token)
444
+
445
+ return CompletionLogProbs(
446
+ text_offset=out_text_offset,
447
+ token_logprobs=out_token_logprobs,
448
+ tokens=out_tokens,
449
+ top_logprobs=out_top_logprobs,
450
+ )
451
+
452
+ async def create_tokenize(self,
453
+ request: TokenizeRequest) -> TokenizeResponse:
454
+ error_check_ret = await self._check_model(request)
455
+ if error_check_ret is not None:
456
+ return error_check_ret
457
+
458
+ (input_ids, input_text) = self._validate_prompt_and_tokenize(
459
+ request,
460
+ prompt=request.prompt,
461
+ add_special_tokens=request.add_special_tokens)
462
+
463
+ return TokenizeResponse(tokens=input_ids,
464
+ count=len(input_ids),
465
+ max_model_len=self.max_model_len)
466
+
467
+ async def create_detokenize(
468
+ self, request: DetokenizeRequest) -> DetokenizeResponse:
469
+ error_check_ret = await self._check_model(request)
470
+ if error_check_ret is not None:
471
+ return error_check_ret
472
+
473
+ (input_ids, input_text) = self._validate_prompt_and_tokenize(
474
+ request, prompt_ids=request.tokens)
475
+
476
+ return DetokenizeResponse(prompt=input_text)
serving_embedding.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import time
3
+ from typing import AsyncIterator, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ from fastapi import Request
7
+
8
+ from vllm.config import ModelConfig
9
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
10
+ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
11
+ EmbeddingResponse,
12
+ EmbeddingResponseData, UsageInfo)
13
+ from vllm.entrypoints.openai.serving_completion import parse_prompt_format
14
+ from vllm.entrypoints.openai.serving_engine import OpenAIServing
15
+ from vllm.logger import init_logger
16
+ from vllm.outputs import EmbeddingRequestOutput
17
+ from vllm.utils import merge_async_iterators, random_uuid
18
+
19
+ logger = init_logger(__name__)
20
+
21
+ TypeTokenIDs = List[int]
22
+
23
+
24
+ def request_output_to_embedding_response(
25
+ final_res_batch: List[EmbeddingRequestOutput], request_id: str,
26
+ created_time: int, model_name: str,
27
+ encoding_format: str) -> EmbeddingResponse:
28
+ data: List[EmbeddingResponseData] = []
29
+ num_prompt_tokens = 0
30
+ for idx, final_res in enumerate(final_res_batch):
31
+ assert final_res is not None
32
+ prompt_token_ids = final_res.prompt_token_ids
33
+ embedding = final_res.outputs.embedding
34
+ if encoding_format == "base64":
35
+ embedding = base64.b64encode(np.array(embedding))
36
+ embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
37
+ data.append(embedding_data)
38
+
39
+ num_prompt_tokens += len(prompt_token_ids)
40
+
41
+ usage = UsageInfo(
42
+ prompt_tokens=num_prompt_tokens,
43
+ total_tokens=num_prompt_tokens,
44
+ )
45
+
46
+ return EmbeddingResponse(
47
+ id=request_id,
48
+ created=created_time,
49
+ model=model_name,
50
+ data=data,
51
+ usage=usage,
52
+ )
53
+
54
+
55
+ class OpenAIServingEmbedding(OpenAIServing):
56
+
57
+ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
58
+ served_model_names: List[str]):
59
+ super().__init__(engine=engine,
60
+ model_config=model_config,
61
+ served_model_names=served_model_names,
62
+ lora_modules=None)
63
+ self._check_embedding_mode(model_config.embedding_mode)
64
+
65
+ async def create_embedding(self, request: EmbeddingRequest,
66
+ raw_request: Request):
67
+ """Completion API similar to OpenAI's API.
68
+
69
+ See https://platform.openai.com/docs/api-reference/embeddings/create
70
+ for the API specification. This API mimics the OpenAI Embedding API.
71
+ """
72
+ error_check_ret = await self._check_model(request)
73
+ if error_check_ret is not None:
74
+ return error_check_ret
75
+
76
+ encoding_format = (request.encoding_format
77
+ if request.encoding_format else "float")
78
+ if request.dimensions is not None:
79
+ return self.create_error_response(
80
+ "dimensions is currently not supported")
81
+
82
+ model_name = request.model
83
+ request_id = f"cmpl-{random_uuid()}"
84
+ created_time = int(time.monotonic())
85
+
86
+ # Schedule the request and get the result generator.
87
+ generators = []
88
+ try:
89
+ prompt_is_tokens, prompts = parse_prompt_format(request.input)
90
+ pooling_params = request.to_pooling_params()
91
+
92
+ for i, prompt in enumerate(prompts):
93
+ if prompt_is_tokens:
94
+ prompt_formats = self._validate_prompt_and_tokenize(
95
+ request, prompt_ids=prompt)
96
+ else:
97
+ prompt_formats = self._validate_prompt_and_tokenize(
98
+ request, prompt=prompt)
99
+
100
+ prompt_ids, prompt_text = prompt_formats
101
+
102
+ generator = self.engine.encode(
103
+ {
104
+ "prompt": prompt_text,
105
+ "prompt_token_ids": prompt_ids
106
+ },
107
+ pooling_params,
108
+ f"{request_id}-{i}",
109
+ )
110
+
111
+ generators.append(generator)
112
+ except ValueError as e:
113
+ # TODO: Use a vllm-specific Validation Error
114
+ return self.create_error_response(str(e))
115
+
116
+ result_generator: AsyncIterator[Tuple[
117
+ int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
118
+
119
+ # Non-streaming response
120
+ final_res_batch: List[Optional[EmbeddingRequestOutput]]
121
+ final_res_batch = [None] * len(prompts)
122
+ try:
123
+ async for i, res in result_generator:
124
+ if await raw_request.is_disconnected():
125
+ # Abort the request if the client disconnects.
126
+ await self.engine.abort(f"{request_id}-{i}")
127
+ # TODO: Use a vllm-specific Validation Error
128
+ return self.create_error_response("Client disconnected")
129
+ final_res_batch[i] = res
130
+ response = request_output_to_embedding_response(
131
+ final_res_batch, request_id, created_time, model_name,
132
+ encoding_format)
133
+ except ValueError as e:
134
+ # TODO: Use a vllm-specific Validation Error
135
+ return self.create_error_response(str(e))
136
+
137
+ return response
138
+
139
+ def _check_embedding_mode(self, embedding_mode: bool):
140
+ if not embedding_mode:
141
+ logger.warning(
142
+ "embedding_mode is False. Embedding API will not work.")
143
+ else:
144
+ logger.info("Activating the server engine with embedding enabled.")
serving_engine.py CHANGED
@@ -10,9 +10,10 @@ from vllm.config import ModelConfig
10
  from vllm.engine.async_llm_engine import AsyncLLMEngine
11
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
  CompletionRequest,
 
13
  EmbeddingRequest, ErrorResponse,
14
  ModelCard, ModelList,
15
- ModelPermission)
16
  from vllm.logger import init_logger
17
  from vllm.lora.request import LoRARequest
18
  from vllm.sequence import Logprob
@@ -35,6 +36,7 @@ class OpenAIServing:
35
  super().__init__()
36
 
37
  self.engine = engine
 
38
  self.max_model_len = model_config.max_model_len
39
 
40
  # A separate tokenizer to map token IDs to strings.
@@ -99,8 +101,9 @@ class OpenAIServing:
99
  return json_str
100
 
101
  async def _check_model(
102
- self, request: Union[CompletionRequest, ChatCompletionRequest,
103
- EmbeddingRequest]
 
104
  ) -> Optional[ErrorResponse]:
105
  if request.model in self.served_model_names:
106
  return None
@@ -126,7 +129,8 @@ class OpenAIServing:
126
  def _validate_prompt_and_tokenize(
127
  self,
128
  request: Union[ChatCompletionRequest, CompletionRequest,
129
- EmbeddingRequest],
 
130
  prompt: Optional[str] = None,
131
  prompt_ids: Optional[List[int]] = None,
132
  truncate_prompt_tokens: Optional[Annotated[int,
@@ -174,6 +178,11 @@ class OpenAIServing:
174
  f"generation. Please reduce the length of the input.", )
175
  return input_ids, input_text
176
 
 
 
 
 
 
177
  if request.max_tokens is None:
178
  if token_num >= self.max_model_len:
179
  raise ValueError(
 
10
  from vllm.engine.async_llm_engine import AsyncLLMEngine
11
  from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12
  CompletionRequest,
13
+ DetokenizeRequest,
14
  EmbeddingRequest, ErrorResponse,
15
  ModelCard, ModelList,
16
+ ModelPermission, TokenizeRequest)
17
  from vllm.logger import init_logger
18
  from vllm.lora.request import LoRARequest
19
  from vllm.sequence import Logprob
 
36
  super().__init__()
37
 
38
  self.engine = engine
39
+ self.model_config = model_config
40
  self.max_model_len = model_config.max_model_len
41
 
42
  # A separate tokenizer to map token IDs to strings.
 
101
  return json_str
102
 
103
  async def _check_model(
104
+ self, request: Union[ChatCompletionRequest, CompletionRequest,
105
+ DetokenizeRequest, EmbeddingRequest,
106
+ TokenizeRequest]
107
  ) -> Optional[ErrorResponse]:
108
  if request.model in self.served_model_names:
109
  return None
 
129
  def _validate_prompt_and_tokenize(
130
  self,
131
  request: Union[ChatCompletionRequest, CompletionRequest,
132
+ DetokenizeRequest, EmbeddingRequest,
133
+ TokenizeRequest],
134
  prompt: Optional[str] = None,
135
  prompt_ids: Optional[List[int]] = None,
136
  truncate_prompt_tokens: Optional[Annotated[int,
 
178
  f"generation. Please reduce the length of the input.", )
179
  return input_ids, input_text
180
 
181
+ # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
182
+ # and does not require model context length validation
183
+ if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
184
+ return input_ids, input_text
185
+
186
  if request.max_tokens is None:
187
  if token_num >= self.max_model_len:
188
  raise ValueError(