|
|
|
|
|
''' |
|
@File : zhipuai_llm.py |
|
@Time : 2023/10/16 22:06:26 |
|
@Author : 0-yy-0 |
|
@Version : 1.0 |
|
@Contact : 310484121@qq.com |
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
@Desc : 基于智谱 AI 大模型自定义 LLM 类 |
|
''' |
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
from typing import ( |
|
Any, |
|
AsyncIterator, |
|
Dict, |
|
Iterator, |
|
List, |
|
Optional, |
|
) |
|
|
|
from langchain.callbacks.manager import ( |
|
AsyncCallbackManagerForLLMRun, |
|
CallbackManagerForLLMRun, |
|
) |
|
from langchain.llms.base import LLM |
|
from langchain.pydantic_v1 import Field, root_validator |
|
from langchain.schema.output import GenerationChunk |
|
from langchain.utils import get_from_dict_or_env |
|
from llm.self_llm import Self_LLM |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ZhipuAILLM(Self_LLM): |
|
"""Zhipuai hosted open source or customized models. |
|
|
|
To use, you should have the ``zhipuai`` python package installed, and |
|
the environment variable ``zhipuai_api_key`` set with |
|
your API key and Secret Key. |
|
|
|
zhipuai_api_key are required parameters which you could get from |
|
https://open.bigmodel.cn/usercenter/apikeys |
|
|
|
Example: |
|
.. code-block:: python |
|
|
|
from langchain.llms import ZhipuAILLM |
|
zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature) |
|
|
|
""" |
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
client: Any |
|
|
|
model: str = "chatglm_std" |
|
"""Model name in chatglm_pro, chatglm_std, chatglm_lite. """ |
|
|
|
zhipuai_api_key: Optional[str] = None |
|
|
|
incremental: Optional[bool] = True |
|
"""Whether to incremental the results or not.""" |
|
|
|
streaming: Optional[bool] = False |
|
"""Whether to streaming the results or not.""" |
|
|
|
|
|
request_timeout: Optional[int] = 60 |
|
"""request timeout for chat http requests""" |
|
|
|
top_p: Optional[float] = 0.8 |
|
temperature: Optional[float] = 0.95 |
|
request_id: Optional[float] = None |
|
|
|
@root_validator() |
|
def validate_enviroment(cls, values: Dict) -> Dict: |
|
|
|
values["zhipuai_api_key"] = get_from_dict_or_env( |
|
values, |
|
"zhipuai_api_key", |
|
"ZHIPUAI_API_KEY", |
|
) |
|
|
|
params = { |
|
"zhipuai_api_key": values["zhipuai_api_key"], |
|
"model": values["model"], |
|
} |
|
try: |
|
import zhipuai |
|
|
|
zhipuai.api_key = values["zhipuai_api_key"] |
|
values["client"] = zhipuai.model_api |
|
except ImportError: |
|
raise ValueError( |
|
"zhipuai package not found, please install it with " |
|
"`pip install zhipuai`" |
|
) |
|
return values |
|
|
|
@property |
|
def _identifying_params(self) -> Dict[str, Any]: |
|
return { |
|
**{"model": self.model}, |
|
**super()._identifying_params, |
|
} |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
"""Return type of llm.""" |
|
return "zhipuai" |
|
|
|
@property |
|
def _default_params(self) -> Dict[str, Any]: |
|
"""Get the default parameters for calling OpenAI API.""" |
|
normal_params = { |
|
"streaming": self.streaming, |
|
"top_p": self.top_p, |
|
"temperature": self.temperature, |
|
"request_id": self.request_id, |
|
} |
|
|
|
return {**normal_params, **self.model_kwargs} |
|
|
|
def _convert_prompt_msg_params( |
|
self, |
|
prompt: str, |
|
**kwargs: Any, |
|
) -> dict: |
|
return { |
|
**{"prompt": prompt, "model": self.model}, |
|
**self._default_params, |
|
**kwargs, |
|
} |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
"""Call out to an zhipuai models endpoint for each generation with a prompt. |
|
Args: |
|
prompt: The prompt to pass into the model. |
|
Returns: |
|
The string generated by the model. |
|
|
|
Example: |
|
.. code-block:: python |
|
response = zhipuai_model("Tell me a joke.") |
|
""" |
|
if self.streaming: |
|
completion = "" |
|
for chunk in self._stream(prompt, stop, run_manager, **kwargs): |
|
completion += chunk.text |
|
return completion |
|
params = self._convert_prompt_msg_params(prompt, **kwargs) |
|
|
|
response_payload = self.client.invoke(**params) |
|
return response_payload["data"]["choices"][-1]["content"].strip('"').strip(" ") |
|
|
|
async def _acall( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
if self.streaming: |
|
completion = "" |
|
async for chunk in self._astream(prompt, stop, run_manager, **kwargs): |
|
completion += chunk.text |
|
return completion |
|
|
|
params = self._convert_prompt_msg_params(prompt, **kwargs) |
|
|
|
response = await self.client.async_invoke(**params) |
|
|
|
return response_payload |
|
|
|
def _stream( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> Iterator[GenerationChunk]: |
|
params = self._convert_prompt_msg_params(prompt, **kwargs) |
|
|
|
for res in self.client.invoke(**params): |
|
if res: |
|
chunk = GenerationChunk(text=res) |
|
yield chunk |
|
if run_manager: |
|
run_manager.on_llm_new_token(chunk.text) |
|
|
|
async def _astream( |
|
|
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> AsyncIterator[GenerationChunk]: |
|
params = self._convert_prompt_msg_params(prompt, **kwargs) |
|
|
|
async for res in await self.client.ado(**params): |
|
if res: |
|
chunk = GenerationChunk(text=res["data"]["choices"]["content"]) |
|
|
|
yield chunk |
|
if run_manager: |
|
await run_manager.on_llm_new_token(chunk.text) |
|
|