JohnSmith9982's picture
Upload 114 files
7bab2f2
raw history blame
No virus
3.48 kB
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel
class ERNIE_Client(BaseLLMModel):
def __init__(self, model_name, api_key, secret_key) -> None:
super().__init__(model_name=model_name)
self.api_key = api_key
self.api_secret = secret_key
if None in [self.api_secret, self.api_key]:
raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key")
if self.model_name == "ERNIE-Bot-turbo":
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
elif self.model_name == "ERNIE-Bot":
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
elif self.model_name == "ERNIE-Bot-4":
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="
def get_access_token(self):
"""
使用 AK,SK 生成鉴权签名(Access Token)
:return: access_token,或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials"
payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.json()["access_token"]
def get_answer_stream_iter(self):
url = self.ERNIE_url + self.get_access_token()
system_prompt = self.system_prompt
history = self.history
if system_prompt is not None:
history = [construct_system(system_prompt), *history]
# 去除history中 history的role为system的
history = [i for i in history if i["role"] != "system"]
payload = json.dumps({
"messages":history,
"stream": True
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload, stream=True)
if response.status_code == 200:
partial_text = ""
for line in response.iter_lines():
if len(line) == 0:
continue
line = json.loads(line[5:])
partial_text += line['result']
yield partial_text
else:
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
def get_answer_at_once(self):
url = self.ERNIE_url + self.get_access_token()
system_prompt = self.system_prompt
history = self.history
if system_prompt is not None:
history = [construct_system(system_prompt), *history]
# 去除history中 history的role为system的
history = [i for i in history if i["role"] != "system"]
payload = json.dumps({
"messages": history,
"stream": True
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload, stream=True)
if response.status_code == 200:
return str(response.json()["result"]),len(response.json()["result"])
else:
return "获取资源错误", 0