hellow-langChain / project /llm /zhipuai_llm.py
guangliang.yin
提示词优化-3
bf8722b
raw
history blame
No virus
7.15 kB
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : zhipuai_llm.py
@Time : 2023/10/16 22:06:26
@Author : 0-yy-0
@Version : 1.0
@Contact : 310484121@qq.com
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
@Desc : 基于智谱 AI 大模型自定义 LLM 类
'''
from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
from project.llm.self_llm import Self_LLM
import re
logger = logging.getLogger(__name__)
class ZhipuAILLM(Self_LLM):
"""Zhipuai hosted open source or customized models.
To use, you should have the ``zhipuai`` python package installed, and
the environment variable ``zhipuai_api_key`` set with
your API key and Secret Key.
zhipuai_api_key are required parameters which you could get from
https://open.bigmodel.cn/usercenter/apikeys
Example:
.. code-block:: python
from langchain.llms import ZhipuAILLM
zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
client: Any
model: str = "chatglm_std"
"""Model name in chatglm_pro, chatglm_std, chatglm_lite. """
zhipuai_api_key: Optional[str] = None
incremental: Optional[bool] = True
"""Whether to incremental the results or not."""
streaming: Optional[bool] = False
"""Whether to streaming the results or not."""
# streaming = -incremental
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
request_id: Optional[float] = None
@root_validator()
def validate_enviroment(cls, values: Dict) -> Dict:
values["zhipuai_api_key"] = get_from_dict_or_env(
values,
"zhipuai_api_key",
"ZHIPUAI_API_KEY",
)
params = {
"zhipuai_api_key": values["zhipuai_api_key"],
"model": values["model"],
}
try:
#import zhipuai
#zhipuai.api_key = values["zhipuai_api_key"]
#values["client"] = zhipuai()
from zhipuai import ZhipuAI
conf_api_key = values["zhipuai_api_key"]
client = ZhipuAI(api_key=conf_api_key)
values["client"] = client
except ImportError:
raise ValueError(
"zhipuai package not found, please install it with "
"`pip install zhipuai`"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"model": self.model},
**super()._identifying_params,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "zhipuai"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {
"streaming": self.streaming,
"top_p": self.top_p,
"temperature": self.temperature,
"request_id": self.request_id,
}
return {**normal_params, **self.model_kwargs}
def _convert_prompt_msg_params(
self,
prompt: str,
**kwargs: Any,
) -> dict:
return {
**{"prompt": prompt, "model": self.model},
**self._default_params,
**kwargs,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an zhipuai models endpoint for each generation with a prompt.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = zhipuai_model("Tell me a joke.")
"""
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
all_word = params['prompt']
keyword = "问题"
matches = re.finditer(keyword, all_word)
indexes = [match.start() for match in matches]
last_index = indexes[len(indexes) -1]
params = {"messages": [
{"role": "system", "content": all_word[0:last_index]},
{"role": "user", "content": all_word[last_index:len(all_word)]}],
"model": self.model, "stream": False, "top_p": 0.8, "temperature": 0.01, "request_id": None}
print("params:", params)
response_payload = self.client.chat.completions.create(**params)
print("response_payload", response_payload)
return response_payload.choices[0].message.content
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response = await self.client.async_invoke(**params)
return response_payload
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
for res in self.client.invoke(**params):
if res:
chunk = GenerationChunk(text=res)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
async for res in await self.client.ado(**params):
if res:
chunk = GenerationChunk(text=res["data"]["choices"]["content"])
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)