WordsApp / llm /wenxin_llm.py
dht-tb16p
Commit 1st version
e60c070
raw
history blame
3.24 kB
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : wenxin_llm.py
@Time : 2023/10/16 18:53:26
@Author : Logan Zou
@Version : 1.0
@Contact : loganzou0421@163.com
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
@Desc : 基于百度文心大模型自定义 LLM 类
'''
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
from pydantic import Field
from llm.self_llm import Self_LLM
import json
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
# 调用文心 API 的工具函数
def get_access_token(api_key : str, secret_key : str):
"""
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
"""
# 指定网址
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"
# 设置 POST 访问
payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
# 通过 POST 访问获取账户对应的 access_token
response = requests.request("POST", url, headers=headers, data=payload)
return response.json().get("access_token")
class Wenxin_LLM(Self_LLM):
# 文心大模型的自定义 LLM
# URL
url : str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={}"
# Secret_Key
secret_key : str = None
# access_token
access_token: str = None
def init_access_token(self):
if self.api_key != None and self.secret_key != None:
# 两个 Key 均非空才可以获取 access_token
try:
self.access_token = get_access_token(self.api_key, self.secret_key)
except Exception as e:
print(e)
print("获取 access_token 失败,请检查 Key")
else:
print("API_Key 或 Secret_Key 为空,请检查 Key")
def _call(self, prompt : str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
# 如果 access_token 为空,初始化 access_token
if self.access_token == None:
self.init_access_token()
# API 调用 url
url = self.url.format(self.access_token)
# 配置 POST 参数
payload = json.dumps({
"messages": [
{
"role": "user",# user prompt
"content": "{}".format(prompt)# 输入的 prompt
}
],
'temperature' : self.temperature
})
headers = {
'Content-Type': 'application/json'
}
# 发起请求
response = requests.request("POST", url, headers=headers, data=payload, timeout=self.request_timeout)
if response.status_code == 200:
# 返回的是一个 Json 字符串
js = json.loads(response.text)
# print(js)
return js["result"]
else:
return "请求失败"
@property
def _llm_type(self) -> str:
return "Wenxin"