File size: 4,627 Bytes
bdadbc2
 
258e7ae
bdadbc2
 
258e7ae
 
 
 
 
 
 
 
 
 
 
bdadbc2
258e7ae
 
 
 
 
 
527ef24
258e7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdadbc2
 
 
 
 
 
 
 
 
 
 
258e7ae
bdadbc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258e7ae
bdadbc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258e7ae
 
 
 
 
 
bdadbc2
 
 
 
 
258e7ae
 
 
 
 
 
 
 
bdadbc2
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
from copy import deepcopy
from typing import Any, Dict, List

from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow


from dataclasses import dataclass


@dataclass
class Command:
    name: str
    description: str
    input_args: List[str]

class Controller_CoderFlow(ChatAtomicFlow):
    def __init__(
            self,
            commands: List[Command],
            **kwargs):
        super().__init__(**kwargs)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            commands=self._build_commands_manual(commands),
            plan="no plans yet",
            plan_file_location="no location yet",
            code_library="no code library yet",
            code_library_location="no location yet",
            logs="no logs yet",
        )
        self.hint_for_model = """
        Make sure your response is in the following format:
              Response Format:
              {
              "command": "call one of the subordinates",
              "command_args": {
                  "arg name": "value"
                  }
              }
        """

    @staticmethod
    def _build_commands_manual(commands: List[Command]) -> str:
        ret = ""
        for i, command in enumerate(commands):
            command_input_json_schema = json.dumps(
                {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
            ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
        return ret

    def _get_content_file_location(self, input_data, content_name):
        # get the location of the file that contains the content: plan, logs, code_library
        assert "memory_files" in input_data, "memory_files not passed to Coder/Controller"
        assert content_name in input_data["memory_files"], f"{content_name} not in memory files"
        return input_data["memory_files"][content_name]

    def _get_content(self, input_data, content_name):
        # get the content of the file that contains the content: plan, logs, code_library
        assert content_name in input_data, f"{content_name} not passed to Coder/Controller"
        content = input_data[content_name]
        if len(content) == 0:
            content = f'No {content_name} yet'
        return content

    @classmethod
    def instantiate_from_config(cls, config):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~Set up backend ~~~
        kwargs.update(cls._set_up_backend(flow_config))

        # ~~~ Set up commands ~~~
        commands = flow_config["commands"]
        commands = [
            Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
            commands.items()
        ]
        kwargs.update({"commands": commands})

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    def _update_prompts_and_input(self, input_data: Dict[str, Any]):
        if 'goal' in input_data:
            input_data['goal'] += self.hint_for_model
        if 'result' in input_data:
            input_data['result'] += self.hint_for_model
        plan_file_location = self._get_content_file_location(input_data, "plan")
        plan_content = self._get_content(input_data, "plan")
        code_library_location = self._get_content_file_location(input_data, "code_library")
        code_library_content = self._get_content(input_data, "code_library")
        logs_content = self._get_content(input_data, "logs")
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            plan_file_location=plan_file_location,
            plan=plan_content,
            code_library_location=code_library_location,
            code_library=code_library_content,
            logs=logs_content
        )

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        self._update_prompts_and_input(input_data)
        api_output = super().run(input_data)["api_output"].strip()
        try:
            response = json.loads(api_output)
            return response
        except json.decoder.JSONDecodeError:
            new_input_data = input_data.copy()
            new_input_data['result'] = "The previous respond cannot be parsed with json.loads. Make sure your next response is in JSON format."
            new_api_output = super().run(new_input_data)["api_output"].strip()
            return json.loads(new_api_output)