|
import json |
|
import os |
|
|
|
import colorama |
|
import requests |
|
import logging |
|
|
|
from modules.models.base_model import BaseLLMModel |
|
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n |
|
|
|
group_id = os.environ.get("MINIMAX_GROUP_ID", "") |
|
|
|
|
|
class MiniMax_Client(BaseLLMModel): |
|
""" |
|
MiniMax Client |
|
接口文档见 https://api.minimax.chat/document/guides/chat |
|
""" |
|
|
|
def __init__(self, model_name, api_key, user_name="", system_prompt=None): |
|
super().__init__(model_name=model_name, user=user_name) |
|
self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' |
|
self.history = [] |
|
self.api_key = api_key |
|
self.system_prompt = system_prompt |
|
self.headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def get_answer_at_once(self): |
|
|
|
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 |
|
|
|
request_body = { |
|
"model": self.model_name.replace('minimax-', ''), |
|
"temperature": temperature, |
|
"skip_info_mask": True, |
|
'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}] |
|
} |
|
if self.n_choices: |
|
request_body['beam_width'] = self.n_choices |
|
if self.system_prompt: |
|
request_body['prompt'] = self.system_prompt |
|
if self.max_generation_token: |
|
request_body['tokens_to_generate'] = self.max_generation_token |
|
if self.top_p: |
|
request_body['top_p'] = self.top_p |
|
|
|
response = requests.post(self.url, headers=self.headers, json=request_body) |
|
|
|
res = response.json() |
|
answer = res['reply'] |
|
total_token_count = res["usage"]["total_tokens"] |
|
return answer, total_token_count |
|
|
|
def get_answer_stream_iter(self): |
|
response = self._get_response(stream=True) |
|
if response is not None: |
|
iter = self._decode_chat_response(response) |
|
partial_text = "" |
|
for i in iter: |
|
partial_text += i |
|
yield partial_text |
|
else: |
|
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG |
|
|
|
def _get_response(self, stream=False): |
|
minimax_api_key = self.api_key |
|
history = self.history |
|
logging.debug(colorama.Fore.YELLOW + |
|
f"{history}" + colorama.Fore.RESET) |
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {minimax_api_key}", |
|
} |
|
|
|
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 |
|
|
|
messages = [] |
|
for msg in self.history: |
|
if msg['role'] == 'user': |
|
messages.append({"sender_type": "USER", "text": msg['content']}) |
|
else: |
|
messages.append({"sender_type": "BOT", "text": msg['content']}) |
|
|
|
request_body = { |
|
"model": self.model_name.replace('minimax-', ''), |
|
"temperature": temperature, |
|
"skip_info_mask": True, |
|
'messages': messages |
|
} |
|
if self.n_choices: |
|
request_body['beam_width'] = self.n_choices |
|
if self.system_prompt: |
|
lines = self.system_prompt.splitlines() |
|
if lines[0].find(":") != -1 and len(lines[0]) < 20: |
|
request_body["role_meta"] = { |
|
"user_name": lines[0].split(":")[0], |
|
"bot_name": lines[0].split(":")[1] |
|
} |
|
lines.pop() |
|
request_body["prompt"] = "\n".join(lines) |
|
if self.max_generation_token: |
|
request_body['tokens_to_generate'] = self.max_generation_token |
|
else: |
|
request_body['tokens_to_generate'] = 512 |
|
if self.top_p: |
|
request_body['top_p'] = self.top_p |
|
|
|
if stream: |
|
timeout = TIMEOUT_STREAMING |
|
request_body['stream'] = True |
|
request_body['use_standard_sse'] = True |
|
else: |
|
timeout = TIMEOUT_ALL |
|
try: |
|
response = requests.post( |
|
self.url, |
|
headers=headers, |
|
json=request_body, |
|
stream=stream, |
|
timeout=timeout, |
|
) |
|
except: |
|
return None |
|
|
|
return response |
|
|
|
def _decode_chat_response(self, response): |
|
error_msg = "" |
|
for chunk in response.iter_lines(): |
|
if chunk: |
|
chunk = chunk.decode() |
|
chunk_length = len(chunk) |
|
print(chunk) |
|
try: |
|
chunk = json.loads(chunk[6:]) |
|
except json.JSONDecodeError: |
|
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}") |
|
error_msg += chunk |
|
continue |
|
if chunk_length > 6 and "delta" in chunk["choices"][0]: |
|
if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop": |
|
self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts)) |
|
break |
|
try: |
|
yield chunk["choices"][0]["delta"] |
|
except Exception as e: |
|
logging.error(f"Error: {e}") |
|
continue |
|
if error_msg: |
|
try: |
|
error_msg = json.loads(error_msg) |
|
if 'base_resp' in error_msg: |
|
status_code = error_msg['base_resp']['status_code'] |
|
status_msg = error_msg['base_resp']['status_msg'] |
|
raise Exception(f"{status_code} - {status_msg}") |
|
except json.JSONDecodeError: |
|
pass |
|
raise Exception(error_msg) |
|
|