Spaces:
Running
Running
import json | |
from typing import List | |
import logging | |
import os | |
import sys | |
sys.path.append(os.getcwd()) | |
from ..base import Action, NON_FILE_TYPES | |
# from cllm.services.tog import TaskSolver, TaskDecomposer, config | |
# from cllm.services.nlp.llms import ChatOpenAI, MessageMemory | |
from cllm.services.tog.api import tog, task_decomposer | |
from collections import OrderedDict | |
import copy | |
logger = logging.getLogger(__name__) | |
class Planner: | |
def __init__( | |
self, streaming=False, backend="remote", device="cuda:0", **llm_kwargs | |
): | |
self.streaming = streaming | |
if backend == "local": | |
pass | |
# self.cfg = config | |
# self.device = device | |
# self.mem = MessageMemory(**self.cfg.memory) | |
# self.llm = ChatOpenAI(temperature=0.2, **llm_kwargs) | |
# self.tog = TaskSolver(self.llm, self.cfg.task_solver_config, device).solve | |
# self.decomposer = TaskDecomposer(device, self.cfg.task_decomposer_cfg).solve | |
elif backend == "remote": | |
self.decomposer = task_decomposer | |
self.tog = tog | |
else: | |
raise ValueError("Backend should be chosen from [remote, local]") | |
def _find_latest_resource(self, resources, type): | |
for key, val in list(resources.items())[::-1]: | |
if val == type: | |
return key | |
return None | |
def _check_task_decomposition( | |
self, task_decomposition: str | list, available_resources: dict | |
): | |
copy_task_decomposition = copy.deepcopy(task_decomposition) | |
available_resources = copy.deepcopy(available_resources) | |
if isinstance(copy_task_decomposition, str): | |
copy_task_decomposition = json.loads(copy_task_decomposition) | |
for subtask in copy_task_decomposition: | |
for arg in subtask["args"]: | |
if arg["type"] in NON_FILE_TYPES: | |
continue | |
r_type = available_resources.get(arg["value"], "None").split(".")[-1] | |
if arg["value"] not in available_resources or arg["type"] != r_type: | |
new_value = self._find_latest_resource( | |
available_resources, arg["type"] | |
) | |
if new_value is None: | |
logger.error( | |
f"No available resource for {arg['value']} with type {arg['type']}" | |
) | |
return None | |
arg["value"] = new_value | |
available_resources[subtask["returns"][0]["value"]] = subtask["returns"][0][ | |
"type" | |
] | |
return json.dumps(copy_task_decomposition, indent=2, ensure_ascii=False) | |
def wrap_request(self, request, memory): | |
logger.info(memory) | |
resource_list = {k: v.split(".")[-1] for k, v in memory.items()} | |
request = f"Resource list: {resource_list}\n{request}" | |
logger.info(f"Input: {request}") | |
return request | |
def solve_streaming(self, request: str, memory: dict = OrderedDict()): | |
request = self.wrap_request(request, memory) | |
sub_tasks = self.decomposer(request, streaming=self.streaming) | |
logger.info(f"Task decomposition: \n{sub_tasks}") | |
sub_tasks = self._check_task_decomposition(sub_tasks, memory) | |
yield sub_tasks | |
if sub_tasks in [None, "", []]: | |
yield None | |
else: | |
solutions = self.tog(request, sub_tasks, streaming=self.streaming) | |
yield solutions | |
def solve(self, request: str, memory: dict = OrderedDict()) -> List: | |
self.wrap_request(request, memory) | |
sub_tasks = self.decomposer(request) | |
solutions = self.tog(request, sub_tasks) | |
print(f"solutions: {solutions}") | |
return sub_tasks, solutions | |
def plan(self, task, memory: dict = OrderedDict()) -> List: | |
if self.streaming: | |
return self.solve_streaming(task, memory) | |
else: | |
return self.solve(task, memory) | |
def _check_solutions(self, solution: List | str) -> bool: | |
if isinstance(solution, str): | |
solution = json.loads(solution) | |
if len(solution) == 0: | |
return False | |
valid = True | |
for i, stage_candiate in enumerate(solution): | |
if len(stage_candiate) == 0: | |
logger.error(f"No solution is found in {i}-th subtask.") | |
valid = False | |
elif ( | |
"solution" in stage_candiate[0] | |
and len(stage_candiate[0]["solution"]) == 0 | |
): | |
logger.error(f"No solution is found in {i+1}-th subtask.") | |
valid = False | |
else: | |
logger.info(f"Solutions for {i+1}-th subtask:\n{stage_candiate}") | |
return valid | |
def parse(self, solution: List | str) -> List[List[Action]]: | |
if isinstance(solution, str): | |
solution = json.loads(solution) | |
if not self._check_solutions(solution): | |
return None | |
if isinstance(solution[0], Action): | |
return solution | |
stages = [] | |
for i, stage_candiate in enumerate(solution): | |
stage = stage_candiate[0]["solution"] | |
actions = [] | |
for action in stage: | |
inputs = {arg["name"]: arg["value"] for arg in action["args"]} | |
outputs = [r["value"] for r in action["returns"]] | |
actions.append( | |
Action(action["tool_name"], inputs=inputs, outputs=outputs) | |
) | |
stages.append(actions) | |
return stages | |
def __call__( | |
self, request: str, memory: dict = OrderedDict() | |
) -> List[List[Action]]: | |
solution = self.solve(request, memory) | |
return self.parse(solution) | |