Spaces:
No application file
No application file
""" | |
A model worker that calls huggingface inference endpoint. | |
Register models in a JSON file with the following format: | |
{ | |
"falcon-180b-chat": { | |
"model_name": "falcon-180B-chat", | |
"api_base": "https://api-inference.huggingface.co/models", | |
"model_path": "tiiuae/falcon-180B-chat", | |
"token": "hf_XXX", | |
"context_length": 2048 | |
}, | |
"zephyr-7b-beta": { | |
"model_name": "zephyr-7b-beta", | |
"model_path": "", | |
"api_base": "xxx", | |
"token": "hf_XXX", | |
"context_length": 4096 | |
} | |
} | |
"model_path", "api_base", "token", and "context_length" are necessary, while others are optional. | |
""" | |
import argparse | |
import asyncio | |
import json | |
import uuid | |
import os | |
from typing import List, Optional | |
import requests | |
import uvicorn | |
from fastapi import BackgroundTasks, FastAPI, Request | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from huggingface_hub import InferenceClient | |
from fastchat.constants import SERVER_ERROR_MSG, ErrorCode | |
from fastchat.serve.base_model_worker import BaseModelWorker | |
from fastchat.utils import build_logger | |
worker_id = str(uuid.uuid4())[:8] | |
logger = build_logger("model_worker", f"model_worker_{worker_id}.log") | |
workers = [] | |
worker_map = {} | |
app = FastAPI() | |
# reference to | |
# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 | |
def get_gen_kwargs( | |
params, | |
seed: Optional[int] = None, | |
): | |
stop = params.get("stop", None) | |
if isinstance(stop, list): | |
stop_sequences = stop | |
elif isinstance(stop, str): | |
stop_sequences = [stop] | |
else: | |
stop_sequences = [] | |
gen_kwargs = { | |
"do_sample": True, | |
"return_full_text": bool(params.get("echo", False)), | |
"max_new_tokens": int(params.get("max_new_tokens", 256)), | |
"top_p": float(params.get("top_p", 1.0)), | |
"temperature": float(params.get("temperature", 1.0)), | |
"stop_sequences": stop_sequences, | |
"repetition_penalty": float(params.get("repetition_penalty", 1.0)), | |
"top_k": params.get("top_k", None), | |
"seed": seed, | |
} | |
if gen_kwargs["top_p"] == 1: | |
gen_kwargs["top_p"] = 0.9999999 | |
if gen_kwargs["top_p"] == 0: | |
gen_kwargs.pop("top_p") | |
if gen_kwargs["temperature"] == 0: | |
gen_kwargs.pop("temperature") | |
gen_kwargs["do_sample"] = False | |
return gen_kwargs | |
def could_be_stop(text, stop): | |
for s in stop: | |
if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): | |
return True | |
return False | |
class HuggingfaceApiWorker(BaseModelWorker): | |
def __init__( | |
self, | |
controller_addr: str, | |
worker_addr: str, | |
worker_id: str, | |
model_path: str, | |
api_base: str, | |
token: str, | |
context_length: int, | |
model_names: List[str], | |
limit_worker_concurrency: int, | |
no_register: bool, | |
conv_template: Optional[str] = None, | |
seed: Optional[int] = None, | |
**kwargs, | |
): | |
super().__init__( | |
controller_addr, | |
worker_addr, | |
worker_id, | |
model_path, | |
model_names, | |
limit_worker_concurrency, | |
conv_template=conv_template, | |
) | |
self.model_path = model_path | |
self.api_base = api_base | |
self.token = token | |
self.context_len = context_length | |
self.seed = seed | |
logger.info( | |
f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." | |
) | |
if not no_register: | |
self.init_heart_beat() | |
def count_token(self, params): | |
# No tokenizer here | |
ret = { | |
"count": 0, | |
"error_code": 0, | |
} | |
return ret | |
def generate_stream_gate(self, params): | |
self.call_ct += 1 | |
prompt = params["prompt"] | |
gen_kwargs = get_gen_kwargs(params, seed=self.seed) | |
stop = gen_kwargs["stop_sequences"] | |
if "falcon" in self.model_path and "chat" in self.model_path: | |
stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) | |
stop = list(set(stop)) | |
gen_kwargs["stop_sequences"] = stop | |
logger.info(f"prompt: {prompt}") | |
logger.info(f"gen_kwargs: {gen_kwargs}") | |
try: | |
if self.model_path == "": | |
url = f"{self.api_base}" | |
else: | |
url = f"{self.api_base}/{self.model_path}" | |
client = InferenceClient(url, token=self.token) | |
res = client.text_generation( | |
prompt, stream=True, details=True, **gen_kwargs | |
) | |
reason = None | |
text = "" | |
for chunk in res: | |
if chunk.token.special: | |
continue | |
text += chunk.token.text | |
s = next((x for x in stop if text.endswith(x)), None) | |
if s is not None: | |
text = text[: -len(s)] | |
reason = "stop" | |
break | |
if could_be_stop(text, stop): | |
continue | |
if ( | |
chunk.details is not None | |
and chunk.details.finish_reason is not None | |
): | |
reason = chunk.details.finish_reason | |
if reason not in ["stop", "length"]: | |
reason = None | |
ret = { | |
"text": text, | |
"error_code": 0, | |
"finish_reason": reason, | |
} | |
yield json.dumps(ret).encode() + b"\0" | |
except Exception as e: | |
ret = { | |
"text": f"{SERVER_ERROR_MSG}\n\n({e})", | |
"error_code": ErrorCode.INTERNAL_ERROR, | |
} | |
yield json.dumps(ret).encode() + b"\0" | |
def generate_gate(self, params): | |
for x in self.generate_stream_gate(params): | |
pass | |
return json.loads(x[:-1].decode()) | |
def get_embeddings(self, params): | |
raise NotImplementedError() | |
def release_worker_semaphore(worker): | |
worker.semaphore.release() | |
def acquire_worker_semaphore(worker): | |
if worker.semaphore is None: | |
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | |
return worker.semaphore.acquire() | |
def create_background_tasks(worker): | |
background_tasks = BackgroundTasks() | |
background_tasks.add_task(lambda: release_worker_semaphore(worker)) | |
return background_tasks | |
async def api_generate_stream(request: Request): | |
params = await request.json() | |
worker = worker_map[params["model"]] | |
await acquire_worker_semaphore(worker) | |
generator = worker.generate_stream_gate(params) | |
background_tasks = create_background_tasks(worker) | |
return StreamingResponse(generator, background=background_tasks) | |
async def api_generate(request: Request): | |
params = await request.json() | |
worker = worker_map[params["model"]] | |
await acquire_worker_semaphore(worker) | |
output = worker.generate_gate(params) | |
release_worker_semaphore(worker) | |
return JSONResponse(output) | |
async def api_get_embeddings(request: Request): | |
params = await request.json() | |
worker = worker_map[params["model"]] | |
await acquire_worker_semaphore(worker) | |
embedding = worker.get_embeddings(params) | |
release_worker_semaphore(worker) | |
return JSONResponse(content=embedding) | |
async def api_get_status(request: Request): | |
return { | |
"model_names": [m for w in workers for m in w.model_names], | |
"speed": 1, | |
"queue_length": sum([w.get_queue_length() for w in workers]), | |
} | |
async def api_count_token(request: Request): | |
params = await request.json() | |
worker = worker_map[params["model"]] | |
return worker.count_token(params) | |
async def api_get_conv(request: Request): | |
params = await request.json() | |
worker = worker_map[params["model"]] | |
return worker.get_conv_template() | |
async def api_model_details(request: Request): | |
params = await request.json() | |
worker = worker_map[params["model"]] | |
return {"context_length": worker.context_len} | |
def create_huggingface_api_worker(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="localhost") | |
parser.add_argument("--port", type=int, default=21002) | |
parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | |
parser.add_argument( | |
"--controller-address", type=str, default="http://localhost:21001" | |
) | |
# all model-related parameters are listed in --model-info-file | |
parser.add_argument( | |
"--model-info-file", | |
type=str, | |
required=True, | |
help="Huggingface API model's info file path", | |
) | |
parser.add_argument( | |
"--limit-worker-concurrency", | |
type=int, | |
default=5, | |
help="Limit the model concurrency to prevent OOM.", | |
) | |
parser.add_argument("--no-register", action="store_true") | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=None, | |
help="Overwrite the random seed for each generation.", | |
) | |
parser.add_argument( | |
"--ssl", | |
action="store_true", | |
required=False, | |
default=False, | |
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", | |
) | |
args = parser.parse_args() | |
with open(args.model_info_file, "r", encoding="UTF-8") as f: | |
model_info = json.load(f) | |
logger.info(f"args: {args}") | |
model_path_list = [] | |
api_base_list = [] | |
token_list = [] | |
context_length_list = [] | |
model_names_list = [] | |
conv_template_list = [] | |
for m in model_info: | |
model_path_list.append(model_info[m]["model_path"]) | |
api_base_list.append(model_info[m]["api_base"]) | |
token_list.append(model_info[m]["token"]) | |
context_length = model_info[m]["context_length"] | |
model_names = model_info[m].get("model_names", [m.split("/")[-1]]) | |
if isinstance(model_names, str): | |
model_names = [model_names] | |
conv_template = model_info[m].get("conv_template", None) | |
context_length_list.append(context_length) | |
model_names_list.append(model_names) | |
conv_template_list.append(conv_template) | |
logger.info(f"Model paths: {model_path_list}") | |
logger.info(f"API bases: {api_base_list}") | |
logger.info(f"Tokens: {token_list}") | |
logger.info(f"Context lengths: {context_length_list}") | |
logger.info(f"Model names: {model_names_list}") | |
logger.info(f"Conv templates: {conv_template_list}") | |
for ( | |
model_names, | |
conv_template, | |
model_path, | |
api_base, | |
token, | |
context_length, | |
) in zip( | |
model_names_list, | |
conv_template_list, | |
model_path_list, | |
api_base_list, | |
token_list, | |
context_length_list, | |
): | |
m = HuggingfaceApiWorker( | |
args.controller_address, | |
args.worker_address, | |
worker_id, | |
model_path, | |
api_base, | |
token, | |
context_length, | |
model_names, | |
args.limit_worker_concurrency, | |
no_register=args.no_register, | |
conv_template=conv_template, | |
seed=args.seed, | |
) | |
workers.append(m) | |
for name in model_names: | |
worker_map[name] = m | |
# register all the models | |
url = args.controller_address + "/register_worker" | |
data = { | |
"worker_name": workers[0].worker_addr, | |
"check_heart_beat": not args.no_register, | |
"worker_status": { | |
"model_names": [m for w in workers for m in w.model_names], | |
"speed": 1, | |
"queue_length": sum([w.get_queue_length() for w in workers]), | |
}, | |
} | |
r = requests.post(url, json=data) | |
assert r.status_code == 200 | |
return args, workers | |
if __name__ == "__main__": | |
args, workers = create_huggingface_api_worker() | |
if args.ssl: | |
uvicorn.run( | |
app, | |
host=args.host, | |
port=args.port, | |
log_level="info", | |
ssl_keyfile=os.environ["SSL_KEYFILE"], | |
ssl_certfile=os.environ["SSL_CERTFILE"], | |
) | |
else: | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |