zjowowen's picture
init space
079c32c
raw
history blame
4.32 kB
from typing import Dict, Any
import torch
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error
from ding.policy import DQNPolicy
from ding.utils import POLICY_REGISTRY
from ding.policy.common_utils import default_preprocess_learn
from ding.torch_utils import to_device
@POLICY_REGISTRY.register('md_dqn')
class MultiDiscreteDQNPolicy(DQNPolicy):
r"""
Overview:
Policy class of Multi-discrete action space DQN algorithm.
"""
def _forward_learn(self, data: dict) -> Dict[str, Any]:
"""
Overview:
Forward computation of learn mode(updating policy). It supports both single and multi-discrete action \
space. It depends on whether the ``q_value`` is a list.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``
- optional: ``action_distribution``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value = self._target_model.forward(data['next_obs'])['logit']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
value_gamma = data.get('value_gamma')
if isinstance(q_value, list):
act_num = len(q_value)
loss, td_error_per_sample = [], []
q_value_list = []
for i in range(act_num):
td_data = q_nstep_td_data(
q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight']
)
loss_, td_error_per_sample_ = q_nstep_td_error(
td_data, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)
loss.append(loss_)
td_error_per_sample.append(td_error_per_sample_.abs())
q_value_list.append(q_value[i].mean().item())
loss = sum(loss) / (len(loss) + 1e-8)
td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
q_value_mean = sum(q_value_list) / act_num
else:
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
loss, td_error_per_sample = q_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)
q_value_mean = q_value.mean().item()
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'q_value_mean': q_value_mean,
'priority': td_error_per_sample.abs().tolist(),
}