""" A model worker that executes the model. """ import argparse import base64 import gc import json import os from typing import List, Optional import uuid import base64 import numpy as np import torch import torch.nn.functional as F from transformers import set_seed import uvicorn from fastchat.constants import ErrorCode, SERVER_ERROR_MSG from fastchat.model.model_adapter import ( load_model, add_model_args, get_generate_stream_function, ) import requests from fastchat.modules.awq import AWQConfig from fastchat.modules.exllama import ExllamaConfig from fastchat.modules.xfastertransformer import XftConfig from fastchat.modules.gptq import GptqConfig from fastchat.serve.base_model_worker import BaseModelWorker, app from fastchat.utils import ( build_logger, get_context_length, str_to_torch_dtype, ) import os os.environ['TRANSFORMERS_CACHE'] = "/checkpoint/tianleli/cache" os.environ['HF_HOME'] = "/checkpoint/tianleli/cache" os.environ['HF_DATASETS_CACHE'] = "/checkpoint/tianleli/cache" worker_id = str(uuid.uuid4())[:8] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") class ModelWorker(BaseModelWorker): def __init__( self, controller_addr: str, worker_addr: str, worker_id: str, model_path: str, model_names: List[str], limit_worker_concurrency: int, no_register: bool, device: str, num_gpus: int, max_gpu_memory: str, dtype: Optional[torch.dtype] = None, load_8bit: bool = False, cpu_offloading: bool = False, gptq_config: Optional[GptqConfig] = None, awq_config: Optional[AWQConfig] = None, exllama_config: Optional[ExllamaConfig] = None, xft_config: Optional[XftConfig] = None, stream_interval: int = 2, conv_template: Optional[str] = None, embed_in_truncate: bool = False, seed: Optional[int] = None, debug: bool = False, **kwargs, ): super().__init__( controller_addr, worker_addr, worker_id, model_path, model_names, limit_worker_concurrency, conv_template=conv_template, ) logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") self.model, self.tokenizer = load_model( model_path, device=device, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory, dtype=dtype, load_8bit=load_8bit, cpu_offloading=cpu_offloading, gptq_config=gptq_config, awq_config=awq_config, exllama_config=exllama_config, xft_config=xft_config, debug=debug, ) self.device = device if self.tokenizer.pad_token == None: self.tokenizer.pad_token = self.tokenizer.eos_token if model_path.startswith("imagenhub"): self.context_len = get_context_length(self.model.pipe.config) else: self.context_len = get_context_length(self.model.config) logger.info(f"model type: {str(type(self.model)).lower()}") self.generate_stream_func = get_generate_stream_function(self.model, model_path) self.stream_interval = stream_interval self.embed_in_truncate = embed_in_truncate self.seed = seed if not no_register: self.init_heart_beat() def generate_stream_gate(self, params): self.call_ct += 1 # try: if self.seed is not None: set_seed(self.seed) for output in self.generate_stream_func( self.model, self.tokenizer, params, self.device, self.context_len, self.stream_interval, ): logger.info(f"output.shape: {output['text'].size}") # image = base64.b64encode(np.array(output["text"])).decode("utf-8") # image = base64.b64encode(np.array(output["text"]).tobytes()).decode("utf-8") image = np.array(output["text"]).tolist() logger.info(f"image.shape: {len(image)}") ret = { "text": image, "error_code": 0, } # if "usage" in output: # ret["usage"] = output["usage"] # if "finish_reason" in output: # ret["finish_reason"] = output["finish_reason"] # if "logprobs" in output: # ret["logprobs"] = output["logprobs"] yield json.dumps(ret).encode() + b"\0" # yield ret # except torch.cuda.OutOfMemoryError as e: # ret = { # "text": f"{SERVER_ERROR_MSG}\n\n({e})", # "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, # } # yield json.dumps(ret).encode() + b"\0" # except (ValueError, RuntimeError) as e: # ret = { # "text": f"{SERVER_ERROR_MSG}\n\n({e})", # "error_code": ErrorCode.INTERNAL_ERROR, # } # yield json.dumps(ret).encode() + b"\0" def generate_gate(self, params): for x in self.generate_stream_gate(params): # return x pass return json.loads(x[:-1].decode()) # return x def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): if model_type_dict.get("is_bert"): model_output = self.model(input_ids) if model_type_dict.get("is_robert"): data = model_output.last_hidden_state else: data = model_output[0] elif model_type_dict.get("is_t5"): model_output = self.model(input_ids, decoder_input_ids=input_ids) data = model_output.encoder_last_hidden_state else: model_output = self.model(input_ids, output_hidden_states=True) if model_type_dict.get("is_chatglm"): data = model_output.hidden_states[-1].transpose(0, 1) else: data = model_output.hidden_states[-1] mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) token_num = torch.sum(attention_mask).item() return sum_embeddings, token_num def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: embeddings = embeddings.cpu() return [ base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings ] @torch.inference_mode() def get_embeddings(self, params): self.call_ct += 1 try: tokenizer = self.tokenizer ret = {"embedding": [], "token_num": 0} model_type_dict = { "is_llama": "llama" in str(type(self.model)), "is_t5": "t5" in str(type(self.model)), "is_chatglm": "chatglm" in str(type(self.model)), "is_bert": "bert" in str(type(self.model)), "is_robert": "robert" in str(type(self.model)), } if self.embed_in_truncate: encoding = tokenizer.batch_encode_plus( params["input"], padding=True, truncation="longest_first", return_tensors="pt", max_length=self.context_len, ) else: encoding = tokenizer.batch_encode_plus( params["input"], padding=True, return_tensors="pt" ) input_ids = encoding["input_ids"].to(self.device) attention_mask = input_ids != tokenizer.pad_token_id base64_encode = params.get("encoding_format", None) if self.embed_in_truncate: chunk_embeddings, token_num = self.__process_embed_chunk( input_ids, attention_mask, **model_type_dict ) embedding = chunk_embeddings / token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) ret["token_num"] = token_num else: all_embeddings = [] all_token_num = 0 for i in range(0, input_ids.size(1), self.context_len): chunk_input_ids = input_ids[:, i : i + self.context_len] chunk_attention_mask = attention_mask[:, i : i + self.context_len] chunk_embeddings, token_num = self.__process_embed_chunk( chunk_input_ids, chunk_attention_mask, **model_type_dict ) all_embeddings.append(chunk_embeddings) all_token_num += token_num all_embeddings_tensor = torch.stack(all_embeddings) embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) ret["token_num"] = all_token_num if base64_encode == "base64": out_embeddings = self.__encode_base64(normalized_embeddings) else: out_embeddings = normalized_embeddings.tolist() ret["embedding"] = out_embeddings gc.collect() torch.cuda.empty_cache() if self.device == "xpu": torch.xpu.empty_cache() if self.device == "npu": torch.npu.empty_cache() except torch.cuda.OutOfMemoryError as e: ret = { "text": f"{SERVER_ERROR_MSG}\n\n({e})", "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, } except (ValueError, RuntimeError) as e: ret = { "text": f"{SERVER_ERROR_MSG}\n\n({e})", "error_code": ErrorCode.INTERNAL_ERROR, } return ret def create_model_worker(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) parser.add_argument("--worker-address", type=str, default="http://localhost:21002") parser.add_argument( "--controller-address", type=str, default="http://172.17.15.237:21001" ) add_model_args(parser) parser.add_argument( "--model-names", type=lambda s: s.split(","), help="Optional display comma separated names", ) parser.add_argument( "--conv-template", type=str, default=None, help="Conversation prompt template." ) parser.add_argument("--embed-in-truncate", action="store_true") parser.add_argument( "--limit-worker-concurrency", type=int, default=5, help="Limit the model concurrency to prevent OOM.", ) parser.add_argument("--stream-interval", type=int, default=2) parser.add_argument("--no-register", action="store_true") parser.add_argument( "--seed", type=int, default=None, help="Overwrite the random seed for each generation.", ) parser.add_argument( "--debug", type=bool, default=False, help="Print debugging messages" ) parser.add_argument( "--ssl", action="store_true", required=False, default=False, help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", ) args = parser.parse_args() logger.info(f"args: {args}") if args.gpus: if len(args.gpus.split(",")) < args.num_gpus: raise ValueError( f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus gptq_config = GptqConfig( ckpt=args.gptq_ckpt or args.model_path, wbits=args.gptq_wbits, groupsize=args.gptq_groupsize, act_order=args.gptq_act_order, ) awq_config = AWQConfig( ckpt=args.awq_ckpt or args.model_path, wbits=args.awq_wbits, groupsize=args.awq_groupsize, ) if args.enable_exllama: exllama_config = ExllamaConfig( max_seq_len=args.exllama_max_seq_len, gpu_split=args.exllama_gpu_split, cache_8bit=args.exllama_cache_8bit, ) else: exllama_config = None if args.enable_xft: xft_config = XftConfig( max_seq_len=args.xft_max_seq_len, data_type=args.xft_dtype, ) if args.device != "cpu": print("xFasterTransformer now is only support CPUs. Reset device to CPU") args.device = "cpu" else: xft_config = None worker = ModelWorker( args.controller_address, args.worker_address, worker_id, args.model_path, args.model_names, args.limit_worker_concurrency, no_register=args.no_register, device=args.device, num_gpus=args.num_gpus, max_gpu_memory=args.max_gpu_memory, dtype=str_to_torch_dtype(args.dtype), load_8bit=args.load_8bit, cpu_offloading=args.cpu_offloading, gptq_config=gptq_config, awq_config=awq_config, exllama_config=exllama_config, xft_config=xft_config, stream_interval=args.stream_interval, conv_template=args.conv_template, embed_in_truncate=args.embed_in_truncate, seed=args.seed, debug=args.debug, ) return args, worker if __name__ == "__main__": args, worker = create_model_worker() if args.ssl: uvicorn.run( app, host=args.host, port=args.port, log_level="info", ssl_keyfile=os.environ["SSL_KEYFILE"], ssl_certfile=os.environ["SSL_CERTFILE"], ) else: uvicorn.run(app, host=args.host, port=args.port, log_level="info")