File size: 5,276 Bytes
bd7cdc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8479bee
 
 
 
 
 
 
54f092c
 
 
b4e4404
 
 
 
54f092c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import json
from copy import deepcopy
from typing import Any, Dict, List

from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow

from dataclasses import dataclass


@dataclass
class Command:
    name: str
    description: str
    input_args: List[str]

class ControllerFlow_ExtLib(ChatAtomicFlow):
    def __init__(
            self,
            commands: List[Command],
            **kwargs):
        super().__init__(**kwargs)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            commands=self._build_commands_manual(commands),
            plan_file_location="no location yet",
            plan="no plan yet",
            logs="no logs yet",
        )
        self.hint_for_model = """
        Make sure your response is in the following format:
              Response Format:
              {
              "command": "call one of the commands you have e.g. `write_code`",
              "command_args": {
                  "arg name": "value"
                  }
              }
        """

    @staticmethod
    def _build_commands_manual(commands: List[Command]) -> str:
        ret = ""
        for i, command in enumerate(commands):
            command_input_json_schema = json.dumps(
                {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
            ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
        return ret

    def _get_plan_file_location(self, input_data: Dict[str, Any]):
        assert "memory_files" in input_data, "memory_files not passed to Extlib/Controller"
        assert "plan" in input_data["memory_files"], "plan not in memory files"
        return input_data["memory_files"]["plan"]

    def _get_plan_content(self, input_data: Dict[str, Any]):
        assert "plan" in input_data, "plan not passed to Extlib/Controller"
        plan_content = input_data["plan"]
        if len(plan_content) == 0:
            plan_content = 'No plan yet'
        return plan_content

    def _get_logs_content(self, input_data: Dict[str, Any]):
        assert "logs" in input_data, "logs not passed to Extlib/Controller"
        logs_content = input_data["logs"]
        if len(logs_content) == 0:
            logs_content = "No logs yet"
        return logs_content

    @classmethod
    def instantiate_from_config(cls, config):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~Set up backend ~~~
        kwargs.update(cls._set_up_backend(flow_config))

        # ~~~ Set up commands ~~~
        commands = flow_config["commands"]
        commands = [
            Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
            commands.items()
        ]
        kwargs.update({"commands": commands})

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    def _update_prompts_and_input(self, input_data: Dict[str, Any]):
        if 'goal' in input_data:
            input_data['goal'] += self.hint_for_model
        if 'result' in input_data:
            input_data['result'] += self.hint_for_model
        plan_file_location = self._get_plan_file_location(input_data)
        plan_content = self._get_plan_content(input_data)
        logs_content = self._get_logs_content(input_data)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            plan_file_location=plan_file_location,
            plan=plan_content,
            logs=logs_content
        )

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        self._update_prompts_and_input(input_data)

        # ~~~when conversation is initialized, append the updated system prompts to the chat history ~~~
        if self._is_conversation_initialized():
            updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
            self._state_update_add_chat_message(content=updated_system_message_content,
                                                role=self.flow_config["system_name"])
            
        while True:
            api_output = super().run(input_data)["api_output"].strip()
            try:
                start = api_output.index("{")
                end = api_output.rindex("}") + 1
                json_str = api_output[start:end]
                return json.loads(json_str)
            except (json.decoder.JSONDecodeError, json.JSONDecodeError):
                updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
                self._state_update_add_chat_message(content=updated_system_message_content,
                                                    role=self.flow_config["system_name"])
                new_goal = "The previous respond cannot be parsed with json.loads. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable."
                new_input_data = input_data.copy()
                new_input_data['result'] = new_goal
                input_data = new_input_data