File size: 3,997 Bytes
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

import json
import unittest
import contextlib
from uuid import uuid4, UUID
from copy import deepcopy
from typing import Generator, Literal

import requests
import gradio as gr

from spitfight.colosseum.common import (
    COLOSSEUM_PROMPT_ROUTE,
    COLOSSEUM_RESP_VOTE_ROUTE,
    COLOSSEUM_ENERGY_VOTE_ROUTE,
    PromptRequest,
    ResponseVoteRequest,
    ResponseVoteResponse,
    EnergyVoteRequest,
    EnergyVoteResponse,
)


class ControllerClient:
    """Client for the Colosseum controller, to be used by Gradio."""

    def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None:
        """Initialize the controller client."""
        self.controller_addr = controller_addr
        self.timeout = timeout
        self.request_id = str(request_id) or str(uuid4())

    def fork(self) -> ControllerClient:
        """Return a copy of the client with a new request ID."""
        return ControllerClient(
            controller_addr=self.controller_addr,
            timeout=self.timeout,
            request_id=uuid4(),
        )

    def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]:
        """Generate the response of the `index`th model with the prompt."""
        prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index)
        with _catch_requests_exceptions():
            resp = requests.post(
                f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
                json=prompt_request.dict(),
                stream=True,
                timeout=self.timeout,
            )
        _check_response(resp)
        # XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines?
        for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"):
            if chunk:
                yield json.loads(chunk.decode("utf-8"))

    def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse:
        """Notify the controller of the user's vote for the response."""
        response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index)
        with _catch_requests_exceptions():
            resp = requests.post(
                f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}",
                json=response_vote_request.dict(),
            )
        _check_response(resp)
        return ResponseVoteResponse(**resp.json())

    def energy_vote(self, is_worth: bool) -> EnergyVoteResponse:
        """Notify the controller of the user's vote for energy."""
        energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth)
        with _catch_requests_exceptions():
            resp = requests.post(
                f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}",
                json=energy_vote_request.dict(),
            )
        _check_response(resp)
        return EnergyVoteResponse(**resp.json())


@contextlib.contextmanager
def _catch_requests_exceptions():
    """Catch requests exceptions and raise gr.Error instead."""
    try:
        yield
    except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
        raise gr.Error("Failed to connect to our the backend server. Please try again later.")


def _check_response(response: requests.Response) -> None:
    if 400 <= response.status_code < 500:
        raise gr.Error(response.json()["detail"])
    elif response.status_code >= 500:
        raise gr.Error("Failed to talk to our backend server. Please try again later.")


class TestControllerClient(unittest.TestCase):
    def test_new_uuid_on_deepcopy(self):
        client = ControllerClient("http://localhost:8000")
        clients = [client.fork() for _ in range(50)]
        request_ids = [client.request_id for client in clients]
        assert len(set(request_ids)) == len(request_ids)


if __name__ == "__main__":
    unittest.main()