File size: 5,036 Bytes
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
7ed0b8b
8ff63e4
 
 
 
7ed0b8b
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ed0b8b
 
 
 
8ff63e4
 
 
 
 
7ed0b8b
 
 
 
 
 
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import json

import uvicorn
from pydantic import BaseSettings
from fastapi import FastAPI, Depends
from fastapi.responses import StreamingResponse
from fastapi.exceptions import HTTPException
from text_generation.errors import OverloadedError, UnknownError, ValidationError

from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
from spitfight.colosseum.common import (
    COLOSSEUM_MODELS_ROUTE,
    COLOSSEUM_PROMPT_ROUTE,
    COLOSSEUM_RESP_VOTE_ROUTE,
    COLOSSEUM_ENERGY_VOTE_ROUTE,
    COLOSSEUM_HEALTH_ROUTE,
    ModelsResponse,
    PromptRequest,
    ResponseVoteRequest,
    ResponseVoteResponse,
    EnergyVoteRequest,
    EnergyVoteResponse,
)
from spitfight.colosseum.controller.controller import (
    Controller,
    init_global_controller,
    get_global_controller,
)
from spitfight.utils import prepend_generator


class ControllerConfig(BaseSettings):
    """Controller settings automatically loaded from environment variables."""
    # Controller
    background_task_interval: int = 300
    max_num_req_states: int = 10000
    req_state_expiration_time: int = 600
    compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"]

    # Logging
    log_dir: str = "/logs"
    controller_log_file: str = "controller.log"
    request_log_file: str = "requests.log"
    uvicorn_log_file: str = "uvicorn.log"

    # Generation
    max_new_tokens: int = 512
    do_sample: bool = True
    temperature: float = 1.0
    repetition_penalty: float = 1.0
    top_k: int = 50
    top_p: float = 0.95


app = FastAPI()
settings = ControllerConfig()
logger = get_logger("spitfight.colosseum.controller.router")

@app.on_event("startup")
async def startup_event():
    init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file))
    init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file))
    init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file))
    init_global_controller(settings)

@app.on_event("shutdown")
async def shutdown_event():
    get_global_controller().shutdown()
    shutdown_queued_root_loggers()

@app.get(COLOSSEUM_MODELS_ROUTE, response_model=ModelsResponse)
async def models(controller: Controller = Depends(get_global_controller)):
    return ModelsResponse(available_models=controller.get_available_models())

@app.post(COLOSSEUM_PROMPT_ROUTE)
async def prompt(
    request: PromptRequest,
    controller: Controller = Depends(get_global_controller),
):
    generator = controller.prompt(
        request.request_id,
        request.prompt,
        request.model_index,
        request.model_preference,
    )

    # First try to get the first token in order to catch TGI errors.
    try:
        first_token = await generator.__anext__()
    except OverloadedError:
        name = controller.request_states[request.request_id].model_names[request.model_index]
        logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request))
        raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.")
    except ValidationError as e:
        logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request))
        raise HTTPException(status_code=422, detail=str(e))
    except StopAsyncIteration:
        logger.info("TGI returned empty response. Failed request: %s", repr(request))
        return StreamingResponse(
            iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]),
        )
    except UnknownError as e:
        logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request))
        raise HTTPException(status_code=500, detail=str(e))

    return StreamingResponse(prepend_generator(first_token, generator))

@app.post(COLOSSEUM_RESP_VOTE_ROUTE, response_model=ResponseVoteResponse)
async def response_vote(
    request: ResponseVoteRequest,
    controller: Controller = Depends(get_global_controller),
):
    if (state := controller.response_vote(request.request_id, request.victory_index)) is None:
        raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
    return ResponseVoteResponse(
        energy_consumptions=state.energy_consumptions,
        model_names=state.model_names,
    )

@app.post(COLOSSEUM_ENERGY_VOTE_ROUTE, response_model=EnergyVoteResponse)
async def energy_vote(
    request: EnergyVoteRequest,
    controller: Controller = Depends(get_global_controller),
):
    if (state := controller.energy_vote(request.request_id, request.is_worth)) is None:
        raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
    return EnergyVoteResponse(model_names=state.model_names)

@app.get(COLOSSEUM_HEALTH_ROUTE)
async def health():
    return "OK"


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", log_config=None)