zjowowen's picture
init space
079c32c
from typing import Optional, Union, List, Tuple, Dict
import math
import torch
import torch.nn as nn
import treetensor.torch as ttorch
import ding
from ding.torch_utils.network.normalization import build_normalization
if ding.enable_hpc_rl:
from hpc_rll.torch_utils.network.rnn import LSTM as HPCLSTM
else:
HPCLSTM = None
def is_sequence(data):
"""
Overview:
Determines if the input data is of type list or tuple.
Arguments:
- data: The input data to be checked.
Returns:
- boolean: True if the input is a list or a tuple, False otherwise.
"""
return isinstance(data, list) or isinstance(data, tuple)
def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor:
"""
Overview:
Generates a boolean mask for a batch of sequences with differing lengths.
Arguments:
- lengths (:obj:`torch.Tensor`): A tensor with the lengths of each sequence. Shape could be (n, 1) or (n).
- max_len (:obj:`int`, optional): The padding size. If max_len is None, the padding size is the max length of \
sequences.
Returns:
- masks (:obj:`torch.BoolTensor`): A boolean mask tensor. The mask has the same device as lengths.
"""
if len(lengths.shape) == 1:
lengths = lengths.unsqueeze(dim=1)
bz = lengths.numel()
if max_len is None:
max_len = lengths.max()
else:
max_len = min(max_len, lengths.max())
return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device)
class LSTMForwardWrapper(object):
"""
Overview:
Class providing methods to use before and after the LSTM `forward` method.
Wraps the LSTM `forward` method.
Interfaces:
``_before_forward``, ``_after_forward``
"""
def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor:
"""
Overview:
Preprocesses the inputs and previous states before the LSTM `forward` method.
Arguments:
- inputs (:obj:`torch.Tensor`): Input vector of the LSTM cell. Shape: [seq_len, batch_size, input_size]
- prev_state (:obj:`Union[None, List[Dict]]`): Previous state tensor. Shape: [num_directions*num_layers, \
batch_size, hidden_size]. If None, prv_state will be initialized to all zeros.
Returns:
- prev_state (:obj:`torch.Tensor`): Preprocessed previous state for the LSTM batch.
"""
assert hasattr(self, 'num_layers')
assert hasattr(self, 'hidden_size')
seq_len, batch_size = inputs.shape[:2]
if prev_state is None:
num_directions = 1
zeros = torch.zeros(
num_directions * self.num_layers,
batch_size,
self.hidden_size,
dtype=inputs.dtype,
device=inputs.device
)
prev_state = (zeros, zeros)
elif is_sequence(prev_state):
if len(prev_state) != batch_size:
raise RuntimeError(
"prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size)
)
num_directions = 1
zeros = torch.zeros(
num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device
)
state = []
for prev in prev_state:
if prev is None:
state.append([zeros, zeros])
else:
if isinstance(prev, (Dict, ttorch.Tensor)):
state.append([v for v in prev.values()])
else:
state.append(prev)
state = list(zip(*state))
prev_state = [torch.cat(t, dim=1) for t in state]
elif isinstance(prev_state, dict):
prev_state = list(prev_state.values())
else:
raise TypeError("not support prev_state type: {}".format(type(prev_state)))
return prev_state
def _after_forward(self,
next_state: Tuple[torch.Tensor],
list_next_state: bool = False) -> Union[List[Dict], Dict[str, torch.Tensor]]:
"""
Overview:
Post-processes the next_state after the LSTM `forward` method.
Arguments:
- next_state (:obj:`Tuple[torch.Tensor]`): Tuple containing the next state (h, c).
- list_next_state (:obj:`bool`, optional): Determines the format of the returned next_state. \
If True, returns next_state in list format. Default is False.
Returns:
- next_state(:obj:`Union[List[Dict], Dict[str, torch.Tensor]]`): The post-processed next_state.
"""
if list_next_state:
h, c = next_state
batch_size = h.shape[1]
next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)]
next_state = list(zip(*next_state))
next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state]
else:
next_state = {k: v for k, v in zip(['h', 'c'], next_state)}
return next_state
class LSTM(nn.Module, LSTMForwardWrapper):
"""
Overview:
Implementation of an LSTM cell with Layer Normalization (LN).
Interfaces:
``__init__``, ``forward``
.. note::
For a primer on LSTM, refer to https://zhuanlan.zhihu.com/p/32085405.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
norm_type: Optional[str] = None,
dropout: float = 0.
) -> None:
"""
Overview:
Initialize LSTM cell parameters.
Arguments:
- input_size (:obj:`int`): Size of the input vector.
- hidden_size (:obj:`int`): Size of the hidden state vector.
- num_layers (:obj:`int`): Number of LSTM layers.
- norm_type (:obj:`Optional[str]`): Normalization type, default is None.
- dropout (:obj:`float`): Dropout rate, default is 0.
"""
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
norm_func = build_normalization(norm_type)
self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)])
self.wx = nn.ParameterList()
self.wh = nn.ParameterList()
dims = [input_size] + [hidden_size] * num_layers
for l in range(num_layers):
self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4)))
self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4)))
self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4))
self.use_dropout = dropout > 0.
if self.use_dropout:
self.dropout = nn.Dropout(dropout)
self._init()
def _init(self):
"""
Overview:
Initialize the parameters of the LSTM cell.
"""
gain = math.sqrt(1. / self.hidden_size)
for l in range(self.num_layers):
torch.nn.init.uniform_(self.wx[l], -gain, gain)
torch.nn.init.uniform_(self.wh[l], -gain, gain)
if self.bias is not None:
torch.nn.init.uniform_(self.bias[l], -gain, gain)
def forward(self,
inputs: torch.Tensor,
prev_state: torch.Tensor,
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:
"""
Overview:
Compute output and next state given previous state and input.
Arguments:
- inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size].
- prev_state (:obj:`torch.Tensor`): Previous state, \
size [num_directions*num_layers, batch_size, hidden_size].
- list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True.
Returns:
- x (:obj:`torch.Tensor`): Output from LSTM.
- next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM.
"""
seq_len, batch_size = inputs.shape[:2]
prev_state = self._before_forward(inputs, prev_state)
H, C = prev_state
x = inputs
next_state = []
for l in range(self.num_layers):
h, c = H[l], C[l]
new_x = []
for s in range(seq_len):
gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l])
) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l]))
if self.bias is not None:
gate += self.bias[l]
gate = list(torch.chunk(gate, 4, dim=1))
i, f, o, u = gate
i = torch.sigmoid(i)
f = torch.sigmoid(f)
o = torch.sigmoid(o)
u = torch.tanh(u)
c = f * c + i * u
h = o * torch.tanh(c)
new_x.append(h)
next_state.append((h, c))
x = torch.stack(new_x, dim=0)
if self.use_dropout and l != self.num_layers - 1:
x = self.dropout(x)
next_state = [torch.stack(t, dim=0) for t in zip(*next_state)]
next_state = self._after_forward(next_state, list_next_state)
return x, next_state
class PytorchLSTM(nn.LSTM, LSTMForwardWrapper):
"""
Overview:
Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM,
refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
Interfaces:
``forward``
"""
def forward(self,
inputs: torch.Tensor,
prev_state: torch.Tensor,
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:
"""
Overview:
Executes nn.LSTM.forward with preprocessed input.
Arguments:
- inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size].
- prev_state (:obj:`torch.Tensor`): Previous state, size [num_directions*num_layers, batch_size, \
hidden_size].
- list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True.
Returns:
- output (:obj:`torch.Tensor`): Output from LSTM.
- next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM.
"""
prev_state = self._before_forward(inputs, prev_state)
output, next_state = nn.LSTM.forward(self, inputs, prev_state)
next_state = self._after_forward(next_state, list_next_state)
return output, next_state
class GRU(nn.GRUCell, LSTMForwardWrapper):
"""
Overview:
This class extends the `torch.nn.GRUCell` and `LSTMForwardWrapper` classes, and formats inputs and outputs
accordingly.
Interfaces:
``__init__``, ``forward``
Properties:
hidden_size, num_layers
.. note::
For further details, refer to the official PyTorch documentation:
<https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU>
"""
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
"""
Overview:
Initialize the GRU class with input size, hidden size, and number of layers.
Arguments:
- input_size (:obj:`int`): The size of the input vector.
- hidden_size (:obj:`int`): The size of the hidden state vector.
- num_layers (:obj:`int`): The number of GRU layers.
"""
super(GRU, self).__init__(input_size, hidden_size)
self.hidden_size = hidden_size
self.num_layers = num_layers
def forward(self,
inputs: torch.Tensor,
prev_state: Optional[torch.Tensor] = None,
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, List]]:
"""
Overview:
Wrap the `nn.GRU.forward` method.
Arguments:
- inputs (:obj:`torch.Tensor`): Input vector of cell, tensor of size [seq_len, batch_size, input_size].
- prev_state (:obj:`Optional[torch.Tensor]`): None or tensor of \
size [num_directions*num_layers, batch_size, hidden_size].
- list_next_state (:obj:`bool`): Whether to return next_state in list format (default is True).
Returns:
- output (:obj:`torch.Tensor`): Output from GRU.
- next_state (:obj:`torch.Tensor` or :obj:`list`): Hidden state from GRU.
"""
# for compatibility
prev_state, _ = self._before_forward(inputs, prev_state)
inputs, prev_state = inputs.squeeze(0), prev_state.squeeze(0)
next_state = nn.GRUCell.forward(self, inputs, prev_state)
next_state = next_state.unsqueeze(0)
x = next_state
# for compatibility
next_state = self._after_forward([next_state, next_state.clone()], list_next_state)
return x, next_state
def get_lstm(
lstm_type: str,
input_size: int,
hidden_size: int,
num_layers: int = 1,
norm_type: str = 'LN',
dropout: float = 0.,
seq_len: Optional[int] = None,
batch_size: Optional[int] = None
) -> Union[LSTM, PytorchLSTM]:
"""
Overview:
Build and return the corresponding LSTM cell based on the provided parameters.
Arguments:
- lstm_type (:obj:`str`): Version of RNN cell. Supported options are ['normal', 'pytorch', 'hpc', 'gru'].
- input_size (:obj:`int`): Size of the input vector.
- hidden_size (:obj:`int`): Size of the hidden state vector.
- num_layers (:obj:`int`): Number of LSTM layers (default is 1).
- norm_type (:obj:`str`): Type of normalization (default is 'LN').
- dropout (:obj:`float`): Dropout rate (default is 0.0).
- seq_len (:obj:`Optional[int]`): Sequence length (default is None).
- batch_size (:obj:`Optional[int]`): Batch size (default is None).
Returns:
- lstm (:obj:`Union[LSTM, PytorchLSTM]`): The corresponding LSTM cell.
"""
assert lstm_type in ['normal', 'pytorch', 'hpc', 'gru']
if lstm_type == 'normal':
return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout)
elif lstm_type == 'pytorch':
return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout)
elif lstm_type == 'hpc':
return HPCLSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout).cuda()
elif lstm_type == 'gru':
assert num_layers == 1
return GRU(input_size, hidden_size, num_layers)