Spaces:
Runtime error
Runtime error
Synced repo using 'sync_with_huggingface' Github Action
Browse files- .gitattributes +9 -0
- RL_based/checkpoints/Acrobot-v1/expert/policy.pth +3 -0
- RL_based/checkpoints/Blackjack-v1/expert/policy.pth +3 -0
- RL_based/checkpoints/CartPole-v0/expert/policy.pth +3 -0
- RL_based/checkpoints/CartPole-v1/expert/policy.pth +3 -0
- RL_based/checkpoints/CliffWalking-v0/expert/policy.pth +3 -0
- RL_based/checkpoints/FrozenLake-v1/expert/policy.pth +3 -0
- RL_based/checkpoints/LunarLander-v2/expert/policy.pth +3 -0
- RL_based/checkpoints/MountainCar-v0/expert/policy.pth +3 -0
- RL_based/checkpoints/Taxi-v3/expert/policy.pth +3 -0
- deciders/act.py +2 -2
- deciders/cot.py +2 -2
- deciders/exe.py +2 -2
- deciders/reflexion.py +2 -2
- deciders/self_consistency.py +2 -2
- deciders/selfask.py +2 -2
- deciders/spp.py +2 -2
- deciders/utils.py +4 -26
- envs/__init__.py +56 -16
- envs/mujoco/invertedDoublePendulum_translator.py +19 -14
- envs/mujoco/invertedPendulum_translator.py +14 -19
- record_reflexion.csv +2 -1
- shell/test_reflexion.sh +28 -28
.gitattributes
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RL_based/checkpoints/Blackjack-v1/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
2 |
+
RL_based/checkpoints/FrozenLake-v1/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
3 |
+
RL_based/checkpoints/CliffWalking-v0/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
4 |
+
RL_based/checkpoints/Taxi-v3/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
5 |
+
RL_based/checkpoints/Acrobot-v1/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
6 |
+
RL_based/checkpoints/CartPole-v1/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
7 |
+
RL_based/checkpoints/LunarLander-v2/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
8 |
+
RL_based/checkpoints/CartPole-v0/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
9 |
+
RL_based/checkpoints/MountainCar-v0/expert/policy.pth filter=lfs diff=lfs merge=lfs -text
|
RL_based/checkpoints/Acrobot-v1/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09224e6c813ca9d51f45f0ee502d657c51e9356a10c970ff2169e04fcb7ceed0
|
3 |
+
size 16563471
|
RL_based/checkpoints/Blackjack-v1/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:beaedfb2de808aca82b241f2eac35641676ccdd2017be6e7deda0a97d5323cd5
|
3 |
+
size 16562959
|
RL_based/checkpoints/CartPole-v0/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:023a2c35ae9d0d26a6649bcb00f22ad98951cf37d5dcb76a65924f51ff27258a
|
3 |
+
size 16562959
|
RL_based/checkpoints/CartPole-v1/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64a8c76d798320817b7e4a501f987efe36f874a4a248a189cd8c82c703d0a828
|
3 |
+
size 16562959
|
RL_based/checkpoints/CliffWalking-v0/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2981b5bf16236a36c0acb8ebeb853b3877ac52eb0ac1340830d27fb87b8e2a1
|
3 |
+
size 16563983
|
RL_based/checkpoints/FrozenLake-v1/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a663cc382fbf6cd75318ea26f5d8b099e2818e3d4f3c9fd55e5690ace6f38821
|
3 |
+
size 16563983
|
RL_based/checkpoints/LunarLander-v2/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a62aaaea80ca273217fdd9c41eb4c3c46cdb06c843a5c2dddf6cbe7c77a9d1ff
|
3 |
+
size 16563983
|
RL_based/checkpoints/MountainCar-v0/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01399cc32dabd7a945259188ba9274b6bc25233263764f8779c0ea987a514c77
|
3 |
+
size 16563471
|
RL_based/checkpoints/Taxi-v3/expert/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c1fb509e64f91f040de7dd2089d108f3871d9cfe7e4d4f695c01780d8eb296c
|
3 |
+
size 16565007
|
deciders/act.py
CHANGED
@@ -26,7 +26,7 @@ class RandomAct():
|
|
26 |
return action, '', '', '', 0, 0
|
27 |
|
28 |
class NaiveAct(gpt):
|
29 |
-
def __init__(self,
|
30 |
self.action_space = action_space
|
31 |
self.temperature = temperature
|
32 |
self.action_desc_dict = args.action_desc_dict
|
@@ -39,7 +39,7 @@ class NaiveAct(gpt):
|
|
39 |
else:
|
40 |
model = args.gpt_version
|
41 |
self.encoding = tiktoken.encoding_for_model(model)
|
42 |
-
super().__init__(args
|
43 |
self.distiller = distiller
|
44 |
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
45 |
if isinstance(self.action_space, Discrete):
|
|
|
26 |
return action, '', '', '', 0, 0
|
27 |
|
28 |
class NaiveAct(gpt):
|
29 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
|
30 |
self.action_space = action_space
|
31 |
self.temperature = temperature
|
32 |
self.action_desc_dict = args.action_desc_dict
|
|
|
39 |
else:
|
40 |
model = args.gpt_version
|
41 |
self.encoding = tiktoken.encoding_for_model(model)
|
42 |
+
super().__init__(args)
|
43 |
self.distiller = distiller
|
44 |
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
|
45 |
if isinstance(self.action_space, Discrete):
|
deciders/cot.py
CHANGED
@@ -17,8 +17,8 @@ from .utils import run_chain
|
|
17 |
|
18 |
|
19 |
class ChainOfThought(NaiveAct):
|
20 |
-
def __init__(self,
|
21 |
-
super().__init__(
|
22 |
|
23 |
def act(
|
24 |
self,
|
|
|
17 |
|
18 |
|
19 |
class ChainOfThought(NaiveAct):
|
20 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
|
23 |
def act(
|
24 |
self,
|
deciders/exe.py
CHANGED
@@ -20,8 +20,8 @@ from loguru import logger
|
|
20 |
|
21 |
|
22 |
class EXE(NaiveAct):
|
23 |
-
def __init__(self,
|
24 |
-
super().__init__(
|
25 |
self.pre_memory = []
|
26 |
self.post_memory = []
|
27 |
self.is_first = True
|
|
|
20 |
|
21 |
|
22 |
class EXE(NaiveAct):
|
23 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
|
24 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
25 |
self.pre_memory = []
|
26 |
self.post_memory = []
|
27 |
self.is_first = True
|
deciders/reflexion.py
CHANGED
@@ -19,8 +19,8 @@ from .utils import run_chain
|
|
19 |
|
20 |
|
21 |
class Reflexion(NaiveAct):
|
22 |
-
def __init__(self,
|
23 |
-
super().__init__(
|
24 |
|
25 |
def num_tokens_from_string(self,string: str) -> int:
|
26 |
"""Returns the number of tokens in a text string."""
|
|
|
19 |
|
20 |
|
21 |
class Reflexion(NaiveAct):
|
22 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
23 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
24 |
|
25 |
def num_tokens_from_string(self,string: str) -> int:
|
26 |
"""Returns the number of tokens in a text string."""
|
deciders/self_consistency.py
CHANGED
@@ -17,9 +17,9 @@ from .utils import run_chain
|
|
17 |
|
18 |
|
19 |
class SelfConsistency(NaiveAct):
|
20 |
-
def __init__(self,
|
21 |
temperature = 0.7
|
22 |
-
super().__init__(
|
23 |
self.temperature = temperature
|
24 |
|
25 |
def act(
|
|
|
17 |
|
18 |
|
19 |
class SelfConsistency(NaiveAct):
|
20 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
temperature = 0.7
|
22 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
23 |
self.temperature = temperature
|
24 |
|
25 |
def act(
|
deciders/selfask.py
CHANGED
@@ -17,8 +17,8 @@ from .utils import run_chain
|
|
17 |
|
18 |
|
19 |
class SelfAskAct(NaiveAct):
|
20 |
-
def __init__(self,
|
21 |
-
super().__init__(
|
22 |
|
23 |
def act(
|
24 |
self,
|
|
|
17 |
|
18 |
|
19 |
class SelfAskAct(NaiveAct):
|
20 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
21 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens,logger)
|
22 |
|
23 |
def act(
|
24 |
self,
|
deciders/spp.py
CHANGED
@@ -16,8 +16,8 @@ from .act import NaiveAct
|
|
16 |
from .utils import run_chain
|
17 |
|
18 |
class SPP(NaiveAct):
|
19 |
-
def __init__(self,
|
20 |
-
super().__init__(
|
21 |
|
22 |
def act(
|
23 |
self,
|
|
|
16 |
from .utils import run_chain
|
17 |
|
18 |
class SPP(NaiveAct):
|
19 |
+
def __init__(self, action_space, args, prompts, distiller, temperature=0.1, max_tokens=None, logger=None):
|
20 |
+
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
|
21 |
|
22 |
def act(
|
23 |
self,
|
deciders/utils.py
CHANGED
@@ -19,30 +19,8 @@ Model = Literal["gpt-4", "gpt-35-turbo", "text-davinci-003"]
|
|
19 |
# from .gpt import gpt
|
20 |
# gpt().__init__()
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
# def run_chain(chain, *args, **kwargs):
|
25 |
-
# return chain.run(*args, **kwargs)
|
26 |
-
import concurrent.futures
|
27 |
-
|
28 |
-
def timeout_decorator(timeout):
|
29 |
-
def decorator(function):
|
30 |
-
def wrapper(*args, **kwargs):
|
31 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
32 |
-
future = executor.submit(function, *args, **kwargs)
|
33 |
-
try:
|
34 |
-
return future.result(timeout)
|
35 |
-
except concurrent.futures.TimeoutError:
|
36 |
-
raise RuntimeError(
|
37 |
-
f"Function '{function.__name__}' timed out after {timeout} seconds"
|
38 |
-
)
|
39 |
-
except Exception as e:
|
40 |
-
raise e
|
41 |
-
return wrapper
|
42 |
-
return decorator
|
43 |
-
|
44 |
-
|
45 |
-
@timeout_decorator(30)
|
46 |
def run_chain(chain, *args, **kwargs):
|
47 |
return chain.run(*args, **kwargs)
|
48 |
|
@@ -76,7 +54,6 @@ def get_completion(prompt: str, api_type: str = "azure", engine: str = "gpt-35-t
|
|
76 |
temperature=temperature,
|
77 |
# request_timeout = 1
|
78 |
)
|
79 |
-
import pdb; pdb.set_trace()
|
80 |
return response.choices[0]["message"]["content"]
|
81 |
|
82 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
@@ -108,4 +85,5 @@ def get_chat(prompt: str, api_type: str = "azure", model: str = "gpt-35-turbo",
|
|
108 |
temperature=temperature,
|
109 |
# request_timeout = 1
|
110 |
)
|
111 |
-
return response.choices[0]["message"]["content"]
|
|
|
|
19 |
# from .gpt import gpt
|
20 |
# gpt().__init__()
|
21 |
|
22 |
+
import timeout_decorator
|
23 |
+
@timeout_decorator.timeout(30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def run_chain(chain, *args, **kwargs):
|
25 |
return chain.run(*args, **kwargs)
|
26 |
|
|
|
54 |
temperature=temperature,
|
55 |
# request_timeout = 1
|
56 |
)
|
|
|
57 |
return response.choices[0]["message"]["content"]
|
58 |
|
59 |
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
|
|
85 |
temperature=temperature,
|
86 |
# request_timeout = 1
|
87 |
)
|
88 |
+
return response.choices[0]["message"]["content"]
|
89 |
+
|
envs/__init__.py
CHANGED
@@ -18,25 +18,24 @@ from .atari import mspacman_policies, mspacman_translator
|
|
18 |
from .atari import montezumarevenge_policies, montezumarevenge_translator
|
19 |
register_environments()
|
20 |
|
21 |
-
from .mujoco import ant_translator, ant_policies
|
22 |
|
23 |
REGISTRY = {}
|
24 |
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
25 |
REGISTRY["base_env"] = BaseEnv
|
26 |
-
REGISTRY["
|
27 |
-
REGISTRY["
|
28 |
REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
|
29 |
REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
|
30 |
REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
|
31 |
REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
|
32 |
|
33 |
-
REGISTRY["
|
34 |
REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
|
35 |
REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
|
36 |
|
37 |
-
REGISTRY["
|
38 |
-
REGISTRY["
|
39 |
-
REGISTRY["
|
40 |
|
41 |
REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
|
42 |
REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
|
@@ -55,9 +54,9 @@ REGISTRY["frozenlake_basic_translator"] = frozenlake_translator.BasicStateSequen
|
|
55 |
REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
|
56 |
|
57 |
|
58 |
-
REGISTRY["
|
59 |
-
REGISTRY["
|
60 |
-
REGISTRY["
|
61 |
|
62 |
|
63 |
REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
|
@@ -139,6 +138,47 @@ REGISTRY["RepresentedMontezumaRevenge_basic_policies"] = [
|
|
139 |
montezumarevenge_policies.dedicated_18_policy,
|
140 |
]
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
## For mujoco env
|
143 |
|
144 |
|
@@ -156,12 +196,12 @@ from .mujoco import walker2d_translator, walker2d_policies
|
|
156 |
|
157 |
|
158 |
|
159 |
-
REGISTRY["
|
160 |
-
REGISTRY["
|
161 |
-
REGISTRY["
|
162 |
-
REGISTRY["
|
163 |
-
REGISTRY["
|
164 |
-
REGISTRY["
|
165 |
|
166 |
|
167 |
REGISTRY["swimmer_init_translator"] = swimmer_translator.GameDescriber
|
|
|
18 |
from .atari import montezumarevenge_policies, montezumarevenge_translator
|
19 |
register_environments()
|
20 |
|
|
|
21 |
|
22 |
REGISTRY = {}
|
23 |
REGISTRY["sampling_wrapper"] = SettableStateEnv
|
24 |
REGISTRY["base_env"] = BaseEnv
|
25 |
+
REGISTRY["cart_init_translator"] = cartpole_translator.GameDescriber
|
26 |
+
REGISTRY["cart_basic_translator"] = cartpole_translator.BasicStateSequenceTranslator
|
27 |
REGISTRY["acrobot_init_translator"] = acrobot_translator.GameDescriber
|
28 |
REGISTRY["acrobot_basic_translator"] = acrobot_translator.BasicStateSequenceTranslator
|
29 |
REGISTRY["mountaincar_init_translator"] = mountaincar_translator.GameDescriber
|
30 |
REGISTRY["mountaincar_basic_translator"] = mountaincar_translator.BasicStateSequenceTranslator
|
31 |
|
32 |
+
REGISTRY["cart_policies"] = [cartpole_policies.dedicated_1_policy, cartpole_policies.dedicated_2_policy, cartpole_policies.pseudo_random_policy, cartpole_policies.real_random_policy]
|
33 |
REGISTRY["acrobot_policies"] = [acrobot_policies.dedicated_1_policy, acrobot_policies.dedicated_2_policy, acrobot_policies.dedicated_3_policy, acrobot_policies.pseudo_random_policy, acrobot_policies.real_random_policy]
|
34 |
REGISTRY["mountaincar_policies"] = [mountaincar_policies.dedicated_1_policy, mountaincar_policies.dedicated_2_policy, mountaincar_policies.dedicated_3_policy, mountaincar_policies.pseudo_random_policy, mountaincar_policies.real_random_policy]
|
35 |
|
36 |
+
REGISTRY["lunarLander_init_translator"] = LunarLander_translator.GameDescriber
|
37 |
+
REGISTRY["lunarLander_basic_translator"] = LunarLander_translator.BasicStateSequenceTranslator
|
38 |
+
REGISTRY["lunarLander_policies"] = [LunarLander_policies.dedicated_1_policy, LunarLander_policies.dedicated_2_policy, LunarLander_policies.dedicated_3_policy,LunarLander_policies.dedicated_4_policy, LunarLander_policies.pseudo_random_policy, LunarLander_policies.real_random_policy]
|
39 |
|
40 |
REGISTRY["blackjack_init_translator"] = blackjack_translator.GameDescriber
|
41 |
REGISTRY["blackjack_basic_translator"] = blackjack_translator.BasicStateSequenceTranslator
|
|
|
54 |
REGISTRY["frozenlake_policies"] = [frozenlake_policies.dedicated_1_policy, frozenlake_policies.dedicated_2_policy, frozenlake_policies.dedicated_3_policy, frozenlake_policies.dedicated_4_policy, frozenlake_policies.pseudo_random_policy, frozenlake_policies.real_random_policy]
|
55 |
|
56 |
|
57 |
+
REGISTRY["mountaincarContinuous_init_translator"] = mountaincarContinuous_translator.GameDescriber
|
58 |
+
REGISTRY["mountaincarContinuous_basic_translator"] = mountaincarContinuous_translator.BasicStateSequenceTranslator
|
59 |
+
REGISTRY["mountaincarContinuous_policies"] = [mountaincarContinuous_policies.pseudo_random_policy, mountaincarContinuous_policies.real_random_policy]
|
60 |
|
61 |
|
62 |
REGISTRY["RepresentedBoxing_init_translator"] = Boxing_translator.GameDescriber
|
|
|
138 |
montezumarevenge_policies.dedicated_18_policy,
|
139 |
]
|
140 |
|
141 |
+
REGISTRY["RepresentedMsPacman_init_translator"] = mspacman_translator.GameDescriber
|
142 |
+
REGISTRY["RepresentedMsPacman_basic_translator"] = mspacman_translator.BasicStateSequenceTranslator
|
143 |
+
REGISTRY["RepresentedMsPacman_basic_policies"] = [
|
144 |
+
mspacman_policies.real_random_policy,
|
145 |
+
mspacman_policies.pseudo_random_policy,
|
146 |
+
mspacman_policies.dedicated_1_policy,
|
147 |
+
mspacman_policies.dedicated_2_policy,
|
148 |
+
mspacman_policies.dedicated_3_policy,
|
149 |
+
mspacman_policies.dedicated_4_policy,
|
150 |
+
mspacman_policies.dedicated_5_policy,
|
151 |
+
mspacman_policies.dedicated_6_policy,
|
152 |
+
mspacman_policies.dedicated_7_policy,
|
153 |
+
mspacman_policies.dedicated_8_policy,
|
154 |
+
mspacman_policies.dedicated_9_policy,
|
155 |
+
]
|
156 |
+
|
157 |
+
REGISTRY["RepresentedMontezumaRevenge_init_translator"] = montezumarevenge_translator.GameDescriber
|
158 |
+
REGISTRY["RepresentedMontezumaRevenge_basic_translator"] = montezumarevenge_translator.BasicStateSequenceTranslator
|
159 |
+
REGISTRY["RepresentedMontezumaRevenge_basic_policies"] = [
|
160 |
+
montezumarevenge_policies.real_random_policy,
|
161 |
+
montezumarevenge_policies.pseudo_random_policy,
|
162 |
+
montezumarevenge_policies.dedicated_1_policy,
|
163 |
+
montezumarevenge_policies.dedicated_2_policy,
|
164 |
+
montezumarevenge_policies.dedicated_3_policy,
|
165 |
+
montezumarevenge_policies.dedicated_4_policy,
|
166 |
+
montezumarevenge_policies.dedicated_5_policy,
|
167 |
+
montezumarevenge_policies.dedicated_6_policy,
|
168 |
+
montezumarevenge_policies.dedicated_7_policy,
|
169 |
+
montezumarevenge_policies.dedicated_8_policy,
|
170 |
+
montezumarevenge_policies.dedicated_9_policy,
|
171 |
+
montezumarevenge_policies.dedicated_10_policy,
|
172 |
+
montezumarevenge_policies.dedicated_11_policy,
|
173 |
+
montezumarevenge_policies.dedicated_12_policy,
|
174 |
+
montezumarevenge_policies.dedicated_13_policy,
|
175 |
+
montezumarevenge_policies.dedicated_14_policy,
|
176 |
+
montezumarevenge_policies.dedicated_15_policy,
|
177 |
+
montezumarevenge_policies.dedicated_16_policy,
|
178 |
+
montezumarevenge_policies.dedicated_17_policy,
|
179 |
+
montezumarevenge_policies.dedicated_18_policy,
|
180 |
+
]
|
181 |
+
|
182 |
## For mujoco env
|
183 |
|
184 |
|
|
|
196 |
|
197 |
|
198 |
|
199 |
+
REGISTRY["invertedPendulum_init_translator"] = invertedPendulum_translator.GameDescriber
|
200 |
+
REGISTRY["invertedPendulum_basic_translator"] = invertedPendulum_translator.BasicStateSequenceTranslator
|
201 |
+
REGISTRY["invertedPendulum_policies"] = [invertedPendulum_policies.pseudo_random_policy, invertedPendulum_policies.real_random_policy]
|
202 |
+
REGISTRY["invertedDoublePendulum_init_translator"] = invertedDoublePendulum_translator.GameDescriber
|
203 |
+
REGISTRY["invertedDoublePendulum_basic_translator"] = invertedDoublePendulum_translator.BasicStateSequenceTranslator
|
204 |
+
REGISTRY["invertedDoublePendulum_policies"] = [invertedDoublePendulum_policies.pseudo_random_policy, invertedDoublePendulum_policies.real_random_policy]
|
205 |
|
206 |
|
207 |
REGISTRY["swimmer_init_translator"] = swimmer_translator.GameDescriber
|
envs/mujoco/invertedDoublePendulum_translator.py
CHANGED
@@ -7,9 +7,16 @@ class BasicLevelTranslator:
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
-
f"
|
11 |
-
f"
|
12 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
return res
|
15 |
|
@@ -18,7 +25,7 @@ class GameDescriber:
|
|
18 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
19 |
self.max_episode_len = args.max_episode_len
|
20 |
self.action_desc_dict = {
|
21 |
-
0: "Apply a force in the range [-
|
22 |
}
|
23 |
self.reward_desc_dict = {}
|
24 |
|
@@ -30,24 +37,22 @@ class GameDescriber:
|
|
30 |
|
31 |
def describe_goal(self):
|
32 |
return (
|
33 |
-
"The goal in the
|
34 |
-
"by applying continuous forces
|
35 |
)
|
36 |
|
37 |
def describe_game(self):
|
38 |
return (
|
39 |
-
"In the
|
40 |
-
"
|
41 |
-
"to the cart
|
42 |
-
"
|
43 |
-
"and angular velocities. The goal is to maintain balance as long as possible."
|
44 |
)
|
45 |
|
46 |
def describe_action(self):
|
47 |
return (
|
48 |
-
"Your next move: \n Please provide a numerical value
|
49 |
-
"
|
50 |
-
"in the right direction, and a negative value indicates applying force in the left direction."
|
51 |
)
|
52 |
|
53 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
+
f"Sine of the angle between cart and first pole: {state[1]:.2f}\n"
|
11 |
+
f"Sine of the angle between two poles: {state[2]:.2f}\n"
|
12 |
+
f"Cosine of the angle between cart and first pole: {state[3]:.2f}\n"
|
13 |
+
f"Cosine of the angle between two poles: {state[4]:.2f}\n"
|
14 |
+
f"Velocity of the cart: {state[5]:.2f} m/s\n"
|
15 |
+
f"Angular velocity of angle between cart and first pole: {state[6]:.2f} rad/s\n"
|
16 |
+
f"Angular velocity of angle between two poles: {state[7]:.2f} rad/s\n"
|
17 |
+
f"Constraint Force 1: {state[8]:.2f} N\n"
|
18 |
+
f"Constraint Force 2: {state[9]:.2f} N\n"
|
19 |
+
f"Constraint Force 3: {state[10]:.2f} N"
|
20 |
)
|
21 |
return res
|
22 |
|
|
|
25 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
26 |
self.max_episode_len = args.max_episode_len
|
27 |
self.action_desc_dict = {
|
28 |
+
0: "Apply a force in the range [-3, 3] to the cart to control its motion.",
|
29 |
}
|
30 |
self.reward_desc_dict = {}
|
31 |
|
|
|
37 |
|
38 |
def describe_goal(self):
|
39 |
return (
|
40 |
+
"The goal in the InvertedDoublePendulum environment is to balance the two poles "\
|
41 |
+
"on top of the cart by applying continuous forces on the cart."
|
42 |
)
|
43 |
|
44 |
def describe_game(self):
|
45 |
return (
|
46 |
+
"In the InvertedDoublePendulum environment, you control a system with a cart and two poles. "\
|
47 |
+
"Your objective is to balance the two poles on top of the cart by applying continuous forces "\
|
48 |
+
"to the cart. The environment provides observations of the cart's position, angles of the poles, "\
|
49 |
+
"and their angular velocities. The episode ends when certain termination conditions are met."
|
|
|
50 |
)
|
51 |
|
52 |
def describe_action(self):
|
53 |
return (
|
54 |
+
"Your next move: \n Please provide a numerical value within the range of [-3,3], "\
|
55 |
+
"representing the force to be applied to the cart."
|
|
|
56 |
)
|
57 |
|
58 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
envs/mujoco/invertedPendulum_translator.py
CHANGED
@@ -7,16 +7,9 @@ class BasicLevelTranslator:
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
-
f"
|
11 |
-
f"
|
12 |
-
f"
|
13 |
-
f"Cosine of the angle between two poles: {state[4]:.2f}\n"
|
14 |
-
f"Velocity of the cart: {state[5]:.2f} m/s\n"
|
15 |
-
f"Angular velocity of angle between cart and first pole: {state[6]:.2f} rad/s\n"
|
16 |
-
f"Angular velocity of angle between two poles: {state[7]:.2f} rad/s\n"
|
17 |
-
f"Constraint Force 1: {state[8]:.2f} N\n"
|
18 |
-
f"Constraint Force 2: {state[9]:.2f} N\n"
|
19 |
-
f"Constraint Force 3: {state[10]:.2f} N"
|
20 |
)
|
21 |
return res
|
22 |
|
@@ -25,7 +18,7 @@ class GameDescriber:
|
|
25 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
26 |
self.max_episode_len = args.max_episode_len
|
27 |
self.action_desc_dict = {
|
28 |
-
0: "Apply a force in the range [-
|
29 |
}
|
30 |
self.reward_desc_dict = {}
|
31 |
|
@@ -37,22 +30,24 @@ class GameDescriber:
|
|
37 |
|
38 |
def describe_goal(self):
|
39 |
return (
|
40 |
-
"The goal in the
|
41 |
-
"
|
42 |
)
|
43 |
|
44 |
def describe_game(self):
|
45 |
return (
|
46 |
-
"In the
|
47 |
-
"Your objective is to balance the
|
48 |
-
"to the cart
|
49 |
-
"
|
|
|
50 |
)
|
51 |
|
52 |
def describe_action(self):
|
53 |
return (
|
54 |
-
"Your next move: \n Please provide a numerical value
|
55 |
-
"
|
|
|
56 |
)
|
57 |
|
58 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
|
|
7 |
def translate(self, state):
|
8 |
res = (
|
9 |
f"Position of the cart: {state[0]:.2f} m\n"
|
10 |
+
f"Vertical angle of the pole: {state[1]:.2f} rad\n"
|
11 |
+
f"Linear velocity of the cart: {state[2]:.2f} m/s\n"
|
12 |
+
f"Angular velocity of the pole: {state[3]:.2f} rad/s"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
return res
|
15 |
|
|
|
18 |
self.is_only_local_obs = args.is_only_local_obs == 1
|
19 |
self.max_episode_len = args.max_episode_len
|
20 |
self.action_desc_dict = {
|
21 |
+
0: "Apply a force in the range [-1, 1] to the cart to control its motion.",
|
22 |
}
|
23 |
self.reward_desc_dict = {}
|
24 |
|
|
|
30 |
|
31 |
def describe_goal(self):
|
32 |
return (
|
33 |
+
"The goal in the Inverted Pendulum environment is to balance the pole on top of the cart "\
|
34 |
+
"by applying continuous forces to the cart, keeping it upright."
|
35 |
)
|
36 |
|
37 |
def describe_game(self):
|
38 |
return (
|
39 |
+
"In the Inverted Pendulum environment, you control a cart that can move linearly with a pole "\
|
40 |
+
"attached to it. Your objective is to balance the pole on top of the cart by applying forces "\
|
41 |
+
"to the cart in a way that keeps the pole upright. "\
|
42 |
+
"The environment provides observations of the cart's position, pole angle, velocities, "\
|
43 |
+
"and angular velocities. The goal is to maintain balance as long as possible."
|
44 |
)
|
45 |
|
46 |
def describe_action(self):
|
47 |
return (
|
48 |
+
"Your next move: \n Please provide a numerical value for the force to be applied to the cart. "\
|
49 |
+
"This value should be within the range of [-3, 3], where a positive value indicates applying force "\
|
50 |
+
"in the right direction, and a negative value indicates applying force in the left direction."
|
51 |
)
|
52 |
|
53 |
class BasicStateSequenceTranslator(BasicLevelTranslator):
|
record_reflexion.csv
CHANGED
@@ -19,4 +19,5 @@ Walker2d-v4,1,expert,5000.0
|
|
19 |
Swimmer-v4,1,expert,44.4
|
20 |
Reacher-v4,1,expert,-2.6
|
21 |
Pusher-v4,1,expert,-52.3
|
22 |
-
|
|
|
|
19 |
Swimmer-v4,1,expert,44.4
|
20 |
Reacher-v4,1,expert,-2.6
|
21 |
Pusher-v4,1,expert,-52.3
|
22 |
+
InvertedPendulum-v4,1,expert,1000.0
|
23 |
+
InvertedDoublePendulum-v4,1,expert,9359.5
|
shell/test_reflexion.sh
CHANGED
@@ -1,43 +1,43 @@
|
|
1 |
|
2 |
# CartPole-v0
|
3 |
# Naive Actor
|
4 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
5 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
6 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
7 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
8 |
|
9 |
# COT
|
10 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
11 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
12 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
13 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
14 |
|
15 |
# self consistency
|
16 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
17 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
18 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
19 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
20 |
|
21 |
# self-ask
|
22 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
23 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
24 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
25 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
26 |
|
27 |
# SPP
|
28 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
29 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
30 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
31 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
32 |
|
33 |
# REFLEXION
|
34 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
35 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
36 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
37 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
38 |
|
39 |
# exe
|
40 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
41 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
42 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
43 |
-
python main_reflexion.py --env_name CartPole-v0 --init_summarizer
|
|
|
1 |
|
2 |
# CartPole-v0
|
3 |
# Naive Actor
|
4 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider naive_actor --prompt_level 1 --num_trails 1
|
5 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider naive_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
6 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider naive_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
7 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider naive_actor --prompt_level 5 --num_trails 1
|
8 |
|
9 |
# COT
|
10 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider cot_actor --prompt_level 1 --num_trails 1
|
11 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider cot_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
12 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider cot_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
13 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider cot_actor --prompt_level 5 --num_trails 1
|
14 |
|
15 |
# self consistency
|
16 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider self_consistency_actor --prompt_level 1 --num_trails 1
|
17 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider self_consistency_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
18 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider self_consistency_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
19 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider self_consistency_actor --prompt_level 5 --num_trails 1
|
20 |
|
21 |
# self-ask
|
22 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider selfask_actor --prompt_level 1 --num_trails 1
|
23 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider selfask_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
24 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider selfask_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
25 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider selfask_actor --prompt_level 5 --num_trails 1
|
26 |
|
27 |
# SPP
|
28 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider spp_actor --prompt_level 1 --num_trails 1
|
29 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider spp_actor --prompt_level 3 -num_trails 2 --distiller traj_distiller
|
30 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider spp_actor --prompt_level 4 --num_trails 1 --distiller traj_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
31 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider spp_actor --prompt_level 5 --num_trails 1
|
32 |
|
33 |
# REFLEXION
|
34 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider reflexion_actor --prompt_level 1 --num_trails 1 --distiller reflect_distiller
|
35 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider reflexion_actor --prompt_level 3 -num_trails 2 --distiller reflect_distiller
|
36 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider reflexion_actor --prompt_level 4 --num_trails 1 --distiller reflect_distiller --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
37 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider reflexion_actor --prompt_level 5 --num_trails 1 --distiller reflect_distiller
|
38 |
|
39 |
# exe
|
40 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider exe_actor --prompt_level 1 --num_trails 1 --distiller guide_generator
|
41 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider exe_actor --prompt_level 3 -num_trails 2 --distiller guide_generator
|
42 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider exe_actor --prompt_level 4 --num_trails 1 --distiller guide_generator --prompt_path "envs/classic_control/few_shot_examples/cartpole"
|
43 |
+
python main_reflexion.py --env_name CartPole-v0 --init_summarizer cart_init_translator --curr_summarizer cart_basic_translator --decider exe_actor --prompt_level 5 --num_trails 1 --distiller guide_generator
|