refactor: replace logging with print statements for debugging in RLAgent
Browse files
agent.py
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
Implementation of the AgentInterface for MetaWorld tasks.
|
3 |
-
|
4 |
-
This agent uses the SawyerPickPlaceV2Policy from MetaWorld as an expert policy.
|
5 |
-
"""
|
6 |
|
7 |
import logging
|
8 |
from typing import Any, Dict
|
@@ -31,17 +26,26 @@ class RLAgent(AgentInterface):
|
|
31 |
):
|
32 |
super().__init__(observation_space, action_space, seed, **kwargs)
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
self.policy = SawyerReachV3Policy()
|
38 |
-
|
39 |
|
40 |
# Track episode state
|
41 |
self.episode_step = 0
|
42 |
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
|
46 |
def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
|
47 |
"""
|
@@ -55,31 +59,60 @@ class RLAgent(AgentInterface):
|
|
55 |
action: Action tensor to take in the environment
|
56 |
"""
|
57 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
# Process observation to extract the format needed by the expert policy
|
59 |
processed_obs = self._process_observation(obs)
|
60 |
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
# Increment episode step
|
71 |
self.episode_step += 1
|
72 |
-
|
73 |
-
# Occasionally log actions to avoid spam
|
74 |
-
if self.episode_step % 50 == 0:
|
75 |
-
self.logger.debug(
|
76 |
-
f"Step {self.episode_step}: Action shape {action_tensor.shape}"
|
77 |
-
)
|
78 |
-
|
79 |
return action_tensor
|
80 |
|
81 |
except Exception as e:
|
82 |
-
|
83 |
# Return zeros as a fallback
|
84 |
if isinstance(self.action_space, gym.spaces.Box):
|
85 |
return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
|
@@ -93,48 +126,56 @@ class RLAgent(AgentInterface):
|
|
93 |
MetaWorld policies typically expect a specific observation format.
|
94 |
"""
|
95 |
if isinstance(obs, dict):
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
else:
|
117 |
-
# If already a numpy array or similar, use directly
|
118 |
processed_obs = obs
|
119 |
|
120 |
-
# Ensure
|
121 |
if not isinstance(processed_obs, np.ndarray):
|
122 |
try:
|
123 |
processed_obs = np.array(processed_obs, dtype=np.float32)
|
124 |
except Exception as e:
|
125 |
-
|
126 |
-
# Return
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
processed_obs = np.zeros(
|
133 |
-
self.observation_space.shape, dtype=np.float32
|
134 |
-
)
|
135 |
-
else:
|
136 |
-
# Typical MetaWorld observation dimension if all else fails
|
137 |
-
processed_obs = np.zeros(39, dtype=np.float32)
|
138 |
|
139 |
return processed_obs
|
140 |
|
@@ -142,9 +183,11 @@ class RLAgent(AgentInterface):
|
|
142 |
"""
|
143 |
Reset agent state between episodes.
|
144 |
"""
|
145 |
-
|
146 |
self.episode_step = 0
|
147 |
-
#
|
|
|
|
|
148 |
|
149 |
def _build_model(self):
|
150 |
"""
|
@@ -153,13 +196,4 @@ class RLAgent(AgentInterface):
|
|
153 |
This is a placeholder for where you would define your neural network
|
154 |
architecture using PyTorch, TensorFlow, or another framework.
|
155 |
"""
|
156 |
-
|
157 |
-
# model = torch.nn.Sequential(
|
158 |
-
# torch.nn.Linear(self.observation_space.shape[0], 128),
|
159 |
-
# torch.nn.ReLU(),
|
160 |
-
# torch.nn.Linear(128, 64),
|
161 |
-
# torch.nn.ReLU(),
|
162 |
-
# torch.nn.Linear(64, self.action_space.shape[0]),
|
163 |
-
# )
|
164 |
-
# return model
|
165 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import logging
|
3 |
from typing import Any, Dict
|
|
|
26 |
):
|
27 |
super().__init__(observation_space, action_space, seed, **kwargs)
|
28 |
|
29 |
+
print(f"Initializing MetaWorld agent with seed {self.seed}")
|
30 |
+
|
31 |
+
# Log spaces for debugging
|
32 |
+
if observation_space:
|
33 |
+
print(f"Observation space: {observation_space}")
|
34 |
+
if action_space:
|
35 |
+
print(f"Action space: {action_space}")
|
36 |
|
37 |
self.policy = SawyerReachV3Policy()
|
38 |
+
print("Successfully initialized SawyerReachV3Policy")
|
39 |
|
40 |
# Track episode state
|
41 |
self.episode_step = 0
|
42 |
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
|
43 |
+
|
44 |
+
# Debug flags
|
45 |
+
self.debug_observations = True
|
46 |
+
self.debug_actions = True
|
47 |
|
48 |
+
print("MetaWorld agent initialized successfully")
|
49 |
|
50 |
def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
|
51 |
"""
|
|
|
59 |
action: Action tensor to take in the environment
|
60 |
"""
|
61 |
try:
|
62 |
+
# Debug observation structure
|
63 |
+
if self.debug_observations and self.episode_step % 20 == 0:
|
64 |
+
print(f"Raw observation structure: {type(obs)}")
|
65 |
+
if isinstance(obs, dict):
|
66 |
+
print(f"Observation keys: {list(obs.keys())}")
|
67 |
+
for key, value in obs.items():
|
68 |
+
if isinstance(value, np.ndarray):
|
69 |
+
print(f" {key}: shape={value.shape}, dtype={value.dtype}")
|
70 |
+
else:
|
71 |
+
print(f" {key}: {type(value)} = {value}")
|
72 |
+
|
73 |
# Process observation to extract the format needed by the expert policy
|
74 |
processed_obs = self._process_observation(obs)
|
75 |
|
76 |
+
# Debug processed observation
|
77 |
+
if self.debug_observations and self.episode_step % 20 == 0:
|
78 |
+
print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
|
79 |
+
print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
|
80 |
|
81 |
+
# Use the expert policy
|
82 |
+
action_numpy = self.policy.get_action(processed_obs)
|
83 |
+
|
84 |
+
# Debug raw policy output
|
85 |
+
if self.debug_actions and self.episode_step % 20 == 0:
|
86 |
+
print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
|
87 |
+
print(f"Action shape: {np.array(action_numpy).shape}")
|
88 |
+
|
89 |
+
# Convert to tensor
|
90 |
+
if isinstance(action_numpy, (list, tuple)):
|
91 |
+
action_tensor = torch.tensor(action_numpy, dtype=torch.float32)
|
92 |
+
else:
|
93 |
+
action_tensor = torch.from_numpy(np.array(action_numpy)).float()
|
94 |
+
|
95 |
+
# Ensure correct action dimensionality
|
96 |
+
if self.action_space and hasattr(self.action_space, 'shape'):
|
97 |
+
expected_shape = self.action_space.shape[0]
|
98 |
+
if action_tensor.shape[0] != expected_shape:
|
99 |
+
print(f"Action shape mismatch: got {action_tensor.shape[0]}, expected {expected_shape}")
|
100 |
+
# Pad or truncate as needed
|
101 |
+
if action_tensor.shape[0] < expected_shape:
|
102 |
+
padding = torch.zeros(expected_shape - action_tensor.shape[0])
|
103 |
+
action_tensor = torch.cat([action_tensor, padding])
|
104 |
+
else:
|
105 |
+
action_tensor = action_tensor[:expected_shape]
|
106 |
+
|
107 |
+
# Debug final action
|
108 |
+
if self.debug_actions and self.episode_step % 20 == 0:
|
109 |
+
print(f"Final action tensor: {action_tensor}")
|
110 |
|
|
|
111 |
self.episode_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return action_tensor
|
113 |
|
114 |
except Exception as e:
|
115 |
+
print(f"Error in act method: {e}")
|
116 |
# Return zeros as a fallback
|
117 |
if isinstance(self.action_space, gym.spaces.Box):
|
118 |
return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
|
|
|
126 |
MetaWorld policies typically expect a specific observation format.
|
127 |
"""
|
128 |
if isinstance(obs, dict):
|
129 |
+
# Try different keys that MetaWorld might use
|
130 |
+
possible_keys = [
|
131 |
+
"observation",
|
132 |
+
"obs",
|
133 |
+
"state_observation",
|
134 |
+
"achieved_goal",
|
135 |
+
"state"
|
136 |
+
]
|
137 |
+
|
138 |
+
processed_obs = None
|
139 |
+
for key in possible_keys:
|
140 |
+
if key in obs:
|
141 |
+
processed_obs = obs[key]
|
142 |
+
if self.debug_observations and self.episode_step % 50 == 0:
|
143 |
+
print(f"Using observation key: {key}")
|
144 |
+
break
|
145 |
+
|
146 |
+
if processed_obs is None:
|
147 |
+
# If none of the expected keys found, concatenate all numeric values
|
148 |
+
numeric_values = []
|
149 |
+
for key, value in obs.items():
|
150 |
+
if isinstance(value, (np.ndarray, list, tuple)):
|
151 |
+
flat_value = np.array(value).flatten()
|
152 |
+
numeric_values.append(flat_value)
|
153 |
+
if self.debug_observations and self.episode_step % 50 == 0:
|
154 |
+
print(f"Concatenating key {key}: shape={flat_value.shape}")
|
155 |
+
|
156 |
+
if numeric_values:
|
157 |
+
processed_obs = np.concatenate(numeric_values)
|
158 |
+
if self.debug_observations and self.episode_step % 50 == 0:
|
159 |
+
print(f"Concatenated observation shape: {processed_obs.shape}")
|
160 |
+
else:
|
161 |
+
# Last resort: use first value
|
162 |
+
processed_obs = next(iter(obs.values()))
|
163 |
+
print("No numeric values found, using first observation value")
|
164 |
else:
|
|
|
165 |
processed_obs = obs
|
166 |
|
167 |
+
# Ensure numpy array
|
168 |
if not isinstance(processed_obs, np.ndarray):
|
169 |
try:
|
170 |
processed_obs = np.array(processed_obs, dtype=np.float32)
|
171 |
except Exception as e:
|
172 |
+
print(f"Failed to convert observation to numpy array: {e}")
|
173 |
+
# Return default observation size for MetaWorld reach task
|
174 |
+
processed_obs = np.zeros(39, dtype=np.float32)
|
175 |
+
|
176 |
+
# Ensure proper shape for MetaWorld reach policy
|
177 |
+
if processed_obs.ndim > 1:
|
178 |
+
processed_obs = processed_obs.flatten()
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
return processed_obs
|
181 |
|
|
|
183 |
"""
|
184 |
Reset agent state between episodes.
|
185 |
"""
|
186 |
+
print(f"Resetting agent after {self.episode_step} steps")
|
187 |
self.episode_step = 0
|
188 |
+
# Reset debug flags if needed
|
189 |
+
self.debug_observations = True
|
190 |
+
self.debug_actions = True
|
191 |
|
192 |
def _build_model(self):
|
193 |
"""
|
|
|
196 |
This is a placeholder for where you would define your neural network
|
197 |
architecture using PyTorch, TensorFlow, or another framework.
|
198 |
"""
|
199 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|