yuancwang
init
b725c5a
raw
history blame
No virus
4.1 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from asteroid_filterbanks import Encoder, ParamSincFB
from .RawNetBasicBlock import Bottle2neck, PreEmphasis
class RawNet3(nn.Module):
def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
super().__init__()
nOut = kwargs["nOut"]
self.context = context
self.encoder_type = kwargs["encoder_type"]
self.log_sinc = kwargs["log_sinc"]
self.norm_sinc = kwargs["norm_sinc"]
self.out_bn = kwargs["out_bn"]
self.summed = summed
self.preprocess = nn.Sequential(
PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
)
self.conv1 = Encoder(
ParamSincFB(
C // 4,
251,
stride=kwargs["sinc_stride"],
)
)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(C // 4)
self.layer1 = block(
C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
)
self.layer2 = block(C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3)
self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
if self.context:
attn_input = 1536 * 3
else:
attn_input = 1536
print("self.encoder_type", self.encoder_type)
if self.encoder_type == "ECA":
attn_output = 1536
elif self.encoder_type == "ASP":
attn_output = 1
else:
raise ValueError("Undefined encoder")
self.attention = nn.Sequential(
nn.Conv1d(attn_input, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Conv1d(128, attn_output, kernel_size=1),
nn.Softmax(dim=2),
)
self.bn5 = nn.BatchNorm1d(3072)
self.fc6 = nn.Linear(3072, nOut)
self.bn6 = nn.BatchNorm1d(nOut)
self.mp3 = nn.MaxPool1d(3)
def forward(self, x):
"""
:param x: input mini-batch (bs, samp)
"""
with torch.cuda.amp.autocast(enabled=False):
x = self.preprocess(x)
x = torch.abs(self.conv1(x))
if self.log_sinc:
x = torch.log(x + 1e-6)
if self.norm_sinc == "mean":
x = x - torch.mean(x, dim=-1, keepdim=True)
elif self.norm_sinc == "mean_std":
m = torch.mean(x, dim=-1, keepdim=True)
s = torch.std(x, dim=-1, keepdim=True)
s[s < 0.001] = 0.001
x = (x - m) / s
if self.summed:
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(self.mp3(x1) + x2)
else:
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
x = self.relu(x)
t = x.size()[-1]
if self.context:
global_x = torch.cat(
(
x,
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
torch.sqrt(
torch.var(x, dim=2, keepdim=True).clamp(min=1e-4, max=1e4)
).repeat(1, 1, t),
),
dim=1,
)
else:
global_x = x
w = self.attention(global_x)
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt(
(torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
)
x = torch.cat((mu, sg), 1)
x = self.bn5(x)
x = self.fc6(x)
if self.out_bn:
x = self.bn6(x)
return x
def MainModel(**kwargs):
model = RawNet3(Bottle2neck, model_scale=8, context=True, summed=True, **kwargs)
return model