Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.incremental_decoding_utils import ( | |
| FairseqIncrementalState, | |
| with_incremental_state, | |
| ) | |
| from fairseq.modules.fairseq_dropout import FairseqDropout | |
| from torch import Tensor | |
| from .unfold import unfold1d | |
| def DynamicConv( | |
| input_size, | |
| kernel_size=1, | |
| padding_l=None, | |
| num_heads=1, | |
| weight_dropout=0.0, | |
| weight_softmax=False, | |
| renorm_padding=False, | |
| bias=False, | |
| conv_bias=False, | |
| query_size=None, | |
| in_proj=False, | |
| ): | |
| if torch.cuda.is_available(): | |
| try: | |
| from fairseq.modules.dynamicconv_layer import DynamicconvLayer | |
| return DynamicconvLayer( | |
| input_size, | |
| kernel_size=kernel_size, | |
| padding_l=padding_l, | |
| num_heads=num_heads, | |
| weight_dropout=weight_dropout, | |
| weight_softmax=weight_softmax, | |
| renorm_padding=renorm_padding, | |
| bias=bias, | |
| conv_bias=conv_bias, | |
| query_size=query_size, | |
| ) | |
| except ImportError as e: | |
| print(e) | |
| return DynamicConv1dTBC( | |
| input_size, | |
| kernel_size=kernel_size, | |
| padding_l=padding_l, | |
| num_heads=num_heads, | |
| weight_dropout=weight_dropout, | |
| weight_softmax=weight_softmax, | |
| renorm_padding=renorm_padding, | |
| bias=bias, | |
| conv_bias=conv_bias, | |
| query_size=query_size, | |
| ) | |
| def Linear(in_features, out_features, bias=True): | |
| m = nn.Linear(in_features, out_features, bias) | |
| nn.init.xavier_uniform_(m.weight) | |
| if bias: | |
| nn.init.constant_(m.bias, 0.0) | |
| return m | |
| class DynamicConv1dTBC(nn.Module): | |
| """Dynamic lightweight convolution taking T x B x C inputs | |
| Args: | |
| input_size: # of channels of the input | |
| kernel_size: convolution channels | |
| padding_l: padding to the left when using "same" padding | |
| num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) | |
| weight_dropout: the drop rate of the DropConnect to drop the weight | |
| weight_softmax: normalize the weight with softmax before the convolution | |
| renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) | |
| bias: use bias | |
| conv_bias: bias of the convolution | |
| query_size: specified when feeding a different input as the query | |
| in_proj: project the input and generate the filter together | |
| Shape: | |
| Input: TxBxC, i.e. (timesteps, batch_size, input_size) | |
| Output: TxBxC, i.e. (timesteps, batch_size, input_size) | |
| Attributes: | |
| weight: the learnable weights of the module of shape | |
| `(num_heads, 1, kernel_size)` | |
| bias: the learnable bias of the module of shape `(input_size)` | |
| """ | |
| def __init__( | |
| self, | |
| input_size, | |
| kernel_size=1, | |
| padding_l=None, | |
| num_heads=1, | |
| weight_dropout=0.0, | |
| weight_softmax=False, | |
| renorm_padding=False, | |
| bias=False, | |
| conv_bias=False, | |
| query_size=None, | |
| in_proj=False, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.query_size = input_size if query_size is None else query_size | |
| self.kernel_size = kernel_size | |
| self.padding_l = padding_l | |
| self.num_heads = num_heads | |
| self.weight_dropout_module = FairseqDropout( | |
| weight_dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.weight_softmax = weight_softmax | |
| self.renorm_padding = renorm_padding | |
| if in_proj: | |
| self.weight_linear = Linear( | |
| self.input_size, self.input_size + num_heads * kernel_size * 1 | |
| ) | |
| else: | |
| self.weight_linear = Linear( | |
| self.query_size, num_heads * kernel_size * 1, bias=bias | |
| ) | |
| if conv_bias: | |
| self.conv_bias = nn.Parameter(torch.Tensor(input_size)) | |
| else: | |
| self.conv_bias = None | |
| self.reset_parameters() | |
| def in_proj(self): | |
| return ( | |
| self.weight_linear.out_features | |
| == self.input_size + self.num_heads * self.kernel_size | |
| ) | |
| def reset_parameters(self): | |
| self.weight_linear.reset_parameters() | |
| if self.conv_bias is not None: | |
| nn.init.constant_(self.conv_bias, 0.0) | |
| def forward(self, x, incremental_state=None, query=None, unfold=None): | |
| """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C | |
| args: | |
| x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) | |
| incremental_state: A dict to keep the state | |
| unfold: unfold the input or not. If not, we use the matrix trick instead | |
| query: use the specified query to predict the conv filters | |
| """ | |
| unfold = ( | |
| x.size(0) > 512 if unfold is None else unfold | |
| ) # use unfold mode as default for long sequence to save memory | |
| unfold = unfold or (incremental_state is not None) | |
| assert query is None or not self.in_proj | |
| if query is None: | |
| query = x | |
| if unfold: | |
| output = self._forward_unfolded(x, incremental_state, query) | |
| else: | |
| output = self._forward_expanded(x, incremental_state, query) | |
| if self.conv_bias is not None: | |
| output = output + self.conv_bias.view(1, 1, -1) | |
| return output | |
| def _forward_unfolded(self, x, incremental_state, query): | |
| """The conventional implementation of convolutions. | |
| Unfolding the input by having a window shifting to the right.""" | |
| T, B, C = x.size() | |
| K, H = self.kernel_size, self.num_heads | |
| R = C // H | |
| assert R * H == C == self.input_size | |
| if self.in_proj: | |
| proj = self.weight_linear(x) | |
| x = proj.narrow(2, 0, self.input_size).contiguous() | |
| weight = ( | |
| proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) | |
| ) | |
| else: | |
| weight = self.weight_linear(query).view(T * B * H, -1) | |
| # renorm_padding is only implemented in _forward_expanded | |
| assert not self.renorm_padding or incremental_state is not None | |
| if incremental_state is not None: | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is None: | |
| input_buffer = x.new() | |
| x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) | |
| if self.kernel_size > 1: | |
| self._set_input_buffer( | |
| incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] | |
| ) | |
| x_unfold = x_unfold.view(T * B * H, R, -1) | |
| else: | |
| padding_l = self.padding_l | |
| if K > T and padding_l == K - 1: | |
| weight = weight.narrow(1, K - T, T) | |
| K, padding_l = T, T - 1 | |
| # unfold the input: T x B x C --> T' x B x C x K | |
| x_unfold = unfold1d(x, K, padding_l, 0) | |
| x_unfold = x_unfold.view(T * B * H, R, K) | |
| if self.weight_softmax and not self.renorm_padding: | |
| weight = F.softmax(weight, dim=1) | |
| weight = weight.narrow(1, 0, K) | |
| if incremental_state is not None: | |
| weight = weight[:, -x_unfold.size(2) :] | |
| K = weight.size(1) | |
| if self.weight_softmax and self.renorm_padding: | |
| weight = F.softmax(weight, dim=1) | |
| weight = self.weight_dropout_module(weight, inplace=False) | |
| output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 | |
| output = output.view(T, B, C) | |
| return output | |
| def _forward_expanded(self, x, incremental_stat, query): | |
| """Turn the convolution filters into band matrices and do matrix multiplication. | |
| This is faster when the sequence is short, but less memory efficient. | |
| This is not used in the decoder during inference. | |
| """ | |
| T, B, C = x.size() | |
| K, H = self.kernel_size, self.num_heads | |
| R = C // H | |
| assert R * H == C == self.input_size | |
| if self.in_proj: | |
| proj = self.weight_linear(x) | |
| x = proj.narrow(2, 0, self.input_size).contiguous() | |
| weight = ( | |
| proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) | |
| ) | |
| else: | |
| weight = self.weight_linear(query).view(T * B * H, -1) | |
| if not self.renorm_padding: | |
| if self.weight_softmax: | |
| weight = F.softmax(weight, dim=1) | |
| weight = self.weight_dropout_module(weight, inplace=False) | |
| weight = weight.narrow(1, 0, K).contiguous() | |
| weight = weight.view(T, B * H, K).transpose(0, 1) | |
| x = x.view(T, B * H, R).transpose(0, 1) | |
| if self.weight_softmax and self.renorm_padding: | |
| # turn the convolution filters into band matrices | |
| weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) | |
| weight_expanded.as_strided( | |
| (B * H, T, K), (T * (T + K - 1), T + K, 1) | |
| ).copy_(weight) | |
| weight_expanded = weight_expanded.narrow(2, self.padding_l, T) | |
| # normalize the weight over valid positions like self-attention | |
| weight_expanded = F.softmax(weight_expanded, dim=2) | |
| weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) | |
| else: | |
| P = self.padding_l | |
| # For efficiency, we cut the kernel size and reduce the padding when the kernel is larger than the length | |
| if K > T and P == K - 1: | |
| weight = weight.narrow(2, K - T, T) | |
| K, P = T, T - 1 | |
| # turn the convolution filters into band matrices | |
| weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) | |
| weight_expanded.as_strided( | |
| (B * H, T, K), (T * (T + K - 1), T + K, 1) | |
| ).copy_(weight) | |
| weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T | |
| output = torch.bmm(weight_expanded, x) | |
| output = output.transpose(0, 1).contiguous().view(T, B, C) | |
| return output | |
| def reorder_incremental_state(self, incremental_state, new_order): | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is not None: | |
| input_buffer = input_buffer.index_select(1, new_order) | |
| self._set_input_buffer(incremental_state, input_buffer) | |
| def _get_input_buffer(self, incremental_state): | |
| return utils.get_incremental_state(self, incremental_state, "input_buffer") | |
| def _set_input_buffer(self, incremental_state, new_buffer): | |
| return utils.set_incremental_state( | |
| self, incremental_state, "input_buffer", new_buffer | |
| ) | |
| def extra_repr(self): | |
| s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( | |
| self.input_size, | |
| self.kernel_size, | |
| self.padding_l, | |
| self.num_heads, | |
| self.weight_softmax, | |
| self.conv_bias is not None, | |
| self.renorm_padding, | |
| self.in_proj, | |
| ) | |
| if self.query_size != self.input_size: | |
| s += ", query_size={}".format(self.query_size) | |
| if self.weight_dropout_module.p > 0.0: | |
| s += ", weight_dropout={}".format(self.weight_dropout_module.p) | |
| return s | |
| class DynamicConv_scripatable(nn.Module, FairseqIncrementalState): | |
| """Dynamic lightweight convolution taking T x B x C inputs | |
| Args: | |
| input_size: # of channels of the input | |
| kernel_size: convolution channels | |
| padding_l: padding to the left when using "same" padding | |
| num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) | |
| weight_dropout: the drop rate of the DropConnect to drop the weight | |
| weight_softmax: normalize the weight with softmax before the convolution | |
| renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) | |
| bias: use bias | |
| conv_bias: bias of the convolution | |
| query_size: specified when feeding a different input as the query | |
| in_proj: project the input and generate the filter together | |
| Shape: | |
| Input: TxBxC, i.e. (timesteps, batch_size, input_size) | |
| Output: TxBxC, i.e. (timesteps, batch_size, input_size) | |
| Attributes: | |
| weight: the learnable weights of the module of shape | |
| `(num_heads, 1, kernel_size)` | |
| bias: the learnable bias of the module of shape `(input_size)` | |
| """ | |
| def __init__( | |
| self, | |
| input_size, | |
| kernel_size=1, | |
| padding_l=None, | |
| num_heads=1, | |
| weight_dropout=0.0, | |
| weight_softmax=False, | |
| renorm_padding=False, | |
| bias=False, | |
| conv_bias=False, | |
| query_size=None, | |
| in_proj=False, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.query_size = input_size if query_size is None else query_size | |
| self.kernel_size = kernel_size | |
| self.padding_l = padding_l | |
| self.num_heads = num_heads | |
| self.weight_dropout_module = FairseqDropout( | |
| weight_dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.weight_softmax = weight_softmax | |
| self.renorm_padding = renorm_padding | |
| if in_proj: | |
| self.weight_linear = Linear( | |
| self.input_size, self.input_size + num_heads * kernel_size * 1 | |
| ) | |
| else: | |
| self.weight_linear = Linear( | |
| self.query_size, num_heads * kernel_size * 1, bias=bias | |
| ) | |
| self.in_proj = ( | |
| self.weight_linear.out_features | |
| == self.input_size + self.num_heads * self.kernel_size | |
| ) | |
| self.has_conv_bias = conv_bias | |
| self.conv_bias = nn.Parameter(torch.Tensor(input_size).view(1, 1, -1)) | |
| self.init_incremental_state() | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| self.weight_linear.reset_parameters() | |
| if self.has_conv_bias: | |
| nn.init.constant_(self.conv_bias, 0.0) | |
| def forward( | |
| self, | |
| x, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| query: Optional[Tensor] = None, | |
| ): | |
| """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C | |
| args: | |
| x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) | |
| incremental_state: A dict to keep the state | |
| unfold: unfold the input or not. If not, we use the matrix trick instead | |
| query: use the specified query to predict the conv filters | |
| """ | |
| assert query is None or not self.in_proj | |
| if query is None: | |
| query = x | |
| output = self._forward_unfolded(x, incremental_state, query) | |
| if self.has_conv_bias: | |
| output = output + self.conv_bias | |
| return output | |
| def _forward_unfolded( | |
| self, | |
| x, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
| query, | |
| ): | |
| """The conventional implementation of convolutions. | |
| Unfolding the input by having a window shifting to the right.""" | |
| T, B, C = x.size() | |
| K, H = self.kernel_size, self.num_heads | |
| R = C // H | |
| assert R * H == C == self.input_size | |
| TxBxH = T * B * H | |
| if self.in_proj: | |
| proj = self.weight_linear(x) | |
| x = proj.narrow(2, 0, self.input_size).contiguous() | |
| weight = proj.narrow(2, self.input_size, H * K).contiguous().view(TxBxH, -1) | |
| else: | |
| weight = self.weight_linear(query).view(TxBxH, -1) | |
| # renorm_padding is only implemented in _forward_expanded | |
| assert not self.renorm_padding or incremental_state is not None | |
| if incremental_state is not None: | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is not None: | |
| x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) | |
| else: | |
| x_unfold = x.unsqueeze(3).clone() | |
| if self.kernel_size > 1: | |
| self._set_input_buffer( | |
| incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] | |
| ) | |
| x_unfold = x_unfold.view(TxBxH, R, -1) | |
| else: | |
| padding_l = self.padding_l | |
| if K > T and padding_l == K - 1: | |
| weight = weight.narrow(1, K - T, T) | |
| K, padding_l = T, T - 1 | |
| # unfold the input: T x B x C --> T' x B x C x K | |
| x_unfold = unfold1d(x, K, padding_l, 0.0) | |
| x_unfold = x_unfold.view(TxBxH, R, K) | |
| if self.weight_softmax and not self.renorm_padding: | |
| weight = F.softmax(weight, dim=1) | |
| weight = weight.narrow(1, 0, K) | |
| if incremental_state is not None: | |
| weight = weight[:, -(x_unfold.size(2)) :] | |
| K = weight.size(1) | |
| if self.weight_softmax and self.renorm_padding: | |
| weight = F.softmax(weight, dim=1) | |
| weight = self.weight_dropout_module(weight, inplace=False) | |
| output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T x B x H x R x 1 | |
| output = output.view(T, B, C) | |
| return output | |
| def reorder_incremental_state( | |
| self, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
| new_order: Tensor, | |
| ): | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is not None: | |
| input_buffer = input_buffer.index_select(1, new_order) | |
| self._set_input_buffer(incremental_state, input_buffer) | |
| def _get_input_buffer( | |
| self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] | |
| ): | |
| result = self.get_incremental_state(incremental_state, "input_buffer") | |
| if result is not None and "input_buffer" in result: | |
| return result["input_buffer"] | |
| else: | |
| return None | |
| def _set_input_buffer( | |
| self, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
| new_buffer: Optional[Tensor], | |
| ): | |
| result = self.set_incremental_state( | |
| incremental_state, "input_buffer", {"input_buffer": new_buffer} | |
| ) | |
| if result is not None: | |
| incremental_state = result | |
| return incremental_state | |
| def extra_repr(self): | |
| s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( # noqa | |
| self.input_size, | |
| self.kernel_size, | |
| self.padding_l, | |
| self.num_heads, | |
| self.weight_softmax, | |
| self.conv_bias is not None, | |
| self.renorm_padding, | |
| self.in_proj, | |
| ) | |
| if self.query_size != self.input_size: | |
| s += ", query_size={}".format(self.query_size) | |
| if self.weight_dropout_module.p > 0.0: | |
| s += ", weight_dropout={}".format(self.weight_dropout_module.p) | |
| return s | |