| """ |
| Implementation of the AgentInterface for MetaWorld tasks. |
| |
| This agent uses the SawyerPickPlaceV2Policy from MetaWorld as an expert policy. |
| """ |
|
|
| import logging |
| from typing import Any, Dict |
|
|
| import gymnasium as gym |
| import metaworld |
| import numpy as np |
| import torch |
| from agent_interface import AgentInterface |
| from metaworld.policies.sawyer_reach_v3_policy import SawyerReachV3Policy |
|
|
|
|
| class RLAgent(AgentInterface): |
| """ |
| MetaWorld agent implementation using the SawyerReachV3Policy expert policy. |
| |
| This agent uses the expert policy from MetaWorld for reach tasks. |
| """ |
|
|
| def __init__( |
| self, |
| observation_space: gym.Space | None = None, |
| action_space: gym.Space | None = None, |
| seed: int | None = None, |
| **kwargs, |
| ): |
| super().__init__(observation_space, action_space, seed, **kwargs) |
|
|
| self.logger = logging.getLogger(__name__) |
| self.logger.info(f"Initializing MetaWorld agent with seed {self.seed}") |
|
|
| self.policy = SawyerReachV3Policy() |
| self.logger.info("Successfully initialized SawyerReachV3Policy") |
|
|
| |
| self.episode_step = 0 |
| self.max_episode_steps = kwargs.get("max_episode_steps", 200) |
|
|
| self.logger.info("MetaWorld agent initialized successfully") |
|
|
| def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor: |
| """ |
| Process the observation and return an action using the MetaWorld expert policy. |
| |
| Args: |
| obs: Observation from the environment |
| kwargs: Additional arguments |
| |
| Returns: |
| action: Action tensor to take in the environment |
| """ |
| try: |
| |
| processed_obs = self._process_observation(obs) |
|
|
| |
| |
| action_numpy = self.policy.get_action(processed_obs) |
| action_tensor = torch.from_numpy(np.array(action_numpy)).float() |
|
|
| |
| if self.episode_step % 50 == 0: |
| self.logger.debug(f"Using expert policy action: {action_numpy}") |
|
|
| |
| self.episode_step += 1 |
|
|
| |
| if self.episode_step % 50 == 0: |
| self.logger.debug( |
| f"Step {self.episode_step}: Action shape {action_tensor.shape}" |
| ) |
|
|
| return action_tensor |
|
|
| except Exception as e: |
| self.logger.error(f"Error in act method: {e}", exc_info=True) |
| |
| if isinstance(self.action_space, gym.spaces.Box): |
| return torch.zeros(self.action_space.shape[0], dtype=torch.float32) |
| else: |
| return torch.zeros(4, dtype=torch.float32) |
|
|
| def _process_observation(self, obs): |
| """ |
| Helper method to process observations for the MetaWorld expert policy. |
| |
| MetaWorld policies typically expect a specific observation format. |
| """ |
| if isinstance(obs, dict): |
| |
| if "observation" in obs: |
| |
| processed_obs = obs["observation"] |
| elif "obs" in obs: |
| processed_obs = obs["obs"] |
| elif "state_observation" in obs: |
| |
| processed_obs = obs["state_observation"] |
| elif "goal_achieved" in obs: |
| |
| |
| achievement = obs.get("goal_achieved", False) |
| base_obs = next(iter(obs.values())) |
| self.logger.debug(f"Goal achieved: {achievement}") |
| processed_obs = base_obs |
| else: |
| |
| processed_obs = next(iter(obs.values())) |
| self.logger.debug(f"Using observation key: {next(iter(obs.keys()))}") |
| else: |
| |
| processed_obs = obs |
|
|
| |
| if not isinstance(processed_obs, np.ndarray): |
| try: |
| processed_obs = np.array(processed_obs, dtype=np.float32) |
| except Exception as e: |
| self.logger.error(f"Failed to convert observation to numpy array: {e}") |
| |
| if ( |
| self.observation_space |
| and hasattr(self.observation_space, "shape") |
| and self.observation_space.shape is not None |
| ): |
| processed_obs = np.zeros( |
| self.observation_space.shape, dtype=np.float32 |
| ) |
| else: |
| |
| processed_obs = np.zeros(39, dtype=np.float32) |
|
|
| return processed_obs |
|
|
| def reset(self) -> None: |
| """ |
| Reset agent state between episodes. |
| """ |
| self.logger.debug("Resetting agent") |
| self.episode_step = 0 |
| |
|
|
| def _build_model(self): |
| """ |
| Build a neural network model for the agent. |
| |
| This is a placeholder for where you would define your neural network |
| architecture using PyTorch, TensorFlow, or another framework. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pass |
|
|