Jae-Won Chung
Reset to random
3c356de
from __future__ import annotations
import json
import contextlib
from uuid import uuid4, UUID
from typing import Generator, Literal
import requests
import gradio as gr
from spitfight.colosseum.common import (
COLOSSEUM_MODELS_ROUTE,
COLOSSEUM_PROMPT_ROUTE,
COLOSSEUM_RESP_VOTE_ROUTE,
COLOSSEUM_ENERGY_VOTE_ROUTE,
ModelsResponse,
PromptRequest,
ResponseVoteRequest,
ResponseVoteResponse,
EnergyVoteRequest,
EnergyVoteResponse,
)
class ControllerClient:
"""Client for the Colosseum controller, to be used by Gradio."""
def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None:
"""Initialize the controller client."""
self.controller_addr = controller_addr
self.timeout = timeout
self.request_id = str(uuid4()) if request_id is None else str(request_id)
def fork(self) -> ControllerClient:
"""Return a copy of the client with a new request ID."""
return ControllerClient(
controller_addr=self.controller_addr,
timeout=self.timeout,
request_id=uuid4(),
)
def get_available_models(self) -> list[str]:
"""Retrieve the list of available models."""
with _catch_requests_exceptions():
resp = requests.get(
f"http://{self.controller_addr}{COLOSSEUM_MODELS_ROUTE}",
timeout=self.timeout,
)
_check_response(resp)
return ModelsResponse(**resp.json()).available_models
def prompt(
self,
prompt: str,
index: Literal[0, 1],
model_preference: str,
) -> Generator[str, None, None]:
"""Generate the response of the `index`th model with the prompt.
`user_pref` is the user's preference for the model to use. It can be
`"Random"` or one of the models in the list returned by `get_available_models`.
"""
prompt_request = PromptRequest(
request_id=self.request_id,
prompt=prompt,
model_index=index,
model_preference=model_preference,
)
with _catch_requests_exceptions():
resp = requests.post(
f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
json=prompt_request.dict(),
stream=True,
timeout=self.timeout,
)
_check_response(resp)
# XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines?
for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield json.loads(chunk.decode("utf-8"))
def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse:
"""Notify the controller of the user's vote for the response."""
response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index)
with _catch_requests_exceptions():
resp = requests.post(
f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}",
json=response_vote_request.dict(),
)
_check_response(resp)
return ResponseVoteResponse(**resp.json())
def energy_vote(self, is_worth: bool) -> EnergyVoteResponse:
"""Notify the controller of the user's vote for energy."""
energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth)
with _catch_requests_exceptions():
resp = requests.post(
f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}",
json=energy_vote_request.dict(),
)
_check_response(resp)
return EnergyVoteResponse(**resp.json())
@contextlib.contextmanager
def _catch_requests_exceptions():
"""Catch requests exceptions and raise gr.Error instead."""
try:
yield
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
raise gr.Error("Failed to connect to our the backend server. Please try again later.")
def _check_response(response: requests.Response) -> None:
if 400 <= response.status_code < 500:
raise gr.Error(response.json()["detail"])
elif response.status_code >= 500:
raise gr.Error("Failed to talk to our backend server. Please try again later.")