File size: 3,479 Bytes
7bab2f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from ..presets import *
from ..utils import *

from .base_model import BaseLLMModel


class ERNIE_Client(BaseLLMModel):
    def __init__(self, model_name, api_key, secret_key) -> None:
        super().__init__(model_name=model_name)
        self.api_key = api_key
        self.api_secret = secret_key
        if None in [self.api_secret, self.api_key]:
            raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key")

        if self.model_name == "ERNIE-Bot-turbo":
            self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
        elif self.model_name == "ERNIE-Bot":
            self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
        elif self.model_name == "ERNIE-Bot-4":
            self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="

    def get_access_token(self):
        """
        使用 AK,SK 生成鉴权签名(Access Token)
        :return: access_token,或是None(如果错误)
        """
        url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials"

        payload = json.dumps("")
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }

        response = requests.request("POST", url, headers=headers, data=payload)

        return response.json()["access_token"]
    def get_answer_stream_iter(self):
        url = self.ERNIE_url + self.get_access_token()
        system_prompt = self.system_prompt
        history = self.history
        if system_prompt is not None:
            history = [construct_system(system_prompt), *history]

        # 去除history中 history的role为system的
        history = [i for i in history if i["role"] != "system"]

        payload = json.dumps({
            "messages":history,
            "stream": True
        })
        headers = {
            'Content-Type': 'application/json'
        }

        response = requests.request("POST", url, headers=headers, data=payload, stream=True)

        if response.status_code == 200:
            partial_text = ""
            for line in response.iter_lines():
                if len(line) == 0:
                    continue
                line = json.loads(line[5:])
                partial_text += line['result']
                yield partial_text
        else:
            yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG


    def get_answer_at_once(self):
        url = self.ERNIE_url + self.get_access_token()
        system_prompt = self.system_prompt
        history = self.history
        if system_prompt is not None:
            history = [construct_system(system_prompt), *history]

        # 去除history中 history的role为system的
        history = [i for i in history if i["role"] != "system"]

        payload = json.dumps({
            "messages": history,
            "stream": True
        })
        headers = {
            'Content-Type': 'application/json'
        }

        response = requests.request("POST", url, headers=headers, data=payload, stream=True)

        if response.status_code == 200:

            return str(response.json()["result"]),len(response.json()["result"])
        else:
            return "获取资源错误", 0