|
""" SRU Implementation """ |
|
|
|
|
|
import subprocess |
|
import platform |
|
import os |
|
import re |
|
import configargparse |
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import Function |
|
from torch.cuda.amp import custom_fwd, custom_bwd |
|
from collections import namedtuple |
|
|
|
|
|
|
|
class CheckSRU(configargparse.Action): |
|
def __init__(self, option_strings, dest, **kwargs): |
|
super(CheckSRU, self).__init__(option_strings, dest, **kwargs) |
|
|
|
def __call__(self, parser, namespace, values, option_string=None): |
|
if values == 'SRU': |
|
check_sru_requirement(abort=True) |
|
|
|
setattr(namespace, self.dest, values) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_sru_requirement(abort=False): |
|
""" |
|
Return True if check pass; if check fails and abort is True, |
|
raise an Exception, othereise return False. |
|
""" |
|
|
|
|
|
try: |
|
if platform.system() == 'Windows': |
|
subprocess.check_output('pip freeze | findstr cupy', shell=True) |
|
subprocess.check_output('pip freeze | findstr pynvrtc', |
|
shell=True) |
|
else: |
|
subprocess.check_output('pip freeze | grep -w cupy', shell=True) |
|
subprocess.check_output('pip freeze | grep -w pynvrtc', |
|
shell=True) |
|
except subprocess.CalledProcessError: |
|
if not abort: |
|
return False |
|
raise AssertionError("Using SRU requires 'cupy' and 'pynvrtc' " |
|
"python packages installed.") |
|
|
|
|
|
if torch.cuda.is_available() is False: |
|
if not abort: |
|
return False |
|
raise AssertionError("Using SRU requires pytorch built with cuda.") |
|
|
|
|
|
pattern = re.compile(".*cuda/lib.*") |
|
ld_path = os.getenv('LD_LIBRARY_PATH', "") |
|
if re.match(pattern, ld_path) is None: |
|
if not abort: |
|
return False |
|
raise AssertionError("Using SRU requires setting cuda lib path, e.g. " |
|
"export LD_LIBRARY_PATH=/usr/local/cuda/lib64.") |
|
|
|
return True |
|
|
|
|
|
SRU_CODE = """ |
|
extern "C" { |
|
__forceinline__ __device__ float sigmoidf(float x) |
|
{ |
|
return 1.f / (1.f + expf(-x)); |
|
} |
|
__forceinline__ __device__ float reluf(float x) |
|
{ |
|
return (x > 0.f) ? x : 0.f; |
|
} |
|
__global__ void sru_fwd(const float * __restrict__ u, |
|
const float * __restrict__ x, |
|
const float * __restrict__ bias, |
|
const float * __restrict__ init, |
|
const float * __restrict__ mask_h, |
|
const int len, const int batch, |
|
const int d, const int k, |
|
float * __restrict__ h, |
|
float * __restrict__ c, |
|
const int activation_type) |
|
{ |
|
assert ((k == 3) || (x == NULL)); |
|
int ncols = batch*d; |
|
int col = blockIdx.x * blockDim.x + threadIdx.x; |
|
if (col >= ncols) return; |
|
int ncols_u = ncols*k; |
|
int ncols_x = (k == 3) ? ncols : ncols_u; |
|
const float bias1 = *(bias + (col%d)); |
|
const float bias2 = *(bias + (col%d) + d); |
|
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
|
float cur = *(init + col); |
|
const float *up = u + (col*k); |
|
const float *xp = (k == 3) ? (x + col) : (up + 3); |
|
float *cp = c + col; |
|
float *hp = h + col; |
|
for (int row = 0; row < len; ++row) |
|
{ |
|
float g1 = sigmoidf((*(up+1))+bias1); |
|
float g2 = sigmoidf((*(up+2))+bias2); |
|
cur = (cur-(*up))*g1 + (*up); |
|
*cp = cur; |
|
float val = (activation_type == 1) ? tanh(cur) : ( |
|
(activation_type == 2) ? reluf(cur) : cur |
|
); |
|
*hp = (val*mask-(*xp))*g2 + (*xp); |
|
up += ncols_u; |
|
xp += ncols_x; |
|
cp += ncols; |
|
hp += ncols; |
|
} |
|
} |
|
__global__ void sru_bwd(const float * __restrict__ u, |
|
const float * __restrict__ x, |
|
const float * __restrict__ bias, |
|
const float * __restrict__ init, |
|
const float * __restrict__ mask_h, |
|
const float * __restrict__ c, |
|
const float * __restrict__ grad_h, |
|
const float * __restrict__ grad_last, |
|
const int len, |
|
const int batch, const int d, const int k, |
|
float * __restrict__ grad_u, |
|
float * __restrict__ grad_x, |
|
float * __restrict__ grad_bias, |
|
float * __restrict__ grad_init, |
|
int activation_type) |
|
{ |
|
assert((k == 3) || (x == NULL)); |
|
assert((k == 3) || (grad_x == NULL)); |
|
int ncols = batch*d; |
|
int col = blockIdx.x * blockDim.x + threadIdx.x; |
|
if (col >= ncols) return; |
|
int ncols_u = ncols*k; |
|
int ncols_x = (k == 3) ? ncols : ncols_u; |
|
const float bias1 = *(bias + (col%d)); |
|
const float bias2 = *(bias + (col%d) + d); |
|
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
|
float gbias1 = 0; |
|
float gbias2 = 0; |
|
float cur = *(grad_last + col); |
|
const float *up = u + (col*k) + (len-1)*ncols_u; |
|
const float *xp = (k == 3) ? (x + col + (len-1)*ncols) : (up + 3); |
|
const float *cp = c + col + (len-1)*ncols; |
|
const float *ghp = grad_h + col + (len-1)*ncols; |
|
float *gup = grad_u + (col*k) + (len-1)*ncols_u; |
|
float *gxp = (k == 3) ? (grad_x + col + (len-1)*ncols) : (gup + 3); |
|
for (int row = len-1; row >= 0; --row) |
|
{ |
|
const float g1 = sigmoidf((*(up+1))+bias1); |
|
const float g2 = sigmoidf((*(up+2))+bias2); |
|
const float c_val = (activation_type == 1) ? tanh(*cp) : ( |
|
(activation_type == 2) ? reluf(*cp) : (*cp) |
|
); |
|
const float x_val = *xp; |
|
const float u_val = *up; |
|
const float prev_c_val = (row>0) ? (*(cp-ncols)) : (*(init+col)); |
|
const float gh_val = *ghp; |
|
// h = c*g2 + x*(1-g2) = (c-x)*g2 + x |
|
// c = c'*g1 + g0*(1-g1) = (c'-g0)*g1 + g0 |
|
// grad wrt x |
|
*gxp = gh_val*(1-g2); |
|
// grad wrt g2, u2 and bias2 |
|
float gg2 = gh_val*(c_val*mask-x_val)*(g2*(1-g2)); |
|
*(gup+2) = gg2; |
|
gbias2 += gg2; |
|
// grad wrt c |
|
const float tmp = (activation_type == 1) ? (g2*(1-c_val*c_val)) : ( |
|
((activation_type == 0) || (c_val > 0)) ? g2 : 0.f |
|
); |
|
const float gc = gh_val*mask*tmp + cur; |
|
// grad wrt u0 |
|
*gup = gc*(1-g1); |
|
// grad wrt g1, u1, and bias1 |
|
float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); |
|
*(gup+1) = gg1; |
|
gbias1 += gg1; |
|
// grad wrt c' |
|
cur = gc*g1; |
|
up -= ncols_u; |
|
xp -= ncols_x; |
|
cp -= ncols; |
|
gup -= ncols_u; |
|
gxp -= ncols_x; |
|
ghp -= ncols; |
|
} |
|
*(grad_bias + col) = gbias1; |
|
*(grad_bias + col + ncols) = gbias2; |
|
*(grad_init +col) = cur; |
|
} |
|
__global__ void sru_bi_fwd(const float * __restrict__ u, |
|
const float * __restrict__ x, |
|
const float * __restrict__ bias, |
|
const float * __restrict__ init, |
|
const float * __restrict__ mask_h, |
|
const int len, const int batch, |
|
const int d, const int k, |
|
float * __restrict__ h, |
|
float * __restrict__ c, |
|
const int activation_type) |
|
{ |
|
assert ((k == 3) || (x == NULL)); |
|
assert ((k == 3) || (k == 4)); |
|
int ncols = batch*d*2; |
|
int col = blockIdx.x * blockDim.x + threadIdx.x; |
|
if (col >= ncols) return; |
|
int ncols_u = ncols*k; |
|
int ncols_x = (k == 3) ? ncols : ncols_u; |
|
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
|
float cur = *(init + col); |
|
const int d2 = d*2; |
|
const bool flip = (col%d2) >= d; |
|
const float bias1 = *(bias + (col%d2)); |
|
const float bias2 = *(bias + (col%d2) + d2); |
|
const float *up = u + (col*k); |
|
const float *xp = (k == 3) ? (x + col) : (up + 3); |
|
float *cp = c + col; |
|
float *hp = h + col; |
|
if (flip) { |
|
up += (len-1)*ncols_u; |
|
xp += (len-1)*ncols_x; |
|
cp += (len-1)*ncols; |
|
hp += (len-1)*ncols; |
|
} |
|
int ncols_u_ = flip ? -ncols_u : ncols_u; |
|
int ncols_x_ = flip ? -ncols_x : ncols_x; |
|
int ncols_ = flip ? -ncols : ncols; |
|
for (int cnt = 0; cnt < len; ++cnt) |
|
{ |
|
float g1 = sigmoidf((*(up+1))+bias1); |
|
float g2 = sigmoidf((*(up+2))+bias2); |
|
cur = (cur-(*up))*g1 + (*up); |
|
*cp = cur; |
|
float val = (activation_type == 1) ? tanh(cur) : ( |
|
(activation_type == 2) ? reluf(cur) : cur |
|
); |
|
*hp = (val*mask-(*xp))*g2 + (*xp); |
|
up += ncols_u_; |
|
xp += ncols_x_; |
|
cp += ncols_; |
|
hp += ncols_; |
|
} |
|
} |
|
__global__ void sru_bi_bwd(const float * __restrict__ u, |
|
const float * __restrict__ x, |
|
const float * __restrict__ bias, |
|
const float * __restrict__ init, |
|
const float * __restrict__ mask_h, |
|
const float * __restrict__ c, |
|
const float * __restrict__ grad_h, |
|
const float * __restrict__ grad_last, |
|
const int len, const int batch, |
|
const int d, const int k, |
|
float * __restrict__ grad_u, |
|
float * __restrict__ grad_x, |
|
float * __restrict__ grad_bias, |
|
float * __restrict__ grad_init, |
|
int activation_type) |
|
{ |
|
assert((k == 3) || (x == NULL)); |
|
assert((k == 3) || (grad_x == NULL)); |
|
assert((k == 3) || (k == 4)); |
|
int ncols = batch*d*2; |
|
int col = blockIdx.x * blockDim.x + threadIdx.x; |
|
if (col >= ncols) return; |
|
int ncols_u = ncols*k; |
|
int ncols_x = (k == 3) ? ncols : ncols_u; |
|
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
|
float gbias1 = 0; |
|
float gbias2 = 0; |
|
float cur = *(grad_last + col); |
|
const int d2 = d*2; |
|
const bool flip = ((col%d2) >= d); |
|
const float bias1 = *(bias + (col%d2)); |
|
const float bias2 = *(bias + (col%d2) + d2); |
|
const float *up = u + (col*k); |
|
const float *xp = (k == 3) ? (x + col) : (up + 3); |
|
const float *cp = c + col; |
|
const float *ghp = grad_h + col; |
|
float *gup = grad_u + (col*k); |
|
float *gxp = (k == 3) ? (grad_x + col) : (gup + 3); |
|
if (!flip) { |
|
up += (len-1)*ncols_u; |
|
xp += (len-1)*ncols_x; |
|
cp += (len-1)*ncols; |
|
ghp += (len-1)*ncols; |
|
gup += (len-1)*ncols_u; |
|
gxp += (len-1)*ncols_x; |
|
} |
|
int ncols_u_ = flip ? -ncols_u : ncols_u; |
|
int ncols_x_ = flip ? -ncols_x : ncols_x; |
|
int ncols_ = flip ? -ncols : ncols; |
|
for (int cnt = 0; cnt < len; ++cnt) |
|
{ |
|
const float g1 = sigmoidf((*(up+1))+bias1); |
|
const float g2 = sigmoidf((*(up+2))+bias2); |
|
const float c_val = (activation_type == 1) ? tanh(*cp) : ( |
|
(activation_type == 2) ? reluf(*cp) : (*cp) |
|
); |
|
const float x_val = *xp; |
|
const float u_val = *up; |
|
const float prev_c_val = (cnt<len-1)?(*(cp-ncols_)):(*(init+col)); |
|
const float gh_val = *ghp; |
|
// h = c*g2 + x*(1-g2) = (c-x)*g2 + x |
|
// c = c'*g1 + g0*(1-g1) = (c'-g0)*g1 + g0 |
|
// grad wrt x |
|
*gxp = gh_val*(1-g2); |
|
// grad wrt g2, u2 and bias2 |
|
float gg2 = gh_val*(c_val*mask-x_val)*(g2*(1-g2)); |
|
*(gup+2) = gg2; |
|
gbias2 += gg2; |
|
// grad wrt c |
|
const float tmp = (activation_type == 1) ? (g2*(1-c_val*c_val)) : ( |
|
((activation_type == 0) || (c_val > 0)) ? g2 : 0.f |
|
); |
|
const float gc = gh_val*mask*tmp + cur; |
|
// grad wrt u0 |
|
*gup = gc*(1-g1); |
|
// grad wrt g1, u1, and bias1 |
|
float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); |
|
*(gup+1) = gg1; |
|
gbias1 += gg1; |
|
// grad wrt c' |
|
cur = gc*g1; |
|
up -= ncols_u_; |
|
xp -= ncols_x_; |
|
cp -= ncols_; |
|
gup -= ncols_u_; |
|
gxp -= ncols_x_; |
|
ghp -= ncols_; |
|
} |
|
*(grad_bias + col) = gbias1; |
|
*(grad_bias + col + ncols) = gbias2; |
|
*(grad_init +col) = cur; |
|
} |
|
} |
|
""" |
|
SRU_FWD_FUNC, SRU_BWD_FUNC = None, None |
|
SRU_BiFWD_FUNC, SRU_BiBWD_FUNC = None, None |
|
SRU_STREAM = None |
|
|
|
|
|
def load_sru_mod(): |
|
global SRU_FWD_FUNC, SRU_BWD_FUNC, SRU_BiFWD_FUNC, SRU_BiBWD_FUNC |
|
global SRU_STREAM |
|
if check_sru_requirement(): |
|
from cupy.cuda import function |
|
from pynvrtc.compiler import Program |
|
|
|
|
|
device = torch.device("cuda") |
|
tmp_ = torch.rand(1, 1).to(device) |
|
|
|
sru_prog = Program(SRU_CODE.encode('utf-8'), |
|
'sru_prog.cu'.encode('utf-8')) |
|
sru_ptx = sru_prog.compile() |
|
sru_mod = function.Module() |
|
sru_mod.load(bytes(sru_ptx.encode())) |
|
|
|
SRU_FWD_FUNC = sru_mod.get_function('sru_fwd') |
|
SRU_BWD_FUNC = sru_mod.get_function('sru_bwd') |
|
SRU_BiFWD_FUNC = sru_mod.get_function('sru_bi_fwd') |
|
SRU_BiBWD_FUNC = sru_mod.get_function('sru_bi_bwd') |
|
|
|
stream = namedtuple('Stream', ['ptr']) |
|
SRU_STREAM = stream(ptr=torch.cuda.current_stream().cuda_stream) |
|
|
|
|
|
class SRU_Compute(Function): |
|
|
|
def __init__(self, activation_type, d_out, bidirectional=False): |
|
SRU_Compute.maybe_load_sru_mod() |
|
super(SRU_Compute, self).__init__() |
|
self.activation_type = activation_type |
|
self.d_out = d_out |
|
self.bidirectional = bidirectional |
|
|
|
@staticmethod |
|
def maybe_load_sru_mod(): |
|
global SRU_FWD_FUNC |
|
|
|
if SRU_FWD_FUNC is None: |
|
load_sru_mod() |
|
|
|
@custom_fwd |
|
def forward(self, u, x, bias, init=None, mask_h=None): |
|
bidir = 2 if self.bidirectional else 1 |
|
length = x.size(0) if x.dim() == 3 else 1 |
|
batch = x.size(-2) |
|
d = self.d_out |
|
k = u.size(-1) // d |
|
k_ = k // 2 if self.bidirectional else k |
|
ncols = batch * d * bidir |
|
thread_per_block = min(512, ncols) |
|
num_block = (ncols - 1) // thread_per_block + 1 |
|
|
|
init_ = x.new(ncols).zero_() if init is None else init |
|
size = (length, batch, d * bidir) if x.dim() == 3 else (batch, d * bidir) |
|
c = x.new(*size) |
|
h = x.new(*size) |
|
|
|
FUNC = SRU_FWD_FUNC if not self.bidirectional else SRU_BiFWD_FUNC |
|
FUNC(args=[ |
|
u.contiguous().data_ptr(), |
|
x.contiguous().data_ptr() if k_ == 3 else 0, |
|
bias.data_ptr(), |
|
init_.contiguous().data_ptr(), |
|
mask_h.data_ptr() if mask_h is not None else 0, |
|
length, |
|
batch, |
|
d, |
|
k_, |
|
h.data_ptr(), |
|
c.data_ptr(), |
|
self.activation_type], |
|
block=(thread_per_block, 1, 1), grid=(num_block, 1, 1), |
|
stream=SRU_STREAM |
|
) |
|
|
|
self.save_for_backward(u, x, bias, init, mask_h) |
|
self.intermediate = c |
|
if x.dim() == 2: |
|
last_hidden = c |
|
elif self.bidirectional: |
|
|
|
last_hidden = torch.stack((c[-1, :, :d], c[0, :, d:])) |
|
else: |
|
last_hidden = c[-1] |
|
return h, last_hidden |
|
|
|
@custom_bwd |
|
def backward(self, grad_h, grad_last): |
|
if self.bidirectional: |
|
grad_last = torch.cat((grad_last[0], grad_last[1]), 1) |
|
bidir = 2 if self.bidirectional else 1 |
|
u, x, bias, init, mask_h = self.saved_tensors |
|
c = self.intermediate |
|
length = x.size(0) if x.dim() == 3 else 1 |
|
batch = x.size(-2) |
|
d = self.d_out |
|
k = u.size(-1) // d |
|
k_ = k // 2 if self.bidirectional else k |
|
ncols = batch * d * bidir |
|
thread_per_block = min(512, ncols) |
|
num_block = (ncols - 1) // thread_per_block + 1 |
|
|
|
init_ = x.new(ncols).zero_() if init is None else init |
|
grad_u = u.new(*u.size()) |
|
grad_bias = x.new(2, batch, d * bidir) |
|
grad_init = x.new(batch, d * bidir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_x = x.new(*x.size()) if k_ == 3 else None |
|
|
|
FUNC = SRU_BWD_FUNC if not self.bidirectional else SRU_BiBWD_FUNC |
|
FUNC(args=[ |
|
u.contiguous().data_ptr(), |
|
x.contiguous().data_ptr() if k_ == 3 else 0, |
|
bias.data_ptr(), |
|
init_.contiguous().data_ptr(), |
|
mask_h.data_ptr() if mask_h is not None else 0, |
|
c.data_ptr(), |
|
grad_h.contiguous().data_ptr(), |
|
grad_last.contiguous().data_ptr(), |
|
length, |
|
batch, |
|
d, |
|
k_, |
|
grad_u.data_ptr(), |
|
grad_x.data_ptr() if k_ == 3 else 0, |
|
grad_bias.data_ptr(), |
|
grad_init.data_ptr(), |
|
self.activation_type], |
|
block=(thread_per_block, 1, 1), grid=(num_block, 1, 1), |
|
stream=SRU_STREAM |
|
) |
|
return grad_u, grad_x, grad_bias.sum(1).view(-1), grad_init, None |
|
|
|
|
|
class SRUCell(nn.Module): |
|
def __init__(self, n_in, n_out, dropout=0, rnn_dropout=0, |
|
bidirectional=False, use_tanh=1, use_relu=0): |
|
super(SRUCell, self).__init__() |
|
self.n_in = n_in |
|
self.n_out = n_out |
|
self.rnn_dropout = rnn_dropout |
|
self.dropout = dropout |
|
self.bidirectional = bidirectional |
|
self.activation_type = 2 if use_relu else (1 if use_tanh else 0) |
|
|
|
out_size = n_out * 2 if bidirectional else n_out |
|
k = 4 if n_in != out_size else 3 |
|
self.size_per_dir = n_out * k |
|
self.weight = nn.Parameter(torch.Tensor( |
|
n_in, |
|
self.size_per_dir * 2 if bidirectional else self.size_per_dir |
|
)) |
|
self.bias = nn.Parameter(torch.Tensor( |
|
n_out * 4 if bidirectional else n_out * 2 |
|
)) |
|
self.init_weight() |
|
|
|
def init_weight(self): |
|
val_range = (3.0 / self.n_in)**0.5 |
|
self.weight.data.uniform_(-val_range, val_range) |
|
self.bias.data.zero_() |
|
|
|
def set_bias(self, bias_val=0): |
|
n_out = self.n_out |
|
if self.bidirectional: |
|
self.bias.data[n_out * 2:].zero_().add_(bias_val) |
|
else: |
|
self.bias.data[n_out:].zero_().add_(bias_val) |
|
|
|
def forward(self, input, c0=None): |
|
assert input.dim() == 2 or input.dim() == 3 |
|
n_in, n_out = self.n_in, self.n_out |
|
batch = input.size(-2) |
|
if c0 is None: |
|
c0 = input.data.new( |
|
batch, n_out if not self.bidirectional else n_out * 2 |
|
).zero_() |
|
|
|
if self.training and (self.rnn_dropout > 0): |
|
mask = self.get_dropout_mask_((batch, n_in), self.rnn_dropout) |
|
x = input * mask.expand_as(input) |
|
else: |
|
x = input |
|
|
|
x_2d = x if x.dim() == 2 else x.contiguous().view(-1, n_in) |
|
u = x_2d.mm(self.weight) |
|
|
|
if self.training and (self.dropout > 0): |
|
bidir = 2 if self.bidirectional else 1 |
|
mask_h = self.get_dropout_mask_( |
|
(batch, n_out * bidir), self.dropout) |
|
h, c = SRU_Compute(self.activation_type, n_out, |
|
self.bidirectional)( |
|
u, input, self.bias, c0, mask_h |
|
) |
|
else: |
|
h, c = SRU_Compute(self.activation_type, n_out, |
|
self.bidirectional)( |
|
u, input, self.bias, c0 |
|
) |
|
|
|
return h, c |
|
|
|
def get_dropout_mask_(self, size, p): |
|
w = self.weight.data |
|
return w.new(*size).bernoulli_(1 - p).div_(1 - p) |
|
|
|
|
|
class SRU(nn.Module): |
|
""" |
|
Implementation of "Training RNNs as Fast as CNNs" |
|
:cite:`DBLP:journals/corr/abs-1709-02755` |
|
|
|
TODO: turn to pytorch's implementation when it is available. |
|
|
|
This implementation is adpoted from the author of the paper: |
|
https://github.com/taolei87/sru/blob/master/cuda_functional.py. |
|
|
|
Args: |
|
input_size (int): input to model |
|
hidden_size (int): hidden dimension |
|
num_layers (int): number of layers |
|
dropout (float): dropout to use (stacked) |
|
rnn_dropout (float): dropout to use (recurrent) |
|
bidirectional (bool): bidirectional |
|
use_tanh (bool): activation |
|
use_relu (bool): activation |
|
""" |
|
|
|
def __init__(self, input_size, hidden_size, |
|
num_layers=2, dropout=0, rnn_dropout=0, |
|
bidirectional=False, use_tanh=1, use_relu=0): |
|
|
|
|
|
check_sru_requirement(abort=True) |
|
super(SRU, self).__init__() |
|
self.n_in = input_size |
|
self.n_out = hidden_size |
|
self.depth = num_layers |
|
self.dropout = dropout |
|
self.rnn_dropout = rnn_dropout |
|
self.rnn_lst = nn.ModuleList() |
|
self.bidirectional = bidirectional |
|
self.out_size = hidden_size * 2 if bidirectional else hidden_size |
|
|
|
for i in range(num_layers): |
|
sru_cell = SRUCell( |
|
n_in=self.n_in if i == 0 else self.out_size, |
|
n_out=self.n_out, |
|
dropout=dropout if i + 1 != num_layers else 0, |
|
rnn_dropout=rnn_dropout, |
|
bidirectional=bidirectional, |
|
use_tanh=use_tanh, |
|
use_relu=use_relu, |
|
) |
|
self.rnn_lst.append(sru_cell) |
|
|
|
def set_bias(self, bias_val=0): |
|
for l in self.rnn_lst: |
|
l.set_bias(bias_val) |
|
|
|
def forward(self, input, c0=None, return_hidden=True): |
|
assert input.dim() == 3 |
|
dir_ = 2 if self.bidirectional else 1 |
|
if c0 is None: |
|
zeros = input.data.new( |
|
input.size(1), self.n_out * dir_ |
|
).zero_() |
|
c0 = [zeros for i in range(self.depth)] |
|
else: |
|
if isinstance(c0, tuple): |
|
|
|
c0 = c0[0] |
|
assert c0.dim() == 3 |
|
c0 = [h.squeeze(0) for h in c0.chunk(self.depth, 0)] |
|
|
|
prevx = input |
|
lstc = [] |
|
for i, rnn in enumerate(self.rnn_lst): |
|
h, c = rnn(prevx, c0[i]) |
|
prevx = h |
|
lstc.append(c) |
|
|
|
if self.bidirectional: |
|
|
|
fh = torch.cat(lstc) |
|
else: |
|
fh = torch.stack(lstc) |
|
|
|
if return_hidden: |
|
return prevx, fh |
|
else: |
|
return prevx |
|
|