File size: 8,849 Bytes
23c9a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c31af5
23c9a86
9c31af5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23c9a86
 
 
 
9c31af5
 
 
 
 
23c9a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c31af5
 
 
 
 
 
 
23c9a86
 
 
 
 
 
 
 
9c31af5
 
 
 
 
 
 
 
 
23c9a86
 
 
 
 
9c31af5
 
 
 
 
 
 
23c9a86
 
 
 
 
 
 
9c31af5
 
 
 
 
 
 
23c9a86
 
 
 
 
 
 
 
9c31af5
 
 
 
 
 
23c9a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c31af5
 
 
 
23c9a86
 
 
 
 
 
 
 
 
 
 
 
 
 
9c31af5
 
 
 
 
 
23c9a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import json
from copy import deepcopy
from typing import Any, Dict, List

from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow

from dataclasses import dataclass


@dataclass
class Command:
    name: str
    description: str
    input_args: List[str]


class ControllerFlow_ExtLib(ChatAtomicFlow):
    """Refer to: https://huggingface.co/aiflows/JarvisFlowModule/blob/main/Controller_JarvisFlow.py for a detailed doc.
    This flow inherits from ChatAtomicFlow and is used as a controller of the ExtendLibraryFlow.

    *Input Interface Non Initialized*:
    - `goal`
    - `memory_files`
    - `plan`
    - `logs`

    *Input Interface Initialized*:
    - `goal`
    - `memory_files`
    - `plan`
    - `logs`
    - `result`

    *Output Interface*:
    - `command`
    - `command_args`

    *Configuration Parameters*:
    - `input_interface_non_initialized`: a list of input interface names when the conversation starts.
    - `input_interface_initialized`: a list of input interface names when the conversation is initialized.
    - `output_interface`: the output of the controller, it is the command and the command arguments.
    - `commands`: a list of commands that the controller can call. Each command has a name, a description, and a list of input arguments.
    - `system_message_prompt_template`: the system message prompt template.
    - `init_human_message_prompt_template`: the initial human (user) message prompt template.
    - `human_message_prompt_template`: the human (user) message prompt template.
    - `previous_messages`: the sliding window of previous messages.


    """
    def __init__(
            self,
            commands: List[Command],
            **kwargs):
        """
        The constructor of the ControllerFlow_ExtLib class.
        :param commands: a list of commands that the controller can call. Each command has a name, a description, and a list of input arguments.
        :param kwargs: the configuration parameters of the ControllerFlow_ExtLib class.
        """
        super().__init__(**kwargs)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            commands=self._build_commands_manual(commands),
            plan_file_location="no location yet",
            plan="no plan yet",
            logs="no logs yet",
        )
        self.hint_for_model = """
        Make sure your response is in the following format:
              Response Format:
              {
              "command": "call one of the commands you have e.g. `write_code`",
              "command_args": {
                  "arg name": "value"
                  }
              }
        """

    @staticmethod
    def _build_commands_manual(commands: List[Command]) -> str:
        """
        Build the manual for the commands.
        :param commands: a list of commands that the controller can call. Each command has a name, a description, and a list of input arguments.
        :type commands: List[Command]
        :return: the manual for the commands.
        :rtype: str
        """
        ret = ""
        for i, command in enumerate(commands):
            command_input_json_schema = json.dumps(
                {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
            ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
        return ret

    def _get_plan_file_location(self, input_data: Dict[str, Any]):
        """ Get the plan file location from the input data.
        :param input_data: the input data.
        :raise AssertionError: if the plan file location is not passed to the controller.
        :raise AssertionError: if the plan file location is not in the memory files.
        :raise AssertionError: if the plan is not in the memory files.
        :type input_data: Dict[str, Any]
        :return: the plan file location.
        :rtype: str
        """
        assert "memory_files" in input_data, "memory_files not passed to Extlib/Controller"
        assert "plan" in input_data["memory_files"], "plan not in memory files"
        return input_data["memory_files"]["plan"]

    def _get_plan_content(self, input_data: Dict[str, Any]):
        """ Get the plan content from the input data.
        :param input_data: the input data.
        :raise AssertionError: if the plan is not passed to the controller.
        :type input_data: Dict[str, Any]
        :return: the plan content.
        :rtype: str
        """
        assert "plan" in input_data, "plan not passed to Extlib/Controller"
        plan_content = input_data["plan"]
        if len(plan_content) == 0:
            plan_content = 'No plan yet'
        return plan_content

    def _get_logs_content(self, input_data: Dict[str, Any]):
        """ Get the logs content from the input data.
        :param input_data: the input data.
        :raise AssertionError: if the logs is not passed to the controller.
        :type input_data: Dict[str, Any]
        :return: the logs content.
        :rtype: str
        """
        assert "logs" in input_data, "logs not passed to Extlib/Controller"
        logs_content = input_data["logs"]
        if len(logs_content) == 0:
            logs_content = "No logs yet"
        return logs_content

    @classmethod
    def instantiate_from_config(cls, config):
        """Instantiate the flow from a configuration.
        :param config: the configuration.
        :type config: Dict[str, Any]
        :return: the instantiated flow.
        :rtype: ControllerFlow_ExtLib
        """
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~Set up backend ~~~
        kwargs.update(cls._set_up_backend(flow_config))

        # ~~~ Set up commands ~~~
        commands = flow_config["commands"]
        commands = [
            Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
            commands.items()
        ]
        kwargs.update({"commands": commands})

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    def _update_prompts_and_input(self, input_data: Dict[str, Any]):
        """Update the prompts and the input data. The system prompts are updated with the new memory.
        :param input_data: the input data.
        :type input_data: Dict[str, Any]
        """
        if 'goal' in input_data:
            input_data['goal'] += self.hint_for_model
        if 'result' in input_data:
            input_data['result'] += self.hint_for_model
        plan_file_location = self._get_plan_file_location(input_data)
        plan_content = self._get_plan_content(input_data)
        logs_content = self._get_logs_content(input_data)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            plan_file_location=plan_file_location,
            plan=plan_content,
            logs=logs_content
        )

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        """Run the flow.
        :param input_data: the input data.
        :type input_data: Dict[str, Any]
        :return: the output of the flow.
        :rtype: Dict[str, Any]
        """
        self._update_prompts_and_input(input_data)

        # ~~~when conversation is initialized, append the updated system prompts to the chat history ~~~
        if self._is_conversation_initialized():
            updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
            self._state_update_add_chat_message(content=updated_system_message_content,
                                                role=self.flow_config["system_name"])
            
        while True:
            api_output = super().run(input_data)["api_output"].strip()
            try:
                start = api_output.index("{")
                end = api_output.rindex("}") + 1
                json_str = api_output[start:end]
                return json.loads(json_str)
            except (json.decoder.JSONDecodeError, json.JSONDecodeError):
                updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
                self._state_update_add_chat_message(content=updated_system_message_content,
                                                    role=self.flow_config["system_name"])
                new_goal = "The previous respond cannot be parsed with json.loads. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable."
                new_input_data = input_data.copy()
                new_input_data['result'] = new_goal
                input_data = new_input_data