rishiad commited on
Commit
7fdb0a1
·
unverified ·
1 Parent(s): b0ac93f

refactor: replace logging with print statements for debugging in RLAgent

Browse files
Files changed (1) hide show
  1. agent.py +106 -72
agent.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Implementation of the AgentInterface for MetaWorld tasks.
3
-
4
- This agent uses the SawyerPickPlaceV2Policy from MetaWorld as an expert policy.
5
- """
6
 
7
  import logging
8
  from typing import Any, Dict
@@ -31,17 +26,26 @@ class RLAgent(AgentInterface):
31
  ):
32
  super().__init__(observation_space, action_space, seed, **kwargs)
33
 
34
- self.logger = logging.getLogger(__name__)
35
- self.logger.info(f"Initializing MetaWorld agent with seed {self.seed}")
 
 
 
 
 
36
 
37
  self.policy = SawyerReachV3Policy()
38
- self.logger.info("Successfully initialized SawyerReachV3Policy")
39
 
40
  # Track episode state
41
  self.episode_step = 0
42
  self.max_episode_steps = kwargs.get("max_episode_steps", 200)
 
 
 
 
43
 
44
- self.logger.info("MetaWorld agent initialized successfully")
45
 
46
  def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
47
  """
@@ -55,31 +59,60 @@ class RLAgent(AgentInterface):
55
  action: Action tensor to take in the environment
56
  """
57
  try:
 
 
 
 
 
 
 
 
 
 
 
58
  # Process observation to extract the format needed by the expert policy
59
  processed_obs = self._process_observation(obs)
60
 
61
- # Use the expert policy (MetaWorld is always available)
62
- # MetaWorld policies expect numpy arrays
63
- action_numpy = self.policy.get_action(processed_obs)
64
- action_tensor = torch.from_numpy(np.array(action_numpy)).float()
65
 
66
- # Log occasionally
67
- if self.episode_step % 50 == 0:
68
- self.logger.debug(f"Using expert policy action: {action_numpy}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Increment episode step
71
  self.episode_step += 1
72
-
73
- # Occasionally log actions to avoid spam
74
- if self.episode_step % 50 == 0:
75
- self.logger.debug(
76
- f"Step {self.episode_step}: Action shape {action_tensor.shape}"
77
- )
78
-
79
  return action_tensor
80
 
81
  except Exception as e:
82
- self.logger.error(f"Error in act method: {e}", exc_info=True)
83
  # Return zeros as a fallback
84
  if isinstance(self.action_space, gym.spaces.Box):
85
  return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
@@ -93,48 +126,56 @@ class RLAgent(AgentInterface):
93
  MetaWorld policies typically expect a specific observation format.
94
  """
95
  if isinstance(obs, dict):
96
- # MetaWorld environment can return observations in different formats
97
- if "observation" in obs:
98
- # Standard format for goal-observable environments
99
- processed_obs = obs["observation"]
100
- elif "obs" in obs:
101
- processed_obs = obs["obs"]
102
- elif "state_observation" in obs:
103
- # Some MetaWorld environments use this key
104
- processed_obs = obs["state_observation"]
105
- elif "goal_achieved" in obs:
106
- # If we have information about goal achievement
107
- # This might be needed for certain policy decisions
108
- achievement = obs.get("goal_achieved", False)
109
- base_obs = next(iter(obs.values()))
110
- self.logger.debug(f"Goal achieved: {achievement}")
111
- processed_obs = base_obs
112
- else:
113
- # If structure is unknown, use the first value
114
- processed_obs = next(iter(obs.values()))
115
- self.logger.debug(f"Using observation key: {next(iter(obs.keys()))}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  else:
117
- # If already a numpy array or similar, use directly
118
  processed_obs = obs
119
 
120
- # Ensure we're returning a numpy array as expected by MetaWorld policies
121
  if not isinstance(processed_obs, np.ndarray):
122
  try:
123
  processed_obs = np.array(processed_obs, dtype=np.float32)
124
  except Exception as e:
125
- self.logger.error(f"Failed to convert observation to numpy array: {e}")
126
- # Return a dummy observation if conversion fails
127
- if (
128
- self.observation_space
129
- and hasattr(self.observation_space, "shape")
130
- and self.observation_space.shape is not None
131
- ):
132
- processed_obs = np.zeros(
133
- self.observation_space.shape, dtype=np.float32
134
- )
135
- else:
136
- # Typical MetaWorld observation dimension if all else fails
137
- processed_obs = np.zeros(39, dtype=np.float32)
138
 
139
  return processed_obs
140
 
@@ -142,9 +183,11 @@ class RLAgent(AgentInterface):
142
  """
143
  Reset agent state between episodes.
144
  """
145
- self.logger.debug("Resetting agent")
146
  self.episode_step = 0
147
- # Any other stateful components would be reset here
 
 
148
 
149
  def _build_model(self):
150
  """
@@ -153,13 +196,4 @@ class RLAgent(AgentInterface):
153
  This is a placeholder for where you would define your neural network
154
  architecture using PyTorch, TensorFlow, or another framework.
155
  """
156
- # Example of where you might build a simple PyTorch model
157
- # model = torch.nn.Sequential(
158
- # torch.nn.Linear(self.observation_space.shape[0], 128),
159
- # torch.nn.ReLU(),
160
- # torch.nn.Linear(128, 64),
161
- # torch.nn.ReLU(),
162
- # torch.nn.Linear(64, self.action_space.shape[0]),
163
- # )
164
- # return model
165
- pass
 
 
 
 
 
 
1
 
2
  import logging
3
  from typing import Any, Dict
 
26
  ):
27
  super().__init__(observation_space, action_space, seed, **kwargs)
28
 
29
+ print(f"Initializing MetaWorld agent with seed {self.seed}")
30
+
31
+ # Log spaces for debugging
32
+ if observation_space:
33
+ print(f"Observation space: {observation_space}")
34
+ if action_space:
35
+ print(f"Action space: {action_space}")
36
 
37
  self.policy = SawyerReachV3Policy()
38
+ print("Successfully initialized SawyerReachV3Policy")
39
 
40
  # Track episode state
41
  self.episode_step = 0
42
  self.max_episode_steps = kwargs.get("max_episode_steps", 200)
43
+
44
+ # Debug flags
45
+ self.debug_observations = True
46
+ self.debug_actions = True
47
 
48
+ print("MetaWorld agent initialized successfully")
49
 
50
  def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
51
  """
 
59
  action: Action tensor to take in the environment
60
  """
61
  try:
62
+ # Debug observation structure
63
+ if self.debug_observations and self.episode_step % 20 == 0:
64
+ print(f"Raw observation structure: {type(obs)}")
65
+ if isinstance(obs, dict):
66
+ print(f"Observation keys: {list(obs.keys())}")
67
+ for key, value in obs.items():
68
+ if isinstance(value, np.ndarray):
69
+ print(f" {key}: shape={value.shape}, dtype={value.dtype}")
70
+ else:
71
+ print(f" {key}: {type(value)} = {value}")
72
+
73
  # Process observation to extract the format needed by the expert policy
74
  processed_obs = self._process_observation(obs)
75
 
76
+ # Debug processed observation
77
+ if self.debug_observations and self.episode_step % 20 == 0:
78
+ print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
79
+ print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
80
 
81
+ # Use the expert policy
82
+ action_numpy = self.policy.get_action(processed_obs)
83
+
84
+ # Debug raw policy output
85
+ if self.debug_actions and self.episode_step % 20 == 0:
86
+ print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
87
+ print(f"Action shape: {np.array(action_numpy).shape}")
88
+
89
+ # Convert to tensor
90
+ if isinstance(action_numpy, (list, tuple)):
91
+ action_tensor = torch.tensor(action_numpy, dtype=torch.float32)
92
+ else:
93
+ action_tensor = torch.from_numpy(np.array(action_numpy)).float()
94
+
95
+ # Ensure correct action dimensionality
96
+ if self.action_space and hasattr(self.action_space, 'shape'):
97
+ expected_shape = self.action_space.shape[0]
98
+ if action_tensor.shape[0] != expected_shape:
99
+ print(f"Action shape mismatch: got {action_tensor.shape[0]}, expected {expected_shape}")
100
+ # Pad or truncate as needed
101
+ if action_tensor.shape[0] < expected_shape:
102
+ padding = torch.zeros(expected_shape - action_tensor.shape[0])
103
+ action_tensor = torch.cat([action_tensor, padding])
104
+ else:
105
+ action_tensor = action_tensor[:expected_shape]
106
+
107
+ # Debug final action
108
+ if self.debug_actions and self.episode_step % 20 == 0:
109
+ print(f"Final action tensor: {action_tensor}")
110
 
 
111
  self.episode_step += 1
 
 
 
 
 
 
 
112
  return action_tensor
113
 
114
  except Exception as e:
115
+ print(f"Error in act method: {e}")
116
  # Return zeros as a fallback
117
  if isinstance(self.action_space, gym.spaces.Box):
118
  return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
 
126
  MetaWorld policies typically expect a specific observation format.
127
  """
128
  if isinstance(obs, dict):
129
+ # Try different keys that MetaWorld might use
130
+ possible_keys = [
131
+ "observation",
132
+ "obs",
133
+ "state_observation",
134
+ "achieved_goal",
135
+ "state"
136
+ ]
137
+
138
+ processed_obs = None
139
+ for key in possible_keys:
140
+ if key in obs:
141
+ processed_obs = obs[key]
142
+ if self.debug_observations and self.episode_step % 50 == 0:
143
+ print(f"Using observation key: {key}")
144
+ break
145
+
146
+ if processed_obs is None:
147
+ # If none of the expected keys found, concatenate all numeric values
148
+ numeric_values = []
149
+ for key, value in obs.items():
150
+ if isinstance(value, (np.ndarray, list, tuple)):
151
+ flat_value = np.array(value).flatten()
152
+ numeric_values.append(flat_value)
153
+ if self.debug_observations and self.episode_step % 50 == 0:
154
+ print(f"Concatenating key {key}: shape={flat_value.shape}")
155
+
156
+ if numeric_values:
157
+ processed_obs = np.concatenate(numeric_values)
158
+ if self.debug_observations and self.episode_step % 50 == 0:
159
+ print(f"Concatenated observation shape: {processed_obs.shape}")
160
+ else:
161
+ # Last resort: use first value
162
+ processed_obs = next(iter(obs.values()))
163
+ print("No numeric values found, using first observation value")
164
  else:
 
165
  processed_obs = obs
166
 
167
+ # Ensure numpy array
168
  if not isinstance(processed_obs, np.ndarray):
169
  try:
170
  processed_obs = np.array(processed_obs, dtype=np.float32)
171
  except Exception as e:
172
+ print(f"Failed to convert observation to numpy array: {e}")
173
+ # Return default observation size for MetaWorld reach task
174
+ processed_obs = np.zeros(39, dtype=np.float32)
175
+
176
+ # Ensure proper shape for MetaWorld reach policy
177
+ if processed_obs.ndim > 1:
178
+ processed_obs = processed_obs.flatten()
 
 
 
 
 
 
179
 
180
  return processed_obs
181
 
 
183
  """
184
  Reset agent state between episodes.
185
  """
186
+ print(f"Resetting agent after {self.episode_step} steps")
187
  self.episode_step = 0
188
+ # Reset debug flags if needed
189
+ self.debug_observations = True
190
+ self.debug_actions = True
191
 
192
  def _build_model(self):
193
  """
 
196
  This is a placeholder for where you would define your neural network
197
  architecture using PyTorch, TensorFlow, or another framework.
198
  """
199
+ pass