Spaces:
Runtime error
Runtime error
import gc | |
import json | |
import time | |
import requests | |
import base64 | |
import uvicorn | |
import argparse | |
import torch | |
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedModel, PreTrainedTokenizer, \ | |
TextIteratorStreamer, CodeGenTokenizerFast as Tokenizer | |
from contextlib import asynccontextmanager | |
from loguru import logger | |
from typing import List, Literal, Union, Tuple, Optional | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
from PIL import Image | |
from io import BytesIO | |
import os | |
import re | |
from threading import Thread | |
from moondream import Moondream, detect_device | |
import omnichat | |
# 请求 | |
class TextContent(BaseModel): | |
type: Literal["text"] | |
text: str | |
class ImageUrl(BaseModel): | |
url: str | |
class ImageUrlContent(BaseModel): | |
type: Literal["image_url"] | |
image_url: ImageUrl | |
ContentItem = Union[TextContent, ImageUrlContent] | |
class ChatMessageInput(BaseModel): | |
role: Literal["user", "assistant", "system"] | |
content: Union[str, List[ContentItem]] | |
name: Optional[str] = None | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessageInput] | |
temperature: Optional[float] = 0.8 | |
top_p: Optional[float] = 0.8 | |
max_tokens: Optional[int] = None | |
stream: Optional[bool] = False | |
# Additional parameters | |
repetition_penalty: Optional[float] = 1.0 | |
# 响应 | |
class ChatMessageResponse(BaseModel): | |
role: Literal["assistant"] | |
content: str = None | |
name: Optional[str] = None | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int | |
message: ChatMessageResponse | |
class DeltaMessage(BaseModel): | |
role: Optional[Literal["user", "assistant", "system"]] = None | |
content: Optional[str] = None | |
class ChatCompletionResponseStreamChoice(BaseModel): | |
index: int | |
delta: DeltaMessage | |
class UsageInfo(BaseModel): | |
prompt_tokens: int = 0 | |
total_tokens: int = 0 | |
completion_tokens: Optional[int] = 0 | |
class ChatCompletionResponse(BaseModel): | |
model: str | |
object: Literal["chat.completion", "chat.completion.chunk"] | |
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
usage: Optional[UsageInfo] = None | |
# 图片输入处理 | |
def process_img(input_data): | |
if isinstance(input_data, str): | |
# URL | |
if input_data.startswith("http://") or input_data.startswith("https://"): | |
response = requests.get(input_data) | |
image_data = response.content | |
pil_image = Image.open(BytesIO(image_data)).convert('RGB') | |
# base64 | |
elif input_data.startswith("data:image/"): | |
base64_data = input_data.split(",")[1] | |
image_data = base64.b64decode(base64_data) | |
pil_image = Image.open(BytesIO(image_data)).convert('RGB') | |
# img_path | |
else: | |
pil_image = Image.open(input_data) | |
# PIL | |
elif isinstance(input_data, Image.Image): | |
pil_image = input_data | |
else: | |
raise ValueError("data type error") | |
return pil_image | |
# 历史消息处理 | |
def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[ | |
Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]: | |
formatted_history = [] | |
image_list = [] | |
last_user_query = '' | |
for i, message in enumerate(messages): | |
role = message.role | |
content = message.content | |
if isinstance(content, list): # text | |
text_content = ' '.join(item.text for item in content if isinstance(item, TextContent)) | |
else: | |
text_content = content | |
if isinstance(content, list): # image | |
for item in content: | |
if isinstance(item, ImageUrlContent): | |
image_url = item.image_url.url | |
image = process_img(image_url) | |
image_list.append(image) | |
if role == 'user': | |
if i == len(messages) - 1: # last message | |
last_user_query = text_content | |
else: | |
formatted_history.append((text_content, '')) | |
elif role == 'assistant': | |
if formatted_history: | |
if formatted_history[-1][1] != '': | |
assert False, f"the last query is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}" | |
formatted_history[-1] = (formatted_history[-1][0], text_content) | |
else: | |
assert False, f"assistant reply before user" | |
else: | |
assert False, f"unrecognized role: {role}" | |
return last_user_query, formatted_history, image_list | |
# Moondrean推理 | |
def generate_stream_moondream(params: dict): | |
global model, tokenizer | |
# 输入处理 | |
def chat_history_to_prompt(history): | |
prompt = "" | |
for i, (old_query, response) in enumerate(history): | |
prompt += f"Question: {old_query}\n\nAnswer: {response}\n\n" | |
return prompt | |
messages = params["messages"] | |
prompt, formatted_history, image_list = process_history_and_images(messages) | |
history = chat_history_to_prompt(formatted_history) | |
# 只处理最后一张图 | |
img = image_list[-1] | |
# 构建输入 | |
''' | |
answer_question( | |
self, | |
image_embeds, | |
question, | |
tokenizer, | |
chat_history="", | |
result_queue=None, | |
**kwargs, | |
) | |
''' | |
image_embeds = model.encode_image(img) | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
gen_kwargs = { | |
"image_embeds": image_embeds, | |
"question": prompt, | |
"tokenizer": tokenizer, | |
"chat_history": history, | |
"result_queue": None, | |
"streamer": streamer, | |
} | |
thread = Thread( | |
target=model.answer_question, | |
kwargs=gen_kwargs, | |
) | |
input_echo_len = 0 | |
total_len = 0 | |
# 启动推理 | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
clean_text = re.sub("<$|END$", "", new_text) | |
buffer += clean_text | |
yield { | |
"text": buffer.strip("<END"), | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
} | |
generated_ret ={ | |
"text": buffer.strip("<END"), | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
} | |
yield generated_ret | |
# Moondrean单次响应 | |
def generate_moondream(params: dict): | |
for response in generate_stream_moondream(params): | |
pass | |
return response | |
# CogVLM推理 | |
def generate_stream_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): | |
""" | |
Generates a stream of responses using the CogVLM model in inference mode. | |
It's optimized to handle continuous input-output interactions with the model in a streaming manner. | |
""" | |
messages = params["messages"] | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = int(params.get("max_tokens", 256)) | |
query, history, image_list = process_history_and_images(messages) | |
logger.debug(f"==== request ====\n{query}") | |
# only can slove the latest picture | |
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, | |
images=[image_list[-1]]) | |
inputs = { | |
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), | |
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), | |
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), | |
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]], | |
} | |
if 'cross_images' in input_by_model and input_by_model['cross_images']: | |
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] | |
input_echo_len = len(inputs["input_ids"][0]) | |
streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = { | |
"repetition_penalty": repetition_penalty, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": False, | |
"top_p": top_p, | |
'streamer': streamer, | |
} | |
if temperature > 1e-5: | |
gen_kwargs["temperature"] = temperature | |
total_len = 0 | |
generated_text = "" | |
with torch.no_grad(): | |
model.generate(**inputs, **gen_kwargs) | |
for next_text in streamer: | |
generated_text += next_text | |
yield { | |
"text": generated_text, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
} | |
ret = { | |
"text": generated_text, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
} | |
yield ret | |
# CogVLM单次响应 | |
def generate_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): | |
for response in generate_stream_cogvlm(model, tokenizer, params): | |
pass | |
return response | |
def generate_minicpm(model, params): | |
messages = params["messages"] | |
query, history, image_list = process_history_and_images(messages) | |
msgs = history | |
msgs.append({'role': 'user', 'content': query}) | |
image = image_list[-1] | |
# image is a PIL image | |
buffer = BytesIO() | |
image.save(buffer, format="JPEG") # You can adjust the format as needed | |
buffer.seek(0) | |
image_base64 = base64.b64encode(buffer.read()) | |
image_base64_str = image_base64.decode("utf-8") | |
input = {'image': image_base64_str, 'question': json.dumps(msgs)} | |
generation = model.chat(input) | |
response = {"text": generation, "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}} | |
print(response) | |
return response | |
# 流式响应 | |
async def predict(model_id: str, params: dict): | |
return "no stream" | |
torch.set_grad_enabled(False) | |
# 生命周期管理器,结束清显存 | |
async def lifespan(app: FastAPI): | |
yield | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
app = FastAPI(lifespan=lifespan) | |
# 允许跨域 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 对话路由 | |
async def create_chat_completion(request: ChatCompletionRequest): | |
global model, tokenizer | |
# 检查请求 | |
if len(request.messages) < 1 or request.messages[-1].role == "assistant": | |
raise HTTPException(status_code=400, detail="Invalid request") | |
gen_params = dict( | |
messages=request.messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
max_tokens=request.max_tokens or 1024, | |
echo=False, | |
stream=request.stream, | |
) | |
# 流式响应 | |
if request.stream: | |
generate = predict(request.model, gen_params) | |
return | |
# 单次响应 | |
if STATE_MOD == "cog": | |
response = generate_cogvlm(model, tokenizer, gen_params) | |
elif STATE_MOD == "moon": | |
response = generate_moondream(gen_params) | |
elif STATE_MOD == "mini": | |
response = generate_minicpm(model, gen_params) | |
usage = UsageInfo() | |
message = ChatMessageResponse( | |
role="assistant", | |
content=response["text"], | |
) | |
logger.debug(f"==== message ====\n{message}") | |
choice_data = ChatCompletionResponseChoice( | |
index=0, | |
message=message, | |
) | |
task_usage = UsageInfo.model_validate(response["usage"]) | |
for usage_key, usage_value in task_usage.model_dump().items(): | |
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) | |
# 模型切换路由配置 | |
STATE_MOD = "moon" | |
MODEL_PATH = "" | |
# 模型加载 | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def load_mod(model_input, mod_type): | |
global model, tokenizer, language_processor_version | |
if mod_type == "cog": | |
tokenizer_path = os.environ.get("TOKENIZER_PATH", 'lmsys/vicuna-7b-v1.5') | |
tokenizer = LlamaTokenizer.from_pretrained( | |
tokenizer_path, | |
trust_remote_code=True, | |
signal_type=language_processor_version | |
) | |
if 'cuda' in DEVICE: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_input, | |
trust_remote_code=True, | |
load_in_4bit=True, | |
torch_dtype=torch_type, | |
low_cpu_mem_usage=True | |
).eval() | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_input, | |
trust_remote_code=True | |
).float().to(DEVICE).eval() | |
elif mod_type == "moon": | |
device, dtype = detect_device() | |
model = Moondream.from_pretrained(model_input).to(device=device, dtype=dtype).eval() | |
tokenizer = Tokenizer.from_pretrained(model_input) | |
elif mod_type == "mini": | |
model, tokenizer = omnichat.OmniLMMChat(model_input), None | |
async def switch_vqa(): | |
global model, STATE_MOD, mod_vqa, language_processor_version | |
STATE_MOD = "cog" | |
del model | |
model = None | |
language_processor_version = "chat_old" | |
load_mod(mod_vqa, STATE_MOD) | |
async def switch_chat(): | |
global model, STATE_MOD, mod_chat, language_processor_version | |
STATE_MOD = "cog" | |
del model | |
model = None | |
language_processor_version = "chat" | |
load_mod(mod_chat, STATE_MOD) | |
async def switch_moon(): | |
global model, STATE_MOD, mod_moon | |
STATE_MOD = "moon" | |
del model | |
model = None | |
load_mod(mod_moon, STATE_MOD) | |
async def switch_mini(): | |
global model, STATE_MOD, mod_mini | |
STATE_MOD = "mini" | |
del model | |
model = None | |
load_mod(mod_mini, STATE_MOD) | |
# 关闭 | |
async def close(): | |
global model | |
del model | |
model = None | |
gc.collect() | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mod", type=str, default="moondrean") | |
args = parser.parse_args() | |
mod = args.mod | |
mod_vqa = './models/cogagent-vqa-hf' | |
mod_chat = './models/cogagent-chat-hf' | |
mod_moon = './models/moondream' | |
mod_mini = './models/MiniCPM-Llama3-V-2_5' | |
''' | |
mod_list = [ | |
"moondrean", | |
"Cog-vqa", | |
"Cog-chat" | |
"MiniCPM" | |
] | |
''' | |
if mod == "Cog-vqa": | |
STATE_MOD = "cog" | |
MODEL_PATH = mod_vqa | |
language_processor_version = "chat_old" | |
elif mod == "Cog-chat": | |
STATE_MOD = "cog" | |
MODEL_PATH = mod_chat | |
language_processor_version = "chat" | |
elif mod == "moondream": | |
STATE_MOD = "moon" | |
MODEL_PATH = mod_moon | |
elif mod == "MiniCPM": | |
STATE_MOD = "mini" | |
MODEL_PATH = mod_mini | |
if __name__ == "__main__": | |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
torch_type = torch.bfloat16 | |
else: | |
torch_type = torch.float16 | |
print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) | |
load_mod(MODEL_PATH, STATE_MOD) | |
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) | |