import argparse import logging import os import sys import librosa import torch import tqdm from .data.data import EvalDataLoader, EvalDataset from . import distrib from .utils import remove_pad from .utils import bold, deserialize_model, LogProgress logger = logging.getLogger(__name__) def load_model(): global device global model global pkg print("Loading svoice model if available...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pkg = torch.load('checkpoint.th', map_location=device) if 'model' in pkg: model = pkg['model'] else: model = pkg model = deserialize_model(model) logger.debug(model) model.eval() model.to(device) print("svoice model loaded.") print("Device: {}".format(device)) parser = argparse.ArgumentParser("Speech separation using MulCat blocks") parser.add_argument("model_path", type=str, help="Model name") parser.add_argument("out_dir", type=str, default="exp/result", help="Directory putting enhanced wav files") parser.add_argument("--mix_dir", type=str, default=None, help="Directory including mix wav files") parser.add_argument("--mix_json", type=str, default=None, help="Json file including mix wav files") parser.add_argument('--device', default="cuda") parser.add_argument("--sample_rate", default=8000, type=int, help="Sample rate") parser.add_argument("--batch_size", default=1, type=int, help="Batch size") parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, default=logging.INFO, help="More loggging") def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000): # Remove padding and flat flat_estimate = remove_pad(estimate_source, lengths) mix_sig = remove_pad(mix_sig, lengths) # Write result for i, filename in enumerate(filenames): filename = os.path.join( out_dir, os.path.basename(filename).strip(".wav")) write(mix_sig[i], filename + ".wav", sr=sr) C = flat_estimate[i].shape[0] # future support for wave playing for c in range(C): write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr) def write(inputs, filename, sr=8000): librosa.output.write_wav(filename, inputs, sr, norm=True) def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000): mix_dir, mix_json = mix_dir, None out_dir = 'separated' # Load data eval_dataset = EvalDataset( mix_dir, mix_json, batch_size=batch_size, sample_rate=sample_rate, ) eval_loader = distrib.loader( eval_dataset, batch_size=1, klass=EvalDataLoader) if distrib.rank == 0: os.makedirs(out_dir, exist_ok=True) distrib.barrier() with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)): # Get batch data mixture, lengths, filenames = data mixture = mixture.to(device) lengths = lengths.to(device) # Forward estimate_sources = model(mixture)[-1] # save wav files save_wavs(estimate_sources, mixture, lengths, filenames, out_dir, sr=sample_rate) separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)] separated_files = [os.path.abspath(f) for f in separated_files] separated_files = [f for f in separated_files if not f.endswith('original.wav')] return separated_files def get_mix_paths(args): mix_dir = None mix_json = None # fix mix dir try: if args.dset.mix_dir: mix_dir = args.dset.mix_dir except: mix_dir = args.mix_dir # fix mix json try: if args.dset.mix_json: mix_json = args.dset.mix_json except: mix_json = args.mix_json return mix_dir, mix_json def separate(args, model=None, local_out_dir=None): mix_dir, mix_json = get_mix_paths(args) if not mix_json and not mix_dir: logger.error("Must provide mix_dir or mix_json! " "When providing mix_dir, mix_json is ignored.") # Load model if not model: # model pkg = torch.load(args.model_path) if 'model' in pkg: model = pkg['model'] else: model = pkg model = deserialize_model(model) logger.debug(model) model.eval() model.to(args.device) if local_out_dir: out_dir = local_out_dir else: out_dir = args.out_dir # Load data eval_dataset = EvalDataset( mix_dir, mix_json, batch_size=args.batch_size, sample_rate=args.sample_rate, ) eval_loader = distrib.loader( eval_dataset, batch_size=1, klass=EvalDataLoader) if distrib.rank == 0: os.makedirs(out_dir, exist_ok=True) distrib.barrier() with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)): # Get batch data mixture, lengths, filenames = data mixture = mixture.to(args.device) lengths = lengths.to(args.device) # Forward estimate_sources = model(mixture)[-1] # save wav files save_wavs(estimate_sources, mixture, lengths, filenames, out_dir, sr=args.sample_rate) if __name__ == "__main__": args = parser.parse_args() logging.basicConfig(stream=sys.stderr, level=args.verbose) logger.debug(args) separate(args, local_out_dir=args.out_dir)