visual-arena / fastchat /serve /model_worker.py
tianleliphoebe's picture
Upload folder using huggingface_hub
ec0c335 verified
"""
A model worker that executes the model.
"""
import argparse
import base64
import gc
import json
import os
from typing import List, Optional
import uuid
import base64
import numpy as np
import torch
import torch.nn.functional as F
from transformers import set_seed
import uvicorn
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.model.model_adapter import (
load_model,
add_model_args,
get_generate_stream_function,
)
import requests
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.xfastertransformer import XftConfig
from fastchat.modules.gptq import GptqConfig
from fastchat.serve.base_model_worker import BaseModelWorker, app
from fastchat.utils import (
build_logger,
get_context_length,
str_to_torch_dtype,
)
import os
os.environ['TRANSFORMERS_CACHE'] = "/checkpoint/tianleli/cache"
os.environ['HF_HOME'] = "/checkpoint/tianleli/cache"
os.environ['HF_DATASETS_CACHE'] = "/checkpoint/tianleli/cache"
worker_id = str(uuid.uuid4())[:8]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
class ModelWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
stream_interval: int = 2,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
seed: Optional[int] = None,
debug: bool = False,
**kwargs,
):
super().__init__(
controller_addr,
worker_addr,
worker_id,
model_path,
model_names,
limit_worker_concurrency,
conv_template=conv_template,
)
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
self.model, self.tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
xft_config=xft_config,
debug=debug,
)
self.device = device
if self.tokenizer.pad_token == None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if model_path.startswith("imagenhub"):
self.context_len = get_context_length(self.model.pipe.config)
else:
self.context_len = get_context_length(self.model.config)
logger.info(f"model type: {str(type(self.model)).lower()}")
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
self.stream_interval = stream_interval
self.embed_in_truncate = embed_in_truncate
self.seed = seed
if not no_register:
self.init_heart_beat()
def generate_stream_gate(self, params):
self.call_ct += 1
# try:
if self.seed is not None:
set_seed(self.seed)
for output in self.generate_stream_func(
self.model,
self.tokenizer,
params,
self.device,
self.context_len,
self.stream_interval,
):
logger.info(f"output.shape: {output['text'].size}")
# image = base64.b64encode(np.array(output["text"])).decode("utf-8")
# image = base64.b64encode(np.array(output["text"]).tobytes()).decode("utf-8")
image = np.array(output["text"]).tolist()
logger.info(f"image.shape: {len(image)}")
ret = {
"text": image,
"error_code": 0,
}
# if "usage" in output:
# ret["usage"] = output["usage"]
# if "finish_reason" in output:
# ret["finish_reason"] = output["finish_reason"]
# if "logprobs" in output:
# ret["logprobs"] = output["logprobs"]
yield json.dumps(ret).encode() + b"\0"
# yield ret
# except torch.cuda.OutOfMemoryError as e:
# ret = {
# "text": f"{SERVER_ERROR_MSG}\n\n({e})",
# "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
# }
# yield json.dumps(ret).encode() + b"\0"
# except (ValueError, RuntimeError) as e:
# ret = {
# "text": f"{SERVER_ERROR_MSG}\n\n({e})",
# "error_code": ErrorCode.INTERNAL_ERROR,
# }
# yield json.dumps(ret).encode() + b"\0"
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
# return x
pass
return json.loads(x[:-1].decode())
# return x
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
if model_type_dict.get("is_bert"):
model_output = self.model(input_ids)
if model_type_dict.get("is_robert"):
data = model_output.last_hidden_state
else:
data = model_output[0]
elif model_type_dict.get("is_t5"):
model_output = self.model(input_ids, decoder_input_ids=input_ids)
data = model_output.encoder_last_hidden_state
else:
model_output = self.model(input_ids, output_hidden_states=True)
if model_type_dict.get("is_chatglm"):
data = model_output.hidden_states[-1].transpose(0, 1)
else:
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
token_num = torch.sum(attention_mask).item()
return sum_embeddings, token_num
def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
embeddings = embeddings.cpu()
return [
base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
]
@torch.inference_mode()
def get_embeddings(self, params):
self.call_ct += 1
try:
tokenizer = self.tokenizer
ret = {"embedding": [], "token_num": 0}
model_type_dict = {
"is_llama": "llama" in str(type(self.model)),
"is_t5": "t5" in str(type(self.model)),
"is_chatglm": "chatglm" in str(type(self.model)),
"is_bert": "bert" in str(type(self.model)),
"is_robert": "robert" in str(type(self.model)),
}
if self.embed_in_truncate:
encoding = tokenizer.batch_encode_plus(
params["input"],
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=self.context_len,
)
else:
encoding = tokenizer.batch_encode_plus(
params["input"], padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = input_ids != tokenizer.pad_token_id
base64_encode = params.get("encoding_format", None)
if self.embed_in_truncate:
chunk_embeddings, token_num = self.__process_embed_chunk(
input_ids, attention_mask, **model_type_dict
)
embedding = chunk_embeddings / token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = token_num
else:
all_embeddings = []
all_token_num = 0
for i in range(0, input_ids.size(1), self.context_len):
chunk_input_ids = input_ids[:, i : i + self.context_len]
chunk_attention_mask = attention_mask[:, i : i + self.context_len]
chunk_embeddings, token_num = self.__process_embed_chunk(
chunk_input_ids, chunk_attention_mask, **model_type_dict
)
all_embeddings.append(chunk_embeddings)
all_token_num += token_num
all_embeddings_tensor = torch.stack(all_embeddings)
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = all_token_num
if base64_encode == "base64":
out_embeddings = self.__encode_base64(normalized_embeddings)
else:
out_embeddings = normalized_embeddings.tolist()
ret["embedding"] = out_embeddings
gc.collect()
torch.cuda.empty_cache()
if self.device == "xpu":
torch.xpu.empty_cache()
if self.device == "npu":
torch.npu.empty_cache()
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return ret
def create_model_worker():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://172.17.15.237:21001"
)
add_model_args(parser)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument("--embed-in-truncate", action="store_true")
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--seed",
type=int,
default=None,
help="Overwrite the random seed for each generation.",
)
parser.add_argument(
"--debug", type=bool, default=False, help="Print debugging messages"
)
parser.add_argument(
"--ssl",
action="store_true",
required=False,
default=False,
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
)
args = parser.parse_args()
logger.info(f"args: {args}")
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
if args.enable_xft:
xft_config = XftConfig(
max_seq_len=args.xft_max_seq_len,
data_type=args.xft_dtype,
)
if args.device != "cpu":
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
args.device = "cpu"
else:
xft_config = None
worker = ModelWorker(
args.controller_address,
args.worker_address,
worker_id,
args.model_path,
args.model_names,
args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
dtype=str_to_torch_dtype(args.dtype),
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
xft_config=xft_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
seed=args.seed,
debug=args.debug,
)
return args, worker
if __name__ == "__main__":
args, worker = create_model_worker()
if args.ssl:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
ssl_keyfile=os.environ["SSL_KEYFILE"],
ssl_certfile=os.environ["SSL_CERTFILE"],
)
else:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")