Spaces:
Paused
Paused
| # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| from collections import OrderedDict | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear | |
| class FCM(nn.Module): | |
| def __init__(self, | |
| block=BasicResBlock, | |
| num_blocks=[2, 2], | |
| m_channels=32, | |
| feat_dim=80): | |
| super(FCM, self).__init__() | |
| self.in_planes = m_channels | |
| self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(m_channels) | |
| self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) | |
| self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2) | |
| self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(m_channels) | |
| self.out_channels = m_channels * (feat_dim // 8) | |
| def _make_layer(self, block, planes, num_blocks, stride): | |
| strides = [stride] + [1] * (num_blocks - 1) | |
| layers = [] | |
| for stride in strides: | |
| layers.append(block(self.in_planes, planes, stride)) | |
| self.in_planes = planes * block.expansion | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = x.unsqueeze(1) | |
| out = F.relu(self.bn1(self.conv1(x))) | |
| out = self.layer1(out) | |
| out = self.layer2(out) | |
| out = F.relu(self.bn2(self.conv2(out))) | |
| shape = out.shape | |
| out = out.reshape(shape[0], shape[1]*shape[2], shape[3]) | |
| return out | |
| class CAMPPlus(nn.Module): | |
| def __init__(self, | |
| feat_dim=80, | |
| embedding_size=512, | |
| growth_rate=32, | |
| bn_size=4, | |
| init_channels=128, | |
| config_str='batchnorm-relu', | |
| memory_efficient=True): | |
| super(CAMPPlus, self).__init__() | |
| self.head = FCM(feat_dim=feat_dim) | |
| channels = self.head.out_channels | |
| self.xvector = nn.Sequential( | |
| OrderedDict([ | |
| ('tdnn', | |
| TDNNLayer(channels, | |
| init_channels, | |
| 5, | |
| stride=2, | |
| dilation=1, | |
| padding=-1, | |
| config_str=config_str)), | |
| ])) | |
| channels = init_channels | |
| for i, (num_layers, kernel_size, | |
| dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))): | |
| block = CAMDenseTDNNBlock(num_layers=num_layers, | |
| in_channels=channels, | |
| out_channels=growth_rate, | |
| bn_channels=bn_size * growth_rate, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| config_str=config_str, | |
| memory_efficient=memory_efficient) | |
| self.xvector.add_module('block%d' % (i + 1), block) | |
| channels = channels + num_layers * growth_rate | |
| self.xvector.add_module( | |
| 'transit%d' % (i + 1), | |
| TransitLayer(channels, | |
| channels // 2, | |
| bias=False, | |
| config_str=config_str)) | |
| channels //= 2 | |
| self.xvector.add_module( | |
| 'out_nonlinear', get_nonlinear(config_str, channels)) | |
| self.xvector.add_module('stats', StatsPool()) | |
| self.xvector.add_module( | |
| 'dense', | |
| DenseLayer(channels * 2, embedding_size, config_str='batchnorm_')) | |
| for m in self.modules(): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.kaiming_normal_(m.weight.data) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x): | |
| x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) | |
| x = self.head(x) | |
| x = self.xvector(x) | |
| return x |