""" SRU Implementation """ |
import subprocess |
import platform |
import os |
import re |
import configargparse |
import torch |
import torch.nn as nn |
from torch.autograd import Function |
from torch.cuda.amp import custom_fwd, custom_bwd |
from collections import namedtuple |
class CheckSRU(configargparse.Action): |
def __init__(self, option_strings, dest, **kwargs): |
super(CheckSRU, self).__init__(option_strings, dest, **kwargs) |
def __call__(self, parser, namespace, values, option_string=None): |
if values == "SRU": |
check_sru_requirement(abort=True) |
setattr(namespace, self.dest, values) |
def check_sru_requirement(abort=False): |
""" |
Return True if check pass; if check fails and abort is True, |
raise an Exception, othereise return False. |
""" |
try: |
if platform.system() == "Windows": |
subprocess.check_output("pip freeze | findstr cupy", shell=True) |
subprocess.check_output("pip freeze | findstr pynvrtc", shell=True) |
else: |
subprocess.check_output("pip freeze | grep -w cupy", shell=True) |
subprocess.check_output("pip freeze | grep -w pynvrtc", shell=True) |
except subprocess.CalledProcessError: |
if not abort: |
return False |
raise AssertionError( |
"Using SRU requires 'cupy' and 'pynvrtc' " "python packages installed." |
) |
if torch.cuda.is_available() is False: |
if not abort: |
return False |
raise AssertionError("Using SRU requires pytorch built with cuda.") |
pattern = re.compile(".*cuda/lib.*") |
ld_path = os.getenv("LD_LIBRARY_PATH", "") |
if re.match(pattern, ld_path) is None: |
if not abort: |
return False |
raise AssertionError( |
"Using SRU requires setting cuda lib path, e.g. " |
"export LD_LIBRARY_PATH=/usr/local/cuda/lib64." |
) |
return True |
SRU_CODE = """ |
extern "C" { |
__forceinline__ __device__ float sigmoidf(float x) |
{ |
return 1.f / (1.f + expf(-x)); |
} |
__forceinline__ __device__ float reluf(float x) |
{ |
return (x > 0.f) ? x : 0.f; |
} |
__global__ void sru_fwd(const float * __restrict__ u, |
const float * __restrict__ x, |
const float * __restrict__ bias, |
const float * __restrict__ init, |
const float * __restrict__ mask_h, |
const int len, const int batch, |
const int d, const int k, |
float * __restrict__ h, |
float * __restrict__ c, |
const int activation_type) |
{ |
assert ((k == 3) || (x == NULL)); |
int ncols = batch*d; |
int col = blockIdx.x * blockDim.x + threadIdx.x; |
if (col >= ncols) return; |
int ncols_u = ncols*k; |
int ncols_x = (k == 3) ? ncols : ncols_u; |
const float bias1 = *(bias + (col%d)); |
const float bias2 = *(bias + (col%d) + d); |
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
float cur = *(init + col); |
const float *up = u + (col*k); |
const float *xp = (k == 3) ? (x + col) : (up + 3); |
float *cp = c + col; |
float *hp = h + col; |
for (int row = 0; row < len; ++row) |
{ |
float g1 = sigmoidf((*(up+1))+bias1); |
float g2 = sigmoidf((*(up+2))+bias2); |
cur = (cur-(*up))*g1 + (*up); |
*cp = cur; |
float val = (activation_type == 1) ? tanh(cur) : ( |
(activation_type == 2) ? reluf(cur) : cur |
); |
*hp = (val*mask-(*xp))*g2 + (*xp); |
up += ncols_u; |
xp += ncols_x; |
cp += ncols; |
hp += ncols; |
} |
} |
__global__ void sru_bwd(const float * __restrict__ u, |
const float * __restrict__ x, |
const float * __restrict__ bias, |
const float * __restrict__ init, |
const float * __restrict__ mask_h, |
const float * __restrict__ c, |
const float * __restrict__ grad_h, |
const float * __restrict__ grad_last, |
const int len, |
const int batch, const int d, const int k, |
float * __restrict__ grad_u, |
float * __restrict__ grad_x, |
float * __restrict__ grad_bias, |
float * __restrict__ grad_init, |
int activation_type) |
{ |
assert((k == 3) || (x == NULL)); |
assert((k == 3) || (grad_x == NULL)); |
int ncols = batch*d; |
int col = blockIdx.x * blockDim.x + threadIdx.x; |
if (col >= ncols) return; |
int ncols_u = ncols*k; |
int ncols_x = (k == 3) ? ncols : ncols_u; |
const float bias1 = *(bias + (col%d)); |
const float bias2 = *(bias + (col%d) + d); |
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
float gbias1 = 0; |
float gbias2 = 0; |
float cur = *(grad_last + col); |
const float *up = u + (col*k) + (len-1)*ncols_u; |
const float *xp = (k == 3) ? (x + col + (len-1)*ncols) : (up + 3); |
const float *cp = c + col + (len-1)*ncols; |
const float *ghp = grad_h + col + (len-1)*ncols; |
float *gup = grad_u + (col*k) + (len-1)*ncols_u; |
float *gxp = (k == 3) ? (grad_x + col + (len-1)*ncols) : (gup + 3); |
for (int row = len-1; row >= 0; --row) |
{ |
const float g1 = sigmoidf((*(up+1))+bias1); |
const float g2 = sigmoidf((*(up+2))+bias2); |
const float c_val = (activation_type == 1) ? tanh(*cp) : ( |
(activation_type == 2) ? reluf(*cp) : (*cp) |
); |
const float x_val = *xp; |
const float u_val = *up; |
const float prev_c_val = (row>0) ? (*(cp-ncols)) : (*(init+col)); |
const float gh_val = *ghp; |
// h = c*g2 + x*(1-g2) = (c-x)*g2 + x |
// c = c'*g1 + g0*(1-g1) = (c'-g0)*g1 + g0 |
// grad wrt x |
*gxp = gh_val*(1-g2); |
// grad wrt g2, u2 and bias2 |
float gg2 = gh_val*(c_val*mask-x_val)*(g2*(1-g2)); |
*(gup+2) = gg2; |
gbias2 += gg2; |
// grad wrt c |
const float tmp = (activation_type == 1) ? (g2*(1-c_val*c_val)) : ( |
((activation_type == 0) || (c_val > 0)) ? g2 : 0.f |
); |
const float gc = gh_val*mask*tmp + cur; |
// grad wrt u0 |
*gup = gc*(1-g1); |
// grad wrt g1, u1, and bias1 |
float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); |
*(gup+1) = gg1; |
gbias1 += gg1; |
// grad wrt c' |
cur = gc*g1; |
up -= ncols_u; |
xp -= ncols_x; |
cp -= ncols; |
gup -= ncols_u; |
gxp -= ncols_x; |
ghp -= ncols; |
} |
*(grad_bias + col) = gbias1; |
*(grad_bias + col + ncols) = gbias2; |
*(grad_init +col) = cur; |
} |
__global__ void sru_bi_fwd(const float * __restrict__ u, |
const float * __restrict__ x, |
const float * __restrict__ bias, |
const float * __restrict__ init, |
const float * __restrict__ mask_h, |
const int len, const int batch, |
const int d, const int k, |
float * __restrict__ h, |
float * __restrict__ c, |
const int activation_type) |
{ |
assert ((k == 3) || (x == NULL)); |
assert ((k == 3) || (k == 4)); |
int ncols = batch*d*2; |
int col = blockIdx.x * blockDim.x + threadIdx.x; |
if (col >= ncols) return; |
int ncols_u = ncols*k; |
int ncols_x = (k == 3) ? ncols : ncols_u; |
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
float cur = *(init + col); |
const int d2 = d*2; |
const bool flip = (col%d2) >= d; |
const float bias1 = *(bias + (col%d2)); |
const float bias2 = *(bias + (col%d2) + d2); |
const float *up = u + (col*k); |
const float *xp = (k == 3) ? (x + col) : (up + 3); |
float *cp = c + col; |
float *hp = h + col; |
if (flip) { |
up += (len-1)*ncols_u; |
xp += (len-1)*ncols_x; |
cp += (len-1)*ncols; |
hp += (len-1)*ncols; |
} |
int ncols_u_ = flip ? -ncols_u : ncols_u; |
int ncols_x_ = flip ? -ncols_x : ncols_x; |
int ncols_ = flip ? -ncols : ncols; |
for (int cnt = 0; cnt < len; ++cnt) |
{ |
float g1 = sigmoidf((*(up+1))+bias1); |
float g2 = sigmoidf((*(up+2))+bias2); |
cur = (cur-(*up))*g1 + (*up); |
*cp = cur; |
float val = (activation_type == 1) ? tanh(cur) : ( |
(activation_type == 2) ? reluf(cur) : cur |
); |
*hp = (val*mask-(*xp))*g2 + (*xp); |
up += ncols_u_; |
xp += ncols_x_; |
cp += ncols_; |
hp += ncols_; |
} |
} |
__global__ void sru_bi_bwd(const float * __restrict__ u, |
const float * __restrict__ x, |
const float * __restrict__ bias, |
const float * __restrict__ init, |
const float * __restrict__ mask_h, |
const float * __restrict__ c, |
const float * __restrict__ grad_h, |
const float * __restrict__ grad_last, |
const int len, const int batch, |
const int d, const int k, |
float * __restrict__ grad_u, |
float * __restrict__ grad_x, |
float * __restrict__ grad_bias, |
float * __restrict__ grad_init, |
int activation_type) |
{ |
assert((k == 3) || (x == NULL)); |
assert((k == 3) || (grad_x == NULL)); |
assert((k == 3) || (k == 4)); |
int ncols = batch*d*2; |
int col = blockIdx.x * blockDim.x + threadIdx.x; |
if (col >= ncols) return; |
int ncols_u = ncols*k; |
int ncols_x = (k == 3) ? ncols : ncols_u; |
const float mask = (mask_h == NULL) ? 1.0 : (*(mask_h + col)); |
float gbias1 = 0; |
float gbias2 = 0; |
float cur = *(grad_last + col); |
const int d2 = d*2; |
const bool flip = ((col%d2) >= d); |
const float bias1 = *(bias + (col%d2)); |
const float bias2 = *(bias + (col%d2) + d2); |
const float *up = u + (col*k); |
const float *xp = (k == 3) ? (x + col) : (up + 3); |
const float *cp = c + col; |
const float *ghp = grad_h + col; |
float *gup = grad_u + (col*k); |
float *gxp = (k == 3) ? (grad_x + col) : (gup + 3); |
if (!flip) { |
up += (len-1)*ncols_u; |
xp += (len-1)*ncols_x; |
cp += (len-1)*ncols; |
ghp += (len-1)*ncols; |
gup += (len-1)*ncols_u; |
gxp += (len-1)*ncols_x; |
} |
int ncols_u_ = flip ? -ncols_u : ncols_u; |
int ncols_x_ = flip ? -ncols_x : ncols_x; |
int ncols_ = flip ? -ncols : ncols; |
for (int cnt = 0; cnt < len; ++cnt) |
{ |
const float g1 = sigmoidf((*(up+1))+bias1); |
const float g2 = sigmoidf((*(up+2))+bias2); |
const float c_val = (activation_type == 1) ? tanh(*cp) : ( |
(activation_type == 2) ? reluf(*cp) : (*cp) |
); |
const float x_val = *xp; |
const float u_val = *up; |
const float prev_c_val = (cnt<len-1)?(*(cp-ncols_)):(*(init+col)); |
const float gh_val = *ghp; |
// h = c*g2 + x*(1-g2) = (c-x)*g2 + x |
// c = c'*g1 + g0*(1-g1) = (c'-g0)*g1 + g0 |
// grad wrt x |
*gxp = gh_val*(1-g2); |
// grad wrt g2, u2 and bias2 |
float gg2 = gh_val*(c_val*mask-x_val)*(g2*(1-g2)); |
*(gup+2) = gg2; |
gbias2 += gg2; |
// grad wrt c |
const float tmp = (activation_type == 1) ? (g2*(1-c_val*c_val)) : ( |
((activation_type == 0) || (c_val > 0)) ? g2 : 0.f |
); |
const float gc = gh_val*mask*tmp + cur; |
// grad wrt u0 |
*gup = gc*(1-g1); |
// grad wrt g1, u1, and bias1 |
float gg1 = gc*(prev_c_val-u_val)*(g1*(1-g1)); |
*(gup+1) = gg1; |
gbias1 += gg1; |
// grad wrt c' |
cur = gc*g1; |
up -= ncols_u_; |
xp -= ncols_x_; |
cp -= ncols_; |
gup -= ncols_u_; |
gxp -= ncols_x_; |
ghp -= ncols_; |
} |
*(grad_bias + col) = gbias1; |
*(grad_bias + col + ncols) = gbias2; |
*(grad_init +col) = cur; |
} |
} |
""" |
def load_sru_mod(): |
global SRU_STREAM |
if check_sru_requirement(): |
from cupy.cuda import function |
from pynvrtc.compiler import Program |
device = torch.device("cuda") |
tmp_ = torch.rand(1, 1).to(device) |
sru_prog = Program(SRU_CODE.encode("utf-8"), "sru_prog.cu".encode("utf-8")) |
sru_ptx = sru_prog.compile() |
sru_mod = function.Module() |
sru_mod.load(bytes(sru_ptx.encode())) |
SRU_FWD_FUNC = sru_mod.get_function("sru_fwd") |
SRU_BWD_FUNC = sru_mod.get_function("sru_bwd") |
SRU_BiFWD_FUNC = sru_mod.get_function("sru_bi_fwd") |
SRU_BiBWD_FUNC = sru_mod.get_function("sru_bi_bwd") |
stream = namedtuple("Stream", ["ptr"]) |
SRU_STREAM = stream(ptr=torch.cuda.current_stream().cuda_stream) |
class SRU_Compute(Function): |
def __init__(self, activation_type, d_out, bidirectional=False): |
SRU_Compute.maybe_load_sru_mod() |
super(SRU_Compute, self).__init__() |
self.activation_type = activation_type |
self.d_out = d_out |
self.bidirectional = bidirectional |
@staticmethod |
def maybe_load_sru_mod(): |
global SRU_FWD_FUNC |
if SRU_FWD_FUNC is None: |
load_sru_mod() |
@custom_fwd |
def forward(self, u, x, bias, init=None, mask_h=None): |
bidir = 2 if self.bidirectional else 1 |
length = x.size(0) if x.dim() == 3 else 1 |
batch = x.size(-2) |
d = self.d_out |
k = u.size(-1) // d |
k_ = k // 2 if self.bidirectional else k |
ncols = batch * d * bidir |
thread_per_block = min(512, ncols) |
num_block = (ncols - 1) // thread_per_block + 1 |
init_ = x.new(ncols).zero_() if init is None else init |
size = (length, batch, d * bidir) if x.dim() == 3 else (batch, d * bidir) |
c = x.new(*size) |
h = x.new(*size) |
FUNC = SRU_FWD_FUNC if not self.bidirectional else SRU_BiFWD_FUNC |
args=[ |
u.contiguous().data_ptr(), |
x.contiguous().data_ptr() if k_ == 3 else 0, |
bias.data_ptr(), |
init_.contiguous().data_ptr(), |
mask_h.data_ptr() if mask_h is not None else 0, |
length, |
batch, |
d, |
k_, |
h.data_ptr(), |
c.data_ptr(), |
self.activation_type, |
], |
block=(thread_per_block, 1, 1), |
grid=(num_block, 1, 1), |
stream=SRU_STREAM, |
) |
self.save_for_backward(u, x, bias, init, mask_h) |
self.intermediate = c |
if x.dim() == 2: |
last_hidden = c |
elif self.bidirectional: |
last_hidden = torch.stack((c[-1, :, :d], c[0, :, d:])) |
else: |
last_hidden = c[-1] |
return h, last_hidden |
@custom_bwd |
def backward(self, grad_h, grad_last): |
if self.bidirectional: |
grad_last = torch.cat((grad_last[0], grad_last[1]), 1) |
bidir = 2 if self.bidirectional else 1 |
u, x, bias, init, mask_h = self.saved_tensors |
c = self.intermediate |
length = x.size(0) if x.dim() == 3 else 1 |
batch = x.size(-2) |
d = self.d_out |
k = u.size(-1) // d |
k_ = k // 2 if self.bidirectional else k |
ncols = batch * d * bidir |
thread_per_block = min(512, ncols) |
num_block = (ncols - 1) // thread_per_block + 1 |
init_ = x.new(ncols).zero_() if init is None else init |
grad_u = u.new(*u.size()) |
grad_bias = x.new(2, batch, d * bidir) |
grad_init = x.new(batch, d * bidir) |
grad_x = x.new(*x.size()) if k_ == 3 else None |
FUNC = SRU_BWD_FUNC if not self.bidirectional else SRU_BiBWD_FUNC |
args=[ |
u.contiguous().data_ptr(), |
x.contiguous().data_ptr() if k_ == 3 else 0, |
bias.data_ptr(), |
init_.contiguous().data_ptr(), |
mask_h.data_ptr() if mask_h is not None else 0, |
c.data_ptr(), |
grad_h.contiguous().data_ptr(), |
grad_last.contiguous().data_ptr(), |
length, |
batch, |
d, |
k_, |
grad_u.data_ptr(), |
grad_x.data_ptr() if k_ == 3 else 0, |
grad_bias.data_ptr(), |
grad_init.data_ptr(), |
self.activation_type, |
], |
block=(thread_per_block, 1, 1), |
grid=(num_block, 1, 1), |
stream=SRU_STREAM, |
) |
return grad_u, grad_x, grad_bias.sum(1).view(-1), grad_init, None |
class SRUCell(nn.Module): |
def __init__( |
self, |
n_in, |
n_out, |
dropout=0, |
rnn_dropout=0, |
bidirectional=False, |
use_tanh=1, |
use_relu=0, |
): |
super(SRUCell, self).__init__() |
self.n_in = n_in |
self.n_out = n_out |
self.rnn_dropout = rnn_dropout |
self.dropout = dropout |
self.bidirectional = bidirectional |
self.activation_type = 2 if use_relu else (1 if use_tanh else 0) |
out_size = n_out * 2 if bidirectional else n_out |
k = 4 if n_in != out_size else 3 |
self.size_per_dir = n_out * k |
self.weight = nn.Parameter( |
torch.Tensor( |
n_in, self.size_per_dir * 2 if bidirectional else self.size_per_dir |
) |
) |
self.bias = nn.Parameter( |
torch.Tensor(n_out * 4 if bidirectional else n_out * 2) |
) |
self.init_weight() |
def init_weight(self): |
val_range = (3.0 / self.n_in) ** 0.5 |
self.weight.data.uniform_(-val_range, val_range) |
self.bias.data.zero_() |
def set_bias(self, bias_val=0): |
n_out = self.n_out |
if self.bidirectional: |
self.bias.data[n_out * 2 :].zero_().add_(bias_val) |
else: |
self.bias.data[n_out:].zero_().add_(bias_val) |
def forward(self, input, c0=None): |
assert input.dim() == 2 or input.dim() == 3 |
n_in, n_out = self.n_in, self.n_out |
batch = input.size(-2) |
if c0 is None: |
c0 = input.data.new( |
batch, n_out if not self.bidirectional else n_out * 2 |
).zero_() |
if self.training and (self.rnn_dropout > 0): |
mask = self.get_dropout_mask_((batch, n_in), self.rnn_dropout) |
x = input * mask.expand_as(input) |
else: |
x = input |
x_2d = x if x.dim() == 2 else x.contiguous().view(-1, n_in) |
u = x_2d.mm(self.weight) |
if self.training and (self.dropout > 0): |
bidir = 2 if self.bidirectional else 1 |
mask_h = self.get_dropout_mask_((batch, n_out * bidir), self.dropout) |
h, c = SRU_Compute(self.activation_type, n_out, self.bidirectional)( |
u, input, self.bias, c0, mask_h |
) |
else: |
h, c = SRU_Compute(self.activation_type, n_out, self.bidirectional)( |
u, input, self.bias, c0 |
) |
return h, c |
def get_dropout_mask_(self, size, p): |
w = self.weight.data |
return w.new(*size).bernoulli_(1 - p).div_(1 - p) |
class SRU(nn.Module): |
""" |
Implementation of "Training RNNs as Fast as CNNs" |
:cite:`DBLP:journals/corr/abs-1709-02755` |
TODO: turn to pytorch's implementation when it is available. |
This implementation is adpoted from the author of the paper: |
https://github.com/taolei87/sru/blob/master/cuda_functional.py. |
Args: |
input_size (int): input to model |
hidden_size (int): hidden dimension |
num_layers (int): number of layers |
dropout (float): dropout to use (stacked) |
rnn_dropout (float): dropout to use (recurrent) |
bidirectional (bool): bidirectional |
use_tanh (bool): activation |
use_relu (bool): activation |
""" |
def __init__( |
self, |
input_size, |
hidden_size, |
num_layers=2, |
dropout=0, |
rnn_dropout=0, |
bidirectional=False, |
use_tanh=1, |
use_relu=0, |
): |
check_sru_requirement(abort=True) |
super(SRU, self).__init__() |
self.n_in = input_size |
self.n_out = hidden_size |
self.depth = num_layers |
self.dropout = dropout |
self.rnn_dropout = rnn_dropout |
self.rnn_lst = nn.ModuleList() |
self.bidirectional = bidirectional |
self.out_size = hidden_size * 2 if bidirectional else hidden_size |
for i in range(num_layers): |
sru_cell = SRUCell( |
n_in=self.n_in if i == 0 else self.out_size, |
n_out=self.n_out, |
dropout=dropout if i + 1 != num_layers else 0, |
rnn_dropout=rnn_dropout, |
bidirectional=bidirectional, |
use_tanh=use_tanh, |
use_relu=use_relu, |
) |
self.rnn_lst.append(sru_cell) |
def set_bias(self, bias_val=0): |
for l in self.rnn_lst: |
l.set_bias(bias_val) |
def forward(self, input, c0=None, return_hidden=True): |
assert input.dim() == 3 |
dir_ = 2 if self.bidirectional else 1 |
if c0 is None: |
zeros = input.data.new(input.size(1), self.n_out * dir_).zero_() |
c0 = [zeros for i in range(self.depth)] |
else: |
if isinstance(c0, tuple): |
c0 = c0[0] |
assert c0.dim() == 3 |
c0 = [h.squeeze(0) for h in c0.chunk(self.depth, 0)] |
prevx = input |
lstc = [] |
for i, rnn in enumerate(self.rnn_lst): |
h, c = rnn(prevx, c0[i]) |
prevx = h |
lstc.append(c) |
if self.bidirectional: |
fh = torch.cat(lstc) |
else: |
fh = torch.stack(lstc) |
if return_hidden: |
return prevx, fh |
else: |
return prevx |