Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/modules
/linearized_convolution.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.incremental_decoding_utils import with_incremental_state | |
from .conv_tbc import ConvTBC | |
from typing import Dict, Optional | |
from torch import Tensor | |
class LinearizedConvolution(ConvTBC): | |
"""An optimized version of nn.Conv1d. | |
At training time, this module uses ConvTBC, which is an optimized version | |
of Conv1d. At inference time, it optimizes incremental generation (i.e., | |
one time step at a time) by replacing the convolutions with linear layers. | |
Note that the input order changes from training to inference. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): | |
super().__init__(in_channels, out_channels, kernel_size, **kwargs) | |
self._linearized_weight = None | |
self.register_backward_hook(self._clear_linearized_weight) | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars) | |
# don't store redundant _linearized_weight in checkpoints | |
if prefix + "_linearized_weight" in state: | |
del state[prefix + "_linearized_weight"] | |
return state | |
def upgrade_state_dict_named(self, state_dict, name): | |
prefix = name + "." if name != "" else "" | |
if prefix + "_linearized_weight" in state_dict: | |
del state_dict[prefix + "_linearized_weight"] | |
def forward( | |
self, | |
input, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
): | |
""" | |
Args: | |
incremental_state: Used to buffer signal; if not None, then input is | |
expected to contain a single frame. If the input order changes | |
between time steps, call reorder_incremental_state. | |
Input: | |
Time x Batch x Channel during training | |
Batch x Time x Channel during inference | |
""" | |
if incremental_state is None: | |
output = self.conv_tbc(input) | |
if self.kernel_size[0] > 1 and self.padding[0] > 0: | |
# remove future timesteps added by padding | |
output = output[: -self.padding[0], :, :] | |
return output | |
# reshape weight | |
weight = self._get_linearized_weight() | |
kw = self.kernel_size[0] | |
bsz = input.size(0) # input: bsz x len x dim | |
if kw > 1: | |
input = input.data | |
input_buffer = self._get_input_buffer(incremental_state) | |
if input_buffer is None: | |
input_buffer = input.new(bsz, kw, input.size(2)).zero_() | |
self._set_input_buffer(incremental_state, input_buffer) | |
else: | |
# shift buffer | |
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() | |
# append next input | |
input_buffer[:, -1, :] = input[:, -1, :] | |
input = input_buffer | |
with torch.no_grad(): | |
output = F.linear(input.view(bsz, -1), weight, self.bias) | |
return output.view(bsz, 1, -1) | |
def reorder_incremental_state( | |
self, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
new_order, | |
): | |
input_buffer = self._get_input_buffer(incremental_state) | |
if input_buffer is not None: | |
input_buffer = input_buffer.index_select(0, new_order) | |
self._set_input_buffer(incremental_state, input_buffer) | |
def _get_input_buffer( | |
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] | |
): | |
return utils.get_incremental_state(self, incremental_state, "input_buffer") | |
def _set_input_buffer( | |
self, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
new_buffer, | |
): | |
return utils.set_incremental_state( | |
self, incremental_state, "input_buffer", new_buffer | |
) | |
def _get_linearized_weight(self): | |
if self._linearized_weight is None: | |
kw = self.kernel_size[0] | |
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() | |
assert weight.size() == (self.out_channels, kw, self.in_channels) | |
return weight.view(self.out_channels, -1) | |
return self._linearized_weight | |
def _clear_linearized_weight(self, *args): | |
self._linearized_weight = None | |