Spaces:
Sleeping
Sleeping
""" | |
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint. | |
""" | |
import json | |
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union | |
from httpx import Headers, Response | |
from litellm.constants import DEFAULT_MAX_TOKENS_FOR_TRITON | |
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory | |
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator | |
from litellm.llms.base_llm.chat.transformation import ( | |
BaseConfig, | |
BaseLLMException, | |
LiteLLMLoggingObj, | |
) | |
from litellm.types.llms.openai import AllMessageValues | |
from litellm.types.utils import ( | |
ChatCompletionToolCallChunk, | |
ChatCompletionUsageBlock, | |
Choices, | |
GenericStreamingChunk, | |
Message, | |
ModelResponse, | |
) | |
from ..common_utils import TritonError | |
class TritonConfig(BaseConfig): | |
""" | |
Base class for Triton configurations. | |
Handles routing between /infer and /generate triton completion llms | |
""" | |
def get_error_class( | |
self, error_message: str, status_code: int, headers: Union[Dict, Headers] | |
) -> BaseLLMException: | |
return TritonError( | |
status_code=status_code, message=error_message, headers=headers | |
) | |
def validate_environment( | |
self, | |
headers: Dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: Dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> Dict: | |
return {"Content-Type": "application/json"} | |
def get_supported_openai_params(self, model: str) -> List: | |
return ["max_tokens", "max_completion_tokens"] | |
def map_openai_params( | |
self, | |
non_default_params: Dict, | |
optional_params: Dict, | |
model: str, | |
drop_params: bool, | |
) -> Dict: | |
for param, value in non_default_params.items(): | |
if param == "max_tokens" or param == "max_completion_tokens": | |
optional_params[param] = value | |
return optional_params | |
def get_complete_url( | |
self, | |
api_base: Optional[str], | |
api_key: Optional[str], | |
model: str, | |
optional_params: dict, | |
litellm_params: dict, | |
stream: Optional[bool] = None, | |
) -> str: | |
if api_base is None: | |
raise ValueError("api_base is required") | |
llm_type = self._get_triton_llm_type(api_base) | |
if llm_type == "generate" and stream: | |
return api_base + "_stream" | |
return api_base | |
def transform_response( | |
self, | |
model: str, | |
raw_response: Response, | |
model_response: ModelResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: Dict, | |
messages: List[AllMessageValues], | |
optional_params: Dict, | |
litellm_params: Dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
json_mode: Optional[bool] = None, | |
) -> ModelResponse: | |
api_base = litellm_params.get("api_base", "") | |
llm_type = self._get_triton_llm_type(api_base) | |
if llm_type == "generate": | |
return TritonGenerateConfig().transform_response( | |
model=model, | |
raw_response=raw_response, | |
model_response=model_response, | |
logging_obj=logging_obj, | |
request_data=request_data, | |
messages=messages, | |
optional_params=optional_params, | |
litellm_params=litellm_params, | |
encoding=encoding, | |
api_key=api_key, | |
json_mode=json_mode, | |
) | |
elif llm_type == "infer": | |
return TritonInferConfig().transform_response( | |
model=model, | |
raw_response=raw_response, | |
model_response=model_response, | |
logging_obj=logging_obj, | |
request_data=request_data, | |
messages=messages, | |
optional_params=optional_params, | |
litellm_params=litellm_params, | |
encoding=encoding, | |
api_key=api_key, | |
json_mode=json_mode, | |
) | |
return model_response | |
def transform_request( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
headers: dict, | |
) -> dict: | |
api_base = litellm_params.get("api_base", "") | |
llm_type = self._get_triton_llm_type(api_base) | |
if llm_type == "generate": | |
return TritonGenerateConfig().transform_request( | |
model=model, | |
messages=messages, | |
optional_params=optional_params, | |
litellm_params=litellm_params, | |
headers=headers, | |
) | |
elif llm_type == "infer": | |
return TritonInferConfig().transform_request( | |
model=model, | |
messages=messages, | |
optional_params=optional_params, | |
litellm_params=litellm_params, | |
headers=headers, | |
) | |
return {} | |
def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]: | |
if api_base.endswith("/generate"): | |
return "generate" | |
elif api_base.endswith("/infer"): | |
return "infer" | |
else: | |
raise ValueError(f"Invalid Triton API base: {api_base}") | |
def get_model_response_iterator( | |
self, | |
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
sync_stream: bool, | |
json_mode: Optional[bool] = False, | |
) -> Any: | |
return TritonResponseIterator( | |
streaming_response=streaming_response, | |
sync_stream=sync_stream, | |
json_mode=json_mode, | |
) | |
class TritonGenerateConfig(TritonConfig): | |
""" | |
Transformations for triton /generate endpoint (This is a trtllm model) | |
""" | |
def transform_request( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
headers: dict, | |
) -> dict: | |
inference_params = optional_params.copy() | |
stream = inference_params.pop("stream", False) | |
data_for_triton: Dict[str, Any] = { | |
"text_input": prompt_factory(model=model, messages=messages), | |
"parameters": { | |
"max_tokens": int( | |
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON) | |
), | |
}, | |
"stream": bool(stream), | |
} | |
data_for_triton["parameters"].update(inference_params) | |
return data_for_triton | |
def transform_response( | |
self, | |
model: str, | |
raw_response: Response, | |
model_response: ModelResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: Dict, | |
messages: List[AllMessageValues], | |
optional_params: Dict, | |
litellm_params: Dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
json_mode: Optional[bool] = None, | |
) -> ModelResponse: | |
try: | |
raw_response_json = raw_response.json() | |
except Exception: | |
raise TritonError( | |
message=raw_response.text, status_code=raw_response.status_code | |
) | |
model_response.choices = [ | |
Choices(index=0, message=Message(content=raw_response_json["text_output"])) | |
] | |
return model_response | |
class TritonInferConfig(TritonConfig): | |
""" | |
Transformations for triton /infer endpoint (his is an infer model with a custom model on triton) | |
""" | |
def transform_request( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
headers: dict, | |
) -> dict: | |
text_input = messages[0].get("content", "") | |
data_for_triton = { | |
"inputs": [ | |
{ | |
"name": "text_input", | |
"shape": [1], | |
"datatype": "BYTES", | |
"data": [text_input], | |
} | |
] | |
} | |
for k, v in optional_params.items(): | |
if not (k == "stream" or k == "max_retries"): | |
datatype = "INT32" if isinstance(v, int) else "BYTES" | |
datatype = "FP32" if isinstance(v, float) else datatype | |
data_for_triton["inputs"].append( | |
{"name": k, "shape": [1], "datatype": datatype, "data": [v]} | |
) | |
if "max_tokens" not in optional_params: | |
data_for_triton["inputs"].append( | |
{ | |
"name": "max_tokens", | |
"shape": [1], | |
"datatype": "INT32", | |
"data": [20], | |
} | |
) | |
return data_for_triton | |
def transform_response( | |
self, | |
model: str, | |
raw_response: Response, | |
model_response: ModelResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: Dict, | |
messages: List[AllMessageValues], | |
optional_params: Dict, | |
litellm_params: Dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
json_mode: Optional[bool] = None, | |
) -> ModelResponse: | |
try: | |
raw_response_json = raw_response.json() | |
except Exception: | |
raise TritonError( | |
message=raw_response.text, status_code=raw_response.status_code | |
) | |
_triton_response_data = raw_response_json["outputs"][0]["data"] | |
triton_response_data: Optional[str] = None | |
if isinstance(_triton_response_data, list): | |
triton_response_data = "".join(_triton_response_data) | |
else: | |
triton_response_data = _triton_response_data | |
model_response.choices = [ | |
Choices( | |
index=0, | |
message=Message(content=triton_response_data), | |
) | |
] | |
return model_response | |
class TritonResponseIterator(BaseModelResponseIterator): | |
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: | |
try: | |
text = "" | |
tool_use: Optional[ChatCompletionToolCallChunk] = None | |
is_finished = False | |
finish_reason = "" | |
usage: Optional[ChatCompletionUsageBlock] = None | |
provider_specific_fields = None | |
index = int(chunk.get("index", 0)) | |
# set values | |
text = chunk.get("text_output", "") | |
finish_reason = chunk.get("stop_reason", "") | |
is_finished = chunk.get("is_finished", False) | |
return GenericStreamingChunk( | |
text=text, | |
tool_use=tool_use, | |
is_finished=is_finished, | |
finish_reason=finish_reason, | |
usage=usage, | |
index=index, | |
provider_specific_fields=provider_specific_fields, | |
) | |
except json.JSONDecodeError: | |
raise ValueError(f"Failed to decode JSON from chunk: {chunk}") | |