|
|
|
import logging |
|
from typing import Any, Dict |
|
|
|
import gymnasium as gym |
|
import metaworld |
|
import numpy as np |
|
import torch |
|
from agent_interface import AgentInterface |
|
from metaworld.policies.sawyer_reach_v3_policy import SawyerReachV3Policy |
|
|
|
|
|
class RLAgent(AgentInterface): |
|
""" |
|
MetaWorld agent implementation using the SawyerReachV3Policy expert policy. |
|
|
|
This agent uses the expert policy from MetaWorld for reach tasks. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
observation_space: gym.Space | None = None, |
|
action_space: gym.Space | None = None, |
|
seed: int | None = None, |
|
**kwargs, |
|
): |
|
super().__init__(observation_space, action_space, seed, **kwargs) |
|
|
|
print(f"Initializing MetaWorld agent with seed {self.seed}") |
|
|
|
|
|
if observation_space: |
|
print(f"Observation space: {observation_space}") |
|
if action_space: |
|
print(f"Action space: {action_space}") |
|
|
|
self.policy = SawyerReachV3Policy() |
|
print("Successfully initialized SawyerReachV3Policy") |
|
|
|
|
|
if hasattr(self.policy, 'action_space'): |
|
print(f"Policy action space: {self.policy.action_space}") |
|
if hasattr(self.policy, 'scale'): |
|
print(f"Policy scale: {self.policy.scale}") |
|
if hasattr(self.policy, 'bias'): |
|
print(f"Policy bias: {self.policy.bias}") |
|
|
|
|
|
if hasattr(self.policy, 'get_action'): |
|
print(f"Policy has get_action method") |
|
if hasattr(self.policy, '_get_obs'): |
|
print(f"Policy has _get_obs method") |
|
|
|
|
|
try: |
|
|
|
if hasattr(self.policy, 'observation_space'): |
|
print(f"Policy observation space: {self.policy.observation_space}") |
|
except: |
|
pass |
|
|
|
|
|
self.episode_step = 0 |
|
self.max_episode_steps = kwargs.get("max_episode_steps", 200) |
|
|
|
|
|
self.policy_scale = kwargs.get("policy_scale", 1.0) |
|
|
|
|
|
self.try_alternative_obs = True |
|
|
|
|
|
self.debug_observations = True |
|
self.debug_actions = True |
|
|
|
print("MetaWorld agent initialized successfully") |
|
|
|
def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor: |
|
""" |
|
Process the observation and return an action using the MetaWorld expert policy. |
|
|
|
Args: |
|
obs: Observation from the environment |
|
kwargs: Additional arguments |
|
|
|
Returns: |
|
action: Action tensor to take in the environment |
|
""" |
|
try: |
|
|
|
print(f"Raw observation structure: {type(obs)}") |
|
if isinstance(obs, dict): |
|
print(f"Observation keys: {list(obs.keys())}") |
|
for key, value in obs.items(): |
|
if isinstance(value, np.ndarray): |
|
print(f" {key}: shape={value.shape}, dtype={value.dtype}") |
|
else: |
|
print(f" {key}: {type(value)} = {value}") |
|
|
|
|
|
processed_obs = self._process_observation(obs) |
|
|
|
|
|
if self.try_alternative_obs: |
|
processed_obs = self._normalize_observation(processed_obs) |
|
|
|
|
|
if isinstance(obs, dict): |
|
print("Full observation keys and shapes:") |
|
for key, value in obs.items(): |
|
if isinstance(value, np.ndarray): |
|
print(f" {key}: shape={value.shape}, dtype={value.dtype}, range=[{value.min():.3f}, {value.max():.3f}]") |
|
else: |
|
print(f" {key}: {type(value)} = {value}") |
|
|
|
|
|
print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}") |
|
print(f"Processed obs sample: {processed_obs[:10]}...") |
|
|
|
|
|
action_numpy = None |
|
|
|
|
|
try: |
|
action_numpy = self.policy.get_action(processed_obs) |
|
print(f"β Used processed 39-dim observation for policy") |
|
except Exception as e1: |
|
print(f"β Failed with processed observation: {e1}") |
|
|
|
|
|
if action_numpy is None and isinstance(obs, dict): |
|
try: |
|
action_numpy = self.policy.get_action(obs) |
|
print(f"β Used raw observation dictionary for policy") |
|
except Exception as e2: |
|
print(f"β Failed with raw observation dictionary: {e2}") |
|
|
|
|
|
try: |
|
metaworld_obs = self._extract_metaworld_obs(obs) |
|
if metaworld_obs is not None: |
|
action_numpy = self.policy.get_action(metaworld_obs) |
|
print(f"β Used extracted MetaWorld observation for policy") |
|
except Exception as e3: |
|
print(f"β Failed with extracted observation: {e3}") |
|
|
|
|
|
if action_numpy is None: |
|
print("β Using zero action as fallback") |
|
action_numpy = np.zeros(4, dtype=np.float32) |
|
|
|
|
|
print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}") |
|
print(f"Action shape: {np.array(action_numpy).shape}") |
|
|
|
|
|
if isinstance(action_numpy, (list, tuple)): |
|
action_tensor = torch.tensor(action_numpy, dtype=torch.float32) |
|
else: |
|
action_tensor = torch.from_numpy(np.array(action_numpy)).float() |
|
|
|
|
|
action_tensor = action_tensor * self.policy_scale |
|
|
|
|
|
action_tensor = torch.clamp(action_tensor, -1.0, 1.0) |
|
|
|
|
|
if self.action_space and hasattr(self.action_space, 'shape'): |
|
expected_shape = self.action_space.shape[0] |
|
if action_tensor.shape[0] != expected_shape: |
|
print(f"Action shape mismatch: got {action_tensor.shape[0]}, expected {expected_shape}") |
|
|
|
if action_tensor.shape[0] < expected_shape: |
|
padding = torch.zeros(expected_shape - action_tensor.shape[0]) |
|
action_tensor = torch.cat([action_tensor, padding]) |
|
else: |
|
action_tensor = action_tensor[:expected_shape] |
|
|
|
|
|
print(f"Final action tensor: {action_tensor}") |
|
|
|
self.episode_step += 1 |
|
return action_tensor |
|
|
|
except Exception as e: |
|
print(f"Error in act method: {e}") |
|
|
|
if isinstance(self.action_space, gym.spaces.Box): |
|
return torch.zeros(self.action_space.shape[0], dtype=torch.float32) |
|
else: |
|
return torch.zeros(4, dtype=torch.float32) |
|
|
|
def _process_observation(self, obs): |
|
""" |
|
Helper method to process observations for the MetaWorld expert policy. |
|
|
|
MetaWorld reach task policies typically expect observations with: |
|
- End effector position (3 values) |
|
- Target position (3 values) |
|
- Joint positions and velocities (various dimensions) |
|
- Total around 39 dimensions for Sawyer reach task |
|
""" |
|
if isinstance(obs, dict): |
|
|
|
metaworld_keys = [ |
|
"observation", |
|
"obs", |
|
"state", |
|
"achieved_goal", |
|
"desired_goal", |
|
] |
|
|
|
processed_obs = None |
|
for key in metaworld_keys: |
|
if key in obs: |
|
processed_obs = obs[key] |
|
print(f"Using MetaWorld observation key: {key}") |
|
break |
|
|
|
|
|
if processed_obs is not None: |
|
if isinstance(processed_obs, np.ndarray): |
|
|
|
processed_obs = processed_obs.flatten().astype(np.float32) |
|
else: |
|
processed_obs = np.array(processed_obs, dtype=np.float32).flatten() |
|
|
|
if processed_obs is None: |
|
|
|
print("No standard MetaWorld key found, concatenating observation components") |
|
|
|
|
|
components = [] |
|
for key, value in obs.items(): |
|
if isinstance(value, np.ndarray) and len(value.flatten()) > 0: |
|
flat_value = value.flatten().astype(np.float32) |
|
components.append(flat_value) |
|
print(f"Adding component {key}: shape={flat_value.shape}") |
|
|
|
if components: |
|
processed_obs = np.concatenate(components) |
|
print(f"Concatenated observation shape: {processed_obs.shape}") |
|
else: |
|
|
|
processed_obs = np.zeros(39, dtype=np.float32) |
|
print("No valid observation components found, using zeros") |
|
else: |
|
|
|
processed_obs = np.array(obs, dtype=np.float32).flatten() |
|
|
|
|
|
if len(processed_obs) != 39: |
|
print(f"Observation dimension mismatch: got {len(processed_obs)}, expected 39") |
|
if len(processed_obs) < 39: |
|
|
|
padding = np.zeros(39 - len(processed_obs), dtype=np.float32) |
|
processed_obs = np.concatenate([processed_obs, padding]) |
|
print(f"Padded observation to 39 dimensions") |
|
else: |
|
|
|
processed_obs = processed_obs[:39] |
|
print(f"Truncated observation to 39 dimensions") |
|
|
|
return processed_obs |
|
|
|
def _extract_metaworld_obs(self, obs): |
|
""" |
|
Extract MetaWorld-specific observation components for the reach task. |
|
|
|
MetaWorld reach observations typically include: |
|
- Joint positions (7 values for Sawyer) |
|
- Joint velocities (7 values) |
|
- End effector position (3 values) |
|
- Target position (3 values) |
|
- Other task-specific info |
|
""" |
|
if not isinstance(obs, dict): |
|
return None |
|
|
|
components = [] |
|
|
|
|
|
if 'qpos' in obs: |
|
joint_pos = np.array(obs['qpos'], dtype=np.float32).flatten() |
|
components.append(joint_pos) |
|
print(f"Found joint positions: {joint_pos.shape}") |
|
|
|
|
|
if 'qvel' in obs: |
|
joint_vel = np.array(obs['qvel'], dtype=np.float32).flatten() |
|
components.append(joint_vel) |
|
print(f"Found joint velocities: {joint_vel.shape}") |
|
|
|
|
|
if 'eef_pos' in obs or 'achieved_goal' in obs: |
|
eef_key = 'eef_pos' if 'eef_pos' in obs else 'achieved_goal' |
|
eef_pos = np.array(obs[eef_key], dtype=np.float32).flatten() |
|
if len(eef_pos) >= 3: |
|
components.append(eef_pos[:3]) |
|
print(f"Found end effector position: {eef_pos[:3]}") |
|
|
|
|
|
if 'target_pos' in obs or 'desired_goal' in obs: |
|
target_key = 'target_pos' if 'target_pos' in obs else 'desired_goal' |
|
target_pos = np.array(obs[target_key], dtype=np.float32).flatten() |
|
if len(target_pos) >= 3: |
|
components.append(target_pos[:3]) |
|
print(f"Found target position: {target_pos[:3]}") |
|
|
|
|
|
if components: |
|
metaworld_obs = np.concatenate(components) |
|
print(f"Extracted MetaWorld observation: {metaworld_obs.shape} dimensions") |
|
return metaworld_obs |
|
|
|
return None |
|
|
|
def _normalize_observation(self, obs): |
|
""" |
|
Normalize observation if needed for MetaWorld policy. |
|
|
|
Some MetaWorld policies expect normalized observations. |
|
""" |
|
if not isinstance(obs, np.ndarray): |
|
return obs |
|
|
|
|
|
obs_min, obs_max = obs.min(), obs.max() |
|
|
|
|
|
if abs(obs_max) > 10 or abs(obs_min) > 10: |
|
print(f"Observation values seem large (min={obs_min:.3f}, max={obs_max:.3f}), normalizing...") |
|
|
|
obs_mean = obs.mean() |
|
obs_std = obs.std() |
|
if obs_std > 0: |
|
normalized_obs = (obs - obs_mean) / obs_std |
|
print(f"Normalized observation range: [{normalized_obs.min():.3f}, {normalized_obs.max():.3f}]") |
|
return normalized_obs |
|
|
|
return obs |
|
|
|
def reset(self) -> None: |
|
""" |
|
Reset agent state between episodes. |
|
""" |
|
print(f"Resetting agent after {self.episode_step} steps") |
|
self.episode_step = 0 |
|
|
|
self.debug_observations = True |
|
self.debug_actions = True |
|
|
|
def _build_model(self): |
|
""" |
|
Build a neural network model for the agent. |
|
|
|
This is a placeholder for where you would define your neural network |
|
architecture using PyTorch, TensorFlow, or another framework. |
|
""" |
|
pass |