feat: enhance observation processing and debugging in RLAgent for MetaWorld policies
Browse files
agent.py
CHANGED
@@ -45,6 +45,20 @@ class RLAgent(AgentInterface):
|
|
45 |
if hasattr(self.policy, 'bias'):
|
46 |
print(f"Policy bias: {self.policy.bias}")
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# Track episode state
|
49 |
self.episode_step = 0
|
50 |
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
|
@@ -52,6 +66,9 @@ class RLAgent(AgentInterface):
|
|
52 |
# Policy scaling factor (can be adjusted if policy constants are too high)
|
53 |
self.policy_scale = kwargs.get("policy_scale", 1.0)
|
54 |
|
|
|
|
|
|
|
55 |
# Debug flags
|
56 |
self.debug_observations = True
|
57 |
self.debug_actions = True
|
@@ -83,12 +100,54 @@ class RLAgent(AgentInterface):
|
|
83 |
# Process observation to extract the format needed by the expert policy
|
84 |
processed_obs = self._process_observation(obs)
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
# Debug processed observation (reduced frequency)
|
87 |
print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
|
88 |
print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
|
89 |
|
90 |
-
#
|
91 |
-
action_numpy =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# Debug raw policy output (reduced frequency)
|
94 |
print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
|
@@ -136,58 +195,151 @@ class RLAgent(AgentInterface):
|
|
136 |
"""
|
137 |
Helper method to process observations for the MetaWorld expert policy.
|
138 |
|
139 |
-
MetaWorld policies typically expect
|
|
|
|
|
|
|
|
|
140 |
"""
|
141 |
if isinstance(obs, dict):
|
142 |
-
#
|
143 |
-
|
144 |
-
"observation",
|
145 |
-
"obs",
|
146 |
-
"
|
147 |
-
"achieved_goal",
|
148 |
-
"
|
149 |
]
|
150 |
-
|
151 |
processed_obs = None
|
152 |
-
for key in
|
153 |
if key in obs:
|
154 |
processed_obs = obs[key]
|
155 |
-
print(f"Using observation key: {key}")
|
156 |
break
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
if processed_obs is None:
|
159 |
-
#
|
160 |
-
|
|
|
|
|
|
|
161 |
for key, value in obs.items():
|
162 |
-
if isinstance(value,
|
163 |
-
flat_value =
|
164 |
-
|
165 |
-
print(f"
|
166 |
-
|
167 |
-
if
|
168 |
-
processed_obs = np.concatenate(
|
169 |
print(f"Concatenated observation shape: {processed_obs.shape}")
|
170 |
else:
|
171 |
-
# Last resort:
|
172 |
-
processed_obs =
|
173 |
-
print("No
|
174 |
else:
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
|
178 |
-
if not isinstance(processed_obs, np.ndarray):
|
179 |
-
try:
|
180 |
-
processed_obs = np.array(processed_obs, dtype=np.float32)
|
181 |
-
except Exception as e:
|
182 |
-
print(f"Failed to convert observation to numpy array: {e}")
|
183 |
-
# Return default observation size for MetaWorld reach task
|
184 |
-
processed_obs = np.zeros(39, dtype=np.float32)
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
def reset(self) -> None:
|
193 |
"""
|
|
|
45 |
if hasattr(self.policy, 'bias'):
|
46 |
print(f"Policy bias: {self.policy.bias}")
|
47 |
|
48 |
+
# Inspect policy methods to understand expected input format
|
49 |
+
if hasattr(self.policy, 'get_action'):
|
50 |
+
print(f"Policy has get_action method")
|
51 |
+
if hasattr(self.policy, '_get_obs'):
|
52 |
+
print(f"Policy has _get_obs method")
|
53 |
+
|
54 |
+
# Try to understand what observation format the policy expects
|
55 |
+
try:
|
56 |
+
# Some MetaWorld policies might have observation space info
|
57 |
+
if hasattr(self.policy, 'observation_space'):
|
58 |
+
print(f"Policy observation space: {self.policy.observation_space}")
|
59 |
+
except:
|
60 |
+
pass
|
61 |
+
|
62 |
# Track episode state
|
63 |
self.episode_step = 0
|
64 |
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
|
|
|
66 |
# Policy scaling factor (can be adjusted if policy constants are too high)
|
67 |
self.policy_scale = kwargs.get("policy_scale", 1.0)
|
68 |
|
69 |
+
# Flag to try different observation processing strategies
|
70 |
+
self.try_alternative_obs = True
|
71 |
+
|
72 |
# Debug flags
|
73 |
self.debug_observations = True
|
74 |
self.debug_actions = True
|
|
|
100 |
# Process observation to extract the format needed by the expert policy
|
101 |
processed_obs = self._process_observation(obs)
|
102 |
|
103 |
+
# Optionally normalize observation
|
104 |
+
if self.try_alternative_obs:
|
105 |
+
processed_obs = self._normalize_observation(processed_obs)
|
106 |
+
|
107 |
+
# Debug: print all observation keys and their shapes to understand the structure
|
108 |
+
if isinstance(obs, dict):
|
109 |
+
print("Full observation keys and shapes:")
|
110 |
+
for key, value in obs.items():
|
111 |
+
if isinstance(value, np.ndarray):
|
112 |
+
print(f" {key}: shape={value.shape}, dtype={value.dtype}, range=[{value.min():.3f}, {value.max():.3f}]")
|
113 |
+
else:
|
114 |
+
print(f" {key}: {type(value)} = {value}")
|
115 |
+
|
116 |
# Debug processed observation (reduced frequency)
|
117 |
print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
|
118 |
print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
|
119 |
|
120 |
+
# Try different approaches for the MetaWorld policy
|
121 |
+
action_numpy = None
|
122 |
+
|
123 |
+
# Strategy 1: Try with processed observation (39-dim flattened array)
|
124 |
+
try:
|
125 |
+
action_numpy = self.policy.get_action(processed_obs)
|
126 |
+
print(f"✓ Used processed 39-dim observation for policy")
|
127 |
+
except Exception as e1:
|
128 |
+
print(f"✗ Failed with processed observation: {e1}")
|
129 |
+
|
130 |
+
# Strategy 2: Try with raw observation if it's a dict
|
131 |
+
if action_numpy is None and isinstance(obs, dict):
|
132 |
+
try:
|
133 |
+
action_numpy = self.policy.get_action(obs)
|
134 |
+
print(f"✓ Used raw observation dictionary for policy")
|
135 |
+
except Exception as e2:
|
136 |
+
print(f"✗ Failed with raw observation dictionary: {e2}")
|
137 |
+
|
138 |
+
# Strategy 3: Try extracting specific MetaWorld observation components
|
139 |
+
try:
|
140 |
+
metaworld_obs = self._extract_metaworld_obs(obs)
|
141 |
+
if metaworld_obs is not None:
|
142 |
+
action_numpy = self.policy.get_action(metaworld_obs)
|
143 |
+
print(f"✓ Used extracted MetaWorld observation for policy")
|
144 |
+
except Exception as e3:
|
145 |
+
print(f"✗ Failed with extracted observation: {e3}")
|
146 |
+
|
147 |
+
# Final fallback
|
148 |
+
if action_numpy is None:
|
149 |
+
print("âš Using zero action as fallback")
|
150 |
+
action_numpy = np.zeros(4, dtype=np.float32)
|
151 |
|
152 |
# Debug raw policy output (reduced frequency)
|
153 |
print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
|
|
|
195 |
"""
|
196 |
Helper method to process observations for the MetaWorld expert policy.
|
197 |
|
198 |
+
MetaWorld reach task policies typically expect observations with:
|
199 |
+
- End effector position (3 values)
|
200 |
+
- Target position (3 values)
|
201 |
+
- Joint positions and velocities (various dimensions)
|
202 |
+
- Total around 39 dimensions for Sawyer reach task
|
203 |
"""
|
204 |
if isinstance(obs, dict):
|
205 |
+
# MetaWorld-specific observation keys for reach task
|
206 |
+
metaworld_keys = [
|
207 |
+
"observation", # Standard observation
|
208 |
+
"obs", # Alternative observation key
|
209 |
+
"state", # State observation
|
210 |
+
"achieved_goal", # For goal-based tasks
|
211 |
+
"desired_goal", # Target position
|
212 |
]
|
213 |
+
|
214 |
processed_obs = None
|
215 |
+
for key in metaworld_keys:
|
216 |
if key in obs:
|
217 |
processed_obs = obs[key]
|
218 |
+
print(f"Using MetaWorld observation key: {key}")
|
219 |
break
|
220 |
+
|
221 |
+
# If we found a specific key, ensure it's the right format
|
222 |
+
if processed_obs is not None:
|
223 |
+
if isinstance(processed_obs, np.ndarray):
|
224 |
+
# Ensure it's flattened and has the right dtype
|
225 |
+
processed_obs = processed_obs.flatten().astype(np.float32)
|
226 |
+
else:
|
227 |
+
processed_obs = np.array(processed_obs, dtype=np.float32).flatten()
|
228 |
+
|
229 |
if processed_obs is None:
|
230 |
+
# Fallback: concatenate relevant observation components
|
231 |
+
print("No standard MetaWorld key found, concatenating observation components")
|
232 |
+
|
233 |
+
# Look for position and velocity information
|
234 |
+
components = []
|
235 |
for key, value in obs.items():
|
236 |
+
if isinstance(value, np.ndarray) and len(value.flatten()) > 0:
|
237 |
+
flat_value = value.flatten().astype(np.float32)
|
238 |
+
components.append(flat_value)
|
239 |
+
print(f"Adding component {key}: shape={flat_value.shape}")
|
240 |
+
|
241 |
+
if components:
|
242 |
+
processed_obs = np.concatenate(components)
|
243 |
print(f"Concatenated observation shape: {processed_obs.shape}")
|
244 |
else:
|
245 |
+
# Last resort: create zeros
|
246 |
+
processed_obs = np.zeros(39, dtype=np.float32)
|
247 |
+
print("No valid observation components found, using zeros")
|
248 |
else:
|
249 |
+
# If obs is already an array, ensure it's properly formatted
|
250 |
+
processed_obs = np.array(obs, dtype=np.float32).flatten()
|
251 |
+
|
252 |
+
# Ensure we have the expected dimension for MetaWorld reach (typically 39)
|
253 |
+
if len(processed_obs) != 39:
|
254 |
+
print(f"Observation dimension mismatch: got {len(processed_obs)}, expected 39")
|
255 |
+
if len(processed_obs) < 39:
|
256 |
+
# Pad with zeros
|
257 |
+
padding = np.zeros(39 - len(processed_obs), dtype=np.float32)
|
258 |
+
processed_obs = np.concatenate([processed_obs, padding])
|
259 |
+
print(f"Padded observation to 39 dimensions")
|
260 |
+
else:
|
261 |
+
# Truncate
|
262 |
+
processed_obs = processed_obs[:39]
|
263 |
+
print(f"Truncated observation to 39 dimensions")
|
264 |
|
265 |
+
return processed_obs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
+
def _extract_metaworld_obs(self, obs):
|
268 |
+
"""
|
269 |
+
Extract MetaWorld-specific observation components for the reach task.
|
270 |
+
|
271 |
+
MetaWorld reach observations typically include:
|
272 |
+
- Joint positions (7 values for Sawyer)
|
273 |
+
- Joint velocities (7 values)
|
274 |
+
- End effector position (3 values)
|
275 |
+
- Target position (3 values)
|
276 |
+
- Other task-specific info
|
277 |
+
"""
|
278 |
+
if not isinstance(obs, dict):
|
279 |
+
return None
|
280 |
+
|
281 |
+
components = []
|
282 |
+
|
283 |
+
# Try to find joint positions
|
284 |
+
if 'qpos' in obs:
|
285 |
+
joint_pos = np.array(obs['qpos'], dtype=np.float32).flatten()
|
286 |
+
components.append(joint_pos)
|
287 |
+
print(f"Found joint positions: {joint_pos.shape}")
|
288 |
+
|
289 |
+
# Try to find joint velocities
|
290 |
+
if 'qvel' in obs:
|
291 |
+
joint_vel = np.array(obs['qvel'], dtype=np.float32).flatten()
|
292 |
+
components.append(joint_vel)
|
293 |
+
print(f"Found joint velocities: {joint_vel.shape}")
|
294 |
+
|
295 |
+
# Try to find end effector position
|
296 |
+
if 'eef_pos' in obs or 'achieved_goal' in obs:
|
297 |
+
eef_key = 'eef_pos' if 'eef_pos' in obs else 'achieved_goal'
|
298 |
+
eef_pos = np.array(obs[eef_key], dtype=np.float32).flatten()
|
299 |
+
if len(eef_pos) >= 3:
|
300 |
+
components.append(eef_pos[:3]) # Take first 3 values (x, y, z)
|
301 |
+
print(f"Found end effector position: {eef_pos[:3]}")
|
302 |
+
|
303 |
+
# Try to find target/goal position
|
304 |
+
if 'target_pos' in obs or 'desired_goal' in obs:
|
305 |
+
target_key = 'target_pos' if 'target_pos' in obs else 'desired_goal'
|
306 |
+
target_pos = np.array(obs[target_key], dtype=np.float32).flatten()
|
307 |
+
if len(target_pos) >= 3:
|
308 |
+
components.append(target_pos[:3]) # Take first 3 values (x, y, z)
|
309 |
+
print(f"Found target position: {target_pos[:3]}")
|
310 |
+
|
311 |
+
# If we found components, concatenate them
|
312 |
+
if components:
|
313 |
+
metaworld_obs = np.concatenate(components)
|
314 |
+
print(f"Extracted MetaWorld observation: {metaworld_obs.shape} dimensions")
|
315 |
+
return metaworld_obs
|
316 |
+
|
317 |
+
return None
|
318 |
+
|
319 |
+
def _normalize_observation(self, obs):
|
320 |
+
"""
|
321 |
+
Normalize observation if needed for MetaWorld policy.
|
322 |
|
323 |
+
Some MetaWorld policies expect normalized observations.
|
324 |
+
"""
|
325 |
+
if not isinstance(obs, np.ndarray):
|
326 |
+
return obs
|
327 |
+
|
328 |
+
# Check if observation values are in a reasonable range
|
329 |
+
obs_min, obs_max = obs.min(), obs.max()
|
330 |
+
|
331 |
+
# If values are very large or very small, they might need normalization
|
332 |
+
if abs(obs_max) > 10 or abs(obs_min) > 10:
|
333 |
+
print(f"Observation values seem large (min={obs_min:.3f}, max={obs_max:.3f}), normalizing...")
|
334 |
+
# Normalize to roughly [-1, 1] range
|
335 |
+
obs_mean = obs.mean()
|
336 |
+
obs_std = obs.std()
|
337 |
+
if obs_std > 0:
|
338 |
+
normalized_obs = (obs - obs_mean) / obs_std
|
339 |
+
print(f"Normalized observation range: [{normalized_obs.min():.3f}, {normalized_obs.max():.3f}]")
|
340 |
+
return normalized_obs
|
341 |
+
|
342 |
+
return obs
|
343 |
|
344 |
def reset(self) -> None:
|
345 |
"""
|