|
import json |
|
from copy import deepcopy |
|
from typing import Any, Dict, List |
|
from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow |
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class Command: |
|
name: str |
|
description: str |
|
input_args: List[str] |
|
|
|
|
|
class ControllerAtomicFlow(OpenAIChatAtomicFlow): |
|
def __init__(self, commands: List[Command], **kwargs): |
|
super().__init__(**kwargs) |
|
self.system_message_prompt_template = self.system_message_prompt_template.partial( |
|
commands=self._build_commands_manual(commands) |
|
) |
|
|
|
@staticmethod |
|
def _build_commands_manual(commands: List[Command]) -> str: |
|
ret = "" |
|
for i, command in enumerate(commands): |
|
command_input_json_schema = json.dumps( |
|
{input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args}) |
|
ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n" |
|
return ret |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
|
|
|
|
commands = flow_config["commands"] |
|
commands = [ |
|
Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in |
|
commands.items() |
|
] |
|
kwargs.update({"commands": commands}) |
|
|
|
|
|
return cls(**kwargs) |
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
api_output = super().run(input_data)["api_output"].strip() |
|
response = json.loads(api_output) |
|
return response |
|
|