Spaces:
Running
Running
from typing import Any, Dict | |
from enum import Enum | |
from langchain_community.chat_models.huggingface import ChatHuggingFace | |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
from langchain_core import pydantic_v1 | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.utils import get_from_dict_or_env | |
from langchain_openai import ChatOpenAI | |
class LLMBackends(Enum): | |
"""LLMBackends | |
Enum for LLMBackends. | |
""" | |
VLLM = "VLLM" | |
HFChat = "HFChat" | |
Fireworks = "Fireworks" | |
class LazyChatHuggingFace(ChatHuggingFace): | |
"""LazyChatHuggingFace""" | |
def __init__(self, **kwargs: Any): | |
BaseChatModel.__init__(self, **kwargs) | |
from transformers import AutoTokenizer | |
if not self.model_id: | |
self._resolve_model_id() | |
self.tokenizer = ( | |
AutoTokenizer.from_pretrained(self.model_id) | |
if self.tokenizer is None | |
else self.tokenizer | |
) | |
class LazyHuggingFaceEndpoint(HuggingFaceEndpoint): | |
"""LazyHuggingFaceEndpoint""" | |
# We're using a lazy endpoint to avoid logging in with hf_token, | |
# which might in fact be a hf_oauth token that does only permit inference, | |
# not logging in. | |
def validate_environment(cls, values: Dict) -> Dict: # noqa: UP006, N805 | |
"""Validate that package is installed and that the API token is valid.""" | |
try: | |
from huggingface_hub import AsyncInferenceClient, InferenceClient | |
except ImportError: | |
msg = ( | |
"Could not import huggingface_hub python package. " | |
"Please install it with `pip install huggingface_hub`." | |
) | |
raise ImportError(msg) # noqa: B904 | |
huggingfacehub_api_token = get_from_dict_or_env( | |
values, "huggingfacehub_api_token", "HF_TOKEN" | |
) | |
values["client"] = InferenceClient( | |
model=values["model"], | |
timeout=values["timeout"], | |
token=huggingfacehub_api_token, | |
**values["server_kwargs"], | |
) | |
values["async_client"] = AsyncInferenceClient( | |
model=values["model"], | |
timeout=values["timeout"], | |
token=huggingfacehub_api_token, | |
**values["server_kwargs"], | |
) | |
return values | |
def get_chat_model_wrapper( | |
model_id: str, | |
inference_server_url: str, | |
token: str, | |
backend: str = "HuggingFaceEndpoint", | |
**model_init_kwargs | |
): | |
backend = LLMBackends(backend) | |
if backend == LLMBackends.HFChat: | |
llm = LazyHuggingFaceEndpoint( | |
endpoint_url=inference_server_url, | |
task="text-generation", | |
huggingfacehub_api_token=token, | |
**model_init_kwargs, | |
) | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer) | |
elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]: | |
chat_model = ChatOpenAI( | |
model=model_id, | |
openai_api_base=inference_server_url, # type: ignore | |
openai_api_key=token, # type: ignore | |
**model_init_kwargs, | |
) | |
return chat_model | |