Komorebi / modules /models /minimax.py
meteor-2023's picture
Duplicate from JohnSmith9982/ChuanhuChatGPT
3678cf8
raw
history blame
6.05 kB
import json
import os
import colorama
import requests
import logging
from modules.models.base_model import BaseLLMModel
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n
group_id = os.environ.get("MINIMAX_GROUP_ID", "")
class MiniMax_Client(BaseLLMModel):
"""
MiniMax Client
接口文档见 https://api.minimax.chat/document/guides/chat
"""
def __init__(self, model_name, api_key, user_name="", system_prompt=None):
super().__init__(model_name=model_name, user=user_name)
self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
self.history = []
self.api_key = api_key
self.system_prompt = system_prompt
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def get_answer_at_once(self):
# minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
request_body = {
"model": self.model_name.replace('minimax-', ''),
"temperature": temperature,
"skip_info_mask": True,
'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
}
if self.n_choices:
request_body['beam_width'] = self.n_choices
if self.system_prompt:
request_body['prompt'] = self.system_prompt
if self.max_generation_token:
request_body['tokens_to_generate'] = self.max_generation_token
if self.top_p:
request_body['top_p'] = self.top_p
response = requests.post(self.url, headers=self.headers, json=request_body)
res = response.json()
answer = res['reply']
total_token_count = res["usage"]["total_tokens"]
return answer, total_token_count
def get_answer_stream_iter(self):
response = self._get_response(stream=True)
if response is not None:
iter = self._decode_chat_response(response)
partial_text = ""
for i in iter:
partial_text += i
yield partial_text
else:
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
def _get_response(self, stream=False):
minimax_api_key = self.api_key
history = self.history
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {minimax_api_key}",
}
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
messages = []
for msg in self.history:
if msg['role'] == 'user':
messages.append({"sender_type": "USER", "text": msg['content']})
else:
messages.append({"sender_type": "BOT", "text": msg['content']})
request_body = {
"model": self.model_name.replace('minimax-', ''),
"temperature": temperature,
"skip_info_mask": True,
'messages': messages
}
if self.n_choices:
request_body['beam_width'] = self.n_choices
if self.system_prompt:
lines = self.system_prompt.splitlines()
if lines[0].find(":") != -1 and len(lines[0]) < 20:
request_body["role_meta"] = {
"user_name": lines[0].split(":")[0],
"bot_name": lines[0].split(":")[1]
}
lines.pop()
request_body["prompt"] = "\n".join(lines)
if self.max_generation_token:
request_body['tokens_to_generate'] = self.max_generation_token
else:
request_body['tokens_to_generate'] = 512
if self.top_p:
request_body['top_p'] = self.top_p
if stream:
timeout = TIMEOUT_STREAMING
request_body['stream'] = True
request_body['use_standard_sse'] = True
else:
timeout = TIMEOUT_ALL
try:
response = requests.post(
self.url,
headers=headers,
json=request_body,
stream=stream,
timeout=timeout,
)
except:
return None
return response
def _decode_chat_response(self, response):
error_msg = ""
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode()
chunk_length = len(chunk)
print(chunk)
try:
chunk = json.loads(chunk[6:])
except json.JSONDecodeError:
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
error_msg += chunk
continue
if chunk_length > 6 and "delta" in chunk["choices"][0]:
if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
break
try:
yield chunk["choices"][0]["delta"]
except Exception as e:
logging.error(f"Error: {e}")
continue
if error_msg:
try:
error_msg = json.loads(error_msg)
if 'base_resp' in error_msg:
status_code = error_msg['base_resp']['status_code']
status_msg = error_msg['base_resp']['status_msg']
raise Exception(f"{status_code} - {status_msg}")
except json.JSONDecodeError:
pass
raise Exception(error_msg)