{action_name}: {action_input}\n
' + else: + return "" + + +class ChuanhuCallbackHandler(BaseCallbackHandler): + def __init__(self, callback) -> None: + """Initialize callback handler.""" + self.callback = callback + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + self.callback(get_action_description(action.log)) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + # if observation_prefix is not None: + # self.callback(f"\n\n{observation_prefix}") + # self.callback(output) + # if llm_prefix is not None: + # self.callback(f"\n\n{llm_prefix}") + if observation_prefix is not None: + logging.info(observation_prefix) + self.callback(output) + if llm_prefix is not None: + logging.info(llm_prefix) + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + # self.callback(f"{finish.log}\n\n") + logging.info(finish.log) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + self.callback(token) + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + pass + + +class ModelType(Enum): + Unknown = -1 + OpenAI = 0 + ChatGLM = 1 + LLaMA = 2 + XMChat = 3 + StableLM = 4 + MOSS = 5 + YuanAI = 6 + Minimax = 7 + ChuanhuAgent = 8 + GooglePaLM = 9 + LangchainChat = 10 + Midjourney = 11 + Spark = 12 + OpenAIInstruct = 13 + Claude = 14 + Qwen = 15 + OpenAIVision = 16 + ERNIE = 17 + DALLE3 = 18 + + @classmethod + def get_type(cls, model_name: str): + model_type = None + model_name_lower = model_name.lower() + if "gpt" in model_name_lower: + if "instruct" in model_name_lower: + model_type = ModelType.OpenAIInstruct + elif "vision" in model_name_lower: + model_type = ModelType.OpenAIVision + else: + model_type = ModelType.OpenAI + elif "chatglm" in model_name_lower: + model_type = ModelType.ChatGLM + elif "llama" in model_name_lower or "alpaca" in model_name_lower: + model_type = ModelType.LLaMA + elif "xmchat" in model_name_lower or "imp" in model_name_lower: + model_type = ModelType.XMChat + elif "stablelm" in model_name_lower: + model_type = ModelType.StableLM + elif "moss" in model_name_lower: + model_type = ModelType.MOSS + elif "yuanai" in model_name_lower: + model_type = ModelType.YuanAI + elif "minimax" in model_name_lower: + model_type = ModelType.Minimax + elif "川虎助理" in model_name_lower: + model_type = ModelType.ChuanhuAgent + elif "palm" in model_name_lower: + model_type = ModelType.GooglePaLM + elif "midjourney" in model_name_lower: + model_type = ModelType.Midjourney + elif "azure" in model_name_lower or "api" in model_name_lower: + model_type = ModelType.LangchainChat + elif "星火大模型" in model_name_lower: + model_type = ModelType.Spark + elif "claude" in model_name_lower: + model_type = ModelType.Claude + elif "qwen" in model_name_lower: + model_type = ModelType.Qwen + elif "ernie" in model_name_lower: + model_type = ModelType.ERNIE + elif "dall" in model_name_lower: + model_type = ModelType.DALLE3 + else: + model_type = ModelType.LLaMA + return model_type + + +class BaseLLMModel: + def __init__( + self, + model_name, + system_prompt=INITIAL_SYSTEM_PROMPT, + temperature=1.0, + top_p=1.0, + n_choices=1, + stop="", + max_generation_token=None, + presence_penalty=0, + frequency_penalty=0, + logit_bias=None, + user="", + single_turn=False, + ) -> None: + self.history = [] + self.all_token_counts = [] + try: + self.model_name = MODEL_METADATA[model_name]["model_name"] + except: + self.model_name = model_name + self.model_type = ModelType.get_type(model_name) + try: + self.token_upper_limit = MODEL_METADATA[model_name]["token_limit"] + except KeyError: + self.token_upper_limit = DEFAULT_TOKEN_LIMIT + self.interrupted = False + self.system_prompt = system_prompt + self.api_key = None + self.need_api_key = False + self.history_file_path = get_first_history_name(user) + self.user_name = user + self.chatbot = [] + + self.default_single_turn = single_turn + self.default_temperature = temperature + self.default_top_p = top_p + self.default_n_choices = n_choices + self.default_stop_sequence = stop + self.default_max_generation_token = max_generation_token + self.default_presence_penalty = presence_penalty + self.default_frequency_penalty = frequency_penalty + self.default_logit_bias = logit_bias + self.default_user_identifier = user + + self.single_turn = single_turn + self.temperature = temperature + self.top_p = top_p + self.n_choices = n_choices + self.stop_sequence = stop + self.max_generation_token = max_generation_token + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.logit_bias = logit_bias + self.user_identifier = user + + self.metadata = {} + + def get_answer_stream_iter(self): + """Implement stream prediction. + Conversations are stored in self.history, with the most recent question in OpenAI format. + Should return a generator that yields the next word (str) in the answer. + """ + logging.warning( + "Stream prediction is not implemented. Using at once prediction instead." + ) + response, _ = self.get_answer_at_once() + yield response + + def get_answer_at_once(self): + """predict at once, need to be implemented + conversations are stored in self.history, with the most recent question, in OpenAI format + Should return: + the answer (str) + total token count (int) + """ + logging.warning("at once predict not implemented, using stream predict instead") + response_iter = self.get_answer_stream_iter() + count = 0 + for response in response_iter: + count += 1 + return response, sum(self.all_token_counts) + count + + def billing_info(self): + """get billing infomation, inplement if needed""" + # logging.warning("billing info not implemented, using default") + return BILLING_NOT_APPLICABLE_MSG + + def count_token(self, user_input): + """get token count from input, implement if needed""" + # logging.warning("token count not implemented, using default") + return len(user_input) + + def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""): + def get_return_value(): + return chatbot, status_text + + status_text = i18n("开始实时传输回答……") + if fake_input: + chatbot.append((fake_input, "")) + else: + chatbot.append((inputs, "")) + + user_token_count = self.count_token(inputs) + self.all_token_counts.append(user_token_count) + logging.debug(f"输入token计数: {user_token_count}") + + stream_iter = self.get_answer_stream_iter() + + if display_append: + display_append = ( + '\n\n