| | import json |
| | import os |
| |
|
| | import colorama |
| | import requests |
| | import logging |
| |
|
| | from modules.models.base_model import BaseLLMModel |
| | from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n |
| |
|
| | group_id = os.environ.get("MINIMAX_GROUP_ID", "") |
| |
|
| |
|
| | class MiniMax_Client(BaseLLMModel): |
| | """ |
| | MiniMax Client |
| | 接口文档见 https://api.minimax.chat/document/guides/chat |
| | """ |
| |
|
| | def __init__(self, model_name, api_key, user_name="", system_prompt=None): |
| | super().__init__(model_name=model_name, user=user_name) |
| | self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' |
| | self.history = [] |
| | self.api_key = api_key |
| | self.system_prompt = system_prompt |
| | self.headers = { |
| | "Authorization": f"Bearer {api_key}", |
| | "Content-Type": "application/json" |
| | } |
| |
|
| | def get_answer_at_once(self): |
| | |
| | temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 |
| |
|
| | request_body = { |
| | "model": self.model_name.replace('minimax-', ''), |
| | "temperature": temperature, |
| | "skip_info_mask": True, |
| | 'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}] |
| | } |
| | if self.n_choices: |
| | request_body['beam_width'] = self.n_choices |
| | if self.system_prompt: |
| | request_body['prompt'] = self.system_prompt |
| | if self.max_generation_token: |
| | request_body['tokens_to_generate'] = self.max_generation_token |
| | if self.top_p: |
| | request_body['top_p'] = self.top_p |
| |
|
| | response = requests.post(self.url, headers=self.headers, json=request_body) |
| |
|
| | res = response.json() |
| | answer = res['reply'] |
| | total_token_count = res["usage"]["total_tokens"] |
| | return answer, total_token_count |
| |
|
| | def get_answer_stream_iter(self): |
| | response = self._get_response(stream=True) |
| | if response is not None: |
| | iter = self._decode_chat_response(response) |
| | partial_text = "" |
| | for i in iter: |
| | partial_text += i |
| | yield partial_text |
| | else: |
| | yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG |
| |
|
| | def _get_response(self, stream=False): |
| | minimax_api_key = self.api_key |
| | history = self.history |
| | logging.debug(colorama.Fore.YELLOW + |
| | f"{history}" + colorama.Fore.RESET) |
| | headers = { |
| | "Content-Type": "application/json", |
| | "Authorization": f"Bearer {minimax_api_key}", |
| | } |
| |
|
| | temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10 |
| |
|
| | messages = [] |
| | for msg in self.history: |
| | if msg['role'] == 'user': |
| | messages.append({"sender_type": "USER", "text": msg['content']}) |
| | else: |
| | messages.append({"sender_type": "BOT", "text": msg['content']}) |
| |
|
| | request_body = { |
| | "model": self.model_name.replace('minimax-', ''), |
| | "temperature": temperature, |
| | "skip_info_mask": True, |
| | 'messages': messages |
| | } |
| | if self.n_choices: |
| | request_body['beam_width'] = self.n_choices |
| | if self.system_prompt: |
| | lines = self.system_prompt.splitlines() |
| | if lines[0].find(":") != -1 and len(lines[0]) < 20: |
| | request_body["role_meta"] = { |
| | "user_name": lines[0].split(":")[0], |
| | "bot_name": lines[0].split(":")[1] |
| | } |
| | lines.pop() |
| | request_body["prompt"] = "\n".join(lines) |
| | if self.max_generation_token: |
| | request_body['tokens_to_generate'] = self.max_generation_token |
| | else: |
| | request_body['tokens_to_generate'] = 512 |
| | if self.top_p: |
| | request_body['top_p'] = self.top_p |
| |
|
| | if stream: |
| | timeout = TIMEOUT_STREAMING |
| | request_body['stream'] = True |
| | request_body['use_standard_sse'] = True |
| | else: |
| | timeout = TIMEOUT_ALL |
| | try: |
| | response = requests.post( |
| | self.url, |
| | headers=headers, |
| | json=request_body, |
| | stream=stream, |
| | timeout=timeout, |
| | ) |
| | except: |
| | return None |
| |
|
| | return response |
| |
|
| | def _decode_chat_response(self, response): |
| | error_msg = "" |
| | for chunk in response.iter_lines(): |
| | if chunk: |
| | chunk = chunk.decode() |
| | chunk_length = len(chunk) |
| | print(chunk) |
| | try: |
| | chunk = json.loads(chunk[6:]) |
| | except json.JSONDecodeError: |
| | print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}") |
| | error_msg += chunk |
| | continue |
| | if chunk_length > 6 and "delta" in chunk["choices"][0]: |
| | if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop": |
| | self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts)) |
| | break |
| | try: |
| | yield chunk["choices"][0]["delta"] |
| | except Exception as e: |
| | logging.error(f"Error: {e}") |
| | continue |
| | if error_msg: |
| | try: |
| | error_msg = json.loads(error_msg) |
| | if 'base_resp' in error_msg: |
| | status_code = error_msg['base_resp']['status_code'] |
| | status_msg = error_msg['base_resp']['status_msg'] |
| | raise Exception(f"{status_code} - {status_msg}") |
| | except json.JSONDecodeError: |
| | pass |
| | raise Exception(error_msg) |
| |
|