|
import os |
|
import re |
|
import sys |
|
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) |
|
sys.path.append('/home/xj/toolAugEnv/code/toolConstraint') |
|
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,react_reflect_planner_agent_prompt,reflect_prompt |
|
|
|
import json |
|
import time |
|
from langchain.callbacks import get_openai_callback |
|
|
|
from tqdm import tqdm |
|
from tools.planner.apis import Planner, ReactPlanner, ReactReflectPlanner |
|
import openai |
|
|
|
os.environ["http_proxy"] = "http://127.0.0.1:7890" |
|
os.environ["https_proxy"] = "http://127.0.0.1:7890" |
|
|
|
|
|
|
|
def load_line_json_data(filename): |
|
data = [] |
|
with open(filename, 'r', encoding='utf-8') as f: |
|
for line in f.read().strip().split('\n'): |
|
unit = json.loads(line) |
|
data.append(unit) |
|
return data |
|
|
|
def extract_numbers_from_filenames(directory): |
|
|
|
pattern = r'annotation_(\d+).json' |
|
|
|
|
|
files = os.listdir(directory) |
|
|
|
|
|
numbers = [int(re.search(pattern, file).group(1)) for file in files if re.match(pattern, file)] |
|
|
|
return numbers |
|
|
|
|
|
def catch_openai_api_error(): |
|
error = sys.exc_info()[0] |
|
if error == openai.error.APIConnectionError: |
|
print("APIConnectionError") |
|
elif error == openai.error.RateLimitError: |
|
print("RateLimitError") |
|
time.sleep(60) |
|
elif error == openai.error.APIError: |
|
print("APIError") |
|
elif error == openai.error.AuthenticationError: |
|
print("AuthenticationError") |
|
else: |
|
print("API error:", error) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
model_name=['gpt-3.5-turbo-1106','gpt-4-1106-preview','gemini','mixtral'][1] |
|
set_type = ['dev','test'][0] |
|
method = ['direct','cot','react','reflexion'][0] |
|
directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}' |
|
query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl')) |
|
numbers = [i for i in range(1,len(query_data_list)+1)] |
|
|
|
if method == 'direct': |
|
planner = Planner(model_name=model_name, agent_prompt=planner_agent_prompt) |
|
elif method == 'cot': |
|
planner = Planner(model_name=model_name, agent_prompt=cot_planner_agent_prompt) |
|
elif method == 'react': |
|
planner = ReactPlanner(model_name=model_name, agent_prompt=react_planner_agent_prompt) |
|
elif method == 'reflexion': |
|
planner = ReactReflectPlanner(model_name=model_name, agent_prompt=react_reflect_planner_agent_prompt,reflect_prompt=reflect_prompt) |
|
|
|
|
|
with get_openai_callback() as cb: |
|
for number in tqdm(numbers[:]): |
|
|
|
|
|
human_collected_info_data = json.load(open(os.path.join(directory, 'plan/human_collected_info_{}.json'.format(number)))) |
|
query_data = query_data_list[number-1] |
|
|
|
while True: |
|
if method in ['react','reflexion']: |
|
planner_results, scratchpad = planner.run(human_collected_info_data, query_data['query']) |
|
else: |
|
planner_results = planner.run(human_collected_info_data, query_data['query']) |
|
if planner_results != None: |
|
break |
|
print(planner_results) |
|
|
|
if not os.path.exists(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}')): |
|
os.makedirs(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}')) |
|
if not os.path.exists(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json')): |
|
result = [{}] |
|
else: |
|
result = json.load(open(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json'))) |
|
if method in ['react','reflexion']: |
|
result[-1][f'{model_name}_{method}_collected_info_results_logs'] = scratchpad |
|
result[-1][f'{model_name}_{method}_collected_info_results'] = planner_results |
|
|
|
with open(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json'), 'w') as f: |
|
json.dump(result, f, indent=4) |
|
print(cb) |