silk-road's picture
Upload 18 files
fee0ada
raw
history blame contribute delete
No virus
3.72 kB
import os
import json
import time
import hashlib
import requests
import copy
from .BaseLLM import BaseLLM
BAICHUAN_API_AK = os.getenv("BAICHUAN_API_AK")
BAICHUAN_API_SK = os.getenv("BAICHUAN_API_SK")
def sign(secret_key, data):
json_data = json.dumps(data)
time_stamp = int(time.time())
input_string = secret_key + json_data + str(time_stamp)
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
def do_request(messages, api_key, secret_key):
url = "https://api.baichuan-ai.com/v1/chat"
data = {
"model": "Baichuan2-53B",
"messages": messages
}
signature = sign(secret_key, data)
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(int(time.time())),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(url, data=json.dumps(data), headers=headers)
if response.status_code == 200:
return response.json()
else:
return None
class BaiChuanAPIGPT(BaseLLM):
def __init__(self, model="baichuan-api", api_key=None, secret_key=None, verbose=False, if_trick = True):
self.if_trick = if_trick
super(BaiChuanAPIGPT, self).__init__()
self.api_key = api_key or BAICHUAN_API_AK
self.secret_key = secret_key or BAICHUAN_API_SK
self.verbose = verbose
self.model_name = model
self.messages = []
if self.verbose:
print('model name, ', self.model_name)
if self.api_key is None or self.secret_key is None:
print('Please set BAICHUAN_API_AK and BAICHUAN_API_SK')
def initialize_message(self):
self.messages = []
def ai_message(self, payload):
if len(self.messages) == 0:
self.user_message("请根据我的要求进行角色扮演:")
elif len(self.messages) % 2 == 1:
self.messages.append({"role":"assistant","content":payload})
elif len(self.messages)% 2 == 0:
self.messages[-1]["content"] += "\n"+ payload
def system_message(self, payload):
self.messages.append({"role":"user","content":payload})
def user_message(self, payload):
if len(self.messages) % 2 == 0:
self.messages.append({"role":"user","content":payload})
# self.messages[-1]["content"] +=
elif len(self.messages)% 2 == 1:
self.messages[-1]["content"] += "\n"+ payload
def get_response(self):
max_try = 5
sleep_interval = 3
chat_messages = copy.deepcopy(self.messages)
if self.if_trick == True:
lines = chat_messages[-1]["content"].split('\n')
lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
chat_messages[-1]["content"] = '\n'.join(lines)
for i in range(max_try):
response = do_request(chat_messages, self.api_key, self.secret_key)
if response is not None:
if self.verbose:
print('Get Baichuan API response success')
messages = response['data']['messages']
if len(messages) > 0:
return messages[-1]['content'].strip("\"'")
else:
if self.verbose:
print('Get Baichuan API response failed, retrying...')
time.sleep(sleep_interval)
def print_prompt(self):
for message in self.messages:
print(f"{message['role']}: {message['content']}")