import os |
import sys |
import argparse |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
from valid import check_validation |
from inference import proc_folder |
from train import train_model |
from scripts.redact_config import redact_config |
from scripts.valid_to_inference import copying_files |
from scripts.trim import trim_directory |
base_args = { |
'device_ids': '0', |
'model_type': '', |
'start_check_point': '', |
'config_path': '', |
'data_path': '', |
'valid_path': '', |
'results_path': 'tests/train_results', |
'store_dir': 'tests/valid_inference_result', |
'input_folder': '', |
'metrics': ['neg_log_wmse', 'l1_freq', 'si_sdr', 'sdr', 'aura_stft', 'aura_mrstft', 'bleedless', 'fullness'], |
'max_folders': 2 |
} |
def parse_args(dict_args): |
parser = argparse.ArgumentParser() |
parser.add_argument("--check_train", action='store_true', help="Check train or not") |
parser.add_argument("--check_valid", action='store_true', help="Check train or not") |
parser.add_argument("--check_inference", action='store_true', help="Check train or not") |
parser.add_argument('--device_ids', type=str, help='Device IDs for training/inference') |
parser.add_argument('--model_type', type=str, help='Model type') |
parser.add_argument('--start_check_point', type=str, help='Path to the checkpoint to start from') |
parser.add_argument('--config_path', type=str, help='Path to the configuration file') |
parser.add_argument('--data_path', type=str, help='Path to the training data') |
parser.add_argument('--valid_path', type=str, help='Path to the validation data') |
parser.add_argument('--results_path', type=str, help='Path to save training results') |
parser.add_argument('--store_dir', type=str, help='Path to store validation/inference results') |
parser.add_argument('--input_folder', type=str, help='Path to the input folder for inference') |
parser.add_argument('--metrics', nargs='+', help='List of metrics to evaluate') |
parser.add_argument('--max_folders', type=str, help='Maximum number of folders to process') |
parser.add_argument("--dataset_type", type=int, default=1, |
help="Dataset type. Must be one of: 1, 2, 3 or 4.") |
parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers") |
parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory") |
parser.add_argument("--seed", type=int, default=0, help="random seed") |
parser.add_argument("--use_multistft_loss", action='store_true', |
help="Use MultiSTFT Loss (from auraloss package)") |
parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss") |
parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss") |
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key') |
parser.add_argument("--pre_valid", action='store_true', help='Run validation before training') |
parser.add_argument("--metric_for_scheduler", default="sdr", |
choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', |
'fullness'], help='Metric which will be used for scheduler.') |
parser.add_argument("--train_lora", action='store_true', help="Train with LoRA") |
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") |
parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation") |
parser.add_argument("--use_tta", action='store_true', |
help="Flag adds test time augmentation during inference (polarity and channel inverse)." |
" While this triples the runtime, it reduces noise and slightly improves prediction quality.") |
parser.add_argument("--extract_instrumental", action='store_true', |
help="invert vocals to get instrumental if provided") |
parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar") |
parser.add_argument("--force_cpu", action='store_true', help="Force the use of CPU even if CUDA is available") |
parser.add_argument("--flac_file", action='store_true', help="Output flac file instead of wav") |
parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', |
help="PCM type for FLAC files (PCM_16 or PCM_24)") |
parser.add_argument("--draw_spectro", type=float, default=0, |
help="If --store_dir is set then code will generate spectrograms for resulted stems as well." |
" Value defines for how many seconds os track spectrogram will be generated.") |
if dict_args is not None: |
args = parser.parse_args([]) |
args_dict = vars(args) |
args_dict.update(dict_args) |
args = argparse.Namespace(**args_dict) |
else: |
args = parser.parse_args() |
return args |
def test_settings(dict_args, test_type): |
cli_args = parse_args(dict_args) |
for key, value in vars(cli_args).items(): |
if value is not None: |
base_args[key] = value |
if test_type == 'user': |
missing_args = [arg for arg in ['model_type', 'config_path', 'start_check_point', 'data_path', 'valid_path'] if |
not base_args[arg]] |
if missing_args: |
missing_args_str = ', '.join(f'--{arg}' for arg in missing_args) |
raise ValueError( |
f"The following arguments are required but missing: {missing_args_str}." |
f" Please specify them either via command-line arguments or directly in `base_args`.") |
base_args['config_path'] = redact_config({'orig_config': base_args['config_path'], |
'model_type': base_args['model_type'], |
'new_config': ''}) |
trim_args_train = {'input_directory': base_args['data_path'], |
'max_folders': base_args['max_folders']} |
base_args['data_path'] = trim_directory(trim_args_train) |
trim_args_valid = {'input_directory': base_args['valid_path'], |
'max_folders': base_args['max_folders']} |
base_args['valid_path'] = trim_directory(trim_args_valid) |
if not base_args['input_folder']: |
tests_dir = os.path.join(os.path.dirname(base_args['valid_path']), 'for_inference') |
base_args['input_folder'] = tests_dir |
val_to_inf_args = {'valid_path': base_args['valid_path'], |
'inference_dir': base_args['input_folder'], |
'max_mixtures': 1} |
copying_files(val_to_inf_args) |
if base_args['check_valid']: |
valid_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', |
'store_dir', 'device_ids', 'num_workers', 'pin_memory', 'extension', |
'use_tta', 'metrics', 'lora_checkpoint', 'draw_spectro']} |
valid_args['valid_path'] = [base_args['valid_path']] |
print('Start validation.') |
check_validation(valid_args) |
print(f'Validation ended. See results in {base_args["store_dir"]}') |
if base_args['check_inference']: |
inference_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'input_folder', |
'store_dir', 'device_ids', 'extract_instrumental', |
'disable_detailed_pbar', 'force_cpu', 'flac_file', 'pcm_type', |
'use_tta', 'lora_checkpoint', 'draw_spectro']} |
print('Start inference.') |
proc_folder(inference_args) |
print(f'Inference ended. See results in {base_args["store_dir"]}') |
if base_args['check_train']: |
train_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'results_path', |
'data_path', 'dataset_type', 'valid_path', 'num_workers', 'pin_memory', |
'seed', 'device_ids', 'use_multistft_loss', 'use_mse_loss', |
'use_l1_loss', 'wandb_key', 'pre_valid', 'metrics', |
'metric_for_scheduler', 'train_lora', 'lora_checkpoint']} |
print('Start train.') |
train_model(train_args) |
print('End!') |
if __name__ == "__main__": |
test_settings(None, 'user') |