FIRE / src /serve /base_model_worker.py
zhangbofei
fix: src
2238fe2
import asyncio
import threading
import time
from typing import List
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import requests
from src.constants import WORKER_HEART_BEAT_INTERVAL
from src.conversation import Conversation
from src.utils import pretty_print_semaphore, build_logger
worker = None
logger = None
app = FastAPI()
def heart_beat_worker(obj):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
obj.send_heart_beat()
class BaseModelWorker:
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
multimodal: bool = False,
):
global logger, worker
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_names = model_names or [model_path.split("/")[-1]]
self.limit_worker_concurrency = limit_worker_concurrency
self.conv = self.make_conv_template(conv_template, model_path)
self.conv.sep_style = int(self.conv.sep_style)
self.multimodal = multimodal
self.tokenizer = None
self.context_len = None
self.call_ct = 0
self.semaphore = None
self.heart_beat_thread = None
if logger is None:
logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log")
if worker is None:
worker = self
def make_conv_template(
self,
conv_template: str = None,
model_path: str = None,
) -> Conversation:
"""
can be overrided to costomize the conversation template for different model workers.
"""
from fastchat.conversation import get_conv_template
from fastchat.model.model_adapter import get_conversation_template
if conv_template:
conv = get_conv_template(conv_template)
else:
conv = get_conversation_template(model_path)
return conv
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker,
args=(self,),
daemon=True,
)
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
"multimodal": self.multimodal,
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(
f"Send heart beat. Models: {self.model_names}. "
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
f"call_ct: {self.call_ct}. "
f"worker_id: {self.worker_id}. "
)
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(
url,
json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length(),
},
timeout=5,
)
exist = ret.json()["exist"]
break
except (requests.exceptions.RequestException, KeyError) as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if self.semaphore is None:
return 0
else:
sempahore_value = (
self.semaphore._value
if self.semaphore._value is not None
else self.limit_worker_concurrency
)
waiter_count = (
0 if self.semaphore._waiters is None else len(self.semaphore._waiters)
)
return self.limit_worker_concurrency - sempahore_value + waiter_count
def get_status(self):
return {
"model_names": self.model_names,
"speed": 1,
"queue_length": self.get_queue_length(),
}
def count_token(self, params):
prompt = params["prompt"]
try:
input_ids = self.tokenizer(prompt).input_ids
input_echo_len = len(input_ids)
except TypeError:
input_echo_len = self.tokenizer.num_tokens(prompt)
ret = {
"count": input_echo_len,
"error_code": 0,
}
return ret
def get_conv_template(self):
return {"conv": self.conv}
def generate_stream_gate(self, params):
raise NotImplementedError
def generate_gate(self, params):
raise NotImplementedError
def get_embeddings(self, params):
raise NotImplementedError
def release_worker_semaphore():
worker.semaphore.release()
def acquire_worker_semaphore():
if worker.semaphore is None:
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
return worker.semaphore.acquire()
def create_background_tasks():
background_tasks = BackgroundTasks()
background_tasks.add_task(release_worker_semaphore)
return background_tasks
@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
await acquire_worker_semaphore()
generator = worker.generate_stream_gate(params)
background_tasks = create_background_tasks()
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
await acquire_worker_semaphore()
output = await asyncio.to_thread(worker.generate_gate, params)
release_worker_semaphore()
return JSONResponse(output)
@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
params = await request.json()
await acquire_worker_semaphore()
embedding = worker.get_embeddings(params)
release_worker_semaphore()
return JSONResponse(content=embedding)
@app.post("/worker_get_status")
async def api_get_status(request: Request):
return worker.get_status()
@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
return worker.count_token(params)
@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
return worker.get_conv_template()
@app.post("/model_details")
async def api_model_details(request: Request):
return {"context_length": worker.context_len}