Jarvis-K commited on
Commit
5f98914
1 Parent(s): d731338

add support to multi-dim con

Browse files
deciders/act.py CHANGED
@@ -3,7 +3,7 @@
3
  import openai
4
  from .gpt import gpt
5
  from loguru import logger
6
- from .parser import PARSERS
7
  from langchain.output_parsers import PydanticOutputParser
8
  from langchain.output_parsers import OutputFixingParser
9
  from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
@@ -12,13 +12,18 @@ import tiktoken
12
  import json
13
  import re
14
  from .utils import run_chain
 
15
 
16
  class RandomAct():
17
  def __init__(self, action_space):
18
  self.action_space = action_space
19
 
20
  def act(self, state_description, action_description, env_info, game_description=None, goal_description=None):
21
- return self.action_space.sample()+1, '', '', '', 0, 0
 
 
 
 
22
 
23
  class NaiveAct(gpt):
24
  def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
@@ -37,7 +42,10 @@ class NaiveAct(gpt):
37
  super().__init__(args)
38
  self.distiller = distiller
39
  self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
40
- self.default_action = 1
 
 
 
41
  self.parser = self._parser_initialization()
42
  self.irr_game_description = ''
43
  self.memory = []
@@ -82,11 +90,12 @@ class NaiveAct(gpt):
82
 
83
 
84
  def _parser_initialization(self):
85
- if hasattr(self.action_space, 'n'):
86
- assert self.action_space.n in PARSERS.keys(), f'Action space {self.action_space} is not supported.'
87
  num_action = self.action_space.n
88
- else:
89
- num_action = 1
 
90
 
91
  if self.args.api_type == "azure":
92
  autofixing_chat = AzureChatOpenAI(
@@ -204,7 +213,6 @@ class NaiveAct(gpt):
204
  prompt, res = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
205
  action_str = res.choices[0].text.strip()
206
  print(f'my anwser is {action_str}')
207
- # import pdb; pdb.set_trace()
208
  try:
209
  if "Continuous" in self.args.env_name:
210
  action = float(re.findall(r"[-+]?\d*\.\d+", action_str)[0])
 
3
  import openai
4
  from .gpt import gpt
5
  from loguru import logger
6
+ from .parser import DISPARSERS, CONPARSERS
7
  from langchain.output_parsers import PydanticOutputParser
8
  from langchain.output_parsers import OutputFixingParser
9
  from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
 
12
  import json
13
  import re
14
  from .utils import run_chain
15
+ from gym.spaces import Discrete
16
 
17
  class RandomAct():
18
  def __init__(self, action_space):
19
  self.action_space = action_space
20
 
21
  def act(self, state_description, action_description, env_info, game_description=None, goal_description=None):
22
+ if isinstance(self.action_space, Discrete):
23
+ action = self.action_space.sample()+1
24
+ else:
25
+ action = self.action_space.sample()
26
+ return action, '', '', '', 0, 0
27
 
28
  class NaiveAct(gpt):
29
  def __init__(self, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
 
42
  super().__init__(args)
43
  self.distiller = distiller
44
  self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
45
+ if isinstance(self.action_space, Discrete):
46
+ self.default_action = 1
47
+ else:
48
+ self.default_action = [0 for ind in range(self.action_space.shape[0])]
49
  self.parser = self._parser_initialization()
50
  self.irr_game_description = ''
51
  self.memory = []
 
90
 
91
 
92
  def _parser_initialization(self):
93
+ if isinstance(self.action_space, Discrete):
94
+ PARSERS = DISPARSERS
95
  num_action = self.action_space.n
96
+ else:
97
+ PARSERS = CONPARSERS
98
+ num_action = self.action_space.shape[0]
99
 
100
  if self.args.api_type == "azure":
101
  autofixing_chat = AzureChatOpenAI(
 
213
  prompt, res = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
214
  action_str = res.choices[0].text.strip()
215
  print(f'my anwser is {action_str}')
 
216
  try:
217
  if "Continuous" in self.args.env_name:
218
  action = float(re.findall(r"[-+]?\d*\.\d+", action_str)[0])
deciders/parser.py CHANGED
@@ -1,4 +1,5 @@
1
  from pydantic import BaseModel, Field, validator
 
2
 
3
  class DisActionModel(BaseModel):
4
  action: int = Field(description="the chosen action to perform")
@@ -17,15 +18,37 @@ def generate_action_class(max_action):
17
  return type(f"{max_action}Action", (DisActionModel,), {'action_is_valid': DisActionModel.create_validator(max_action)})
18
 
19
  # Dictionary of parsers with dynamic class generation
20
- PARSERS = {num: generate_action_class(num) for num in [2, 3, 4, 6, 9, 18]}
21
-
22
- # class ContinuousAction(BaseModel):
23
- # action: float = Field(description="the choosed action to perform")
24
- # # You can add custom validation logic easily with Pydantic.
25
- # @validator('action')
26
- # def action_is_valid(cls, field):
27
- # if not (field >= -1 and field <= 1):
28
- # raise ValueError("Action is not valid ([-1,1])!")
29
- # return field
30
-
31
- # PARSERS = {1:ContinuousAction, 2: TwoAction, 3: ThreeAction, 4: FourAction, 6: SixAction, 9:NineAction, 18: FullAtariAction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydantic import BaseModel, Field, validator
2
+ from typing import List
3
 
4
  class DisActionModel(BaseModel):
5
  action: int = Field(description="the chosen action to perform")
 
18
  return type(f"{max_action}Action", (DisActionModel,), {'action_is_valid': DisActionModel.create_validator(max_action)})
19
 
20
  # Dictionary of parsers with dynamic class generation
21
+ DISPARSERS = {num: generate_action_class(num) for num in [2, 3, 4, 6, 9, 18]}
22
+
23
+ class ContinuousActionBase(BaseModel):
24
+ action: List[float] = Field(description="the chosen continuous actions to perform")
25
+
26
+ @classmethod
27
+ def set_expected_length(cls, length):
28
+ cls.expected_length = length
29
+
30
+ @validator('action', pre=True)
31
+ def validate_length(cls, action):
32
+ if len(action) != cls.expected_length:
33
+ raise ValueError(f"The action list must have exactly {cls.expected_length} items.")
34
+ return action
35
+
36
+ @validator('action', each_item=True)
37
+ def action_is_valid(cls, item):
38
+ if not -1 <= item <= 1:
39
+ raise ValueError("Each action dimension must be in the range [-1, 1]!")
40
+ return item
41
+
42
+ # Generate classes dynamically
43
+ def generate_continuous_action_class(expected_length):
44
+ NewClass = type(
45
+ f"{expected_length}DContinuousAction",
46
+ (ContinuousActionBase,),
47
+ {}
48
+ )
49
+ NewClass.set_expected_length(expected_length)
50
+ return NewClass
51
+
52
+
53
+ # Dictionary of parsers with dynamic class generation
54
+ CONPARSERS = {length: generate_continuous_action_class(length) for length in range(1, 17)}
envs/base_env.py CHANGED
@@ -1,6 +1,7 @@
1
  # This file contains functions for interacting with the CartPole environment
2
 
3
  import gym
 
4
 
5
  class SettableStateEnv(gym.Wrapper):
6
  def __init__(self, env):
@@ -55,10 +56,11 @@ class BaseEnv(gym.Wrapper):
55
 
56
  def step_llm(self, action):
57
  potential_next_state = self.get_potential_next_state(action)
58
- if "Continuous" in self.env_name:
59
- state, reward, terminated, _, info = super().step(action)
60
- else:
61
  state, reward, terminated, _, info = super().step(action-1)
 
 
 
62
  self.transition_data['action'] = action
63
  self.transition_data['next_state'] = state
64
  self.transition_data['reward'] = reward
 
1
  # This file contains functions for interacting with the CartPole environment
2
 
3
  import gym
4
+ from gym.spaces import Discrete
5
 
6
  class SettableStateEnv(gym.Wrapper):
7
  def __init__(self, env):
 
56
 
57
  def step_llm(self, action):
58
  potential_next_state = self.get_potential_next_state(action)
59
+ if isinstance(self.action_space, Discrete):
 
 
60
  state, reward, terminated, _, info = super().step(action-1)
61
+ else:
62
+ state, reward, terminated, _, info = super().step(action)
63
+
64
  self.transition_data['action'] = action
65
  self.transition_data['next_state'] = state
66
  self.transition_data['reward'] = reward
main_reflexion.py CHANGED
@@ -17,6 +17,9 @@ import random
17
  import numpy as np
18
  import datetime
19
  from loguru import logger
 
 
 
20
 
21
 
22
  def set_seed(seed):
@@ -109,9 +112,6 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
109
  logfile
110
  )
111
 
112
- if "Continuous" in args.env_name:
113
- action = [action]
114
-
115
  state_description, reward, termination, truncation, env_info = environment.step_llm(
116
  action
117
  )
@@ -137,10 +137,6 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
137
  logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
138
  continue
139
  if error_flag:
140
- if "Continuous" in args.env_name:
141
- action = [decider.default_action]
142
- else:
143
- action = decider.default_action
144
  state_description, reward, termination, truncation, env_info = environment.step_llm(
145
  action
146
  )
@@ -164,7 +160,7 @@ def _run(translator, environment, decider, max_episode_len, logfile, args, trail
164
  logger.info(f"current_total_cost: {current_total_cost}")
165
  logger.info(f"Now it is round {round}.")
166
 
167
- frames.append(environment.render())
168
  if termination or truncation:
169
  if logger:
170
  logger.info(f"Terminated!")
 
17
  import numpy as np
18
  import datetime
19
  from loguru import logger
20
+ from gym.spaces import Discrete
21
+
22
+
23
 
24
 
25
  def set_seed(seed):
 
112
  logfile
113
  )
114
 
 
 
 
115
  state_description, reward, termination, truncation, env_info = environment.step_llm(
116
  action
117
  )
 
137
  logger.debug(f"Error: {e}, Retry! ({error_i+1}/{retry_num})")
138
  continue
139
  if error_flag:
 
 
 
 
140
  state_description, reward, termination, truncation, env_info = environment.step_llm(
141
  action
142
  )
 
160
  logger.info(f"current_total_cost: {current_total_cost}")
161
  logger.info(f"Now it is round {round}.")
162
 
163
+ # frames.append(environment.render())
164
  if termination or truncation:
165
  if logger:
166
  logger.info(f"Terminated!")
record_reflexion.csv CHANGED
@@ -10,4 +10,5 @@ FrozenLake-v1,1,expert,200.0
10
  MountainCarContinuous-v0,1,expert,200.0
11
  RepresentedBoxing-v0,1,expert,200.0
12
  RepresentedPong-v0,1,expert,200.0
 
13
 
 
10
  MountainCarContinuous-v0,1,expert,200.0
11
  RepresentedBoxing-v0,1,expert,200.0
12
  RepresentedPong-v0,1,expert,200.0
13
+ Ant-v4,1,expert,100
14