Spaces:
Runtime error
Runtime error
Synced repo using 'sync_with_huggingface' Github Action
Browse files- app.py +7 -0
- deciders/act.py +2 -2
- deciders/cot.py +2 -2
- deciders/exe.py +2 -2
- deciders/reflexion.py +2 -2
- deciders/self_consistency.py +2 -2
- deciders/selfask.py +2 -2
- deciders/spp.py +2 -2
- deciders/utils.py +26 -4
- envs/__init__.py +16 -56
- envs/mujoco/invertedDoublePendulum_translator.py +14 -19
- envs/mujoco/invertedPendulum_translator.py +19 -14
- record_reflexion.csv +1 -2
- shell/test_reflexion.sh +28 -28
app.py
CHANGED
@@ -260,6 +260,7 @@ def main_progress(
|
|
260 |
|
261 |
if __name__ == "__main__":
|
262 |
|
|
|
263 |
|
264 |
# install Atari ROMs
|
265 |
subprocess.run(['AutoROM', '--accept-license'])
|
@@ -357,6 +358,12 @@ if __name__ == "__main__":
|
|
357 |
"FrozenLake-v1",
|
358 |
"MountainCarContinuous-v0",
|
359 |
"Ant-v4",
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
"RepresentedBoxing-v0",
|
361 |
"RepresentedPong-v0",
|
362 |
"RepresentedMsPacman-v0",
|
|
|
260 |
|
261 |
if __name__ == "__main__":
|
262 |
|
263 |
+
# Github action test 8
|
264 |
|
265 |
# install Atari ROMs
|
266 |
subprocess.run(['AutoROM', '--accept-license'])
|
|
|
358 |
"FrozenLake-v1",
|
359 |
"MountainCarContinuous-v0",
|
360 |
"Ant-v4",
|
361 |
+
"HalfCheetah-v4",
|
362 |
+
"Hopper-v4",
|
363 |
+
"Walker2d-v4",
|
364 |
+
"Swimmer-v4",
|
365 |
+
"Reacher-v4",
|
366 |
+
"Pusher-v4",
|
367 |
"RepresentedBoxing-v0",
|
368 |
"RepresentedPong-v0",
|
369 |
"RepresentedMsPacman-v0",
|
deciders/act.py
CHANGED
@@ -26,7 +26,7 @@ class RandomAct():
|
|
26 |
return action, '', '', '', 0, 0
|
27 |
|
28 |
class NaiveAct(gpt):
|
29 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
|
30 |
self.action_space = action_space
|
31 |
self.temperature = temperature
|
32 |
self.action_desc_dict = args.action_desc_dict
|
@@ -39,7 +39,7 @@ class NaiveAct(gpt):
|
|
39 |
else:
|
40 |
model = args.gpt_version
|
41 |
self.encoding = tiktoken.encoding_for_model(model)
|
42 |
-
super().__init__(args)
|
43 |
self.distiller = distiller
|
44 |
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
45 |
if isinstance(self.action_space, Discrete):
|
|
|
26 |
return action, '', '', '', 0, 0
|
27 |
|
28 |
class NaiveAct(gpt):
|
29 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
|
30 |
self.action_space = action_space
|
31 |
self.temperature = temperature
|
32 |
self.action_desc_dict = args.action_desc_dict
|
|
|
39 |
else:
|
40 |
model = args.gpt_version
|
41 |
self.encoding = tiktoken.encoding_for_model(model)
|
42 |
+
super().__init__(args, openai_key)
|
43 |
self.distiller = distiller
|
44 |
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
45 |
if isinstance(self.action_space, Discrete):
|
deciders/cot.py
CHANGED
@@ -17,8 +17,8 @@ from .utils import run_chain
|
|
17 |
|
18 |
|
19 |
class ChainOfThought(NaiveAct):
|
20 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
-
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
|
23 |
def act(
|
24 |
self,
|
|
|
17 |
|
18 |
|
19 |
class ChainOfThought(NaiveAct):
|
20 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
|
23 |
def act(
|
24 |
self,
|
deciders/exe.py
CHANGED
@@ -20,8 +20,8 @@ from loguru import logger
|
|
20 |
|
21 |
|
22 |
class EXE(NaiveAct):
|
23 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
|
24 |
-
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
25 |
self.pre_memory = []
|
26 |
self.post_memory = []
|
27 |
self.is_first = True
|
|
|
20 |
|
21 |
|
22 |
class EXE(NaiveAct):
|
23 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
|
24 |
+
super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
25 |
self.pre_memory = []
|
26 |
self.post_memory = []
|
27 |
self.is_first = True
|
deciders/reflexion.py
CHANGED
@@ -19,8 +19,8 @@ from .utils import run_chain
|
|
19 |
|
20 |
|
21 |
class Reflexion(NaiveAct):
|
22 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
23 |
-
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
24 |
|
25 |
def num_tokens_from_string(self,string: str) -> int:
|
26 |
"""Returns the number of tokens in a text string."""
|
|
|
19 |
|
20 |
|
21 |
class Reflexion(NaiveAct):
|
22 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
23 |
+
super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
24 |
|
25 |
def num_tokens_from_string(self,string: str) -> int:
|
26 |
"""Returns the number of tokens in a text string."""
|
deciders/self_consistency.py
CHANGED
@@ -17,9 +17,9 @@ from .utils import run_chain
|
|
17 |
|
18 |
|
19 |
class SelfConsistency(NaiveAct):
|
20 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
temperature = 0.7
|
22 |
-
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
23 |
self.temperature = temperature
|
24 |
|
25 |
def act(
|
|
|
17 |
|
18 |
|
19 |
class SelfConsistency(NaiveAct):
|
20 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
temperature = 0.7
|
22 |
+
super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
23 |
self.temperature = temperature
|
24 |
|
25 |
def act(
|
deciders/selfask.py
CHANGED
@@ -17,8 +17,8 @@ from .utils import run_chain
|
|
17 |
|
18 |
|
19 |
class SelfAskAct(NaiveAct):
|
20 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
-
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
|
23 |
def act(
|
24 |
self,
|
|
|
17 |
|
18 |
|
19 |
class SelfAskAct(NaiveAct):
|
20 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
|
23 |
def act(
|
24 |
self,
|
deciders/spp.py
CHANGED
@@ -16,8 +16,8 @@ from .act import NaiveAct
|
|
16 |
from .utils import run_chain
|
17 |
|
18 |
class SPP(NaiveAct):
|
19 |
-
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
20 |
-
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
21 |
|
22 |
def act(
|
23 |
self,
|
|
|
16 |
from .utils import run_chain
|
17 |
|
18 |
class SPP(NaiveAct):
|
19 |
+
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
20 |
+
super().__init__(openai_key, action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
21 |
|
22 |
def act(
|
23 |
self,
|
deciders/utils.py
CHANGED
@@ -19,8 +19,30 @@ Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
|
|
19 |
# from .gpt import gpt
|
20 |
# gpt().__init__()
|
21 |
|
22 |
-
import timeout_decorator
|
23 |
-
@timeout_decorator.timeout(30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def run_chain(chain, *args, **kwargs):
|
25 |
return chain.run(*args, **kwargs)
|
26 |
|
@@ -54,6 +76,7 @@ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-t
|
|
54 |
temperature=temperature,
|
55 |
# request_timeout = 1
|
56 |
)
|
|
|
57 |
return response.choices[0]["message"]["content"]
|
58 |
|
59 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
@@ -85,5 +108,4 @@ def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo",
|
|
85 |
temperature=temperature,
|
86 |
# request_timeout = 1
|
87 |
)
|
88 |
-
return response.choices[0]["message"]["content"]
|
89 |
-
|
|
|
19 |
# from .gpt import gpt
|
20 |
# gpt().__init__()
|
21 |
|
22 |
+
# import timeout_decorator
|
23 |
+
# @timeout_decorator.timeout(30)
|
24 |
+
# def run_chain(chain, *args, **kwargs):
|
25 |
+
# return chain.run(*args, **kwargs)
|
26 |
+
import concurrent.futures
|
27 |
+
|
28 |
+
def timeout_decorator(timeout):
|
29 |
+
def decorator(function):
|
30 |
+
def wrapper(*args, **kwargs):
|
31 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
32 |
+
future = executor.submit(function, *args, **kwargs)
|
33 |
+
try:
|
34 |
+
return future.result(timeout)
|
35 |
+
except concurrent.futures.TimeoutError:
|
36 |
+
raise RuntimeError(
|
37 |
+
f"Function '{function.__name__}' timed out after {timeout} seconds"
|
38 |
+
)
|
39 |
+
except Exception as e:
|
40 |
+
raise e
|
41 |
+
return wrapper
|
42 |
+
return decorator
|
43 |
+
|
44 |
+
|
45 |
+
@timeout_decorator(30)
|
46 |
def run_chain(chain, *args, **kwargs):
|
47 |
return chain.run(*args, **kwargs)
|
48 |
|
|
|
76 |
temperature=temperature,
|
77 |
# request_timeout = 1
|
78 |
)
|
79 |
+
import pdb; pdb.set_trace()
|
80 |
return response.choices[0]["message"]["content"]
|
81 |
|
82 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
|
|
108 |
temperature=temperature,
|
109 |
# request_timeout = 1
|
110 |
)
|
111 |
+
return response.choices[0]["message"]["content"]
|
|
envs/__init__.py
CHANGED
@@ -18,24 +18,25 @@ from .atari import mspacman_policies, mspacman_translator
|
|
18 |
from .atari import montezumarevenge_policies, montezumarevenge_translator
|
19 |
register_environments()
|
20 |
|
|
|
21 |
|
22 |
REGISTRY = {}
|
23 |
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
24 |
REGISTRY["base_env"] = BaseEnv
|
25 |
-
REGISTRY["
|
26 |
-
REGISTRY["
|
27 |
REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
|
28 |
REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
|
29 |
REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
|
30 |
REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
|
31 |
|
32 |
-
REGISTRY["
|
33 |
REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
|
34 |
REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
|
35 |
|
36 |
-
REGISTRY["
|
37 |
-
REGISTRY["
|
38 |
-
REGISTRY["
|
39 |
|
40 |
REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
|
41 |
REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
|
@@ -54,9 +55,9 @@ REGISTRY["frozenlake_basic_translator"] = frozenlake_translator.BasicStateSequen
|
|
54 |
REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
|
55 |
|
56 |
|
57 |
-
REGISTRY["
|
58 |
-
REGISTRY["
|
59 |
-
REGISTRY["
|
60 |
|
61 |
|
62 |
REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
|
@@ -138,47 +139,6 @@ REGISTRY["RepresentedMontezumaRevenge_basic_policies"] = [
|
|
138 |
montezumarevenge_policies.dedicated_18_policy,
|
139 |
]
|
140 |
|
141 |
-
REGISTRY["RepresentedMsPacman_init_translator"] = mspacman_translator.GameDescriber
|
142 |
-
REGISTRY["RepresentedMsPacman_basic_translator"] = mspacman_translator.BasicStateSequenceTranslator
|
143 |
-
REGISTRY["RepresentedMsPacman_basic_policies"] = [
|
144 |
-
mspacman_policies.real_random_policy,
|
145 |
-
mspacman_policies.pseudo_random_policy,
|
146 |
-
mspacman_policies.dedicated_1_policy,
|
147 |
-
mspacman_policies.dedicated_2_policy,
|
148 |
-
mspacman_policies.dedicated_3_policy,
|
149 |
-
mspacman_policies.dedicated_4_policy,
|
150 |
-
mspacman_policies.dedicated_5_policy,
|
151 |
-
mspacman_policies.dedicated_6_policy,
|
152 |
-
mspacman_policies.dedicated_7_policy,
|
153 |
-
mspacman_policies.dedicated_8_policy,
|
154 |
-
mspacman_policies.dedicated_9_policy,
|
155 |
-
]
|
156 |
-
|
157 |
-
REGISTRY["RepresentedMontezumaRevenge_init_translator"] = montezumarevenge_translator.GameDescriber
|
158 |
-
REGISTRY["RepresentedMontezumaRevenge_basic_translator"] = montezumarevenge_translator.BasicStateSequenceTranslator
|
159 |
-
REGISTRY["RepresentedMontezumaRevenge_basic_policies"] = [
|
160 |
-
montezumarevenge_policies.real_random_policy,
|
161 |
-
montezumarevenge_policies.pseudo_random_policy,
|
162 |
-
montezumarevenge_policies.dedicated_1_policy,
|
163 |
-
montezumarevenge_policies.dedicated_2_policy,
|
164 |
-
montezumarevenge_policies.dedicated_3_policy,
|
165 |
-
montezumarevenge_policies.dedicated_4_policy,
|
166 |
-
montezumarevenge_policies.dedicated_5_policy,
|
167 |
-
montezumarevenge_policies.dedicated_6_policy,
|
168 |
-
montezumarevenge_policies.dedicated_7_policy,
|
169 |
-
montezumarevenge_policies.dedicated_8_policy,
|
170 |
-
montezumarevenge_policies.dedicated_9_policy,
|
171 |
-
montezumarevenge_policies.dedicated_10_policy,
|
172 |
-
montezumarevenge_policies.dedicated_11_policy,
|
173 |
-
montezumarevenge_policies.dedicated_12_policy,
|
174 |
-
montezumarevenge_policies.dedicated_13_policy,
|
175 |
-
montezumarevenge_policies.dedicated_14_policy,
|
176 |
-
montezumarevenge_policies.dedicated_15_policy,
|
177 |
-
montezumarevenge_policies.dedicated_16_policy,
|
178 |
-
montezumarevenge_policies.dedicated_17_policy,
|
179 |
-
montezumarevenge_policies.dedicated_18_policy,
|
180 |
-
]
|
181 |
-
|
182 |
## For mujoco env
|
183 |
|
184 |
|
@@ -196,12 +156,12 @@ from .mujoco import walker2d_translator, walker2d_policies
|
|
196 |
|
197 |
|
198 |
|
199 |
-
REGISTRY["
|
200 |
-
REGISTRY["
|
201 |
-
REGISTRY["
|
202 |
-
REGISTRY["
|
203 |
-
REGISTRY["
|
204 |
-
REGISTRY["
|
205 |
|
206 |
|
207 |
REGISTRY["swimmer_init_translator"] = swimmer_translator.GameDescriber
|
|
|
18 |
from .atari import montezumarevenge_policies, montezumarevenge_translator
|
19 |
register_environments()
|
20 |
|
21 |
+
from .mujoco import ant_translator, ant_policies
|
22 |
|
23 |
REGISTRY = {}
|
24 |
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
25 |
REGISTRY["base_env"] = BaseEnv
|
26 |
+
REGISTRY["cartpole_init_translator"] = cartpole_translator.GameDescriber
|
27 |
+
REGISTRY["cartpole_basic_translator"] = cartpole_translator.BasicStateSequenceTranslator
|
28 |
REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
|
29 |
REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
|
30 |
REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
|
31 |
REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
|
32 |
|
33 |
+
REGISTRY["cartpole_policies"] = [cartpole_policies.dedicated_1_policy, cartpole_policies.dedicated_2_policy, cartpole_policies.pseudo_random_policy, cartpole_policies.real_random_policy]
|
34 |
REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
|
35 |
REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
|
36 |
|
37 |
+
REGISTRY["lunarlander_init_translator"] = LunarLander_translator.GameDescriber
|
38 |
+
REGISTRY["lunarlander_basic_translator"] = LunarLander_translator.BasicStateSequenceTranslator
|
39 |
+
REGISTRY["lunarlander_policies"] = [LunarLander_policies.dedicated_1_policy, LunarLander_policies.dedicated_2_policy, LunarLander_policies.dedicated_3_policy,LunarLander_policies.dedicated_4_policy, LunarLander_policies.pseudo_random_policy, LunarLander_policies.real_random_policy]
|
40 |
|
41 |
REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
|
42 |
REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
|
|
|
55 |
REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
|
56 |
|
57 |
|
58 |
+
REGISTRY["mountaincarcontinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
|
59 |
+
REGISTRY["mountaincarcontinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
|
60 |
+
REGISTRY["mountaincarcontinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
|
61 |
|
62 |
|
63 |
REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
|
|
|
139 |
montezumarevenge_policies.dedicated_18_policy,
|
140 |
]
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
## For mujoco env
|
143 |
|
144 |
|
|
|
156 |
|
157 |
|
158 |
|
159 |
+
REGISTRY["invertedpendulum_init_translator"] = invertedPendulum_translator.GameDescriber
|
160 |
+
REGISTRY["invertedpendulum_basic_translator"] = invertedPendulum_translator.BasicStateSequenceTranslator
|
161 |
+
REGISTRY["invertedpendulum_policies"] = [invertedPendulum_policies.pseudo_random_policy, invertedPendulum_policies.real_random_policy]
|
162 |
+
REGISTRY["inverteddoublependulum_init_translator"] = invertedDoublePendulum_translator.GameDescriber
|
163 |
+
REGISTRY["inverteddoublependulum_basic_translator"] = invertedDoublePendulum_translator.BasicStateSequenceTranslator
|
164 |
+
REGISTRY["inverteddoublependulum_policies"] = [invertedDoublePendulum_policies.pseudo_random_policy, invertedDoublePendulum_policies.real_random_policy]
|
165 |
|
166 |
|
167 |
REGISTRY["swimmer_init_translator"] = swimmer_translator.GameDescriber
|
envs/mujoco/invertedDoublePendulum_translator.py
CHANGED
@@ -7,16 +7,9 @@ class BasicLevelTranslator:
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
-
f"
|
11 |
-
f"
|
12 |
-
f"
|
13 |
-
f"Cosine of the angle between two poles: {state[4]:.2f}\n"
|
14 |
-
f"Velocity of the cart: {state[5]:.2f} m/s\n"
|
15 |
-
f"Angular velocity of angle between cart and first pole: {state[6]:.2f} rad/s\n"
|
16 |
-
f"Angular velocity of angle between two poles: {state[7]:.2f} rad/s\n"
|
17 |
-
f"Constraint Force 1: {state[8]:.2f} N\n"
|
18 |
-
f"Constraint Force 2: {state[9]:.2f} N\n"
|
19 |
-
f"Constraint Force 3: {state[10]:.2f} N"
|
20 |
)
|
21 |
return res
|
22 |
|
@@ -25,7 +18,7 @@ class GameDescriber:
|
|
25 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
26 |
self.max_episode_len = args.max_episode_len
|
27 |
self.action_desc_dict = {
|
28 |
-
0: "Apply a force in the range [-
|
29 |
}
|
30 |
self.reward_desc_dict = {}
|
31 |
|
@@ -37,22 +30,24 @@ class GameDescriber:
|
|
37 |
|
38 |
def describe_goal(self):
|
39 |
return (
|
40 |
-
"The goal in the
|
41 |
-
"
|
42 |
)
|
43 |
|
44 |
def describe_game(self):
|
45 |
return (
|
46 |
-
"In the
|
47 |
-
"Your objective is to balance the
|
48 |
-
"to the cart
|
49 |
-
"
|
|
|
50 |
)
|
51 |
|
52 |
def describe_action(self):
|
53 |
return (
|
54 |
-
"Your next move: \n Please provide a numerical value
|
55 |
-
"
|
|
|
56 |
)
|
57 |
|
58 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
+
f"Vertical angle of the pole: {state[1]:.2f} rad\n"
|
11 |
+
f"Linear velocity of the cart: {state[2]:.2f} m/s\n"
|
12 |
+
f"Angular velocity of the pole: {state[3]:.2f} rad/s"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
return res
|
15 |
|
|
|
18 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
19 |
self.max_episode_len = args.max_episode_len
|
20 |
self.action_desc_dict = {
|
21 |
+
0: "Apply a force in the range [-1, 1] to the cart to control its motion.",
|
22 |
}
|
23 |
self.reward_desc_dict = {}
|
24 |
|
|
|
30 |
|
31 |
def describe_goal(self):
|
32 |
return (
|
33 |
+
"The goal in the Inverted Pendulum environment is to balance the pole on top of the cart "\
|
34 |
+
"by applying continuous forces to the cart, keeping it upright."
|
35 |
)
|
36 |
|
37 |
def describe_game(self):
|
38 |
return (
|
39 |
+
"In the Inverted Pendulum environment, you control a cart that can move linearly with a pole "\
|
40 |
+
"attached to it. Your objective is to balance the pole on top of the cart by applying forces "\
|
41 |
+
"to the cart in a way that keeps the pole upright. "\
|
42 |
+
"The environment provides observations of the cart's position, pole angle, velocities, "\
|
43 |
+
"and angular velocities. The goal is to maintain balance as long as possible."
|
44 |
)
|
45 |
|
46 |
def describe_action(self):
|
47 |
return (
|
48 |
+
"Your next move: \n Please provide a numerical value for the force to be applied to the cart. "\
|
49 |
+
"This value should be within the range of [-3, 3], where a positive value indicates applying force "\
|
50 |
+
"in the right direction, and a negative value indicates applying force in the left direction."
|
51 |
)
|
52 |
|
53 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
envs/mujoco/invertedPendulum_translator.py
CHANGED
@@ -7,9 +7,16 @@ class BasicLevelTranslator:
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
-
f"
|
11 |
-
f"
|
12 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
return res
|
15 |
|
@@ -18,7 +25,7 @@ class GameDescriber:
|
|
18 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
19 |
self.max_episode_len = args.max_episode_len
|
20 |
self.action_desc_dict = {
|
21 |
-
0: "Apply a force in the range [-
|
22 |
}
|
23 |
self.reward_desc_dict = {}
|
24 |
|
@@ -30,24 +37,22 @@ class GameDescriber:
|
|
30 |
|
31 |
def describe_goal(self):
|
32 |
return (
|
33 |
-
"The goal in the
|
34 |
-
"by applying continuous forces
|
35 |
)
|
36 |
|
37 |
def describe_game(self):
|
38 |
return (
|
39 |
-
"In the
|
40 |
-
"
|
41 |
-
"to the cart
|
42 |
-
"
|
43 |
-
"and angular velocities. The goal is to maintain balance as long as possible."
|
44 |
)
|
45 |
|
46 |
def describe_action(self):
|
47 |
return (
|
48 |
-
"Your next move: \n Please provide a numerical value
|
49 |
-
"
|
50 |
-
"in the right direction, and a negative value indicates applying force in the left direction."
|
51 |
)
|
52 |
|
53 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
+
f"Sine of the angle between cart and first pole: {state[1]:.2f}\n"
|
11 |
+
f"Sine of the angle between two poles: {state[2]:.2f}\n"
|
12 |
+
f"Cosine of the angle between cart and first pole: {state[3]:.2f}\n"
|
13 |
+
f"Cosine of the angle between two poles: {state[4]:.2f}\n"
|
14 |
+
f"Velocity of the cart: {state[5]:.2f} m/s\n"
|
15 |
+
f"Angular velocity of angle between cart and first pole: {state[6]:.2f} rad/s\n"
|
16 |
+
f"Angular velocity of angle between two poles: {state[7]:.2f} rad/s\n"
|
17 |
+
f"Constraint Force 1: {state[8]:.2f} N\n"
|
18 |
+
f"Constraint Force 2: {state[9]:.2f} N\n"
|
19 |
+
f"Constraint Force 3: {state[10]:.2f} N"
|
20 |
)
|
21 |
return res
|
22 |
|
|
|
25 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
26 |
self.max_episode_len = args.max_episode_len
|
27 |
self.action_desc_dict = {
|
28 |
+
0: "Apply a force in the range [-3, 3] to the cart to control its motion.",
|
29 |
}
|
30 |
self.reward_desc_dict = {}
|
31 |
|
|
|
37 |
|
38 |
def describe_goal(self):
|
39 |
return (
|
40 |
+
"The goal in the InvertedDoublePendulum environment is to balance the two poles "\
|
41 |
+
"on top of the cart by applying continuous forces on the cart."
|
42 |
)
|
43 |
|
44 |
def describe_game(self):
|
45 |
return (
|
46 |
+
"In the InvertedDoublePendulum environment, you control a system with a cart and two poles. "\
|
47 |
+
"Your objective is to balance the two poles on top of the cart by applying continuous forces "\
|
48 |
+
"to the cart. The environment provides observations of the cart's position, angles of the poles, "\
|
49 |
+
"and their angular velocities. The episode ends when certain termination conditions are met."
|
|
|
50 |
)
|
51 |
|
52 |
def describe_action(self):
|
53 |
return (
|
54 |
+
"Your next move: \n Please provide a numerical value within the range of [-3,3], "\
|
55 |
+
"representing the force to be applied to the cart."
|
|
|
56 |
)
|
57 |
|
58 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
record_reflexion.csv
CHANGED
@@ -19,5 +19,4 @@ Walker2d-v4,1,expert,5000.0
|
|
19 |
Swimmer-v4,1,expert,44.4
|
20 |
Reacher-v4,1,expert,-2.6
|
21 |
Pusher-v4,1,expert,-52.3
|
22 |
-
|
23 |
-
InvertedDoublePendulum-v4,1,expert,9359.5
|
|
|
19 |
Swimmer-v4,1,expert,44.4
|
20 |
Reacher-v4,1,expert,-2.6
|
21 |
Pusher-v4,1,expert,-52.3
|
22 |
+
|
|
shell/test_reflexion.sh
CHANGED
@@ -1,43 +1,43 @@
|
|
1 |
|
2 |
# CartPole-v0
|
3 |
# Naive Actor
|
4 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
5 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
6 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
7 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
8 |
|
9 |
# COT
|
10 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
11 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
12 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
13 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
14 |
|
15 |
# self consistency
|
16 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
17 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
18 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
19 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
20 |
|
21 |
# self-ask
|
22 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
23 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
24 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
25 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
26 |
|
27 |
# SPP
|
28 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
29 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
30 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
31 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
32 |
|
33 |
# REFLEXION
|
34 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
35 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
36 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
37 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
38 |
|
39 |
# exe
|
40 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
41 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
42 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
43 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
|
|
1 |
|
2 |
# CartPole-v0
|
3 |
# Naive Actor
|
4 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider naive_actor --prompt_level 1 --num_trails 1
|
5 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider naive_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
6 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider naive_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
7 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider naive_actor --prompt_level 5 --num_trails 1
|
8 |
|
9 |
# COT
|
10 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider cot_actor --prompt_level 1 --num_trails 1
|
11 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider cot_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
12 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider cot_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
13 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider cot_actor --prompt_level 5 --num_trails 1
|
14 |
|
15 |
# self consistency
|
16 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider self_consistency_actor --prompt_level 1 --num_trails 1
|
17 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider self_consistency_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
18 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider self_consistency_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
19 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider self_consistency_actor --prompt_level 5 --num_trails 1
|
20 |
|
21 |
# self-ask
|
22 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider selfask_actor --prompt_level 1 --num_trails 1
|
23 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider selfask_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
24 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider selfask_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
25 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider selfask_actor --prompt_level 5 --num_trails 1
|
26 |
|
27 |
# SPP
|
28 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider spp_actor --prompt_level 1 --num_trails 1
|
29 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider spp_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
30 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider spp_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
31 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider spp_actor --prompt_level 5 --num_trails 1
|
32 |
|
33 |
# REFLEXION
|
34 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider reflexion_actor --prompt_level 1 --num_trails 1 --distiller reflect_distiller
|
35 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider reflexion_actor --prompt_level 3 -num_trails 2 --distiller reflect_distiller
|
36 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider reflexion_actor --prompt_level 4 --num_trails 1 --distiller reflect_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
37 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider reflexion_actor --prompt_level 5 --num_trails 1 --distiller reflect_distiller
|
38 |
|
39 |
# exe
|
40 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider exe_actor --prompt_level 1 --num_trails 1 --distiller guide_generator
|
41 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider exe_actor --prompt_level 3 -num_trails 2 --distiller guide_generator
|
42 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider exe_actor --prompt_level 4 --num_trails 1 --distiller guide_generator --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
43 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cartpole_init_translator --curr_summarizer cartpole_basic_translator --decider exe_actor --prompt_level 5 --num_trails 1 --distiller guide_generator
|