Spaces:
Build error
Build error
import os | |
import time | |
from asyncio.log import logger | |
import uvicorn | |
import gc | |
import json | |
import torch | |
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine | |
from fastapi import FastAPI, HTTPException, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from typing import List, Literal, Optional, Union | |
from pydantic import BaseModel, Field | |
from transformers import AutoTokenizer, LogitsProcessor | |
from sse_starlette.sse import EventSourceResponse | |
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 | |
MODEL_PATH = "../llama-factory/merged_models/internlm2_5-7b-chat-1m_sft_bf16_p2_full" | |
MAX_MODEL_LENGTH = 8192 | |
async def lifespan(app: FastAPI): | |
yield | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
app = FastAPI(lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class ModelCard(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "owner" | |
root: Optional[str] = None | |
parent: Optional[str] = None | |
permission: Optional[list] = None | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelCard] = [] | |
class FunctionCallResponse(BaseModel): | |
name: Optional[str] = None | |
arguments: Optional[str] = None | |
class ChatMessage(BaseModel): | |
role: Literal["user", "assistant", "system", "tool"] | |
content: str = None | |
name: Optional[str] = None | |
function_call: Optional[FunctionCallResponse] = None | |
class DeltaMessage(BaseModel): | |
role: Optional[Literal["user", "assistant", "system"]] = None | |
content: Optional[str] = None | |
function_call: Optional[FunctionCallResponse] = None | |
class EmbeddingRequest(BaseModel): | |
input: Union[List[str], str] | |
model: str | |
class CompletionUsage(BaseModel): | |
prompt_tokens: int | |
completion_tokens: int | |
total_tokens: int | |
class EmbeddingResponse(BaseModel): | |
data: list | |
model: str | |
object: str | |
usage: CompletionUsage | |
class UsageInfo(BaseModel): | |
prompt_tokens: int = 0 | |
total_tokens: int = 0 | |
completion_tokens: Optional[int] = 0 | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 0.8 | |
top_p: Optional[float] = 0.8 | |
max_tokens: Optional[int] = None | |
stream: Optional[bool] = False | |
tools: Optional[Union[dict, List[dict]]] = None | |
tool_choice: Optional[Union[str, dict]] = "None" | |
repetition_penalty: Optional[float] = 1.1 | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: Literal["stop", "length", "function_call"] | |
class ChatCompletionResponseStreamChoice(BaseModel): | |
delta: DeltaMessage | |
finish_reason: Optional[Literal["stop", "length", "function_call"]] | |
index: int | |
class ChatCompletionResponse(BaseModel): | |
model: str | |
id: str | |
object: Literal["chat.completion", "chat.completion.chunk"] | |
choices: List[ | |
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] | |
] | |
created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
usage: Optional[UsageInfo] = None | |
class InvalidScoreLogitsProcessor(LogitsProcessor): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
) -> torch.FloatTensor: | |
if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
scores.zero_() | |
scores[..., 5] = 5e4 | |
return scores | |
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: | |
content = "" | |
for response in output.split("<|assistant|>"): | |
if "\n" in response: | |
metadata, content = response.split("\n", maxsplit=1) | |
else: | |
metadata, content = "", response | |
if not metadata.strip(): | |
content = content.strip() | |
else: | |
if use_tool: | |
parameters = eval(content.strip()) | |
content = { | |
"name": metadata.strip(), | |
"arguments": json.dumps(parameters, ensure_ascii=False), | |
} | |
else: | |
content = {"name": metadata.strip(), "content": content} | |
return content | |
async def generate_stream_glm4(params): | |
messages = params["messages"] | |
tools = params["tools"] | |
tool_choice = params["tool_choice"] | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = int(params.get("max_tokens", 8192)) | |
messages = process_messages(messages, tools=tools, tool_choice=tool_choice) | |
inputs = tokenizer.apply_chat_template( | |
messages, add_generation_prompt=True, tokenize=False | |
) | |
params_dict = { | |
"n": 1, | |
"best_of": 1, | |
"presence_penalty": 1.0, | |
"frequency_penalty": 0.0, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": -1, | |
"repetition_penalty": repetition_penalty, | |
"use_beam_search": False, | |
"length_penalty": 1, | |
"early_stopping": False, | |
"stop_token_ids": [151329, 151336, 151338], | |
"ignore_eos": False, | |
"max_tokens": max_new_tokens, | |
"logprobs": None, | |
"prompt_logprobs": None, | |
"skip_special_tokens": True, | |
} | |
sampling_params = SamplingParams(**params_dict) | |
async for output in engine.generate( | |
inputs=inputs, sampling_params=sampling_params, request_id="glm-4-9b" | |
): | |
output_len = len(output.outputs[0].token_ids) | |
input_len = len(output.prompt_token_ids) | |
ret = { | |
"text": output.outputs[0].text, | |
"usage": { | |
"prompt_tokens": input_len, | |
"completion_tokens": output_len, | |
"total_tokens": output_len + input_len, | |
}, | |
"finish_reason": output.outputs[0].finish_reason, | |
} | |
yield ret | |
gc.collect() | |
torch.cuda.empty_cache() | |
def process_messages(messages, tools=None, tool_choice="none"): | |
_messages = messages | |
messages = [] | |
msg_has_sys = False | |
def filter_tools(tool_choice, tools): | |
function_name = tool_choice.get("function", {}).get("name", None) | |
if not function_name: | |
return [] | |
filtered_tools = [ | |
tool | |
for tool in tools | |
if tool.get("function", {}).get("name") == function_name | |
] | |
return filtered_tools | |
if tool_choice != "none": | |
if isinstance(tool_choice, dict): | |
tools = filter_tools(tool_choice, tools) | |
if tools: | |
messages.append({"role": "system", "content": None, "tools": tools}) | |
msg_has_sys = True | |
# add to metadata | |
if isinstance(tool_choice, dict) and tools: | |
messages.append( | |
{ | |
"role": "assistant", | |
"metadata": tool_choice["function"]["name"], | |
"content": "", | |
} | |
) | |
for m in _messages: | |
role, content, func_call = m.role, m.content, m.function_call | |
if role == "function": | |
messages.append({"role": "observation", "content": content}) | |
elif role == "assistant" and func_call is not None: | |
for response in content.split("<|assistant|>"): | |
if "\n" in response: | |
metadata, sub_content = response.split("\n", maxsplit=1) | |
else: | |
metadata, sub_content = "", response | |
messages.append( | |
{"role": role, "metadata": metadata, "content": sub_content.strip()} | |
) | |
else: | |
if role == "system" and msg_has_sys: | |
msg_has_sys = False | |
continue | |
messages.append({"role": role, "content": content}) | |
return messages | |
async def health() -> Response: | |
"""Health check.""" | |
return Response(status_code=200) | |
async def list_models(): | |
model_card = ModelCard(id="glm-4") | |
return ModelList(data=[model_card]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
if len(request.messages) < 1 or request.messages[-1].role == "assistant": | |
raise HTTPException(status_code=400, detail="Invalid request") | |
gen_params = dict( | |
messages=request.messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
max_tokens=request.max_tokens or 1024, | |
echo=False, | |
stream=request.stream, | |
repetition_penalty=request.repetition_penalty, | |
tools=request.tools, | |
tool_choice=request.tool_choice, | |
) | |
logger.debug(f"==== request ====\n{gen_params}") | |
if request.stream: | |
predict_stream_generator = predict_stream(request.model, gen_params) | |
output = await anext(predict_stream_generator) | |
if output: | |
return EventSourceResponse( | |
predict_stream_generator, media_type="text/event-stream" | |
) | |
logger.debug(f"First result output:\n{output}") | |
function_call = None | |
if output and request.tools: | |
try: | |
function_call = process_response(output, use_tool=True) | |
except: | |
logger.warning("Failed to parse tool call") | |
# CallFunction | |
if isinstance(function_call, dict): | |
function_call = FunctionCallResponse(**function_call) | |
tool_response = "" | |
if not gen_params.get("messages"): | |
gen_params["messages"] = [] | |
gen_params["messages"].append(ChatMessage(role="assistant", content=output)) | |
gen_params["messages"].append( | |
ChatMessage(role="tool", name=function_call.name, content=tool_response) | |
) | |
generate = predict(request.model, gen_params) | |
return EventSourceResponse(generate, media_type="text/event-stream") | |
else: | |
generate = parse_output_text(request.model, output) | |
return EventSourceResponse(generate, media_type="text/event-stream") | |
response = "" | |
async for response in generate_stream_glm4(gen_params): | |
pass | |
if response["text"].startswith("\n"): | |
response["text"] = response["text"][1:] | |
response["text"] = response["text"].strip() | |
usage = UsageInfo() | |
function_call, finish_reason = None, "stop" | |
if request.tools: | |
try: | |
function_call = process_response(response["text"], use_tool=True) | |
except: | |
logger.warning( | |
"Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered." | |
) | |
if isinstance(function_call, dict): | |
finish_reason = "function_call" | |
function_call = FunctionCallResponse(**function_call) | |
message = ChatMessage( | |
role="assistant", | |
content=response["text"], | |
function_call=( | |
function_call if isinstance(function_call, FunctionCallResponse) else None | |
), | |
) | |
logger.debug(f"==== message ====\n{message}") | |
choice_data = ChatCompletionResponseChoice( | |
index=0, | |
message=message, | |
finish_reason=finish_reason, | |
) | |
task_usage = UsageInfo.model_validate(response["usage"]) | |
for usage_key, usage_value in task_usage.model_dump().items(): | |
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
return ChatCompletionResponse( | |
model=request.model, | |
id="", # for open_source model, id is empty | |
choices=[choice_data], | |
object="chat.completion", | |
usage=usage, | |
) | |
async def predict(model_id: str, params: dict): | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
previous_text = "" | |
async for new_response in generate_stream_glm4(params): | |
decoded_unicode = new_response["text"] | |
delta_text = decoded_unicode[len(previous_text) :] | |
previous_text = decoded_unicode | |
finish_reason = new_response["finish_reason"] | |
if len(delta_text) == 0 and finish_reason != "function_call": | |
continue | |
function_call = None | |
if finish_reason == "function_call": | |
try: | |
function_call = process_response(decoded_unicode, use_tool=True) | |
except: | |
logger.warning( | |
"Failed to parse tool call, maybe the response is not a tool call or have been answered." | |
) | |
if isinstance(function_call, dict): | |
function_call = FunctionCallResponse(**function_call) | |
delta = DeltaMessage( | |
content=delta_text, | |
role="assistant", | |
function_call=( | |
function_call | |
if isinstance(function_call, FunctionCallResponse) | |
else None | |
), | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=delta, finish_reason=finish_reason | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=DeltaMessage(), finish_reason="stop" | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
yield "[DONE]" | |
async def predict_stream(model_id, gen_params): | |
output = "" | |
is_function_call = False | |
has_send_first_chunk = False | |
async for new_response in generate_stream_glm4(gen_params): | |
decoded_unicode = new_response["text"] | |
delta_text = decoded_unicode[len(output) :] | |
output = decoded_unicode | |
if not is_function_call and len(output) > 7: | |
is_function_call = output and "get_" in output | |
if is_function_call: | |
continue | |
finish_reason = new_response["finish_reason"] | |
if not has_send_first_chunk: | |
message = DeltaMessage( | |
content="", | |
role="assistant", | |
function_call=None, | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=message, finish_reason=finish_reason | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
id="", | |
choices=[choice_data], | |
created=int(time.time()), | |
object="chat.completion.chunk", | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
send_msg = delta_text if has_send_first_chunk else output | |
has_send_first_chunk = True | |
message = DeltaMessage( | |
content=send_msg, | |
role="assistant", | |
function_call=None, | |
) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=message, finish_reason=finish_reason | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, | |
id="", | |
choices=[choice_data], | |
created=int(time.time()), | |
object="chat.completion.chunk", | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
if is_function_call: | |
yield output | |
else: | |
yield "[DONE]" | |
async def parse_output_text(model_id: str, value: str): | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=DeltaMessage(role="assistant", content=value), finish_reason=None | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, delta=DeltaMessage(), finish_reason="stop" | |
) | |
chunk = ChatCompletionResponse( | |
model=model_id, id="", choices=[choice_data], object="chat.completion.chunk" | |
) | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
yield "[DONE]" | |
if __name__ == "__main__": | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
engine_args = AsyncEngineArgs( | |
model=MODEL_PATH, | |
tokenizer=MODEL_PATH, | |
tensor_parallel_size=1, | |
dtype="bfloat16", | |
trust_remote_code=True, | |
gpu_memory_utilization=0.9, | |
enforce_eager=True, | |
worker_use_ray=True, | |
engine_use_ray=False, | |
disable_log_requests=True, | |
max_model_len=MAX_MODEL_LENGTH, | |
) | |
engine = AsyncLLMEngine.from_engine_args(engine_args) | |
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) | |