Spaces:
Paused
Paused
update some code to comply with 0.5.1
Browse files- api_server.py +39 -8
- protocol.py +112 -21
- serving_chat.py +518 -158
- serving_completion.py +378 -397
- serving_embedding.py +144 -0
- serving_engine.py +13 -4
api_server.py
CHANGED
@@ -15,20 +15,27 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
15 |
from prometheus_client import make_asgi_app
|
16 |
from starlette.routing import Mount
|
17 |
|
18 |
-
import vllm
|
19 |
import vllm.envs as envs
|
20 |
from vllm.engine.arg_utils import AsyncEngineArgs
|
21 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
22 |
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
|
|
|
|
23 |
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
24 |
ChatCompletionResponse,
|
25 |
CompletionRequest,
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
28 |
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
29 |
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
30 |
from vllm.logger import init_logger
|
31 |
from vllm.usage.usage_lib import UsageContext
|
|
|
32 |
|
33 |
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
34 |
|
@@ -59,6 +66,7 @@ async def lifespan(app: fastapi.FastAPI):
|
|
59 |
|
60 |
app = fastapi.FastAPI(lifespan=lifespan)
|
61 |
|
|
|
62 |
def parse_args():
|
63 |
parser = make_arg_parser()
|
64 |
return parser.parse_args()
|
@@ -84,7 +92,29 @@ async def health() -> Response:
|
|
84 |
return Response(status_code=200)
|
85 |
|
86 |
|
87 |
-
@app.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
async def show_available_models():
|
89 |
models = await openai_serving_chat.show_available_models()
|
90 |
return JSONResponse(content=models.model_dump())
|
@@ -92,11 +122,11 @@ async def show_available_models():
|
|
92 |
|
93 |
@app.get("/version")
|
94 |
async def show_version():
|
95 |
-
ver = {"version":
|
96 |
return JSONResponse(content=ver)
|
97 |
|
98 |
|
99 |
-
@app.post("/
|
100 |
async def create_chat_completion(request: ChatCompletionRequest,
|
101 |
raw_request: Request):
|
102 |
generator = await openai_serving_chat.create_chat_completion(
|
@@ -112,7 +142,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|
112 |
return JSONResponse(content=generator.model_dump())
|
113 |
|
114 |
|
115 |
-
@app.post("/
|
116 |
async def create_completion(request: CompletionRequest, raw_request: Request):
|
117 |
generator = await openai_serving_completion.create_completion(
|
118 |
request, raw_request)
|
@@ -126,7 +156,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|
126 |
return JSONResponse(content=generator.model_dump())
|
127 |
|
128 |
|
129 |
-
@app.post("/
|
130 |
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
131 |
generator = await openai_serving_embedding.create_embedding(
|
132 |
request, raw_request)
|
@@ -173,7 +203,7 @@ if __name__ == "__main__":
|
|
173 |
raise ValueError(f"Invalid middleware {middleware}. "
|
174 |
f"Must be a function or a class.")
|
175 |
|
176 |
-
logger.info("vLLM API server version %s",
|
177 |
logger.info("args: %s", args)
|
178 |
|
179 |
if args.served_model_name is not None:
|
@@ -182,6 +212,7 @@ if __name__ == "__main__":
|
|
182 |
served_model_names = [args.model]
|
183 |
|
184 |
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
|
185 |
engine = AsyncLLMEngine.from_engine_args(
|
186 |
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
187 |
|
|
|
15 |
from prometheus_client import make_asgi_app
|
16 |
from starlette.routing import Mount
|
17 |
|
|
|
18 |
import vllm.envs as envs
|
19 |
from vllm.engine.arg_utils import AsyncEngineArgs
|
20 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
21 |
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
22 |
+
# yapf conflicts with isort for this block
|
23 |
+
# yapf: disable
|
24 |
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
25 |
ChatCompletionResponse,
|
26 |
CompletionRequest,
|
27 |
+
DetokenizeRequest,
|
28 |
+
DetokenizeResponse,
|
29 |
+
EmbeddingRequest, ErrorResponse,
|
30 |
+
TokenizeRequest,
|
31 |
+
TokenizeResponse)
|
32 |
+
# yapf: enable
|
33 |
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
34 |
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
35 |
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
36 |
from vllm.logger import init_logger
|
37 |
from vllm.usage.usage_lib import UsageContext
|
38 |
+
from vllm.version import __version__ as VLLM_VERSION
|
39 |
|
40 |
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
41 |
|
|
|
66 |
|
67 |
app = fastapi.FastAPI(lifespan=lifespan)
|
68 |
|
69 |
+
|
70 |
def parse_args():
|
71 |
parser = make_arg_parser()
|
72 |
return parser.parse_args()
|
|
|
92 |
return Response(status_code=200)
|
93 |
|
94 |
|
95 |
+
@app.post("/tokenize")
|
96 |
+
async def tokenize(request: TokenizeRequest):
|
97 |
+
generator = await openai_serving_completion.create_tokenize(request)
|
98 |
+
if isinstance(generator, ErrorResponse):
|
99 |
+
return JSONResponse(content=generator.model_dump(),
|
100 |
+
status_code=generator.code)
|
101 |
+
else:
|
102 |
+
assert isinstance(generator, TokenizeResponse)
|
103 |
+
return JSONResponse(content=generator.model_dump())
|
104 |
+
|
105 |
+
|
106 |
+
@app.post("/detokenize")
|
107 |
+
async def detokenize(request: DetokenizeRequest):
|
108 |
+
generator = await openai_serving_completion.create_detokenize(request)
|
109 |
+
if isinstance(generator, ErrorResponse):
|
110 |
+
return JSONResponse(content=generator.model_dump(),
|
111 |
+
status_code=generator.code)
|
112 |
+
else:
|
113 |
+
assert isinstance(generator, DetokenizeResponse)
|
114 |
+
return JSONResponse(content=generator.model_dump())
|
115 |
+
|
116 |
+
|
117 |
+
@app.get("/v1/models")
|
118 |
async def show_available_models():
|
119 |
models = await openai_serving_chat.show_available_models()
|
120 |
return JSONResponse(content=models.model_dump())
|
|
|
122 |
|
123 |
@app.get("/version")
|
124 |
async def show_version():
|
125 |
+
ver = {"version": VLLM_VERSION}
|
126 |
return JSONResponse(content=ver)
|
127 |
|
128 |
|
129 |
+
@app.post("/v1/chat/completions")
|
130 |
async def create_chat_completion(request: ChatCompletionRequest,
|
131 |
raw_request: Request):
|
132 |
generator = await openai_serving_chat.create_chat_completion(
|
|
|
142 |
return JSONResponse(content=generator.model_dump())
|
143 |
|
144 |
|
145 |
+
@app.post("/v1/completions")
|
146 |
async def create_completion(request: CompletionRequest, raw_request: Request):
|
147 |
generator = await openai_serving_completion.create_completion(
|
148 |
request, raw_request)
|
|
|
156 |
return JSONResponse(content=generator.model_dump())
|
157 |
|
158 |
|
159 |
+
@app.post("/v1/embeddings")
|
160 |
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
161 |
generator = await openai_serving_embedding.create_embedding(
|
162 |
request, raw_request)
|
|
|
203 |
raise ValueError(f"Invalid middleware {middleware}. "
|
204 |
f"Must be a function or a class.")
|
205 |
|
206 |
+
logger.info("vLLM API server version %s", VLLM_VERSION)
|
207 |
logger.info("args: %s", args)
|
208 |
|
209 |
if args.served_model_name is not None:
|
|
|
212 |
served_model_names = [args.model]
|
213 |
|
214 |
engine_args = AsyncEngineArgs.from_cli_args(args)
|
215 |
+
|
216 |
engine = AsyncLLMEngine.from_engine_args(
|
217 |
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
218 |
|
protocol.py
CHANGED
@@ -102,6 +102,11 @@ class ResponseFormat(OpenAIBaseModel):
|
|
102 |
type: Literal["text", "json_object"]
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
105 |
class FunctionDefinition(OpenAIBaseModel):
|
106 |
name: str
|
107 |
description: Optional[str] = None
|
@@ -140,6 +145,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
140 |
le=torch.iinfo(torch.long).max)
|
141 |
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
142 |
stream: Optional[bool] = False
|
|
|
143 |
temperature: Optional[float] = 0.7
|
144 |
top_p: Optional[float] = 1.0
|
145 |
tools: Optional[List[ChatCompletionToolsParam]] = None
|
@@ -185,6 +191,27 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
185 |
"special tokens so this should be set to False (as is the "
|
186 |
"default)."),
|
187 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
include_stop_str_in_output: Optional[bool] = Field(
|
189 |
default=False,
|
190 |
description=(
|
@@ -229,15 +256,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
229 |
|
230 |
logits_processors = None
|
231 |
if self.logit_bias:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
def logit_bias_logits_processor(
|
234 |
token_ids: List[int],
|
235 |
logits: torch.Tensor) -> torch.Tensor:
|
236 |
-
|
237 |
-
|
238 |
-
# Clamp the bias between -100 and 100 per OpenAI API spec
|
239 |
-
bias = min(100, max(-100, bias))
|
240 |
-
logits[int(token_id)] += bias
|
241 |
return logits
|
242 |
|
243 |
logits_processors = [logit_bias_logits_processor]
|
@@ -269,6 +303,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
269 |
logits_processors=logits_processors,
|
270 |
)
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
@model_validator(mode="before")
|
273 |
@classmethod
|
274 |
def check_guided_decoding_count(cls, data):
|
@@ -308,9 +351,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
308 |
raise ValueError(
|
309 |
"when using `top_logprobs`, `logprobs` must be set to true."
|
310 |
)
|
311 |
-
elif
|
312 |
raise ValueError(
|
313 |
-
"`top_logprobs` must be a value
|
314 |
return data
|
315 |
|
316 |
|
@@ -332,6 +375,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|
332 |
le=torch.iinfo(torch.long).max)
|
333 |
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
334 |
stream: Optional[bool] = False
|
|
|
335 |
suffix: Optional[str] = None
|
336 |
temperature: Optional[float] = 1.0
|
337 |
top_p: Optional[float] = 1.0
|
@@ -404,15 +448,22 @@ class CompletionRequest(OpenAIBaseModel):
|
|
404 |
|
405 |
logits_processors = None
|
406 |
if self.logit_bias:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
def logit_bias_logits_processor(
|
409 |
token_ids: List[int],
|
410 |
logits: torch.Tensor) -> torch.Tensor:
|
411 |
-
|
412 |
-
|
413 |
-
# Clamp the bias between -100 and 100 per OpenAI API spec
|
414 |
-
bias = min(100, max(-100, bias))
|
415 |
-
logits[int(token_id)] += bias
|
416 |
return logits
|
417 |
|
418 |
logits_processors = [logit_bias_logits_processor]
|
@@ -463,9 +514,16 @@ class CompletionRequest(OpenAIBaseModel):
|
|
463 |
@classmethod
|
464 |
def check_logprobs(cls, data):
|
465 |
if "logprobs" in data and data[
|
466 |
-
"logprobs"] is not None and not
|
467 |
-
raise ValueError(
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
return data
|
470 |
|
471 |
|
@@ -491,7 +549,8 @@ class CompletionLogProbs(OpenAIBaseModel):
|
|
491 |
text_offset: List[int] = Field(default_factory=list)
|
492 |
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
493 |
tokens: List[str] = Field(default_factory=list)
|
494 |
-
top_logprobs:
|
|
|
495 |
|
496 |
|
497 |
class CompletionResponseChoice(OpenAIBaseModel):
|
@@ -543,7 +602,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
|
543 |
class EmbeddingResponseData(BaseModel):
|
544 |
index: int
|
545 |
object: str = "embedding"
|
546 |
-
embedding: List[float]
|
547 |
|
548 |
|
549 |
class EmbeddingResponse(BaseModel):
|
@@ -590,7 +649,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
|
|
590 |
index: int
|
591 |
message: ChatMessage
|
592 |
logprobs: Optional[ChatCompletionLogProbs] = None
|
593 |
-
finish_reason: Optional[
|
594 |
stop_reason: Optional[Union[int, str]] = None
|
595 |
|
596 |
|
@@ -613,7 +672,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
|
613 |
index: int
|
614 |
delta: DeltaMessage
|
615 |
logprobs: Optional[ChatCompletionLogProbs] = None
|
616 |
-
finish_reason: Optional[
|
617 |
stop_reason: Optional[Union[int, str]] = None
|
618 |
|
619 |
|
@@ -649,6 +708,17 @@ class BatchRequestInput(OpenAIBaseModel):
|
|
649 |
body: Union[ChatCompletionRequest, ]
|
650 |
|
651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
class BatchRequestOutput(OpenAIBaseModel):
|
653 |
"""
|
654 |
The per-line object of the batch output and error files
|
@@ -660,8 +730,29 @@ class BatchRequestOutput(OpenAIBaseModel):
|
|
660 |
# inputs.
|
661 |
custom_id: str
|
662 |
|
663 |
-
response: Optional[
|
664 |
|
665 |
# For requests that failed with a non-HTTP error, this will contain more
|
666 |
# information on the cause of the failure.
|
667 |
-
error: Optional[Any]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
type: Literal["text", "json_object"]
|
103 |
|
104 |
|
105 |
+
class StreamOptions(OpenAIBaseModel):
|
106 |
+
include_usage: Optional[bool] = True
|
107 |
+
continuous_usage_stats: Optional[bool] = True
|
108 |
+
|
109 |
+
|
110 |
class FunctionDefinition(OpenAIBaseModel):
|
111 |
name: str
|
112 |
description: Optional[str] = None
|
|
|
145 |
le=torch.iinfo(torch.long).max)
|
146 |
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
147 |
stream: Optional[bool] = False
|
148 |
+
stream_options: Optional[StreamOptions] = None
|
149 |
temperature: Optional[float] = 0.7
|
150 |
top_p: Optional[float] = 1.0
|
151 |
tools: Optional[List[ChatCompletionToolsParam]] = None
|
|
|
191 |
"special tokens so this should be set to False (as is the "
|
192 |
"default)."),
|
193 |
)
|
194 |
+
documents: Optional[List[Dict[str, str]]] = Field(
|
195 |
+
default=None,
|
196 |
+
description=
|
197 |
+
("A list of dicts representing documents that will be accessible to "
|
198 |
+
"the model if it is performing RAG (retrieval-augmented generation)."
|
199 |
+
" If the template does not support RAG, this argument will have no "
|
200 |
+
"effect. We recommend that each document should be a dict containing "
|
201 |
+
"\"title\" and \"text\" keys."),
|
202 |
+
)
|
203 |
+
chat_template: Optional[str] = Field(
|
204 |
+
default=None,
|
205 |
+
description=(
|
206 |
+
"A Jinja template to use for this conversion. "
|
207 |
+
"If this is not passed, the model's default chat template will be "
|
208 |
+
"used instead."),
|
209 |
+
)
|
210 |
+
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
211 |
+
default=None,
|
212 |
+
description=("Additional kwargs to pass to the template renderer. "
|
213 |
+
"Will be accessible by the chat template."),
|
214 |
+
)
|
215 |
include_stop_str_in_output: Optional[bool] = Field(
|
216 |
default=False,
|
217 |
description=(
|
|
|
256 |
|
257 |
logits_processors = None
|
258 |
if self.logit_bias:
|
259 |
+
logit_bias: Dict[int, float] = {}
|
260 |
+
try:
|
261 |
+
for token_id, bias in self.logit_bias.items():
|
262 |
+
# Convert token_id to integer before we add to LLMEngine
|
263 |
+
# Clamp the bias between -100 and 100 per OpenAI API spec
|
264 |
+
logit_bias[int(token_id)] = min(100, max(-100, bias))
|
265 |
+
except ValueError as exc:
|
266 |
+
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
|
267 |
+
f"but token_id must be an integer or string "
|
268 |
+
f"representing an integer") from exc
|
269 |
|
270 |
def logit_bias_logits_processor(
|
271 |
token_ids: List[int],
|
272 |
logits: torch.Tensor) -> torch.Tensor:
|
273 |
+
for token_id, bias in logit_bias.items():
|
274 |
+
logits[token_id] += bias
|
|
|
|
|
|
|
275 |
return logits
|
276 |
|
277 |
logits_processors = [logit_bias_logits_processor]
|
|
|
303 |
logits_processors=logits_processors,
|
304 |
)
|
305 |
|
306 |
+
@model_validator(mode='before')
|
307 |
+
@classmethod
|
308 |
+
def validate_stream_options(cls, values):
|
309 |
+
if (values.get('stream_options') is not None
|
310 |
+
and not values.get('stream')):
|
311 |
+
raise ValueError(
|
312 |
+
"stream_options can only be set if stream is true")
|
313 |
+
return values
|
314 |
+
|
315 |
@model_validator(mode="before")
|
316 |
@classmethod
|
317 |
def check_guided_decoding_count(cls, data):
|
|
|
351 |
raise ValueError(
|
352 |
"when using `top_logprobs`, `logprobs` must be set to true."
|
353 |
)
|
354 |
+
elif data["top_logprobs"] < 0:
|
355 |
raise ValueError(
|
356 |
+
"`top_logprobs` must be a value a positive value.")
|
357 |
return data
|
358 |
|
359 |
|
|
|
375 |
le=torch.iinfo(torch.long).max)
|
376 |
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
377 |
stream: Optional[bool] = False
|
378 |
+
stream_options: Optional[StreamOptions] = None
|
379 |
suffix: Optional[str] = None
|
380 |
temperature: Optional[float] = 1.0
|
381 |
top_p: Optional[float] = 1.0
|
|
|
448 |
|
449 |
logits_processors = None
|
450 |
if self.logit_bias:
|
451 |
+
logit_bias: Dict[int, float] = {}
|
452 |
+
try:
|
453 |
+
for token_id, bias in self.logit_bias.items():
|
454 |
+
# Convert token_id to integer
|
455 |
+
# Clamp the bias between -100 and 100 per OpenAI API spec
|
456 |
+
logit_bias[int(token_id)] = min(100, max(-100, bias))
|
457 |
+
except ValueError as exc:
|
458 |
+
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
|
459 |
+
f"but token_id must be an integer or string "
|
460 |
+
f"representing an integer") from exc
|
461 |
|
462 |
def logit_bias_logits_processor(
|
463 |
token_ids: List[int],
|
464 |
logits: torch.Tensor) -> torch.Tensor:
|
465 |
+
for token_id, bias in logit_bias.items():
|
466 |
+
logits[token_id] += bias
|
|
|
|
|
|
|
467 |
return logits
|
468 |
|
469 |
logits_processors = [logit_bias_logits_processor]
|
|
|
514 |
@classmethod
|
515 |
def check_logprobs(cls, data):
|
516 |
if "logprobs" in data and data[
|
517 |
+
"logprobs"] is not None and not data["logprobs"] >= 0:
|
518 |
+
raise ValueError("if passed, `logprobs` must be a positive value.")
|
519 |
+
return data
|
520 |
+
|
521 |
+
@model_validator(mode="before")
|
522 |
+
@classmethod
|
523 |
+
def validate_stream_options(cls, data):
|
524 |
+
if data.get("stream_options") and not data.get("stream"):
|
525 |
+
raise ValueError(
|
526 |
+
"Stream options can only be defined when stream is True.")
|
527 |
return data
|
528 |
|
529 |
|
|
|
549 |
text_offset: List[int] = Field(default_factory=list)
|
550 |
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
551 |
tokens: List[str] = Field(default_factory=list)
|
552 |
+
top_logprobs: List[Optional[Dict[str,
|
553 |
+
float]]] = Field(default_factory=list)
|
554 |
|
555 |
|
556 |
class CompletionResponseChoice(OpenAIBaseModel):
|
|
|
602 |
class EmbeddingResponseData(BaseModel):
|
603 |
index: int
|
604 |
object: str = "embedding"
|
605 |
+
embedding: Union[List[float], str]
|
606 |
|
607 |
|
608 |
class EmbeddingResponse(BaseModel):
|
|
|
649 |
index: int
|
650 |
message: ChatMessage
|
651 |
logprobs: Optional[ChatCompletionLogProbs] = None
|
652 |
+
finish_reason: Optional[str] = None
|
653 |
stop_reason: Optional[Union[int, str]] = None
|
654 |
|
655 |
|
|
|
672 |
index: int
|
673 |
delta: DeltaMessage
|
674 |
logprobs: Optional[ChatCompletionLogProbs] = None
|
675 |
+
finish_reason: Optional[str] = None
|
676 |
stop_reason: Optional[Union[int, str]] = None
|
677 |
|
678 |
|
|
|
708 |
body: Union[ChatCompletionRequest, ]
|
709 |
|
710 |
|
711 |
+
class BatchResponseData(OpenAIBaseModel):
|
712 |
+
# HTTP status code of the response.
|
713 |
+
status_code: int = 200
|
714 |
+
|
715 |
+
# An unique identifier for the API request.
|
716 |
+
request_id: str
|
717 |
+
|
718 |
+
# The body of the response.
|
719 |
+
body: Union[ChatCompletionResponse, ]
|
720 |
+
|
721 |
+
|
722 |
class BatchRequestOutput(OpenAIBaseModel):
|
723 |
"""
|
724 |
The per-line object of the batch output and error files
|
|
|
730 |
# inputs.
|
731 |
custom_id: str
|
732 |
|
733 |
+
response: Optional[BatchResponseData]
|
734 |
|
735 |
# For requests that failed with a non-HTTP error, this will contain more
|
736 |
# information on the cause of the failure.
|
737 |
+
error: Optional[Any]
|
738 |
+
|
739 |
+
|
740 |
+
class TokenizeRequest(OpenAIBaseModel):
|
741 |
+
model: str
|
742 |
+
prompt: str
|
743 |
+
add_special_tokens: bool = Field(default=True)
|
744 |
+
|
745 |
+
|
746 |
+
class TokenizeResponse(OpenAIBaseModel):
|
747 |
+
tokens: List[int]
|
748 |
+
count: int
|
749 |
+
max_model_len: int
|
750 |
+
|
751 |
+
|
752 |
+
class DetokenizeRequest(OpenAIBaseModel):
|
753 |
+
model: str
|
754 |
+
tokens: List[int]
|
755 |
+
|
756 |
+
|
757 |
+
class DetokenizeResponse(OpenAIBaseModel):
|
758 |
+
prompt: str
|
serving_chat.py
CHANGED
@@ -1,224 +1,558 @@
|
|
1 |
-
import time
|
2 |
import codecs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from fastapi import Request
|
4 |
-
from
|
5 |
-
|
6 |
-
|
|
|
7 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
8 |
-
from protocol import (
|
|
|
|
|
|
|
9 |
ChatCompletionRequest, ChatCompletionResponse,
|
10 |
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
11 |
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
12 |
-
UsageInfo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from vllm.outputs import RequestOutput
|
14 |
-
from
|
|
|
|
|
|
|
15 |
|
16 |
logger = init_logger(__name__)
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class OpenAIServingChat(OpenAIServing):
|
20 |
|
21 |
def __init__(self,
|
22 |
engine: AsyncLLMEngine,
|
23 |
-
|
|
|
24 |
response_role: str,
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
self.response_role = response_role
|
28 |
self._load_chat_template(chat_template)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
async def create_chat_completion(
|
31 |
-
self,
|
|
|
|
|
32 |
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
33 |
ChatCompletionResponse]:
|
34 |
"""Completion API similar to OpenAI's API.
|
35 |
|
36 |
-
See
|
37 |
-
for the API specification. This API mimics the OpenAI
|
|
|
38 |
|
39 |
-
NOTE: Currently we do not support the following
|
40 |
- function_call (Users should implement this by themselves)
|
41 |
-
- logit_bias (to be supported by vLLM engine)
|
42 |
"""
|
43 |
error_check_ret = await self._check_model(request)
|
44 |
if error_check_ret is not None:
|
45 |
return error_check_ret
|
46 |
|
47 |
-
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
48 |
-
# TODO: support logit_bias in vLLM engine.
|
49 |
-
return self.create_error_response(
|
50 |
-
"logit_bias is not currently supported")
|
51 |
-
|
52 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
prompt = self.tokenizer.apply_chat_template(
|
54 |
-
conversation=
|
55 |
tokenize=False,
|
56 |
-
add_generation_prompt=request.add_generation_prompt
|
|
|
|
|
|
|
|
|
|
|
57 |
except Exception as e:
|
58 |
-
logger.error(
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
return self.create_error_response(str(e))
|
61 |
|
62 |
request_id = f"cmpl-{random_uuid()}"
|
63 |
try:
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
sampling_params = request.to_sampling_params()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
except ValueError as e:
|
68 |
return self.create_error_response(str(e))
|
69 |
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
# Streaming response
|
73 |
if request.stream:
|
74 |
return self.chat_completion_stream_generator(
|
75 |
-
request, result_generator, request_id)
|
76 |
else:
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
81 |
if request.add_generation_prompt:
|
82 |
return self.response_role
|
83 |
else:
|
84 |
-
return request.messages[-1]
|
85 |
|
86 |
async def chat_completion_stream_generator(
|
87 |
self, request: ChatCompletionRequest,
|
88 |
-
result_generator: AsyncIterator[RequestOutput], request_id: str
|
89 |
-
|
90 |
-
|
91 |
-
model_name =
|
92 |
-
created_time = int(time.
|
93 |
chunk_object_type = "chat.completion.chunk"
|
94 |
-
|
95 |
-
# Send first response for each request.n (index) with the role
|
96 |
-
role = self.get_chat_request_role(request)
|
97 |
-
for i in range(request.n):
|
98 |
-
choice_data = ChatCompletionResponseStreamChoice(
|
99 |
-
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
100 |
-
chunk = ChatCompletionStreamResponse(id=request_id,
|
101 |
-
object=chunk_object_type,
|
102 |
-
created=created_time,
|
103 |
-
choices=[choice_data],
|
104 |
-
model=model_name)
|
105 |
-
data = chunk.model_dump_json(exclude_unset=True)
|
106 |
-
yield f"data: {data}\n\n"
|
107 |
-
|
108 |
-
# Send response to echo the input portion of the last message
|
109 |
-
if request.echo:
|
110 |
-
last_msg_content = ""
|
111 |
-
if request.messages and isinstance(
|
112 |
-
request.messages, list) and request.messages[-1].get(
|
113 |
-
"content") and request.messages[-1].get(
|
114 |
-
"role") == role:
|
115 |
-
last_msg_content = request.messages[-1]["content"]
|
116 |
-
if last_msg_content:
|
117 |
-
for i in range(request.n):
|
118 |
-
choice_data = ChatCompletionResponseStreamChoice(
|
119 |
-
index=i,
|
120 |
-
delta=DeltaMessage(content=last_msg_content),
|
121 |
-
finish_reason=None)
|
122 |
-
chunk = ChatCompletionStreamResponse(
|
123 |
-
id=request_id,
|
124 |
-
object=chunk_object_type,
|
125 |
-
created=created_time,
|
126 |
-
choices=[choice_data],
|
127 |
-
model=model_name)
|
128 |
-
data = chunk.model_dump_json(exclude_unset=True)
|
129 |
-
yield f"data: {data}\n\n"
|
130 |
|
131 |
# Send response for each token for each request.n (index)
|
|
|
132 |
previous_texts = [""] * request.n
|
133 |
previous_num_tokens = [0] * request.n
|
134 |
finish_reason_sent = [False] * request.n
|
135 |
-
|
136 |
-
res:
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
if
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
# Send
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
# Send the final done message after all response.n are finished
|
186 |
yield "data: [DONE]\n\n"
|
187 |
|
188 |
async def chat_completion_full_generator(
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
192 |
|
193 |
-
model_name =
|
194 |
-
created_time = int(time.
|
195 |
-
final_res: RequestOutput = None
|
196 |
|
197 |
async for res in result_generator:
|
198 |
-
if await raw_request.is_disconnected():
|
199 |
# Abort the request if the client disconnects.
|
200 |
await self.engine.abort(request_id)
|
201 |
return self.create_error_response("Client disconnected")
|
202 |
final_res = res
|
203 |
assert final_res is not None
|
204 |
|
205 |
-
choices = []
|
|
|
206 |
role = self.get_chat_request_role(request)
|
207 |
for output in final_res.outputs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
choice_data = ChatCompletionResponseChoice(
|
209 |
index=output.index,
|
210 |
-
message=
|
|
|
211 |
finish_reason=output.finish_reason,
|
212 |
-
|
213 |
choices.append(choice_data)
|
214 |
|
215 |
if request.echo:
|
216 |
last_msg_content = ""
|
217 |
-
if
|
218 |
-
|
219 |
-
|
220 |
-
"role") == role:
|
221 |
-
last_msg_content = request.messages[-1]["content"]
|
222 |
|
223 |
for choice in choices:
|
224 |
full_message = last_msg_content + choice.message.content
|
@@ -242,24 +576,50 @@ class OpenAIServingChat(OpenAIServing):
|
|
242 |
|
243 |
return response
|
244 |
|
245 |
-
def
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import codecs
|
2 |
+
import time
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from functools import cached_property
|
5 |
+
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
|
6 |
+
List, Optional)
|
7 |
+
from typing import Sequence as GenericSequence
|
8 |
+
from typing import TypedDict, Union, cast, final
|
9 |
+
|
10 |
from fastapi import Request
|
11 |
+
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
12 |
+
ChatCompletionContentPartTextParam)
|
13 |
+
|
14 |
+
from vllm.config import ModelConfig
|
15 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
16 |
+
from vllm.entrypoints.openai.protocol import (
|
17 |
+
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
18 |
+
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
19 |
+
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
|
20 |
ChatCompletionRequest, ChatCompletionResponse,
|
21 |
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
22 |
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
23 |
+
FunctionCall, ToolCall, UsageInfo)
|
24 |
+
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
25 |
+
OpenAIServing)
|
26 |
+
from vllm.inputs import PromptInputs
|
27 |
+
from vllm.logger import init_logger
|
28 |
+
from vllm.model_executor.guided_decoding import (
|
29 |
+
get_guided_decoding_logits_processor)
|
30 |
+
from vllm.multimodal import MultiModalDataDict
|
31 |
+
from vllm.multimodal.utils import async_get_and_parse_image
|
32 |
from vllm.outputs import RequestOutput
|
33 |
+
from vllm.sequence import Logprob
|
34 |
+
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
35 |
+
log_tracing_disabled_warning)
|
36 |
+
from vllm.utils import random_uuid
|
37 |
|
38 |
logger = init_logger(__name__)
|
39 |
|
40 |
|
41 |
+
@final # So that it should be compatible with Dict[str, str]
|
42 |
+
class ConversationMessage(TypedDict):
|
43 |
+
role: str
|
44 |
+
content: str
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass(frozen=True)
|
48 |
+
class ChatMessageParseResult:
|
49 |
+
messages: List[ConversationMessage]
|
50 |
+
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
|
51 |
+
default_factory=list)
|
52 |
+
|
53 |
+
|
54 |
class OpenAIServingChat(OpenAIServing):
|
55 |
|
56 |
def __init__(self,
|
57 |
engine: AsyncLLMEngine,
|
58 |
+
model_config: ModelConfig,
|
59 |
+
served_model_names: List[str],
|
60 |
response_role: str,
|
61 |
+
lora_modules: Optional[List[LoRAModulePath]] = None,
|
62 |
+
chat_template: Optional[str] = None):
|
63 |
+
super().__init__(engine=engine,
|
64 |
+
model_config=model_config,
|
65 |
+
served_model_names=served_model_names,
|
66 |
+
lora_modules=lora_modules)
|
67 |
+
|
68 |
self.response_role = response_role
|
69 |
self._load_chat_template(chat_template)
|
70 |
|
71 |
+
def _load_chat_template(self, chat_template: Optional[str]):
|
72 |
+
tokenizer = self.tokenizer
|
73 |
+
|
74 |
+
if chat_template is not None:
|
75 |
+
try:
|
76 |
+
with open(chat_template, "r") as f:
|
77 |
+
tokenizer.chat_template = f.read()
|
78 |
+
except OSError as e:
|
79 |
+
JINJA_CHARS = "{}\n"
|
80 |
+
if not any(c in chat_template for c in JINJA_CHARS):
|
81 |
+
msg = (f"The supplied chat template ({chat_template}) "
|
82 |
+
f"looks like a file path, but it failed to be "
|
83 |
+
f"opened. Reason: {e}")
|
84 |
+
raise ValueError(msg) from e
|
85 |
+
|
86 |
+
# If opening a file fails, set chat template to be args to
|
87 |
+
# ensure we decode so our escape are interpreted correctly
|
88 |
+
tokenizer.chat_template = codecs.decode(
|
89 |
+
chat_template, "unicode_escape")
|
90 |
+
|
91 |
+
logger.info("Using supplied chat template:\n%s",
|
92 |
+
tokenizer.chat_template)
|
93 |
+
elif tokenizer.chat_template is not None:
|
94 |
+
logger.info("Using default chat template:\n%s",
|
95 |
+
tokenizer.chat_template)
|
96 |
+
else:
|
97 |
+
logger.warning(
|
98 |
+
"No chat template provided. Chat API will not work.")
|
99 |
+
|
100 |
+
@cached_property
|
101 |
+
def image_token_str(self) -> Optional[str]:
|
102 |
+
# TODO: Let user specify how to insert image tokens into prompt
|
103 |
+
# (similar to chat template)
|
104 |
+
model_type = self.model_config.hf_config.model_type
|
105 |
+
if model_type == "phi3_v":
|
106 |
+
# Workaround since this token is not defined in the tokenizer
|
107 |
+
return "<|image_1|>"
|
108 |
+
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
|
109 |
+
"paligemma"):
|
110 |
+
# These models do not use image tokens in the prompt
|
111 |
+
return None
|
112 |
+
if model_type.startswith("llava"):
|
113 |
+
return self.tokenizer.decode(
|
114 |
+
self.model_config.hf_config.image_token_index)
|
115 |
+
|
116 |
+
else:
|
117 |
+
raise TypeError("Unknown model type: {model_type}")
|
118 |
+
|
119 |
+
# TODO: Let user specify how to insert image tokens into prompt
|
120 |
+
# (similar to chat template)
|
121 |
+
def _get_full_image_text_prompt(self, image_token_str: str,
|
122 |
+
text_prompt: str) -> str:
|
123 |
+
"""Combine image and text prompts for vision language model"""
|
124 |
+
|
125 |
+
# NOTE: For now we assume all model architectures use the same
|
126 |
+
# image + text prompt format. This may change in the future.
|
127 |
+
return f"{image_token_str}\n{text_prompt}"
|
128 |
+
|
129 |
+
def _parse_chat_message_content_parts(
|
130 |
+
self,
|
131 |
+
role: str,
|
132 |
+
parts: Iterable[ChatCompletionContentPartParam],
|
133 |
+
) -> ChatMessageParseResult:
|
134 |
+
texts: List[str] = []
|
135 |
+
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
136 |
+
|
137 |
+
for part in parts:
|
138 |
+
part_type = part["type"]
|
139 |
+
if part_type == "text":
|
140 |
+
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
141 |
+
texts.append(text)
|
142 |
+
elif part_type == "image_url":
|
143 |
+
if len(mm_futures) > 0:
|
144 |
+
raise NotImplementedError(
|
145 |
+
"Multiple 'image_url' input is currently not supported."
|
146 |
+
)
|
147 |
+
|
148 |
+
image_url = cast(ChatCompletionContentPartImageParam,
|
149 |
+
part)["image_url"]
|
150 |
+
|
151 |
+
if image_url.get("detail", "auto") != "auto":
|
152 |
+
logger.warning(
|
153 |
+
"'image_url.detail' is currently not supported and "
|
154 |
+
"will be ignored.")
|
155 |
+
|
156 |
+
image_future = async_get_and_parse_image(image_url["url"])
|
157 |
+
mm_futures.append(image_future)
|
158 |
+
else:
|
159 |
+
raise NotImplementedError(f"Unknown part type: {part_type}")
|
160 |
+
|
161 |
+
text_prompt = "\n".join(texts)
|
162 |
+
|
163 |
+
if mm_futures:
|
164 |
+
image_token_str = self.image_token_str
|
165 |
+
if image_token_str is not None:
|
166 |
+
if image_token_str in text_prompt:
|
167 |
+
logger.warning(
|
168 |
+
"Detected image token string in the text prompt. "
|
169 |
+
"Skipping prompt formatting.")
|
170 |
+
else:
|
171 |
+
text_prompt = self._get_full_image_text_prompt(
|
172 |
+
image_token_str=image_token_str,
|
173 |
+
text_prompt=text_prompt,
|
174 |
+
)
|
175 |
+
|
176 |
+
messages = [ConversationMessage(role=role, content=text_prompt)]
|
177 |
+
|
178 |
+
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
179 |
+
|
180 |
+
def _parse_chat_message_content(
|
181 |
+
self,
|
182 |
+
message: ChatCompletionMessageParam,
|
183 |
+
) -> ChatMessageParseResult:
|
184 |
+
role = message["role"]
|
185 |
+
content = message.get("content")
|
186 |
+
|
187 |
+
if content is None:
|
188 |
+
return ChatMessageParseResult(messages=[], mm_futures=[])
|
189 |
+
if isinstance(content, str):
|
190 |
+
messages = [ConversationMessage(role=role, content=content)]
|
191 |
+
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
192 |
+
|
193 |
+
return self._parse_chat_message_content_parts(role, content)
|
194 |
+
|
195 |
async def create_chat_completion(
|
196 |
+
self,
|
197 |
+
request: ChatCompletionRequest,
|
198 |
+
raw_request: Optional[Request] = None
|
199 |
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
200 |
ChatCompletionResponse]:
|
201 |
"""Completion API similar to OpenAI's API.
|
202 |
|
203 |
+
See https://platform.openai.com/docs/api-reference/chat/create
|
204 |
+
for the API specification. This API mimics the OpenAI
|
205 |
+
ChatCompletion API.
|
206 |
|
207 |
+
NOTE: Currently we do not support the following feature:
|
208 |
- function_call (Users should implement this by themselves)
|
|
|
209 |
"""
|
210 |
error_check_ret = await self._check_model(request)
|
211 |
if error_check_ret is not None:
|
212 |
return error_check_ret
|
213 |
|
|
|
|
|
|
|
|
|
|
|
214 |
try:
|
215 |
+
conversation: List[ConversationMessage] = []
|
216 |
+
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
217 |
+
|
218 |
+
for msg in request.messages:
|
219 |
+
chat_parsed_result = self._parse_chat_message_content(msg)
|
220 |
+
|
221 |
+
conversation.extend(chat_parsed_result.messages)
|
222 |
+
mm_futures.extend(chat_parsed_result.mm_futures)
|
223 |
+
|
224 |
+
tool_dicts = None if request.tools is None else [
|
225 |
+
tool.model_dump() for tool in request.tools
|
226 |
+
]
|
227 |
+
|
228 |
prompt = self.tokenizer.apply_chat_template(
|
229 |
+
conversation=conversation,
|
230 |
tokenize=False,
|
231 |
+
add_generation_prompt=request.add_generation_prompt,
|
232 |
+
tools=tool_dicts,
|
233 |
+
documents=request.documents,
|
234 |
+
chat_template=request.chat_template,
|
235 |
+
**(request.chat_template_kwargs or {}),
|
236 |
+
)
|
237 |
except Exception as e:
|
238 |
+
logger.error("Error in applying chat template from request: %s", e)
|
239 |
+
return self.create_error_response(str(e))
|
240 |
+
|
241 |
+
mm_data: Optional[MultiModalDataDict] = None
|
242 |
+
try:
|
243 |
+
if len(mm_futures):
|
244 |
+
# since we support only single mm data currently
|
245 |
+
assert len(
|
246 |
+
mm_futures
|
247 |
+
) == 1, "Multiple 'image_url' input is currently not supported."
|
248 |
+
mm_data = await mm_futures[0]
|
249 |
+
except Exception as e:
|
250 |
+
logger.error("Error in loading multi-modal data: %s", e)
|
251 |
return self.create_error_response(str(e))
|
252 |
|
253 |
request_id = f"cmpl-{random_uuid()}"
|
254 |
try:
|
255 |
+
# Tokenize/detokenize depending on prompt format (string/token list)
|
256 |
+
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
257 |
+
request,
|
258 |
+
prompt=prompt,
|
259 |
+
add_special_tokens=request.add_special_tokens)
|
260 |
sampling_params = request.to_sampling_params()
|
261 |
+
lora_request = self._maybe_get_lora(request)
|
262 |
+
decoding_config = await self.engine.get_decoding_config()
|
263 |
+
guided_decoding_backend = request.guided_decoding_backend \
|
264 |
+
or decoding_config.guided_decoding_backend
|
265 |
+
guided_decode_logits_processor = (
|
266 |
+
await get_guided_decoding_logits_processor(
|
267 |
+
guided_decoding_backend, request, await
|
268 |
+
self.engine.get_tokenizer()))
|
269 |
+
if guided_decode_logits_processor:
|
270 |
+
if sampling_params.logits_processors is None:
|
271 |
+
sampling_params.logits_processors = []
|
272 |
+
sampling_params.logits_processors.append(
|
273 |
+
guided_decode_logits_processor)
|
274 |
except ValueError as e:
|
275 |
return self.create_error_response(str(e))
|
276 |
|
277 |
+
inputs: PromptInputs = {
|
278 |
+
"prompt": prompt_text,
|
279 |
+
"prompt_token_ids": prompt_ids,
|
280 |
+
}
|
281 |
+
if mm_data:
|
282 |
+
inputs["multi_modal_data"] = mm_data
|
283 |
+
|
284 |
+
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
285 |
+
trace_headers = None
|
286 |
+
if is_tracing_enabled and raw_request:
|
287 |
+
trace_headers = extract_trace_headers(raw_request.headers)
|
288 |
+
if not is_tracing_enabled and raw_request and contains_trace_headers(
|
289 |
+
raw_request.headers):
|
290 |
+
log_tracing_disabled_warning()
|
291 |
+
|
292 |
+
result_generator = self.engine.generate(
|
293 |
+
inputs,
|
294 |
+
sampling_params,
|
295 |
+
request_id,
|
296 |
+
lora_request,
|
297 |
+
trace_headers=trace_headers,
|
298 |
+
)
|
299 |
# Streaming response
|
300 |
if request.stream:
|
301 |
return self.chat_completion_stream_generator(
|
302 |
+
request, result_generator, request_id, conversation)
|
303 |
else:
|
304 |
+
try:
|
305 |
+
return await self.chat_completion_full_generator(
|
306 |
+
request, raw_request, result_generator, request_id,
|
307 |
+
conversation)
|
308 |
+
except ValueError as e:
|
309 |
+
# TODO: Use a vllm-specific Validation Error
|
310 |
+
return self.create_error_response(str(e))
|
311 |
|
312 |
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
313 |
if request.add_generation_prompt:
|
314 |
return self.response_role
|
315 |
else:
|
316 |
+
return request.messages[-1]["role"]
|
317 |
|
318 |
async def chat_completion_stream_generator(
|
319 |
self, request: ChatCompletionRequest,
|
320 |
+
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
321 |
+
conversation: List[ConversationMessage]
|
322 |
+
) -> AsyncGenerator[str, None]:
|
323 |
+
model_name = self.served_model_names[0]
|
324 |
+
created_time = int(time.time())
|
325 |
chunk_object_type = "chat.completion.chunk"
|
326 |
+
first_iteration = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
# Send response for each token for each request.n (index)
|
329 |
+
assert request.n is not None
|
330 |
previous_texts = [""] * request.n
|
331 |
previous_num_tokens = [0] * request.n
|
332 |
finish_reason_sent = [False] * request.n
|
333 |
+
try:
|
334 |
+
async for res in result_generator:
|
335 |
+
# We need to do it here, because if there are exceptions in
|
336 |
+
# the result_generator, it needs to be sent as the FIRST
|
337 |
+
# response (by the try...catch).
|
338 |
+
if first_iteration:
|
339 |
+
# Send first response for each request.n (index) with
|
340 |
+
# the role
|
341 |
+
role = self.get_chat_request_role(request)
|
342 |
+
for i in range(request.n):
|
343 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
344 |
+
index=i,
|
345 |
+
delta=DeltaMessage(role=role),
|
346 |
+
logprobs=None,
|
347 |
+
finish_reason=None)
|
348 |
+
chunk = ChatCompletionStreamResponse(
|
349 |
+
id=request_id,
|
350 |
+
object=chunk_object_type,
|
351 |
+
created=created_time,
|
352 |
+
choices=[choice_data],
|
353 |
+
model=model_name)
|
354 |
+
if (request.stream_options
|
355 |
+
and request.stream_options.include_usage):
|
356 |
+
chunk.usage = None
|
357 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
358 |
+
yield f"data: {data}\n\n"
|
359 |
+
|
360 |
+
# Send response to echo the input portion of the
|
361 |
+
# last message
|
362 |
+
if request.echo:
|
363 |
+
last_msg_content = ""
|
364 |
+
if conversation and conversation[-1].get(
|
365 |
+
"content") and conversation[-1].get(
|
366 |
+
"role") == role:
|
367 |
+
last_msg_content = conversation[-1]["content"]
|
368 |
+
|
369 |
+
if last_msg_content:
|
370 |
+
for i in range(request.n):
|
371 |
+
choice_data = (
|
372 |
+
ChatCompletionResponseStreamChoice(
|
373 |
+
index=i,
|
374 |
+
delta=DeltaMessage(
|
375 |
+
content=last_msg_content),
|
376 |
+
finish_reason=None))
|
377 |
+
chunk = ChatCompletionStreamResponse(
|
378 |
+
id=request_id,
|
379 |
+
object=chunk_object_type,
|
380 |
+
created=created_time,
|
381 |
+
choices=[choice_data],
|
382 |
+
logprobs=None,
|
383 |
+
model=model_name)
|
384 |
+
if (request.stream_options and
|
385 |
+
request.stream_options.include_usage):
|
386 |
+
chunk.usage = None
|
387 |
+
data = chunk.model_dump_json(
|
388 |
+
exclude_unset=True)
|
389 |
+
yield f"data: {data}\n\n"
|
390 |
+
first_iteration = False
|
391 |
+
|
392 |
+
for output in res.outputs:
|
393 |
+
i = output.index
|
394 |
+
|
395 |
+
if finish_reason_sent[i]:
|
396 |
+
continue
|
397 |
+
|
398 |
+
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
399 |
+
out_logprobs = output.logprobs[
|
400 |
+
previous_num_tokens[i]:] if output.logprobs else None
|
401 |
+
|
402 |
+
if request.logprobs and request.top_logprobs is not None:
|
403 |
+
assert out_logprobs is not None, (
|
404 |
+
"Did not output logprobs")
|
405 |
+
logprobs = self._create_chat_logprobs(
|
406 |
+
token_ids=delta_token_ids,
|
407 |
+
top_logprobs=out_logprobs,
|
408 |
+
num_output_top_logprobs=request.top_logprobs,
|
409 |
+
)
|
410 |
+
else:
|
411 |
+
logprobs = None
|
412 |
+
|
413 |
+
delta_text = output.text[len(previous_texts[i]):]
|
414 |
+
previous_texts[i] = output.text
|
415 |
+
previous_num_tokens[i] = len(output.token_ids)
|
416 |
+
|
417 |
+
if request.tool_choice and type(
|
418 |
+
request.tool_choice
|
419 |
+
) is ChatCompletionNamedToolChoiceParam:
|
420 |
+
delta_message = DeltaMessage(tool_calls=[
|
421 |
+
ToolCall(function=FunctionCall(
|
422 |
+
name=request.tool_choice.function.name,
|
423 |
+
arguments=delta_text))
|
424 |
+
])
|
425 |
+
else:
|
426 |
+
delta_message = DeltaMessage(content=delta_text)
|
427 |
+
|
428 |
+
if output.finish_reason is None:
|
429 |
+
# Send token-by-token response for each request.n
|
430 |
+
|
431 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
432 |
+
index=i,
|
433 |
+
delta=delta_message,
|
434 |
+
logprobs=logprobs,
|
435 |
+
finish_reason=None)
|
436 |
+
chunk = ChatCompletionStreamResponse(
|
437 |
+
id=request_id,
|
438 |
+
object=chunk_object_type,
|
439 |
+
created=created_time,
|
440 |
+
choices=[choice_data],
|
441 |
+
model=model_name)
|
442 |
+
if (request.stream_options
|
443 |
+
and request.stream_options.include_usage):
|
444 |
+
chunk.usage = None
|
445 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
446 |
+
yield f"data: {data}\n\n"
|
447 |
+
else:
|
448 |
+
# Send the finish response for each request.n only once
|
449 |
+
prompt_tokens = len(res.prompt_token_ids)
|
450 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
451 |
+
index=i,
|
452 |
+
delta=delta_message,
|
453 |
+
logprobs=logprobs,
|
454 |
+
finish_reason=output.finish_reason,
|
455 |
+
stop_reason=output.stop_reason)
|
456 |
+
chunk = ChatCompletionStreamResponse(
|
457 |
+
id=request_id,
|
458 |
+
object=chunk_object_type,
|
459 |
+
created=created_time,
|
460 |
+
choices=[choice_data],
|
461 |
+
model=model_name)
|
462 |
+
if (request.stream_options
|
463 |
+
and request.stream_options.include_usage):
|
464 |
+
chunk.usage = None
|
465 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
466 |
+
yield f"data: {data}\n\n"
|
467 |
+
finish_reason_sent[i] = True
|
468 |
+
|
469 |
+
if (request.stream_options
|
470 |
+
and request.stream_options.include_usage):
|
471 |
+
final_usage = UsageInfo(
|
472 |
+
prompt_tokens=prompt_tokens,
|
473 |
+
completion_tokens=previous_num_tokens[i],
|
474 |
+
total_tokens=prompt_tokens + previous_num_tokens[i],
|
475 |
+
)
|
476 |
+
|
477 |
+
final_usage_chunk = ChatCompletionStreamResponse(
|
478 |
+
id=request_id,
|
479 |
+
object=chunk_object_type,
|
480 |
+
created=created_time,
|
481 |
+
choices=[],
|
482 |
+
model=model_name,
|
483 |
+
usage=final_usage)
|
484 |
+
final_usage_data = (final_usage_chunk.model_dump_json(
|
485 |
+
exclude_unset=True, exclude_none=True))
|
486 |
+
yield f"data: {final_usage_data}\n\n"
|
487 |
+
|
488 |
+
except ValueError as e:
|
489 |
+
# TODO: Use a vllm-specific Validation Error
|
490 |
+
data = self.create_streaming_error_response(str(e))
|
491 |
+
yield f"data: {data}\n\n"
|
492 |
# Send the final done message after all response.n are finished
|
493 |
yield "data: [DONE]\n\n"
|
494 |
|
495 |
async def chat_completion_full_generator(
|
496 |
+
self, request: ChatCompletionRequest, raw_request: Optional[Request],
|
497 |
+
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
498 |
+
conversation: List[ConversationMessage]
|
499 |
+
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
500 |
|
501 |
+
model_name = self.served_model_names[0]
|
502 |
+
created_time = int(time.time())
|
503 |
+
final_res: Optional[RequestOutput] = None
|
504 |
|
505 |
async for res in result_generator:
|
506 |
+
if raw_request is not None and await raw_request.is_disconnected():
|
507 |
# Abort the request if the client disconnects.
|
508 |
await self.engine.abort(request_id)
|
509 |
return self.create_error_response("Client disconnected")
|
510 |
final_res = res
|
511 |
assert final_res is not None
|
512 |
|
513 |
+
choices: List[ChatCompletionResponseChoice] = []
|
514 |
+
|
515 |
role = self.get_chat_request_role(request)
|
516 |
for output in final_res.outputs:
|
517 |
+
token_ids = output.token_ids
|
518 |
+
out_logprobs = output.logprobs
|
519 |
+
|
520 |
+
if request.logprobs and request.top_logprobs is not None:
|
521 |
+
assert out_logprobs is not None, "Did not output logprobs"
|
522 |
+
logprobs = self._create_chat_logprobs(
|
523 |
+
token_ids=token_ids,
|
524 |
+
top_logprobs=out_logprobs,
|
525 |
+
num_output_top_logprobs=request.top_logprobs,
|
526 |
+
)
|
527 |
+
else:
|
528 |
+
logprobs = None
|
529 |
+
|
530 |
+
if request.tool_choice and type(
|
531 |
+
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
532 |
+
message = ChatMessage(
|
533 |
+
role=role,
|
534 |
+
content="",
|
535 |
+
tool_calls=[
|
536 |
+
ToolCall(function=FunctionCall(
|
537 |
+
name=request.tool_choice.function.name,
|
538 |
+
arguments=output.text))
|
539 |
+
])
|
540 |
+
elif not request.tool_choice or request.tool_choice == "none":
|
541 |
+
message = ChatMessage(role=role, content=output.text)
|
542 |
+
|
543 |
choice_data = ChatCompletionResponseChoice(
|
544 |
index=output.index,
|
545 |
+
message=message,
|
546 |
+
logprobs=logprobs,
|
547 |
finish_reason=output.finish_reason,
|
548 |
+
stop_reason=output.stop_reason)
|
549 |
choices.append(choice_data)
|
550 |
|
551 |
if request.echo:
|
552 |
last_msg_content = ""
|
553 |
+
if conversation and conversation[-1].get(
|
554 |
+
"content") and conversation[-1].get("role") == role:
|
555 |
+
last_msg_content = conversation[-1]["content"]
|
|
|
|
|
556 |
|
557 |
for choice in choices:
|
558 |
full_message = last_msg_content + choice.message.content
|
|
|
576 |
|
577 |
return response
|
578 |
|
579 |
+
def _get_top_logprobs(
|
580 |
+
self, logprobs: Dict[int, Logprob],
|
581 |
+
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
582 |
+
return [
|
583 |
+
ChatCompletionLogProb(
|
584 |
+
token=self._get_decoded_token(p[1], p[0]),
|
585 |
+
logprob=max(p[1].logprob, -9999.0),
|
586 |
+
bytes=list(
|
587 |
+
self._get_decoded_token(p[1],
|
588 |
+
p[0]).encode("utf-8",
|
589 |
+
errors="replace")))
|
590 |
+
for i, p in enumerate(logprobs.items())
|
591 |
+
if top_logprobs and i < top_logprobs
|
592 |
+
]
|
593 |
|
594 |
+
def _create_chat_logprobs(
|
595 |
+
self,
|
596 |
+
token_ids: GenericSequence[int],
|
597 |
+
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
598 |
+
num_output_top_logprobs: Optional[int] = None,
|
599 |
+
) -> ChatCompletionLogProbs:
|
600 |
+
"""Create OpenAI-style logprobs."""
|
601 |
+
|
602 |
+
logprobs_content = []
|
603 |
+
|
604 |
+
for i, token_id in enumerate(token_ids):
|
605 |
+
step_top_logprobs = top_logprobs[i]
|
606 |
+
if step_top_logprobs is None:
|
607 |
+
logprobs_content.append(
|
608 |
+
ChatCompletionLogProbsContent(
|
609 |
+
token=self.tokenizer.decode(token_id),
|
610 |
+
bytes=list(
|
611 |
+
self.tokenizer.decode(token_id).encode(
|
612 |
+
"utf-8", errors="replace"))))
|
613 |
+
else:
|
614 |
+
logprobs_content.append(
|
615 |
+
ChatCompletionLogProbsContent(
|
616 |
+
token=step_top_logprobs[token_id].decoded_token,
|
617 |
+
logprob=max(step_top_logprobs[token_id].logprob,
|
618 |
+
-9999.0),
|
619 |
+
bytes=list(
|
620 |
+
step_top_logprobs[token_id].decoded_token.encode(
|
621 |
+
"utf-8", errors="replace")),
|
622 |
+
top_logprobs=self._get_top_logprobs(
|
623 |
+
step_top_logprobs, num_output_top_logprobs)))
|
624 |
+
|
625 |
+
return ChatCompletionLogProbs(content=logprobs_content)
|
serving_completion.py
CHANGED
@@ -1,24 +1,26 @@
|
|
1 |
-
import codecs
|
2 |
import time
|
3 |
-
from
|
4 |
-
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
|
5 |
Optional)
|
6 |
from typing import Sequence as GenericSequence
|
7 |
-
from typing import
|
8 |
|
9 |
from fastapi import Request
|
10 |
-
from openai.types.chat import ChatCompletionContentPartTextParam
|
11 |
|
12 |
from vllm.config import ModelConfig
|
13 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
22 |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
23 |
OpenAIServing)
|
24 |
from vllm.logger import init_logger
|
@@ -26,417 +28,364 @@ from vllm.model_executor.guided_decoding import (
|
|
26 |
get_guided_decoding_logits_processor)
|
27 |
from vllm.outputs import RequestOutput
|
28 |
from vllm.sequence import Logprob
|
29 |
-
from vllm.
|
|
|
|
|
30 |
|
31 |
logger = init_logger(__name__)
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
-
class
|
46 |
|
47 |
-
def __init__(self,
|
48 |
-
engine: AsyncLLMEngine,
|
49 |
-
model_config: ModelConfig,
|
50 |
served_model_names: List[str],
|
51 |
-
|
52 |
-
lora_modules: Optional[List[LoRAModulePath]] = None,
|
53 |
-
chat_template: Optional[str] = None):
|
54 |
super().__init__(engine=engine,
|
55 |
model_config=model_config,
|
56 |
served_model_names=served_model_names,
|
57 |
lora_modules=lora_modules)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
def _load_chat_template(self, chat_template: Optional[str]):
|
63 |
-
tokenizer = self.tokenizer
|
64 |
-
|
65 |
-
if chat_template is not None:
|
66 |
-
try:
|
67 |
-
with open(chat_template, "r") as f:
|
68 |
-
tokenizer.chat_template = f.read()
|
69 |
-
except OSError as e:
|
70 |
-
JINJA_CHARS = "{}\n"
|
71 |
-
if not any(c in chat_template for c in JINJA_CHARS):
|
72 |
-
msg = (f"The supplied chat template ({chat_template}) "
|
73 |
-
f"looks like a file path, but it failed to be "
|
74 |
-
f"opened. Reason: {e}")
|
75 |
-
raise ValueError(msg) from e
|
76 |
-
|
77 |
-
# If opening a file fails, set chat template to be args to
|
78 |
-
# ensure we decode so our escape are interpreted correctly
|
79 |
-
tokenizer.chat_template = codecs.decode(
|
80 |
-
chat_template, "unicode_escape")
|
81 |
-
|
82 |
-
logger.info("Using supplied chat template:\n%s",
|
83 |
-
tokenizer.chat_template)
|
84 |
-
elif tokenizer.chat_template is not None:
|
85 |
-
logger.info("Using default chat template:\n%s",
|
86 |
-
tokenizer.chat_template)
|
87 |
-
else:
|
88 |
-
logger.warning(
|
89 |
-
"No chat template provided. Chat API will not work.")
|
90 |
-
|
91 |
-
def _parse_chat_message_content_parts(
|
92 |
-
self,
|
93 |
-
role: str,
|
94 |
-
parts: Iterable[ChatCompletionContentPartParam],
|
95 |
-
) -> ChatMessageParseResult:
|
96 |
-
texts: List[str] = []
|
97 |
-
|
98 |
-
for _, part in enumerate(parts):
|
99 |
-
part_type = part["type"]
|
100 |
-
if part_type == "text":
|
101 |
-
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
102 |
-
|
103 |
-
texts.append(text)
|
104 |
-
else:
|
105 |
-
raise NotImplementedError(f"Unknown part type: {part_type}")
|
106 |
-
|
107 |
-
messages = [ConversationMessage(role=role, content="\n".join(texts))]
|
108 |
-
|
109 |
-
return ChatMessageParseResult(messages=messages)
|
110 |
-
|
111 |
-
def _parse_chat_message_content(
|
112 |
-
self,
|
113 |
-
message: ChatCompletionMessageParam,
|
114 |
-
) -> ChatMessageParseResult:
|
115 |
-
role = message["role"]
|
116 |
-
content = message.get("content")
|
117 |
-
|
118 |
-
if content is None:
|
119 |
-
return ChatMessageParseResult(messages=[])
|
120 |
-
if isinstance(content, str):
|
121 |
-
messages = [ConversationMessage(role=role, content=content)]
|
122 |
-
return ChatMessageParseResult(messages=messages)
|
123 |
-
|
124 |
-
return self._parse_chat_message_content_parts(role, content)
|
125 |
-
|
126 |
-
async def create_chat_completion(
|
127 |
-
self,
|
128 |
-
request: ChatCompletionRequest,
|
129 |
-
raw_request: Optional[Request] = None
|
130 |
-
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
131 |
-
ChatCompletionResponse]:
|
132 |
"""Completion API similar to OpenAI's API.
|
133 |
|
134 |
-
See https://platform.openai.com/docs/api-reference/
|
135 |
-
for the API specification. This API mimics the OpenAI
|
136 |
-
ChatCompletion API.
|
137 |
|
138 |
NOTE: Currently we do not support the following feature:
|
139 |
-
-
|
|
|
140 |
"""
|
141 |
error_check_ret = await self._check_model(request)
|
142 |
if error_check_ret is not None:
|
143 |
return error_check_ret
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
parsed_msg = self._parse_chat_message_content(msg)
|
150 |
-
|
151 |
-
conversation.extend(parsed_msg.messages)
|
152 |
-
|
153 |
-
prompt = self.tokenizer.apply_chat_template(
|
154 |
-
conversation=conversation,
|
155 |
-
tokenize=False,
|
156 |
-
add_generation_prompt=request.add_generation_prompt,
|
157 |
-
)
|
158 |
-
except Exception as e:
|
159 |
-
logger.error("Error in applying chat template from request: %s", e)
|
160 |
-
return self.create_error_response(str(e))
|
161 |
|
|
|
162 |
request_id = f"cmpl-{random_uuid()}"
|
|
|
|
|
|
|
|
|
163 |
try:
|
164 |
-
# Tokenize/detokenize depending on prompt format (string/token list)
|
165 |
-
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
166 |
-
request,
|
167 |
-
prompt=prompt,
|
168 |
-
add_special_tokens=request.add_special_tokens)
|
169 |
sampling_params = request.to_sampling_params()
|
170 |
lora_request = self._maybe_get_lora(request)
|
171 |
decoding_config = await self.engine.get_decoding_config()
|
172 |
guided_decoding_backend = request.guided_decoding_backend \
|
173 |
or decoding_config.guided_decoding_backend
|
174 |
-
|
175 |
await get_guided_decoding_logits_processor(
|
176 |
guided_decoding_backend, request, await
|
177 |
self.engine.get_tokenizer()))
|
178 |
-
if
|
179 |
if sampling_params.logits_processors is None:
|
180 |
sampling_params.logits_processors = []
|
181 |
sampling_params.logits_processors.append(
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
except ValueError as e:
|
|
|
184 |
return self.create_error_response(str(e))
|
185 |
|
186 |
-
result_generator
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
# Streaming response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
if request.stream:
|
197 |
-
|
198 |
-
request, result_generator, request_id, conversation)
|
199 |
-
else:
|
200 |
-
try:
|
201 |
-
return await self.chat_completion_full_generator(
|
202 |
-
request, raw_request, result_generator, request_id,
|
203 |
-
conversation)
|
204 |
-
except ValueError as e:
|
205 |
-
# TODO: Use a vllm-specific Validation Error
|
206 |
-
return self.create_error_response(str(e))
|
207 |
-
|
208 |
-
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
209 |
-
if request.add_generation_prompt:
|
210 |
-
return self.response_role
|
211 |
-
else:
|
212 |
-
return request.messages[-1]["role"]
|
213 |
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
conversation: List[ConversationMessage]
|
218 |
-
) -> AsyncGenerator[str, None]:
|
219 |
-
model_name = self.served_model_names[0]
|
220 |
-
created_time = int(time.time())
|
221 |
-
chunk_object_type = "chat.completion.chunk"
|
222 |
-
first_iteration = True
|
223 |
|
224 |
-
|
225 |
-
assert request.n is not None
|
226 |
-
previous_texts = [""] * request.n
|
227 |
-
previous_num_tokens = [0] * request.n
|
228 |
-
finish_reason_sent = [False] * request.n
|
229 |
-
try:
|
230 |
-
async for res in result_generator:
|
231 |
-
# We need to do it here, because if there are exceptions in
|
232 |
-
# the result_generator, it needs to be sent as the FIRST
|
233 |
-
# response (by the try...catch).
|
234 |
-
if first_iteration:
|
235 |
-
# Send first response for each request.n (index) with
|
236 |
-
# the role
|
237 |
-
role = self.get_chat_request_role(request)
|
238 |
-
for i in range(request.n):
|
239 |
-
choice_data = ChatCompletionResponseStreamChoice(
|
240 |
-
index=i,
|
241 |
-
delta=DeltaMessage(role=role),
|
242 |
-
logprobs=None,
|
243 |
-
finish_reason=None)
|
244 |
-
chunk = ChatCompletionStreamResponse(
|
245 |
-
id=request_id,
|
246 |
-
object=chunk_object_type,
|
247 |
-
created=created_time,
|
248 |
-
choices=[choice_data],
|
249 |
-
model=model_name)
|
250 |
-
data = chunk.model_dump_json(exclude_unset=True)
|
251 |
-
yield f"data: {data}\n\n"
|
252 |
-
|
253 |
-
# Send response to echo the input portion of the
|
254 |
-
# last message
|
255 |
-
if request.echo:
|
256 |
-
last_msg_content = ""
|
257 |
-
if conversation and conversation[-1].get(
|
258 |
-
"content") and conversation[-1].get(
|
259 |
-
"role") == role:
|
260 |
-
last_msg_content = conversation[-1]["content"]
|
261 |
-
|
262 |
-
if last_msg_content:
|
263 |
-
for i in range(request.n):
|
264 |
-
choice_data = (
|
265 |
-
ChatCompletionResponseStreamChoice(
|
266 |
-
index=i,
|
267 |
-
delta=DeltaMessage(
|
268 |
-
content=last_msg_content),
|
269 |
-
finish_reason=None))
|
270 |
-
chunk = ChatCompletionStreamResponse(
|
271 |
-
id=request_id,
|
272 |
-
object=chunk_object_type,
|
273 |
-
created=created_time,
|
274 |
-
choices=[choice_data],
|
275 |
-
logprobs=None,
|
276 |
-
model=model_name)
|
277 |
-
data = chunk.model_dump_json(
|
278 |
-
exclude_unset=True)
|
279 |
-
yield f"data: {data}\n\n"
|
280 |
-
first_iteration = False
|
281 |
|
282 |
-
|
283 |
-
i = output.index
|
284 |
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
-
|
289 |
-
|
290 |
-
previous_num_tokens[i]:] if output.logprobs else None
|
291 |
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
token_ids=delta_token_ids,
|
295 |
-
top_logprobs=
|
296 |
-
num_output_top_logprobs=request.
|
|
|
297 |
)
|
298 |
else:
|
299 |
logprobs = None
|
300 |
|
301 |
-
delta_text = output.text[len(previous_texts[i]):]
|
302 |
previous_texts[i] = output.text
|
303 |
previous_num_tokens[i] = len(output.token_ids)
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
])
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
finish_reason=output.finish_reason,
|
346 |
-
stop_reason=output.stop_reason)
|
347 |
-
chunk = ChatCompletionStreamResponse(
|
348 |
-
id=request_id,
|
349 |
-
object=chunk_object_type,
|
350 |
-
created=created_time,
|
351 |
-
choices=[choice_data],
|
352 |
-
model=model_name)
|
353 |
-
if final_usage is not None:
|
354 |
-
chunk.usage = final_usage
|
355 |
-
data = chunk.model_dump_json(exclude_unset=True,
|
356 |
-
exclude_none=True)
|
357 |
-
yield f"data: {data}\n\n"
|
358 |
-
finish_reason_sent[i] = True
|
359 |
except ValueError as e:
|
360 |
# TODO: Use a vllm-specific Validation Error
|
361 |
data = self.create_streaming_error_response(str(e))
|
362 |
yield f"data: {data}\n\n"
|
363 |
-
# Send the final done message after all response.n are finished
|
364 |
yield "data: [DONE]\n\n"
|
365 |
|
366 |
-
|
367 |
-
self,
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
|
376 |
-
async for res in result_generator:
|
377 |
-
if raw_request is not None and await raw_request.is_disconnected():
|
378 |
-
# Abort the request if the client disconnects.
|
379 |
-
await self.engine.abort(request_id)
|
380 |
-
return self.create_error_response("Client disconnected")
|
381 |
-
final_res = res
|
382 |
-
assert final_res is not None
|
383 |
-
|
384 |
-
choices = []
|
385 |
-
|
386 |
-
role = self.get_chat_request_role(request)
|
387 |
-
for output in final_res.outputs:
|
388 |
-
token_ids = output.token_ids
|
389 |
-
top_logprobs = output.logprobs
|
390 |
-
|
391 |
-
if request.logprobs:
|
392 |
-
logprobs = self._create_chat_logprobs(
|
393 |
-
token_ids=token_ids,
|
394 |
-
top_logprobs=top_logprobs,
|
395 |
-
num_output_top_logprobs=request.top_logprobs,
|
396 |
-
)
|
397 |
-
else:
|
398 |
-
logprobs = None
|
399 |
-
|
400 |
-
if request.tool_choice and type(
|
401 |
-
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
402 |
-
message = ChatMessage(
|
403 |
-
role=role,
|
404 |
-
content="",
|
405 |
-
tool_calls=[
|
406 |
-
ToolCall(function=FunctionCall(
|
407 |
-
name=request.tool_choice.function.name,
|
408 |
-
arguments=output.text))
|
409 |
-
])
|
410 |
-
elif not request.tool_choice or request.tool_choice == "none":
|
411 |
-
message = ChatMessage(role=role, content=output.text)
|
412 |
-
|
413 |
-
choice_data = ChatCompletionResponseChoice(
|
414 |
-
index=output.index,
|
415 |
-
message=message,
|
416 |
-
logprobs=logprobs,
|
417 |
-
finish_reason=output.finish_reason,
|
418 |
-
stop_reason=output.stop_reason)
|
419 |
-
choices.append(choice_data)
|
420 |
-
|
421 |
-
if request.echo:
|
422 |
-
last_msg_content = ""
|
423 |
-
if conversation and conversation[-1].get(
|
424 |
-
"content") and conversation[-1].get("role") == role:
|
425 |
-
last_msg_content = conversation[-1]["content"]
|
426 |
-
|
427 |
-
for choice in choices:
|
428 |
-
full_message = last_msg_content + choice.message.content
|
429 |
-
choice.message.content = full_message
|
430 |
-
|
431 |
-
num_prompt_tokens = len(final_res.prompt_token_ids)
|
432 |
-
num_generated_tokens = sum(
|
433 |
-
len(output.token_ids) for output in final_res.outputs)
|
434 |
usage = UsageInfo(
|
435 |
prompt_tokens=num_prompt_tokens,
|
436 |
completion_tokens=num_generated_tokens,
|
437 |
total_tokens=num_prompt_tokens + num_generated_tokens,
|
438 |
)
|
439 |
-
|
|
|
440 |
id=request_id,
|
441 |
created=created_time,
|
442 |
model=model_name,
|
@@ -444,52 +393,84 @@ class OpenAIServingChat(OpenAIServing):
|
|
444 |
usage=usage,
|
445 |
)
|
446 |
|
447 |
-
|
448 |
-
|
449 |
-
def _get_top_logprobs(
|
450 |
-
self, logprobs: Dict[int, Logprob],
|
451 |
-
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
452 |
-
return [
|
453 |
-
ChatCompletionLogProb(
|
454 |
-
token=self._get_decoded_token(p[1], p[0]),
|
455 |
-
logprob=max(p[1].logprob, -9999.0),
|
456 |
-
bytes=list(
|
457 |
-
self._get_decoded_token(p[1],
|
458 |
-
p[0]).encode("utf-8",
|
459 |
-
errors="replace")))
|
460 |
-
for i, p in enumerate(logprobs.items())
|
461 |
-
if top_logprobs and i < top_logprobs
|
462 |
-
]
|
463 |
-
|
464 |
-
def _create_chat_logprobs(
|
465 |
self,
|
466 |
token_ids: GenericSequence[int],
|
467 |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
468 |
-
num_output_top_logprobs:
|
469 |
-
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
-
|
473 |
|
474 |
for i, token_id in enumerate(token_ids):
|
475 |
step_top_logprobs = top_logprobs[i]
|
476 |
if step_top_logprobs is None:
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
else:
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import time
|
2 |
+
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
|
|
3 |
Optional)
|
4 |
from typing import Sequence as GenericSequence
|
5 |
+
from typing import Tuple
|
6 |
|
7 |
from fastapi import Request
|
|
|
8 |
|
9 |
from vllm.config import ModelConfig
|
10 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
11 |
+
# yapf conflicts with isort for this block
|
12 |
+
# yapf: disable
|
13 |
+
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
14 |
+
CompletionRequest,
|
15 |
+
CompletionResponse,
|
16 |
+
CompletionResponseChoice,
|
17 |
+
CompletionResponseStreamChoice,
|
18 |
+
CompletionStreamResponse,
|
19 |
+
DetokenizeRequest,
|
20 |
+
DetokenizeResponse,
|
21 |
+
TokenizeRequest,
|
22 |
+
TokenizeResponse, UsageInfo)
|
23 |
+
# yapf: enable
|
24 |
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
25 |
OpenAIServing)
|
26 |
from vllm.logger import init_logger
|
|
|
28 |
get_guided_decoding_logits_processor)
|
29 |
from vllm.outputs import RequestOutput
|
30 |
from vllm.sequence import Logprob
|
31 |
+
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
32 |
+
log_tracing_disabled_warning)
|
33 |
+
from vllm.utils import merge_async_iterators, random_uuid
|
34 |
|
35 |
logger = init_logger(__name__)
|
36 |
|
37 |
+
TypeTokenIDs = List[int]
|
38 |
+
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
39 |
+
TypeCreateLogProbsFn = Callable[
|
40 |
+
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
41 |
+
|
42 |
+
|
43 |
+
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
44 |
+
# get the prompt, openai supports the following
|
45 |
+
# "a string, array of strings, array of tokens, or array of token arrays."
|
46 |
+
prompt_is_tokens = False
|
47 |
+
prompts = [prompt] # case 1: a string
|
48 |
+
if isinstance(prompt, list):
|
49 |
+
if len(prompt) == 0:
|
50 |
+
raise ValueError("please provide at least one prompt")
|
51 |
+
elif isinstance(prompt[0], str):
|
52 |
+
prompt_is_tokens = False
|
53 |
+
prompts = prompt # case 2: array of strings
|
54 |
+
elif isinstance(prompt[0], int):
|
55 |
+
prompt_is_tokens = True
|
56 |
+
prompts = [prompt] # case 3: array of tokens
|
57 |
+
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
58 |
+
prompt_is_tokens = True
|
59 |
+
prompts = prompt # case 4: array of token arrays
|
60 |
+
else:
|
61 |
+
raise ValueError("prompt must be a string, array of strings, "
|
62 |
+
"array of tokens, or array of token arrays")
|
63 |
+
return prompt_is_tokens, prompts
|
64 |
|
65 |
|
66 |
+
class OpenAIServingCompletion(OpenAIServing):
|
67 |
|
68 |
+
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
|
|
|
|
69 |
served_model_names: List[str],
|
70 |
+
lora_modules: Optional[List[LoRAModulePath]]):
|
|
|
|
|
71 |
super().__init__(engine=engine,
|
72 |
model_config=model_config,
|
73 |
served_model_names=served_model_names,
|
74 |
lora_modules=lora_modules)
|
75 |
|
76 |
+
async def create_completion(self, request: CompletionRequest,
|
77 |
+
raw_request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
"""Completion API similar to OpenAI's API.
|
79 |
|
80 |
+
See https://platform.openai.com/docs/api-reference/completions/create
|
81 |
+
for the API specification. This API mimics the OpenAI Completion API.
|
|
|
82 |
|
83 |
NOTE: Currently we do not support the following feature:
|
84 |
+
- suffix (the language models we currently support do not support
|
85 |
+
suffix)
|
86 |
"""
|
87 |
error_check_ret = await self._check_model(request)
|
88 |
if error_check_ret is not None:
|
89 |
return error_check_ret
|
90 |
|
91 |
+
# Return error for unsupported features.
|
92 |
+
if request.suffix is not None:
|
93 |
+
return self.create_error_response(
|
94 |
+
"suffix is not currently supported")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
model_name = self.served_model_names[0]
|
97 |
request_id = f"cmpl-{random_uuid()}"
|
98 |
+
created_time = int(time.time())
|
99 |
+
|
100 |
+
# Schedule the request and get the result generator.
|
101 |
+
generators: List[AsyncIterator[RequestOutput]] = []
|
102 |
try:
|
|
|
|
|
|
|
|
|
|
|
103 |
sampling_params = request.to_sampling_params()
|
104 |
lora_request = self._maybe_get_lora(request)
|
105 |
decoding_config = await self.engine.get_decoding_config()
|
106 |
guided_decoding_backend = request.guided_decoding_backend \
|
107 |
or decoding_config.guided_decoding_backend
|
108 |
+
guided_decode_logit_processor = (
|
109 |
await get_guided_decoding_logits_processor(
|
110 |
guided_decoding_backend, request, await
|
111 |
self.engine.get_tokenizer()))
|
112 |
+
if guided_decode_logit_processor is not None:
|
113 |
if sampling_params.logits_processors is None:
|
114 |
sampling_params.logits_processors = []
|
115 |
sampling_params.logits_processors.append(
|
116 |
+
guided_decode_logit_processor)
|
117 |
+
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
118 |
+
|
119 |
+
for i, prompt in enumerate(prompts):
|
120 |
+
if prompt_is_tokens:
|
121 |
+
prompt_formats = self._validate_prompt_and_tokenize(
|
122 |
+
request,
|
123 |
+
prompt_ids=prompt,
|
124 |
+
truncate_prompt_tokens=sampling_params.
|
125 |
+
truncate_prompt_tokens)
|
126 |
+
else:
|
127 |
+
prompt_formats = self._validate_prompt_and_tokenize(
|
128 |
+
request,
|
129 |
+
prompt=prompt,
|
130 |
+
truncate_prompt_tokens=sampling_params.
|
131 |
+
truncate_prompt_tokens)
|
132 |
+
prompt_ids, prompt_text = prompt_formats
|
133 |
+
|
134 |
+
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
135 |
+
trace_headers = None
|
136 |
+
if is_tracing_enabled:
|
137 |
+
trace_headers = extract_trace_headers(raw_request.headers)
|
138 |
+
if not is_tracing_enabled and contains_trace_headers(
|
139 |
+
raw_request.headers):
|
140 |
+
log_tracing_disabled_warning()
|
141 |
+
|
142 |
+
generator = self.engine.generate(
|
143 |
+
{
|
144 |
+
"prompt": prompt_text,
|
145 |
+
"prompt_token_ids": prompt_ids
|
146 |
+
},
|
147 |
+
sampling_params,
|
148 |
+
f"{request_id}-{i}",
|
149 |
+
lora_request=lora_request,
|
150 |
+
trace_headers=trace_headers,
|
151 |
+
)
|
152 |
+
|
153 |
+
generators.append(generator)
|
154 |
except ValueError as e:
|
155 |
+
# TODO: Use a vllm-specific Validation Error
|
156 |
return self.create_error_response(str(e))
|
157 |
|
158 |
+
result_generator: AsyncIterator[Tuple[
|
159 |
+
int, RequestOutput]] = merge_async_iterators(*generators)
|
160 |
+
|
161 |
+
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
162 |
+
# results. In addition, we do not stream the results when use
|
163 |
+
# beam search.
|
164 |
+
stream = (request.stream
|
165 |
+
and (request.best_of is None or request.n == request.best_of)
|
166 |
+
and not request.use_beam_search)
|
167 |
+
|
168 |
# Streaming response
|
169 |
+
if stream:
|
170 |
+
return self.completion_stream_generator(request,
|
171 |
+
raw_request,
|
172 |
+
result_generator,
|
173 |
+
request_id,
|
174 |
+
created_time,
|
175 |
+
model_name,
|
176 |
+
num_prompts=len(prompts))
|
177 |
+
|
178 |
+
# Non-streaming response
|
179 |
+
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
180 |
+
try:
|
181 |
+
async for i, res in result_generator:
|
182 |
+
if await raw_request.is_disconnected():
|
183 |
+
# Abort the request if the client disconnects.
|
184 |
+
await self.engine.abort(f"{request_id}-{i}")
|
185 |
+
return self.create_error_response("Client disconnected")
|
186 |
+
final_res_batch[i] = res
|
187 |
+
response = self.request_output_to_completion_response(
|
188 |
+
final_res_batch, request, request_id, created_time, model_name)
|
189 |
+
except ValueError as e:
|
190 |
+
# TODO: Use a vllm-specific Validation Error
|
191 |
+
return self.create_error_response(str(e))
|
192 |
+
|
193 |
+
# When user requests streaming but we don't stream, we still need to
|
194 |
+
# return a streaming response with a single event.
|
195 |
if request.stream:
|
196 |
+
response_json = response.model_dump_json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
199 |
+
yield f"data: {response_json}\n\n"
|
200 |
+
yield "data: [DONE]\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
+
return fake_stream_generator()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
+
return response
|
|
|
205 |
|
206 |
+
async def completion_stream_generator(
|
207 |
+
self,
|
208 |
+
request: CompletionRequest,
|
209 |
+
raw_request: Request,
|
210 |
+
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
211 |
+
request_id: str,
|
212 |
+
created_time: int,
|
213 |
+
model_name: str,
|
214 |
+
num_prompts: int,
|
215 |
+
) -> AsyncGenerator[str, None]:
|
216 |
+
assert request.n is not None
|
217 |
+
previous_texts = [""] * request.n * num_prompts
|
218 |
+
previous_num_tokens = [0] * request.n * num_prompts
|
219 |
+
has_echoed = [False] * request.n * num_prompts
|
220 |
|
221 |
+
try:
|
222 |
+
async for prompt_idx, res in result_generator:
|
|
|
223 |
|
224 |
+
# Abort the request if the client disconnects.
|
225 |
+
if await raw_request.is_disconnected():
|
226 |
+
await self.engine.abort(f"{request_id}-{prompt_idx}")
|
227 |
+
raise StopAsyncIteration()
|
228 |
+
|
229 |
+
for output in res.outputs:
|
230 |
+
i = output.index + prompt_idx * request.n
|
231 |
+
# TODO(simon): optimize the performance by avoiding full
|
232 |
+
# text O(n^2) sending.
|
233 |
+
|
234 |
+
assert request.max_tokens is not None
|
235 |
+
if request.echo and request.max_tokens == 0:
|
236 |
+
# only return the prompt
|
237 |
+
delta_text = res.prompt
|
238 |
+
delta_token_ids = res.prompt_token_ids
|
239 |
+
out_logprobs = res.prompt_logprobs
|
240 |
+
has_echoed[i] = True
|
241 |
+
elif (request.echo and request.max_tokens > 0
|
242 |
+
and not has_echoed[i]):
|
243 |
+
# echo the prompt and first token
|
244 |
+
delta_text = res.prompt + output.text
|
245 |
+
delta_token_ids = (res.prompt_token_ids +
|
246 |
+
output.token_ids)
|
247 |
+
out_logprobs = res.prompt_logprobs + (output.logprobs
|
248 |
+
or [])
|
249 |
+
has_echoed[i] = True
|
250 |
+
else:
|
251 |
+
# return just the delta
|
252 |
+
delta_text = output.text[len(previous_texts[i]):]
|
253 |
+
delta_token_ids = output.token_ids[
|
254 |
+
previous_num_tokens[i]:]
|
255 |
+
out_logprobs = output.logprobs[previous_num_tokens[
|
256 |
+
i]:] if output.logprobs else None
|
257 |
+
|
258 |
+
if request.logprobs is not None:
|
259 |
+
assert out_logprobs is not None, (
|
260 |
+
"Did not output logprobs")
|
261 |
+
logprobs = self._create_completion_logprobs(
|
262 |
token_ids=delta_token_ids,
|
263 |
+
top_logprobs=out_logprobs,
|
264 |
+
num_output_top_logprobs=request.logprobs,
|
265 |
+
initial_text_offset=len(previous_texts[i]),
|
266 |
)
|
267 |
else:
|
268 |
logprobs = None
|
269 |
|
|
|
270 |
previous_texts[i] = output.text
|
271 |
previous_num_tokens[i] = len(output.token_ids)
|
272 |
+
finish_reason = output.finish_reason
|
273 |
+
stop_reason = output.stop_reason
|
274 |
+
|
275 |
+
chunk = CompletionStreamResponse(
|
276 |
+
id=request_id,
|
277 |
+
created=created_time,
|
278 |
+
model=model_name,
|
279 |
+
choices=[
|
280 |
+
CompletionResponseStreamChoice(
|
281 |
+
index=i,
|
282 |
+
text=delta_text,
|
283 |
+
logprobs=logprobs,
|
284 |
+
finish_reason=finish_reason,
|
285 |
+
stop_reason=stop_reason,
|
286 |
+
)
|
287 |
])
|
288 |
+
if (request.stream_options
|
289 |
+
and request.stream_options.include_usage):
|
290 |
+
if (request.stream_options.continuous_usage_stats
|
291 |
+
or output.finish_reason is not None):
|
292 |
+
prompt_tokens = len(res.prompt_token_ids)
|
293 |
+
completion_tokens = len(output.token_ids)
|
294 |
+
usage = UsageInfo(
|
295 |
+
prompt_tokens=prompt_tokens,
|
296 |
+
completion_tokens=completion_tokens,
|
297 |
+
total_tokens=prompt_tokens + completion_tokens,
|
298 |
+
)
|
299 |
+
if request.stream_options.continuous_usage_stats:
|
300 |
+
chunk.usage = usage
|
301 |
+
else:
|
302 |
+
chunk.usage = None
|
303 |
+
|
304 |
+
response_json = chunk.model_dump_json(exclude_unset=True)
|
305 |
+
yield f"data: {response_json}\n\n"
|
306 |
+
|
307 |
+
if (request.stream_options
|
308 |
+
and request.stream_options.include_usage):
|
309 |
+
final_usage_chunk = CompletionStreamResponse(
|
310 |
+
id=request_id,
|
311 |
+
created=created_time,
|
312 |
+
model=model_name,
|
313 |
+
choices=[],
|
314 |
+
usage=usage,
|
315 |
+
)
|
316 |
+
final_usage_data = (final_usage_chunk.model_dump_json(
|
317 |
+
exclude_unset=True, exclude_none=True))
|
318 |
+
yield f"data: {final_usage_data}\n\n"
|
319 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
except ValueError as e:
|
321 |
# TODO: Use a vllm-specific Validation Error
|
322 |
data = self.create_streaming_error_response(str(e))
|
323 |
yield f"data: {data}\n\n"
|
|
|
324 |
yield "data: [DONE]\n\n"
|
325 |
|
326 |
+
def request_output_to_completion_response(
|
327 |
+
self,
|
328 |
+
final_res_batch: List[RequestOutput],
|
329 |
+
request: CompletionRequest,
|
330 |
+
request_id: str,
|
331 |
+
created_time: int,
|
332 |
+
model_name: str,
|
333 |
+
) -> CompletionResponse:
|
334 |
+
choices: List[CompletionResponseChoice] = []
|
335 |
+
num_prompt_tokens = 0
|
336 |
+
num_generated_tokens = 0
|
337 |
+
for final_res in final_res_batch:
|
338 |
+
assert final_res is not None
|
339 |
+
prompt_token_ids = final_res.prompt_token_ids
|
340 |
+
prompt_logprobs = final_res.prompt_logprobs
|
341 |
+
prompt_text = final_res.prompt
|
342 |
+
|
343 |
+
for output in final_res.outputs:
|
344 |
+
assert request.max_tokens is not None
|
345 |
+
if request.echo and request.max_tokens == 0:
|
346 |
+
token_ids = prompt_token_ids
|
347 |
+
out_logprobs = prompt_logprobs
|
348 |
+
output_text = prompt_text
|
349 |
+
elif request.echo and request.max_tokens > 0:
|
350 |
+
token_ids = prompt_token_ids + list(output.token_ids)
|
351 |
+
out_logprobs = (prompt_logprobs + output.logprobs
|
352 |
+
if request.logprobs is not None else None)
|
353 |
+
output_text = prompt_text + output.text
|
354 |
+
else:
|
355 |
+
token_ids = output.token_ids
|
356 |
+
out_logprobs = output.logprobs
|
357 |
+
output_text = output.text
|
358 |
+
|
359 |
+
if request.logprobs is not None:
|
360 |
+
assert out_logprobs is not None, "Did not output logprobs"
|
361 |
+
logprobs = self._create_completion_logprobs(
|
362 |
+
token_ids=token_ids,
|
363 |
+
top_logprobs=out_logprobs,
|
364 |
+
num_output_top_logprobs=request.logprobs,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
logprobs = None
|
368 |
+
|
369 |
+
choice_data = CompletionResponseChoice(
|
370 |
+
index=len(choices),
|
371 |
+
text=output_text,
|
372 |
+
logprobs=logprobs,
|
373 |
+
finish_reason=output.finish_reason,
|
374 |
+
stop_reason=output.stop_reason,
|
375 |
+
)
|
376 |
+
choices.append(choice_data)
|
377 |
|
378 |
+
num_prompt_tokens += len(prompt_token_ids)
|
379 |
+
num_generated_tokens += sum(
|
380 |
+
len(output.token_ids) for output in final_res.outputs)
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
usage = UsageInfo(
|
383 |
prompt_tokens=num_prompt_tokens,
|
384 |
completion_tokens=num_generated_tokens,
|
385 |
total_tokens=num_prompt_tokens + num_generated_tokens,
|
386 |
)
|
387 |
+
|
388 |
+
return CompletionResponse(
|
389 |
id=request_id,
|
390 |
created=created_time,
|
391 |
model=model_name,
|
|
|
393 |
usage=usage,
|
394 |
)
|
395 |
|
396 |
+
def _create_completion_logprobs(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
self,
|
398 |
token_ids: GenericSequence[int],
|
399 |
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
400 |
+
num_output_top_logprobs: int,
|
401 |
+
initial_text_offset: int = 0,
|
402 |
+
) -> CompletionLogProbs:
|
403 |
+
"""Create logprobs for OpenAI Completion API."""
|
404 |
+
out_text_offset: List[int] = []
|
405 |
+
out_token_logprobs: List[Optional[float]] = []
|
406 |
+
out_tokens: List[str] = []
|
407 |
+
out_top_logprobs: List[Optional[Dict[str, float]]] = []
|
408 |
|
409 |
+
last_token_len = 0
|
410 |
|
411 |
for i, token_id in enumerate(token_ids):
|
412 |
step_top_logprobs = top_logprobs[i]
|
413 |
if step_top_logprobs is None:
|
414 |
+
token = self.tokenizer.decode(token_id)
|
415 |
+
out_tokens.append(token)
|
416 |
+
out_token_logprobs.append(None)
|
417 |
+
out_top_logprobs.append(None)
|
418 |
+
else:
|
419 |
+
token = self._get_decoded_token(step_top_logprobs[token_id],
|
420 |
+
token_id)
|
421 |
+
token_logprob = max(step_top_logprobs[token_id].logprob,
|
422 |
+
-9999.0)
|
423 |
+
out_tokens.append(token)
|
424 |
+
out_token_logprobs.append(token_logprob)
|
425 |
+
|
426 |
+
# makes sure to add the top num_output_top_logprobs + 1
|
427 |
+
# logprobs, as defined in the openai API
|
428 |
+
# (cf. https://github.com/openai/openai-openapi/blob/
|
429 |
+
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
430 |
+
out_top_logprobs.append({
|
431 |
+
# Convert float("-inf") to the
|
432 |
+
# JSON-serializable float that OpenAI uses
|
433 |
+
self._get_decoded_token(top_lp[1], top_lp[0]):
|
434 |
+
max(top_lp[1].logprob, -9999.0)
|
435 |
+
for i, top_lp in enumerate(step_top_logprobs.items())
|
436 |
+
if num_output_top_logprobs >= i
|
437 |
+
})
|
438 |
+
|
439 |
+
if len(out_text_offset) == 0:
|
440 |
+
out_text_offset.append(initial_text_offset)
|
441 |
else:
|
442 |
+
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
443 |
+
last_token_len = len(token)
|
444 |
+
|
445 |
+
return CompletionLogProbs(
|
446 |
+
text_offset=out_text_offset,
|
447 |
+
token_logprobs=out_token_logprobs,
|
448 |
+
tokens=out_tokens,
|
449 |
+
top_logprobs=out_top_logprobs,
|
450 |
+
)
|
451 |
+
|
452 |
+
async def create_tokenize(self,
|
453 |
+
request: TokenizeRequest) -> TokenizeResponse:
|
454 |
+
error_check_ret = await self._check_model(request)
|
455 |
+
if error_check_ret is not None:
|
456 |
+
return error_check_ret
|
457 |
+
|
458 |
+
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
459 |
+
request,
|
460 |
+
prompt=request.prompt,
|
461 |
+
add_special_tokens=request.add_special_tokens)
|
462 |
+
|
463 |
+
return TokenizeResponse(tokens=input_ids,
|
464 |
+
count=len(input_ids),
|
465 |
+
max_model_len=self.max_model_len)
|
466 |
+
|
467 |
+
async def create_detokenize(
|
468 |
+
self, request: DetokenizeRequest) -> DetokenizeResponse:
|
469 |
+
error_check_ret = await self._check_model(request)
|
470 |
+
if error_check_ret is not None:
|
471 |
+
return error_check_ret
|
472 |
+
|
473 |
+
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
474 |
+
request, prompt_ids=request.tokens)
|
475 |
+
|
476 |
+
return DetokenizeResponse(prompt=input_text)
|
serving_embedding.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import time
|
3 |
+
from typing import AsyncIterator, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from fastapi import Request
|
7 |
+
|
8 |
+
from vllm.config import ModelConfig
|
9 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
10 |
+
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
11 |
+
EmbeddingResponse,
|
12 |
+
EmbeddingResponseData, UsageInfo)
|
13 |
+
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
|
14 |
+
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
15 |
+
from vllm.logger import init_logger
|
16 |
+
from vllm.outputs import EmbeddingRequestOutput
|
17 |
+
from vllm.utils import merge_async_iterators, random_uuid
|
18 |
+
|
19 |
+
logger = init_logger(__name__)
|
20 |
+
|
21 |
+
TypeTokenIDs = List[int]
|
22 |
+
|
23 |
+
|
24 |
+
def request_output_to_embedding_response(
|
25 |
+
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
26 |
+
created_time: int, model_name: str,
|
27 |
+
encoding_format: str) -> EmbeddingResponse:
|
28 |
+
data: List[EmbeddingResponseData] = []
|
29 |
+
num_prompt_tokens = 0
|
30 |
+
for idx, final_res in enumerate(final_res_batch):
|
31 |
+
assert final_res is not None
|
32 |
+
prompt_token_ids = final_res.prompt_token_ids
|
33 |
+
embedding = final_res.outputs.embedding
|
34 |
+
if encoding_format == "base64":
|
35 |
+
embedding = base64.b64encode(np.array(embedding))
|
36 |
+
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
37 |
+
data.append(embedding_data)
|
38 |
+
|
39 |
+
num_prompt_tokens += len(prompt_token_ids)
|
40 |
+
|
41 |
+
usage = UsageInfo(
|
42 |
+
prompt_tokens=num_prompt_tokens,
|
43 |
+
total_tokens=num_prompt_tokens,
|
44 |
+
)
|
45 |
+
|
46 |
+
return EmbeddingResponse(
|
47 |
+
id=request_id,
|
48 |
+
created=created_time,
|
49 |
+
model=model_name,
|
50 |
+
data=data,
|
51 |
+
usage=usage,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class OpenAIServingEmbedding(OpenAIServing):
|
56 |
+
|
57 |
+
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
58 |
+
served_model_names: List[str]):
|
59 |
+
super().__init__(engine=engine,
|
60 |
+
model_config=model_config,
|
61 |
+
served_model_names=served_model_names,
|
62 |
+
lora_modules=None)
|
63 |
+
self._check_embedding_mode(model_config.embedding_mode)
|
64 |
+
|
65 |
+
async def create_embedding(self, request: EmbeddingRequest,
|
66 |
+
raw_request: Request):
|
67 |
+
"""Completion API similar to OpenAI's API.
|
68 |
+
|
69 |
+
See https://platform.openai.com/docs/api-reference/embeddings/create
|
70 |
+
for the API specification. This API mimics the OpenAI Embedding API.
|
71 |
+
"""
|
72 |
+
error_check_ret = await self._check_model(request)
|
73 |
+
if error_check_ret is not None:
|
74 |
+
return error_check_ret
|
75 |
+
|
76 |
+
encoding_format = (request.encoding_format
|
77 |
+
if request.encoding_format else "float")
|
78 |
+
if request.dimensions is not None:
|
79 |
+
return self.create_error_response(
|
80 |
+
"dimensions is currently not supported")
|
81 |
+
|
82 |
+
model_name = request.model
|
83 |
+
request_id = f"cmpl-{random_uuid()}"
|
84 |
+
created_time = int(time.monotonic())
|
85 |
+
|
86 |
+
# Schedule the request and get the result generator.
|
87 |
+
generators = []
|
88 |
+
try:
|
89 |
+
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
90 |
+
pooling_params = request.to_pooling_params()
|
91 |
+
|
92 |
+
for i, prompt in enumerate(prompts):
|
93 |
+
if prompt_is_tokens:
|
94 |
+
prompt_formats = self._validate_prompt_and_tokenize(
|
95 |
+
request, prompt_ids=prompt)
|
96 |
+
else:
|
97 |
+
prompt_formats = self._validate_prompt_and_tokenize(
|
98 |
+
request, prompt=prompt)
|
99 |
+
|
100 |
+
prompt_ids, prompt_text = prompt_formats
|
101 |
+
|
102 |
+
generator = self.engine.encode(
|
103 |
+
{
|
104 |
+
"prompt": prompt_text,
|
105 |
+
"prompt_token_ids": prompt_ids
|
106 |
+
},
|
107 |
+
pooling_params,
|
108 |
+
f"{request_id}-{i}",
|
109 |
+
)
|
110 |
+
|
111 |
+
generators.append(generator)
|
112 |
+
except ValueError as e:
|
113 |
+
# TODO: Use a vllm-specific Validation Error
|
114 |
+
return self.create_error_response(str(e))
|
115 |
+
|
116 |
+
result_generator: AsyncIterator[Tuple[
|
117 |
+
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
|
118 |
+
|
119 |
+
# Non-streaming response
|
120 |
+
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
121 |
+
final_res_batch = [None] * len(prompts)
|
122 |
+
try:
|
123 |
+
async for i, res in result_generator:
|
124 |
+
if await raw_request.is_disconnected():
|
125 |
+
# Abort the request if the client disconnects.
|
126 |
+
await self.engine.abort(f"{request_id}-{i}")
|
127 |
+
# TODO: Use a vllm-specific Validation Error
|
128 |
+
return self.create_error_response("Client disconnected")
|
129 |
+
final_res_batch[i] = res
|
130 |
+
response = request_output_to_embedding_response(
|
131 |
+
final_res_batch, request_id, created_time, model_name,
|
132 |
+
encoding_format)
|
133 |
+
except ValueError as e:
|
134 |
+
# TODO: Use a vllm-specific Validation Error
|
135 |
+
return self.create_error_response(str(e))
|
136 |
+
|
137 |
+
return response
|
138 |
+
|
139 |
+
def _check_embedding_mode(self, embedding_mode: bool):
|
140 |
+
if not embedding_mode:
|
141 |
+
logger.warning(
|
142 |
+
"embedding_mode is False. Embedding API will not work.")
|
143 |
+
else:
|
144 |
+
logger.info("Activating the server engine with embedding enabled.")
|
serving_engine.py
CHANGED
@@ -10,9 +10,10 @@ from vllm.config import ModelConfig
|
|
10 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
11 |
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
12 |
CompletionRequest,
|
|
|
13 |
EmbeddingRequest, ErrorResponse,
|
14 |
ModelCard, ModelList,
|
15 |
-
ModelPermission)
|
16 |
from vllm.logger import init_logger
|
17 |
from vllm.lora.request import LoRARequest
|
18 |
from vllm.sequence import Logprob
|
@@ -35,6 +36,7 @@ class OpenAIServing:
|
|
35 |
super().__init__()
|
36 |
|
37 |
self.engine = engine
|
|
|
38 |
self.max_model_len = model_config.max_model_len
|
39 |
|
40 |
# A separate tokenizer to map token IDs to strings.
|
@@ -99,8 +101,9 @@ class OpenAIServing:
|
|
99 |
return json_str
|
100 |
|
101 |
async def _check_model(
|
102 |
-
self, request: Union[
|
103 |
-
EmbeddingRequest
|
|
|
104 |
) -> Optional[ErrorResponse]:
|
105 |
if request.model in self.served_model_names:
|
106 |
return None
|
@@ -126,7 +129,8 @@ class OpenAIServing:
|
|
126 |
def _validate_prompt_and_tokenize(
|
127 |
self,
|
128 |
request: Union[ChatCompletionRequest, CompletionRequest,
|
129 |
-
EmbeddingRequest
|
|
|
130 |
prompt: Optional[str] = None,
|
131 |
prompt_ids: Optional[List[int]] = None,
|
132 |
truncate_prompt_tokens: Optional[Annotated[int,
|
@@ -174,6 +178,11 @@ class OpenAIServing:
|
|
174 |
f"generation. Please reduce the length of the input.", )
|
175 |
return input_ids, input_text
|
176 |
|
|
|
|
|
|
|
|
|
|
|
177 |
if request.max_tokens is None:
|
178 |
if token_num >= self.max_model_len:
|
179 |
raise ValueError(
|
|
|
10 |
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
11 |
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
12 |
CompletionRequest,
|
13 |
+
DetokenizeRequest,
|
14 |
EmbeddingRequest, ErrorResponse,
|
15 |
ModelCard, ModelList,
|
16 |
+
ModelPermission, TokenizeRequest)
|
17 |
from vllm.logger import init_logger
|
18 |
from vllm.lora.request import LoRARequest
|
19 |
from vllm.sequence import Logprob
|
|
|
36 |
super().__init__()
|
37 |
|
38 |
self.engine = engine
|
39 |
+
self.model_config = model_config
|
40 |
self.max_model_len = model_config.max_model_len
|
41 |
|
42 |
# A separate tokenizer to map token IDs to strings.
|
|
|
101 |
return json_str
|
102 |
|
103 |
async def _check_model(
|
104 |
+
self, request: Union[ChatCompletionRequest, CompletionRequest,
|
105 |
+
DetokenizeRequest, EmbeddingRequest,
|
106 |
+
TokenizeRequest]
|
107 |
) -> Optional[ErrorResponse]:
|
108 |
if request.model in self.served_model_names:
|
109 |
return None
|
|
|
129 |
def _validate_prompt_and_tokenize(
|
130 |
self,
|
131 |
request: Union[ChatCompletionRequest, CompletionRequest,
|
132 |
+
DetokenizeRequest, EmbeddingRequest,
|
133 |
+
TokenizeRequest],
|
134 |
prompt: Optional[str] = None,
|
135 |
prompt_ids: Optional[List[int]] = None,
|
136 |
truncate_prompt_tokens: Optional[Annotated[int,
|
|
|
178 |
f"generation. Please reduce the length of the input.", )
|
179 |
return input_ids, input_text
|
180 |
|
181 |
+
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
182 |
+
# and does not require model context length validation
|
183 |
+
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
|
184 |
+
return input_ids, input_text
|
185 |
+
|
186 |
if request.max_tokens is None:
|
187 |
if token_num >= self.max_model_len:
|
188 |
raise ValueError(
|