|
from typing import Any, Dict |
|
from enum import Enum |
|
|
|
|
|
|
|
from langchain_core import pydantic_v1 |
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
from langchain_core.utils import get_from_dict_or_env |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
|
|
|
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
class LLMBackends(Enum): |
|
"""LLMBackends |
|
|
|
Enum for LLMBackends. |
|
""" |
|
|
|
VLLM = "VLLM" |
|
HFChat = "HFChat" |
|
Fireworks = "Fireworks" |
|
|
|
|
|
class LazyChatHuggingFace(ChatHuggingFace): |
|
"""LazyChatHuggingFace""" |
|
|
|
def __init__(self, **kwargs: Any): |
|
BaseChatModel.__init__(self, **kwargs) |
|
|
|
from transformers import AutoTokenizer |
|
|
|
if not self.model_id: |
|
self._resolve_model_id() |
|
|
|
self.tokenizer = ( |
|
AutoTokenizer.from_pretrained(self.model_id) |
|
if self.tokenizer is None |
|
else self.tokenizer |
|
) |
|
|
|
class LazyHuggingFaceEndpoint(HuggingFaceEndpoint): |
|
"""LazyHuggingFaceEndpoint""" |
|
|
|
|
|
|
|
|
|
@pydantic_v1.root_validator(pre=True) |
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: |
|
return super().build_extra(values) |
|
|
|
@pydantic_v1.root_validator() |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that package is installed and that the API token is valid.""" |
|
try: |
|
from huggingface_hub import AsyncInferenceClient, InferenceClient |
|
|
|
except ImportError: |
|
msg = ( |
|
"Could not import huggingface_hub python package. " |
|
"Please install it with `pip install huggingface_hub`." |
|
) |
|
raise ImportError(msg) |
|
|
|
huggingfacehub_api_token = get_from_dict_or_env( |
|
values, "huggingfacehub_api_token", "HF_TOKEN" |
|
) |
|
|
|
values["client"] = InferenceClient( |
|
model=values["model"], |
|
timeout=values["timeout"], |
|
token=huggingfacehub_api_token, |
|
**values["server_kwargs"], |
|
) |
|
values["async_client"] = AsyncInferenceClient( |
|
model=values["model"], |
|
timeout=values["timeout"], |
|
token=huggingfacehub_api_token, |
|
**values["server_kwargs"], |
|
) |
|
|
|
return values |
|
|
|
|
|
def get_chat_model_wrapper( |
|
model_id: str, |
|
inference_server_url: str, |
|
token: str, |
|
backend: str = "HFChat", |
|
**model_init_kwargs |
|
): |
|
|
|
backend = LLMBackends(backend) |
|
|
|
if backend == LLMBackends.HFChat: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = LazyHuggingFaceEndpoint( |
|
repo_id=model_id, |
|
task="text-generation", |
|
huggingfacehub_api_token=token, |
|
**model_init_kwargs, |
|
) |
|
|
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
|
chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer) |
|
|
|
elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]: |
|
chat_model = ChatOpenAI( |
|
model=model_id, |
|
openai_api_base=inference_server_url, |
|
openai_api_key=token, |
|
**model_init_kwargs, |
|
) |
|
|
|
else: |
|
raise ValueError(f"Backend {backend} not supported") |
|
|
|
return chat_model |
|
|