Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |
import httpx | |
from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse | |
from litellm.types.utils import ModelInfo | |
from ..chat.transformation import BaseLLMException | |
if TYPE_CHECKING: | |
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
LiteLLMLoggingObj = _LiteLLMLoggingObj | |
else: | |
LiteLLMLoggingObj = Any | |
class BaseRerankConfig(ABC): | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
api_key: Optional[str] = None, | |
) -> dict: | |
pass | |
def transform_rerank_request( | |
self, | |
model: str, | |
optional_rerank_params: OptionalRerankParams, | |
headers: dict, | |
) -> dict: | |
return {} | |
def transform_rerank_response( | |
self, | |
model: str, | |
raw_response: httpx.Response, | |
model_response: RerankResponse, | |
logging_obj: LiteLLMLoggingObj, | |
api_key: Optional[str] = None, | |
request_data: dict = {}, | |
optional_params: dict = {}, | |
litellm_params: dict = {}, | |
) -> RerankResponse: | |
return model_response | |
def get_complete_url(self, api_base: Optional[str], model: str) -> str: | |
""" | |
OPTIONAL | |
Get the complete url for the request | |
Some providers need `model` in `api_base` | |
""" | |
return api_base or "" | |
def get_supported_cohere_rerank_params(self, model: str) -> list: | |
pass | |
def map_cohere_rerank_params( | |
self, | |
non_default_params: dict, | |
model: str, | |
drop_params: bool, | |
query: str, | |
documents: List[Union[str, Dict[str, Any]]], | |
custom_llm_provider: Optional[str] = None, | |
top_n: Optional[int] = None, | |
rank_fields: Optional[List[str]] = None, | |
return_documents: Optional[bool] = True, | |
max_chunks_per_doc: Optional[int] = None, | |
max_tokens_per_doc: Optional[int] = None, | |
) -> OptionalRerankParams: | |
pass | |
def get_error_class( | |
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
) -> BaseLLMException: | |
raise BaseLLMException( | |
status_code=status_code, | |
message=error_message, | |
headers=headers, | |
) | |
def calculate_rerank_cost( | |
self, | |
model: str, | |
custom_llm_provider: Optional[str] = None, | |
billed_units: Optional[RerankBilledUnits] = None, | |
model_info: Optional[ModelInfo] = None, | |
) -> Tuple[float, float]: | |
""" | |
Calculates the cost per query for a given rerank model. | |
Input: | |
- model: str, the model name without provider prefix | |
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider. | |
- num_queries: int, the number of queries to calculate the cost for | |
- model_info: ModelInfo, the model info for the given model | |
Returns: | |
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd | |
""" | |
if ( | |
model_info is None | |
or "input_cost_per_query" not in model_info | |
or model_info["input_cost_per_query"] is None | |
or billed_units is None | |
): | |
return 0.0, 0.0 | |
search_units = billed_units.get("search_units") | |
if search_units is None: | |
return 0.0, 0.0 | |
prompt_cost = model_info["input_cost_per_query"] * search_units | |
return prompt_cost, 0.0 | |