zwgao's picture
add file
3fdcc70
import json
from typing import List
import logging
import os
import sys
sys.path.append(os.getcwd())
from ..base import Action, NON_FILE_TYPES
# from cllm.services.tog import TaskSolver, TaskDecomposer, config
# from cllm.services.nlp.llms import ChatOpenAI, MessageMemory
from cllm.services.tog.api import tog, task_decomposer
from collections import OrderedDict
import copy
logger = logging.getLogger(__name__)
class Planner:
def __init__(
self, streaming=False, backend="remote", device="cuda:0", **llm_kwargs
):
self.streaming = streaming
if backend == "local":
pass
# self.cfg = config
# self.device = device
# self.mem = MessageMemory(**self.cfg.memory)
# self.llm = ChatOpenAI(temperature=0.2, **llm_kwargs)
# self.tog = TaskSolver(self.llm, self.cfg.task_solver_config, device).solve
# self.decomposer = TaskDecomposer(device, self.cfg.task_decomposer_cfg).solve
elif backend == "remote":
self.decomposer = task_decomposer
self.tog = tog
else:
raise ValueError("Backend should be chosen from [remote, local]")
def _find_latest_resource(self, resources, type):
for key, val in list(resources.items())[::-1]:
if val == type:
return key
return None
def _check_task_decomposition(
self, task_decomposition: str | list, available_resources: dict
):
copy_task_decomposition = copy.deepcopy(task_decomposition)
available_resources = copy.deepcopy(available_resources)
if isinstance(copy_task_decomposition, str):
copy_task_decomposition = json.loads(copy_task_decomposition)
for subtask in copy_task_decomposition:
for arg in subtask["args"]:
if arg["type"] in NON_FILE_TYPES:
continue
r_type = available_resources.get(arg["value"], "None").split(".")[-1]
if arg["value"] not in available_resources or arg["type"] != r_type:
new_value = self._find_latest_resource(
available_resources, arg["type"]
)
if new_value is None:
logger.error(
f"No available resource for {arg['value']} with type {arg['type']}"
)
return None
arg["value"] = new_value
available_resources[subtask["returns"][0]["value"]] = subtask["returns"][0][
"type"
]
return json.dumps(copy_task_decomposition, indent=2, ensure_ascii=False)
def wrap_request(self, request, memory):
logger.info(memory)
resource_list = {k: v.split(".")[-1] for k, v in memory.items()}
request = f"Resource list: {resource_list}\n{request}"
logger.info(f"Input: {request}")
return request
def solve_streaming(self, request: str, memory: dict = OrderedDict()):
request = self.wrap_request(request, memory)
sub_tasks = self.decomposer(request, streaming=self.streaming)
logger.info(f"Task decomposition: \n{sub_tasks}")
sub_tasks = self._check_task_decomposition(sub_tasks, memory)
yield sub_tasks
if sub_tasks in [None, "", []]:
yield None
else:
solutions = self.tog(request, sub_tasks, streaming=self.streaming)
yield solutions
def solve(self, request: str, memory: dict = OrderedDict()) -> List:
self.wrap_request(request, memory)
sub_tasks = self.decomposer(request)
solutions = self.tog(request, sub_tasks)
print(f"solutions: {solutions}")
return sub_tasks, solutions
def plan(self, task, memory: dict = OrderedDict()) -> List:
if self.streaming:
return self.solve_streaming(task, memory)
else:
return self.solve(task, memory)
def _check_solutions(self, solution: List | str) -> bool:
if isinstance(solution, str):
solution = json.loads(solution)
if len(solution) == 0:
return False
valid = True
for i, stage_candiate in enumerate(solution):
if len(stage_candiate) == 0:
logger.error(f"No solution is found in {i}-th subtask.")
valid = False
elif (
"solution" in stage_candiate[0]
and len(stage_candiate[0]["solution"]) == 0
):
logger.error(f"No solution is found in {i+1}-th subtask.")
valid = False
else:
logger.info(f"Solutions for {i+1}-th subtask:\n{stage_candiate}")
return valid
def parse(self, solution: List | str) -> List[List[Action]]:
if isinstance(solution, str):
solution = json.loads(solution)
if not self._check_solutions(solution):
return None
if isinstance(solution[0], Action):
return solution
stages = []
for i, stage_candiate in enumerate(solution):
stage = stage_candiate[0]["solution"]
actions = []
for action in stage:
inputs = {arg["name"]: arg["value"] for arg in action["args"]}
outputs = [r["value"] for r in action["returns"]]
actions.append(
Action(action["tool_name"], inputs=inputs, outputs=outputs)
)
stages.append(actions)
return stages
def __call__(
self, request: str, memory: dict = OrderedDict()
) -> List[List[Action]]:
solution = self.solve(request, memory)
return self.parse(solution)