Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 11,354 Bytes
			
			8e8cd3e  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177  | 
								import argparse
import yamlargparse
import torch.nn as nn
class network_wrapper(nn.Module):
    """
    A wrapper class for loading different neural network models for tasks such as 
    speech enhancement (SE), speech separation (SS), and target speaker extraction (TSE).
    It manages argument parsing, model configuration loading, and model instantiation 
    based on the task and model name.
    """
    
    def __init__(self):
        """
        Initializes the network wrapper without any predefined model or arguments.
        """
        super(network_wrapper, self).__init__()
        self.args = None  # Placeholder for command-line arguments
        self.config_path = None  # Path to the YAML configuration file
        self.model_name = None  # Model name to be loaded based on the task
    def load_args_se(self):
        """
        Loads the arguments for the speech enhancement task using a YAML config file.
        Sets the configuration path and parses all the required parameters such as 
        input/output paths, model settings, and FFT parameters.
        """
        self.config_path = 'config/inference/' + self.model_name + '.yaml'
        parser = yamlargparse.ArgumentParser("Settings")
        # General model and inference settings
        parser.add_argument('--config', help='Config file path', action=yamlargparse.ActionConfigFile)
        parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference')
        parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoints/FRCRN_SE_16K', help='Checkpoint directory')
        parser.add_argument('--input-path', dest='input_path', type=str, help='Path for noisy audio input')
        parser.add_argument('--output-dir', dest='output_dir', type=str, help='Directory for enhanced audio output')
        parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)')
        parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use')
        # Model-specific settings
        parser.add_argument('--network', type=str, help='Select SE models: FRCRN_SE_16K, MossFormer2_SE_48K')
        parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate')
        parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
        parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size')
        # FFT parameters for feature extraction
        parser.add_argument('--window-len', dest='win_len', type=int, default=400, help='Window length for framing')
        parser.add_argument('--window-inc', dest='win_inc', type=int, default=100, help='Window shift for framing')
        parser.add_argument('--fft-len', dest='fft_len', type=int, default=512, help='FFT length for feature extraction')
        parser.add_argument('--num-mels', dest='num_mels', type=int, default=60, help='Number of mel-spectrogram bins')
        parser.add_argument('--window-type', dest='win_type', type=str, default='hamming', help='Window type: hamming or hanning')
        # Parse arguments from the config file
        self.args = parser.parse_args(['--config', self.config_path])
    def load_args_ss(self):
        """
        Loads the arguments for the speech separation task using a YAML config file.
        This method sets parameters such as input/output paths, model configurations, 
        and encoder/decoder settings for the MossFormer2-based speech separation model.
        """
        self.config_path = 'config/inference/' + self.model_name + '.yaml'
        parser = yamlargparse.ArgumentParser("Settings")
        # General model and inference settings
        parser.add_argument('--config', default=self.config_path, help='Config file path', action=yamlargparse.ActionConfigFile)
        parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference')
        parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoints/FRCRN_SE_16K', help='Checkpoint directory')
        parser.add_argument('--input-path', dest='input_path', type=str, help='Path for mixed audio input')
        parser.add_argument('--output-dir', dest='output_dir', type=str, help='Directory for separated audio output')
        parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)')
        parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use')
        # Model-specific settings for speech separation
        parser.add_argument('--network', type=str, help='Select SS models: MossFormer2_SS_16K')
        parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate')
        parser.add_argument('--num-spks', dest='num_spks', type=int, default=2, help='Number of speakers to separate')
        parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
        parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size')
        # Encoder settings
        parser.add_argument('--encoder_kernel-size', dest='encoder_kernel_size', type=int, default=16, help='Kernel size for Conv1D encoder')
        parser.add_argument('--encoder-embedding-dim', dest='encoder_embedding_dim', type=int, default=512, help='Embedding dimension from encoder')
        # MossFormer model parameters
        parser.add_argument('--mossformer-squence-dim', dest='mossformer_sequence_dim', type=int, default=512, help='Sequence dimension for MossFormer')
        parser.add_argument('--num-mossformer_layer', dest='num_mossformer_layer', type=int, default=24, help='Number of MossFormer layers')
        
        # Parse arguments from the config file
        self.args = parser.parse_args(['--config', self.config_path])
    def load_args_tse(self):
        """
        Loads the arguments for the target speaker extraction (TSE) task using a YAML config file.
        Parameters include input/output paths, CUDA configurations, and decoding parameters.
        """
        self.config_path = 'config/inference/' + self.model_name + '.yaml'
        parser = yamlargparse.ArgumentParser("Settings")
        # General model and inference settings
        parser.add_argument('--config', default=self.config_path, help='Config file path', action=yamlargparse.ActionConfigFile)
        parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference')
        parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoint_dir/AV_MossFormer2_TSE_16K', help='Checkpoint directory')
        parser.add_argument('--input-path', dest='input_path', type=str, help='Path for mixed audio input')
        parser.add_argument('--output-dir', dest='output_dir', type=str, help='Directory for separated audio output')
        parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)')
        parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use')
        # Model-specific settings for target speaker extraction
        parser.add_argument('--network', type=str, help='Select TSE models(currently supports AV_MossFormer2_TSE_16K)')
        parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate (currently supports 16 kHz)')
        parser.add_argument('--network_reference', type=dict, help='a dictionary that contains the parameters of auxilary reference signal')
        parser.add_argument('--network_audio', type=dict, help='a dictionary that contains the network parameters')
        # Decode parameters for streaming or chunk-based decoding
        parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
        parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Chunk length for streaming')
        # Parse arguments from the config file
        self.args = parser.parse_args(['--config', self.config_path])
    def __call__(self, task, model_name):
        """
        Calls the appropriate argument-loading function based on the task type 
        (e.g., 'speech_enhancement', 'speech_separation', or 'target_speaker_extraction').
        It then loads the corresponding model based on the selected task and model name.
        
        Args:
        - task (str): The task type ('speech_enhancement', 'speech_separation', 'target_speaker_extraction').
        - model_name (str): The name of the model to load (e.g., 'FRCRN_SE_16K').
        
        Returns:
        - self.network: The instantiated neural network model.
        """
        
        self.model_name = model_name  # Set the model name based on user input
        
        # Load arguments specific to the task
        if task == 'speech_enhancement':
            self.load_args_se()  # Load arguments for speech enhancement
        elif task == 'speech_separation':
            self.load_args_ss()  # Load arguments for speech separation
        elif task == 'target_speaker_extraction':
            self.load_args_tse()  # Load arguments for target speaker extraction
        else:
            # Print error message if the task is unsupported
            print(f'{task} is not supported, please select from: '
                  'speech_enhancement, speech_separation, or target_speaker_extraction')
            return
        print(self.args)  # Display the parsed arguments
        self.args.task = task 
        self.args.network = self.model_name  # Set the network name to the model name
        # Initialize the corresponding network based on the selected model
        if self.args.network == 'FRCRN_SE_16K':
            from networks import CLS_FRCRN_SE_16K
            self.network = CLS_FRCRN_SE_16K(self.args)  # Load FRCRN model
        elif self.args.network == 'MossFormer2_SE_48K':
            from networks import CLS_MossFormer2_SE_48K
            self.network = CLS_MossFormer2_SE_48K(self.args)  # Load MossFormer2 model
        elif self.args.network == 'MossFormerGAN_SE_16K':
            from networks import CLS_MossFormerGAN_SE_16K
            self.network = CLS_MossFormerGAN_SE_16K(self.args)  # Load MossFormerGAN model
        elif self.args.network == 'MossFormer2_SS_16K':
            from networks import CLS_MossFormer2_SS_16K
            self.network = CLS_MossFormer2_SS_16K(self.args)  # Load MossFormer2 for separation
        elif self.args.network == 'AV_MossFormer2_TSE_16K':
            from networks import CLS_AV_MossFormer2_TSE_16K
            self.network = CLS_AV_MossFormer2_TSE_16K(self.args)  # Load AV MossFormer2 model for target speaker extraction
        else:
            # Print error message if no matching network is found
            print("No network found!")
            return
        
        return self.network  # Return the instantiated network model
 |