from models.mossformer2_se.mossformer2 import MossFormer_MaskNet import torch.nn as nn class MossFormer2_SE_48K(nn.Module): """ The MossFormer2_SE_48K model for speech enhancement. This class encapsulates the functionality of the MossFormer MaskNet within a higher-level model. It processes input audio data to produce enhanced outputs and corresponding masks. Arguments --------- args : Namespace Configuration arguments that may include hyperparameters and model settings (not utilized in this implementation but can be extended for flexibility). Example --------- >>> model = MossFormer2_SE_48K(args).model >>> x = torch.randn(10, 180, 2000) # Example input >>> outputs, mask = model(x) # Forward pass >>> outputs.shape, mask.shape # Check output shapes """ def __init__(self, args): super(MossFormer2_SE_48K, self).__init__() # Initialize the TestNet model, which contains the MossFormer MaskNet self.model = TestNet() # Instance of TestNet def forward(self, x): """ Forward pass through the model. Arguments --------- x : torch.Tensor Input tensor of dimension [B, N, S], where B is the batch size, N is the number of channels (180 in this case), and S is the sequence length (e.g., time frames). Returns ------- outputs : torch.Tensor Enhanced audio output tensor from the model. mask : torch.Tensor Mask tensor predicted by the model for speech separation. """ outputs, mask = self.model(x) # Get outputs and mask from TestNet return outputs, mask # Return the outputs and mask class TestNet(nn.Module): """ The TestNet class for testing the MossFormer MaskNet implementation. This class builds a model that integrates the MossFormer_MaskNet for processing input audio and generating masks for source separation. Arguments --------- n_layers : int The number of layers in the model. It determines the depth of the model architecture, we leave this para unused at this moment. """ def __init__(self, n_layers=18): super(TestNet, self).__init__() self.n_layers = n_layers # Set the number of layers # Initialize the MossFormer MaskNet with specified input and output channels self.mossformer = MossFormer_MaskNet(in_channels=180, out_channels=512, out_channels_final=961) def forward(self, input): """ Forward pass through the TestNet model. Arguments --------- input : torch.Tensor Input tensor of dimension [B, N, S], where B is the batch size, N is the number of input channels (180), and S is the sequence length. Returns ------- out_list : list List containing the mask tensor predicted by the MossFormer_MaskNet. """ out_list = [] # Initialize output list to store outputs # Transpose input to match expected shape for MaskNet x = input.transpose(1, 2) # Change shape from [B, N, S] to [B, S, N] # Get the mask from the MossFormer MaskNet mask = self.mossformer(x) # Forward pass through the MossFormer_MaskNet out_list.append(mask) # Append the mask to the output list return out_list # Return the list containing the mask