Spaces:
Running
Running
from __future__ import annotations | |
import json | |
import unittest | |
import contextlib | |
from uuid import uuid4, UUID | |
from copy import deepcopy | |
from typing import Generator, Literal | |
import requests | |
import gradio as gr | |
from spitfight.colosseum.common import ( | |
COLOSSEUM_PROMPT_ROUTE, | |
COLOSSEUM_RESP_VOTE_ROUTE, | |
COLOSSEUM_ENERGY_VOTE_ROUTE, | |
PromptRequest, | |
ResponseVoteRequest, | |
ResponseVoteResponse, | |
EnergyVoteRequest, | |
EnergyVoteResponse, | |
) | |
class ControllerClient: | |
"""Client for the Colosseum controller, to be used by Gradio.""" | |
def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None: | |
"""Initialize the controller client.""" | |
self.controller_addr = controller_addr | |
self.timeout = timeout | |
self.request_id = str(request_id) or str(uuid4()) | |
def fork(self) -> ControllerClient: | |
"""Return a copy of the client with a new request ID.""" | |
return ControllerClient( | |
controller_addr=self.controller_addr, | |
timeout=self.timeout, | |
request_id=uuid4(), | |
) | |
def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]: | |
"""Generate the response of the `index`th model with the prompt.""" | |
prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index) | |
with _catch_requests_exceptions(): | |
resp = requests.post( | |
f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}", | |
json=prompt_request.dict(), | |
stream=True, | |
timeout=self.timeout, | |
) | |
_check_response(resp) | |
# XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines? | |
for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
if chunk: | |
yield json.loads(chunk.decode("utf-8")) | |
def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse: | |
"""Notify the controller of the user's vote for the response.""" | |
response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index) | |
with _catch_requests_exceptions(): | |
resp = requests.post( | |
f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}", | |
json=response_vote_request.dict(), | |
) | |
_check_response(resp) | |
return ResponseVoteResponse(**resp.json()) | |
def energy_vote(self, is_worth: bool) -> EnergyVoteResponse: | |
"""Notify the controller of the user's vote for energy.""" | |
energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth) | |
with _catch_requests_exceptions(): | |
resp = requests.post( | |
f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}", | |
json=energy_vote_request.dict(), | |
) | |
_check_response(resp) | |
return EnergyVoteResponse(**resp.json()) | |
def _catch_requests_exceptions(): | |
"""Catch requests exceptions and raise gr.Error instead.""" | |
try: | |
yield | |
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): | |
raise gr.Error("Failed to connect to our the backend server. Please try again later.") | |
def _check_response(response: requests.Response) -> None: | |
if 400 <= response.status_code < 500: | |
raise gr.Error(response.json()["detail"]) | |
elif response.status_code >= 500: | |
raise gr.Error("Failed to talk to our backend server. Please try again later.") | |
class TestControllerClient(unittest.TestCase): | |
def test_new_uuid_on_deepcopy(self): | |
client = ControllerClient("http://localhost:8000") | |
clients = [client.fork() for _ in range(50)] | |
request_ids = [client.request_id for client in clients] | |
assert len(set(request_ids)) == len(request_ids) | |
if __name__ == "__main__": | |
unittest.main() | |