File size: 6,051 Bytes
b28a1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
import os

import colorama
import requests
import logging

from modules.models.base_model import BaseLLMModel
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n

group_id = os.environ.get("MINIMAX_GROUP_ID", "")


class MiniMax_Client(BaseLLMModel):
    """
    MiniMax Client
    接口文档见 https://api.minimax.chat/document/guides/chat
    """

    def __init__(self, model_name, api_key, user_name="", system_prompt=None):
        super().__init__(model_name=model_name, user=user_name)
        self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
        self.history = []
        self.api_key = api_key
        self.system_prompt = system_prompt
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }

    def get_answer_at_once(self):
        # minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
        temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10

        request_body = {
            "model": self.model_name.replace('minimax-', ''),
            "temperature": temperature,
            "skip_info_mask": True,
            'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
        }
        if self.n_choices:
            request_body['beam_width'] = self.n_choices
        if self.system_prompt:
            request_body['prompt'] = self.system_prompt
        if self.max_generation_token:
            request_body['tokens_to_generate'] = self.max_generation_token
        if self.top_p:
            request_body['top_p'] = self.top_p

        response = requests.post(self.url, headers=self.headers, json=request_body)

        res = response.json()
        answer = res['reply']
        total_token_count = res["usage"]["total_tokens"]
        return answer, total_token_count

    def get_answer_stream_iter(self):
        response = self._get_response(stream=True)
        if response is not None:
            iter = self._decode_chat_response(response)
            partial_text = ""
            for i in iter:
                partial_text += i
                yield partial_text
        else:
            yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG

    def _get_response(self, stream=False):
        minimax_api_key = self.api_key
        history = self.history
        logging.debug(colorama.Fore.YELLOW +
                      f"{history}" + colorama.Fore.RESET)
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {minimax_api_key}",
        }

        temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10

        messages = []
        for msg in self.history:
            if msg['role'] == 'user':
                messages.append({"sender_type": "USER", "text": msg['content']})
            else:
                messages.append({"sender_type": "BOT", "text": msg['content']})

        request_body = {
            "model": self.model_name.replace('minimax-', ''),
            "temperature": temperature,
            "skip_info_mask": True,
            'messages': messages
        }
        if self.n_choices:
            request_body['beam_width'] = self.n_choices
        if self.system_prompt:
            lines = self.system_prompt.splitlines()
            if lines[0].find(":") != -1 and len(lines[0]) < 20:
                request_body["role_meta"] = {
                    "user_name": lines[0].split(":")[0],
                    "bot_name": lines[0].split(":")[1]
                }
                lines.pop()
            request_body["prompt"] = "\n".join(lines)
        if self.max_generation_token:
            request_body['tokens_to_generate'] = self.max_generation_token
        else:
            request_body['tokens_to_generate'] = 512
        if self.top_p:
            request_body['top_p'] = self.top_p

        if stream:
            timeout = TIMEOUT_STREAMING
            request_body['stream'] = True
            request_body['use_standard_sse'] = True
        else:
            timeout = TIMEOUT_ALL
        try:
            response = requests.post(
                self.url,
                headers=headers,
                json=request_body,
                stream=stream,
                timeout=timeout,
            )
        except:
            return None

        return response

    def _decode_chat_response(self, response):
        error_msg = ""
        for chunk in response.iter_lines():
            if chunk:
                chunk = chunk.decode()
                chunk_length = len(chunk)
                print(chunk)
                try:
                    chunk = json.loads(chunk[6:])
                except json.JSONDecodeError:
                    print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
                    error_msg += chunk
                    continue
                if chunk_length > 6 and "delta" in chunk["choices"][0]:
                    if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
                        self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
                        break
                    try:
                        yield chunk["choices"][0]["delta"]
                    except Exception as e:
                        logging.error(f"Error: {e}")
                        continue
        if error_msg:
            try:
                error_msg = json.loads(error_msg)
                if 'base_resp' in error_msg:
                    status_code = error_msg['base_resp']['status_code']
                    status_msg = error_msg['base_resp']['status_msg']
                    raise Exception(f"{status_code} - {status_msg}")
            except json.JSONDecodeError:
                pass
            raise Exception(error_msg)