bkhmsi's picture
initialized repo
d36d50b
raw
history blame
No virus
6.34 kB
from typing import (
Tuple,
List,
Optional,
Dict,
Callable,
Union,
cast,
)
from collections import namedtuple
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
import torch as T
from torch import nn
from torch.nn import functional as F
from torch import Tensor
from .rnn_base import (
IRecurrentCell,
IRecurrentCellBuilder,
RecurrentLayer,
RecurrentLayerStack,
)
__all__ = [
'K_LSTM',
'K_LSTM_Cell',
'K_LSTM_Cell_Builder',
]
ACTIVATIONS = {
'sigmoid': nn.Sigmoid(),
'tanh': nn.Tanh(),
'hard_tanh': nn.Hardtanh(),
'relu': nn.ReLU(),
}
GateSpans = namedtuple('GateSpans', ['I', 'F', 'G', 'O'])
@dataclass
class K_LSTM_Cell_Builder(IRecurrentCellBuilder):
vertical_dropout : float = 0.0
recurrent_dropout : float = 0.0
recurrent_dropout_mode : str = 'gal_tied'
input_kernel_initialization : str = 'xavier_uniform'
recurrent_activation : str = 'sigmoid'
tied_forget_gate : bool = False
def make(self, input_size: int):
return K_LSTM_Cell(input_size, self)
class K_LSTM_Cell(IRecurrentCell):
def __repr__(self):
return (
f'{self.__class__.__name__}('
+ ', '.join(
[
f'in: {self.Dx}',
f'hid: {self.Dh}',
f'rdo: {self.recurrent_dropout_p} @{self.recurrent_dropout_mode}',
f'vdo: {self.vertical_dropout_p}'
]
)
+')'
)
def __init__(
self,
input_size: int,
args: K_LSTM_Cell_Builder,
):
super().__init__()
self._args = args
self.Dx = input_size
self.Dh = args.hidden_size
self.recurrent_kernel = nn.Linear(self.Dh, self.Dh * 4)
self.input_kernel = nn.Linear(self.Dx, self.Dh * 4)
self.recurrent_dropout_p = args.recurrent_dropout or 0.0
self.vertical_dropout_p = args.vertical_dropout or 0.0
self.recurrent_dropout_mode = args.recurrent_dropout_mode
self.recurrent_dropout = nn.Dropout(self.recurrent_dropout_p)
self.vertical_dropout = nn.Dropout(self.vertical_dropout_p)
self.tied_forget_gate = args.tied_forget_gate
if isinstance(args.recurrent_activation, str):
self.fun_rec = ACTIVATIONS[args.recurrent_activation]
else:
self.fun_rec = args.recurrent_activation
self.reset_parameters_()
# @T.jit.ignore
def get_recurrent_weights(self):
# type: () -> Tuple[GateSpans, GateSpans]
W = self.recurrent_kernel.weight.chunk(4, 0)
b = self.recurrent_kernel.bias.chunk(4, 0)
W = GateSpans(W[0], W[1], W[2], W[3])
b = GateSpans(b[0], b[1], b[2], b[3])
return W, b
# @T.jit.ignore
def get_input_weights(self):
# type: () -> Tuple[GateSpans, GateSpans]
W = self.input_kernel.weight.chunk(4, 0)
b = self.input_kernel.bias.chunk(4, 0)
W = GateSpans(W[0], W[1], W[2], W[3])
b = GateSpans(b[0], b[1], b[2], b[3])
return W, b
@T.jit.ignore
def reset_parameters_(self):
rw, rb = self.get_recurrent_weights()
iw, ib = self.get_input_weights()
nn.init.zeros_(self.input_kernel.bias)
nn.init.zeros_(self.recurrent_kernel.bias)
nn.init.ones_(rb.F)
#^ forget bias
for W in rw:
nn.init.orthogonal_(W)
for W in iw:
nn.init.xavier_uniform_(W)
@T.jit.export
def get_init_state(self, input: Tensor) -> Tuple[Tensor, Tensor]:
batch_size = input.shape[1]
h0 = T.zeros(batch_size, self.Dh, device=input.device)
c0 = T.zeros(batch_size, self.Dh, device=input.device)
return (h0, c0)
def apply_input_kernel(self, xt: Tensor) -> List[Tensor]:
xto = self.vertical_dropout(xt)
out = self.input_kernel(xto).chunk(4, 1)
# return cast(List[Tensor], out)
return out
def apply_recurrent_kernel(self, h_tm1: Tensor):
#^ h_tm1 : [b h]
mode = self.recurrent_dropout_mode
if mode == 'gal_tied':
hto = self.recurrent_dropout(h_tm1)
out = self.recurrent_kernel(hto)
#^ out : [b 4h]
outs = out.chunk(4, -1)
elif mode == 'gal_gates':
outs = []
WW, bb = self.get_recurrent_weights()
for i in range(4):
hto = self.recurrent_dropout(h_tm1)
outs.append(F.linear(hto, WW[i], bb[i]))
else:
outs = self.recurrent_kernel(h_tm1).chunk(4, -1)
return outs
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
#^ input : [b i]
#^ state.h : [b h]
(h_tm1, c_tm1) = state
Xi, Xf, Xg, Xo = self.apply_input_kernel(input)
Hi, Hf, Hg, Ho = self.apply_recurrent_kernel(h_tm1)
ft = self.fun_rec(Xf + Hf)
ot = self.fun_rec(Xo + Ho)
if self.tied_forget_gate:
it = 1.0 - ft
else:
it = self.fun_rec(Xi + Hi)
gt = T.tanh(Xg + Hg) # * np.sqrt(3)
if self.recurrent_dropout_mode == 'semeniuta':
#* https://arxiv.org/abs/1603.05118
gt = self.recurrent_dropout(gt)
ct = (ft * c_tm1) + (it * gt)
ht = ot * T.tanh(ct)
return ht, (ht, ct)
@T.jit.export
def loop(self, inputs, state_t0, mask=None):
# type: (List[Tensor], Tuple[Tensor, Tensor], Optional[List[Tensor]]) -> Tuple[List[Tensor], Tuple[Tensor, Tensor]]
'''
This loops over t (time) steps
'''
#^ inputs : t * [b i]
#^ state_t0[i] : [b s]
#^ out : [t b h]
state = state_t0
outs = []
for xt in inputs:
ht, state = self(xt, state)
outs.append(ht)
return outs, state
class K_LSTM(RecurrentLayerStack):
def __init__(
self,
*args,
**kargs,
):
builder = K_LSTM_Cell_Builder
super().__init__(
builder,
*args, **kargs
)