Spaces:
Sleeping
Sleeping
import json | |
import os | |
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args | |
import httpx | |
import litellm | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
from litellm.llms.custom_httpx.http_handler import ( | |
AsyncHTTPHandler, | |
HTTPHandler, | |
get_async_httpx_client, | |
) | |
from litellm.types.utils import EmbeddingResponse | |
from ...base import BaseLLM | |
from ..common_utils import HuggingFaceError | |
from .transformation import HuggingFaceEmbeddingConfig | |
config = HuggingFaceEmbeddingConfig() | |
HF_HUB_URL = "https://huggingface.co" | |
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/ | |
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity" | |
] | |
def get_hf_task_embedding_for_model( | |
model: str, task_type: Optional[str], api_base: str | |
) -> Optional[str]: | |
if task_type is not None: | |
if task_type in get_args(hf_tasks_embeddings): | |
return task_type | |
else: | |
raise Exception( | |
"Invalid task_type={}. Expected one of={}".format( | |
task_type, hf_tasks_embeddings | |
) | |
) | |
http_client = HTTPHandler(concurrent_limit=1) | |
model_info = http_client.get(url=f"{api_base}/api/models/{model}") | |
model_info_dict = model_info.json() | |
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None) | |
return pipeline_tag | |
async def async_get_hf_task_embedding_for_model( | |
model: str, task_type: Optional[str], api_base: str | |
) -> Optional[str]: | |
if task_type is not None: | |
if task_type in get_args(hf_tasks_embeddings): | |
return task_type | |
else: | |
raise Exception( | |
"Invalid task_type={}. Expected one of={}".format( | |
task_type, hf_tasks_embeddings | |
) | |
) | |
http_client = get_async_httpx_client( | |
llm_provider=litellm.LlmProviders.HUGGINGFACE, | |
) | |
model_info = await http_client.get(url=f"{api_base}/api/models/{model}") | |
model_info_dict = model_info.json() | |
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None) | |
return pipeline_tag | |
class HuggingFaceEmbedding(BaseLLM): | |
_client_session: Optional[httpx.Client] = None | |
_aclient_session: Optional[httpx.AsyncClient] = None | |
def __init__(self) -> None: | |
super().__init__() | |
def _transform_input_on_pipeline_tag( | |
self, input: List, pipeline_tag: Optional[str] | |
) -> dict: | |
if pipeline_tag is None: | |
return {"inputs": input} | |
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity": | |
if len(input) < 2: | |
raise HuggingFaceError( | |
status_code=400, | |
message="sentence-similarity requires 2+ sentences", | |
) | |
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}} | |
elif pipeline_tag == "rerank": | |
if len(input) < 2: | |
raise HuggingFaceError( | |
status_code=400, | |
message="reranker requires 2+ sentences", | |
) | |
return {"inputs": {"query": input[0], "texts": input[1:]}} | |
return {"inputs": input} # default to feature-extraction pipeline tag | |
async def _async_transform_input( | |
self, | |
model: str, | |
task_type: Optional[str], | |
embed_url: str, | |
input: List, | |
optional_params: dict, | |
) -> dict: | |
hf_task = await async_get_hf_task_embedding_for_model( | |
model=model, task_type=task_type, api_base=HF_HUB_URL | |
) | |
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task) | |
if len(optional_params.keys()) > 0: | |
data["options"] = optional_params | |
return data | |
def _process_optional_params(self, data: dict, optional_params: dict) -> dict: | |
special_options_keys = config.get_special_options_params() | |
special_parameters_keys = [ | |
"min_length", | |
"max_length", | |
"top_k", | |
"top_p", | |
"temperature", | |
"repetition_penalty", | |
"max_time", | |
] | |
for k, v in optional_params.items(): | |
if k in special_options_keys: | |
data.setdefault("options", {}) | |
data["options"][k] = v | |
elif k in special_parameters_keys: | |
data.setdefault("parameters", {}) | |
data["parameters"][k] = v | |
else: | |
data[k] = v | |
return data | |
def _transform_input( | |
self, | |
input: List, | |
model: str, | |
call_type: Literal["sync", "async"], | |
optional_params: dict, | |
embed_url: str, | |
) -> dict: | |
data: Dict = {} | |
## TRANSFORMATION ## | |
if "sentence-transformers" in model: | |
if len(input) == 0: | |
raise HuggingFaceError( | |
status_code=400, | |
message="sentence transformers requires 2+ sentences", | |
) | |
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}} | |
else: | |
data = {"inputs": input} | |
task_type = optional_params.pop("input_type", None) | |
if call_type == "sync": | |
hf_task = get_hf_task_embedding_for_model( | |
model=model, task_type=task_type, api_base=HF_HUB_URL | |
) | |
elif call_type == "async": | |
return self._async_transform_input( | |
model=model, task_type=task_type, embed_url=embed_url, input=input | |
) # type: ignore | |
data = self._transform_input_on_pipeline_tag( | |
input=input, pipeline_tag=hf_task | |
) | |
if len(optional_params.keys()) > 0: | |
data = self._process_optional_params( | |
data=data, optional_params=optional_params | |
) | |
return data | |
def _process_embedding_response( | |
self, | |
embeddings: dict, | |
model_response: EmbeddingResponse, | |
model: str, | |
input: List, | |
encoding: Any, | |
) -> EmbeddingResponse: | |
output_data = [] | |
if "similarities" in embeddings: | |
for idx, embedding in embeddings["similarities"]: | |
output_data.append( | |
{ | |
"object": "embedding", | |
"index": idx, | |
"embedding": embedding, # flatten list returned from hf | |
} | |
) | |
else: | |
for idx, embedding in enumerate(embeddings): | |
if isinstance(embedding, float): | |
output_data.append( | |
{ | |
"object": "embedding", | |
"index": idx, | |
"embedding": embedding, # flatten list returned from hf | |
} | |
) | |
elif isinstance(embedding, list) and isinstance(embedding[0], float): | |
output_data.append( | |
{ | |
"object": "embedding", | |
"index": idx, | |
"embedding": embedding, # flatten list returned from hf | |
} | |
) | |
else: | |
output_data.append( | |
{ | |
"object": "embedding", | |
"index": idx, | |
"embedding": embedding[0][ | |
0 | |
], # flatten list returned from hf | |
} | |
) | |
model_response.object = "list" | |
model_response.data = output_data | |
model_response.model = model | |
input_tokens = 0 | |
for text in input: | |
input_tokens += len(encoding.encode(text)) | |
setattr( | |
model_response, | |
"usage", | |
litellm.Usage( | |
prompt_tokens=input_tokens, | |
completion_tokens=input_tokens, | |
total_tokens=input_tokens, | |
prompt_tokens_details=None, | |
completion_tokens_details=None, | |
), | |
) | |
return model_response | |
async def aembedding( | |
self, | |
model: str, | |
input: list, | |
model_response: litellm.utils.EmbeddingResponse, | |
timeout: Union[float, httpx.Timeout], | |
logging_obj: LiteLLMLoggingObj, | |
optional_params: dict, | |
api_base: str, | |
api_key: Optional[str], | |
headers: dict, | |
encoding: Callable, | |
client: Optional[AsyncHTTPHandler] = None, | |
): | |
## TRANSFORMATION ## | |
data = self._transform_input( | |
input=input, | |
model=model, | |
call_type="sync", | |
optional_params=optional_params, | |
embed_url=api_base, | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=input, | |
api_key=api_key, | |
additional_args={ | |
"complete_input_dict": data, | |
"headers": headers, | |
"api_base": api_base, | |
}, | |
) | |
## COMPLETION CALL | |
if client is None: | |
client = get_async_httpx_client( | |
llm_provider=litellm.LlmProviders.HUGGINGFACE, | |
) | |
response = await client.post(api_base, headers=headers, data=json.dumps(data)) | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=response, | |
) | |
embeddings = response.json() | |
if "error" in embeddings: | |
raise HuggingFaceError(status_code=500, message=embeddings["error"]) | |
## PROCESS RESPONSE ## | |
return self._process_embedding_response( | |
embeddings=embeddings, | |
model_response=model_response, | |
model=model, | |
input=input, | |
encoding=encoding, | |
) | |
def embedding( | |
self, | |
model: str, | |
input: list, | |
model_response: EmbeddingResponse, | |
optional_params: dict, | |
litellm_params: dict, | |
logging_obj: LiteLLMLoggingObj, | |
encoding: Callable, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), | |
aembedding: Optional[bool] = None, | |
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
headers={}, | |
) -> EmbeddingResponse: | |
super().embedding() | |
headers = config.validate_environment( | |
api_key=api_key, | |
headers=headers, | |
model=model, | |
optional_params=optional_params, | |
messages=[], | |
litellm_params=litellm_params, | |
) | |
task_type = optional_params.pop("input_type", None) | |
task = get_hf_task_embedding_for_model( | |
model=model, task_type=task_type, api_base=HF_HUB_URL | |
) | |
# print_verbose(f"{model}, {task}") | |
embed_url = "" | |
if "https" in model: | |
embed_url = model | |
elif api_base: | |
embed_url = api_base | |
elif "HF_API_BASE" in os.environ: | |
embed_url = os.getenv("HF_API_BASE", "") | |
elif "HUGGINGFACE_API_BASE" in os.environ: | |
embed_url = os.getenv("HUGGINGFACE_API_BASE", "") | |
else: | |
embed_url = ( | |
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}" | |
) | |
## ROUTING ## | |
if aembedding is True: | |
return self.aembedding( | |
input=input, | |
model_response=model_response, | |
timeout=timeout, | |
logging_obj=logging_obj, | |
headers=headers, | |
api_base=embed_url, # type: ignore | |
api_key=api_key, | |
client=client if isinstance(client, AsyncHTTPHandler) else None, | |
model=model, | |
optional_params=optional_params, | |
encoding=encoding, | |
) | |
## TRANSFORMATION ## | |
data = self._transform_input( | |
input=input, | |
model=model, | |
call_type="sync", | |
optional_params=optional_params, | |
embed_url=embed_url, | |
) | |
## LOGGING | |
logging_obj.pre_call( | |
input=input, | |
api_key=api_key, | |
additional_args={ | |
"complete_input_dict": data, | |
"headers": headers, | |
"api_base": embed_url, | |
}, | |
) | |
## COMPLETION CALL | |
if client is None or not isinstance(client, HTTPHandler): | |
client = HTTPHandler(concurrent_limit=1) | |
response = client.post(embed_url, headers=headers, data=json.dumps(data)) | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=response, | |
) | |
embeddings = response.json() | |
if "error" in embeddings: | |
raise HuggingFaceError(status_code=500, message=embeddings["error"]) | |
## PROCESS RESPONSE ## | |
return self._process_embedding_response( | |
embeddings=embeddings, | |
model_response=model_response, | |
model=model, | |
input=input, | |
encoding=encoding, | |
) | |