silk-road's picture
Upload 18 files
fee0ada
import os
import json
import time
import hashlib
import requests
import copy
from .BaseLLM import BaseLLM
BAICHUAN_API_AK = os.getenv("BAICHUAN_API_AK")
BAICHUAN_API_SK = os.getenv("BAICHUAN_API_SK")
def sign(secret_key, data):
json_data = json.dumps(data)
time_stamp = int(time.time())
input_string = secret_key + json_data + str(time_stamp)
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
def do_request(messages, api_key, secret_key):
url = "https://api.baichuan-ai.com/v1/chat"
data = {
"model": "Baichuan2-53B",
"messages": messages
}
signature = sign(secret_key, data)
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(int(time.time())),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(url, data=json.dumps(data), headers=headers)
if response.status_code == 200:
return response.json()
else:
return None
class BaiChuanAPIGPT(BaseLLM):
def __init__(self, model="baichuan-api", api_key=None, secret_key=None, verbose=False, if_trick = True):
self.if_trick = if_trick
super(BaiChuanAPIGPT, self).__init__()
self.api_key = api_key or BAICHUAN_API_AK
self.secret_key = secret_key or BAICHUAN_API_SK
self.verbose = verbose
self.model_name = model
self.messages = []
if self.verbose:
print('model name, ', self.model_name)
if self.api_key is None or self.secret_key is None:
print('Please set BAICHUAN_API_AK and BAICHUAN_API_SK')
def initialize_message(self):
self.messages = []
def ai_message(self, payload):
if len(self.messages) == 0:
self.user_message("请根据我的要求进行角色扮演:")
elif len(self.messages) % 2 == 1:
self.messages.append({"role":"assistant","content":payload})
elif len(self.messages)% 2 == 0:
self.messages[-1]["content"] += "\n"+ payload
def system_message(self, payload):
self.messages.append({"role":"user","content":payload})
def user_message(self, payload):
if len(self.messages) % 2 == 0:
self.messages.append({"role":"user","content":payload})
# self.messages[-1]["content"] +=
elif len(self.messages)% 2 == 1:
self.messages[-1]["content"] += "\n"+ payload
def get_response(self):
max_try = 5
sleep_interval = 3
chat_messages = copy.deepcopy(self.messages)
if self.if_trick == True:
lines = chat_messages[-1]["content"].split('\n')
lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
chat_messages[-1]["content"] = '\n'.join(lines)
for i in range(max_try):
response = do_request(chat_messages, self.api_key, self.secret_key)
if response is not None:
if self.verbose:
print('Get Baichuan API response success')
messages = response['data']['messages']
if len(messages) > 0:
return messages[-1]['content'].strip("\"'")
else:
if self.verbose:
print('Get Baichuan API response failed, retrying...')
time.sleep(sleep_interval)
def print_prompt(self):
for message in self.messages:
print(f"{message['role']}: {message['content']}")