#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @File : zhipuai_llm.py @Time : 2023/10/16 22:06:26 @Author : 0-yy-0 @Version : 1.0 @Contact : 310484121@qq.com @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA @Desc : 基于智谱 AI 大模型自定义 LLM 类 ''' from __future__ import annotations import logging from typing import ( Any, AsyncIterator, Dict, Iterator, List, Optional, ) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain.llms.base import LLM from langchain.pydantic_v1 import Field, root_validator from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env from project.llm.self_llm import Self_LLM import re logger = logging.getLogger(__name__) class ZhipuAILLM(Self_LLM): """Zhipuai hosted open source or customized models. To use, you should have the ``zhipuai`` python package installed, and the environment variable ``zhipuai_api_key`` set with your API key and Secret Key. zhipuai_api_key are required parameters which you could get from https://open.bigmodel.cn/usercenter/apikeys Example: .. code-block:: python from langchain.llms import ZhipuAILLM zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature) """ model_kwargs: Dict[str, Any] = Field(default_factory=dict) client: Any model: str = "chatglm_std" """Model name in chatglm_pro, chatglm_std, chatglm_lite. """ zhipuai_api_key: Optional[str] = None incremental: Optional[bool] = True """Whether to incremental the results or not.""" streaming: Optional[bool] = False """Whether to streaming the results or not.""" # streaming = -incremental request_timeout: Optional[int] = 60 """request timeout for chat http requests""" top_p: Optional[float] = 0.8 temperature: Optional[float] = 0.95 request_id: Optional[float] = None @root_validator() def validate_enviroment(cls, values: Dict) -> Dict: values["zhipuai_api_key"] = get_from_dict_or_env( values, "zhipuai_api_key", "ZHIPUAI_API_KEY", ) params = { "zhipuai_api_key": values["zhipuai_api_key"], "model": values["model"], } try: #import zhipuai #zhipuai.api_key = values["zhipuai_api_key"] #values["client"] = zhipuai() from zhipuai import ZhipuAI conf_api_key = values["zhipuai_api_key"] client = ZhipuAI(api_key=conf_api_key) values["client"] = client except ImportError: raise ValueError( "zhipuai package not found, please install it with " "`pip install zhipuai`" ) return values @property def _identifying_params(self) -> Dict[str, Any]: return { **{"model": self.model}, **super()._identifying_params, } @property def _llm_type(self) -> str: """Return type of llm.""" return "zhipuai" @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { "streaming": self.streaming, "top_p": self.top_p, "temperature": self.temperature, "request_id": self.request_id, } return {**normal_params, **self.model_kwargs} def _convert_prompt_msg_params( self, prompt: str, **kwargs: Any, ) -> dict: return { **{"prompt": prompt, "model": self.model}, **self._default_params, **kwargs, } def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to an zhipuai models endpoint for each generation with a prompt. Args: prompt: The prompt to pass into the model. Returns: The string generated by the model. Example: .. code-block:: python response = zhipuai_model("Tell me a joke.") """ if self.streaming: completion = "" for chunk in self._stream(prompt, stop, run_manager, **kwargs): completion += chunk.text return completion params = self._convert_prompt_msg_params(prompt, **kwargs) all_word = params['prompt'] keyword = "问题" matches = re.finditer(keyword, all_word) indexes = [match.start() for match in matches] last_index = indexes[len(indexes) -1] params = {"messages": [ {"role": "system", "content": all_word[0:last_index]}, {"role": "user", "content": all_word[last_index:len(all_word)]}], "model": self.model, "stream": False, "top_p": 0.8, "temperature": 0.01, "request_id": None} print("params:", params) response_payload = self.client.chat.completions.create(**params) print("response_payload", response_payload) return response_payload.choices[0].message.content async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: if self.streaming: completion = "" async for chunk in self._astream(prompt, stop, run_manager, **kwargs): completion += chunk.text return completion params = self._convert_prompt_msg_params(prompt, **kwargs) response = await self.client.async_invoke(**params) return response_payload def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: params = self._convert_prompt_msg_params(prompt, **kwargs) for res in self.client.invoke(**params): if res: chunk = GenerationChunk(text=res) yield chunk if run_manager: run_manager.on_llm_new_token(chunk.text) async def _astream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: params = self._convert_prompt_msg_params(prompt, **kwargs) async for res in await self.client.ado(**params): if res: chunk = GenerationChunk(text=res["data"]["choices"]["content"]) yield chunk if run_manager: await run_manager.on_llm_new_token(chunk.text)