Spaces:
Sleeping
Sleeping
import json | |
from typing import Any, Callable, Optional, Union | |
import httpx | |
import litellm | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
from litellm.llms.custom_httpx.http_handler import ( | |
AsyncHTTPHandler, | |
HTTPHandler, | |
get_async_httpx_client, | |
) | |
from litellm.types.llms.bedrock import CohereEmbeddingRequest | |
from litellm.types.utils import EmbeddingResponse | |
from .transformation import CohereEmbeddingConfig | |
def validate_environment(api_key, headers: dict): | |
headers.update( | |
{ | |
"Request-Source": "unspecified:litellm", | |
"accept": "application/json", | |
"content-type": "application/json", | |
} | |
) | |
if api_key: | |
headers["Authorization"] = f"Bearer {api_key}" | |
return headers | |
class CohereError(Exception): | |
def __init__(self, status_code, message): | |
self.status_code = status_code | |
self.message = message | |
self.request = httpx.Request( | |
method="POST", url="https://api.cohere.ai/v1/generate" | |
) | |
self.response = httpx.Response(status_code=status_code, request=self.request) | |
super().__init__( | |
self.message | |
) # Call the base class constructor with the parameters it needs | |
async def async_embedding( | |
model: str, | |
data: Union[dict, CohereEmbeddingRequest], | |
input: list, | |
model_response: litellm.utils.EmbeddingResponse, | |
timeout: Optional[Union[float, httpx.Timeout]], | |
logging_obj: LiteLLMLoggingObj, | |
optional_params: dict, | |
api_base: str, | |
api_key: Optional[str], | |
headers: dict, | |
encoding: Callable, | |
client: Optional[AsyncHTTPHandler] = None, | |
): | |
## LOGGING | |
logging_obj.pre_call( | |
input=input, | |
api_key=api_key, | |
additional_args={ | |
"complete_input_dict": data, | |
"headers": headers, | |
"api_base": api_base, | |
}, | |
) | |
## COMPLETION CALL | |
if client is None: | |
client = get_async_httpx_client( | |
llm_provider=litellm.LlmProviders.COHERE, | |
params={"timeout": timeout}, | |
) | |
try: | |
response = await client.post(api_base, headers=headers, data=json.dumps(data)) | |
except httpx.HTTPStatusError as e: | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=e.response.text, | |
) | |
raise e | |
except Exception as e: | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=str(e), | |
) | |
raise e | |
## PROCESS RESPONSE ## | |
return CohereEmbeddingConfig()._transform_response( | |
response=response, | |
api_key=api_key, | |
logging_obj=logging_obj, | |
data=data, | |
model_response=model_response, | |
model=model, | |
encoding=encoding, | |
input=input, | |
) | |
def embedding( | |
model: str, | |
input: list, | |
model_response: EmbeddingResponse, | |
logging_obj: LiteLLMLoggingObj, | |
optional_params: dict, | |
headers: dict, | |
encoding: Any, | |
data: Optional[Union[dict, CohereEmbeddingRequest]] = None, | |
complete_api_base: Optional[str] = None, | |
api_key: Optional[str] = None, | |
aembedding: Optional[bool] = None, | |
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None), | |
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
): | |
headers = validate_environment(api_key, headers=headers) | |
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed" | |
model = model | |
data = data or CohereEmbeddingConfig()._transform_request( | |
model=model, input=input, inference_params=optional_params | |
) | |
## ROUTING | |
if aembedding is True: | |
return async_embedding( | |
model=model, | |
data=data, | |
input=input, | |
model_response=model_response, | |
timeout=timeout, | |
logging_obj=logging_obj, | |
optional_params=optional_params, | |
api_base=embed_url, | |
api_key=api_key, | |
headers=headers, | |
encoding=encoding, | |
client=( | |
client | |
if client is not None and isinstance(client, AsyncHTTPHandler) | |
else None | |
), | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
) | |
## COMPLETION CALL | |
if client is None or not isinstance(client, HTTPHandler): | |
client = HTTPHandler(concurrent_limit=1) | |
response = client.post(embed_url, headers=headers, data=json.dumps(data)) | |
return CohereEmbeddingConfig()._transform_response( | |
response=response, | |
api_key=api_key, | |
logging_obj=logging_obj, | |
data=data, | |
model_response=model_response, | |
model=model, | |
encoding=encoding, | |
input=input, | |
) | |