rishiad commited on
Commit
3ddfff8
·
unverified ·
1 Parent(s): d36955c

feat: enhance observation processing and debugging in RLAgent for MetaWorld policies

Browse files
Files changed (1) hide show
  1. agent.py +191 -39
agent.py CHANGED
@@ -45,6 +45,20 @@ class RLAgent(AgentInterface):
45
  if hasattr(self.policy, 'bias'):
46
  print(f"Policy bias: {self.policy.bias}")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Track episode state
49
  self.episode_step = 0
50
  self.max_episode_steps = kwargs.get("max_episode_steps", 200)
@@ -52,6 +66,9 @@ class RLAgent(AgentInterface):
52
  # Policy scaling factor (can be adjusted if policy constants are too high)
53
  self.policy_scale = kwargs.get("policy_scale", 1.0)
54
 
 
 
 
55
  # Debug flags
56
  self.debug_observations = True
57
  self.debug_actions = True
@@ -83,12 +100,54 @@ class RLAgent(AgentInterface):
83
  # Process observation to extract the format needed by the expert policy
84
  processed_obs = self._process_observation(obs)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # Debug processed observation (reduced frequency)
87
  print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
88
  print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
89
 
90
- # Use the expert policy
91
- action_numpy = self.policy.get_action(processed_obs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Debug raw policy output (reduced frequency)
94
  print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
@@ -136,58 +195,151 @@ class RLAgent(AgentInterface):
136
  """
137
  Helper method to process observations for the MetaWorld expert policy.
138
 
139
- MetaWorld policies typically expect a specific observation format.
 
 
 
 
140
  """
141
  if isinstance(obs, dict):
142
- # Try different keys that MetaWorld might use
143
- possible_keys = [
144
- "observation",
145
- "obs",
146
- "state_observation",
147
- "achieved_goal",
148
- "state"
149
  ]
150
-
151
  processed_obs = None
152
- for key in possible_keys:
153
  if key in obs:
154
  processed_obs = obs[key]
155
- print(f"Using observation key: {key}")
156
  break
157
-
 
 
 
 
 
 
 
 
158
  if processed_obs is None:
159
- # If none of the expected keys found, concatenate all numeric values
160
- numeric_values = []
 
 
 
161
  for key, value in obs.items():
162
- if isinstance(value, (np.ndarray, list, tuple)):
163
- flat_value = np.array(value).flatten()
164
- numeric_values.append(flat_value)
165
- print(f"Concatenating key {key}: shape={flat_value.shape}")
166
-
167
- if numeric_values:
168
- processed_obs = np.concatenate(numeric_values)
169
  print(f"Concatenated observation shape: {processed_obs.shape}")
170
  else:
171
- # Last resort: use first value
172
- processed_obs = next(iter(obs.values()))
173
- print("No numeric values found, using first observation value")
174
  else:
175
- processed_obs = obs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # Ensure numpy array
178
- if not isinstance(processed_obs, np.ndarray):
179
- try:
180
- processed_obs = np.array(processed_obs, dtype=np.float32)
181
- except Exception as e:
182
- print(f"Failed to convert observation to numpy array: {e}")
183
- # Return default observation size for MetaWorld reach task
184
- processed_obs = np.zeros(39, dtype=np.float32)
185
 
186
- # Ensure proper shape for MetaWorld reach policy
187
- if processed_obs.ndim > 1:
188
- processed_obs = processed_obs.flatten()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- return processed_obs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  def reset(self) -> None:
193
  """
 
45
  if hasattr(self.policy, 'bias'):
46
  print(f"Policy bias: {self.policy.bias}")
47
 
48
+ # Inspect policy methods to understand expected input format
49
+ if hasattr(self.policy, 'get_action'):
50
+ print(f"Policy has get_action method")
51
+ if hasattr(self.policy, '_get_obs'):
52
+ print(f"Policy has _get_obs method")
53
+
54
+ # Try to understand what observation format the policy expects
55
+ try:
56
+ # Some MetaWorld policies might have observation space info
57
+ if hasattr(self.policy, 'observation_space'):
58
+ print(f"Policy observation space: {self.policy.observation_space}")
59
+ except:
60
+ pass
61
+
62
  # Track episode state
63
  self.episode_step = 0
64
  self.max_episode_steps = kwargs.get("max_episode_steps", 200)
 
66
  # Policy scaling factor (can be adjusted if policy constants are too high)
67
  self.policy_scale = kwargs.get("policy_scale", 1.0)
68
 
69
+ # Flag to try different observation processing strategies
70
+ self.try_alternative_obs = True
71
+
72
  # Debug flags
73
  self.debug_observations = True
74
  self.debug_actions = True
 
100
  # Process observation to extract the format needed by the expert policy
101
  processed_obs = self._process_observation(obs)
102
 
103
+ # Optionally normalize observation
104
+ if self.try_alternative_obs:
105
+ processed_obs = self._normalize_observation(processed_obs)
106
+
107
+ # Debug: print all observation keys and their shapes to understand the structure
108
+ if isinstance(obs, dict):
109
+ print("Full observation keys and shapes:")
110
+ for key, value in obs.items():
111
+ if isinstance(value, np.ndarray):
112
+ print(f" {key}: shape={value.shape}, dtype={value.dtype}, range=[{value.min():.3f}, {value.max():.3f}]")
113
+ else:
114
+ print(f" {key}: {type(value)} = {value}")
115
+
116
  # Debug processed observation (reduced frequency)
117
  print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
118
  print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
119
 
120
+ # Try different approaches for the MetaWorld policy
121
+ action_numpy = None
122
+
123
+ # Strategy 1: Try with processed observation (39-dim flattened array)
124
+ try:
125
+ action_numpy = self.policy.get_action(processed_obs)
126
+ print(f"✓ Used processed 39-dim observation for policy")
127
+ except Exception as e1:
128
+ print(f"✗ Failed with processed observation: {e1}")
129
+
130
+ # Strategy 2: Try with raw observation if it's a dict
131
+ if action_numpy is None and isinstance(obs, dict):
132
+ try:
133
+ action_numpy = self.policy.get_action(obs)
134
+ print(f"✓ Used raw observation dictionary for policy")
135
+ except Exception as e2:
136
+ print(f"✗ Failed with raw observation dictionary: {e2}")
137
+
138
+ # Strategy 3: Try extracting specific MetaWorld observation components
139
+ try:
140
+ metaworld_obs = self._extract_metaworld_obs(obs)
141
+ if metaworld_obs is not None:
142
+ action_numpy = self.policy.get_action(metaworld_obs)
143
+ print(f"✓ Used extracted MetaWorld observation for policy")
144
+ except Exception as e3:
145
+ print(f"✗ Failed with extracted observation: {e3}")
146
+
147
+ # Final fallback
148
+ if action_numpy is None:
149
+ print("âš  Using zero action as fallback")
150
+ action_numpy = np.zeros(4, dtype=np.float32)
151
 
152
  # Debug raw policy output (reduced frequency)
153
  print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
 
195
  """
196
  Helper method to process observations for the MetaWorld expert policy.
197
 
198
+ MetaWorld reach task policies typically expect observations with:
199
+ - End effector position (3 values)
200
+ - Target position (3 values)
201
+ - Joint positions and velocities (various dimensions)
202
+ - Total around 39 dimensions for Sawyer reach task
203
  """
204
  if isinstance(obs, dict):
205
+ # MetaWorld-specific observation keys for reach task
206
+ metaworld_keys = [
207
+ "observation", # Standard observation
208
+ "obs", # Alternative observation key
209
+ "state", # State observation
210
+ "achieved_goal", # For goal-based tasks
211
+ "desired_goal", # Target position
212
  ]
213
+
214
  processed_obs = None
215
+ for key in metaworld_keys:
216
  if key in obs:
217
  processed_obs = obs[key]
218
+ print(f"Using MetaWorld observation key: {key}")
219
  break
220
+
221
+ # If we found a specific key, ensure it's the right format
222
+ if processed_obs is not None:
223
+ if isinstance(processed_obs, np.ndarray):
224
+ # Ensure it's flattened and has the right dtype
225
+ processed_obs = processed_obs.flatten().astype(np.float32)
226
+ else:
227
+ processed_obs = np.array(processed_obs, dtype=np.float32).flatten()
228
+
229
  if processed_obs is None:
230
+ # Fallback: concatenate relevant observation components
231
+ print("No standard MetaWorld key found, concatenating observation components")
232
+
233
+ # Look for position and velocity information
234
+ components = []
235
  for key, value in obs.items():
236
+ if isinstance(value, np.ndarray) and len(value.flatten()) > 0:
237
+ flat_value = value.flatten().astype(np.float32)
238
+ components.append(flat_value)
239
+ print(f"Adding component {key}: shape={flat_value.shape}")
240
+
241
+ if components:
242
+ processed_obs = np.concatenate(components)
243
  print(f"Concatenated observation shape: {processed_obs.shape}")
244
  else:
245
+ # Last resort: create zeros
246
+ processed_obs = np.zeros(39, dtype=np.float32)
247
+ print("No valid observation components found, using zeros")
248
  else:
249
+ # If obs is already an array, ensure it's properly formatted
250
+ processed_obs = np.array(obs, dtype=np.float32).flatten()
251
+
252
+ # Ensure we have the expected dimension for MetaWorld reach (typically 39)
253
+ if len(processed_obs) != 39:
254
+ print(f"Observation dimension mismatch: got {len(processed_obs)}, expected 39")
255
+ if len(processed_obs) < 39:
256
+ # Pad with zeros
257
+ padding = np.zeros(39 - len(processed_obs), dtype=np.float32)
258
+ processed_obs = np.concatenate([processed_obs, padding])
259
+ print(f"Padded observation to 39 dimensions")
260
+ else:
261
+ # Truncate
262
+ processed_obs = processed_obs[:39]
263
+ print(f"Truncated observation to 39 dimensions")
264
 
265
+ return processed_obs
 
 
 
 
 
 
 
266
 
267
+ def _extract_metaworld_obs(self, obs):
268
+ """
269
+ Extract MetaWorld-specific observation components for the reach task.
270
+
271
+ MetaWorld reach observations typically include:
272
+ - Joint positions (7 values for Sawyer)
273
+ - Joint velocities (7 values)
274
+ - End effector position (3 values)
275
+ - Target position (3 values)
276
+ - Other task-specific info
277
+ """
278
+ if not isinstance(obs, dict):
279
+ return None
280
+
281
+ components = []
282
+
283
+ # Try to find joint positions
284
+ if 'qpos' in obs:
285
+ joint_pos = np.array(obs['qpos'], dtype=np.float32).flatten()
286
+ components.append(joint_pos)
287
+ print(f"Found joint positions: {joint_pos.shape}")
288
+
289
+ # Try to find joint velocities
290
+ if 'qvel' in obs:
291
+ joint_vel = np.array(obs['qvel'], dtype=np.float32).flatten()
292
+ components.append(joint_vel)
293
+ print(f"Found joint velocities: {joint_vel.shape}")
294
+
295
+ # Try to find end effector position
296
+ if 'eef_pos' in obs or 'achieved_goal' in obs:
297
+ eef_key = 'eef_pos' if 'eef_pos' in obs else 'achieved_goal'
298
+ eef_pos = np.array(obs[eef_key], dtype=np.float32).flatten()
299
+ if len(eef_pos) >= 3:
300
+ components.append(eef_pos[:3]) # Take first 3 values (x, y, z)
301
+ print(f"Found end effector position: {eef_pos[:3]}")
302
+
303
+ # Try to find target/goal position
304
+ if 'target_pos' in obs or 'desired_goal' in obs:
305
+ target_key = 'target_pos' if 'target_pos' in obs else 'desired_goal'
306
+ target_pos = np.array(obs[target_key], dtype=np.float32).flatten()
307
+ if len(target_pos) >= 3:
308
+ components.append(target_pos[:3]) # Take first 3 values (x, y, z)
309
+ print(f"Found target position: {target_pos[:3]}")
310
+
311
+ # If we found components, concatenate them
312
+ if components:
313
+ metaworld_obs = np.concatenate(components)
314
+ print(f"Extracted MetaWorld observation: {metaworld_obs.shape} dimensions")
315
+ return metaworld_obs
316
+
317
+ return None
318
+
319
+ def _normalize_observation(self, obs):
320
+ """
321
+ Normalize observation if needed for MetaWorld policy.
322
 
323
+ Some MetaWorld policies expect normalized observations.
324
+ """
325
+ if not isinstance(obs, np.ndarray):
326
+ return obs
327
+
328
+ # Check if observation values are in a reasonable range
329
+ obs_min, obs_max = obs.min(), obs.max()
330
+
331
+ # If values are very large or very small, they might need normalization
332
+ if abs(obs_max) > 10 or abs(obs_min) > 10:
333
+ print(f"Observation values seem large (min={obs_min:.3f}, max={obs_max:.3f}), normalizing...")
334
+ # Normalize to roughly [-1, 1] range
335
+ obs_mean = obs.mean()
336
+ obs_std = obs.std()
337
+ if obs_std > 0:
338
+ normalized_obs = (obs - obs_mean) / obs_std
339
+ print(f"Normalized observation range: [{normalized_obs.min():.3f}, {normalized_obs.max():.3f}]")
340
+ return normalized_obs
341
+
342
+ return obs
343
 
344
  def reset(self) -> None:
345
  """