silk-road's picture
Upload 18 files
fee0ada
# ErnieGPT.py
from pyexpat import model
import erniebot
#以下密钥信息从os环境获取
import os
import copy
# appid = os.environ['APPID']
# api_secret = os.environ['APISecret']
# api_key = os.environ['APIKey']
erniebot.api_type = os.environ["APIType"]
erniebot.access_token = os.environ["ErnieAccess"]
from .BaseLLM import BaseLLM
class ErnieGPT(BaseLLM):
def __init__(self,model="ernie-bot", ernie_trick = True ):
super(ErnieGPT,self).__init__()
self.model = model
if model not in ["ernie-bot", "ernie-bot-turbo", "ernie-vilg-v2", "ernie-text-embedding", "ernie-bot-8k", "ernie-bot-4"]:
raise Exception("Unknown Ernie model")
# SparkApi.answer =""
self.messages = []
self.ernie_trick = ernie_trick
def initialize_message(self):
self.messages = []
def ai_message(self, payload):
if len(self.messages) == 0:
self.user_message("请根据我的要求进行角色扮演:")
elif len(self.messages) % 2 == 1:
self.messages.append({"role":"assistant","content":payload})
elif len(self.messages)% 2 == 0:
self.messages[-1]["content"] += "\n"+ payload
def system_message(self, payload):
self.messages.append({"role":"user","content":payload})
def user_message(self, payload):
if len(self.messages) % 2 == 0:
self.messages.append({"role":"user","content":payload})
# self.messages[-1]["content"] +=
elif len(self.messages)% 2 == 1:
self.messages[-1]["content"] += "\n"+ payload
def get_response(self):
# question = checklen(getText("user",Input))
chat_messages = copy.deepcopy(self.messages)
lines = chat_messages[-1]["content"].split('\n')
if self.ernie_trick:
lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
chat_messages[-1]["content"] = '\n'.join(lines)
# chat_messages[-1]["content"] = "请请模仿上述经典桥段进行回复\n" + chat_messages[-1]["content"]
response = erniebot.ChatCompletion.create(model=self.model, messages=chat_messages)
# message_json = [{"role": "user", "content": self.messages}]
# SparkApi.answer =""
# SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,message_json)
return response["result"]
def print_prompt(self):
for message in self.messages:
print(f"{message['role']}: {message['content']}")