Spaces:
Sleeping
Sleeping
# main class chaining Planner, Worker and Solver. | |
import re | |
import time | |
from nodes.Planner import Planner | |
from nodes.Solver import Solver | |
from nodes.Worker import * | |
from utils.util import * | |
class PWS: | |
def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003", | |
solver_model="text-davinci-003"): | |
self.workers = available_tools | |
self.planner = Planner(workers=self.workers, | |
model_name=planner_model, | |
fewshot=fewshot) | |
self.solver = Solver(model_name=solver_model) | |
self.plans = [] | |
self.planner_evidences = {} | |
self.worker_evidences = {} | |
self.tool_counter = {} | |
self.planner_token_unit_price = get_token_unit_price(planner_model) | |
self.solver_token_unit_price = get_token_unit_price(solver_model) | |
self.tool_token_unit_price = get_token_unit_price("text-davinci-003") | |
self.google_unit_price = 0.01 | |
# input: the question line. e.g. "Question: What is the capital of France?" | |
def run(self, input): | |
# run is stateless, so we need to reset the evidences | |
self._reinitialize() | |
result = {} | |
st = time.time() | |
# Plan | |
planner_response = self.planner.run(input, log=True) | |
plan = planner_response["output"] | |
planner_log = planner_response["input"] + planner_response["output"] | |
self.plans = self._parse_plans(plan) | |
self.planner_evidences = self._parse_planner_evidences(plan) | |
#assert len(self.plans) == len(self.planner_evidences) | |
# Work | |
self._get_worker_evidences() | |
worker_log = "" | |
for i in range(len(self.plans)): | |
e = f"#E{i + 1}" | |
worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n" | |
# Solve | |
solver_response = self.solver.run(input, worker_log, log=True) | |
output = solver_response["output"] | |
solver_log = solver_response["input"] + solver_response["output"] | |
result["wall_time"] = time.time() - st | |
result["input"] = input | |
result["output"] = output | |
result["planner_log"] = planner_log | |
result["worker_log"] = worker_log | |
result["solver_log"] = solver_log | |
result["tool_usage"] = self.tool_counter | |
result["steps"] = len(self.plans) + 1 | |
result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \ | |
+ solver_response["prompt_tokens"] + solver_response["completion_tokens"] \ | |
+ self.tool_counter.get("LLM_token", 0) \ | |
+ self.tool_counter.get("Calculator_token", 0) | |
result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \ | |
+ self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \ | |
+ self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0)) | |
result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price | |
result["total_cost"] = result["token_cost"] + result["tool_cost"] | |
return result | |
def _parse_plans(self, response): | |
plans = [] | |
for line in response.splitlines(): | |
if line.startswith("Plan:"): | |
plans.append(line) | |
return plans | |
def _parse_planner_evidences(self, response): | |
evidences = {} | |
for line in response.splitlines(): | |
if line.startswith("#") and line[1] == "E" and line[2].isdigit(): | |
e, tool_call = line.split("=", 1) | |
e, tool_call = e.strip(), tool_call.strip() | |
if len(e) == 3: | |
evidences[e] = tool_call | |
else: | |
evidences[e] = "No evidence found" | |
return evidences | |
# use planner evidences to assign tasks to respective workers. | |
def _get_worker_evidences(self): | |
for e, tool_call in self.planner_evidences.items(): | |
if "[" not in tool_call: | |
self.worker_evidences[e] = tool_call | |
continue | |
tool, tool_input = tool_call.split("[", 1) | |
tool_input = tool_input[:-1] | |
# find variables in input and replace with previous evidences | |
for var in re.findall(r"#E\d+", tool_input): | |
if var in self.worker_evidences: | |
tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]") | |
if tool in self.workers: | |
self.worker_evidences[e] = WORKER_REGISTRY[tool].run(tool_input) | |
if tool == "Google": | |
self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1 # number of query | |
elif tool == "LLM": | |
self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len( | |
tool_input + self.worker_evidences[e]) // 4 | |
elif tool == "Calculator": | |
self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \ | |
+ len( | |
LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[ | |
e]) // 4 | |
else: | |
self.worker_evidences[e] = "No evidence found" | |
def _reinitialize(self): | |
self.plans = [] | |
self.planner_evidences = {} | |
self.worker_evidences = {} | |
self.tool_counter = {} | |
class PWS_Base(PWS): | |
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003", | |
solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]): | |
super().__init__(available_tools=available_tools, | |
fewshot=fewshot, | |
planner_model=planner_model, | |
solver_model=solver_model) | |
class PWS_Extra(PWS): | |
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003", | |
solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]): | |
super().__init__(available_tools=available_tools, | |
fewshot=fewshot, | |
planner_model=planner_model, | |
solver_model=solver_model) | |