Spaces:
Running
Running
from nodes.LLMNode import LLMNode | |
from nodes.Worker import WORKER_REGISTRY | |
from prompts.planner import * | |
from utils.util import LLAMA_WEIGHTS | |
class Planner(LLMNode): | |
def __init__(self, workers, prefix=DEFAULT_PREFIX, suffix=DEFAULT_SUFFIX, fewshot=DEFAULT_FEWSHOT, | |
model_name="text-davinci-003", stop=None): | |
super().__init__("Planner", model_name, stop, input_type=str, output_type=str) | |
self.workers = workers | |
self.prefix = prefix | |
self.worker_prompt = self._generate_worker_prompt() | |
self.suffix = suffix | |
self.fewshot = fewshot | |
def run(self, input, log=False): | |
assert isinstance(input, self.input_type) | |
prompt = self.prefix + self.worker_prompt + self.fewshot + self.suffix + input + '\n' | |
if self.model_name in LLAMA_WEIGHTS: | |
prompt = [self.prefix + self.worker_prompt, input] | |
response = self.call_llm(prompt, self.stop) | |
completion = response["output"] | |
if log: | |
return response | |
return completion | |
def _get_worker(self, name): | |
if name in WORKER_REGISTRY: | |
return WORKER_REGISTRY[name] | |
else: | |
raise ValueError("Worker not found") | |
def _generate_worker_prompt(self): | |
prompt = "Tools can be one of the following:\n" | |
for name in self.workers: | |
worker = self._get_worker(name) | |
prompt += f"{worker.name}[input]: {worker.description}\n" | |
return prompt + "\n" | |