Spaces:
Running
Running
| from typing import Any, Dict | |
| from enum import Enum | |
| #from langchain_community.chat_models.huggingface import ChatHuggingFace | |
| #from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
| from langchain_core import pydantic_v1 | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.utils import get_from_dict_or_env | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_openai import ChatOpenAI | |
| class LLMBackends(Enum): | |
| """LLMBackends | |
| Enum for LLMBackends. | |
| """ | |
| VLLM = "VLLM" | |
| HFChat = "HFChat" | |
| Fireworks = "Fireworks" | |
| class LazyChatHuggingFace(ChatHuggingFace): | |
| """LazyChatHuggingFace""" | |
| def __init__(self, **kwargs: Any): | |
| BaseChatModel.__init__(self, **kwargs) | |
| from transformers import AutoTokenizer | |
| if not self.model_id: | |
| self._resolve_model_id() | |
| self.tokenizer = ( | |
| AutoTokenizer.from_pretrained(self.model_id) | |
| if self.tokenizer is None | |
| else self.tokenizer | |
| ) | |
| class LazyHuggingFaceEndpoint(HuggingFaceEndpoint): | |
| """LazyHuggingFaceEndpoint""" | |
| # We're using a lazy endpoint to avoid logging in with hf_token, | |
| # which might in fact be a hf_oauth token that does only permit inference, | |
| # not logging in. | |
| def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
| return super().build_extra(values) | |
| def validate_environment(cls, values: Dict) -> Dict: # noqa: UP006, N805 | |
| """Validate that package is installed and that the API token is valid.""" | |
| try: | |
| from huggingface_hub import AsyncInferenceClient, InferenceClient | |
| except ImportError: | |
| msg = ( | |
| "Could not import huggingface_hub python package. " | |
| "Please install it with `pip install huggingface_hub`." | |
| ) | |
| raise ImportError(msg) # noqa: B904 | |
| huggingfacehub_api_token = get_from_dict_or_env( | |
| values, "huggingfacehub_api_token", "HF_TOKEN" | |
| ) | |
| values["client"] = InferenceClient( | |
| model=values["model"], | |
| timeout=values["timeout"], | |
| token=huggingfacehub_api_token, | |
| **values["server_kwargs"], | |
| ) | |
| values["async_client"] = AsyncInferenceClient( | |
| model=values["model"], | |
| timeout=values["timeout"], | |
| token=huggingfacehub_api_token, | |
| **values["server_kwargs"], | |
| ) | |
| return values | |
| def get_chat_model_wrapper( | |
| model_id: str, | |
| inference_server_url: str, | |
| token: str, | |
| backend: str = "HFChat", | |
| **model_init_kwargs | |
| ): | |
| backend = LLMBackends(backend) | |
| if backend == LLMBackends.HFChat: | |
| # llm = LazyHuggingFaceEndpoint( | |
| # endpoint_url=inference_server_url, | |
| # task="text-generation", | |
| # huggingfacehub_api_token=token, | |
| # **model_init_kwargs, | |
| # ) | |
| llm = LazyHuggingFaceEndpoint( | |
| repo_id=model_id, | |
| task="text-generation", | |
| huggingfacehub_api_token=token, | |
| **model_init_kwargs, | |
| ) | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
| chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer) | |
| elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]: | |
| chat_model = ChatOpenAI( | |
| model=model_id, | |
| openai_api_base=inference_server_url, # type: ignore | |
| openai_api_key=token, # type: ignore | |
| **model_init_kwargs, | |
| ) | |
| else: | |
| raise ValueError(f"Backend {backend} not supported") | |
| return chat_model | |