baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
7.21 kB
from collections import OrderedDict
import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer
class SACTrainer(TorchTrainer):
def __init__(
self,
env,
policy,
qf1,
qf2,
target_qf1,
target_qf2,
discount=0.99,
reward_scale=1.0,
policy_lr=1e-3,
qf_lr=1e-3,
optimizer_class=optim.Adam,
soft_target_tau=1e-2,
target_update_period=1,
plotter=None,
render_eval_paths=False,
use_automatic_entropy_tuning=True,
target_entropy=None,
):
super().__init__()
self.env = env
self.policy = policy
self.qf1 = qf1
self.qf2 = qf2
self.target_qf1 = target_qf1
self.target_qf2 = target_qf2
self.soft_target_tau = soft_target_tau
self.target_update_period = target_update_period
self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
if self.use_automatic_entropy_tuning:
if target_entropy:
self.target_entropy = target_entropy
else:
self.target_entropy = -np.prod(self.env.action_space.shape).item() # heuristic value from Tuomas
self.log_alpha = ptu.zeros(1, requires_grad=True)
self.alpha_optimizer = optimizer_class(
[self.log_alpha],
lr=policy_lr,
)
self.plotter = plotter
self.render_eval_paths = render_eval_paths
self.qf_criterion = nn.MSELoss()
self.vf_criterion = nn.MSELoss()
self.policy_optimizer = optimizer_class(
self.policy.parameters(),
lr=policy_lr,
)
self.qf1_optimizer = optimizer_class(
self.qf1.parameters(),
lr=qf_lr,
)
self.qf2_optimizer = optimizer_class(
self.qf2.parameters(),
lr=qf_lr,
)
self.discount = discount
self.reward_scale = reward_scale
self.eval_statistics = OrderedDict()
self._n_train_steps_total = 0
self._need_to_update_eval_statistics = True
def train_from_torch(self, batch):
rewards = batch['rewards']
terminals = batch['terminals']
obs = batch['observations']
actions = batch['actions']
next_obs = batch['next_observations']
"""
Policy and Alpha Loss
"""
new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
obs, reparameterize=True, return_log_prob=True,
)
if self.use_automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
alpha = self.log_alpha.exp()
else:
alpha_loss = 0
alpha = 1
q_new_actions = torch.min(
self.qf1(obs, new_obs_actions),
self.qf2(obs, new_obs_actions),
)
policy_loss = (alpha*log_pi - q_new_actions).mean()
"""
QF Loss
"""
q1_pred = self.qf1(obs, actions)
q2_pred = self.qf2(obs, actions)
# Make sure policy accounts for squashing functions like tanh correctly!
new_next_actions, _, _, new_log_pi, *_ = self.policy(
next_obs, reparameterize=True, return_log_prob=True,
)
target_q_values = torch.min(
self.target_qf1(next_obs, new_next_actions),
self.target_qf2(next_obs, new_next_actions),
) - alpha * new_log_pi
q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
"""
Update networks
"""
self.qf1_optimizer.zero_grad()
qf1_loss.backward()
self.qf1_optimizer.step()
self.qf2_optimizer.zero_grad()
qf2_loss.backward()
self.qf2_optimizer.step()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
"""
Soft Updates
"""
if self._n_train_steps_total % self.target_update_period == 0:
ptu.soft_update_from_to(
self.qf1, self.target_qf1, self.soft_target_tau
)
ptu.soft_update_from_to(
self.qf2, self.target_qf2, self.soft_target_tau
)
"""
Save some statistics for eval
"""
if self._need_to_update_eval_statistics:
self._need_to_update_eval_statistics = False
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
policy_loss = (log_pi - q_new_actions).mean()
self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q1 Predictions',
ptu.get_numpy(q1_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q2 Predictions',
ptu.get_numpy(q2_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q Targets',
ptu.get_numpy(q_target),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Log Pis',
ptu.get_numpy(log_pi),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Policy mu',
ptu.get_numpy(policy_mean),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Policy log std',
ptu.get_numpy(policy_log_std),
))
if self.use_automatic_entropy_tuning:
self.eval_statistics['Alpha'] = alpha.item()
self.eval_statistics['Alpha Loss'] = alpha_loss.item()
self._n_train_steps_total += 1
def get_diagnostics(self):
return self.eval_statistics
def end_epoch(self, epoch):
self._need_to_update_eval_statistics = True
@property
def networks(self):
return [
self.policy,
self.qf1,
self.qf2,
self.target_qf1,
self.target_qf2,
]
def get_snapshot(self):
return dict(
policy=self.policy,
qf1=self.qf1,
qf2=self.qf2,
target_qf1=self.qf1,
target_qf2=self.qf2,
)