dh-mc commited on
Commit
d2aae87
·
1 Parent(s): 4aab576

creating openai compatible server

Browse files
TurtleSoupBaseline/openai_api_server.py CHANGED
@@ -17,7 +17,9 @@ from transformers import AutoTokenizer, LogitsProcessor
17
  from sse_starlette.sse import EventSourceResponse
18
 
19
  EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
20
- MODEL_PATH = 'THUDM/glm-4-9b-chat'
 
 
21
  MAX_MODEL_LENGTH = 8192
22
 
23
 
@@ -125,14 +127,16 @@ class ChatCompletionResponse(BaseModel):
125
  model: str
126
  id: str
127
  object: Literal["chat.completion", "chat.completion.chunk"]
128
- choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
 
 
129
  created: Optional[int] = Field(default_factory=lambda: int(time.time()))
130
  usage: Optional[UsageInfo] = None
131
 
132
 
133
  class InvalidScoreLogitsProcessor(LogitsProcessor):
134
  def __call__(
135
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
136
  ) -> torch.FloatTensor:
137
  if torch.isnan(scores).any() or torch.isinf(scores).any():
138
  scores.zero_()
@@ -154,13 +158,10 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
154
  parameters = eval(content.strip())
155
  content = {
156
  "name": metadata.strip(),
157
- "arguments": json.dumps(parameters, ensure_ascii=False)
158
  }
159
  else:
160
- content = {
161
- "name": metadata.strip(),
162
- "content": content
163
- }
164
  return content
165
 
166
 
@@ -174,7 +175,9 @@ async def generate_stream_glm4(params):
174
  top_p = float(params.get("top_p", 1.0))
175
  max_new_tokens = int(params.get("max_tokens", 8192))
176
  messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
177
- inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
 
178
  params_dict = {
179
  "n": 1,
180
  "best_of": 1,
@@ -195,7 +198,9 @@ async def generate_stream_glm4(params):
195
  "skip_special_tokens": True,
196
  }
197
  sampling_params = SamplingParams(**params_dict)
198
- async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id="glm-4-9b"):
 
 
199
  output_len = len(output.outputs[0].token_ids)
200
  input_len = len(output.prompt_token_ids)
201
  ret = {
@@ -203,7 +208,7 @@ async def generate_stream_glm4(params):
203
  "usage": {
204
  "prompt_tokens": input_len,
205
  "completion_tokens": output_len,
206
- "total_tokens": output_len + input_len
207
  },
208
  "finish_reason": output.outputs[0].finish_reason,
209
  }
@@ -218,12 +223,13 @@ def process_messages(messages, tools=None, tool_choice="none"):
218
  msg_has_sys = False
219
 
220
  def filter_tools(tool_choice, tools):
221
- function_name = tool_choice.get('function', {}).get('name', None)
222
  if not function_name:
223
  return []
224
  filtered_tools = [
225
- tool for tool in tools
226
- if tool.get('function', {}).get('name') == function_name
 
227
  ]
228
  return filtered_tools
229
 
@@ -231,13 +237,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
231
  if isinstance(tool_choice, dict):
232
  tools = filter_tools(tool_choice, tools)
233
  if tools:
234
- messages.append(
235
- {
236
- "role": "system",
237
- "content": None,
238
- "tools": tools
239
- }
240
- )
241
  msg_has_sys = True
242
 
243
  # add to metadata
@@ -246,19 +246,14 @@ def process_messages(messages, tools=None, tool_choice="none"):
246
  {
247
  "role": "assistant",
248
  "metadata": tool_choice["function"]["name"],
249
- "content": ""
250
  }
251
  )
252
 
253
  for m in _messages:
254
  role, content, func_call = m.role, m.content, m.function_call
255
  if role == "function":
256
- messages.append(
257
- {
258
- "role": "observation",
259
- "content": content
260
- }
261
- )
262
  elif role == "assistant" and func_call is not None:
263
  for response in content.split("<|assistant|>"):
264
  if "\n" in response:
@@ -266,11 +261,7 @@ def process_messages(messages, tools=None, tool_choice="none"):
266
  else:
267
  metadata, sub_content = "", response
268
  messages.append(
269
- {
270
- "role": role,
271
- "metadata": metadata,
272
- "content": sub_content.strip()
273
- }
274
  )
275
  else:
276
  if role == "system" and msg_has_sys:
@@ -315,7 +306,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
315
  predict_stream_generator = predict_stream(request.model, gen_params)
316
  output = await anext(predict_stream_generator)
317
  if output:
318
- return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
 
 
319
  logger.debug(f"First result output:\n{output}")
320
 
321
  function_call = None
@@ -332,7 +325,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
332
  if not gen_params.get("messages"):
333
  gen_params["messages"] = []
334
  gen_params["messages"].append(ChatMessage(role="assistant", content=output))
335
- gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
 
 
336
  generate = predict(request.model, gen_params)
337
  return EventSourceResponse(generate, media_type="text/event-stream")
338
  else:
@@ -354,7 +349,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
354
  function_call = process_response(response["text"], use_tool=True)
355
  except:
356
  logger.warning(
357
- "Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
 
358
 
359
  if isinstance(function_call, dict):
360
  finish_reason = "function_call"
@@ -363,7 +359,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
363
  message = ChatMessage(
364
  role="assistant",
365
  content=response["text"],
366
- function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
 
 
367
  )
368
 
369
  logger.debug(f"==== message ====\n{message}")
@@ -382,23 +380,23 @@ async def create_chat_completion(request: ChatCompletionRequest):
382
  id="", # for open_source model, id is empty
383
  choices=[choice_data],
384
  object="chat.completion",
385
- usage=usage
386
  )
387
 
388
 
389
  async def predict(model_id: str, params: dict):
390
  choice_data = ChatCompletionResponseStreamChoice(
391
- index=0,
392
- delta=DeltaMessage(role="assistant"),
393
- finish_reason=None
 
394
  )
395
- chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
396
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
397
 
398
  previous_text = ""
399
  async for new_response in generate_stream_glm4(params):
400
  decoded_unicode = new_response["text"]
401
- delta_text = decoded_unicode[len(previous_text):]
402
  previous_text = decoded_unicode
403
 
404
  finish_reason = new_response["finish_reason"]
@@ -411,7 +409,8 @@ async def predict(model_id: str, params: dict):
411
  function_call = process_response(decoded_unicode, use_tool=True)
412
  except:
413
  logger.warning(
414
- "Failed to parse tool call, maybe the response is not a tool call or have been answered.")
 
415
 
416
  if isinstance(function_call, dict):
417
  function_call = FunctionCallResponse(**function_call)
@@ -419,48 +418,42 @@ async def predict(model_id: str, params: dict):
419
  delta = DeltaMessage(
420
  content=delta_text,
421
  role="assistant",
422
- function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
 
 
 
 
423
  )
424
 
425
  choice_data = ChatCompletionResponseStreamChoice(
426
- index=0,
427
- delta=delta,
428
- finish_reason=finish_reason
429
  )
430
  chunk = ChatCompletionResponse(
431
- model=model_id,
432
- id="",
433
- choices=[choice_data],
434
- object="chat.completion.chunk"
435
  )
436
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
437
 
438
  choice_data = ChatCompletionResponseStreamChoice(
439
- index=0,
440
- delta=DeltaMessage(),
441
- finish_reason="stop"
442
  )
443
  chunk = ChatCompletionResponse(
444
- model=model_id,
445
- id="",
446
- choices=[choice_data],
447
- object="chat.completion.chunk"
448
  )
449
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
450
- yield '[DONE]'
451
 
452
 
453
  async def predict_stream(model_id, gen_params):
454
  output = ""
455
  is_function_call = False
456
  has_send_first_chunk = False
457
- async for new_response in generate_stream_glm4(gen_params):
458
  decoded_unicode = new_response["text"]
459
- delta_text = decoded_unicode[len(output):]
460
  output = decoded_unicode
461
 
462
  if not is_function_call and len(output) > 7:
463
- is_function_call = output and 'get_' in output
464
  if is_function_call:
465
  continue
466
 
@@ -472,16 +465,14 @@ async def predict_stream(model_id, gen_params):
472
  function_call=None,
473
  )
474
  choice_data = ChatCompletionResponseStreamChoice(
475
- index=0,
476
- delta=message,
477
- finish_reason=finish_reason
478
  )
479
  chunk = ChatCompletionResponse(
480
  model=model_id,
481
  id="",
482
  choices=[choice_data],
483
  created=int(time.time()),
484
- object="chat.completion.chunk"
485
  )
486
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
487
 
@@ -493,41 +484,39 @@ async def predict_stream(model_id, gen_params):
493
  function_call=None,
494
  )
495
  choice_data = ChatCompletionResponseStreamChoice(
496
- index=0,
497
- delta=message,
498
- finish_reason=finish_reason
499
  )
500
  chunk = ChatCompletionResponse(
501
  model=model_id,
502
  id="",
503
  choices=[choice_data],
504
  created=int(time.time()),
505
- object="chat.completion.chunk"
506
  )
507
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
508
 
509
  if is_function_call:
510
  yield output
511
  else:
512
- yield '[DONE]'
513
 
514
 
515
  async def parse_output_text(model_id: str, value: str):
516
  choice_data = ChatCompletionResponseStreamChoice(
517
- index=0,
518
- delta=DeltaMessage(role="assistant", content=value),
519
- finish_reason=None
 
520
  )
521
- chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
522
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
523
  choice_data = ChatCompletionResponseStreamChoice(
524
- index=0,
525
- delta=DeltaMessage(),
526
- finish_reason="stop"
 
527
  )
528
- chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
529
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
530
- yield '[DONE]'
531
 
532
 
533
  if __name__ == "__main__":
@@ -546,4 +535,4 @@ if __name__ == "__main__":
546
  max_model_len=MAX_MODEL_LENGTH,
547
  )
548
  engine = AsyncLLMEngine.from_engine_args(engine_args)
549
- uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
 
17
  from sse_starlette.sse import EventSourceResponse
18
 
19
  EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
20
+ MODEL_PATH = (
21
+ "../llama-factory/saves/internlm2_5_7b/lora/sft_bf16_p2_full/checkpoint-528"
22
+ )
23
  MAX_MODEL_LENGTH = 8192
24
 
25
 
 
127
  model: str
128
  id: str
129
  object: Literal["chat.completion", "chat.completion.chunk"]
130
+ choices: List[
131
+ Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
132
+ ]
133
  created: Optional[int] = Field(default_factory=lambda: int(time.time()))
134
  usage: Optional[UsageInfo] = None
135
 
136
 
137
  class InvalidScoreLogitsProcessor(LogitsProcessor):
138
  def __call__(
139
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
140
  ) -> torch.FloatTensor:
141
  if torch.isnan(scores).any() or torch.isinf(scores).any():
142
  scores.zero_()
 
158
  parameters = eval(content.strip())
159
  content = {
160
  "name": metadata.strip(),
161
+ "arguments": json.dumps(parameters, ensure_ascii=False),
162
  }
163
  else:
164
+ content = {"name": metadata.strip(), "content": content}
 
 
 
165
  return content
166
 
167
 
 
175
  top_p = float(params.get("top_p", 1.0))
176
  max_new_tokens = int(params.get("max_tokens", 8192))
177
  messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
178
+ inputs = tokenizer.apply_chat_template(
179
+ messages, add_generation_prompt=True, tokenize=False
180
+ )
181
  params_dict = {
182
  "n": 1,
183
  "best_of": 1,
 
198
  "skip_special_tokens": True,
199
  }
200
  sampling_params = SamplingParams(**params_dict)
201
+ async for output in engine.generate(
202
+ inputs=inputs, sampling_params=sampling_params, request_id="glm-4-9b"
203
+ ):
204
  output_len = len(output.outputs[0].token_ids)
205
  input_len = len(output.prompt_token_ids)
206
  ret = {
 
208
  "usage": {
209
  "prompt_tokens": input_len,
210
  "completion_tokens": output_len,
211
+ "total_tokens": output_len + input_len,
212
  },
213
  "finish_reason": output.outputs[0].finish_reason,
214
  }
 
223
  msg_has_sys = False
224
 
225
  def filter_tools(tool_choice, tools):
226
+ function_name = tool_choice.get("function", {}).get("name", None)
227
  if not function_name:
228
  return []
229
  filtered_tools = [
230
+ tool
231
+ for tool in tools
232
+ if tool.get("function", {}).get("name") == function_name
233
  ]
234
  return filtered_tools
235
 
 
237
  if isinstance(tool_choice, dict):
238
  tools = filter_tools(tool_choice, tools)
239
  if tools:
240
+ messages.append({"role": "system", "content": None, "tools": tools})
 
 
 
 
 
 
241
  msg_has_sys = True
242
 
243
  # add to metadata
 
246
  {
247
  "role": "assistant",
248
  "metadata": tool_choice["function"]["name"],
249
+ "content": "",
250
  }
251
  )
252
 
253
  for m in _messages:
254
  role, content, func_call = m.role, m.content, m.function_call
255
  if role == "function":
256
+ messages.append({"role": "observation", "content": content})
 
 
 
 
 
257
  elif role == "assistant" and func_call is not None:
258
  for response in content.split("<|assistant|>"):
259
  if "\n" in response:
 
261
  else:
262
  metadata, sub_content = "", response
263
  messages.append(
264
+ {"role": role, "metadata": metadata, "content": sub_content.strip()}
 
 
 
 
265
  )
266
  else:
267
  if role == "system" and msg_has_sys:
 
306
  predict_stream_generator = predict_stream(request.model, gen_params)
307
  output = await anext(predict_stream_generator)
308
  if output:
309
+ return EventSourceResponse(
310
+ predict_stream_generator, media_type="text/event-stream"
311
+ )
312
  logger.debug(f"First result output:\n{output}")
313
 
314
  function_call = None
 
325
  if not gen_params.get("messages"):
326
  gen_params["messages"] = []
327
  gen_params["messages"].append(ChatMessage(role="assistant", content=output))
328
+ gen_params["messages"].append(
329
+ ChatMessage(role="tool", name=function_call.name, content=tool_response)
330
+ )
331
  generate = predict(request.model, gen_params)
332
  return EventSourceResponse(generate, media_type="text/event-stream")
333
  else:
 
349
  function_call = process_response(response["text"], use_tool=True)
350
  except:
351
  logger.warning(
352
+ "Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered."
353
+ )
354
 
355
  if isinstance(function_call, dict):
356
  finish_reason = "function_call"
 
359
  message = ChatMessage(
360
  role="assistant",
361
  content=response["text"],
362
+ function_call=(
363
+ function_call if isinstance(function_call, FunctionCallResponse) else None
364
+ ),
365
  )
366
 
367
  logger.debug(f"==== message ====\n{message}")
 
380
  id="", # for open_source model, id is empty
381
  choices=[choice_data],
382
  object="chat.completion",
383
+ usage=usage,
384
  )
385
 
386
 
387
  async def predict(model_id: str, params: dict):
388
  choice_data = ChatCompletionResponseStreamChoice(
389
+ index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
390
+ )
391
+ chunk = ChatCompletionResponse(
392
+ model=model_id, id="", choices=[choice_data], object="chat.completion.chunk"
393
  )
 
394
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
395
 
396
  previous_text = ""
397
  async for new_response in generate_stream_glm4(params):
398
  decoded_unicode = new_response["text"]
399
+ delta_text = decoded_unicode[len(previous_text) :]
400
  previous_text = decoded_unicode
401
 
402
  finish_reason = new_response["finish_reason"]
 
409
  function_call = process_response(decoded_unicode, use_tool=True)
410
  except:
411
  logger.warning(
412
+ "Failed to parse tool call, maybe the response is not a tool call or have been answered."
413
+ )
414
 
415
  if isinstance(function_call, dict):
416
  function_call = FunctionCallResponse(**function_call)
 
418
  delta = DeltaMessage(
419
  content=delta_text,
420
  role="assistant",
421
+ function_call=(
422
+ function_call
423
+ if isinstance(function_call, FunctionCallResponse)
424
+ else None
425
+ ),
426
  )
427
 
428
  choice_data = ChatCompletionResponseStreamChoice(
429
+ index=0, delta=delta, finish_reason=finish_reason
 
 
430
  )
431
  chunk = ChatCompletionResponse(
432
+ model=model_id, id="", choices=[choice_data], object="chat.completion.chunk"
 
 
 
433
  )
434
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
435
 
436
  choice_data = ChatCompletionResponseStreamChoice(
437
+ index=0, delta=DeltaMessage(), finish_reason="stop"
 
 
438
  )
439
  chunk = ChatCompletionResponse(
440
+ model=model_id, id="", choices=[choice_data], object="chat.completion.chunk"
 
 
 
441
  )
442
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
443
+ yield "[DONE]"
444
 
445
 
446
  async def predict_stream(model_id, gen_params):
447
  output = ""
448
  is_function_call = False
449
  has_send_first_chunk = False
450
+ async for new_response in generate_stream_glm4(gen_params):
451
  decoded_unicode = new_response["text"]
452
+ delta_text = decoded_unicode[len(output) :]
453
  output = decoded_unicode
454
 
455
  if not is_function_call and len(output) > 7:
456
+ is_function_call = output and "get_" in output
457
  if is_function_call:
458
  continue
459
 
 
465
  function_call=None,
466
  )
467
  choice_data = ChatCompletionResponseStreamChoice(
468
+ index=0, delta=message, finish_reason=finish_reason
 
 
469
  )
470
  chunk = ChatCompletionResponse(
471
  model=model_id,
472
  id="",
473
  choices=[choice_data],
474
  created=int(time.time()),
475
+ object="chat.completion.chunk",
476
  )
477
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
478
 
 
484
  function_call=None,
485
  )
486
  choice_data = ChatCompletionResponseStreamChoice(
487
+ index=0, delta=message, finish_reason=finish_reason
 
 
488
  )
489
  chunk = ChatCompletionResponse(
490
  model=model_id,
491
  id="",
492
  choices=[choice_data],
493
  created=int(time.time()),
494
+ object="chat.completion.chunk",
495
  )
496
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
497
 
498
  if is_function_call:
499
  yield output
500
  else:
501
+ yield "[DONE]"
502
 
503
 
504
  async def parse_output_text(model_id: str, value: str):
505
  choice_data = ChatCompletionResponseStreamChoice(
506
+ index=0, delta=DeltaMessage(role="assistant", content=value), finish_reason=None
507
+ )
508
+ chunk = ChatCompletionResponse(
509
+ model=model_id, id="", choices=[choice_data], object="chat.completion.chunk"
510
  )
 
511
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
512
  choice_data = ChatCompletionResponseStreamChoice(
513
+ index=0, delta=DeltaMessage(), finish_reason="stop"
514
+ )
515
+ chunk = ChatCompletionResponse(
516
+ model=model_id, id="", choices=[choice_data], object="chat.completion.chunk"
517
  )
 
518
  yield "{}".format(chunk.model_dump_json(exclude_unset=True))
519
+ yield "[DONE]"
520
 
521
 
522
  if __name__ == "__main__":
 
535
  max_model_len=MAX_MODEL_LENGTH,
536
  )
537
  engine = AsyncLLMEngine.from_engine_args(engine_args)
538
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
requirements.txt CHANGED
@@ -14,3 +14,5 @@ langchain_openai==0.1.13
14
  wandb==0.17.4
15
  # triton
16
  # xformers
 
 
 
14
  wandb==0.17.4
15
  # triton
16
  # xformers
17
+ uvicorn
18
+ vllm