| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.autograd import Function |
| | import torch.nn.functional as F |
| |
|
| | import lightconv_cuda |
| | from fairseq import utils |
| | from fairseq.incremental_decoding_utils import with_incremental_state |
| | from fairseq.modules.fairseq_dropout import FairseqDropout |
| |
|
| |
|
| | class lightconvFunction(Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, x, weights, padding_l): |
| | ctx.padding_l = padding_l |
| | outputs = lightconv_cuda.forward(x, weights, padding_l) |
| | variables = [x, weights] |
| | ctx.save_for_backward(*variables) |
| | return outputs[0] |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | outputs = lightconv_cuda.backward( |
| | grad_output.contiguous(), |
| | ctx.padding_l, |
| | *ctx.saved_tensors) |
| | grad_input, grad_weights = outputs |
| | return grad_input, grad_weights, None |
| |
|
| |
|
| | @with_incremental_state |
| | class LightconvLayer(nn.Module): |
| | def __init__( |
| | self, |
| | input_size, |
| | kernel_size=1, |
| | padding_l=None, |
| | weight_softmax=False, |
| | num_heads=1, |
| | weight_dropout=0., |
| | bias=False, |
| | ): |
| | super(LightconvLayer, self).__init__() |
| | self.input_size = input_size |
| | self.kernel_size = kernel_size |
| | self.padding_l = padding_l |
| | self.num_heads = num_heads |
| | self.weight_softmax = weight_softmax |
| | self.weight_dropout_module = FairseqDropout(weight_dropout, module_name=self.__class__.__name__) |
| |
|
| | self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.Tensor(input_size)) |
| | else: |
| | self.bias = None |
| | self.reset_parameters() |
| |
|
| | def upgrade_state_dict_named(self, state_dict, name): |
| | prefix = name + '.' if name != '' else '' |
| | for k, v in state_dict.items(): |
| | if k.endswith(prefix + 'weight'): |
| | if v.dim() == 3 and v.size(1) == 1: |
| | state_dict[k] = v.squeeze(1) |
| |
|
| | def reset_parameters(self): |
| | nn.init.xavier_uniform_(self.weight) |
| | if self.bias is not None: |
| | nn.init.constant_(self.bias, 0.) |
| |
|
| | def forward(self, x, incremental_state=None): |
| |
|
| | |
| | if incremental_state is not None: |
| | T, B, C = x.size() |
| | K, H = self.kernel_size, self.num_heads |
| | R = C // H |
| | input_buffer = self._get_input_buffer(incremental_state) |
| | if input_buffer is None: |
| | input_buffer = x.new() |
| | x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) |
| | if self.kernel_size > 1: |
| | self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) |
| | x_unfold = x_unfold.view(T*B*H, R, -1) |
| |
|
| | weight = self.weight |
| | if self.weight_softmax: |
| | weight = F.softmax(weight.float(), dim=1).type_as(weight) |
| |
|
| | weight = weight[:, -x_unfold.size(2):] |
| |
|
| | K = weight.size(1) |
| |
|
| | weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) |
| |
|
| | weight = self.weight_dropout_module(weight) |
| | output = torch.bmm(x_unfold, weight) |
| | output = output.view(T, B, C) |
| | return output |
| |
|
| | |
| | else: |
| | x = x.permute(1, 2, 0).contiguous() |
| | weight = self.weight |
| | if self.weight_softmax: |
| | weight = F.softmax(self.weight, -1) |
| | if self.weight_dropout_module.p: |
| | weight = self.weight_dropout_module(weight) |
| | return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) |
| |
|
| | def reorder_incremental_state(self, incremental_state, new_order): |
| | input_buffer = self._get_input_buffer(incremental_state) |
| | if input_buffer is not None: |
| | input_buffer = input_buffer.index_select(1, new_order) |
| | self._set_input_buffer(incremental_state, input_buffer) |
| |
|
| | def _get_input_buffer(self, incremental_state): |
| | return utils.get_incremental_state(self, incremental_state, 'input_buffer') |
| |
|
| | def _set_input_buffer(self, incremental_state, new_buffer): |
| | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) |
| |
|
| | def half(self): |
| | return self._apply(lambda t: t.half() if t.is_floating_point() else t) |
| |
|