File size: 2,558 Bytes
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# ErnieGPT.py
from pyexpat import model
import erniebot 
#以下密钥信息从os环境获取
import os
import copy

# appid = os.environ['APPID']
# api_secret = os.environ['APISecret'] 
# api_key = os.environ['APIKey']
erniebot.api_type = os.environ["APIType"]
erniebot.access_token = os.environ["ErnieAccess"]

from .BaseLLM import BaseLLM

class ErnieGPT(BaseLLM):

    def __init__(self,model="ernie-bot", ernie_trick = True ):
        super(ErnieGPT,self).__init__()
        self.model = model
        if model not in ["ernie-bot", "ernie-bot-turbo", "ernie-vilg-v2", "ernie-text-embedding", "ernie-bot-8k", "ernie-bot-4"]:
            raise Exception("Unknown Ernie model")
        # SparkApi.answer =""
        self.messages = []

        self.ernie_trick = ernie_trick
        

    def initialize_message(self):
        self.messages = []

    def ai_message(self, payload):
        if len(self.messages) == 0:
            self.user_message("请根据我的要求进行角色扮演:")
        elif len(self.messages) % 2 == 1:
            self.messages.append({"role":"assistant","content":payload})
        elif len(self.messages)% 2 == 0:
            self.messages[-1]["content"] += "\n"+ payload

    def system_message(self, payload):
        
        self.messages.append({"role":"user","content":payload}) 
        

    def user_message(self, payload):
        if len(self.messages) % 2 == 0:
            self.messages.append({"role":"user","content":payload})
            # self.messages[-1]["content"] += 
        elif len(self.messages)% 2 == 1:
            self.messages[-1]["content"] += "\n"+ payload

    def get_response(self):
        # question = checklen(getText("user",Input))
        chat_messages = copy.deepcopy(self.messages)

        lines = chat_messages[-1]["content"].split('\n')

        if self.ernie_trick:
            lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
        
        chat_messages[-1]["content"] = '\n'.join(lines)

        # chat_messages[-1]["content"] = "请请模仿上述经典桥段进行回复\n" + chat_messages[-1]["content"] 
        response = erniebot.ChatCompletion.create(model=self.model, messages=chat_messages)
        # message_json = [{"role": "user", "content": self.messages}]
        # SparkApi.answer =""
        # SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,message_json)
        return response["result"]
    
    def print_prompt(self):
        for message in self.messages:
            print(f"{message['role']}: {message['content']}")