import torch import torch.nn.functional as F class PreEmphasis(torch.nn.Module): def __init__(self, coef: float = 0.97): super().__init__() self.coef = coef # make kernel # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. self.register_buffer( 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) ) def forward(self, input: torch.tensor) -> torch.tensor: assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!' # reflect padding to match lengths of in/out input = input.unsqueeze(1) input = F.pad(input, (1, 0), 'reflect') return F.conv1d(input, self.flipped_filter).squeeze(1)