Spaces:
Runtime error
Runtime error
hzxwonder
commited on
Commit
•
ac09190
1
Parent(s):
c640769
update
Browse files- deciders/utils.py +26 -0
- distillers/self_reflection.py +1 -1
- main_reflexion.py +10 -10
deciders/utils.py
CHANGED
@@ -39,6 +39,22 @@ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-t
|
|
39 |
# request_timeout = 1
|
40 |
)
|
41 |
return response.choices[0].text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
44 |
def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
@@ -60,4 +76,14 @@ def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo",
|
|
60 |
# request_timeout = 1
|
61 |
)
|
62 |
return response.choices[0]["message"]["content"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
|
|
39 |
# request_timeout = 1
|
40 |
)
|
41 |
return response.choices[0].text
|
42 |
+
elif api_type == "openai":
|
43 |
+
messages = [
|
44 |
+
{
|
45 |
+
"role": "user",
|
46 |
+
"content": prompt
|
47 |
+
}
|
48 |
+
]
|
49 |
+
response = openai.ChatCompletion.create(
|
50 |
+
model=engine,
|
51 |
+
messages=messages,
|
52 |
+
max_tokens=max_tokens,
|
53 |
+
stop=stop_strs,
|
54 |
+
temperature=temperature,
|
55 |
+
# request_timeout = 1
|
56 |
+
)
|
57 |
+
return response.choices[0]["message"]["content"]
|
58 |
|
59 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
60 |
def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo", engine: str = "gpt-35-turbo", temperature: float = 0.0, max_tokens: int = 256, stop_strs: Optional[List[str]] = None, is_batched: bool = False) -> str:
|
|
|
76 |
# request_timeout = 1
|
77 |
)
|
78 |
return response.choices[0]["message"]["content"]
|
79 |
+
elif api_type == "openai":
|
80 |
+
response = openai.ChatCompletion.create(
|
81 |
+
model=model,
|
82 |
+
messages=messages,
|
83 |
+
max_tokens=max_tokens,
|
84 |
+
stop=stop_strs,
|
85 |
+
temperature=temperature,
|
86 |
+
# request_timeout = 1
|
87 |
+
)
|
88 |
+
return response.choices[0]["message"]["content"]
|
89 |
|
distillers/self_reflection.py
CHANGED
@@ -45,7 +45,7 @@ class RefletionGenerator():
|
|
45 |
for i, m in enumerate(memory):
|
46 |
query += f'Trial #{i}: {m}\n'
|
47 |
|
48 |
-
query += '\n\
|
49 |
return query
|
50 |
|
51 |
def generate(self, traj, memory, max_len_mem=5):
|
|
|
45 |
for i, m in enumerate(memory):
|
46 |
query += f'Trial #{i}: {m}\n'
|
47 |
|
48 |
+
query += '\n\nPlease give your new plan.'
|
49 |
return query
|
50 |
|
51 |
def generate(self, traj, memory, max_len_mem=5):
|
main_reflexion.py
CHANGED
@@ -99,7 +99,7 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
|
|
99 |
error_flag = True
|
100 |
retry_num = 1
|
101 |
for error_i in range(retry_num):
|
102 |
-
|
103 |
action, prompt, response, tokens, cost = decider.act(
|
104 |
state_description,
|
105 |
action_description,
|
@@ -127,15 +127,15 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
|
|
127 |
current_total_cost += cost
|
128 |
error_flag = False
|
129 |
break
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
if error_flag:
|
140 |
if "Continuous" in args.env_name:
|
141 |
action = [decider.default_action]
|
|
|
99 |
error_flag = True
|
100 |
retry_num = 1
|
101 |
for error_i in range(retry_num):
|
102 |
+
try:
|
103 |
action, prompt, response, tokens, cost = decider.act(
|
104 |
state_description,
|
105 |
action_description,
|
|
|
127 |
current_total_cost += cost
|
128 |
error_flag = False
|
129 |
break
|
130 |
+
except Exception as e:
|
131 |
+
print(e)
|
132 |
+
if error_i < retry_num-1:
|
133 |
+
if "Cliff" in args.env_name or "Frozen" in args.env_name:
|
134 |
+
decider.env_history.remove_invalid_state()
|
135 |
+
decider.env_history.remove_invalid_state()
|
136 |
+
if logger:
|
137 |
+
logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
|
138 |
+
continue
|
139 |
if error_flag:
|
140 |
if "Continuous" in args.env_name:
|
141 |
action = [decider.default_action]
|