hzxwonder commited on
Commit
ac09190
1 Parent(s): c640769
deciders/utils.py CHANGED
@@ -39,6 +39,22 @@ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-t
39
  # request_timeout = 1
40
  )
41
  return response.choices[0].text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
44
  def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
@@ -60,4 +76,14 @@ def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo",
60
  # request_timeout = 1
61
  )
62
  return response.choices[0]["message"]["content"]
 
 
 
 
 
 
 
 
 
 
63
 
 
39
  # request_timeout = 1
40
  )
41
  return response.choices[0].text
42
+ elif api_type == "openai":
43
+ messages = [
44
+ {
45
+ "role": "user",
46
+ "content": prompt
47
+ }
48
+ ]
49
+ response = openai.ChatCompletion.create(
50
+ model=engine,
51
+ messages=messages,
52
+ max_tokens=max_tokens,
53
+ stop=stop_strs,
54
+ temperature=temperature,
55
+ # request_timeout = 1
56
+ )
57
+ return response.choices[0]["message"]["content"]
58
 
59
  # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
60
  def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
 
76
  # request_timeout = 1
77
  )
78
  return response.choices[0]["message"]["content"]
79
+ elif api_type == "openai":
80
+ response = openai.ChatCompletion.create(
81
+ model=model,
82
+ messages=messages,
83
+ max_tokens=max_tokens,
84
+ stop=stop_strs,
85
+ temperature=temperature,
86
+ # request_timeout = 1
87
+ )
88
+ return response.choices[0]["message"]["content"]
89
 
distillers/self_reflection.py CHANGED
@@ -45,7 +45,7 @@ class RefletionGenerator():
45
  for i, m in enumerate(memory):
46
  query += f'Trial #{i}: {m}\n'
47
 
48
- query += '\n\nNew plan:'
49
  return query
50
 
51
  def generate(self, traj, memory, max_len_mem=5):
 
45
  for i, m in enumerate(memory):
46
  query += f'Trial #{i}: {m}\n'
47
 
48
+ query += '\n\nPlease give your new plan.'
49
  return query
50
 
51
  def generate(self, traj, memory, max_len_mem=5):
main_reflexion.py CHANGED
@@ -99,7 +99,7 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
99
  error_flag = True
100
  retry_num = 1
101
  for error_i in range(retry_num):
102
- # try:
103
  action, prompt, response, tokens, cost = decider.act(
104
  state_description,
105
  action_description,
@@ -127,15 +127,15 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
127
  current_total_cost += cost
128
  error_flag = False
129
  break
130
- # except Exception as e:
131
- # print(e)
132
- # if error_i < retry_num-1:
133
- # if "Cliff" in args.env_name or "Frozen" in args.env_name:
134
- # decider.env_history.remove_invalid_state()
135
- # decider.env_history.remove_invalid_state()
136
- # if logger:
137
- # logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
138
- # continue
139
  if error_flag:
140
  if "Continuous" in args.env_name:
141
  action = [decider.default_action]
 
99
  error_flag = True
100
  retry_num = 1
101
  for error_i in range(retry_num):
102
+ try:
103
  action, prompt, response, tokens, cost = decider.act(
104
  state_description,
105
  action_description,
 
127
  current_total_cost += cost
128
  error_flag = False
129
  break
130
+ except Exception as e:
131
+ print(e)
132
+ if error_i < retry_num-1:
133
+ if "Cliff" in args.env_name or "Frozen" in args.env_name:
134
+ decider.env_history.remove_invalid_state()
135
+ decider.env_history.remove_invalid_state()
136
+ if logger:
137
+ logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
138
+ continue
139
  if error_flag:
140
  if "Continuous" in args.env_name:
141
  action = [decider.default_action]