Spaces:
Running
Running
from __future__ import annotations | |
import json | |
import contextlib | |
from uuid import uuid4, UUID | |
from typing import Generator, Literal | |
import requests | |
import gradio as gr | |
from spitfight.colosseum.common import ( | |
COLOSSEUM_MODELS_ROUTE, | |
COLOSSEUM_PROMPT_ROUTE, | |
COLOSSEUM_RESP_VOTE_ROUTE, | |
COLOSSEUM_ENERGY_VOTE_ROUTE, | |
ModelsResponse, | |
PromptRequest, | |
ResponseVoteRequest, | |
ResponseVoteResponse, | |
EnergyVoteRequest, | |
EnergyVoteResponse, | |
) | |
class ControllerClient: | |
"""Client for the Colosseum controller, to be used by Gradio.""" | |
def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None: | |
"""Initialize the controller client.""" | |
self.controller_addr = controller_addr | |
self.timeout = timeout | |
self.request_id = str(uuid4()) if request_id is None else str(request_id) | |
def fork(self) -> ControllerClient: | |
"""Return a copy of the client with a new request ID.""" | |
return ControllerClient( | |
controller_addr=self.controller_addr, | |
timeout=self.timeout, | |
request_id=uuid4(), | |
) | |
def get_available_models(self) -> list[str]: | |
"""Retrieve the list of available models.""" | |
with _catch_requests_exceptions(): | |
resp = requests.get( | |
f"http://{self.controller_addr}{COLOSSEUM_MODELS_ROUTE}", | |
timeout=self.timeout, | |
) | |
_check_response(resp) | |
return ModelsResponse(**resp.json()).available_models | |
def prompt( | |
self, | |
prompt: str, | |
index: Literal[0, 1], | |
model_preference: str, | |
) -> Generator[str, None, None]: | |
"""Generate the response of the `index`th model with the prompt. | |
`user_pref` is the user's preference for the model to use. It can be | |
`"Random"` or one of the models in the list returned by `get_available_models`. | |
""" | |
prompt_request = PromptRequest( | |
request_id=self.request_id, | |
prompt=prompt, | |
model_index=index, | |
model_preference=model_preference, | |
) | |
with _catch_requests_exceptions(): | |
resp = requests.post( | |
f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}", | |
json=prompt_request.dict(), | |
stream=True, | |
timeout=self.timeout, | |
) | |
_check_response(resp) | |
# XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines? | |
for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
if chunk: | |
yield json.loads(chunk.decode("utf-8")) | |
def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse: | |
"""Notify the controller of the user's vote for the response.""" | |
response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index) | |
with _catch_requests_exceptions(): | |
resp = requests.post( | |
f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}", | |
json=response_vote_request.dict(), | |
) | |
_check_response(resp) | |
return ResponseVoteResponse(**resp.json()) | |
def energy_vote(self, is_worth: bool) -> EnergyVoteResponse: | |
"""Notify the controller of the user's vote for energy.""" | |
energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth) | |
with _catch_requests_exceptions(): | |
resp = requests.post( | |
f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}", | |
json=energy_vote_request.dict(), | |
) | |
_check_response(resp) | |
return EnergyVoteResponse(**resp.json()) | |
def _catch_requests_exceptions(): | |
"""Catch requests exceptions and raise gr.Error instead.""" | |
try: | |
yield | |
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): | |
raise gr.Error("Failed to connect to our the backend server. Please try again later.") | |
def _check_response(response: requests.Response) -> None: | |
if 400 <= response.status_code < 500: | |
raise gr.Error(response.json()["detail"]) | |
elif response.status_code >= 500: | |
raise gr.Error("Failed to talk to our backend server. Please try again later.") | |