xjsc0's picture
1
64ec292
import os
from pathlib import Path
from typing import Dict, List
import numpy as np
import soundfile as sf
from scripts.redact_config import redact_config
from test import test_settings
from utils.settings import load_config
MODEL_CONFIGS = {
"config_apollo.yaml": {"model_type": "apollo"},
"config_dnr_bandit_bsrnn_multi_mus64.yaml": {"model_type": "bandit"},
"config_dnr_bandit_v2_mus64.yaml": {"model_type": "bandit_v2"},
"config_drumsep.yaml": {"model_type": "htdemucs"},
"config_htdemucs_6stems.yaml": {"model_type": "htdemucs"},
"config_musdb18_bs_roformer.yaml": {"model_type": "bs_roformer"},
"config_musdb18_demucs3_mmi.yaml": {"model_type": "htdemucs"},
"config_musdb18_htdemucs.yaml": {"model_type": "htdemucs"},
"config_musdb18_mdx23c.yaml": {"model_type": "mdx23c"},
"config_musdb18_mel_band_roformer.yaml": {"model_type": "mel_band_roformer"},
"config_musdb18_mel_band_roformer_all_stems.yaml": {
"model_type": "mel_band_roformer"
},
"config_musdb18_scnet.yaml": {"model_type": "scnet"},
"config_musdb18_scnet_large.yaml": {"model_type": "scnet"},
# 'config_musdb18_scnet_large_starrytong.yaml': {'model_type': 'scnet'},
"config_vocals_bandit_bsrnn_multi_mus64.yaml": {"model_type": "bandit"},
"config_vocals_bs_roformer.yaml": {"model_type": "bs_roformer"},
"config_vocals_htdemucs.yaml": {"model_type": "htdemucs"},
"config_vocals_mdx23c.yaml": {"model_type": "mdx23c"},
"config_vocals_mel_band_roformer.yaml": {"model_type": "mel_band_roformer"},
"config_vocals_scnet.yaml": {"model_type": "scnet"},
"config_vocals_scnet_large.yaml": {"model_type": "scnet"},
"config_vocals_scnet_unofficial.yaml": {"model_type": "scnet_unofficial"},
"config_vocals_segm_models.yaml": {"model_type": "segm_models"},
# 'config_vocals_swin_upernet.yaml': {'model_type': 'swin_upernet'},
# 'config_musdb18_torchseg.yaml': {'model_type': 'torchseg'},
# 'config_musdb18_segm_models.yaml': {'model_type': 'segm_models'},
# 'config_musdb18_bs_mamba2.yaml': {'model_type': 'bs_mamba2'},
# 'config_vocals_bs_mamba2.yaml': {'model_type': 'bs_mamba2'},
# 'config_vocals_torchseg.yaml': {'model_type': 'torchseg'}
}
# Folders for tests
ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CONFIGS_DIR = ROOT_DIR / "configs/"
TEST_DIR = ROOT_DIR / "tests_cache/"
TRAIN_DIR = TEST_DIR / "train_tracks/"
VALID_DIR = TEST_DIR / "valid_tracks/"
def create_dummy_tracks(
directory: Path,
num_tracks: int,
instruments: List[str],
duration: float = 5.0,
sample_rate: int = 44100,
) -> None:
"""
Generates random audio tracks for stems in two subdirectories within the specified directory.
Parameters:
----------
directory : Path
Path to the directory where the tracks will be saved.
num_tracks : int
Number of tracks to generate in each folder.
instruments : List[str]
List of instrument names (stems) to create.
duration : float, optional
Duration of each track in seconds. Default is 5.0.
sample_rate : int, optional
Sampling rate of the generated audio. Default is 44100 Hz.
Returns:
-------
None
"""
os.makedirs(directory, exist_ok=True)
for folder_name in [str(i) for i in range(1, num_tracks + 1)]:
folder_path = directory / folder_name
os.makedirs(folder_path, exist_ok=True)
for instrument in instruments:
# Generate random noice for each track
samples = int(duration * sample_rate)
track = np.random.uniform(-1.0, 1.0, (2, samples)).astype(np.float32)
file_path = folder_path / f"{instrument}.wav"
sf.write(file_path, track.T, sample_rate)
def cleanup_test_tracks() -> None:
"""
Removes all cached test tracks.
This function deletes the entire directory specified by the global `TEST_DIR` variable
if it exists.
Returns:
-------
None
This function does not return a value. It performs cleanup of test data.
"""
def modify_configs() -> Dict[str, Path]:
"""
Updates configuration files in the `configs` directory for use with test data.
This function processes configuration files defined in the global `MODEL_CONFIGS` dictionary,
modifies them to be compatible with test scenarios, and saves the updated configurations
in a test-specific directory.
Returns:
-------
Dict[str, Path]
A dictionary where the keys are the original configuration file names, and the values
are the paths to the updated configuration files.
"""
config_dir = CONFIGS_DIR
updated_configs = {}
for config, args in MODEL_CONFIGS.items():
model_type = args["model_type"]
config_path = config_dir / config
updated_config_path = redact_config(
{
"orig_config": str(config_path),
"model_type": model_type,
"new_config": str(TEST_DIR / "configs" / config),
}
)
updated_configs[config] = updated_config_path
return updated_configs
def run_tests() -> None:
"""
Executes validation tests for all configurations.
This function updates configurations, generates random dummy data for testing,
and runs a series of tests (training, validation, and inference checks) for each
model configuration specified in the global `MODEL_CONFIGS` dictionary.
Returns:
-------
None
"""
updated_configs = modify_configs()
# For every config
for config, args in MODEL_CONFIGS.items():
model_type = args["model_type"]
cfg = load_config(
model_type=model_type, config_path=TEST_DIR / "configs" / config
)
# Random tracks
create_dummy_tracks(
TRAIN_DIR, instruments=cfg.training.instruments + ["mixture"], num_tracks=2
)
create_dummy_tracks(
VALID_DIR, instruments=cfg.training.instruments + ["mixture"], num_tracks=2
)
print(f"\nRunning tests for model: {model_type} (config: {config})")
test_args = {
"check_train": False,
"check_valid": True,
"check_inference": True,
"config_path": updated_configs[config],
"data_path": str(TRAIN_DIR),
"valid_path": str(VALID_DIR),
"results_path": str(TEST_DIR / "results" / model_type),
"store_dir": str(TEST_DIR / "inference_results" / model_type),
"metrics": ["sdr", "si_sdr", "l1_freq"],
}
test_args.update(args)
test_settings(test_args, "admin")
print(f"Tests for model {model_type} completed successfully.")
# Remove test_cache
cleanup_test_tracks()
if __name__ == "__main__":
run_tests()