| | """TD3+BC: TD3 with Behavior Cloning regularization for offline RL. |
| | |
| | All computations done in normalized space: |
| | - States: zero mean, unit variance (from dataset stats) |
| | - Actions: scaled to [-1, 1] using joint limits |
| | """ |
| |
|
| | import os |
| | import csv |
| | import copy |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from offline_dataset import OfflineRLDataset |
| |
|
| |
|
| | |
| | JOINT_LIMITS_LOW = torch.tensor( |
| | [-1.606, -1.221, -3.142, -2.251, -3.142, -2.16, -3.142, 0.0, 0.0], |
| | dtype=torch.float32, |
| | ) |
| | JOINT_LIMITS_HIGH = torch.tensor( |
| | [1.606, 1.518, 3.142, 2.251, 3.142, 3.142, 3.142, 0.05, 0.05], |
| | dtype=torch.float32, |
| | ) |
| |
|
| |
|
| | def normalize_action(action, low, high): |
| | """Map raw action from [low, high] to [-1, 1].""" |
| | return 2.0 * (action - low) / (high - low) - 1.0 |
| |
|
| |
|
| | def denormalize_action(action_norm, low, high): |
| | """Map normalized action from [-1, 1] to [low, high].""" |
| | return low + (action_norm + 1.0) * 0.5 * (high - low) |
| |
|
| |
|
| | class Actor(nn.Module): |
| | def __init__(self, state_dim, action_dim, state_mean, state_std): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(state_dim, 256), |
| | nn.ReLU(), |
| | nn.Linear(256, 256), |
| | nn.ReLU(), |
| | nn.Linear(256, action_dim), |
| | nn.Tanh(), |
| | ) |
| | self.register_buffer("state_mean", state_mean) |
| | self.register_buffer("state_std", state_std) |
| | self.register_buffer("action_low", JOINT_LIMITS_LOW) |
| | self.register_buffer("action_high", JOINT_LIMITS_HIGH) |
| |
|
| | def forward(self, state): |
| | """Returns normalized action in [-1, 1].""" |
| | state_norm = (state - self.state_mean) / self.state_std |
| | return self.net(state_norm) |
| |
|
| | def get_raw_action(self, state): |
| | """Returns denormalized action in joint-limit space.""" |
| | a_norm = self.forward(state) |
| | return denormalize_action(a_norm, self.action_low, self.action_high) |
| |
|
| |
|
| | class Critic(nn.Module): |
| | """Twin Q-networks with LayerNorm for stable offline RL training.""" |
| | def __init__(self, state_dim, action_dim): |
| | super().__init__() |
| | self.q1 = nn.Sequential( |
| | nn.Linear(state_dim + action_dim, 256), |
| | nn.LayerNorm(256), |
| | nn.ReLU(), |
| | nn.Linear(256, 256), |
| | nn.LayerNorm(256), |
| | nn.ReLU(), |
| | nn.Linear(256, 1), |
| | ) |
| | self.q2 = nn.Sequential( |
| | nn.Linear(state_dim + action_dim, 256), |
| | nn.LayerNorm(256), |
| | nn.ReLU(), |
| | nn.Linear(256, 256), |
| | nn.LayerNorm(256), |
| | nn.ReLU(), |
| | nn.Linear(256, 1), |
| | ) |
| |
|
| | def forward(self, state_norm, action_norm): |
| | sa = torch.cat([state_norm, action_norm], dim=-1) |
| | return self.q1(sa), self.q2(sa) |
| |
|
| | def q1_forward(self, state_norm, action_norm): |
| | sa = torch.cat([state_norm, action_norm], dim=-1) |
| | return self.q1(sa) |
| |
|
| |
|
| | class TD3BC: |
| | def __init__( |
| | self, |
| | state_dim=9, |
| | action_dim=9, |
| | state_mean=None, |
| | state_std=None, |
| | lr=3e-4, |
| | discount=0.99, |
| | tau=0.005, |
| | policy_noise=0.2, |
| | noise_clip=0.5, |
| | policy_delay=2, |
| | alpha=2.5, |
| | device="cuda", |
| | ): |
| | self.device = device |
| | self.discount = discount |
| | self.tau = tau |
| | self.policy_noise = policy_noise |
| | self.noise_clip = noise_clip |
| | self.policy_delay = policy_delay |
| | self.alpha = alpha |
| | self.max_action = 1.0 |
| |
|
| | self.state_mean = state_mean.to(device) |
| | self.state_std = state_std.to(device) |
| | self.action_low = JOINT_LIMITS_LOW.to(device) |
| | self.action_high = JOINT_LIMITS_HIGH.to(device) |
| |
|
| | self.actor = Actor(state_dim, action_dim, state_mean, state_std).to(device) |
| | self.actor_target = copy.deepcopy(self.actor) |
| | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) |
| |
|
| | self.critic = Critic(state_dim, action_dim).to(device) |
| | self.critic_target = copy.deepcopy(self.critic) |
| | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) |
| |
|
| | self.total_it = 0 |
| |
|
| | def _normalize_state(self, state): |
| | return (state - self.state_mean) / self.state_std |
| |
|
| | def _normalize_action(self, action): |
| | return normalize_action(action, self.action_low, self.action_high) |
| |
|
| | def train_step(self, state, action, reward, next_state, done): |
| | """One training step. state/action/next_state are raw (unnormalized).""" |
| | self.total_it += 1 |
| |
|
| | |
| | s_norm = self._normalize_state(state) |
| | a_norm = self._normalize_action(action) |
| | ns_norm = self._normalize_state(next_state) |
| |
|
| | with torch.no_grad(): |
| | |
| | noise = (torch.randn_like(a_norm) * self.policy_noise).clamp( |
| | -self.noise_clip, self.noise_clip |
| | ) |
| | |
| | next_a_norm = (self.actor_target(next_state) + noise).clamp(-1.0, 1.0) |
| |
|
| | |
| | target_q1, target_q2 = self.critic_target(ns_norm, next_a_norm) |
| | target_q = torch.min(target_q1, target_q2) |
| | target_q = reward.unsqueeze(-1) + (1.0 - done.unsqueeze(-1)) * self.discount * target_q |
| | |
| | target_q = target_q.clamp(-1.0, 2.0) |
| |
|
| | |
| | current_q1, current_q2 = self.critic(s_norm, a_norm) |
| | critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) |
| |
|
| | self.critic_optimizer.zero_grad() |
| | critic_loss.backward() |
| | nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0) |
| | self.critic_optimizer.step() |
| |
|
| | |
| | actor_loss_val = 0.0 |
| | bc_loss_val = 0.0 |
| | q_value_mean = 0.0 |
| |
|
| | if self.total_it % self.policy_delay == 0: |
| | |
| | pi_norm = self.actor(state) |
| | q_val = self.critic.q1_forward(s_norm, pi_norm) |
| |
|
| | |
| | lam = self.alpha / self.critic.q1_forward(s_norm, a_norm).abs().mean().detach() |
| |
|
| | |
| | bc_loss = ((pi_norm - a_norm) ** 2).mean() |
| |
|
| | actor_loss = -lam * q_val.mean() + bc_loss |
| |
|
| | self.actor_optimizer.zero_grad() |
| | actor_loss.backward() |
| | nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0) |
| | self.actor_optimizer.step() |
| |
|
| | |
| | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): |
| | target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data) |
| | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): |
| | target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data) |
| |
|
| | actor_loss_val = actor_loss.item() |
| | bc_loss_val = bc_loss.item() |
| | q_value_mean = q_val.mean().item() |
| |
|
| | return { |
| | "critic_loss": critic_loss.item(), |
| | "actor_loss": actor_loss_val, |
| | "bc_loss": bc_loss_val, |
| | "q_value_mean": q_value_mean, |
| | "q_value_std": current_q1.std().item(), |
| | } |
| |
|
| | def save(self, filepath): |
| | torch.save({ |
| | "actor": self.actor.state_dict(), |
| | "critic": self.critic.state_dict(), |
| | "actor_target": self.actor_target.state_dict(), |
| | "critic_target": self.critic_target.state_dict(), |
| | "actor_optimizer": self.actor_optimizer.state_dict(), |
| | "critic_optimizer": self.critic_optimizer.state_dict(), |
| | "total_it": self.total_it, |
| | }, filepath) |
| |
|
| | def load(self, filepath): |
| | checkpoint = torch.load(filepath, map_location=self.device) |
| | self.actor.load_state_dict(checkpoint["actor"]) |
| | self.critic.load_state_dict(checkpoint["critic"]) |
| | self.actor_target.load_state_dict(checkpoint["actor_target"]) |
| | self.critic_target.load_state_dict(checkpoint["critic_target"]) |
| | self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer"]) |
| | self.critic_optimizer.load_state_dict(checkpoint["critic_optimizer"]) |
| | self.total_it = checkpoint["total_it"] |
| |
|
| |
|
| | def main(): |
| | import argparse |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--dataset", default="/code/zxx240000/training/offline_rl/data/offline_dataset.npz") |
| | parser.add_argument("--output_dir", default="/code/zxx240000/training/offline_rl/results/td3_bc") |
| | parser.add_argument("--num_iterations", type=int, default=100000) |
| | parser.add_argument("--batch_size", type=int, default=256) |
| | parser.add_argument("--lr", type=float, default=3e-4) |
| | parser.add_argument("--discount", type=float, default=0.99) |
| | parser.add_argument("--tau", type=float, default=0.005) |
| | parser.add_argument("--policy_noise", type=float, default=0.2) |
| | parser.add_argument("--noise_clip", type=float, default=0.5) |
| | parser.add_argument("--policy_delay", type=int, default=2) |
| | parser.add_argument("--alpha", type=float, default=2.5) |
| | parser.add_argument("--eval_freq", type=int, default=10000) |
| | parser.add_argument("--seed", type=int, default=42) |
| | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| | args = parser.parse_args() |
| |
|
| | |
| | torch.manual_seed(args.seed) |
| | np.random.seed(args.seed) |
| |
|
| | |
| | ckpt_dir = os.path.join(args.output_dir, "checkpoints") |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| |
|
| | |
| | print(f"Loading dataset from {args.dataset}") |
| | dataset = OfflineRLDataset(args.dataset, device=args.device) |
| | print(f" {dataset.size} transitions loaded") |
| |
|
| | |
| | state_mean = dataset.state_mean.to(args.device) |
| | state_std = dataset.state_std.to(args.device) |
| |
|
| | |
| | agent = TD3BC( |
| | state_dim=9, |
| | action_dim=9, |
| | state_mean=state_mean, |
| | state_std=state_std, |
| | lr=args.lr, |
| | discount=args.discount, |
| | tau=args.tau, |
| | policy_noise=args.policy_noise, |
| | noise_clip=args.noise_clip, |
| | policy_delay=args.policy_delay, |
| | alpha=args.alpha, |
| | device=args.device, |
| | ) |
| | print(f"TD3+BC agent created on {args.device}") |
| |
|
| | |
| | raw_actions = dataset.actions.to(args.device) |
| | norm_actions = normalize_action(raw_actions, JOINT_LIMITS_LOW.to(args.device), JOINT_LIMITS_HIGH.to(args.device)) |
| | print(f" Normalized action range: [{norm_actions.min():.3f}, {norm_actions.max():.3f}]") |
| | print(f" Normalized action mean: {norm_actions.mean(0).cpu().numpy()}") |
| |
|
| | |
| | log_path = os.path.join(args.output_dir, "training_log.csv") |
| | log_file = open(log_path, "w", newline="") |
| | log_writer = csv.writer(log_file) |
| | log_writer.writerow(["step", "critic_loss", "actor_loss", "bc_loss", "q_value_mean", "q_value_std"]) |
| |
|
| | |
| | running = {"critic_loss": 0, "actor_loss": 0, "bc_loss": 0, "q_value_mean": 0, "q_value_std": 0} |
| | actor_updates = 0 |
| |
|
| | print(f"\nStarting training for {args.num_iterations} iterations...") |
| | for step in range(1, args.num_iterations + 1): |
| | state, action, reward, next_state, done = dataset.sample(args.batch_size) |
| | metrics = agent.train_step(state, action, reward, next_state, done) |
| |
|
| | running["critic_loss"] += metrics["critic_loss"] |
| | running["q_value_std"] += metrics["q_value_std"] |
| | if metrics["actor_loss"] != 0: |
| | running["actor_loss"] += metrics["actor_loss"] |
| | running["bc_loss"] += metrics["bc_loss"] |
| | running["q_value_mean"] += metrics["q_value_mean"] |
| | actor_updates += 1 |
| |
|
| | if step % args.eval_freq == 0: |
| | n = args.eval_freq |
| | n_actor = max(actor_updates, 1) |
| | avg_critic = running["critic_loss"] / n |
| | avg_actor = running["actor_loss"] / n_actor |
| | avg_bc = running["bc_loss"] / n_actor |
| | avg_q_mean = running["q_value_mean"] / n_actor |
| | avg_q_std = running["q_value_std"] / n |
| |
|
| | log_writer.writerow([step, f"{avg_critic:.6f}", f"{avg_actor:.6f}", |
| | f"{avg_bc:.6f}", f"{avg_q_mean:.6f}", f"{avg_q_std:.6f}"]) |
| | log_file.flush() |
| |
|
| | print(f"Step {step:>6d} | Critic: {avg_critic:.6f} | Actor: {avg_actor:.6f} | " |
| | f"BC: {avg_bc:.6f} | Q-mean: {avg_q_mean:.4f} | Q-std: {avg_q_std:.4f}") |
| |
|
| | |
| | ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{step}.pt") |
| | agent.save(ckpt_path) |
| |
|
| | |
| | running = {k: 0 for k in running} |
| | actor_updates = 0 |
| |
|
| | |
| | with torch.no_grad(): |
| | test_states = dataset.states[:100].to(args.device) |
| | test_actions_raw = agent.actor.get_raw_action(test_states) |
| | a_min = test_actions_raw.min(dim=0).values.cpu().numpy() |
| | a_max = test_actions_raw.max(dim=0).values.cpu().numpy() |
| | within_limits = ( |
| | (test_actions_raw >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all() |
| | and (test_actions_raw <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all() |
| | ) |
| | if not within_limits: |
| | print(f" WARNING: Policy outputs outside joint limits!") |
| | print(f" Min: {a_min}") |
| | print(f" Max: {a_max}") |
| |
|
| | log_file.close() |
| |
|
| | |
| | best_path = os.path.join(args.output_dir, "best_model.pt") |
| | agent.save(best_path) |
| | print(f"\nFinal model saved to {best_path}") |
| |
|
| | |
| | print("\n=== FINAL VALIDATION ===") |
| | with torch.no_grad(): |
| | all_states = dataset.states.to(args.device) |
| | chunk_size = 4096 |
| | all_actions = [] |
| | for i in range(0, len(all_states), chunk_size): |
| | chunk = all_states[i:i+chunk_size] |
| | all_actions.append(agent.actor.get_raw_action(chunk)) |
| | all_actions = torch.cat(all_actions, dim=0) |
| |
|
| | print("Policy action statistics (raw joint space):") |
| | joint_names = ["shoulder_pan", "shoulder_lift", "upperarm_roll", "elbow_flex", |
| | "forearm_roll", "wrist_flex", "wrist_roll", "l_gripper", "r_gripper"] |
| | for i, name in enumerate(joint_names): |
| | a = all_actions[:, i] |
| | print(f" {name}: min={a.min():.4f}, max={a.max():.4f}, mean={a.mean():.4f}, " |
| | f"limits=[{JOINT_LIMITS_LOW[i]:.3f}, {JOINT_LIMITS_HIGH[i]:.3f}]") |
| |
|
| | within = ( |
| | (all_actions >= JOINT_LIMITS_LOW.to(args.device) - 1e-5).all() |
| | and (all_actions <= JOINT_LIMITS_HIGH.to(args.device) + 1e-5).all() |
| | ) |
| | print(f"\nAll actions within joint limits: {within.item()}") |
| |
|
| | |
| | print("\nVerifying saved model loads correctly...") |
| | agent2 = TD3BC(state_dim=9, action_dim=9, state_mean=state_mean, state_std=state_std, device=args.device) |
| | agent2.load(best_path) |
| | with torch.no_grad(): |
| | test_s = dataset.states[:10].to(args.device) |
| | test_a = agent2.actor.get_raw_action(test_s) |
| | print(f" Loaded model produces actions: shape={test_a.shape}, range=[{test_a.min():.4f}, {test_a.max():.4f}]") |
| |
|
| | print(f"\nTraining log saved to {log_path}") |
| | print(f"Checkpoints saved to {ckpt_dir}") |
| | print(f"Best model saved to {best_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|