ReWOO-Demo / nodes /Planner.py
billxbf's picture
init
926675f
raw
history blame
1.5 kB
from nodes.LLMNode import LLMNode
from nodes.Worker import WORKER_REGISTRY
from prompts.planner import *
from utils.util import LLAMA_WEIGHTS
class Planner(LLMNode):
def __init__(self, workers, prefix=DEFAULT_PREFIX, suffix=DEFAULT_SUFFIX, fewshot=DEFAULT_FEWSHOT,
model_name="text-davinci-003", stop=None):
super().__init__("Planner", model_name, stop, input_type=str, output_type=str)
self.workers = workers
self.prefix = prefix
self.worker_prompt = self._generate_worker_prompt()
self.suffix = suffix
self.fewshot = fewshot
def run(self, input, log=False):
assert isinstance(input, self.input_type)
prompt = self.prefix + self.worker_prompt + self.fewshot + self.suffix + input + '\n'
if self.model_name in LLAMA_WEIGHTS:
prompt = [self.prefix + self.worker_prompt, input]
response = self.call_llm(prompt, self.stop)
completion = response["output"]
if log:
return response
return completion
def _get_worker(self, name):
if name in WORKER_REGISTRY:
return WORKER_REGISTRY[name]
else:
raise ValueError("Worker not found")
def _generate_worker_prompt(self):
prompt = "Tools can be one of the following:\n"
for name in self.workers:
worker = self._get_worker(name)
prompt += f"{worker.name}[input]: {worker.description}\n"
return prompt + "\n"