gomoku / DI-engine /ding /model /template /procedure_cloning.py
zjowowen's picture
init space
079c32c
from typing import Optional, Tuple, Union, Dict
import torch
import torch.nn as nn
from ding.utils import MODEL_REGISTRY, SequenceType
from ding.torch_utils.network.transformer import Attention
from ding.torch_utils.network.nn_module import fc_block, build_normalization
from ..common import FCEncoder, ConvEncoder
class PCTransformer(nn.Module):
"""
Overview:
The transformer block for neural network of algorithms related to Procedure cloning (PC).
Interfaces:
``__init__``, ``forward``.
"""
def __init__(
self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int,
feedforward_hidden: int, n_feedforward: int
) -> None:
"""
Overview:
Initialize the procedure cloning transformer model according to corresponding input arguments.
Arguments:
- cnn_hidden (:obj:`int`): The last channel dimension of CNN encoder, such as 32.
- att_hidden (:obj:`int`): The dimension of attention blocks, such as 32.
- att_heads (:obj:`int`): The number of heads in attention blocks, such as 4.
- drop_p (:obj:`float`): The drop out rate of attention, such as 0.5.
- max_T (:obj:`int`): The sequence length of procedure cloning, such as 4.
- n_attn (:obj:`int`): The number of attention layers, such as 4.
- feedforward_hidden (:obj:`int`):The dimension of feedforward layers, such as 32.
- n_feedforward (:obj:`int`): The number of feedforward layers, such as 4.
"""
super().__init__()
self.n_att = n_att
self.n_feedforward = n_feedforward
self.attention_layer = []
self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att
self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
for i in range(n_att - 1):
self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
self.att_drop = nn.Dropout(drop_p)
self.fc_blocks = []
self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
for i in range(n_feedforward - 1):
self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward)
self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
The unique execution (forward) method of PCTransformer.
Arguments:
- x (:obj:`torch.Tensor`): Sequential data of several hidden states.
Returns:
- output (:obj:`torch.Tensor`): A tensor with the same shape as the input.
Examples:
>>> model = PCTransformer(128, 128, 8, 0, 16, 2, 128, 2)
>>> h = torch.randn((2, 16, 128))
>>> h = model(h)
>>> assert h.shape == torch.Size([2, 16, 128])
"""
for i in range(self.n_att):
x = self.att_drop(self.attention_layer[i](x, self.mask))
x = self.norm_layer[i](x)
for i in range(self.n_feedforward):
x = self.fc_blocks[i](x)
x = self.norm_layer[i + self.n_att](x)
return x
@MODEL_REGISTRY.register('pc_mcts')
class ProcedureCloningMCTS(nn.Module):
"""
Overview:
The neural network of algorithms related to Procedure cloning (PC).
Interfaces:
``__init__``, ``forward``.
"""
def __init__(
self,
obs_shape: SequenceType,
action_dim: int,
cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256],
cnn_activation: nn.Module = nn.ReLU(),
cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3],
cnn_stride: SequenceType = [1, 1, 1, 1, 1],
cnn_padding: SequenceType = [1, 1, 1, 1, 1],
mlp_hidden_list: SequenceType = [256, 256],
mlp_activation: nn.Module = nn.ReLU(),
att_heads: int = 8,
att_hidden: int = 128,
n_att: int = 4,
n_feedforward: int = 2,
feedforward_hidden: int = 256,
drop_p: float = 0.5,
max_T: int = 17
) -> None:
"""
Overview:
Initialize the MCTS procedure cloning model according to corresponding input arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): Observation space shape, such as [4, 84, 84].
- action_dim (:obj:`int`): Action space shape, such as 6.
- cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as\
[128, 128, 256, 256, 256].
- cnn_activation (:obj:`nn.Module`): The activation function for cnn blocks, such as ``nn.ReLU()``.
- cnn_kernel_size (:obj:`SequenceType`): The kernel size for each cnn block, such as [3, 3, 3, 3, 3].
- cnn_stride (:obj:`SequenceType`): The stride for each cnn block, such as [1, 1, 1, 1, 1].
- cnn_padding (:obj:`SequenceType`): The padding for each cnn block, such as [1, 1, 1, 1, 1].
- mlp_hidden_list (:obj:`SequenceType`): The last dim for this must match the last dim of \
``cnn_hidden_list``, such as [256, 256].
- mlp_activation (:obj:`nn.Module`): The activation function for mlp layers, such as ``nn.ReLU()``.
- att_heads (:obj:`int`): The number of attention heads in transformer, such as 8.
- att_hidden (:obj:`int`): The number of attention dimension in transformer, such as 128.
- n_att (:obj:`int`): The number of attention blocks in transformer, such as 4.
- n_feedforward (:obj:`int`): The number of feedforward layers in transformer, such as 2.
- drop_p (:obj:`float`): The drop out rate of attention, such as 0.5.
- max_T (:obj:`int`): The sequence length of procedure cloning, such as 17.
"""
super().__init__()
# Conv Encoder
self.embed_state = ConvEncoder(
obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding
)
self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation)
self.cnn_hidden_list = cnn_hidden_list
assert cnn_hidden_list[-1] == mlp_hidden_list[-1]
layers = []
for i in range(n_att):
if i == 0:
layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
else:
layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
layers.append(build_normalization('LN')(att_hidden))
for i in range(n_feedforward):
if i == 0:
layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
else:
layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
self.layernorm2 = build_normalization('LN')(feedforward_hidden)
self.transformer = PCTransformer(
cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward
)
self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1])
self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim)
def forward(self, states: torch.Tensor, goals: torch.Tensor,
actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
ProcedureCloningMCTS forward computation graph, input states tensor and goals tensor, \
calculate the predicted states and actions.
Arguments:
- states (:obj:`torch.Tensor`): The observation of current time.
- goals (:obj:`torch.Tensor`): The target observation after a period.
- actions (:obj:`torch.Tensor`): The actions executed during the period.
Returns:
- outputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): Predicted states and actions.
Examples:
>>> inputs = { \
'states': torch.randn(2, 3, 64, 64), \
'goals': torch.randn(2, 3, 64, 64), \
'actions': torch.randn(2, 15, 9) \
}
>>> model = ProcedureCloningMCTS(obs_shape=(3, 64, 64), action_dim=9)
>>> goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
>>> assert goal_preds.shape == (2, 256)
>>> assert action_preds.shape == (2, 16, 9)
"""
B, T, _ = actions.shape
# shape: (B, h_dim)
state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1])
goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1])
# shape: (B, context_len, h_dim)
actions_embeddings = self.embed_action(actions)
h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1)
h = self.transformer(h)
h = h.reshape(B, T + 2, self.cnn_hidden_list[-1])
goal_preds = self.predict_goal(h[:, 0, :])
action_preds = self.predict_action(h[:, 1:, :])
return goal_preds, action_preds
class BFSConvEncoder(nn.Module):
"""
Overview:
The ``BFSConvolution Encoder`` used to encode raw 3-dim observations. And output a feature map with the
same height and width as input. Interfaces: ``__init__``, ``forward``.
"""
def __init__(
self,
obs_shape: SequenceType,
hidden_size_list: SequenceType = [32, 64, 64, 128],
activation: Optional[nn.Module] = nn.ReLU(),
kernel_size: SequenceType = [8, 4, 3],
stride: SequenceType = [4, 2, 1],
padding: Optional[SequenceType] = None,
) -> None:
"""
Overview:
Init the ``BFSConvolution Encoder`` according to the provided arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``.
- hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \
and the final dense layer.
- activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \
Default is ``nn.ReLU()``.
- kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers.
- stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers.
- padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \
See ``nn.Conv2d`` for more details. Default is ``None``.
"""
super(BFSConvEncoder, self).__init__()
self.obs_shape = obs_shape
self.act = activation
self.hidden_size_list = hidden_size_list
if padding is None:
padding = [0 for _ in range(len(kernel_size))]
layers = []
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
layers = layers[:-1]
self.main = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Return output tensor of the env observation.
Arguments:
- x (:obj:`torch.Tensor`): Env raw observation.
Returns:
- outputs (:obj:`torch.Tensor`): Output embedding tensor.
Examples:
>>> model = BFSConvEncoder([3, 16, 16], [32, 32, 4], kernel_size=[3, 3, 3], stride=[1, 1, 1]\
, padding=[1, 1, 1])
>>> inputs = torch.randn(3, 16, 16).unsqueeze(0)
>>> outputs = model(inputs)
>>> assert outputs['logit'].shape == torch.Size([4, 16, 16])
"""
return self.main(x)
@MODEL_REGISTRY.register('pc_bfs')
class ProcedureCloningBFS(nn.Module):
"""
Overview:
The neural network introduced in procedure cloning (PC) to process 3-dim observations.\
Given an input, this model will perform several 3x3 convolutions and output a feature map with \
the same height and width of input. The channel number of output will be the ``action_shape``.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(
self,
obs_shape: SequenceType,
action_shape: int,
encoder_hidden_size_list: SequenceType = [128, 128, 256, 256],
):
"""
Overview:
Init the ``BFSConvolution Encoder`` according to the provided arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``,\
such as [4, 84, 84].
- action_dim (:obj:`int`): Action space shape, such as 6.
- cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as [128, 128, 256, 256].
"""
super().__init__()
num_layers = len(encoder_hidden_size_list)
kernel_sizes = (3, ) * (num_layers + 1)
stride_sizes = (1, ) * (num_layers + 1)
padding_sizes = (1, ) * (num_layers + 1)
# The output channel equals to action_shape + 1
encoder_hidden_size_list.append(action_shape + 1)
self._encoder = BFSConvEncoder(
obs_shape=obs_shape,
hidden_size_list=encoder_hidden_size_list,
kernel_size=kernel_sizes,
stride=stride_sizes,
padding=padding_sizes,
)
def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
The computation graph. Given a 3-dim observation, this function will return a tensor with the same \
height and width. The channel number of output will be the ``action_shape``.
Arguments:
- x (:obj:`torch.Tensor`): The input observation tensor data.
Returns:
- outputs (:obj:`Dict`): The output dict of model's forward computation graph, \
only contains a single key ``logit``.
Examples:
>>> model = ProcedureCloningBFS([3, 16, 16], 4)
>>> inputs = torch.randn(16, 16, 3).unsqueeze(0)
>>> outputs = model(inputs)
>>> assert outputs['logit'].shape == torch.Size([16, 16, 4])
"""
x = x.permute(0, 3, 1, 2)
x = self._encoder(x)
return {'logit': x.permute(0, 2, 3, 1)}