Jae-Won Chung
Make one model selectable by user (#23)
7ed0b8b unverified
raw
history blame
10.6 kB
from __future__ import annotations
import json
import asyncio
from datetime import datetime
from typing import AsyncGenerator, Literal, Optional, TYPE_CHECKING
import aiohttp
from pytz import timezone
from pydantic import BaseModel, Field
from spitfight.log import get_logger
from spitfight.utils import BoundedExpiringDict, TokenGenerationBuffer, create_task
from spitfight.colosseum.controller.worker import WorkerService
from spitfight.prompt import apply_model_characteristics
if TYPE_CHECKING:
from spitfight.colosseum.controller.router import ControllerConfig
controller_logger = get_logger(__name__)
request_logger = get_logger("colosseum_requests")
def now() -> datetime:
return datetime.now(tz=timezone("US/Eastern"))
# Internal states
# The two "chose_*" stages are both the result of voting on a response.
# A normal user will sequentially go through either
# "prompted" -> "chose_less_energy_response", or
# "prompted" -> "chose_more_energy_response" -> "voted_energy"
UserStage = Literal[
"prompted",
"chose_less_energy_response",
"chose_more_energy_response",
"voted_energy",
]
class RequestState(BaseModel):
"""Models the state of a Colosseum play.
This model is also serialized as is and logged.
"""
request_id: str
model_names: list[str]
raw_prompt: str
model_preference: str
responses: list[str] = ["UNSET", "UNSET"]
model_prompts: list[str] = ["UNSET", "UNSET"]
energy_consumptions: list[float] = [0.0, 0.0]
response_victory_index: Optional[Literal[0, 1]] = None
extra_energy_was_worth: Optional[bool] = None
# The time when the user's stage changed.
timestamp: datetime = Field(default_factory=now)
# The user's current stage.
user_stage: UserStage = "prompted"
# When the the user is not going through the aforementioned stages,
# the user's stage transition is recorded here.
abnormal_stage_change: list[tuple[UserStage, UserStage]] = []
def set_response_and_energy(self, model_index: Literal[0, 1], response: str, energy_consumption: float) -> None:
self.timestamp = now()
self.energy_consumptions[model_index] = energy_consumption
self.responses[model_index] = response
def set_response_vote(self, victory_index: Literal[0, 1]) -> None:
self.timestamp = now()
# Next stage depends on the user's vote.
energy_a, energy_b = self.energy_consumptions
if (victory_index == 0 and energy_a <= energy_b) or (victory_index == 1 and energy_a >= energy_b):
next_stage = "chose_less_energy_response"
else:
next_stage = "chose_more_energy_response"
# Detect abnormal stage change.
if self.user_stage != "prompted":
self.abnormal_stage_change.append((self.user_stage, next_stage))
self.user_stage = next_stage
self.response_victory_index = victory_index
def set_energy_vote(self, is_worth: bool) -> None:
self.timestamp = now()
# Detect abnormal stage change.
if self.user_stage != "chose_more_energy_response":
self.abnormal_stage_change.append((self.user_stage, "voted_energy"))
self.user_stage = "voted_energy"
self.extra_energy_was_worth = is_worth
class GenerationConfig(BaseModel):
"""Configuration for generation of prompts."""
max_new_tokens: int
do_sample: bool
temperature: float
repetition_penalty: float
top_k: int
top_p: float
class Controller:
def __init__(
self,
background_task_interval: int,
max_num_req_states: int,
req_state_expiration_time: int,
worker_service: WorkerService,
generation_config: GenerationConfig,
):
self.request_states: BoundedExpiringDict[str, RequestState] = \
BoundedExpiringDict(max_num_req_states, req_state_expiration_time)
self.worker_service = worker_service
self.generation_config = generation_config
self.background_task_handle = create_task(
self._background_task(background_task_interval),
)
def shutdown(self) -> None:
"""Shutdown the controller."""
self.background_task_handle.cancel()
async def _background_task(self, heartbeat_interval: int) -> None:
"""Periodically check if dead workers are alive again and do request state GC."""
while True:
await asyncio.sleep(heartbeat_interval)
await self.worker_service.check_workers()
prev_num_req_states = len(self.request_states)
self.request_states.cleanup()
controller_logger.info(
"Request state garbage collection done: Removed %d reqeusts",
prev_num_req_states - len(self.request_states),
)
def get_available_models(self) -> list[str]:
"""Return the names of available models."""
return [
worker.model_name
for worker in self.worker_service.workers
if worker.status == "up"
]
def response_vote(self, request_id: str, victory_index: Literal[0, 1]) -> RequestState | None:
"""Record the user's response vote and return the new state."""
if (state := self.request_states.get(request_id)) is not None:
state.set_response_vote(victory_index)
# Pop the state from the dict if the user has voted on energy.
if state.user_stage == "chose_less_energy_response":
self.request_states.pop(request_id)
request_logger.info(state.json())
return state
return None
def energy_vote(self, request_id: str, is_worth: bool) -> RequestState | None:
"""Record the user's energy vote and return the new state."""
# Pop the state from the dict, since this is the last step in any case.
if (state := self.request_states.pop(request_id)) is not None:
state.set_energy_vote(is_worth)
request_logger.info(state.json())
return state
return None
async def prompt(
self,
request_id: str,
prompt: str,
model_index: Literal[0, 1],
model_preference: str,
) -> AsyncGenerator[bytes, None]:
# This method is called twice for the same request, once for each model.
# If it's the first time this method is called, assign models to the request.
if request_id not in self.request_states:
workers = self.worker_service.choose_based_on_preference(model_preference)
model_names = [worker.model_name for worker in workers]
self.request_states[request_id] = RequestState(
request_id=request_id,
raw_prompt=prompt,
model_names=model_names,
model_preference=model_preference,
)
request_state = self.request_states[request_id]
model_name = request_state.model_names[model_index]
try:
worker = self.worker_service.get_worker(model_name)
except KeyError:
controller_logger.error("Worker %s not found.", model_name)
raise
except RuntimeError:
controller_logger.error("Worker %s is dead.", model_name)
raise
# Models have different prompt formatting requirements and stopping criteria.
prompt, stop_str, stop_token_ids = apply_model_characteristics(
prompt=prompt,
model_name=worker.model_id,
)
request_state.model_prompts[model_index] = prompt
# Request the model worker to stream the response to the user's prompt.
response = ""
energy = 0.0
client = worker.get_client()
buffer = TokenGenerationBuffer(stop_str=stop_str)
try:
async for resp in client.generate_stream(
prompt=prompt,
stop_sequences=[stop_str] if stop_str is not None else None,
**self.generation_config.dict(),
):
# Even special tokens consume energy when they're generated.
energy += resp.token.energy
# Stop tokens usually don't overlap with (human-readable) stop sequences.
# if resp.token.special or resp.token.id in stop_token_ids:
if resp.token.id in stop_token_ids:
# If the buffer is not empty (i.e., we had partial stop_str matches),
# just yield it to the user.
if (chunk := buffer.token_buffer):
response += chunk
yield json.dumps(chunk).encode() + b"\0"
break
# Skip special tokens.
if resp.token.special:
continue
# The buffer automatically handles `stop_str` partial and full matches.
buffer.append(resp.token.text)
if (chunk := buffer.pop()) is not None:
response += chunk
yield json.dumps(chunk).encode() + b"\0"
elif buffer.matched_stop_str:
break
except aiohttp.ClientConnectorError:
worker.status = "down"
controller_logger.error(
"Problem talking to %s. Aborting and setting worker status to down",
repr(worker),
)
raise
except Exception:
yield json.dumps(buffer.token_buffer).encode() + b"\0"
raise
finally:
request_state.set_response_and_energy(model_index, response, energy)
request_logger.info(request_state.json())
CONTROLLER: Controller | None = None
def init_global_controller(config: ControllerConfig) -> None:
global CONTROLLER
CONTROLLER = Controller(
background_task_interval=config.background_task_interval,
max_num_req_states=config.max_num_req_states,
req_state_expiration_time=config.req_state_expiration_time,
worker_service=WorkerService(config.compose_files),
generation_config=GenerationConfig(
max_new_tokens=config.max_new_tokens,
do_sample=config.do_sample,
temperature=config.temperature,
repetition_penalty=config.repetition_penalty,
top_k=config.top_k,
top_p=config.top_p,
),
)
def get_global_controller() -> Controller:
global CONTROLLER
assert CONTROLLER is not None
return CONTROLLER