Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import glob | |
import torch | |
import shutil | |
import torchaudio | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from pathlib import Path | |
from remfx import effects | |
from ordered_set import OrderedSet | |
from typing import Any, List, Dict | |
from torch.utils.data import Dataset, DataLoader | |
from remfx.utils import create_sequential_chunks | |
# https://zenodo.org/record/1193957 -> VocalSet | |
ALL_EFFECTS = effects.Pedalboard_Effects | |
singer_splits = { | |
"train": [ | |
"male1", | |
"male2", | |
"male3", | |
"male4", | |
"male5", | |
"male6", | |
"male7", | |
"male8", | |
"male9", | |
"female1", | |
"female2", | |
"female3", | |
"female4", | |
"female5", | |
"female6", | |
"female7", | |
], | |
"val": ["male10", "female8"], | |
"test": ["male11", "female9"], | |
} | |
class VocalSet(Dataset): | |
def __init__( | |
self, | |
root: str, | |
sample_rate: int, | |
chunk_size: int = 262144, | |
effect_modules: List[Dict[str, torch.nn.Module]] = None, | |
effects_to_use: List[str] = None, | |
effects_to_remove: List[str] = None, | |
max_kept_effects: int = -1, | |
max_removed_effects: int = 1, | |
shuffle_kept_effects: bool = True, | |
shuffle_removed_effects: bool = False, | |
render_files: bool = True, | |
render_root: str = None, | |
mode: str = "train", | |
): | |
super().__init__() | |
self.chunks = [] | |
self.song_idx = [] | |
self.root = Path(root) | |
self.render_root = Path(render_root) | |
self.chunk_size = chunk_size | |
self.sample_rate = sample_rate | |
self.mode = mode | |
self.max_kept_effects = max_kept_effects | |
self.max_removed_effects = max_removed_effects | |
self.effects_to_use = effects_to_use | |
self.effects_to_remove = effects_to_remove | |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20) | |
self.effects = effect_modules | |
self.shuffle_kept_effects = shuffle_kept_effects | |
self.shuffle_removed_effects = shuffle_removed_effects | |
effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove) | |
self.effects_to_keep = self.validate_effect_input() | |
self.proc_root = self.render_root / "processed" / effects_string / self.mode | |
# find all singer directories | |
singer_dirs = glob.glob(os.path.join(self.root, "data_by_singer", "*")) | |
singer_dirs = [ | |
sd for sd in singer_dirs if os.path.basename(sd) in singer_splits[mode] | |
] | |
self.files = [] | |
for singer_dir in singer_dirs: | |
self.files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav")) | |
self.files = sorted(self.files) | |
if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0: | |
print("Found processed files.") | |
if render_files: | |
re_render = input( | |
"WARNING: By default, will re-render files.\n" | |
"Set render_files=False to skip re-rendering.\n" | |
"Are you sure you want to re-render? (y/n): " | |
) | |
if re_render != "y": | |
sys.exit() | |
shutil.rmtree(self.proc_root) | |
self.num_chunks = 0 | |
print("Total files:", len(self.files)) | |
print("Processing files...") | |
if render_files: | |
# Split audio file into chunks, resample, then apply random effects | |
self.proc_root.mkdir(parents=True, exist_ok=True) | |
for audio_file in tqdm(self.files, total=len(self.files)): | |
chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size) | |
for chunk in chunks: | |
resampled_chunk = torchaudio.functional.resample( | |
chunk, orig_sr, sample_rate | |
) | |
if resampled_chunk.shape[-1] < chunk_size: | |
# Skip if chunk is too small | |
continue | |
dry, wet, dry_effects, wet_effects = self.process_effects( | |
resampled_chunk | |
) | |
output_dir = self.proc_root / str(self.num_chunks) | |
output_dir.mkdir(exist_ok=True) | |
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate) | |
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate) | |
torch.save(dry_effects, output_dir / "dry_effects.pt") | |
torch.save(wet_effects, output_dir / "wet_effects.pt") | |
self.num_chunks += 1 | |
else: | |
self.num_chunks = len(list(self.proc_root.iterdir())) | |
print( | |
f"Found {len(self.files)} {self.mode} files .\n" | |
f"Total chunks: {self.num_chunks}" | |
) | |
def __len__(self): | |
return self.num_chunks | |
def __getitem__(self, idx): | |
input_file = self.proc_root / str(idx) / "input.wav" | |
target_file = self.proc_root / str(idx) / "target.wav" | |
dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt") | |
wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt") | |
input, sr = torchaudio.load(input_file) | |
target, sr = torchaudio.load(target_file) | |
return (input, target, dry_effect_names, wet_effect_names) | |
def validate_effect_input(self): | |
for effect in self.effects.values(): | |
if type(effect) not in ALL_EFFECTS: | |
raise ValueError( | |
f"Effect {effect} not found in ALL_EFFECTS. " | |
f"Please choose from {ALL_EFFECTS}" | |
) | |
for effect in self.effects_to_use: | |
if effect not in self.effects.keys(): | |
raise ValueError( | |
f"Effect {effect} not found in self.effects. " | |
f"Please choose from {self.effects.keys()}" | |
) | |
for effect in self.effects_to_remove: | |
if effect not in self.effects.keys(): | |
raise ValueError( | |
f"Effect {effect} not found in self.effects. " | |
f"Please choose from {self.effects.keys()}" | |
) | |
kept_fx = list( | |
OrderedSet(self.effects_to_use) - OrderedSet(self.effects_to_remove) | |
) | |
kept_str = "randomly" if self.shuffle_kept_effects else "in order" | |
rem_fx = self.effects_to_remove | |
rem_str = "randomly" if self.shuffle_removed_effects else "in order" | |
if self.max_kept_effects == -1: | |
num_kept_str = len(kept_fx) | |
else: | |
num_kept_str = f"Up to {self.max_kept_effects}" | |
if self.max_removed_effects == -1: | |
num_rem_str = len(rem_fx) | |
else: | |
num_rem_str = f"Up to {self.max_removed_effects}" | |
print( | |
f"Effect Summary: \n" | |
f"Apply kept effects: {kept_fx} ({num_kept_str}, chosen {kept_str}) -> Dry\n" | |
f"Apply remove effects: {rem_fx} ({num_rem_str}, chosen {rem_str}) -> Wet\n" | |
) | |
return kept_fx | |
def process_effects(self, dry: torch.Tensor): | |
# Apply Kept Effects | |
# Shuffle effects if specified | |
if self.shuffle_kept_effects: | |
effect_indices = torch.randperm(len(self.effects_to_keep)) | |
else: | |
effect_indices = torch.arange(len(self.effects_to_keep)) | |
# Up to max_kept_effects | |
if self.max_kept_effects != -1: | |
num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1 | |
else: | |
num_kept_effects = len(self.effects_to_keep) | |
effect_indices = effect_indices[:num_kept_effects] | |
# Index in effect settings | |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices] | |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply] | |
# Apply | |
dry_labels = [] | |
for effect in effects_to_apply: | |
dry = effect(dry) | |
dry_labels.append(ALL_EFFECTS.index(type(effect))) | |
# Apply effects_to_remove | |
# Shuffle effects if specified | |
wet = torch.clone(dry) | |
if self.shuffle_removed_effects: | |
effect_indices = torch.randperm(len(self.effects_to_remove)) | |
else: | |
effect_indices = torch.arange(len(self.effects_to_remove)) | |
# Up to max_removed_effects | |
if self.max_removed_effects != -1: | |
num_kept_effects = int(torch.rand(1).item() * (self.max_removed_effects)) | |
else: | |
num_kept_effects = len(self.effects_to_remove) | |
effect_indices = effect_indices[: self.max_removed_effects] | |
# Index in effect settings | |
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices] | |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply] | |
# Apply | |
wet_labels = [] | |
for effect in effects_to_apply: | |
wet = effect(wet) | |
wet_labels.append(ALL_EFFECTS.index(type(effect))) | |
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS)) | |
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS)) | |
for label_idx in wet_labels: | |
wet_labels_tensor[label_idx] = 1.0 | |
for label_idx in dry_labels: | |
dry_labels_tensor[label_idx] = 1.0 | |
# Normalize | |
normalized_dry = self.normalize(dry) | |
normalized_wet = self.normalize(wet) | |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor | |
class VocalSetDatamodule(pl.LightningDataModule): | |
def __init__( | |
self, | |
train_dataset, | |
val_dataset, | |
test_dataset, | |
*, | |
batch_size: int, | |
num_workers: int, | |
pin_memory: bool = False, | |
**kwargs: int, | |
) -> None: | |
super().__init__() | |
self.train_dataset = train_dataset | |
self.val_dataset = val_dataset | |
self.test_dataset = test_dataset | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
def setup(self, stage: Any = None) -> None: | |
pass | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.train_dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=True, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.val_dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
) | |
def test_dataloader(self) -> DataLoader: | |
return DataLoader( | |
dataset=self.test_dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=False, | |
) | |