Spaces:
Running
Running
File size: 10,576 Bytes
8ff63e4 7ed0b8b 8ff63e4 e38f79f 7ed0b8b e38f79f 8ff63e4 7ed0b8b 8ff63e4 7ed0b8b 8ff63e4 7ed0b8b 8ff63e4 e38f79f 8ff63e4 7ed0b8b 8ff63e4 e38f79f 8ff63e4 e38f79f 8ff63e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
from __future__ import annotations
import json
import asyncio
from datetime import datetime
from typing import AsyncGenerator, Literal, Optional, TYPE_CHECKING
import aiohttp
from pytz import timezone
from pydantic import BaseModel, Field
from spitfight.log import get_logger
from spitfight.utils import BoundedExpiringDict, TokenGenerationBuffer, create_task
from spitfight.colosseum.controller.worker import WorkerService
from spitfight.prompt import apply_model_characteristics
if TYPE_CHECKING:
from spitfight.colosseum.controller.router import ControllerConfig
controller_logger = get_logger(__name__)
request_logger = get_logger("colosseum_requests")
def now() -> datetime:
return datetime.now(tz=timezone("US/Eastern"))
# Internal states
# The two "chose_*" stages are both the result of voting on a response.
# A normal user will sequentially go through either
# "prompted" -> "chose_less_energy_response", or
# "prompted" -> "chose_more_energy_response" -> "voted_energy"
UserStage = Literal[
"prompted",
"chose_less_energy_response",
"chose_more_energy_response",
"voted_energy",
]
class RequestState(BaseModel):
"""Models the state of a Colosseum play.
This model is also serialized as is and logged.
"""
request_id: str
model_names: list[str]
raw_prompt: str
model_preference: str
responses: list[str] = ["UNSET", "UNSET"]
model_prompts: list[str] = ["UNSET", "UNSET"]
energy_consumptions: list[float] = [0.0, 0.0]
response_victory_index: Optional[Literal[0, 1]] = None
extra_energy_was_worth: Optional[bool] = None
# The time when the user's stage changed.
timestamp: datetime = Field(default_factory=now)
# The user's current stage.
user_stage: UserStage = "prompted"
# When the the user is not going through the aforementioned stages,
# the user's stage transition is recorded here.
abnormal_stage_change: list[tuple[UserStage, UserStage]] = []
def set_response_and_energy(self, model_index: Literal[0, 1], response: str, energy_consumption: float) -> None:
self.timestamp = now()
self.energy_consumptions[model_index] = energy_consumption
self.responses[model_index] = response
def set_response_vote(self, victory_index: Literal[0, 1]) -> None:
self.timestamp = now()
# Next stage depends on the user's vote.
energy_a, energy_b = self.energy_consumptions
if (victory_index == 0 and energy_a <= energy_b) or (victory_index == 1 and energy_a >= energy_b):
next_stage = "chose_less_energy_response"
else:
next_stage = "chose_more_energy_response"
# Detect abnormal stage change.
if self.user_stage != "prompted":
self.abnormal_stage_change.append((self.user_stage, next_stage))
self.user_stage = next_stage
self.response_victory_index = victory_index
def set_energy_vote(self, is_worth: bool) -> None:
self.timestamp = now()
# Detect abnormal stage change.
if self.user_stage != "chose_more_energy_response":
self.abnormal_stage_change.append((self.user_stage, "voted_energy"))
self.user_stage = "voted_energy"
self.extra_energy_was_worth = is_worth
class GenerationConfig(BaseModel):
"""Configuration for generation of prompts."""
max_new_tokens: int
do_sample: bool
temperature: float
repetition_penalty: float
top_k: int
top_p: float
class Controller:
def __init__(
self,
background_task_interval: int,
max_num_req_states: int,
req_state_expiration_time: int,
worker_service: WorkerService,
generation_config: GenerationConfig,
):
self.request_states: BoundedExpiringDict[str, RequestState] = \
BoundedExpiringDict(max_num_req_states, req_state_expiration_time)
self.worker_service = worker_service
self.generation_config = generation_config
self.background_task_handle = create_task(
self._background_task(background_task_interval),
)
def shutdown(self) -> None:
"""Shutdown the controller."""
self.background_task_handle.cancel()
async def _background_task(self, heartbeat_interval: int) -> None:
"""Periodically check if dead workers are alive again and do request state GC."""
while True:
await asyncio.sleep(heartbeat_interval)
await self.worker_service.check_workers()
prev_num_req_states = len(self.request_states)
self.request_states.cleanup()
controller_logger.info(
"Request state garbage collection done: Removed %d reqeusts",
prev_num_req_states - len(self.request_states),
)
def get_available_models(self) -> list[str]:
"""Return the names of available models."""
return [
worker.model_name
for worker in self.worker_service.workers
if worker.status == "up"
]
def response_vote(self, request_id: str, victory_index: Literal[0, 1]) -> RequestState | None:
"""Record the user's response vote and return the new state."""
if (state := self.request_states.get(request_id)) is not None:
state.set_response_vote(victory_index)
# Pop the state from the dict if the user has voted on energy.
if state.user_stage == "chose_less_energy_response":
self.request_states.pop(request_id)
request_logger.info(state.json())
return state
return None
def energy_vote(self, request_id: str, is_worth: bool) -> RequestState | None:
"""Record the user's energy vote and return the new state."""
# Pop the state from the dict, since this is the last step in any case.
if (state := self.request_states.pop(request_id)) is not None:
state.set_energy_vote(is_worth)
request_logger.info(state.json())
return state
return None
async def prompt(
self,
request_id: str,
prompt: str,
model_index: Literal[0, 1],
model_preference: str,
) -> AsyncGenerator[bytes, None]:
# This method is called twice for the same request, once for each model.
# If it's the first time this method is called, assign models to the request.
if request_id not in self.request_states:
workers = self.worker_service.choose_based_on_preference(model_preference)
model_names = [worker.model_name for worker in workers]
self.request_states[request_id] = RequestState(
request_id=request_id,
raw_prompt=prompt,
model_names=model_names,
model_preference=model_preference,
)
request_state = self.request_states[request_id]
model_name = request_state.model_names[model_index]
try:
worker = self.worker_service.get_worker(model_name)
except KeyError:
controller_logger.error("Worker %s not found.", model_name)
raise
except RuntimeError:
controller_logger.error("Worker %s is dead.", model_name)
raise
# Models have different prompt formatting requirements and stopping criteria.
prompt, stop_str, stop_token_ids = apply_model_characteristics(
prompt=prompt,
model_name=worker.model_id,
)
request_state.model_prompts[model_index] = prompt
# Request the model worker to stream the response to the user's prompt.
response = ""
energy = 0.0
client = worker.get_client()
buffer = TokenGenerationBuffer(stop_str=stop_str)
try:
async for resp in client.generate_stream(
prompt=prompt,
stop_sequences=[stop_str] if stop_str is not None else None,
**self.generation_config.dict(),
):
# Even special tokens consume energy when they're generated.
energy += resp.token.energy
# Stop tokens usually don't overlap with (human-readable) stop sequences.
# if resp.token.special or resp.token.id in stop_token_ids:
if resp.token.id in stop_token_ids:
# If the buffer is not empty (i.e., we had partial stop_str matches),
# just yield it to the user.
if (chunk := buffer.token_buffer):
response += chunk
yield json.dumps(chunk).encode() + b"\0"
break
# Skip special tokens.
if resp.token.special:
continue
# The buffer automatically handles `stop_str` partial and full matches.
buffer.append(resp.token.text)
if (chunk := buffer.pop()) is not None:
response += chunk
yield json.dumps(chunk).encode() + b"\0"
elif buffer.matched_stop_str:
break
except aiohttp.ClientConnectorError:
worker.status = "down"
controller_logger.error(
"Problem talking to %s. Aborting and setting worker status to down",
repr(worker),
)
raise
except Exception:
yield json.dumps(buffer.token_buffer).encode() + b"\0"
raise
finally:
request_state.set_response_and_energy(model_index, response, energy)
request_logger.info(request_state.json())
CONTROLLER: Controller | None = None
def init_global_controller(config: ControllerConfig) -> None:
global CONTROLLER
CONTROLLER = Controller(
background_task_interval=config.background_task_interval,
max_num_req_states=config.max_num_req_states,
req_state_expiration_time=config.req_state_expiration_time,
worker_service=WorkerService(config.compose_files),
generation_config=GenerationConfig(
max_new_tokens=config.max_new_tokens,
do_sample=config.do_sample,
temperature=config.temperature,
repetition_penalty=config.repetition_penalty,
top_k=config.top_k,
top_p=config.top_p,
),
)
def get_global_controller() -> Controller:
global CONTROLLER
assert CONTROLLER is not None
return CONTROLLER
|