# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from torch import nn from torch.nn import functional as F class Conv1d(nn.Conv1d): """Extended nn.Conv1d for incremental dilated convolutions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.clear_buffer() self._linearized_weight = None self.register_backward_hook(self._clear_linearized_weight) def incremental_forward(self, input): # input (B, T, C) # run forward pre hooks for hook in self._forward_pre_hooks.values(): hook(self, input) # reshape weight weight = self._get_linearized_weight() kw = self.kernel_size[0] dilation = self.dilation[0] bsz = input.size(0) if kw > 1: input = input.data if self.input_buffer is None: self.input_buffer = input.new( bsz, kw + (kw - 1) * (dilation - 1), input.size(2) ) self.input_buffer.zero_() else: # shift buffer self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() # append next input self.input_buffer[:, -1, :] = input[:, -1, :] input = self.input_buffer if dilation > 1: input = input[:, 0::dilation, :].contiguous() output = F.linear(input.view(bsz, -1), weight, self.bias) return output.view(bsz, 1, -1) def clear_buffer(self): self.input_buffer = None def _get_linearized_weight(self): if self._linearized_weight is None: kw = self.kernel_size[0] # nn.Conv1d if self.weight.size() == (self.out_channels, self.in_channels, kw): weight = self.weight.transpose(1, 2).contiguous() else: # fairseq.modules.conv_tbc.ConvTBC weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() assert weight.size() == (self.out_channels, kw, self.in_channels) self._linearized_weight = weight.view(self.out_channels, -1) return self._linearized_weight def _clear_linearized_weight(self, *args): self._linearized_weight = None