|
|
|
""" Implement a pyTorch LSTM with hard sigmoid reccurent activation functions. |
|
Adapted from the non-cuda variant of pyTorch LSTM at |
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py |
|
""" |
|
|
|
from __future__ import print_function, division |
|
import math |
|
import torch |
|
|
|
from torch.nn import Module |
|
from torch.nn.parameter import Parameter |
|
from torch.nn.utils.rnn import PackedSequence |
|
import torch.nn.functional as F |
|
|
|
class LSTMHardSigmoid(Module): |
|
|
|
def __init__(self, input_size, hidden_size, |
|
num_layers=1, bias=True, batch_first=False, |
|
dropout=0, bidirectional=False): |
|
super(LSTMHardSigmoid, self).__init__() |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
self.bias = bias |
|
self.batch_first = batch_first |
|
self.dropout = dropout |
|
self.dropout_state = {} |
|
self.bidirectional = bidirectional |
|
num_directions = 2 if bidirectional else 1 |
|
|
|
gate_size = 4 * hidden_size |
|
|
|
self._all_weights = [] |
|
for layer in range(num_layers): |
|
for direction in range(num_directions): |
|
layer_input_size = input_size if layer == 0 else hidden_size * num_directions |
|
|
|
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) |
|
w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) |
|
b_ih = Parameter(torch.Tensor(gate_size)) |
|
b_hh = Parameter(torch.Tensor(gate_size)) |
|
layer_params = (w_ih, w_hh, b_ih, b_hh) |
|
|
|
suffix = '_reverse' if direction == 1 else '' |
|
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] |
|
if bias: |
|
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] |
|
param_names = [x.format(layer, suffix) for x in param_names] |
|
|
|
for name, param in zip(param_names, layer_params): |
|
setattr(self, name, param) |
|
self._all_weights.append(param_names) |
|
|
|
self.flatten_parameters() |
|
self.reset_parameters() |
|
|
|
def flatten_parameters(self): |
|
"""Resets parameter data pointer so that they can use faster code paths. |
|
|
|
Right now, this is a no-op wince we don't use CUDA acceleration. |
|
""" |
|
self._data_ptrs = [] |
|
|
|
def _apply(self, fn): |
|
ret = super(LSTMHardSigmoid, self)._apply(fn) |
|
self.flatten_parameters() |
|
return ret |
|
|
|
def reset_parameters(self): |
|
stdv = 1.0 / math.sqrt(self.hidden_size) |
|
for weight in self.parameters(): |
|
weight.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, input, hx=None): |
|
is_packed = isinstance(input, PackedSequence) |
|
if is_packed: |
|
input, batch_sizes = input |
|
max_batch_size = batch_sizes[0] |
|
else: |
|
batch_sizes = None |
|
max_batch_size = input.size(0) if self.batch_first else input.size(1) |
|
|
|
if hx is None: |
|
num_directions = 2 if self.bidirectional else 1 |
|
hx = torch.autograd.Variable(input.data.new(self.num_layers * |
|
num_directions, |
|
max_batch_size, |
|
self.hidden_size).zero_(), requires_grad=False) |
|
hx = (hx, hx) |
|
|
|
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs |
|
if has_flat_weights: |
|
first_data = next(self.parameters()).data |
|
assert first_data.storage().size() == self._param_buf_size |
|
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) |
|
else: |
|
flat_weight = None |
|
func = AutogradRNN( |
|
self.input_size, |
|
self.hidden_size, |
|
num_layers=self.num_layers, |
|
batch_first=self.batch_first, |
|
dropout=self.dropout, |
|
train=self.training, |
|
bidirectional=self.bidirectional, |
|
batch_sizes=batch_sizes, |
|
dropout_state=self.dropout_state, |
|
flat_weight=flat_weight |
|
) |
|
output, hidden = func(input, self.all_weights, hx) |
|
if is_packed: |
|
output = PackedSequence(output, batch_sizes) |
|
return output, hidden |
|
|
|
def __repr__(self): |
|
s = '{name}({input_size}, {hidden_size}' |
|
if self.num_layers != 1: |
|
s += ', num_layers={num_layers}' |
|
if self.bias is not True: |
|
s += ', bias={bias}' |
|
if self.batch_first is not False: |
|
s += ', batch_first={batch_first}' |
|
if self.dropout != 0: |
|
s += ', dropout={dropout}' |
|
if self.bidirectional is not False: |
|
s += ', bidirectional={bidirectional}' |
|
s += ')' |
|
return s.format(name=self.__class__.__name__, **self.__dict__) |
|
|
|
def __setstate__(self, d): |
|
super(LSTMHardSigmoid, self).__setstate__(d) |
|
self.__dict__.setdefault('_data_ptrs', []) |
|
if 'all_weights' in d: |
|
self._all_weights = d['all_weights'] |
|
if isinstance(self._all_weights[0][0], str): |
|
return |
|
num_layers = self.num_layers |
|
num_directions = 2 if self.bidirectional else 1 |
|
self._all_weights = [] |
|
for layer in range(num_layers): |
|
for direction in range(num_directions): |
|
suffix = '_reverse' if direction == 1 else '' |
|
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] |
|
weights = [x.format(layer, suffix) for x in weights] |
|
if self.bias: |
|
self._all_weights += [weights] |
|
else: |
|
self._all_weights += [weights[:2]] |
|
|
|
@property |
|
def all_weights(self): |
|
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] |
|
|
|
def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False, |
|
dropout=0, train=True, bidirectional=False, batch_sizes=None, |
|
dropout_state=None, flat_weight=None): |
|
|
|
cell = LSTMCell |
|
|
|
if batch_sizes is None: |
|
rec_factory = Recurrent |
|
else: |
|
rec_factory = variable_recurrent_factory(batch_sizes) |
|
|
|
if bidirectional: |
|
layer = (rec_factory(cell), rec_factory(cell, reverse=True)) |
|
else: |
|
layer = (rec_factory(cell),) |
|
|
|
func = StackedRNN(layer, |
|
num_layers, |
|
True, |
|
dropout=dropout, |
|
train=train) |
|
|
|
def forward(input, weight, hidden): |
|
if batch_first and batch_sizes is None: |
|
input = input.transpose(0, 1) |
|
|
|
nexth, output = func(input, hidden, weight) |
|
|
|
if batch_first and batch_sizes is None: |
|
output = output.transpose(0, 1) |
|
|
|
return output, nexth |
|
|
|
return forward |
|
|
|
def Recurrent(inner, reverse=False): |
|
def forward(input, hidden, weight): |
|
output = [] |
|
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) |
|
for i in steps: |
|
hidden = inner(input[i], hidden, *weight) |
|
|
|
output.append(hidden[0] if isinstance(hidden, tuple) else hidden) |
|
|
|
if reverse: |
|
output.reverse() |
|
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) |
|
|
|
return hidden, output |
|
|
|
return forward |
|
|
|
|
|
def variable_recurrent_factory(batch_sizes): |
|
def fac(inner, reverse=False): |
|
if reverse: |
|
return VariableRecurrentReverse(batch_sizes, inner) |
|
else: |
|
return VariableRecurrent(batch_sizes, inner) |
|
return fac |
|
|
|
def VariableRecurrent(batch_sizes, inner): |
|
def forward(input, hidden, weight): |
|
output = [] |
|
input_offset = 0 |
|
last_batch_size = batch_sizes[0] |
|
hiddens = [] |
|
flat_hidden = not isinstance(hidden, tuple) |
|
if flat_hidden: |
|
hidden = (hidden,) |
|
for batch_size in batch_sizes: |
|
step_input = input[input_offset:input_offset + batch_size] |
|
input_offset += batch_size |
|
|
|
dec = last_batch_size - batch_size |
|
if dec > 0: |
|
hiddens.append(tuple(h[-dec:] for h in hidden)) |
|
hidden = tuple(h[:-dec] for h in hidden) |
|
last_batch_size = batch_size |
|
|
|
if flat_hidden: |
|
hidden = (inner(step_input, hidden[0], *weight),) |
|
else: |
|
hidden = inner(step_input, hidden, *weight) |
|
|
|
output.append(hidden[0]) |
|
hiddens.append(hidden) |
|
hiddens.reverse() |
|
|
|
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) |
|
assert hidden[0].size(0) == batch_sizes[0] |
|
if flat_hidden: |
|
hidden = hidden[0] |
|
output = torch.cat(output, 0) |
|
|
|
return hidden, output |
|
|
|
return forward |
|
|
|
|
|
def VariableRecurrentReverse(batch_sizes, inner): |
|
def forward(input, hidden, weight): |
|
output = [] |
|
input_offset = input.size(0) |
|
last_batch_size = batch_sizes[-1] |
|
initial_hidden = hidden |
|
flat_hidden = not isinstance(hidden, tuple) |
|
if flat_hidden: |
|
hidden = (hidden,) |
|
initial_hidden = (initial_hidden,) |
|
hidden = tuple(h[:batch_sizes[-1]] for h in hidden) |
|
for batch_size in reversed(batch_sizes): |
|
inc = batch_size - last_batch_size |
|
if inc > 0: |
|
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) |
|
for h, ih in zip(hidden, initial_hidden)) |
|
last_batch_size = batch_size |
|
step_input = input[input_offset - batch_size:input_offset] |
|
input_offset -= batch_size |
|
|
|
if flat_hidden: |
|
hidden = (inner(step_input, hidden[0], *weight),) |
|
else: |
|
hidden = inner(step_input, hidden, *weight) |
|
output.append(hidden[0]) |
|
|
|
output.reverse() |
|
output = torch.cat(output, 0) |
|
if flat_hidden: |
|
hidden = hidden[0] |
|
return hidden, output |
|
|
|
return forward |
|
|
|
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): |
|
|
|
num_directions = len(inners) |
|
total_layers = num_layers * num_directions |
|
|
|
def forward(input, hidden, weight): |
|
assert(len(weight) == total_layers) |
|
next_hidden = [] |
|
|
|
if lstm: |
|
hidden = list(zip(*hidden)) |
|
|
|
for i in range(num_layers): |
|
all_output = [] |
|
for j, inner in enumerate(inners): |
|
l = i * num_directions + j |
|
|
|
hy, output = inner(input, hidden[l], weight[l]) |
|
next_hidden.append(hy) |
|
all_output.append(output) |
|
|
|
input = torch.cat(all_output, input.dim() - 1) |
|
|
|
if dropout != 0 and i < num_layers - 1: |
|
input = F.dropout(input, p=dropout, training=train, inplace=False) |
|
|
|
if lstm: |
|
next_h, next_c = zip(*next_hidden) |
|
next_hidden = ( |
|
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), |
|
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) |
|
) |
|
else: |
|
next_hidden = torch.cat(next_hidden, 0).view( |
|
total_layers, *next_hidden[0].size()) |
|
|
|
return next_hidden, input |
|
|
|
return forward |
|
|
|
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): |
|
""" |
|
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. |
|
""" |
|
hx, cx = hidden |
|
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) |
|
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
|
|
|
ingate = hard_sigmoid(ingate) |
|
forgetgate = hard_sigmoid(forgetgate) |
|
cellgate = F.tanh(cellgate) |
|
outgate = hard_sigmoid(outgate) |
|
|
|
cy = (forgetgate * cx) + (ingate * cellgate) |
|
hy = outgate * F.tanh(cy) |
|
|
|
return hy, cy |
|
|
|
def hard_sigmoid(x): |
|
""" |
|
Computes element-wise hard sigmoid of x. |
|
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 |
|
""" |
|
x = (0.2 * x) + 0.5 |
|
x = F.threshold(-x, -1, -1) |
|
x = F.threshold(-x, 0, 0) |
|
return x |
|
|