Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |
import httpx | |
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest | |
if TYPE_CHECKING: | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
LoggingClass = LiteLLMLoggingObj | |
else: | |
LoggingClass = Any | |
from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
from ...openai.chat.gpt_transformation import OpenAIGPTConfig | |
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping | |
logger = logging.getLogger(__name__) | |
BASE_URL = "https://router.huggingface.co" | |
class HuggingFaceChatConfig(OpenAIGPTConfig): | |
""" | |
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference | |
""" | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: Dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> dict: | |
default_headers = { | |
"content-type": "application/json", | |
} | |
if api_key is not None: | |
default_headers["Authorization"] = f"Bearer {api_key}" | |
headers = {**headers, **default_headers} | |
return headers | |
def get_error_class( | |
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
) -> BaseLLMException: | |
return HuggingFaceError( | |
status_code=status_code, message=error_message, headers=headers | |
) | |
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]: | |
""" | |
Get the API base for the Huggingface API. | |
Do not add the chat/embedding/rerank extension here. Let the handler do this. | |
""" | |
if model.startswith(("http://", "https://")): | |
base_url = model | |
elif base_url is None: | |
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "") | |
return base_url | |
def get_complete_url( | |
self, | |
api_base: Optional[str], | |
api_key: Optional[str], | |
model: str, | |
optional_params: dict, | |
litellm_params: dict, | |
stream: Optional[bool] = None, | |
) -> str: | |
""" | |
Get the complete URL for the API call. | |
For provider-specific routing through huggingface | |
""" | |
# 1. Check if api_base is provided | |
if api_base is not None: | |
complete_url = api_base | |
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"): | |
complete_url = str(os.getenv("HF_API_BASE")) or str( | |
os.getenv("HUGGINGFACE_API_BASE") | |
) | |
elif model.startswith(("http://", "https://")): | |
complete_url = model | |
# 4. Default construction with provider | |
else: | |
# Parse provider and model | |
first_part, remaining = model.split("/", 1) | |
if "/" in remaining: | |
provider = first_part | |
else: | |
provider = "hf-inference" | |
if provider == "hf-inference": | |
route = f"{provider}/models/{model}/v1/chat/completions" | |
elif provider == "novita": | |
route = f"{provider}/chat/completions" | |
else: | |
route = f"{provider}/v1/chat/completions" | |
complete_url = f"{BASE_URL}/{route}" | |
# Ensure URL doesn't end with a slash | |
complete_url = complete_url.rstrip("/") | |
return complete_url | |
def transform_request( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
headers: dict, | |
) -> dict: | |
if "max_retries" in optional_params: | |
logger.warning("`max_retries` is not supported. It will be ignored.") | |
optional_params.pop("max_retries", None) | |
first_part, remaining = model.split("/", 1) | |
if "/" in remaining: | |
provider = first_part | |
model_id = remaining | |
else: | |
provider = "hf-inference" | |
model_id = model | |
provider_mapping = _fetch_inference_provider_mapping(model_id) | |
if provider not in provider_mapping: | |
raise HuggingFaceError( | |
message=f"Model {model_id} is not supported for provider {provider}", | |
status_code=404, | |
headers={}, | |
) | |
provider_mapping = provider_mapping[provider] | |
if provider_mapping["status"] == "staging": | |
logger.warning( | |
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only." | |
) | |
mapped_model = provider_mapping["providerId"] | |
messages = self._transform_messages(messages=messages, model=mapped_model) | |
return dict( | |
ChatCompletionRequest( | |
model=mapped_model, messages=messages, **optional_params | |
) | |
) | |