|
from ..presets import * |
|
from ..utils import * |
|
|
|
from .base_model import BaseLLMModel |
|
|
|
|
|
class ERNIE_Client(BaseLLMModel): |
|
def __init__(self, model_name, api_key, secret_key) -> None: |
|
super().__init__(model_name=model_name) |
|
self.api_key = api_key |
|
self.api_secret = secret_key |
|
if None in [self.api_secret, self.api_key]: |
|
raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key") |
|
|
|
if self.model_name == "ERNIE-Bot-turbo": |
|
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" |
|
elif self.model_name == "ERNIE-Bot": |
|
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" |
|
elif self.model_name == "ERNIE-Bot-4": |
|
self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" |
|
|
|
def get_access_token(self): |
|
""" |
|
使用 AK,SK 生成鉴权签名(Access Token) |
|
:return: access_token,或是None(如果错误) |
|
""" |
|
url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials" |
|
|
|
payload = json.dumps("") |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json' |
|
} |
|
|
|
response = requests.request("POST", url, headers=headers, data=payload) |
|
|
|
return response.json()["access_token"] |
|
def get_answer_stream_iter(self): |
|
url = self.ERNIE_url + self.get_access_token() |
|
system_prompt = self.system_prompt |
|
history = self.history |
|
if system_prompt is not None: |
|
history = [construct_system(system_prompt), *history] |
|
|
|
|
|
history = [i for i in history if i["role"] != "system"] |
|
|
|
payload = json.dumps({ |
|
"messages":history, |
|
"stream": True |
|
}) |
|
headers = { |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
response = requests.request("POST", url, headers=headers, data=payload, stream=True) |
|
|
|
if response.status_code == 200: |
|
partial_text = "" |
|
for line in response.iter_lines(): |
|
if len(line) == 0: |
|
continue |
|
line = json.loads(line[5:]) |
|
partial_text += line['result'] |
|
yield partial_text |
|
else: |
|
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG |
|
|
|
|
|
def get_answer_at_once(self): |
|
url = self.ERNIE_url + self.get_access_token() |
|
system_prompt = self.system_prompt |
|
history = self.history |
|
if system_prompt is not None: |
|
history = [construct_system(system_prompt), *history] |
|
|
|
|
|
history = [i for i in history if i["role"] != "system"] |
|
|
|
payload = json.dumps({ |
|
"messages": history, |
|
"stream": True |
|
}) |
|
headers = { |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
response = requests.request("POST", url, headers=headers, data=payload, stream=True) |
|
|
|
if response.status_code == 200: |
|
|
|
return str(response.json()["result"]),len(response.json()["result"]) |
|
else: |
|
return "获取资源错误", 0 |
|
|
|
|
|
|