import logging from typing import Any, Dict import gymnasium as gym import metaworld import numpy as np import torch from agent_interface import AgentInterface from metaworld.policies.sawyer_reach_v3_policy import SawyerReachV3Policy class RLAgent(AgentInterface): """ MetaWorld agent implementation using the SawyerReachV3Policy expert policy. This agent uses the expert policy from MetaWorld for reach tasks. """ def __init__( self, observation_space: gym.Space | None = None, action_space: gym.Space | None = None, seed: int | None = None, **kwargs, ): super().__init__(observation_space, action_space, seed, **kwargs) print(f"Initializing MetaWorld agent with seed {self.seed}") # Log spaces for debugging if observation_space: print(f"Observation space: {observation_space}") if action_space: print(f"Action space: {action_space}") self.policy = SawyerReachV3Policy() print("Successfully initialized SawyerReachV3Policy") # Check if policy has any scaling attributes that might need adjustment if hasattr(self.policy, 'action_space'): print(f"Policy action space: {self.policy.action_space}") if hasattr(self.policy, 'scale'): print(f"Policy scale: {self.policy.scale}") if hasattr(self.policy, 'bias'): print(f"Policy bias: {self.policy.bias}") # Inspect policy methods to understand expected input format if hasattr(self.policy, 'get_action'): print(f"Policy has get_action method") if hasattr(self.policy, '_get_obs'): print(f"Policy has _get_obs method") # Try to understand what observation format the policy expects try: # Some MetaWorld policies might have observation space info if hasattr(self.policy, 'observation_space'): print(f"Policy observation space: {self.policy.observation_space}") except: pass # Track episode state self.episode_step = 0 self.max_episode_steps = kwargs.get("max_episode_steps", 200) # Policy scaling factor (can be adjusted if policy constants are too high) self.policy_scale = kwargs.get("policy_scale", 1.0) # Flag to try different observation processing strategies self.try_alternative_obs = True # Debug flags self.debug_observations = True self.debug_actions = True print("MetaWorld agent initialized successfully") def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor: """ Process the observation and return an action using the MetaWorld expert policy. Args: obs: Observation from the environment kwargs: Additional arguments Returns: action: Action tensor to take in the environment """ try: # Debug observation structure (reduced frequency) print(f"Raw observation structure: {type(obs)}") if isinstance(obs, dict): print(f"Observation keys: {list(obs.keys())}") for key, value in obs.items(): if isinstance(value, np.ndarray): print(f" {key}: shape={value.shape}, dtype={value.dtype}") else: print(f" {key}: {type(value)} = {value}") # Process observation to extract the format needed by the expert policy processed_obs = self._process_observation(obs) # Optionally normalize observation if self.try_alternative_obs: processed_obs = self._normalize_observation(processed_obs) # Debug: print all observation keys and their shapes to understand the structure if isinstance(obs, dict): print("Full observation keys and shapes:") for key, value in obs.items(): if isinstance(value, np.ndarray): print(f" {key}: shape={value.shape}, dtype={value.dtype}, range=[{value.min():.3f}, {value.max():.3f}]") else: print(f" {key}: {type(value)} = {value}") # Debug processed observation (reduced frequency) print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}") print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values # Try different approaches for the MetaWorld policy action_numpy = None # Strategy 1: Try with processed observation (39-dim flattened array) try: action_numpy = self.policy.get_action(processed_obs) print(f"✓ Used processed 39-dim observation for policy") except Exception as e1: print(f"✗ Failed with processed observation: {e1}") # Strategy 2: Try with raw observation if it's a dict if action_numpy is None and isinstance(obs, dict): try: action_numpy = self.policy.get_action(obs) print(f"✓ Used raw observation dictionary for policy") except Exception as e2: print(f"✗ Failed with raw observation dictionary: {e2}") # Strategy 3: Try extracting specific MetaWorld observation components try: metaworld_obs = self._extract_metaworld_obs(obs) if metaworld_obs is not None: action_numpy = self.policy.get_action(metaworld_obs) print(f"✓ Used extracted MetaWorld observation for policy") except Exception as e3: print(f"✗ Failed with extracted observation: {e3}") # Final fallback if action_numpy is None: print("⚠ Using zero action as fallback") action_numpy = np.zeros(4, dtype=np.float32) # Debug raw policy output (reduced frequency) print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}") print(f"Action shape: {np.array(action_numpy).shape}") # Convert to tensor if isinstance(action_numpy, (list, tuple)): action_tensor = torch.tensor(action_numpy, dtype=torch.float32) else: action_tensor = torch.from_numpy(np.array(action_numpy)).float() # Apply scaling factor if needed (helps with policy constants that may be too high) action_tensor = action_tensor * self.policy_scale # Clip actions to [-1, 1] range to handle policy constants that may be too high action_tensor = torch.clamp(action_tensor, -1.0, 1.0) # Ensure correct action dimensionality if self.action_space and hasattr(self.action_space, 'shape'): expected_shape = self.action_space.shape[0] if action_tensor.shape[0] != expected_shape: print(f"Action shape mismatch: got {action_tensor.shape[0]}, expected {expected_shape}") # Pad or truncate as needed if action_tensor.shape[0] < expected_shape: padding = torch.zeros(expected_shape - action_tensor.shape[0]) action_tensor = torch.cat([action_tensor, padding]) else: action_tensor = action_tensor[:expected_shape] # Debug final action (reduced frequency) print(f"Final action tensor: {action_tensor}") self.episode_step += 1 return action_tensor except Exception as e: print(f"Error in act method: {e}") # Return zeros as a fallback if isinstance(self.action_space, gym.spaces.Box): return torch.zeros(self.action_space.shape[0], dtype=torch.float32) else: return torch.zeros(4, dtype=torch.float32) def _process_observation(self, obs): """ Helper method to process observations for the MetaWorld expert policy. MetaWorld reach task policies typically expect observations with: - End effector position (3 values) - Target position (3 values) - Joint positions and velocities (various dimensions) - Total around 39 dimensions for Sawyer reach task """ if isinstance(obs, dict): # MetaWorld-specific observation keys for reach task metaworld_keys = [ "observation", # Standard observation "obs", # Alternative observation key "state", # State observation "achieved_goal", # For goal-based tasks "desired_goal", # Target position ] processed_obs = None for key in metaworld_keys: if key in obs: processed_obs = obs[key] print(f"Using MetaWorld observation key: {key}") break # If we found a specific key, ensure it's the right format if processed_obs is not None: if isinstance(processed_obs, np.ndarray): # Ensure it's flattened and has the right dtype processed_obs = processed_obs.flatten().astype(np.float32) else: processed_obs = np.array(processed_obs, dtype=np.float32).flatten() if processed_obs is None: # Fallback: concatenate relevant observation components print("No standard MetaWorld key found, concatenating observation components") # Look for position and velocity information components = [] for key, value in obs.items(): if isinstance(value, np.ndarray) and len(value.flatten()) > 0: flat_value = value.flatten().astype(np.float32) components.append(flat_value) print(f"Adding component {key}: shape={flat_value.shape}") if components: processed_obs = np.concatenate(components) print(f"Concatenated observation shape: {processed_obs.shape}") else: # Last resort: create zeros processed_obs = np.zeros(39, dtype=np.float32) print("No valid observation components found, using zeros") else: # If obs is already an array, ensure it's properly formatted processed_obs = np.array(obs, dtype=np.float32).flatten() # Ensure we have the expected dimension for MetaWorld reach (typically 39) if len(processed_obs) != 39: print(f"Observation dimension mismatch: got {len(processed_obs)}, expected 39") if len(processed_obs) < 39: # Pad with zeros padding = np.zeros(39 - len(processed_obs), dtype=np.float32) processed_obs = np.concatenate([processed_obs, padding]) print(f"Padded observation to 39 dimensions") else: # Truncate processed_obs = processed_obs[:39] print(f"Truncated observation to 39 dimensions") return processed_obs def _extract_metaworld_obs(self, obs): """ Extract MetaWorld-specific observation components for the reach task. MetaWorld reach observations typically include: - Joint positions (7 values for Sawyer) - Joint velocities (7 values) - End effector position (3 values) - Target position (3 values) - Other task-specific info """ if not isinstance(obs, dict): return None components = [] # Try to find joint positions if 'qpos' in obs: joint_pos = np.array(obs['qpos'], dtype=np.float32).flatten() components.append(joint_pos) print(f"Found joint positions: {joint_pos.shape}") # Try to find joint velocities if 'qvel' in obs: joint_vel = np.array(obs['qvel'], dtype=np.float32).flatten() components.append(joint_vel) print(f"Found joint velocities: {joint_vel.shape}") # Try to find end effector position if 'eef_pos' in obs or 'achieved_goal' in obs: eef_key = 'eef_pos' if 'eef_pos' in obs else 'achieved_goal' eef_pos = np.array(obs[eef_key], dtype=np.float32).flatten() if len(eef_pos) >= 3: components.append(eef_pos[:3]) # Take first 3 values (x, y, z) print(f"Found end effector position: {eef_pos[:3]}") # Try to find target/goal position if 'target_pos' in obs or 'desired_goal' in obs: target_key = 'target_pos' if 'target_pos' in obs else 'desired_goal' target_pos = np.array(obs[target_key], dtype=np.float32).flatten() if len(target_pos) >= 3: components.append(target_pos[:3]) # Take first 3 values (x, y, z) print(f"Found target position: {target_pos[:3]}") # If we found components, concatenate them if components: metaworld_obs = np.concatenate(components) print(f"Extracted MetaWorld observation: {metaworld_obs.shape} dimensions") return metaworld_obs return None def _normalize_observation(self, obs): """ Normalize observation if needed for MetaWorld policy. Some MetaWorld policies expect normalized observations. """ if not isinstance(obs, np.ndarray): return obs # Check if observation values are in a reasonable range obs_min, obs_max = obs.min(), obs.max() # If values are very large or very small, they might need normalization if abs(obs_max) > 10 or abs(obs_min) > 10: print(f"Observation values seem large (min={obs_min:.3f}, max={obs_max:.3f}), normalizing...") # Normalize to roughly [-1, 1] range obs_mean = obs.mean() obs_std = obs.std() if obs_std > 0: normalized_obs = (obs - obs_mean) / obs_std print(f"Normalized observation range: [{normalized_obs.min():.3f}, {normalized_obs.max():.3f}]") return normalized_obs return obs def reset(self) -> None: """ Reset agent state between episodes. """ print(f"Resetting agent after {self.episode_step} steps") self.episode_step = 0 # Reset debug flags if needed self.debug_observations = True self.debug_actions = True def _build_model(self): """ Build a neural network model for the agent. This is a placeholder for where you would define your neural network architecture using PyTorch, TensorFlow, or another framework. """ pass