| |
| __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
|
|
| import argparse |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import yaml |
| import os |
| import soundfile as sf |
| from ml_collections import ConfigDict |
| from omegaconf import OmegaConf |
| from tqdm.auto import tqdm |
| from typing import Dict, List, Tuple, Any, Union |
| import loralib as lora |
|
|
| |
| try: |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| except ImportError: |
| plt = None |
|
|
|
|
| def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]: |
| """ |
| Load the configuration from the specified path based on the model type. |
| |
| Parameters: |
| ---------- |
| model_type : str |
| The type of model to load (e.g., 'htdemucs', 'mdx23c', etc.). |
| config_path : str |
| The path to the YAML or OmegaConf configuration file. |
| |
| Returns: |
| ------- |
| config : Any |
| The loaded configuration, which can be in different formats (e.g., OmegaConf or ConfigDict). |
| |
| Raises: |
| ------ |
| FileNotFoundError: |
| If the configuration file at `config_path` is not found. |
| ValueError: |
| If there is an error loading the configuration file. |
| """ |
| try: |
| with open(config_path, 'r') as f: |
| if model_type == 'htdemucs': |
| config = OmegaConf.load(config_path) |
| else: |
| config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) |
| return config |
| except FileNotFoundError: |
| raise FileNotFoundError(f"Configuration file not found at {config_path}") |
| except Exception as e: |
| raise ValueError(f"Error loading configuration: {e}") |
|
|
|
|
| def get_model_from_config(model_type: str, config_path: str) -> Tuple: |
| """ |
| Load the model specified by the model type and configuration file. |
| |
| Parameters: |
| ---------- |
| model_type : str |
| The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.). |
| config_path : str |
| The path to the configuration file (YAML or OmegaConf format). |
| |
| Returns: |
| ------- |
| model : nn.Module or None |
| The initialized model based on the `model_type`, or None if the model type is not recognized. |
| config : Any |
| The configuration used to initialize the model. This could be in different formats |
| depending on the model type (e.g., OmegaConf, ConfigDict). |
| |
| Raises: |
| ------ |
| ValueError: |
| If the `model_type` is unknown or an error occurs during model initialization. |
| """ |
|
|
| config = load_config(model_type, config_path) |
|
|
| if model_type == 'mdx23c': |
| from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net |
| model = TFC_TDF_net(config) |
| elif model_type == 'htdemucs': |
| from models.demucs4ht import get_model |
| model = get_model(config) |
| elif model_type == 'segm_models': |
| from models.segm_models import Segm_Models_Net |
| model = Segm_Models_Net(config) |
| elif model_type == 'torchseg': |
| from models.torchseg_models import Torchseg_Net |
| model = Torchseg_Net(config) |
| elif model_type == 'mel_band_roformer': |
| from models.bs_roformer import MelBandRoformer |
| model = MelBandRoformer(**dict(config.model)) |
| elif model_type == 'bs_roformer': |
| from models.bs_roformer import BSRoformer |
| model = BSRoformer(**dict(config.model)) |
| elif model_type == 'swin_upernet': |
| from models.upernet_swin_transformers import Swin_UperNet_Model |
| model = Swin_UperNet_Model(config) |
| elif model_type == 'bandit': |
| from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple |
| model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) |
| elif model_type == 'bandit_v2': |
| from models.bandit_v2.bandit import Bandit |
| model = Bandit(**config.kwargs) |
| elif model_type == 'scnet_unofficial': |
| from models.scnet_unofficial import SCNet |
| model = SCNet(**config.model) |
| elif model_type == 'scnet': |
| from models.scnet import SCNet |
| model = SCNet(**config.model) |
| elif model_type == 'apollo': |
| from models.look2hear.models import BaseModel |
| model = BaseModel.apollo(**config.model) |
| elif model_type == 'bs_mamba2': |
| from models.ts_bs_mamba2 import Separator |
| model = Separator(**config.model) |
| elif model_type == 'experimental_mdx23c_stht': |
| from models.mdx23c_tfc_tdf_v3_with_STHT import TFC_TDF_net |
| model = TFC_TDF_net(config) |
| elif model_type == 'mel_band_roformer_experimental': |
| from models.bs_roformer.mel_band_roformer_experimental import MelBandRoformer |
| model = MelBandRoformer(**dict(config.model)) |
| elif model_type == 'bs_roformer_experimental': |
| from models.bs_roformer.bs_roformer_experimental import BSRoformer |
| model = BSRoformer(**dict(config.model)) |
| elif model_type == 'bs_roformer_custom': |
| from models.bs_roformer.bs_roformer_custom.bs_roformer import BSRoformer |
| model = BSRoformer(**dict(config.model)) |
| elif model_type == 'scnet_tran': |
| from models.scnet.scnet_tran import SCNet_Tran |
| model = SCNet_Tran(**config.model) |
| elif model_type == 'scnet_masked': |
| from models.scnet.scnet_masked import SCNet |
| model = SCNet(**config.model) |
| elif model_type == 'conformer': |
| from models.conformer_model import ConformerMSS, NeuralModel |
| model = ConformerMSS( |
| core=NeuralModel(**config.model), |
| n_fft=config.stft.n_fft, |
| hop_length=config.stft.hop_length, |
| win_length=getattr(config.stft, 'win_length', config.stft.n_fft), |
| center=config.stft.center |
| ) |
| elif model_type == 'mel_band_conformer': |
| from models.mel_band_conformer import MelBandConformer |
| model = MelBandConformer(**config.model) |
| else: |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
| return model, config |
|
|
|
|
| def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]: |
| """ |
| Reads an audio file, ensuring mono audio is converted to two-dimensional format, |
| and transposes the data to have channels as the first dimension. |
| Parameters |
| ---------- |
| path : str |
| Path to the audio file. |
| skip_err: bool |
| If true, not raise errors |
| instr: |
| name of instument |
| Returns |
| ------- |
| Tuple[np.ndarray, int] |
| A tuple containing: |
| - Transposed audio data as a NumPy array with shape (channels, length). |
| For mono audio, the shape will be (1, length). |
| - Sampling rate (int), e.g., 44100. |
| """ |
|
|
| try: |
| mix, sr = sf.read(path) |
| except Exception as e: |
| if skip_err: |
| print(f"No stem {instr}: skip!") |
| return None, None |
| else: |
| raise RuntimeError(f"Error reading the file at {path}: {e}") |
| else: |
| if len(mix.shape) == 1: |
| mix = np.expand_dims(mix, axis=-1) |
| return mix.T, sr |
|
|
|
|
| def normalize_audio(audio: np.ndarray) -> tuple[np.ndarray, Dict[str, float]]: |
| """ |
| Normalize an audio signal by subtracting the mean and dividing by the standard deviation. |
| |
| Parameters: |
| ---------- |
| audio : np.ndarray |
| Input audio array with shape (channels, time) or (time,). |
| |
| Returns: |
| ------- |
| tuple[np.ndarray, dict[str, float]] |
| - Normalized audio array with the same shape as the input. |
| - Dictionary containing the mean and standard deviation of the original audio. |
| """ |
|
|
| mono = audio.mean(0) |
| mean, std = mono.mean(), mono.std() |
| return (audio - mean) / std, {"mean": mean, "std": std} |
|
|
|
|
| def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray: |
| """ |
| Denormalize an audio signal by reversing the normalization process (multiplying by the standard deviation |
| and adding the mean). |
| |
| Parameters: |
| ---------- |
| audio : np.ndarray |
| Normalized audio array to be denormalized. |
| norm_params : dict[str, float] |
| Dictionary containing the 'mean' and 'std' values used for normalization. |
| |
| Returns: |
| ------- |
| np.ndarray |
| Denormalized audio array with the same shape as the input. |
| """ |
|
|
| return audio * norm_params["std"] + norm_params["mean"] |
|
|
|
|
| def apply_tta( |
| config, |
| model: torch.nn.Module, |
| mix: torch.Tensor, |
| waveforms_orig: Dict[str, torch.Tensor], |
| device: torch.device, |
| model_type: str |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Apply Test-Time Augmentation (TTA) for source separation. |
| |
| This function processes the input mixture with test-time augmentations, including |
| channel inversion and polarity inversion, to enhance the separation results. The |
| results from all augmentations are averaged to produce the final output. |
| |
| Parameters: |
| ---------- |
| config : Any |
| Configuration object containing model and processing parameters. |
| model : torch.nn.Module |
| The trained model used for source separation. |
| mix : torch.Tensor |
| The mixed audio tensor with shape (channels, time). |
| waveforms_orig : Dict[str, torch.Tensor] |
| Dictionary of original separated waveforms (before TTA) for each instrument. |
| device : torch.device |
| Device (CPU or CUDA) on which the model will be executed. |
| model_type : str |
| Type of the model being used (e.g., "demucs", "custom_model"). |
| |
| Returns: |
| ------- |
| Dict[str, torch.Tensor] |
| Updated dictionary of separated waveforms after applying TTA. |
| """ |
| |
| track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()] |
|
|
| |
| for i, augmented_mix in enumerate(track_proc_list): |
| waveforms = demix(config, model, augmented_mix, device, model_type=model_type) |
| for el in waveforms: |
| if i == 0: |
| waveforms_orig[el] += waveforms[el][::-1].copy() |
| else: |
| waveforms_orig[el] -= waveforms[el] |
|
|
| |
| for el in waveforms_orig: |
| waveforms_orig[el] /= len(track_proc_list) + 1 |
|
|
| return waveforms_orig |
|
|
|
|
| def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: |
| """ |
| Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end. |
| |
| This function creates a window of size `window_size` where the first `fade_size` elements |
| linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease |
| from 1 to 0 (fade-out). The middle part of the window is filled with ones. |
| |
| Parameters: |
| ---------- |
| window_size : int |
| The total size of the window. |
| fade_size : int |
| The size of the fade-in and fade-out regions. |
| |
| Returns: |
| ------- |
| torch.Tensor |
| A tensor of shape (window_size,) containing the generated windowing array. |
| |
| Example: |
| ------- |
| If `window_size=10` and `fade_size=3`, the output will be: |
| tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000]) |
| """ |
|
|
| fadein = torch.linspace(0, 1, fade_size) |
| fadeout = torch.linspace(1, 0, fade_size) |
|
|
| window = torch.ones(window_size) |
| window[-fade_size:] = fadeout |
| window[:fade_size] = fadein |
| return window |
|
|
|
|
| def demix( |
| config: ConfigDict, |
| model: torch.nn.Module, |
| mix: torch.Tensor, |
| device: torch.device, |
| model_type: str, |
| pbar: bool = False |
| ) -> Tuple[List[Dict[str, np.ndarray]], np.ndarray]: |
| """ |
| Unified function for audio source separation with support for multiple processing modes. |
| |
| This function separates audio into its constituent sources using either a generic custom logic |
| or a Demucs-specific logic. It supports batch processing and overlapping window-based chunking |
| for efficient and artifact-free separation. |
| |
| Parameters: |
| ---------- |
| config : ConfigDict |
| Configuration object containing audio and inference settings. |
| model : torch.nn.Module |
| The trained model used for audio source separation. |
| mix : torch.Tensor |
| Input audio tensor with shape (channels, time). |
| device : torch.device |
| The computation device (CPU or CUDA). |
| model_type : str, optional |
| Processing mode: |
| - "demucs" for logic specific to the Demucs model. |
| Default is "generic". |
| pbar : bool, optional |
| If True, displays a progress bar during chunk processing. Default is False. |
| |
| Returns: |
| ------- |
| Union[Dict[str, np.ndarray], np.ndarray] |
| - A dictionary mapping target instruments to separated audio sources if multiple instruments are present. |
| - A numpy array of the separated source if only one instrument is present. |
| """ |
|
|
| mix = torch.tensor(mix, dtype=torch.float32) |
|
|
| if model_type == 'htdemucs': |
| mode = 'demucs' |
| else: |
| mode = 'generic' |
| |
| if mode == 'demucs': |
| chunk_size = config.training.samplerate * config.training.segment |
| num_instruments = len(config.training.instruments) |
| num_overlap = config.inference.num_overlap |
| step = chunk_size // num_overlap |
| else: |
| chunk_size = config.audio.chunk_size |
| num_instruments = len(prefer_target_instrument(config)) |
| num_overlap = config.inference.num_overlap |
|
|
| fade_size = chunk_size // 10 |
| step = chunk_size // num_overlap |
| border = chunk_size - step |
| length_init = mix.shape[-1] |
| windowing_array = _getWindowingArray(chunk_size, fade_size) |
| |
| if length_init > 2 * border and border > 0: |
| mix = nn.functional.pad(mix, (border, border), mode="reflect") |
|
|
| batch_size = config.inference.batch_size |
|
|
| use_amp = getattr(config.training, 'use_amp', True) |
|
|
| with torch.cuda.amp.autocast(enabled=use_amp): |
| with torch.inference_mode(): |
| |
| req_shape = (num_instruments,) + mix.shape |
| result = torch.zeros(req_shape, dtype=torch.float32) |
| counter = torch.zeros(req_shape, dtype=torch.float32) |
|
|
| i = 0 |
| batch_data = [] |
| batch_locations = [] |
| progress_bar = tqdm( |
| total=mix.shape[1], desc="Processing audio chunks", leave=False |
| ) if pbar else None |
|
|
| while i < mix.shape[1]: |
| |
| part = mix[:, i:i + chunk_size].to(device) |
| chunk_len = part.shape[-1] |
| if mode == "generic" and chunk_len > chunk_size // 2: |
| pad_mode = "reflect" |
| else: |
| pad_mode = "constant" |
| part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) |
|
|
| batch_data.append(part) |
| batch_locations.append((i, chunk_len)) |
| i += step |
|
|
| |
| if len(batch_data) >= batch_size or i >= mix.shape[1]: |
| arr = torch.stack(batch_data, dim=0) |
| x = model(arr) |
|
|
| if mode == "generic": |
| window = windowing_array.clone() |
| if i - step == 0: |
| window[:fade_size] = 1 |
| elif i >= mix.shape[1]: |
| window[-fade_size:] = 1 |
|
|
| for j, (start, seg_len) in enumerate(batch_locations): |
| if mode == "generic": |
| result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len] |
| counter[..., start:start + seg_len] += window[..., :seg_len] |
| else: |
| result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() |
| counter[..., start:start + seg_len] += 1.0 |
|
|
| batch_data.clear() |
| batch_locations.clear() |
|
|
| if progress_bar: |
| progress_bar.update(step) |
|
|
| if progress_bar: |
| progress_bar.close() |
|
|
| |
| estimated_sources = result / counter |
| estimated_sources = estimated_sources.cpu().numpy() |
| np.nan_to_num(estimated_sources, copy=False, nan=0.0) |
|
|
| |
| if mode == "generic": |
| if length_init > 2 * border and border > 0: |
| estimated_sources = estimated_sources[..., border:-border] |
|
|
| |
| if mode == "demucs": |
| instruments = config.training.instruments |
| else: |
| instruments = prefer_target_instrument(config) |
|
|
| ret_data = {k: v for k, v in zip(instruments, estimated_sources)} |
|
|
| if mode == "demucs" and num_instruments <= 1: |
| return estimated_sources |
| else: |
| return ret_data |
|
|
|
|
| def prefer_target_instrument(config: ConfigDict) -> List[str]: |
| """ |
| Return the list of target instruments based on the configuration. |
| If a specific target instrument is specified in the configuration, |
| it returns a list with that instrument. Otherwise, it returns the list of instruments. |
| |
| Parameters: |
| ---------- |
| config : ConfigDict |
| Configuration object containing the list of instruments or the target instrument. |
| |
| Returns: |
| ------- |
| List[str] |
| A list of target instruments. |
| """ |
| if getattr(config.training, 'target_instrument', None): |
| return [config.training.target_instrument] |
| else: |
| return config.training.instruments |
|
|
|
|
| def load_not_compatible_weights(model: torch.nn.Module, weights: str, verbose: bool = False) -> None: |
| """ |
| Load weights into a model, handling mismatched shapes and dimensions. |
| |
| Args: |
| model: PyTorch model into which the weights will be loaded. |
| weights: Path to the weights file. |
| verbose: If True, prints detailed information about matching and mismatched layers. |
| """ |
|
|
| new_model = model.state_dict() |
| old_model = torch.load(weights, weights_only=False) |
| if 'state' in old_model: |
| |
| old_model = old_model['state'] |
| if 'state_dict' in old_model: |
| |
| old_model = old_model['state_dict'] |
|
|
| for el in new_model: |
| if el in old_model: |
| if verbose: |
| print(f'Match found for {el}!') |
| if new_model[el].shape == old_model[el].shape: |
| if verbose: |
| print('Action: Just copy weights!') |
| new_model[el] = old_model[el] |
| else: |
| if len(new_model[el].shape) != len(old_model[el].shape): |
| if verbose: |
| print('Action: Different dimension! Too lazy to write the code... Skip it') |
| else: |
| if verbose: |
| print(f'Shape is different: {tuple(new_model[el].shape)} != {tuple(old_model[el].shape)}') |
| ln = len(new_model[el].shape) |
| max_shape = [] |
| slices_old = [] |
| slices_new = [] |
| for i in range(ln): |
| max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i])) |
| slices_old.append(slice(0, old_model[el].shape[i])) |
| slices_new.append(slice(0, new_model[el].shape[i])) |
| |
| |
| slices_old = tuple(slices_old) |
| slices_new = tuple(slices_new) |
| max_matrix = np.zeros(max_shape, dtype=np.float32) |
| for i in range(ln): |
| max_matrix[slices_old] = old_model[el].cpu().numpy() |
| max_matrix = torch.from_numpy(max_matrix) |
| new_model[el] = max_matrix[slices_new] |
| else: |
| if verbose: |
| print(f'Match not found for {el}!') |
| model.load_state_dict( |
| new_model |
| ) |
|
|
|
|
| def load_lora_weights(model: torch.nn.Module, lora_path: str, device: str = 'cpu') -> None: |
| """ |
| Load LoRA weights into a model. |
| This function updates the given model with LoRA-specific weights from the specified checkpoint file. |
| It does not require the checkpoint to match the model's full state dictionary, as only LoRA layers are updated. |
| |
| Parameters: |
| ---------- |
| model : Module |
| The PyTorch model into which the LoRA weights will be loaded. |
| lora_path : str |
| Path to the LoRA checkpoint file. |
| device : str, optional |
| The device to load the weights onto, by default 'cpu'. Common values are 'cpu' or 'cuda'. |
| |
| Returns: |
| ------- |
| None |
| The model is updated in place. |
| """ |
| lora_state_dict = torch.load(lora_path, map_location=device, weights_only=False) |
| model.load_state_dict(lora_state_dict, strict=False) |
|
|
|
|
| def load_start_checkpoint(args: argparse.Namespace, model: torch.nn.Module, type_='train') -> None: |
| """ |
| Load the starting checkpoint for a model. |
| |
| Args: |
| args: Parsed command-line arguments containing the checkpoint path. |
| model: PyTorch model to load the checkpoint into. |
| type_: how to load weights - for train we can load not fully compatible weights |
| """ |
|
|
| print(f'Start from checkpoint: {args.start_check_point}') |
| if type_ in ['train']: |
| if 1: |
| load_not_compatible_weights(model, args.start_check_point, verbose=False) |
| else: |
| model.load_state_dict(torch.load(args.start_check_point, weights_only=False)) |
| else: |
| device='cpu' |
| if args.model_type in ['htdemucs', 'apollo']: |
| state_dict = torch.load(args.start_check_point, map_location=device, weights_only=False) |
| |
| if 'state' in state_dict: |
| state_dict = state_dict['state'] |
| |
| if 'state_dict' in state_dict: |
| state_dict = state_dict['state_dict'] |
| else: |
| state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True) |
| model.load_state_dict(state_dict) |
|
|
|
|
| def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module: |
| """ |
| Replaces specific layers in the model with LoRA-extended versions. |
| |
| Parameters: |
| ---------- |
| config : Dict[str, Any] |
| Configuration containing parameters for LoRA. It should include a 'lora' key with parameters for `MergedLinear`. |
| model : nn.Module |
| The original model in which the layers will be replaced. |
| |
| Returns: |
| ------- |
| nn.Module |
| The modified model with the replaced layers. |
| """ |
|
|
| if 'lora' not in config: |
| raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.") |
|
|
| replaced_layers = 0 |
|
|
| for name, module in model.named_modules(): |
| hierarchy = name.split('.') |
| layer_name = hierarchy[-1] |
|
|
| |
| if isinstance(module, nn.Linear): |
| try: |
| |
| parent_module = model |
| for submodule_name in hierarchy[:-1]: |
| parent_module = getattr(parent_module, submodule_name) |
|
|
| |
| setattr( |
| parent_module, |
| layer_name, |
| lora.MergedLinear( |
| in_features=module.in_features, |
| out_features=module.out_features, |
| bias=module.bias is not None, |
| **config['lora'] |
| ) |
| ) |
| replaced_layers += 1 |
|
|
| except Exception as e: |
| print(f"Error replacing layer {name}: {e}") |
|
|
| if replaced_layers == 0: |
| print("Warning: No layers were replaced. Check the model structure and configuration.") |
| else: |
| print(f"Number of layers replaced with LoRA: {replaced_layers}") |
|
|
| return model |
|
|
|
|
| def draw_spectrogram(waveform, sample_rate, length, output_file): |
| if plt is None: |
| print("Warning: matplotlib not available, skipping spectrogram") |
| return |
| |
| import librosa.display |
|
|
| |
| x = waveform[:int(length * sample_rate), :] |
| X = librosa.stft(x.mean(axis=-1)) |
| Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) |
| fig, ax = plt.subplots() |
| |
| img = librosa.display.specshow( |
| Xdb, |
| cmap='plasma', |
| sr=sample_rate, |
| x_axis='time', |
| y_axis='linear', |
| ax=ax |
| ) |
| ax.set(title='File: ' + os.path.basename(output_file)) |
| fig.colorbar(img, ax=ax, format="%+2.f dB") |
| if output_file is not None: |
| plt.savefig(output_file) |
|
|