ASesYusuf1's picture
Upload folder using huggingface_hub
3978e51
raw
history blame
8.99 kB
import os
import sys
import argparse
# Добавляем корень репозитория в системный путь
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from valid import check_validation
from inference import proc_folder
from train import train_model
from scripts.redact_config import redact_config
from scripts.valid_to_inference import copying_files
from scripts.trim import trim_directory
base_args = {
'device_ids': '0',
'model_type': '',
'start_check_point': '',
'config_path': '',
'data_path': '',
'valid_path': '',
'results_path': 'tests/train_results',
'store_dir': 'tests/valid_inference_result',
'input_folder': '',
'metrics': ['neg_log_wmse', 'l1_freq', 'si_sdr', 'sdr', 'aura_stft', 'aura_mrstft', 'bleedless', 'fullness'],
'max_folders': 2
}
def parse_args(dict_args):
parser = argparse.ArgumentParser()
parser.add_argument("--check_train", action='store_true', help="Check train or not")
parser.add_argument("--check_valid", action='store_true', help="Check train or not")
parser.add_argument("--check_inference", action='store_true', help="Check train or not")
parser.add_argument('--device_ids', type=str, help='Device IDs for training/inference')
parser.add_argument('--model_type', type=str, help='Model type')
parser.add_argument('--start_check_point', type=str, help='Path to the checkpoint to start from')
parser.add_argument('--config_path', type=str, help='Path to the configuration file')
parser.add_argument('--data_path', type=str, help='Path to the training data')
parser.add_argument('--valid_path', type=str, help='Path to the validation data')
parser.add_argument('--results_path', type=str, help='Path to save training results')
parser.add_argument('--store_dir', type=str, help='Path to store validation/inference results')
parser.add_argument('--input_folder', type=str, help='Path to the input folder for inference')
parser.add_argument('--metrics', nargs='+', help='List of metrics to evaluate')
parser.add_argument('--max_folders', type=str, help='Maximum number of folders to process')
parser.add_argument("--dataset_type", type=int, default=1,
help="Dataset type. Must be one of: 1, 2, 3 or 4.")
parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers")
parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--use_multistft_loss", action='store_true',
help="Use MultiSTFT Loss (from auraloss package)")
parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss")
parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss")
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key')
parser.add_argument("--pre_valid", action='store_true', help='Run validation before training')
parser.add_argument("--metric_for_scheduler", default="sdr",
choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
'fullness'], help='Metric which will be used for scheduler.')
parser.add_argument("--train_lora", action='store_true', help="Train with LoRA")
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation")
parser.add_argument("--use_tta", action='store_true',
help="Flag adds test time augmentation during inference (polarity and channel inverse)."
" While this triples the runtime, it reduces noise and slightly improves prediction quality.")
parser.add_argument("--extract_instrumental", action='store_true',
help="invert vocals to get instrumental if provided")
parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar")
parser.add_argument("--force_cpu", action='store_true', help="Force the use of CPU even if CUDA is available")
parser.add_argument("--flac_file", action='store_true', help="Output flac file instead of wav")
parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24',
help="PCM type for FLAC files (PCM_16 or PCM_24)")
parser.add_argument("--draw_spectro", type=float, default=0,
help="If --store_dir is set then code will generate spectrograms for resulted stems as well."
" Value defines for how many seconds os track spectrogram will be generated.")
if dict_args is not None:
args = parser.parse_args([])
args_dict = vars(args)
args_dict.update(dict_args)
args = argparse.Namespace(**args_dict)
else:
args = parser.parse_args()
return args
def test_settings(dict_args, test_type):
# Parse from cmd
cli_args = parse_args(dict_args)
# If args from cmd, add or replace in base_args
for key, value in vars(cli_args).items():
if value is not None:
base_args[key] = value
if test_type == 'user':
# Check required arguments
missing_args = [arg for arg in ['model_type', 'config_path', 'start_check_point', 'data_path', 'valid_path'] if
not base_args[arg]]
if missing_args:
missing_args_str = ', '.join(f'--{arg}' for arg in missing_args)
raise ValueError(
f"The following arguments are required but missing: {missing_args_str}."
f" Please specify them either via command-line arguments or directly in `base_args`.")
# Replace config
base_args['config_path'] = redact_config({'orig_config': base_args['config_path'],
'model_type': base_args['model_type'],
'new_config': ''})
# Trim train
trim_args_train = {'input_directory': base_args['data_path'],
'max_folders': base_args['max_folders']}
base_args['data_path'] = trim_directory(trim_args_train)
# Trim valid
trim_args_valid = {'input_directory': base_args['valid_path'],
'max_folders': base_args['max_folders']}
base_args['valid_path'] = trim_directory(trim_args_valid)
# Valid to inference
if not base_args['input_folder']:
tests_dir = os.path.join(os.path.dirname(base_args['valid_path']), 'for_inference')
base_args['input_folder'] = tests_dir
val_to_inf_args = {'valid_path': base_args['valid_path'],
'inference_dir': base_args['input_folder'],
'max_mixtures': 1}
copying_files(val_to_inf_args)
if base_args['check_valid']:
valid_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point',
'store_dir', 'device_ids', 'num_workers', 'pin_memory', 'extension',
'use_tta', 'metrics', 'lora_checkpoint', 'draw_spectro']}
valid_args['valid_path'] = [base_args['valid_path']]
print('Start validation.')
check_validation(valid_args)
print(f'Validation ended. See results in {base_args["store_dir"]}')
if base_args['check_inference']:
inference_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'input_folder',
'store_dir', 'device_ids', 'extract_instrumental',
'disable_detailed_pbar', 'force_cpu', 'flac_file', 'pcm_type',
'use_tta', 'lora_checkpoint', 'draw_spectro']}
print('Start inference.')
proc_folder(inference_args)
print(f'Inference ended. See results in {base_args["store_dir"]}')
if base_args['check_train']:
train_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'results_path',
'data_path', 'dataset_type', 'valid_path', 'num_workers', 'pin_memory',
'seed', 'device_ids', 'use_multistft_loss', 'use_mse_loss',
'use_l1_loss', 'wandb_key', 'pre_valid', 'metrics',
'metric_for_scheduler', 'train_lora', 'lora_checkpoint']}
print('Start train.')
train_model(train_args)
print('End!')
if __name__ == "__main__":
test_settings(None, 'user')