|
""" |
|
A controller manages distributed workers. |
|
It sends worker addresses to clients. |
|
""" |
|
import argparse |
|
import asyncio |
|
import dataclasses |
|
from enum import Enum, auto |
|
import json |
|
import logging |
|
import os |
|
import time |
|
from typing import List, Union |
|
import threading |
|
|
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import StreamingResponse |
|
import numpy as np |
|
import requests |
|
import uvicorn |
|
|
|
from src.constants import ( |
|
CONTROLLER_HEART_BEAT_EXPIRATION, |
|
WORKER_API_TIMEOUT, |
|
ErrorCode, |
|
SERVER_ERROR_MSG, |
|
) |
|
from src.utils import build_logger |
|
|
|
|
|
logger = build_logger("controller", "controller.log") |
|
|
|
|
|
class DispatchMethod(Enum): |
|
LOTTERY = auto() |
|
SHORTEST_QUEUE = auto() |
|
|
|
@classmethod |
|
def from_str(cls, name): |
|
if name == "lottery": |
|
return cls.LOTTERY |
|
elif name == "shortest_queue": |
|
return cls.SHORTEST_QUEUE |
|
else: |
|
raise ValueError(f"Invalid dispatch method") |
|
|
|
|
|
@dataclasses.dataclass |
|
class WorkerInfo: |
|
model_names: List[str] |
|
speed: int |
|
queue_length: int |
|
check_heart_beat: bool |
|
last_heart_beat: str |
|
multimodal: bool |
|
|
|
|
|
def heart_beat_controller(controller): |
|
while True: |
|
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) |
|
controller.remove_stale_workers_by_expiration() |
|
|
|
|
|
class Controller: |
|
def __init__(self, dispatch_method: str): |
|
|
|
self.worker_info = {} |
|
self.dispatch_method = DispatchMethod.from_str(dispatch_method) |
|
|
|
self.heart_beat_thread = threading.Thread( |
|
target=heart_beat_controller, args=(self,) |
|
) |
|
self.heart_beat_thread.start() |
|
|
|
def register_worker( |
|
self, |
|
worker_name: str, |
|
check_heart_beat: bool, |
|
worker_status: dict, |
|
multimodal: bool, |
|
): |
|
if worker_name not in self.worker_info: |
|
logger.info(f"Register a new worker: {worker_name}") |
|
else: |
|
logger.info(f"Register an existing worker: {worker_name}") |
|
|
|
if not worker_status: |
|
worker_status = self.get_worker_status(worker_name) |
|
if not worker_status: |
|
return False |
|
|
|
self.worker_info[worker_name] = WorkerInfo( |
|
worker_status["model_names"], |
|
worker_status["speed"], |
|
worker_status["queue_length"], |
|
check_heart_beat, |
|
time.time(), |
|
multimodal, |
|
) |
|
|
|
logger.info(f"Register done: {worker_name}, {worker_status}") |
|
return True |
|
|
|
def get_worker_status(self, worker_name: str): |
|
try: |
|
r = requests.post(worker_name + "/worker_get_status", timeout=5) |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Get status fails: {worker_name}, {e}") |
|
return None |
|
|
|
if r.status_code != 200: |
|
logger.error(f"Get status fails: {worker_name}, {r}") |
|
return None |
|
|
|
return r.json() |
|
|
|
def remove_worker(self, worker_name: str): |
|
del self.worker_info[worker_name] |
|
|
|
def refresh_all_workers(self): |
|
old_info = dict(self.worker_info) |
|
self.worker_info = {} |
|
|
|
for w_name, w_info in old_info.items(): |
|
if not self.register_worker( |
|
w_name, w_info.check_heart_beat, None, w_info.multimodal |
|
): |
|
logger.info(f"Remove stale worker: {w_name}") |
|
|
|
def list_models(self): |
|
model_names = set() |
|
|
|
for w_name, w_info in self.worker_info.items(): |
|
model_names.update(w_info.model_names) |
|
|
|
return list(model_names) |
|
|
|
def list_multimodal_models(self): |
|
model_names = set() |
|
|
|
for w_name, w_info in self.worker_info.items(): |
|
if w_info.multimodal: |
|
model_names.update(w_info.model_names) |
|
|
|
return list(model_names) |
|
|
|
def list_language_models(self): |
|
model_names = set() |
|
|
|
for w_name, w_info in self.worker_info.items(): |
|
if not w_info.multimodal: |
|
model_names.update(w_info.model_names) |
|
|
|
return list(model_names) |
|
|
|
def get_worker_address(self, model_name: str): |
|
if self.dispatch_method == DispatchMethod.LOTTERY: |
|
worker_names = [] |
|
worker_speeds = [] |
|
for w_name, w_info in self.worker_info.items(): |
|
if model_name in w_info.model_names: |
|
worker_names.append(w_name) |
|
worker_speeds.append(w_info.speed) |
|
worker_speeds = np.array(worker_speeds, dtype=np.float32) |
|
norm = np.sum(worker_speeds) |
|
if norm < 1e-4: |
|
return "" |
|
worker_speeds = worker_speeds / norm |
|
if True: |
|
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) |
|
worker_name = worker_names[pt] |
|
return worker_name |
|
|
|
|
|
while True: |
|
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) |
|
worker_name = worker_names[pt] |
|
|
|
if self.get_worker_status(worker_name): |
|
break |
|
else: |
|
self.remove_worker(worker_name) |
|
worker_speeds[pt] = 0 |
|
norm = np.sum(worker_speeds) |
|
if norm < 1e-4: |
|
return "" |
|
worker_speeds = worker_speeds / norm |
|
continue |
|
return worker_name |
|
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: |
|
worker_names = [] |
|
worker_qlen = [] |
|
for w_name, w_info in self.worker_info.items(): |
|
if model_name in w_info.model_names: |
|
worker_names.append(w_name) |
|
worker_qlen.append(w_info.queue_length / w_info.speed) |
|
if len(worker_names) == 0: |
|
return "" |
|
min_index = np.argmin(worker_qlen) |
|
w_name = worker_names[min_index] |
|
self.worker_info[w_name].queue_length += 1 |
|
logger.info( |
|
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" |
|
) |
|
return w_name |
|
else: |
|
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") |
|
|
|
def receive_heart_beat(self, worker_name: str, queue_length: int): |
|
if worker_name not in self.worker_info: |
|
logger.info(f"Receive unknown heart beat. {worker_name}") |
|
return False |
|
|
|
self.worker_info[worker_name].queue_length = queue_length |
|
self.worker_info[worker_name].last_heart_beat = time.time() |
|
logger.info(f"Receive heart beat. {worker_name}") |
|
return True |
|
|
|
def remove_stale_workers_by_expiration(self): |
|
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION |
|
to_delete = [] |
|
for worker_name, w_info in self.worker_info.items(): |
|
if w_info.check_heart_beat and w_info.last_heart_beat < expire: |
|
to_delete.append(worker_name) |
|
|
|
for worker_name in to_delete: |
|
self.remove_worker(worker_name) |
|
|
|
def handle_no_worker(self, params): |
|
logger.info(f"no worker: {params['model']}") |
|
ret = { |
|
"text": SERVER_ERROR_MSG, |
|
"error_code": ErrorCode.CONTROLLER_NO_WORKER, |
|
} |
|
return json.dumps(ret).encode() + b"\0" |
|
|
|
def handle_worker_timeout(self, worker_address): |
|
logger.info(f"worker timeout: {worker_address}") |
|
ret = { |
|
"text": SERVER_ERROR_MSG, |
|
"error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, |
|
} |
|
return json.dumps(ret).encode() + b"\0" |
|
|
|
|
|
|
|
def worker_api_get_status(self): |
|
model_names = set() |
|
speed = 0 |
|
queue_length = 0 |
|
|
|
for w_name in self.worker_info: |
|
worker_status = self.get_worker_status(w_name) |
|
if worker_status is not None: |
|
model_names.update(worker_status["model_names"]) |
|
speed += worker_status["speed"] |
|
queue_length += worker_status["queue_length"] |
|
|
|
model_names = sorted(list(model_names)) |
|
return { |
|
"model_names": model_names, |
|
"speed": speed, |
|
"queue_length": queue_length, |
|
} |
|
|
|
def worker_api_generate_stream(self, params): |
|
worker_addr = self.get_worker_address(params["model"]) |
|
if not worker_addr: |
|
yield self.handle_no_worker(params) |
|
|
|
try: |
|
response = requests.post( |
|
worker_addr + "/worker_generate_stream", |
|
json=params, |
|
stream=True, |
|
timeout=WORKER_API_TIMEOUT, |
|
) |
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
|
if chunk: |
|
yield chunk + b"\0" |
|
except requests.exceptions.RequestException as e: |
|
yield self.handle_worker_timeout(worker_addr) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.post("/register_worker") |
|
async def register_worker(request: Request): |
|
data = await request.json() |
|
controller.register_worker( |
|
data["worker_name"], |
|
data["check_heart_beat"], |
|
data.get("worker_status", None), |
|
data.get("multimodal", False), |
|
) |
|
|
|
|
|
@app.post("/refresh_all_workers") |
|
async def refresh_all_workers(): |
|
models = controller.refresh_all_workers() |
|
|
|
|
|
@app.post("/list_models") |
|
async def list_models(): |
|
models = controller.list_models() |
|
return {"models": models} |
|
|
|
|
|
@app.post("/list_multimodal_models") |
|
async def list_multimodal_models(): |
|
models = controller.list_multimodal_models() |
|
return {"models": models} |
|
|
|
|
|
@app.post("/list_language_models") |
|
async def list_language_models(): |
|
models = controller.list_language_models() |
|
return {"models": models} |
|
|
|
|
|
@app.post("/get_worker_address") |
|
async def get_worker_address(request: Request): |
|
data = await request.json() |
|
addr = controller.get_worker_address(data["model"]) |
|
return {"address": addr} |
|
|
|
|
|
@app.post("/receive_heart_beat") |
|
async def receive_heart_beat(request: Request): |
|
data = await request.json() |
|
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) |
|
return {"exist": exist} |
|
|
|
|
|
@app.post("/worker_generate_stream") |
|
async def worker_api_generate_stream(request: Request): |
|
params = await request.json() |
|
generator = controller.worker_api_generate_stream(params) |
|
return StreamingResponse(generator) |
|
|
|
|
|
@app.post("/worker_get_status") |
|
async def worker_api_get_status(request: Request): |
|
return controller.worker_api_get_status() |
|
|
|
|
|
@app.get("/test_connection") |
|
async def worker_api_get_status(request: Request): |
|
return "success" |
|
|
|
|
|
def create_controller(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", type=str, default="localhost") |
|
parser.add_argument("--port", type=int, default=21001) |
|
parser.add_argument( |
|
"--dispatch-method", |
|
type=str, |
|
choices=["lottery", "shortest_queue"], |
|
default="shortest_queue", |
|
) |
|
parser.add_argument( |
|
"--ssl", |
|
action="store_true", |
|
required=False, |
|
default=False, |
|
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", |
|
) |
|
args = parser.parse_args() |
|
logger.info(f"args: {args}") |
|
|
|
controller = Controller(args.dispatch_method) |
|
return args, controller |
|
|
|
|
|
if __name__ == "__main__": |
|
args, controller = create_controller() |
|
if args.ssl: |
|
uvicorn.run( |
|
app, |
|
host=args.host, |
|
port=args.port, |
|
log_level="info", |
|
ssl_keyfile=os.environ["SSL_KEYFILE"], |
|
ssl_certfile=os.environ["SSL_CERTFILE"], |
|
) |
|
else: |
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
|
|