Spaces:
Paused
Paused
| """ | |
| Re rank api | |
| LiteLLM supports the re rank API format, no paramter transformation occurs | |
| """ | |
| from typing import Any, Dict, List, Optional, Union | |
| import litellm | |
| from litellm.llms.base import BaseLLM | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| _get_httpx_client, | |
| get_async_httpx_client, | |
| ) | |
| from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig | |
| from litellm.types.rerank import RerankRequest, RerankResponse | |
| class TogetherAIRerank(BaseLLM): | |
| def rerank( | |
| self, | |
| model: str, | |
| api_key: str, | |
| query: str, | |
| documents: List[Union[str, Dict[str, Any]]], | |
| top_n: Optional[int] = None, | |
| rank_fields: Optional[List[str]] = None, | |
| return_documents: Optional[bool] = True, | |
| max_chunks_per_doc: Optional[int] = None, | |
| _is_async: Optional[bool] = False, | |
| ) -> RerankResponse: | |
| client = _get_httpx_client() | |
| request_data = RerankRequest( | |
| model=model, | |
| query=query, | |
| top_n=top_n, | |
| documents=documents, | |
| rank_fields=rank_fields, | |
| return_documents=return_documents, | |
| ) | |
| # exclude None values from request_data | |
| request_data_dict = request_data.dict(exclude_none=True) | |
| if max_chunks_per_doc is not None: | |
| raise ValueError("TogetherAI does not support max_chunks_per_doc") | |
| if _is_async: | |
| return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method | |
| response = client.post( | |
| "https://api.together.xyz/v1/rerank", | |
| headers={ | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| "authorization": f"Bearer {api_key}", | |
| }, | |
| json=request_data_dict, | |
| ) | |
| if response.status_code != 200: | |
| raise Exception(response.text) | |
| _json_response = response.json() | |
| return TogetherAIRerankConfig()._transform_response(_json_response) | |
| async def async_rerank( # New async method | |
| self, | |
| request_data_dict: Dict[str, Any], | |
| api_key: str, | |
| ) -> RerankResponse: | |
| client = get_async_httpx_client( | |
| llm_provider=litellm.LlmProviders.TOGETHER_AI | |
| ) # Use async client | |
| response = await client.post( | |
| "https://api.together.xyz/v1/rerank", | |
| headers={ | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| "authorization": f"Bearer {api_key}", | |
| }, | |
| json=request_data_dict, | |
| ) | |
| if response.status_code != 200: | |
| raise Exception(response.text) | |
| _json_response = response.json() | |
| return TogetherAIRerankConfig()._transform_response(_json_response) | |