| import copy |
| import torch |
| import torch.nn.functional as F |
| from SGVLB import SGVLB |
| from network import Net, Critic |
|
|
|
|
| class BPDAgent(object): |
| def __init__( |
| self, |
| env, |
| args, |
| env_info, |
| thresholds, |
| datasize, |
| device, |
| discount, |
| tau, |
| noise_clip, |
| policy_freq, |
| h, |
| num_teacher_param, |
| ): |
| self.args = args |
| self.env = env |
| self.env_info = env_info |
|
|
| self.actor = Net(env_info['state_dim'], env_info['action_dim'], env_info['action_bound'], |
| args.student_hidden_dims, thresholds['ALPHA_THRESHOLD'], thresholds['THETA_THRESHOLD'], |
| device=device).to(device) |
| self.actor_target = copy.deepcopy(self.actor) |
| self.sgvlb = SGVLB(self.actor, datasize, loss_type='l2', device=device) |
| self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) |
|
|
| self.critic = Critic(env_info['state_dim'], env_info['action_dim']).to(device) |
| self.critic_target = copy.deepcopy(self.critic) |
| self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) |
|
|
| self.discount = discount |
| self.tau = tau |
| self.noise_clip = noise_clip |
| self.policy_freq = policy_freq |
| self.datasize = datasize |
| self.h = h |
|
|
| self.total_it = 0 |
| self.kl_weight = 0 |
|
|
| def set_kl_weight(self, kl_weight): |
| self.kl_weight = kl_weight |
| return |
|
|
| def test(self): |
| self.actor.eval() |
| with torch.no_grad(): |
| return_list = [] |
| for epi_cnt in range(1, self.args.num_test_epi): |
| episode_return = 0 |
| done = False |
| state, _ = self.env.reset() |
| while not done: |
| action = self.actor(state) |
| action = action.cpu().numpy()[0] |
| next_state, reward, terminated, truncated, _ = self.env.step(action) |
| done = terminated or truncated |
| episode_return += reward |
| state = next_state |
| return_list.append(episode_return) |
|
|
| avg_return = sum(return_list) / len(return_list) |
| max_return = max(return_list) |
| min_return = min(return_list) |
|
|
| return avg_return, max_return, min_return |
|
|
| def train(self, transition): |
| self.actor.train() |
|
|
| self.total_it += 1 |
|
|
| states, actions, rewards, next_states, dones = transition |
|
|
| with torch.no_grad(): |
| next_actions = ( |
| self.actor_target(next_states) |
| ).clamp(self.env_info['action_bound'][0], self.env_info['action_bound'][1]) |
|
|
| target_Q1, target_Q2 = self.critic_target(next_states, next_actions) |
| target_Q = torch.min(target_Q1, target_Q2) |
| target_Q = rewards + (1 - dones) * self.discount * target_Q |
|
|
| current_Q1, current_Q2 = self.critic(states, actions) |
| critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) |
|
|
| self.critic_optimizer.zero_grad() |
| critic_loss.backward() |
| self.critic_optimizer.step() |
|
|
| if self.total_it % self.policy_freq == 0: |
| pi = self.actor(states) |
| Q = self.critic.Q1(states, pi) |
| lmbda = (self.h * self.datasize) / Q.abs().mean().detach() |
|
|
| actor_loss = -lmbda * Q.mean() + self.sgvlb(pi, actions, self.kl_weight) |
|
|
| |
| self.actor_optimizer.zero_grad() |
| actor_loss.backward() |
| self.actor_optimizer.step() |
|
|
| |
| for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): |
| target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) |
|
|
| for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): |
| target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) |
|
|
| def __del__(self): |
| del self.actor |
| del self.actor_target |
| del self.critic |
| del self.critic_target |
| return |
|
|
|
|