File size: 3,870 Bytes
0a537e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Any, Dict
from enum import Enum

#from langchain_community.chat_models.huggingface import ChatHuggingFace
#from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core import pydantic_v1
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.utils import get_from_dict_or_env
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint


from langchain_openai import ChatOpenAI


class LLMBackends(Enum):
    """LLMBackends

    Enum for LLMBackends.
    """

    VLLM = "VLLM"
    HFChat = "HFChat"
    Fireworks = "Fireworks"


class LazyChatHuggingFace(ChatHuggingFace):
    """LazyChatHuggingFace"""

    def __init__(self, **kwargs: Any):
        BaseChatModel.__init__(self, **kwargs)

        from transformers import AutoTokenizer

        if not self.model_id:
            self._resolve_model_id()

        self.tokenizer = (
            AutoTokenizer.from_pretrained(self.model_id)
            if self.tokenizer is None
            else self.tokenizer
        )

class LazyHuggingFaceEndpoint(HuggingFaceEndpoint):
    """LazyHuggingFaceEndpoint"""
    # We're using a lazy endpoint to avoid logging in with hf_token,
    # which might in fact be a hf_oauth token that does only permit inference,
    # not logging in.

    @pydantic_v1.root_validator(pre=True)
    def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        return super().build_extra(values)

    @pydantic_v1.root_validator()
    def validate_environment(cls, values: Dict) -> Dict:  # noqa: UP006, N805
        """Validate that package is installed and that the API token is valid."""
        try:
            from huggingface_hub import AsyncInferenceClient, InferenceClient

        except ImportError:
            msg = (
                "Could not import huggingface_hub python package. "
                "Please install it with `pip install huggingface_hub`."
            )
            raise ImportError(msg)  # noqa: B904

        huggingfacehub_api_token = get_from_dict_or_env(
                values, "huggingfacehub_api_token", "HF_TOKEN"
            )

        values["client"] = InferenceClient(
            model=values["model"],
            timeout=values["timeout"],
            token=huggingfacehub_api_token,
            **values["server_kwargs"],
        )
        values["async_client"] = AsyncInferenceClient(
            model=values["model"],
            timeout=values["timeout"],
            token=huggingfacehub_api_token,
            **values["server_kwargs"],
        )

        return values


def get_chat_model_wrapper(
        model_id: str,
        inference_server_url: str,
        token: str,
        backend: str = "HFChat",
        **model_init_kwargs
    ):

    backend = LLMBackends(backend)

    if backend == LLMBackends.HFChat:
        # llm = LazyHuggingFaceEndpoint(
        #     endpoint_url=inference_server_url,
        #     task="text-generation",
        #     huggingfacehub_api_token=token,
        #     **model_init_kwargs,
        # )

        llm = LazyHuggingFaceEndpoint(
            repo_id=model_id,
            task="text-generation",
            huggingfacehub_api_token=token,
            **model_init_kwargs,
        )

        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
        chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer)

    elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]:
        chat_model = ChatOpenAI(
            model=model_id,
            openai_api_base=inference_server_url,  # type: ignore
            openai_api_key=token,  # type: ignore
            **model_init_kwargs,
        )

    else:
        raise ValueError(f"Backend {backend} not supported")

    return chat_model