File size: 2,683 Bytes
7db0ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# This file runs a health check for the LLM, used on litellm/proxy

import asyncio
import random
from typing import Optional

import litellm
import logging
from litellm._logging import print_verbose


logger = logging.getLogger(__name__)


ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key", "prompt", "input"]


def _get_random_llm_message():
    """
    Get a random message from the LLM.
    """
    messages = ["Hey how's it going?", "What's 1 + 1?"]

    return [{"role": "user", "content": random.choice(messages)}]


def _clean_litellm_params(litellm_params: dict):
    """
    Clean the litellm params for display to users.
    """
    return {k: v for k, v in litellm_params.items() if k not in ILLEGAL_DISPLAY_PARAMS}


async def _perform_health_check(model_list: list):
    """
    Perform a health check for each model in the list.
    """
    tasks = []
    for model in model_list:
        litellm_params = model["litellm_params"]
        model_info = model.get("model_info", {})
        litellm_params["messages"] = _get_random_llm_message()
        mode = model_info.get("mode", None)
        tasks.append(
            litellm.ahealth_check(
                litellm_params,
                mode=mode,
                prompt="test from litellm",
                input=["test from litellm"],
            )
        )

    results = await asyncio.gather(*tasks)

    healthy_endpoints = []
    unhealthy_endpoints = []

    for is_healthy, model in zip(results, model_list):
        cleaned_litellm_params = _clean_litellm_params(model["litellm_params"])

        if isinstance(is_healthy, dict) and "error" not in is_healthy:
            healthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
        elif isinstance(is_healthy, dict):
            unhealthy_endpoints.append({**cleaned_litellm_params, **is_healthy})
        else:
            unhealthy_endpoints.append(cleaned_litellm_params)

    return healthy_endpoints, unhealthy_endpoints


async def perform_health_check(
    model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None
):
    """
    Perform a health check on the system.

    Returns:
        (bool): True if the health check passes, False otherwise.
    """
    if not model_list:
        if cli_model:
            model_list = [
                {"model_name": cli_model, "litellm_params": {"model": cli_model}}
            ]
        else:
            return [], []

    if model is not None:
        model_list = [x for x in model_list if x["litellm_params"]["model"] == model]

    healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)

    return healthy_endpoints, unhealthy_endpoints