Spaces:
Runtime error
Runtime error
| # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
| from collections import defaultdict | |
| from typing import Dict, List, Optional, Tuple | |
| from camel.messages import BaseMessage | |
| from camel.types import TerminationMode | |
| from .base import ResponseTerminator | |
| class ResponseWordsTerminator(ResponseTerminator): | |
| r"""Terminate agent when some words reached to occurrence | |
| limit by any message of the response. | |
| Args: | |
| words_dict (dict): Dictionary of words and its occurrence | |
| threshold. | |
| case_sensitive (bool): Whether count the words as | |
| case-sensitive. (default: :obj:`False`) | |
| mode (TerminationMode): Whether terminate agent if any | |
| or all pre-set words reached the threshold. | |
| (default: :obj:`TerminationMode.ANY`) | |
| """ | |
| def __init__( | |
| self, | |
| words_dict: Dict[str, int], | |
| case_sensitive: bool = False, | |
| mode: TerminationMode = TerminationMode.ANY, | |
| ): | |
| super().__init__() | |
| self.words_dict = words_dict | |
| self.case_sensitive = case_sensitive | |
| self.mode = mode | |
| self._word_count_dict: List[Dict[str, int]] = [] | |
| self._validate() | |
| def _validate(self): | |
| if len(self.words_dict) == 0: | |
| raise ValueError("`words_dict` cannot be empty") | |
| for word in self.words_dict: | |
| threshold = self.words_dict[word] | |
| if threshold <= 0: | |
| raise ValueError( | |
| f"Threshold for word `{word}` should " | |
| f"be larger than 0, got `{threshold}`" | |
| ) | |
| def is_terminated( | |
| self, messages: List[BaseMessage] | |
| ) -> Tuple[bool, Optional[str]]: | |
| r"""Whether terminate the agent by checking the occurrence | |
| of specified words reached to preset thresholds. | |
| Args: | |
| messages (list): List of :obj:`BaseMessage` from a response. | |
| Returns: | |
| tuple: A tuple containing whether the agent should be | |
| terminated and a string of termination reason. | |
| """ | |
| if self._terminated: | |
| return True, self._termination_reason | |
| for i in range(len(messages)): | |
| if i >= len(self._word_count_dict): | |
| self._word_count_dict.append(defaultdict(int)) | |
| for word in self.words_dict: | |
| special_word = word if self.case_sensitive else word.lower() | |
| for i, message in enumerate(messages): | |
| if self.case_sensitive: | |
| content = message.content | |
| else: | |
| content = message.content.lower() | |
| if special_word in content: | |
| self._word_count_dict[i][word] += 1 | |
| num_reached: List[int] = [] | |
| all_reasons: List[List[str]] = [] | |
| for i in range(len(self._word_count_dict)): | |
| reached = 0 | |
| reasons: List[str] = [] | |
| for word, value in self._word_count_dict[i].items(): | |
| if value >= self.words_dict[word]: | |
| reached += 1 | |
| reason = ( | |
| f"Word `{word}` appears {value} times in the " | |
| f"{i + 1} message of the response which has " | |
| f"reached termination threshold " | |
| f"{self.words_dict[word]}." | |
| ) | |
| reasons.append(reason) | |
| all_reasons.append(reasons) | |
| num_reached.append(reached) | |
| for i, reached in enumerate(num_reached): | |
| if self.mode == TerminationMode.ANY: | |
| if reached > 0: | |
| self._terminated = True | |
| self._termination_reason = "\n".join(all_reasons[i]) | |
| elif self.mode == TerminationMode.ALL: | |
| if reached >= len(self.words_dict): | |
| self._terminated = True | |
| self._termination_reason = "\n".join(all_reasons[i]) | |
| else: | |
| raise ValueError( | |
| f"Unsupported termination mode " f"`{self.mode}`" | |
| ) | |
| return self._terminated, self._termination_reason | |
| def reset(self): | |
| r"""Reset the terminator.""" | |
| self._terminated = False | |
| self._termination_reason = None | |
| self._word_count_dict = defaultdict(int) | |