File size: 9,110 Bytes
cc9dfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from ...torch_core import *
from torch.utils.cpp_extension import load
from torch.autograd import Function
__all__ = ['QRNNLayer', 'QRNN']
import fastai
if torch.cuda.is_available():
fastai_path = Path(fastai.__path__[0])/'text'/'models'
files = ['forget_mult_cuda.cpp', 'forget_mult_cuda_kernel.cu']
forget_mult_cuda = load(name='forget_mult_cuda', sources=[fastai_path/f for f in files])
files = ['bwd_forget_mult_cuda.cpp', 'bwd_forget_mult_cuda_kernel.cu']
bwd_forget_mult_cuda = load(name='bwd_forget_mult_cuda', sources=[fastai_path/f for f in files])
def dispatch_cuda(cuda_class, cpu_func, x):
return cuda_class.apply if x.device.type == 'cuda' else cpu_func
class ForgetMultGPU(Function):
@staticmethod
def forward(ctx, x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True):
if batch_first:
batch_size, seq_size, hidden_size = f.size()
output = f.new_zeros(batch_size, seq_size + 1, hidden_size)
if hidden_init is not None: output[:, 0] = hidden_init
else: output.zero_()
else:
seq_size, batch_size, hidden_size = f.size()
output = f.new(seq_size + 1, batch_size, hidden_size)
if hidden_init is not None: output[0] = hidden_init
else: output.zero_()
output = forget_mult_cuda.forward(x, f, output, batch_first)
ctx.save_for_backward(x, f, hidden_init, output)
ctx.batch_first = batch_first
return output[:,1:] if batch_first else output[1:]
@staticmethod
def backward(ctx, grad_output):
x, f, hidden_init, output = ctx.saved_tensors
grad_x, grad_f, grad_h = forget_mult_cuda.backward(x, f, output, grad_output, ctx.batch_first)
return (grad_x, grad_f, (None if hidden_init is None else grad_h), None)
class BwdForgetMultGPU(Function):
@staticmethod
def forward(ctx, x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True):
if batch_first:
batch_size, seq_size, hidden_size = f.size()
output = f.new(batch_size, seq_size + 1, hidden_size)
if hidden_init is not None: output[:, -1] = hidden_init
else: output.zero_()
else:
seq_size, batch_size, hidden_size = f.size()
output = f.new(seq_size + 1, batch_size, hidden_size)
if hidden_init is not None: output[-1] = hidden_init
else: output.zero_()
output = bwd_forget_mult_cuda.forward(x, f, output, batch_first)
ctx.save_for_backward(x, f, hidden_init, output)
ctx.batch_first = batch_first
return output[:,:-1] if batch_first else output[:-1]
@staticmethod
def backward(ctx, grad_output:Tensor):
x, f, hidden_init, output = ctx.saved_tensors
grad_x, grad_f, grad_h = bwd_forget_mult_cuda.backward(x, f, output, grad_output, ctx.batch_first)
return (grad_x, grad_f, (None if hidden_init is None else grad_h), None)
def forget_mult_CPU(x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True, backward:bool=False):
result = []
dim = (1 if batch_first else 0)
forgets = f.split(1, dim=dim)
inputs = x.split(1, dim=dim)
prev_h = None if hidden_init is None else hidden_init.unsqueeze(1 if batch_first else 0)
idx_range = range(len(inputs)-1,-1,-1) if backward else range(len(inputs))
for i in idx_range:
prev_h = inputs[i] * forgets[i] if prev_h is None else inputs[i] * forgets[i] + (1-forgets[i]) * prev_h
if backward: result.insert(0, prev_h)
else: result.append(prev_h)
return torch.cat(result, dim=dim)
class QRNNLayer(Module):
"Apply a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence."
def __init__(self, input_size:int, hidden_size:int=None, save_prev_x:bool=False, zoneout:float=0, window:int=1,
output_gate:bool=True, batch_first:bool=True, backward:bool=False):
super().__init__()
assert window in [1, 2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2"
self.save_prev_x,self.zoneout,self.window = save_prev_x,zoneout,window
self.output_gate,self.batch_first,self.backward = output_gate,batch_first,backward
hidden_size = ifnone(hidden_size, input_size)
#One large matmul with concat is faster than N small matmuls and no concat
mult = (3 if output_gate else 2)
self.linear = nn.Linear(window * input_size, mult * hidden_size)
self.prevX = None
def reset(self):
# If you are saving the previous value of x, you should call this when starting with a new state
self.prevX = None
def forward(self, inp, hid=None):
y = self.linear(self._get_source(inp))
if self.output_gate: z_gate,f_gate,o_gate = y.chunk(3, dim=2)
else: z_gate,f_gate = y.chunk(2, dim=2)
z_gate.tanh_()
f_gate.sigmoid_()
if self.zoneout and self.training:
mask = dropout_mask(f_gate, f_gate.size(), self.zoneout).requires_grad_(False)
f_gate = f_gate * mask
z_gate,f_gate = z_gate.contiguous(),f_gate.contiguous()
if self.backward: forget_mult = dispatch_cuda(BwdForgetMultGPU, partial(forget_mult_CPU, backward=True), inp)
else: forget_mult = dispatch_cuda(ForgetMultGPU, forget_mult_CPU, inp)
c_gate = forget_mult(z_gate, f_gate, hid, self.batch_first)
output = torch.sigmoid(o_gate) * c_gate if self.output_gate else c_gate
if self.window > 1 and self.save_prev_x:
if self.backward: self.prevX = (inp[:, :1] if self.batch_first else inp[:1]).detach()
else: self.prevX = (inp[:, -1:] if self.batch_first else inp[-1:]).detach()
idx = 0 if self.backward else -1
return output, (c_gate[:, idx] if self.batch_first else c_gate[idx])
def _get_source(self, inp):
if self.window == 1: return inp
dim = (1 if self.batch_first else 0)
inp_shift = [torch.zeros_like(inp[:,:1] if self.batch_first else inp[:1]) if self.prevX is None else self.prevX]
if self.backward: inp_shift.insert(0,inp[:,1:] if self.batch_first else inp[1:])
else: inp_shift.append(inp[:,:-1] if self.batch_first else inp[:-1])
inp_shift = torch.cat(inp_shift, dim)
return torch.cat([inp, inp_shift], 2)
class QRNN(Module):
"Apply a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence."
def __init__(self, input_size:int, hidden_size:int, n_layers:int=1, bias:bool=True, batch_first:bool=True,
dropout:float=0, bidirectional:bool=False, save_prev_x:bool=False, zoneout:float=0, window:int=None,
output_gate:bool=True):
assert not (save_prev_x and bidirectional), "Can't save the previous X with bidirectional."
assert bias == True, 'Removing underlying bias is not yet supported'
super().__init__()
kwargs = dict(batch_first=batch_first, zoneout=zoneout, output_gate=output_gate)
self.layers = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, save_prev_x=save_prev_x,
window=((2 if l ==0 else 1) if window is None else window), **kwargs)
for l in range(n_layers)])
if bidirectional:
self.layers_bwd = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size,
backward=True, window=((2 if l ==0 else 1) if window is None else window),
**kwargs) for l in range(n_layers)])
self.n_layers,self.batch_first,self.dropout,self.bidirectional = n_layers,batch_first,dropout,bidirectional
def reset(self):
"If your convolutional window is greater than 1 and you save previous xs, you must reset at the beginning of each new sequence."
for layer in self.layers: layer.reset()
if self.bidirectional:
for layer in self.layers_bwd: layer.reset()
def forward(self, inp, hid=None):
new_hid = []
if self.bidirectional: inp_bwd = inp.clone()
for i, layer in enumerate(self.layers):
inp, h = layer(inp, None if hid is None else hid[2*i if self.bidirectional else i])
new_hid.append(h)
if self.bidirectional:
inp_bwd, h_bwd = self.layers_bwd[i](inp_bwd, None if hid is None else hid[2*i+1])
new_hid.append(h_bwd)
if self.dropout != 0 and i < len(self.layers) - 1:
for o in ([inp, inp_bwd] if self.bidirectional else [inp]):
o = F.dropout(o, p=self.dropout, training=self.training, inplace=False)
if self.bidirectional: inp = torch.cat([inp, inp_bwd], dim=2)
return inp, torch.stack(new_hid, 0) |