RemFx / remfx /datasets.py
mattricesound's picture
Update to latest classifier inference
568c3f1
raw
history blame
No virus
25.7 kB
import os
import sys
import glob
import torch
import shutil
import torchaudio
import pytorch_lightning as pl
import random
from tqdm import tqdm
from pathlib import Path
from remfx import effects as effect_lib
from typing import Any, List, Dict
from torch.utils.data import Dataset, DataLoader
from remfx.utils import select_random_chunk
import multiprocessing
from auraloss.freq import MultiResolutionSTFTLoss
STFT_THRESH = 1e-3
ALL_EFFECTS = effect_lib.Pedalboard_Effects
vocalset_splits = {
"train": [
"male1",
"male2",
"male3",
"male4",
"male5",
"male6",
"male7",
"male8",
"male9",
"female1",
"female2",
"female3",
"female4",
"female5",
"female6",
"female7",
],
"val": ["male10", "female8"],
"test": ["male11", "female9"],
}
guitarset_splits = {"train": ["00", "01", "02", "03"], "val": ["04"], "test": ["05"]}
dsd_100_splits = {
"train": ["train"],
"val": ["val"],
"test": ["test"],
}
idmt_drums_splits = {
"train": ["WaveDrum02", "TechnoDrum01"],
"val": ["RealDrum01"],
"test": ["TechnoDrum02", "WaveDrum01"],
}
def locate_files(root: str, mode: str):
file_list = []
# ------------------------- VocalSet -------------------------
vocalset_dir = os.path.join(root, "VocalSet1-2")
if os.path.isdir(vocalset_dir):
# find all singer directories
singer_dirs = glob.glob(os.path.join(vocalset_dir, "data_by_singer", "*"))
singer_dirs = [
sd for sd in singer_dirs if os.path.basename(sd) in vocalset_splits[mode]
]
files = []
for singer_dir in singer_dirs:
files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
print(f"Found {len(files)} files in VocalSet {mode}.")
file_list.append(sorted(files))
# ------------------------- GuitarSet -------------------------
guitarset_dir = os.path.join(root, "audio_mono-mic")
if os.path.isdir(guitarset_dir):
files = glob.glob(os.path.join(guitarset_dir, "*.wav"))
files = [
f
for f in files
if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
]
print(f"Found {len(files)} files in GuitarSet {mode}.")
file_list.append(sorted(files))
# ------------------------- DSD100 ---------------------------------
dsd_100_dir = os.path.join(root, "DSD100/DSD100")
if os.path.isdir(dsd_100_dir):
files = glob.glob(
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
recursive=True,
)
file_list.append(sorted(files))
print(f"Found {len(files)} files in DSD100 {mode}.")
# ------------------------- IDMT-SMT-DRUMS -------------------------
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
if os.path.isdir(idmt_smt_drums_dir):
files = glob.glob(os.path.join(idmt_smt_drums_dir, "audio", "*.wav"))
files = [
f
for f in files
if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
]
file_list.append(sorted(files))
print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
return file_list
def parallel_process_effects(
chunk_idx: int,
proc_root: str,
files: list,
chunk_size: int,
effects: list,
effects_to_keep: list,
num_kept_effects: tuple,
shuffle_kept_effects: bool,
effects_to_remove: list,
num_removed_effects: tuple,
shuffle_removed_effects: bool,
sample_rate: int,
target_lufs_db: float,
):
"""Note: This function has an issue with random seed. It may not fully randomize the effects."""
chunk = None
random_dataset_choice = random.choice(files)
while chunk is None:
random_file_choice = random.choice(random_dataset_choice)
chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate)
# Sum to mono
if chunk.shape[0] > 1:
chunk = chunk.sum(0, keepdim=True)
dry = chunk
# loudness normalization
normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db)
# Apply Kept Effects
# Shuffle effects if specified
if shuffle_kept_effects:
effect_indices = torch.randperm(len(effects_to_keep))
else:
effect_indices = torch.arange(len(effects_to_keep))
r1 = num_kept_effects[0]
r2 = num_kept_effects[1]
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_kept_effects]
# Index in effect settings
effect_names_to_apply = [effects_to_keep[i] for i in effect_indices]
effects_to_apply = [effects[i] for i in effect_names_to_apply]
# Apply
dry_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
dry = normalize(effect(dry))
dry_labels.append(ALL_EFFECTS.index(type(effect)))
# Apply effects_to_remove
# Shuffle effects if specified
if shuffle_removed_effects:
effect_indices = torch.randperm(len(effects_to_remove))
else:
effect_indices = torch.arange(len(effects_to_remove))
wet = torch.clone(dry)
r1 = num_removed_effects[0]
r2 = num_removed_effects[1]
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_removed_effects]
# Index in effect settings
effect_names_to_apply = [effects_to_remove[i] for i in effect_indices]
effects_to_apply = [effects[i] for i in effect_names_to_apply]
# Apply
wet_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
wet = normalize(effect(wet))
wet_labels.append(ALL_EFFECTS.index(type(effect)))
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
for label_idx in wet_labels:
wet_labels_tensor[label_idx] = 1.0
for label_idx in dry_labels:
dry_labels_tensor[label_idx] = 1.0
# Normalize
normalized_dry = normalize(dry)
normalized_wet = normalize(wet)
output_dir = proc_root / str(chunk_idx)
output_dir.mkdir(exist_ok=True)
torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate)
torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate)
torch.save(dry_labels_tensor, output_dir / "dry_effects.pt")
torch.save(wet_labels_tensor, output_dir / "wet_effects.pt")
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
class DynamicEffectDataset(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size: int = 262144,
total_chunks: int = 1000,
effect_modules: List[Dict[str, torch.nn.Module]] = None,
effects_to_keep: List[str] = None,
effects_to_remove: List[str] = None,
num_kept_effects: List[int] = [1, 5],
num_removed_effects: List[int] = [1, 5],
shuffle_kept_effects: bool = True,
shuffle_removed_effects: bool = False,
render_files: bool = True,
render_root: str = None,
mode: str = "train",
parallel: bool = False,
) -> None:
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.render_root = Path(render_root)
self.chunk_size = chunk_size
self.total_chunks = total_chunks
self.sample_rate = sample_rate
self.mode = mode
self.num_kept_effects = num_kept_effects
self.num_removed_effects = num_removed_effects
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.effects = effect_modules
self.shuffle_kept_effects = shuffle_kept_effects
self.shuffle_removed_effects = shuffle_removed_effects
effects_string = "_".join(
self.effects_to_keep
+ ["_"]
+ self.effects_to_remove
+ ["_"]
+ [str(x) for x in num_kept_effects]
+ ["_"]
+ [str(x) for x in num_removed_effects]
)
# self.validate_effect_input()
# self.proc_root = self.render_root / "processed" / effects_string / self.mode
self.parallel = parallel
self.files = locate_files(self.root, self.mode)
def process_effects(self, dry: torch.Tensor):
# Apply Kept Effects
# Shuffle effects if specified
if self.shuffle_kept_effects:
effect_indices = torch.randperm(len(self.effects_to_keep))
else:
effect_indices = torch.arange(len(self.effects_to_keep))
r1 = self.num_kept_effects[0]
r2 = self.num_kept_effects[1]
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_kept_effects]
# Index in effect settings
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
# Apply
dry_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
dry = self.normalize(effect(dry))
dry_labels.append(ALL_EFFECTS.index(type(effect)))
# Apply effects_to_remove
# Shuffle effects if specified
if self.shuffle_removed_effects:
effect_indices = torch.randperm(len(self.effects_to_remove))
else:
effect_indices = torch.arange(len(self.effects_to_remove))
wet = torch.clone(dry)
r1 = self.num_removed_effects[0]
r2 = self.num_removed_effects[1]
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_removed_effects]
# Index in effect settings
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
# Apply
wet_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
wet = self.normalize(effect(wet))
wet_labels.append(ALL_EFFECTS.index(type(effect)))
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
for label_idx in wet_labels:
wet_labels_tensor[label_idx] = 1.0
for label_idx in dry_labels:
dry_labels_tensor[label_idx] = 1.0
# Normalize
normalized_dry = self.normalize(dry)
normalized_wet = self.normalize(wet)
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
def __len__(self):
return self.total_chunks
def __getitem__(self, _: int):
chunk = None
random_dataset_choice = random.choice(self.files)
while chunk is None:
random_file_choice = random.choice(random_dataset_choice)
chunk = select_random_chunk(
random_file_choice, self.chunk_size, self.sample_rate
)
# Sum to mono
if chunk.shape[0] > 1:
chunk = chunk.sum(0, keepdim=True)
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
return wet, dry, dry_effects, wet_effects
class EffectDataset(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size: int = 262144,
total_chunks: int = 1000,
effect_modules: List[Dict[str, torch.nn.Module]] = None,
effects_to_keep: List[str] = None,
effects_to_remove: List[str] = None,
num_kept_effects: List[int] = [1, 5],
num_removed_effects: List[int] = [1, 5],
shuffle_kept_effects: bool = True,
shuffle_removed_effects: bool = False,
render_files: bool = True,
render_root: str = None,
mode: str = "train",
parallel: bool = False,
):
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.render_root = Path(render_root)
self.chunk_size = chunk_size
self.total_chunks = total_chunks
self.sample_rate = sample_rate
self.mode = mode
self.num_kept_effects = num_kept_effects
self.num_removed_effects = num_removed_effects
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate)
self.effects = effect_modules
self.shuffle_kept_effects = shuffle_kept_effects
self.shuffle_removed_effects = shuffle_removed_effects
effects_string = "_".join(
self.effects_to_keep
+ ["_"]
+ self.effects_to_remove
+ ["_"]
+ [str(x) for x in num_kept_effects]
+ ["_"]
+ [str(x) for x in num_removed_effects]
)
self.validate_effect_input()
self.proc_root = self.render_root / "processed" / effects_string / self.mode
self.parallel = parallel
self.files = locate_files(self.root, self.mode)
if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
print("Found processed files.")
if render_files:
re_render = input(
"WARNING: By default, will re-render files.\n"
"Set render_files=False to skip re-rendering.\n"
"Are you sure you want to re-render? (y/n): "
)
if re_render != "y":
sys.exit()
shutil.rmtree(self.proc_root)
print("Total datasets:", len(self.files))
print("Processing files...")
if render_files:
# Split audio file into chunks, resample, then apply random effects
self.proc_root.mkdir(parents=True, exist_ok=True)
if self.parallel:
items = [
(
chunk_idx,
self.proc_root,
self.files,
self.chunk_size,
self.effects,
self.effects_to_keep,
self.num_kept_effects,
self.shuffle_kept_effects,
self.effects_to_remove,
self.num_removed_effects,
self.shuffle_removed_effects,
self.sample_rate,
-20.0,
)
for chunk_idx in range(self.total_chunks)
]
with multiprocessing.Pool(processes=32) as pool:
pool.starmap(parallel_process_effects, items)
print(f"Done proccessing {self.total_chunks}", flush=True)
else:
for num_chunk in tqdm(range(self.total_chunks)):
chunk = None
random_dataset_choice = random.choice(self.files)
while chunk is None:
try:
random_file_choice = random.choice(random_dataset_choice)
except IndexError:
print("IndexError")
print(random_dataset_choice)
print(random_file_choice)
raise IndexError
chunk = select_random_chunk(
random_file_choice, self.chunk_size, self.sample_rate
)
# Sum to mono
if chunk.shape[0] > 1:
chunk = chunk.sum(0, keepdim=True)
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
output_dir = self.proc_root / str(num_chunk)
output_dir.mkdir(exist_ok=True)
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
torch.save(dry_effects, output_dir / "dry_effects.pt")
torch.save(wet_effects, output_dir / "wet_effects.pt")
print("Finished rendering")
else:
self.total_chunks = len(list(self.proc_root.iterdir()))
print("Total chunks:", self.total_chunks)
def __len__(self):
return self.total_chunks
def __getitem__(self, idx):
input_file = self.proc_root / str(idx) / "input.wav"
target_file = self.proc_root / str(idx) / "target.wav"
dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt")
wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt")
input, sr = torchaudio.load(input_file)
target, sr = torchaudio.load(target_file)
return (input, target, dry_effect_names, wet_effect_names)
def validate_effect_input(self):
for effect in self.effects.values():
if type(effect) not in ALL_EFFECTS:
raise ValueError(
f"Effect {effect} not found in ALL_EFFECTS. "
f"Please choose from {ALL_EFFECTS}"
)
for effect in self.effects_to_keep:
if effect not in self.effects.keys():
raise ValueError(
f"Effect {effect} not found in self.effects. "
f"Please choose from {self.effects.keys()}"
)
for effect in self.effects_to_remove:
if effect not in self.effects.keys():
raise ValueError(
f"Effect {effect} not found in self.effects. "
f"Please choose from {self.effects.keys()}"
)
kept_str = "randomly" if self.shuffle_kept_effects else "in order"
rem_str = "randomly" if self.shuffle_removed_effects else "in order"
if self.num_kept_effects[0] > self.num_kept_effects[1]:
raise ValueError(
f"num_kept_effects must be a tuple of (min, max). "
f"Got {self.num_kept_effects}"
)
if self.num_kept_effects[0] == self.num_kept_effects[1]:
num_kept_str = f"{self.num_kept_effects[0]}"
else:
num_kept_str = (
f"Between {self.num_kept_effects[0]}-{self.num_kept_effects[1]}"
)
if self.num_removed_effects[0] > self.num_removed_effects[1]:
raise ValueError(
f"num_removed_effects must be a tuple of (min, max). "
f"Got {self.num_removed_effects}"
)
if self.num_removed_effects[0] == self.num_removed_effects[1]:
num_rem_str = f"{self.num_removed_effects[0]}"
else:
num_rem_str = (
f"Between {self.num_removed_effects[0]}-{self.num_removed_effects[1]}"
)
rem_fx = self.effects_to_remove
kept_fx = self.effects_to_keep
print(
f"Effect Summary: \n"
f"Apply kept effects: {kept_fx} ({num_kept_str}, chosen {kept_str}) -> Dry\n"
f"Apply remove effects: {rem_fx} ({num_rem_str}, chosen {rem_str}) -> Wet\n"
)
def process_effects(self, dry: torch.Tensor):
# Apply Kept Effects
# Shuffle effects if specified
if self.shuffle_kept_effects:
effect_indices = torch.randperm(len(self.effects_to_keep))
else:
effect_indices = torch.arange(len(self.effects_to_keep))
r1 = self.num_kept_effects[0]
r2 = self.num_kept_effects[1]
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_kept_effects]
# Index in effect settings
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
# stft comparison
stft = 0
while stft < STFT_THRESH:
# Apply
dry_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
dry = self.normalize(effect(dry))
dry_labels.append(ALL_EFFECTS.index(type(effect)))
# Apply effects_to_remove
# Shuffle effects if specified
if self.shuffle_removed_effects:
effect_indices = torch.randperm(len(self.effects_to_remove))
else:
effect_indices = torch.arange(len(self.effects_to_remove))
wet = torch.clone(dry)
r1 = self.num_removed_effects[0]
r2 = self.num_removed_effects[1]
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
effect_indices = effect_indices[:num_removed_effects]
# Index in effect settings
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
# Apply
wet_labels = []
for effect in effects_to_apply:
# Normalize in-between effects
wet = self.normalize(effect(wet))
wet_labels.append(ALL_EFFECTS.index(type(effect)))
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
for label_idx in wet_labels:
wet_labels_tensor[label_idx] = 1.0
for label_idx in dry_labels:
dry_labels_tensor[label_idx] = 1.0
# Normalize
normalized_dry = self.normalize(dry)
normalized_wet = self.normalize(wet)
# Check STFT, pick different effects if necessary
if num_removed_effects == 0:
# No need to check if no effects removed
break
stft = self.mrstft(normalized_wet.unsqueeze(0), normalized_dry.unsqueeze(0))
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
class InferenceDataset(Dataset):
def __init__(self, root: str, sample_rate: int, **kwargs):
self.root = Path(root)
self.sample_rate = sample_rate
self.clean_paths = sorted(list(self.root.glob("clean/*.wav")))
self.effected_paths = sorted(list(self.root.glob("effected/*.wav")))
def __len__(self) -> int:
return len(self.clean_paths)
def __getitem__(self, idx: int) -> torch.Tensor:
clean_path = self.clean_paths[idx]
effected_path = self.effected_paths[idx]
clean_audio, sr = torchaudio.load(clean_path)
clean = torchaudio.functional.resample(clean_audio, sr, self.sample_rate)
effected_audio, sr = torchaudio.load(effected_path)
effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate)
# Sum to mono
clean = torch.sum(clean, dim=0, keepdim=True)
effected = torch.sum(effected, dim=0, keepdim=True)
# Pad or trim effected to clean
if effected.shape[1] > clean.shape[1]:
effected = effected[:, : clean.shape[1]]
elif effected.shape[1] < clean.shape[1]:
pad_size = clean.shape[1] - effected.shape[1]
effected = torch.nn.functional.pad(effected, (0, pad_size))
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
wet_labels_tensor = torch.ones(len(ALL_EFFECTS))
return effected, clean, dry_labels_tensor, wet_labels_tensor
class EffectDatamodule(pl.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
test_dataset,
*,
train_batch_size: int,
test_batch_size: int,
num_workers: int,
pin_memory: bool = False,
**kwargs: int,
) -> None:
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: Any = None) -> None:
pass
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.test_dataset,
batch_size=self.test_batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)