import os import sys import glob import torch import shutil import torchaudio import pytorch_lightning as pl import random from tqdm import tqdm from pathlib import Path from remfx import effects as effect_lib from typing import Any, List, Dict from torch.utils.data import Dataset, DataLoader from remfx.utils import select_random_chunk import multiprocessing from auraloss.freq import MultiResolutionSTFTLoss STFT_THRESH = 1e-3 ALL_EFFECTS = effect_lib.Pedalboard_Effects vocalset_splits = { "train": [ "male1", "male2", "male3", "male4", "male5", "male6", "male7", "male8", "male9", "female1", "female2", "female3", "female4", "female5", "female6", "female7", ], "val": ["male10", "female8"], "test": ["male11", "female9"], } guitarset_splits = {"train": ["00", "01", "02", "03"], "val": ["04"], "test": ["05"]} dsd_100_splits = { "train": ["train"], "val": ["val"], "test": ["test"], } idmt_drums_splits = { "train": ["WaveDrum02", "TechnoDrum01"], "val": ["RealDrum01"], "test": ["TechnoDrum02", "WaveDrum01"], } def locate_files(root: str, mode: str): file_list = [] # ------------------------- VocalSet ------------------------- vocalset_dir = os.path.join(root, "VocalSet1-2") if os.path.isdir(vocalset_dir): # find all singer directories singer_dirs = glob.glob(os.path.join(vocalset_dir, "data_by_singer", "*")) singer_dirs = [ sd for sd in singer_dirs if os.path.basename(sd) in vocalset_splits[mode] ] files = [] for singer_dir in singer_dirs: files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav")) print(f"Found {len(files)} files in VocalSet {mode}.") file_list.append(sorted(files)) # ------------------------- GuitarSet ------------------------- guitarset_dir = os.path.join(root, "audio_mono-mic") if os.path.isdir(guitarset_dir): files = glob.glob(os.path.join(guitarset_dir, "*.wav")) files = [ f for f in files if os.path.basename(f).split("_")[0] in guitarset_splits[mode] ] print(f"Found {len(files)} files in GuitarSet {mode}.") file_list.append(sorted(files)) # ------------------------- DSD100 --------------------------------- dsd_100_dir = os.path.join(root, "DSD100/DSD100") if os.path.isdir(dsd_100_dir): files = glob.glob( os.path.join(dsd_100_dir, mode, "**", "*.wav"), recursive=True, ) file_list.append(sorted(files)) print(f"Found {len(files)} files in DSD100 {mode}.") # ------------------------- IDMT-SMT-DRUMS ------------------------- idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2") if os.path.isdir(idmt_smt_drums_dir): files = glob.glob(os.path.join(idmt_smt_drums_dir, "audio", "*.wav")) files = [ f for f in files if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode] ] file_list.append(sorted(files)) print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.") return file_list def parallel_process_effects( chunk_idx: int, proc_root: str, files: list, chunk_size: int, effects: list, effects_to_keep: list, num_kept_effects: tuple, shuffle_kept_effects: bool, effects_to_remove: list, num_removed_effects: tuple, shuffle_removed_effects: bool, sample_rate: int, target_lufs_db: float, ): """Note: This function has an issue with random seed. It may not fully randomize the effects.""" chunk = None random_dataset_choice = random.choice(files) while chunk is None: random_file_choice = random.choice(random_dataset_choice) chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate) # Sum to mono if chunk.shape[0] > 1: chunk = chunk.sum(0, keepdim=True) dry = chunk # loudness normalization normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db) # Apply Kept Effects # Shuffle effects if specified if shuffle_kept_effects: effect_indices = torch.randperm(len(effects_to_keep)) else: effect_indices = torch.arange(len(effects_to_keep)) r1 = num_kept_effects[0] r2 = num_kept_effects[1] num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int() effect_indices = effect_indices[:num_kept_effects] # Index in effect settings effect_names_to_apply = [effects_to_keep[i] for i in effect_indices] effects_to_apply = [effects[i] for i in effect_names_to_apply] # Apply dry_labels = [] for effect in effects_to_apply: # Normalize in-between effects dry = normalize(effect(dry)) dry_labels.append(ALL_EFFECTS.index(type(effect))) # Apply effects_to_remove # Shuffle effects if specified if shuffle_removed_effects: effect_indices = torch.randperm(len(effects_to_remove)) else: effect_indices = torch.arange(len(effects_to_remove)) wet = torch.clone(dry) r1 = num_removed_effects[0] r2 = num_removed_effects[1] num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int() effect_indices = effect_indices[:num_removed_effects] # Index in effect settings effect_names_to_apply = [effects_to_remove[i] for i in effect_indices] effects_to_apply = [effects[i] for i in effect_names_to_apply] # Apply wet_labels = [] for effect in effects_to_apply: # Normalize in-between effects wet = normalize(effect(wet)) wet_labels.append(ALL_EFFECTS.index(type(effect))) wet_labels_tensor = torch.zeros(len(ALL_EFFECTS)) dry_labels_tensor = torch.zeros(len(ALL_EFFECTS)) for label_idx in wet_labels: wet_labels_tensor[label_idx] = 1.0 for label_idx in dry_labels: dry_labels_tensor[label_idx] = 1.0 # Normalize normalized_dry = normalize(dry) normalized_wet = normalize(wet) output_dir = proc_root / str(chunk_idx) output_dir.mkdir(exist_ok=True) torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate) torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate) torch.save(dry_labels_tensor, output_dir / "dry_effects.pt") torch.save(wet_labels_tensor, output_dir / "wet_effects.pt") # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor class DynamicEffectDataset(Dataset): def __init__( self, root: str, sample_rate: int, chunk_size: int = 262144, total_chunks: int = 1000, effect_modules: List[Dict[str, torch.nn.Module]] = None, effects_to_keep: List[str] = None, effects_to_remove: List[str] = None, num_kept_effects: List[int] = [1, 5], num_removed_effects: List[int] = [1, 5], shuffle_kept_effects: bool = True, shuffle_removed_effects: bool = False, render_files: bool = True, render_root: str = None, mode: str = "train", parallel: bool = False, ) -> None: super().__init__() self.chunks = [] self.song_idx = [] self.root = Path(root) self.render_root = Path(render_root) self.chunk_size = chunk_size self.total_chunks = total_chunks self.sample_rate = sample_rate self.mode = mode self.num_kept_effects = num_kept_effects self.num_removed_effects = num_removed_effects self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20) self.effects = effect_modules self.shuffle_kept_effects = shuffle_kept_effects self.shuffle_removed_effects = shuffle_removed_effects effects_string = "_".join( self.effects_to_keep + ["_"] + self.effects_to_remove + ["_"] + [str(x) for x in num_kept_effects] + ["_"] + [str(x) for x in num_removed_effects] ) # self.validate_effect_input() # self.proc_root = self.render_root / "processed" / effects_string / self.mode self.parallel = parallel self.files = locate_files(self.root, self.mode) def process_effects(self, dry: torch.Tensor): # Apply Kept Effects # Shuffle effects if specified if self.shuffle_kept_effects: effect_indices = torch.randperm(len(self.effects_to_keep)) else: effect_indices = torch.arange(len(self.effects_to_keep)) r1 = self.num_kept_effects[0] r2 = self.num_kept_effects[1] num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int() effect_indices = effect_indices[:num_kept_effects] # Index in effect settings effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices] effects_to_apply = [self.effects[i] for i in effect_names_to_apply] # Apply dry_labels = [] for effect in effects_to_apply: # Normalize in-between effects dry = self.normalize(effect(dry)) dry_labels.append(ALL_EFFECTS.index(type(effect))) # Apply effects_to_remove # Shuffle effects if specified if self.shuffle_removed_effects: effect_indices = torch.randperm(len(self.effects_to_remove)) else: effect_indices = torch.arange(len(self.effects_to_remove)) wet = torch.clone(dry) r1 = self.num_removed_effects[0] r2 = self.num_removed_effects[1] num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int() effect_indices = effect_indices[:num_removed_effects] # Index in effect settings effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices] effects_to_apply = [self.effects[i] for i in effect_names_to_apply] # Apply wet_labels = [] for effect in effects_to_apply: # Normalize in-between effects wet = self.normalize(effect(wet)) wet_labels.append(ALL_EFFECTS.index(type(effect))) wet_labels_tensor = torch.zeros(len(ALL_EFFECTS)) dry_labels_tensor = torch.zeros(len(ALL_EFFECTS)) for label_idx in wet_labels: wet_labels_tensor[label_idx] = 1.0 for label_idx in dry_labels: dry_labels_tensor[label_idx] = 1.0 # Normalize normalized_dry = self.normalize(dry) normalized_wet = self.normalize(wet) return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor def __len__(self): return self.total_chunks def __getitem__(self, _: int): chunk = None random_dataset_choice = random.choice(self.files) while chunk is None: random_file_choice = random.choice(random_dataset_choice) chunk = select_random_chunk( random_file_choice, self.chunk_size, self.sample_rate ) # Sum to mono if chunk.shape[0] > 1: chunk = chunk.sum(0, keepdim=True) dry, wet, dry_effects, wet_effects = self.process_effects(chunk) return wet, dry, dry_effects, wet_effects class EffectDataset(Dataset): def __init__( self, root: str, sample_rate: int, chunk_size: int = 262144, total_chunks: int = 1000, effect_modules: List[Dict[str, torch.nn.Module]] = None, effects_to_keep: List[str] = None, effects_to_remove: List[str] = None, num_kept_effects: List[int] = [1, 5], num_removed_effects: List[int] = [1, 5], shuffle_kept_effects: bool = True, shuffle_removed_effects: bool = False, render_files: bool = True, render_root: str = None, mode: str = "train", parallel: bool = False, ): super().__init__() self.chunks = [] self.song_idx = [] self.root = Path(root) self.render_root = Path(render_root) self.chunk_size = chunk_size self.total_chunks = total_chunks self.sample_rate = sample_rate self.mode = mode self.num_kept_effects = num_kept_effects self.num_removed_effects = num_removed_effects self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20) self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate) self.effects = effect_modules self.shuffle_kept_effects = shuffle_kept_effects self.shuffle_removed_effects = shuffle_removed_effects effects_string = "_".join( self.effects_to_keep + ["_"] + self.effects_to_remove + ["_"] + [str(x) for x in num_kept_effects] + ["_"] + [str(x) for x in num_removed_effects] ) self.validate_effect_input() self.proc_root = self.render_root / "processed" / effects_string / self.mode self.parallel = parallel self.files = locate_files(self.root, self.mode) if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0: print("Found processed files.") if render_files: re_render = input( "WARNING: By default, will re-render files.\n" "Set render_files=False to skip re-rendering.\n" "Are you sure you want to re-render? (y/n): " ) if re_render != "y": sys.exit() shutil.rmtree(self.proc_root) print("Total datasets:", len(self.files)) print("Processing files...") if render_files: # Split audio file into chunks, resample, then apply random effects self.proc_root.mkdir(parents=True, exist_ok=True) if self.parallel: items = [ ( chunk_idx, self.proc_root, self.files, self.chunk_size, self.effects, self.effects_to_keep, self.num_kept_effects, self.shuffle_kept_effects, self.effects_to_remove, self.num_removed_effects, self.shuffle_removed_effects, self.sample_rate, -20.0, ) for chunk_idx in range(self.total_chunks) ] with multiprocessing.Pool(processes=32) as pool: pool.starmap(parallel_process_effects, items) print(f"Done proccessing {self.total_chunks}", flush=True) else: for num_chunk in tqdm(range(self.total_chunks)): chunk = None random_dataset_choice = random.choice(self.files) while chunk is None: try: random_file_choice = random.choice(random_dataset_choice) except IndexError: print("IndexError") print(random_dataset_choice) print(random_file_choice) raise IndexError chunk = select_random_chunk( random_file_choice, self.chunk_size, self.sample_rate ) # Sum to mono if chunk.shape[0] > 1: chunk = chunk.sum(0, keepdim=True) dry, wet, dry_effects, wet_effects = self.process_effects(chunk) output_dir = self.proc_root / str(num_chunk) output_dir.mkdir(exist_ok=True) torchaudio.save(output_dir / "input.wav", wet, self.sample_rate) torchaudio.save(output_dir / "target.wav", dry, self.sample_rate) torch.save(dry_effects, output_dir / "dry_effects.pt") torch.save(wet_effects, output_dir / "wet_effects.pt") print("Finished rendering") else: self.total_chunks = len(list(self.proc_root.iterdir())) print("Total chunks:", self.total_chunks) def __len__(self): return self.total_chunks def __getitem__(self, idx): input_file = self.proc_root / str(idx) / "input.wav" target_file = self.proc_root / str(idx) / "target.wav" dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt") wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt") input, sr = torchaudio.load(input_file) target, sr = torchaudio.load(target_file) return (input, target, dry_effect_names, wet_effect_names) def validate_effect_input(self): for effect in self.effects.values(): if type(effect) not in ALL_EFFECTS: raise ValueError( f"Effect {effect} not found in ALL_EFFECTS. " f"Please choose from {ALL_EFFECTS}" ) for effect in self.effects_to_keep: if effect not in self.effects.keys(): raise ValueError( f"Effect {effect} not found in self.effects. " f"Please choose from {self.effects.keys()}" ) for effect in self.effects_to_remove: if effect not in self.effects.keys(): raise ValueError( f"Effect {effect} not found in self.effects. " f"Please choose from {self.effects.keys()}" ) kept_str = "randomly" if self.shuffle_kept_effects else "in order" rem_str = "randomly" if self.shuffle_removed_effects else "in order" if self.num_kept_effects[0] > self.num_kept_effects[1]: raise ValueError( f"num_kept_effects must be a tuple of (min, max). " f"Got {self.num_kept_effects}" ) if self.num_kept_effects[0] == self.num_kept_effects[1]: num_kept_str = f"{self.num_kept_effects[0]}" else: num_kept_str = ( f"Between {self.num_kept_effects[0]}-{self.num_kept_effects[1]}" ) if self.num_removed_effects[0] > self.num_removed_effects[1]: raise ValueError( f"num_removed_effects must be a tuple of (min, max). " f"Got {self.num_removed_effects}" ) if self.num_removed_effects[0] == self.num_removed_effects[1]: num_rem_str = f"{self.num_removed_effects[0]}" else: num_rem_str = ( f"Between {self.num_removed_effects[0]}-{self.num_removed_effects[1]}" ) rem_fx = self.effects_to_remove kept_fx = self.effects_to_keep print( f"Effect Summary: \n" f"Apply kept effects: {kept_fx} ({num_kept_str}, chosen {kept_str}) -> Dry\n" f"Apply remove effects: {rem_fx} ({num_rem_str}, chosen {rem_str}) -> Wet\n" ) def process_effects(self, dry: torch.Tensor): # Apply Kept Effects # Shuffle effects if specified if self.shuffle_kept_effects: effect_indices = torch.randperm(len(self.effects_to_keep)) else: effect_indices = torch.arange(len(self.effects_to_keep)) r1 = self.num_kept_effects[0] r2 = self.num_kept_effects[1] num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int() effect_indices = effect_indices[:num_kept_effects] # Index in effect settings effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices] effects_to_apply = [self.effects[i] for i in effect_names_to_apply] # stft comparison stft = 0 while stft < STFT_THRESH: # Apply dry_labels = [] for effect in effects_to_apply: # Normalize in-between effects dry = self.normalize(effect(dry)) dry_labels.append(ALL_EFFECTS.index(type(effect))) # Apply effects_to_remove # Shuffle effects if specified if self.shuffle_removed_effects: effect_indices = torch.randperm(len(self.effects_to_remove)) else: effect_indices = torch.arange(len(self.effects_to_remove)) wet = torch.clone(dry) r1 = self.num_removed_effects[0] r2 = self.num_removed_effects[1] num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int() effect_indices = effect_indices[:num_removed_effects] # Index in effect settings effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices] effects_to_apply = [self.effects[i] for i in effect_names_to_apply] # Apply wet_labels = [] for effect in effects_to_apply: # Normalize in-between effects wet = self.normalize(effect(wet)) wet_labels.append(ALL_EFFECTS.index(type(effect))) wet_labels_tensor = torch.zeros(len(ALL_EFFECTS)) dry_labels_tensor = torch.zeros(len(ALL_EFFECTS)) for label_idx in wet_labels: wet_labels_tensor[label_idx] = 1.0 for label_idx in dry_labels: dry_labels_tensor[label_idx] = 1.0 # Normalize normalized_dry = self.normalize(dry) normalized_wet = self.normalize(wet) # Check STFT, pick different effects if necessary if num_removed_effects == 0: # No need to check if no effects removed break stft = self.mrstft(normalized_wet.unsqueeze(0), normalized_dry.unsqueeze(0)) return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor class InferenceDataset(Dataset): def __init__(self, root: str, sample_rate: int, **kwargs): self.root = Path(root) self.sample_rate = sample_rate self.clean_paths = sorted(list(self.root.glob("clean/*.wav"))) self.effected_paths = sorted(list(self.root.glob("effected/*.wav"))) def __len__(self) -> int: return len(self.clean_paths) def __getitem__(self, idx: int) -> torch.Tensor: clean_path = self.clean_paths[idx] effected_path = self.effected_paths[idx] clean_audio, sr = torchaudio.load(clean_path) clean = torchaudio.functional.resample(clean_audio, sr, self.sample_rate) effected_audio, sr = torchaudio.load(effected_path) effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate) # Sum to mono clean = torch.sum(clean, dim=0, keepdim=True) effected = torch.sum(effected, dim=0, keepdim=True) # Pad or trim effected to clean if effected.shape[1] > clean.shape[1]: effected = effected[:, : clean.shape[1]] elif effected.shape[1] < clean.shape[1]: pad_size = clean.shape[1] - effected.shape[1] effected = torch.nn.functional.pad(effected, (0, pad_size)) dry_labels_tensor = torch.zeros(len(ALL_EFFECTS)) wet_labels_tensor = torch.ones(len(ALL_EFFECTS)) return effected, clean, dry_labels_tensor, wet_labels_tensor class EffectDatamodule(pl.LightningDataModule): def __init__( self, train_dataset, val_dataset, test_dataset, *, train_batch_size: int, test_batch_size: int, num_workers: int, pin_memory: bool = False, **kwargs: int, ) -> None: super().__init__() self.train_dataset = train_dataset self.val_dataset = val_dataset self.test_dataset = test_dataset self.train_batch_size = train_batch_size self.test_batch_size = test_batch_size self.num_workers = num_workers self.pin_memory = pin_memory def setup(self, stage: Any = None) -> None: pass def train_dataloader(self) -> DataLoader: return DataLoader( dataset=self.train_dataset, batch_size=self.train_batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=True, ) def val_dataloader(self) -> DataLoader: return DataLoader( dataset=self.val_dataset, batch_size=self.train_batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=False, ) def test_dataloader(self) -> DataLoader: return DataLoader( dataset=self.test_dataset, batch_size=self.test_batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=False, )