hsaest commited on
Commit
47ec500
1 Parent(s): c623a92

Delete tools/planner

Browse files
tools/planner/__pycache__/apis.cpython-39.pyc DELETED
Binary file (11.2 kB)
 
tools/planner/__pycache__/env.cpython-39.pyc DELETED
Binary file (5.9 kB)
 
tools/planner/apis.py DELETED
@@ -1,388 +0,0 @@
1
- import sys
2
- import os
3
- sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
4
- from langchain.prompts import PromptTemplate
5
- from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,reflect_prompt,react_reflect_planner_agent_prompt, REFLECTION_HEADER
6
- from langchain.chat_models import ChatOpenAI
7
- from langchain.llms.base import BaseLLM
8
- from langchain.schema import (
9
- AIMessage,
10
- HumanMessage,
11
- SystemMessage
12
- )
13
- from env import ReactEnv,ReactReflectEnv
14
- import tiktoken
15
- import re
16
- import openai
17
- import time
18
- from enum import Enum
19
- from typing import List, Union, Literal
20
- from langchain_google_genai import ChatGoogleGenerativeAI
21
-
22
-
23
- def catch_openai_api_error():
24
- error = sys.exc_info()[0]
25
- if error == openai.error.APIConnectionError:
26
- print("APIConnectionError")
27
- elif error == openai.error.RateLimitError:
28
- print("RateLimitError")
29
- time.sleep(60)
30
- elif error == openai.error.APIError:
31
- print("APIError")
32
- elif error == openai.error.AuthenticationError:
33
- print("AuthenticationError")
34
- else:
35
- print("API error:", error)
36
-
37
-
38
- class ReflexionStrategy(Enum):
39
- """
40
- REFLEXION: Apply reflexion to the next reasoning trace
41
- """
42
- REFLEXION = 'reflexion'
43
-
44
-
45
- class Planner:
46
- def __init__(self,
47
- # args,
48
- agent_prompt: PromptTemplate = planner_agent_prompt,
49
- model_name: str = 'gpt-3.5-turbo-1106',
50
- ) -> None:
51
-
52
- self.agent_prompt = agent_prompt
53
- self.scratchpad: str = ''
54
- self.model_name = model_name
55
- self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
56
-
57
- if model_name in ['mistral-7B-32K']:
58
- self.llm = ChatOpenAI(temperature=0,
59
- max_tokens=4096,
60
- openai_api_key="EMPTY",
61
- openai_api_base="http://localhost:8301/v1",
62
- model_name="gpt-3.5-turbo")
63
-
64
- if model_name in ['ChatGLM3-6B-32K']:
65
- self.llm = ChatOpenAI(temperature=0,
66
- max_tokens=4096,
67
- openai_api_key="EMPTY",
68
- openai_api_base="http://localhost:8501/v1",
69
- model_name="gpt-3.5-turbo")
70
-
71
- elif model_name in ['mixtral']:
72
- self.max_token_length = 30000
73
- self.llm = ChatOpenAI(temperature=0,
74
- max_tokens=4096,
75
- openai_api_key="EMPTY",
76
- openai_api_base="http://10.176.40.135:8000/v1",
77
- model_name="/home/huggingface_models/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/e0bbb53cee412aba95f3b3fa4fc0265b1a0788b2")
78
-
79
- elif model_name in ['gemini']:
80
- self.llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key='AIzaSyDarE2hG-cCeE6-GzNcEHflQa4kjY0QCK0')
81
- else:
82
- self.llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=4096, openai_api_key='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')
83
-
84
-
85
- print(f"PlannerAgent {model_name} loaded.")
86
-
87
- def run(self, text, query, log_file=None) -> str:
88
- if log_file:
89
- log_file.write('\n---------------Planner\n'+self._build_agent_prompt(text, query))
90
- # print(self._build_agent_prompt(text, query))
91
- if self.model_name in ['gemini']:
92
- return str(self.llm.invoke(self._build_agent_prompt(text, query)).content)
93
- else:
94
- if len(self.enc.encode(self._build_agent_prompt(text, query))) > 12000:
95
- return 'Max Token Length Exceeded.'
96
- else:
97
- return self.llm([HumanMessage(content=self._build_agent_prompt(text, query))]).content
98
-
99
- def _build_agent_prompt(self, text, query) -> str:
100
- return self.agent_prompt.format(
101
- text=text,
102
- query=query)
103
-
104
-
105
- class ReactPlanner:
106
- """
107
- A question answering ReAct Agent.
108
- """
109
- def __init__(self,
110
- agent_prompt: PromptTemplate = react_planner_agent_prompt,
111
- model_name: str = 'gpt-3.5-turbo-1106',
112
- ) -> None:
113
-
114
- self.agent_prompt = agent_prompt
115
- self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key='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',model_kwargs={"stop": ["Action","Thought","Observation"]})
116
- self.env = ReactEnv()
117
- self.query = None
118
- self.max_steps = 30
119
- self.reset()
120
- self.finished = False
121
- self.answer = ''
122
- self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
123
-
124
- def run(self, text, query, reset = True) -> None:
125
-
126
- self.query = query
127
- self.text = text
128
-
129
- if reset:
130
- self.reset()
131
-
132
-
133
- while not (self.is_halted() or self.is_finished()):
134
- self.step()
135
-
136
- return self.answer, self.scratchpad
137
-
138
-
139
- def step(self) -> None:
140
- # Think
141
- self.scratchpad += f'\nThought {self.curr_step}:'
142
- self.scratchpad += ' ' + self.prompt_agent()
143
- print(self.scratchpad.split('\n')[-1])
144
-
145
- # Act
146
- self.scratchpad += f'\nAction {self.curr_step}:'
147
- action = self.prompt_agent()
148
- self.scratchpad += ' ' + action
149
- print(self.scratchpad.split('\n')[-1])
150
-
151
- # Observe
152
- self.scratchpad += f'\nObservation {self.curr_step}: '
153
-
154
- action_type, action_arg = parse_action(action)
155
-
156
- if action_type == 'CostEnquiry':
157
- try:
158
- input_arg = eval(action_arg)
159
- if type(input_arg) != dict:
160
- raise ValueError('The sub plan can not be parsed into json format, please check. Only one day plan is supported.')
161
- observation = f'Cost: {self.env.run(input_arg)}'
162
- except SyntaxError:
163
- observation = f'The sub plan can not be parsed into json format, please check.'
164
- except ValueError as e:
165
- observation = str(e)
166
-
167
- elif action_type == 'Finish':
168
- self.finished = True
169
- observation = f'The plan is finished.'
170
- self.answer = action_arg
171
-
172
- else:
173
- observation = f'Action {action_type} is not supported.'
174
-
175
- self.curr_step += 1
176
-
177
- self.scratchpad += observation
178
- print(self.scratchpad.split('\n')[-1])
179
-
180
- def prompt_agent(self) -> str:
181
- while True:
182
- try:
183
- return format_step(self.react_llm([HumanMessage(content=self._build_agent_prompt())]).content)
184
- except:
185
- catch_openai_api_error()
186
- print(self._build_agent_prompt())
187
- print(len(self.enc.encode(self._build_agent_prompt())))
188
- time.sleep(5)
189
-
190
- def _build_agent_prompt(self) -> str:
191
- return self.agent_prompt.format(
192
- query = self.query,
193
- text = self.text,
194
- scratchpad = self.scratchpad)
195
-
196
- def is_finished(self) -> bool:
197
- return self.finished
198
-
199
- def is_halted(self) -> bool:
200
- return ((self.curr_step > self.max_steps) or (
201
- len(self.enc.encode(self._build_agent_prompt())) > 14000)) and not self.finished
202
-
203
- def reset(self) -> None:
204
- self.scratchpad = ''
205
- self.answer = ''
206
- self.curr_step = 1
207
- self.finished = False
208
-
209
-
210
- class ReactReflectPlanner:
211
- """
212
- A question answering Self-Reflecting React Agent.
213
- """
214
- def __init__(self,
215
- agent_prompt: PromptTemplate = react_reflect_planner_agent_prompt,
216
- reflect_prompt: PromptTemplate = reflect_prompt,
217
- model_name: str = 'gpt-3.5-turbo-1106',
218
- ) -> None:
219
-
220
- self.agent_prompt = agent_prompt
221
- self.reflect_prompt = reflect_prompt
222
- if model_name in ['gemini']:
223
- self.react_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key='AIzaSyDarE2hG-cCeE6-GzNcEHflQa4kjY0QCK0')
224
- self.reflect_llm = ChatGoogleGenerativeAI(temperature=0,model="gemini-pro",google_api_key='AIzaSyDarE2hG-cCeE6-GzNcEHflQa4kjY0QCK0')
225
- else:
226
- self.react_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key='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',model_kwargs={"stop": ["Action","Thought","Observation"]})
227
- self.reflect_llm = ChatOpenAI(model_name=model_name, temperature=0, max_tokens=1024, openai_api_key='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',model_kwargs={"stop": ["Action","Thought","Observation"]})
228
- self.model_name = model_name
229
- self.env = ReactReflectEnv()
230
- self.query = None
231
- self.max_steps = 30
232
- self.reset()
233
- self.finished = False
234
- self.answer = ''
235
- self.reflections: List[str] = []
236
- self.reflections_str: str = ''
237
- self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
238
-
239
- def run(self, text, query, reset = True) -> None:
240
-
241
- self.query = query
242
- self.text = text
243
-
244
- if reset:
245
- self.reset()
246
-
247
-
248
- while not (self.is_halted() or self.is_finished()):
249
- self.step()
250
- if self.env.is_terminated and not self.finished:
251
- self.reflect(ReflexionStrategy.REFLEXION)
252
-
253
-
254
- return self.answer, self.scratchpad
255
-
256
-
257
- def step(self) -> None:
258
- # Think
259
- self.scratchpad += f'\nThought {self.curr_step}:'
260
- self.scratchpad += ' ' + self.prompt_agent()
261
- print(self.scratchpad.split('\n')[-1])
262
-
263
- # Act
264
- self.scratchpad += f'\nAction {self.curr_step}:'
265
- action = self.prompt_agent()
266
- self.scratchpad += ' ' + action
267
- print(self.scratchpad.split('\n')[-1])
268
-
269
- # Observe
270
- self.scratchpad += f'\nObservation {self.curr_step}: '
271
-
272
- action_type, action_arg = parse_action(action)
273
-
274
- if action_type == 'CostEnquiry':
275
- try:
276
- input_arg = eval(action_arg)
277
- if type(input_arg) != dict:
278
- raise ValueError('The sub plan can not be parsed into json format, please check. Only one day plan is supported.')
279
- observation = f'Cost: {self.env.run(input_arg)}'
280
- except SyntaxError:
281
- observation = f'The sub plan can not be parsed into json format, please check.'
282
- except ValueError as e:
283
- observation = str(e)
284
-
285
- elif action_type == 'Finish':
286
- self.finished = True
287
- observation = f'The plan is finished.'
288
- self.answer = action_arg
289
-
290
- else:
291
- observation = f'Action {action_type} is not supported.'
292
-
293
- self.curr_step += 1
294
-
295
- self.scratchpad += observation
296
- print(self.scratchpad.split('\n')[-1])
297
-
298
- def reflect(self, strategy: ReflexionStrategy) -> None:
299
- print('Reflecting...')
300
- if strategy == ReflexionStrategy.REFLEXION:
301
- self.reflections += [self.prompt_reflection()]
302
- self.reflections_str = format_reflections(self.reflections)
303
- else:
304
- raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
305
- print(self.reflections_str)
306
-
307
- def prompt_agent(self) -> str:
308
- while True:
309
- try:
310
- if self.model_name in ['gemini']:
311
- return format_step(self.react_llm.invoke(self._build_agent_prompt()).content)
312
- else:
313
- return format_step(self.react_llm([HumanMessage(content=self._build_agent_prompt())]).content)
314
- except:
315
- catch_openai_api_error()
316
- print(self._build_agent_prompt())
317
- print(len(self.enc.encode(self._build_agent_prompt())))
318
- time.sleep(5)
319
-
320
- def prompt_reflection(self) -> str:
321
- while True:
322
- try:
323
- if self.model_name in ['gemini']:
324
- return format_step(self.reflect_llm.invoke(self._build_reflection_prompt()).content)
325
- else:
326
- return format_step(self.reflect_llm([HumanMessage(content=self._build_reflection_prompt())]).content)
327
- except:
328
- catch_openai_api_error()
329
- print(self._build_reflection_prompt())
330
- print(len(self.enc.encode(self._build_reflection_prompt())))
331
- time.sleep(5)
332
-
333
- def _build_agent_prompt(self) -> str:
334
- return self.agent_prompt.format(
335
- query = self.query,
336
- text = self.text,
337
- scratchpad = self.scratchpad,
338
- reflections = self.reflections_str)
339
-
340
- def _build_reflection_prompt(self) -> str:
341
- return self.reflect_prompt.format(
342
- query = self.query,
343
- text = self.text,
344
- scratchpad = self.scratchpad)
345
-
346
- def is_finished(self) -> bool:
347
- return self.finished
348
-
349
- def is_halted(self) -> bool:
350
- return ((self.curr_step > self.max_steps) or (
351
- len(self.enc.encode(self._build_agent_prompt())) > 14000)) and not self.finished
352
-
353
- def reset(self) -> None:
354
- self.scratchpad = ''
355
- self.answer = ''
356
- self.curr_step = 1
357
- self.finished = False
358
- self.reflections = []
359
- self.reflections_str = ''
360
- self.env.reset()
361
-
362
- def format_step(step: str) -> str:
363
- return step.strip('\n').strip().replace('\n', '')
364
-
365
- def parse_action(string):
366
- pattern = r'^(\w+)\[(.+)\]$'
367
- match = re.match(pattern, string)
368
-
369
- try:
370
- if match:
371
- action_type = match.group(1)
372
- action_arg = match.group(2)
373
- return action_type, action_arg
374
- else:
375
- return None, None
376
-
377
- except:
378
- return None, None
379
-
380
- def format_reflections(reflections: List[str],
381
- header: str = REFLECTION_HEADER) -> str:
382
- if reflections == []:
383
- return ''
384
- else:
385
- return header + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in reflections])
386
-
387
- # if __name__ == '__main__':
388
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/planner/env.py DELETED
@@ -1,208 +0,0 @@
1
- from tools.flights.apis import Flights
2
- from tools.accommodations.apis import Accommodations
3
- from tools.restaurants.apis import Restaurants
4
- from tools.googleDistanceMatrix.apis import GoogleDistanceMatrix
5
- from tools.googlePlaces.apis import GooglePlaces
6
- from tools.attractions.apis import Attractions
7
- from evaluation.hardConstriant import extract_from_to,get_valid_name_city
8
- import math
9
-
10
-
11
-
12
-
13
-
14
- class ReactEnv:
15
- def __init__(self):
16
-
17
- self.flight = Flights()
18
- self.accommodation = Accommodations()
19
- self.restaurants = Restaurants()
20
- self.googleDistanceMatrix = GoogleDistanceMatrix()
21
- self.googlePlaces = GooglePlaces()
22
- self.attractions = Attractions()
23
-
24
- def run(self, tested_data):
25
-
26
- total_cost = 0
27
- unit = tested_data
28
- people_number = tested_data['people_number']
29
- returned_info = []
30
-
31
- if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
32
- value = unit['transportation']
33
- org_city, dest_city = extract_from_to(value)
34
- if org_city == None or dest_city == None:
35
- org_city, dest_city = extract_from_to(unit['current_city'])
36
- if 'flight number' in value.lower():
37
- try:
38
- res = self.flight.data[self.flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]]
39
- if len(res) > 0:
40
- total_cost += res['Price'].values[0] * people_number
41
- else:
42
- returned_info.append('The filght information is not valid')
43
- except:
44
- returned_info.append('The filght information is not valid')
45
-
46
- elif 'self-driving' in value.lower() or 'taxi' in value.lower():
47
- try:
48
- if 'self-driving' in value.lower():
49
- # print(org_city,dest_city)
50
- cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost']
51
- if cost == None:
52
- returned_info.append('The transporation information is not valid, please check.')
53
- else:
54
- total_cost += cost * math.ceil(people_number * 1.0 / 5)
55
- else:
56
- cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost']
57
- if cost == None:
58
- returned_info.append('The transporation information is not valid, please check.')
59
- else:
60
- total_cost += cost * math.ceil(people_number * 1.0 / 4)
61
- except:
62
- returned_info.append('The transporation information is not valid, please check. You have to make sure there are two cities (from A to B) in your transportation plan.')
63
-
64
- if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
65
- name, city = get_valid_name_city(unit['breakfast'])
66
- if name != '-' and city != '-':
67
- res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)]
68
- if len(res) > 0:
69
- total_cost += res['Average Cost'].values[0] * people_number
70
- else:
71
- returned_info.append('The breakfase information is not valid, please check.')
72
-
73
- if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
74
- name, city = get_valid_name_city(unit['lunch'])
75
- if name != '-' and city != '-':
76
- res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)]
77
- if len(res) > 0:
78
- total_cost += res['Average Cost'].values[0] * people_number
79
- else:
80
- returned_info.append('The lunch information is not valid, please check.')
81
-
82
- if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
83
- name, city = get_valid_name_city(unit['dinner'])
84
- if name != '-' and city != '-':
85
- res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)]
86
- if len(res) > 0:
87
- total_cost += res['Average Cost'].values[0] * people_number
88
- else:
89
- returned_info.append('The dinner information is not valid, please check.')
90
-
91
- if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
92
- name, city = get_valid_name_city(unit['accommodation'])
93
- if name != '-' and city != '-':
94
- res = self.accommodation.data[(self.accommodation.data['NAME'] == name) & (self.accommodation.data['city'] == city)]
95
- if len(res) > 0:
96
- total_cost += res['price'].values[0] * math.ceil(people_number * 1.0 / res['maximum occupancy'].values[0])
97
- else:
98
- returned_info.append('The accommodation information is not valid, please check.')
99
-
100
- if len(returned_info) == 0:
101
- return "The cost of your plan is " + str(total_cost) + " dollars."
102
- else:
103
- message = "Sorry, the cost of your plan is not available because of the following reasons:"
104
- for idx, info in enumerate(returned_info):
105
- message += str(idx + 1) + ". " + info + " " + '\t'
106
- return message
107
-
108
- class ReactReflectEnv(ReactEnv):
109
- def __init__(self):
110
- super().__init__()
111
- self.is_terminated = False
112
- self.max_retry_step = 3
113
- self.retry_step = 0
114
-
115
- def reset(self):
116
- self.is_terminated = False
117
- self.retry_step = 0
118
-
119
- def run(self, tested_data):
120
- total_cost = 0
121
- unit = tested_data
122
- people_number = tested_data['people_number']
123
- returned_info = []
124
-
125
- if 'transportation' in unit and unit['transportation'] and unit['transportation'] != '-':
126
- value = unit['transportation']
127
- org_city, dest_city = extract_from_to(value)
128
- if org_city == None or dest_city == None:
129
- org_city, dest_city = extract_from_to(unit['current_city'])
130
-
131
-
132
- if org_city == None or dest_city == None:
133
- returned_info.append('The transporation information is not valid, please check.')
134
-
135
- else:
136
- if 'flight number' in value.lower():
137
- try:
138
- res = self.flight.data[self.flight.data['Flight Number'] == value.split('Flight Number: ')[1].split(',')[0]]
139
- if len(res) > 0:
140
- total_cost += res['Price'].values[0] * people_number
141
- else:
142
- returned_info.append('The filght information is not valid')
143
- except:
144
- returned_info.append('The filght information is not valid')
145
-
146
- elif 'self-driving' in value.lower() or 'taxi' in value.lower():
147
- if 'self-driving' in value.lower():
148
- cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'self-driving')['cost']
149
- if cost == None:
150
- returned_info.append('The transporation information is not valid, please check.')
151
- else:
152
- total_cost += cost * math.ceil(people_number * 1.0 / 5)
153
- else:
154
- cost = self.googleDistanceMatrix.run_for_evaluation(org_city,dest_city,'taxi')['cost']
155
- if cost == None:
156
- returned_info.append('The transporation information is not valid, please check.')
157
- else:
158
- total_cost += cost * math.ceil(people_number * 1.0 / 4)
159
-
160
- if 'breakfast' in unit and unit['breakfast'] and unit['breakfast'] != '-':
161
- name, city = get_valid_name_city(unit['breakfast'])
162
- if name != '-' and city != '-':
163
- res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)]
164
- if len(res) > 0:
165
- total_cost += res['Average Cost'].values[0] * people_number
166
- else:
167
- returned_info.append('The breakfase information is not valid, please check.')
168
-
169
- if 'lunch' in unit and unit['lunch'] and unit['lunch'] != '-':
170
- name, city = get_valid_name_city(unit['lunch'])
171
- if name != '-' and city != '-':
172
- res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)]
173
- if len(res) > 0:
174
- total_cost += res['Average Cost'].values[0] * people_number
175
- else:
176
- returned_info.append('The lunch information is not valid, please check.')
177
-
178
- if 'dinner' in unit and unit['dinner'] and unit['dinner'] != '-':
179
- name, city = get_valid_name_city(unit['dinner'])
180
- if name != '-' and city != '-':
181
- res = self.restaurants.data[(self.restaurants.data['Name'] == name) & (self.restaurants.data['City'] == city)]
182
- if len(res) > 0:
183
- total_cost += res['Average Cost'].values[0] * people_number
184
- else:
185
- returned_info.append('The dinner information is not valid, please check.')
186
-
187
- if 'accommodation' in unit and unit['accommodation'] and unit['accommodation'] != '-':
188
- name, city = get_valid_name_city(unit['accommodation'])
189
- if name != '-' and city != '-':
190
- res = self.accommodation.data[(self.accommodation.data['NAME'] == name) & (self.accommodation.data['city'] == city)]
191
- if len(res) > 0:
192
- total_cost += res['price'].values[0] * math.ceil(people_number * 1.0 / res['maximum occupancy'].values[0])
193
- else:
194
- returned_info.append('The accommodation information is not valid, please check.')
195
-
196
- if len(returned_info) == 0:
197
- self.retry_step = 0
198
- self.is_terminated = False
199
- return "The cost of your plan is " + str(total_cost) + " dollars."
200
- else:
201
- message = "Sorry, the cost of your plan is not available because of the following reasons:"
202
- for idx, info in enumerate(returned_info):
203
- message += str(idx + 1) + ". " + info + " " + '\t'
204
- self.retry_step += 1
205
- if self.retry_step >= self.max_retry_step:
206
- self.is_terminated = True
207
- return message
208
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/planner/planner_with_human_annotated_info.py DELETED
@@ -1,124 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
5
- sys.path.append('/home/xj/toolAugEnv/code/toolConstraint')
6
- # print(sys.path)
7
- os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
- from agents.prompts import planner_agent_prompt, cot_planner_agent_prompt, react_planner_agent_prompt,react_reflect_planner_agent_prompt,reflect_prompt
9
- # from annotation.src.utils import get_valid_name_city,extract_before_parenthesis, extract_numbers_from_filenames
10
- import json
11
- import time
12
- from langchain.callbacks import get_openai_callback
13
-
14
- from tqdm import tqdm
15
- from tools.planner.apis import Planner, ReactPlanner, ReactReflectPlanner
16
- import openai
17
-
18
- os.environ["http_proxy"] = "http://127.0.0.1:7890"
19
- os.environ["https_proxy"] = "http://127.0.0.1:7890"
20
-
21
-
22
-
23
- def load_line_json_data(filename):
24
- data = []
25
- with open(filename, 'r', encoding='utf-8') as f:
26
- for line in f.read().strip().split('\n'):
27
- unit = json.loads(line)
28
- data.append(unit)
29
- return data
30
-
31
- def extract_numbers_from_filenames(directory):
32
- # Define the pattern to match files
33
- pattern = r'annotation_(\d+).json'
34
-
35
- # List all files in the directory
36
- files = os.listdir(directory)
37
-
38
- # Extract numbers from filenames that match the pattern
39
- numbers = [int(re.search(pattern, file).group(1)) for file in files if re.match(pattern, file)]
40
-
41
- return numbers
42
-
43
-
44
- def catch_openai_api_error():
45
- error = sys.exc_info()[0]
46
- if error == openai.error.APIConnectionError:
47
- print("APIConnectionError")
48
- elif error == openai.error.RateLimitError:
49
- print("RateLimitError")
50
- time.sleep(60)
51
- elif error == openai.error.APIError:
52
- print("APIError")
53
- elif error == openai.error.AuthenticationError:
54
- print("AuthenticationError")
55
- else:
56
- print("API error:", error)
57
-
58
- # if __name__ == "__main__":
59
- # user_name = 'zk'
60
- # directory = '../../data/annotation/{}'.format(user_name)
61
- # query_data_list = load_line_json_data('../../data/query/{}.jsonl'.format(user_name))
62
- # numbers = extract_numbers_from_filenames(directory)
63
- # with get_openai_callback() as cb:
64
- # for number in tqdm(numbers[:10]):
65
- # print(number)
66
- # json_data = json.load(open(os.path.join(directory, 'annotation_{}.json'.format(number))))
67
- # human_collected_info_data = json.load(open(os.path.join(directory, 'human_collected_info_{}.json'.format(number))))
68
- # query_data = query_data_list[number-1]
69
- # planner_results = planner.run(human_collected_info_data, query_data['query'])
70
- # org_result = json.load(open(os.path.join('../../results/turbo16k-turbo16k/{}/plan_{}.json'.format(user_name,number))))
71
- # # org_result.append({})
72
- # org_result[-1]['chatgpt_human_collected_info_results'] = planner_results
73
- # # write to json file
74
- # # with open(os.path.join('../../results/turbo16k-turbo16k/{}/plan_{}.json'.format(user_name,number)), 'w') as f:
75
- # # json.dump(org_result, f, indent=4)
76
- # print(cb)
77
-
78
- if __name__ == "__main__":
79
- model_name=['gpt-3.5-turbo-1106','gpt-4-1106-preview','gemini','mixtral'][1]
80
- set_type = ['dev','test'][0]
81
- method = ['direct','cot','react','reflexion'][0]
82
- directory = f'/home/xj/toolAugEnv/code/toolConstraint/data/final_data/{set_type}'
83
- query_data_list = load_line_json_data(os.path.join(directory, 'query/query.jsonl'))
84
- numbers = [i for i in range(1,len(query_data_list)+1)]
85
-
86
- if method == 'direct':
87
- planner = Planner(model_name=model_name, agent_prompt=planner_agent_prompt)
88
- elif method == 'cot':
89
- planner = Planner(model_name=model_name, agent_prompt=cot_planner_agent_prompt)
90
- elif method == 'react':
91
- planner = ReactPlanner(model_name=model_name, agent_prompt=react_planner_agent_prompt)
92
- elif method == 'reflexion':
93
- planner = ReactReflectPlanner(model_name=model_name, agent_prompt=react_reflect_planner_agent_prompt,reflect_prompt=reflect_prompt)
94
-
95
-
96
- with get_openai_callback() as cb:
97
- for number in tqdm(numbers[:]):
98
- # print(number)
99
- # json_data = json.load(open(os.path.join(directory, 'plan/annotation_{}.json'.format(number))))
100
- human_collected_info_data = json.load(open(os.path.join(directory, 'plan/human_collected_info_{}.json'.format(number))))
101
- query_data = query_data_list[number-1]
102
-
103
- while True:
104
- if method in ['react','reflexion']:
105
- planner_results, scratchpad = planner.run(human_collected_info_data, query_data['query'])
106
- else:
107
- planner_results = planner.run(human_collected_info_data, query_data['query'])
108
- if planner_results != None:
109
- break
110
- print(planner_results)
111
- # check if the directory exists
112
- if not os.path.exists(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}')):
113
- os.makedirs(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}'))
114
- if not os.path.exists(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json')):
115
- result = [{}]
116
- else:
117
- result = json.load(open(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json')))
118
- if method in ['react','reflexion']:
119
- result[-1][f'{model_name}_{method}_collected_info_results_logs'] = scratchpad
120
- result[-1][f'{model_name}_{method}_collected_info_results'] = planner_results
121
- # write to json file
122
- with open(os.path.join(f'/home/xj/toolAugEnv/code/toolConstraint/results/{set_type}/plan_{number}.json'), 'w') as f:
123
- json.dump(result, f, indent=4)
124
- print(cb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/planner/test.py DELETED
@@ -1 +0,0 @@
1
- print(eval("[ddd"))