Spaces:
Sleeping
Sleeping
File size: 4,648 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import List, Optional, Union
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage
from ..common_utils import InfinityError
class InfinityEmbeddingConfig(BaseEmbeddingConfig):
"""
Reference: https://infinity.modal.michaelfeil.eu/docs
"""
def __init__(self) -> None:
pass
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity embeddings")
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"
return api_base
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = get_secret_str("INFINITY_API_KEY")
default_headers = {
"Authorization": f"Bearer {api_key}",
"accept": "application/json",
"Content-Type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def get_supported_openai_params(self, model: str) -> list:
return [
"encoding_format",
"modality",
"dimensions",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Infinity params
Reference: https://infinity.modal.michaelfeil.eu/docs
"""
if "encoding_format" in non_default_params:
optional_params["encoding_format"] = non_default_params["encoding_format"]
if "modality" in non_default_params:
optional_params["modality"] = non_default_params["modality"]
if "dimensions" in non_default_params:
optional_params["output_dimension"] = non_default_params["dimensions"]
return optional_params
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
return {
"input": input,
"model": model,
**optional_params,
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise InfinityError(
message=raw_response.text, status_code=raw_response.status_code
)
# model_response.usage
model_response.model = raw_response_json.get("model")
model_response.data = raw_response_json.get("data")
model_response.object = raw_response_json.get("object")
usage = Usage(
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
)
model_response.usage = usage
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return InfinityError(
message=error_message, status_code=status_code, headers=headers
)
|