zjowowen's picture
init space
079c32c
from typing import Union, Dict, Optional
import torch
import torch.nn as nn
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
FCEncoder, ConvEncoder
@MODEL_REGISTRY.register('acer')
class ACER(nn.Module):
"""
Overview:
The model of algorithmn ACER(Actor Critic with Experience Replay)
Sample Efficient Actor-Critic with Experience Replay.
https://arxiv.org/abs/1611.01224
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
) -> None:
"""
Overview:
Init the ACER Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for actor's nn.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
- critic_head_layer_num (:obj:`int`):
The num of layers used in the network to compute Q value output for critic's nn.
- activation (:obj:`Optional[nn.Module]`):
The type of activation function to use in ``MLP`` the after ``layer_fn``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`Optional[str]`):
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
"""
super(ACER, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape: int = squeeze(action_shape)
if isinstance(obs_shape, int) or len(obs_shape) == 1:
encoder_cls = FCEncoder
elif len(obs_shape) == 3:
encoder_cls = ConvEncoder
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
)
self.actor_encoder = encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
self.critic_encoder = encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
self.critic_head = RegressionHead(
critic_head_hidden_size, action_shape, critic_head_layer_num, activation=activation, norm_type=norm_type
)
self.actor_head = DiscreteHead(
actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
)
self.actor = [self.actor_encoder, self.actor_head]
self.critic = [self.critic_encoder, self.critic_head]
self.actor = nn.ModuleList(self.actor)
self.critic = nn.ModuleList(self.critic)
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
"""
Overview:
Use observation to predict output.
Parameter updates with ACER's MLPs forward setup.
Arguments:
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of network forward.
Shapes (Actor):
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
Shapes (Critic):
- inputs (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``obs_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
def compute_actor(self, inputs: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to predict output.
Execute parameter updates with ``compute_actor`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
``hidden_size = actor_head_hidden_size``
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
ReturnsKeys (either):
- logit (:obj:`torch.FloatTensor`): :math:`(B, N1)`, where B is batch size and N1 is ``action_shape``
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- logit (:obj:`torch.FloatTensor`): :math:`(B, N1)`, where B is batch size and N1 is ``action_shape``
Examples:
>>> # Regression mode
>>> model = ACER(64, 64)
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
"""
x = self.actor_encoder(inputs)
x = self.actor_head(x)
return x
def compute_critic(self, inputs: torch.Tensor) -> Dict:
"""
Overview:
Execute parameter updates with ``compute_critic`` mode
Use encoded embedding tensor to predict output.
Arguments:
- ``obs``, ``action`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``.
Examples:
>>> inputs =torch.randn(4, N)
>>> model = ACER(obs_shape=(N, ),action_shape=5)
>>> model(inputs, mode='compute_critic')['q_value']
"""
obs = inputs
x = self.critic_encoder(obs)
x = self.critic_head(x)
return {"q_value": x['pred']}