# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn import torch.nn.functional as F class PreEmphasis(torch.nn.Module): def __init__(self, coef: float = 0.97) -> None: super().__init__() self.coef = coef # make kernel # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. self.register_buffer( "flipped_filter", torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0), ) def forward(self, input: torch.tensor) -> torch.tensor: assert ( len(input.size()) == 2 ), "The number of dimensions of input tensor must be 2!" # reflect padding to match lengths of in/out input = input.unsqueeze(1) input = F.pad(input, (1, 0), "reflect") return F.conv1d(input, self.flipped_filter) class AFMS(nn.Module): """ Alpha-Feature map scaling, added to the output of each residual block[1,2]. Reference: [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page """ def __init__(self, nb_dim: int) -> None: super().__init__() self.alpha = nn.Parameter(torch.ones((nb_dim, 1))) self.fc = nn.Linear(nb_dim, nb_dim) self.sig = nn.Sigmoid() def forward(self, x): y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1) y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1) x = x + self.alpha x = x * y return x class Bottle2neck(nn.Module): def __init__( self, inplanes, planes, kernel_size=None, dilation=None, scale=4, pool=False, ): super().__init__() width = int(math.floor(planes / scale)) self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) self.bn1 = nn.BatchNorm1d(width * scale) self.nums = scale - 1 convs = [] bns = [] num_pad = math.floor(kernel_size / 2) * dilation for i in range(self.nums): convs.append( nn.Conv1d( width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad, ) ) bns.append(nn.BatchNorm1d(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) self.bn3 = nn.BatchNorm1d(planes) self.relu = nn.ReLU() self.width = width self.mp = nn.MaxPool1d(pool) if pool else False self.afms = AFMS(planes) if inplanes != planes: # if change in number of filters self.residual = nn.Sequential( nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False) ) else: self.residual = nn.Identity() def forward(self, x): residual = self.residual(x) out = self.conv1(x) out = self.relu(out) out = self.bn1(out) spx = torch.split(out, self.width, 1) for i in range(self.nums): if i == 0: sp = spx[i] else: sp = sp + spx[i] sp = self.convs[i](sp) sp = self.relu(sp) sp = self.bns[i](sp) if i == 0: out = sp else: out = torch.cat((out, sp), 1) out = torch.cat((out, spx[self.nums]), 1) out = self.conv3(out) out = self.relu(out) out = self.bn3(out) out += residual if self.mp: out = self.mp(out) out = self.afms(out) return out