Spaces:
Runtime error
Runtime error
| # coding: utf-8 | |
| __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" | |
| import itertools | |
| import multiprocessing | |
| import os | |
| import pickle | |
| import random | |
| import warnings | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from glob import glob | |
| from typing import Union | |
| import audiomentations as AU | |
| import numpy as np | |
| import pedalboard as PB | |
| import soundfile as sf | |
| import torch | |
| import torch.distributed as dist | |
| from ml_collections import ConfigDict | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from tqdm import tqdm | |
| from tqdm.auto import tqdm | |
| warnings.filterwarnings("ignore") | |
| import argparse | |
| def prepare_data( | |
| config: Union[ConfigDict, OmegaConf], args: argparse.Namespace, batch_size: int | |
| ) -> DataLoader: | |
| """ | |
| Build the training DataLoader. If torch.distributed.is_initialized() is True, | |
| construct a DDP DataLoader with DistributedSampler; otherwise, construct a regular DataLoader. | |
| Args: | |
| config: Dataset configuration passed to MSSDataset. | |
| args: Must provide data_path, results_path, dataset_type, and DataLoader settings. | |
| batch_size: Per-process mini-batch size. | |
| Returns: | |
| Configured DataLoader for the training split. | |
| """ | |
| # DDP | |
| if dist.is_initialized(): | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| if args.dataset_type != 5: | |
| ddp_batch = ( | |
| batch_size * world_size | |
| ) # maintain "num_steps" semantics across the whole world | |
| else: | |
| ddp_batch = batch_size | |
| trainset = MSSDataset( | |
| config, | |
| args.data_path, | |
| batch_size=ddp_batch, | |
| metadata_path=os.path.join( | |
| args.results_path, f"metadata_{args.dataset_type}.pkl" | |
| ), | |
| dataset_type=args.dataset_type, | |
| ) | |
| sampler = DistributedSampler( | |
| trainset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True | |
| ) | |
| train_loader = DataLoader( | |
| trainset, | |
| batch_size=batch_size, # per-process batch size | |
| sampler=sampler, # sampler handles shuffling in DDP | |
| num_workers=args.num_workers, | |
| pin_memory=args.pin_memory, | |
| persistent_workers=args.persistent_workers, | |
| prefetch_factor=args.prefetch_factor, | |
| ) | |
| else: | |
| trainset = MSSDataset( | |
| config, | |
| args.data_path, | |
| batch_size=batch_size, | |
| metadata_path=os.path.join( | |
| args.results_path, f"metadata_{args.dataset_type}.pkl" | |
| ), | |
| dataset_type=args.dataset_type, | |
| ) | |
| train_loader = DataLoader( | |
| trainset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=args.pin_memory, | |
| persistent_workers=args.persistent_workers, | |
| prefetch_factor=args.prefetch_factor, | |
| ) | |
| return train_loader | |
| def load_chunk(path, length, chunk_size, offset=None, target_channels=2): | |
| """ | |
| Returns array with shape (target_channels, chunk_size) | |
| """ | |
| if chunk_size <= length: | |
| if offset is None: | |
| start = np.random.randint(length - chunk_size + 1) | |
| else: | |
| start = offset | |
| x = sf.read(path, dtype="float32", start=start, frames=chunk_size)[0] | |
| else: | |
| if offset is None: | |
| start = 0 | |
| else: | |
| start = offset | |
| frames_to_read = length | |
| x = sf.read(path, dtype="float32", start=start, frames=frames_to_read)[0] | |
| if x.ndim == 1: | |
| x = x[:, None] | |
| if x.shape[0] < chunk_size: | |
| pad = np.zeros((chunk_size - x.shape[0], x.shape[1]), dtype=np.float32) | |
| x = np.concatenate([x, pad], axis=0) | |
| elif x.shape[0] > chunk_size: | |
| x = x[:chunk_size] | |
| ch = x.shape[1] | |
| if ch == target_channels: | |
| pass | |
| elif ch > target_channels: | |
| x = x[:, :target_channels] | |
| elif ch == 1: | |
| x = np.repeat(x, 2, axis=1) | |
| else: | |
| raise ValueError(f"Path: {path}, num_channels: {ch}") | |
| return x.T | |
| def get_track_set_length(params): | |
| path, instruments, file_types, dataset_type = params | |
| should_print = ( | |
| not dist.is_initialized() or dist.get_rank() == 0 | |
| ) and dataset_type != 7 | |
| # Check lengths of all instruments (it can be different in some cases) | |
| lengths_arr = [] | |
| for instr in instruments: | |
| length = -1 | |
| for extension in file_types: | |
| path_to_audio_file = path + "/{}.{}".format(instr, extension) | |
| if os.path.isfile(path_to_audio_file): | |
| length = sf.info(path_to_audio_file).frames | |
| break | |
| if length == -1: | |
| if should_print: | |
| print('Cant find file "{}" in folder {}'.format(instr, path)) | |
| continue | |
| lengths_arr.append(length) | |
| lengths_arr = np.array(lengths_arr) | |
| if lengths_arr.min() != lengths_arr.max() and should_print: | |
| print( | |
| f"Warning: lengths of stems are different for path: {path}. ({lengths_arr.min()} != {lengths_arr.max()})" | |
| ) | |
| # We use minimum to allow overflow for soundfile read in non-equal length cases | |
| return path, lengths_arr.min() | |
| # For multiprocessing | |
| def get_track_length(params): | |
| path = params | |
| length = sf.info(path).frames | |
| return (path, length) | |
| def process_chunk_worker(args): | |
| task, instruments, file_types, min_mean_abs, default_chunk_size = args | |
| track_path, track_length, offset, chunk_size = task | |
| try: | |
| for instrument in instruments: | |
| instrument_loud_enough = False | |
| for extension in file_types: | |
| path_to_audio_file = track_path + "/{}.{}".format(instrument, extension) | |
| if os.path.isfile(path_to_audio_file): | |
| try: | |
| source = load_chunk( | |
| path_to_audio_file, | |
| length=track_length, | |
| offset=offset, | |
| chunk_size=chunk_size, | |
| ) | |
| if np.abs(source).mean() >= min_mean_abs: | |
| instrument_loud_enough = True | |
| break | |
| except Exception: | |
| return (track_path, offset, False) | |
| if not instrument_loud_enough: | |
| return (track_path, offset, False) | |
| return (track_path, offset, True) | |
| except Exception: | |
| return (track_path, offset, False) | |
| class MSSDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| config, | |
| data_path, | |
| metadata_path="metadata.pkl", | |
| dataset_type=1, | |
| batch_size=None, | |
| verbose=True, | |
| ): | |
| self.verbose = verbose | |
| self.config = config | |
| self.dataset_type = dataset_type # 1, 2, 3, 4 or 5 | |
| self.data_path = data_path | |
| self.instruments = instruments = config.training.instruments | |
| if batch_size is None: | |
| batch_size = config.training.batch_size | |
| self.batch_size = batch_size | |
| self.file_types = ["wav", "flac"] | |
| self.metadata_path = metadata_path | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| # Augmentation block | |
| self.aug = False | |
| if "augmentations" in config: | |
| if config["augmentations"].enable is True: | |
| if self.verbose and should_print: | |
| print("Use augmentation for training") | |
| self.aug = True | |
| else: | |
| if self.verbose and should_print: | |
| print( | |
| "There is no augmentations block in config. Augmentations disabled for training..." | |
| ) | |
| metadata = self.get_metadata() | |
| if self.dataset_type in [1, 4, 5, 6, 7]: | |
| if len(metadata) > 0: | |
| if self.verbose and should_print: | |
| print("Found tracks in dataset: {}".format(len(metadata))) | |
| else: | |
| if should_print: | |
| print("No tracks found for training. Check paths you provided!") | |
| exit() | |
| else: | |
| for instr in self.instruments: | |
| if self.verbose and should_print: | |
| print( | |
| "Found tracks for {} in dataset: {}".format( | |
| instr, len(metadata[instr]) | |
| ) | |
| ) | |
| self.metadata = metadata | |
| self.chunk_size = config.audio.chunk_size | |
| self.min_mean_abs = config.audio.min_mean_abs | |
| self.do_chunks = ( | |
| config.training.get("precompute_chunks", False) | |
| and float(self.min_mean_abs) > 0 | |
| ) | |
| # For dataset_type 5 - precompute all chunks | |
| if ( | |
| self.dataset_type == 5 | |
| or (self.dataset_type == 4 or self.dataset_type == 6) | |
| and self.do_chunks | |
| ): | |
| self._initialize_chunks_metadata() | |
| if self.dataset_type == 7: | |
| self._build_class_to_tracks() | |
| def __len__(self): | |
| if self.dataset_type == 5: | |
| return len(self.chunks_metadata) | |
| return self.config.training.num_steps * self.batch_size | |
| def __getitem__(self, index): | |
| if self.dataset_type == 7: | |
| res, mix, active_stem_ids = self.load_class_balanced_aligned() | |
| elif self.dataset_type == 5: | |
| track_path, offset = self.chunks_metadata[index] | |
| res = self._load_chunk_by_offset(track_path, offset) | |
| elif self.dataset_type in [1, 2, 3]: | |
| res = self.load_random_mix() | |
| else: # type 4 or 6 | |
| if self.do_chunks: | |
| track_path, offset = self.chunks_metadata[ | |
| np.random.randint(len(self.chunks_metadata)) | |
| ] | |
| res = self._load_chunk_by_offset(track_path, offset) | |
| else: | |
| if self.dataset_type == 6: | |
| res, mix = self.load_aligned_data() | |
| else: | |
| res, _ = self.load_aligned_data() | |
| # Randomly change loudness of each stem | |
| if self.aug: | |
| if "loudness" in self.config["augmentations"]: | |
| if self.config["augmentations"]["loudness"]: | |
| loud_values = np.random.uniform( | |
| low=self.config["augmentations"]["loudness_min"], | |
| high=self.config["augmentations"]["loudness_max"], | |
| size=(len(res),), | |
| ) | |
| loud_values = torch.tensor(loud_values, dtype=torch.float32) | |
| res *= loud_values[:, None, None] | |
| if self.dataset_type != 6 and self.dataset_type != 7: | |
| mix = res.sum(0) | |
| if self.aug: | |
| if "mp3_compression_on_mixture" in self.config["augmentations"]: | |
| apply_aug = AU.Mp3Compression( | |
| min_bitrate=self.config["augmentations"][ | |
| "mp3_compression_on_mixture_bitrate_min" | |
| ], | |
| max_bitrate=self.config["augmentations"][ | |
| "mp3_compression_on_mixture_bitrate_max" | |
| ], | |
| backend=self.config["augmentations"][ | |
| "mp3_compression_on_mixture_backend" | |
| ], | |
| p=self.config["augmentations"]["mp3_compression_on_mixture"], | |
| ) | |
| mix_conv = mix.cpu().numpy().astype(np.float32) | |
| required_shape = mix_conv.shape | |
| mix = apply_aug(samples=mix_conv, sample_rate=44100) | |
| # Sometimes it gives longer audio (so we cut) | |
| if mix.shape != required_shape: | |
| mix = mix[..., : required_shape[-1]] | |
| mix = torch.tensor(mix, dtype=torch.float32) | |
| # If we need to optimize only given stem | |
| if self.config.training.target_instrument is not None: | |
| index = self.config.training.instruments.index( | |
| self.config.training.target_instrument | |
| ) | |
| return res[index : index + 1], mix | |
| if self.dataset_type == 7: | |
| return res, mix, active_stem_ids | |
| return res, mix | |
| def _build_class_to_tracks(self): | |
| import json | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| cache_path = "class_to_tracks_cache.json" | |
| total_tracks = len(self.metadata) | |
| max_ratio = self.config.training.get("max_class_presence_ratio", 0.4) | |
| if os.path.isfile(cache_path): | |
| if should_print: | |
| print("[dataset_type=7] Loading class_to_tracks from cache") | |
| with open(cache_path, "r", encoding="utf8") as f: | |
| cache = json.load(f) | |
| if ( | |
| cache.get("total_tracks") == total_tracks | |
| and cache.get("max_ratio") == max_ratio | |
| ): | |
| self.class_to_tracks = cache["class_to_tracks"] | |
| self.available_classes = list(self.class_to_tracks.keys()) | |
| if should_print: | |
| print( | |
| f"[dataset_type=7] Loaded {len(self.available_classes)} classes from cache" | |
| ) | |
| return | |
| else: | |
| if should_print: | |
| print("[dataset_type=7] Cache invalid, rebuilding") | |
| class_to_tracks = {instr: [] for instr in self.instruments} | |
| track_iter = self.metadata | |
| if should_print: | |
| track_iter = tqdm( | |
| self.metadata, | |
| desc="[dataset_type=7] Building class_to_tracks", | |
| total=total_tracks, | |
| ) | |
| for track_path, _ in track_iter: | |
| for instr in self.instruments: | |
| for ext in self.file_types: | |
| path = f"{track_path}/{instr}.{ext}" | |
| if os.path.isfile(path): | |
| class_to_tracks[instr].append(track_path) | |
| break | |
| filtered_class_to_tracks = {} | |
| for instr, tracks in class_to_tracks.items(): | |
| count = len(tracks) | |
| ratio = count / total_tracks | |
| if count == 0: | |
| continue | |
| if ratio > max_ratio: | |
| if should_print: | |
| print( | |
| f"[dataset_type=7] Skip frequent stem '{instr}': " | |
| f"{count}/{total_tracks} ({ratio:.1%})" | |
| ) | |
| continue | |
| filtered_class_to_tracks[instr] = tracks | |
| if len(filtered_class_to_tracks) == 0: | |
| raise RuntimeError( | |
| "dataset_type 7: all classes were filtered out by frequency threshold" | |
| ) | |
| self.class_to_tracks = filtered_class_to_tracks | |
| self.available_classes = list(filtered_class_to_tracks.keys()) | |
| if should_print: | |
| print("[dataset_type=7] Saving class_to_tracks cache") | |
| with open(cache_path, "w", encoding="utf8") as f: | |
| json.dump( | |
| { | |
| "total_tracks": total_tracks, | |
| "max_ratio": max_ratio, | |
| "class_to_tracks": filtered_class_to_tracks, | |
| }, | |
| f, | |
| indent=2, | |
| ) | |
| if should_print: | |
| print( | |
| f"[dataset_type=7] Using {len(self.available_classes)} balanced classes " | |
| f"out of {len(self.instruments)} instruments" | |
| ) | |
| def load_class_balanced_aligned(self): | |
| """ | |
| 1) Randomly choose instrument (class) | |
| 2) Randomly choose track containing this instrument | |
| 3) Load aligned chunk from this track | |
| """ | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| instr = random.choice(self.available_classes) | |
| track_path = random.choice(self.class_to_tracks[instr]) | |
| # Find track length | |
| track_length = None | |
| for path, length in self.metadata: | |
| if path == track_path: | |
| track_length = length | |
| break | |
| if track_length is None: | |
| raise RuntimeError(f"Track length not found: {track_path}") | |
| if track_length >= self.chunk_size: | |
| offset = np.random.randint(track_length - self.chunk_size + 1) | |
| else: | |
| offset = None | |
| mix = None | |
| for extension in self.file_types: | |
| path_to_mix_file = f"{track_path}/mixture.{extension}" | |
| if os.path.isfile(path_to_mix_file): | |
| try: | |
| mix = load_chunk( | |
| path_to_mix_file, track_length, self.chunk_size, offset=offset | |
| ) | |
| break | |
| except Exception as e: | |
| print(e) | |
| res = [] | |
| active_stem_ids = [] | |
| for idx, instr in enumerate(self.instruments): | |
| found = False | |
| for extension in self.file_types: | |
| path_to_audio_file = f"{track_path}/{instr}.{extension}" | |
| if os.path.isfile(path_to_audio_file): | |
| try: | |
| source = load_chunk( | |
| path_to_audio_file, | |
| track_length, | |
| self.chunk_size, | |
| offset=offset, | |
| ) | |
| active_stem_ids.append(idx) | |
| found = True | |
| break | |
| except Exception as e: | |
| print(e) | |
| if not found: | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| res.append(source) | |
| res = np.stack(res, axis=0) | |
| if mix is None: | |
| mix = np.sum(res, axis=0) | |
| if self.aug: | |
| for i, instr in enumerate(self.instruments): | |
| res[i] = self.augm_data(res[i], instr) | |
| return ( | |
| torch.tensor(res, dtype=torch.float32), | |
| torch.tensor(mix, dtype=torch.float32), | |
| active_stem_ids, | |
| ) | |
| def _initialize_chunks_metadata(self): | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| chunks_cache_path = self.metadata_path.replace(".pkl", "_chunks.pkl") | |
| current_config = { | |
| "chunk_size": self.chunk_size, | |
| "min_mean_abs": self.min_mean_abs, | |
| "instruments": sorted(self.instruments), | |
| } | |
| if os.path.exists(chunks_cache_path): | |
| try: | |
| cached_chunks = pickle.load(open(chunks_cache_path, "rb")) | |
| cached_config = cached_chunks.get("config", {}) | |
| config_matches = ( | |
| cached_config.get("chunk_size") == current_config["chunk_size"] | |
| and cached_config.get("min_mean_abs") | |
| == current_config["min_mean_abs"] | |
| and cached_config.get("instruments") | |
| == current_config["instruments"] | |
| ) | |
| if config_matches: | |
| self.chunks_metadata = cached_chunks["chunks_metadata"] | |
| if self.verbose and should_print: | |
| print( | |
| f"Loaded {len(self.chunks_metadata)} cached chunks from {chunks_cache_path}" | |
| ) | |
| else: | |
| if self.verbose and should_print: | |
| print("Config changed, recomputing chunks...") | |
| print(f"Cached config: {cached_config}") | |
| print(f"Current config: {current_config}") | |
| self.chunks_metadata = self._precompute_and_cache_chunks( | |
| chunks_cache_path, current_config | |
| ) | |
| except Exception as e: | |
| if self.verbose and should_print: | |
| print(f"Chunks cache corrupted ({e}), recomputing...") | |
| self.chunks_metadata = self._precompute_and_cache_chunks( | |
| chunks_cache_path, current_config | |
| ) | |
| else: | |
| self.chunks_metadata = self._precompute_and_cache_chunks( | |
| chunks_cache_path, current_config | |
| ) | |
| if self.verbose and should_print: | |
| print(f"Precomputed {len(self.chunks_metadata)} chunks") | |
| def _precompute_and_cache_chunks(self, cache_path, config): | |
| """Precompute all chunks and save to cache with config""" | |
| if self.dataset_type == 4 or self.dataset_type == 6: | |
| chunks_metadata = self._precompute_random_chunks() | |
| elif self.dataset_type == 5: | |
| chunks_metadata = self._precompute_chunks() | |
| else: | |
| raise "Only dataset type 4, 5 can be precomputed" | |
| cache_data = {"chunks_metadata": chunks_metadata, "config": config} | |
| pickle.dump(cache_data, open(cache_path, "wb")) | |
| return chunks_metadata | |
| def _precompute_chunks(self): | |
| """Precompute all chunks for dataset_type 5 with overlap 2 using multiprocessing""" | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| tasks = [] | |
| for track_path, track_length in self.metadata: | |
| if track_length < self.chunk_size: | |
| tasks.append((track_path, track_length, 0, track_length)) | |
| else: | |
| step = self.chunk_size // 2 | |
| num_chunks = (track_length - self.chunk_size) // step + 1 | |
| for i in range(num_chunks): | |
| offset = i * step | |
| tasks.append((track_path, track_length, offset, self.chunk_size)) | |
| if should_print: | |
| print(f"Total tasks to process: {len(tasks)}") | |
| if multiprocessing.cpu_count() > 1: | |
| chunks_metadata = self._process_tasks_parallel(tasks, should_print) | |
| else: | |
| chunks_metadata = self._process_tasks_sequential(tasks, should_print) | |
| if self.verbose and should_print: | |
| print( | |
| f"Created {len(chunks_metadata)} good chunks from {len(self.metadata)} tracks" | |
| ) | |
| return chunks_metadata | |
| def _precompute_random_chunks(self): | |
| """Precompute exact number of good chunks""" | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| target_count = self.config.training.get( | |
| "num_precompute_chunks", | |
| self.config.training.num_steps | |
| * self.batch_size | |
| * self.config.training.num_epochs, | |
| ) | |
| chunks_metadata = [] | |
| if should_print: | |
| print(f"Generating exactly {target_count} good chunks...") | |
| with tqdm(total=target_count, desc="Progress good chunks") as pbar: | |
| while len(chunks_metadata) < target_count: | |
| batch_size = self.config.training.get( | |
| "precompute_batch_for_chunks", 500 | |
| ) | |
| tasks = [] | |
| need = target_count - len(chunks_metadata) | |
| for i in range(batch_size): | |
| track_path, track_length = random.choice(self.metadata) | |
| if track_length < self.chunk_size: | |
| tasks.append((track_path, track_length, 0, track_length)) | |
| else: | |
| offset = np.random.randint(track_length - self.chunk_size + 1) | |
| tasks.append( | |
| (track_path, track_length, offset, self.chunk_size) | |
| ) | |
| if multiprocessing.cpu_count() > 1: | |
| good_chunks = self._process_tasks_parallel(tasks, False) | |
| else: | |
| good_chunks = self._process_tasks_sequential(tasks, False) | |
| chunks_metadata.extend(good_chunks) | |
| pbar.update(min(len(good_chunks), need)) | |
| chunks_metadata = chunks_metadata[:target_count] | |
| return chunks_metadata | |
| def _process_tasks_sequential(self, tasks, should_print): | |
| chunks_metadata = [] | |
| pbar = tqdm(tasks, desc="Processing chunks") if should_print else tasks | |
| for task in pbar: | |
| track_path, track_length, offset, chunk_size = task | |
| if self._is_chunk_loud_enough(track_path, offset, chunk_size, track_length): | |
| chunks_metadata.append((track_path, offset)) | |
| return chunks_metadata | |
| def _process_tasks_parallel(self, tasks, should_print): | |
| chunks_metadata = [] | |
| with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool: | |
| worker_args = [ | |
| ( | |
| task, | |
| self.instruments, | |
| self.file_types, | |
| self.min_mean_abs, | |
| self.chunk_size, | |
| ) | |
| for task in tasks | |
| ] | |
| results = [] | |
| if should_print: | |
| with tqdm(total=len(tasks), desc="Processing chunks") as pbar: | |
| for i, result in enumerate( | |
| pool.imap_unordered(process_chunk_worker, worker_args) | |
| ): | |
| results.append(result) | |
| pbar.update(1) | |
| else: | |
| for result in pool.imap_unordered(process_chunk_worker, worker_args): | |
| results.append(result) | |
| for result in results: | |
| track_path, offset, is_loud_enough = result | |
| if is_loud_enough: | |
| chunks_metadata.append((track_path, offset)) | |
| return chunks_metadata | |
| def _is_chunk_loud_enough(self, track_path, offset, chunk_size, track_length): | |
| try: | |
| for instrument in self.instruments: | |
| instrument_loud_enough = False | |
| for extension in self.file_types: | |
| path_to_audio_file = track_path + "/{}.{}".format( | |
| instrument, extension | |
| ) | |
| if os.path.isfile(path_to_audio_file): | |
| try: | |
| source = load_chunk( | |
| path_to_audio_file, | |
| length=track_length, | |
| offset=offset, | |
| chunk_size=chunk_size, | |
| ) | |
| if np.abs(source).mean() >= self.min_mean_abs: | |
| instrument_loud_enough = True | |
| break | |
| except Exception as e: | |
| if not dist.is_initialized() or dist.get_rank() == 0: | |
| print( | |
| "Error loading: {} Path: {}".format( | |
| e, path_to_audio_file | |
| ) | |
| ) | |
| return False | |
| if not instrument_loud_enough: | |
| return False | |
| return True | |
| except Exception as e: | |
| if not dist.is_initialized() or dist.get_rank() == 0: | |
| print( | |
| "Error checking chunk loudness: {} Path: {}".format(e, track_path) | |
| ) | |
| return False | |
| def read_from_metadata_cache(self, track_paths, instr=None): | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| metadata = [] | |
| if os.path.isfile(self.metadata_path): | |
| if self.verbose and should_print: | |
| print("Found metadata cache file: {}".format(self.metadata_path)) | |
| old_metadata = pickle.load(open(self.metadata_path, "rb")) | |
| else: | |
| return track_paths, metadata | |
| if instr: | |
| old_metadata = old_metadata[instr] | |
| # We will not re-read tracks existed in old metadata file | |
| track_paths_set = set(track_paths) | |
| for old_path, file_size in old_metadata: | |
| if old_path in track_paths_set: | |
| metadata.append([old_path, file_size]) | |
| track_paths_set.remove(old_path) | |
| track_paths = list(track_paths_set) | |
| if len(metadata) > 0 and should_print: | |
| print("Old metadata was used for {} tracks.".format(len(metadata))) | |
| return track_paths, metadata | |
| def get_metadata(self): | |
| read_metadata_procs = multiprocessing.cpu_count() - 2 | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| if "read_metadata_procs" in self.config["training"]: | |
| read_metadata_procs = int(self.config["training"]["read_metadata_procs"]) | |
| if self.verbose and should_print: | |
| print( | |
| "Dataset type:", | |
| self.dataset_type, | |
| "Processes to use:", | |
| read_metadata_procs, | |
| "\nCollecting metadata for", | |
| str(self.data_path), | |
| ) | |
| if self.dataset_type in [1, 4, 5, 6, 7]: # Added type 7 | |
| track_paths = [] | |
| if type(self.data_path) == list: | |
| for tp in self.data_path: | |
| tracks_for_folder = sorted(glob(tp + "/*")) | |
| if len(tracks_for_folder) == 0 and should_print: | |
| print( | |
| "Warning: no tracks found in folder '{}'. Please check it!".format( | |
| tp | |
| ) | |
| ) | |
| track_paths += tracks_for_folder | |
| else: | |
| track_paths += sorted(glob(self.data_path + "/*")) | |
| track_paths = [ | |
| path | |
| for path in track_paths | |
| if os.path.basename(path)[0] != "." and os.path.isdir(path) | |
| ] | |
| track_paths, metadata = self.read_from_metadata_cache(track_paths, None) | |
| if read_metadata_procs <= 1: | |
| pbar = tqdm(track_paths) if should_print else track_paths | |
| for path in pbar: | |
| track_path, track_length = get_track_set_length( | |
| (path, self.instruments, self.file_types, self.dataset_type) | |
| ) | |
| metadata.append((track_path, track_length)) | |
| else: | |
| with ThreadPoolExecutor(max_workers=read_metadata_procs) as executor: | |
| futures = [ | |
| executor.submit(get_track_set_length, args) | |
| for args in zip( | |
| track_paths, | |
| itertools.repeat(self.instruments), | |
| itertools.repeat(self.file_types), | |
| itertools.repeat(self.dataset_type), | |
| ) | |
| ] | |
| if should_print: | |
| for f in tqdm(as_completed(futures), total=len(futures)): | |
| track_path, track_length = f.result() | |
| metadata.append((track_path, track_length)) | |
| else: | |
| for f in as_completed(futures): | |
| metadata.append(f.result()) | |
| elif self.dataset_type == 2: | |
| metadata = dict() | |
| for instr in self.instruments: | |
| metadata[instr] = [] | |
| track_paths = [] | |
| if type(self.data_path) == list: | |
| for tp in self.data_path: | |
| track_paths += sorted(glob(tp + "/{}/*.wav".format(instr))) | |
| track_paths += sorted(glob(tp + "/{}/*.flac".format(instr))) | |
| else: | |
| track_paths += sorted( | |
| glob(self.data_path + "/{}/*.wav".format(instr)) | |
| ) | |
| track_paths += sorted( | |
| glob(self.data_path + "/{}/*.flac".format(instr)) | |
| ) | |
| track_paths, metadata[instr] = self.read_from_metadata_cache( | |
| track_paths, instr | |
| ) | |
| if read_metadata_procs <= 1: | |
| pbar = tqdm(track_paths) if should_print else track_paths | |
| for path in pbar: | |
| length = sf.info(path).frames | |
| metadata[instr].append((path, length)) | |
| else: | |
| p = multiprocessing.Pool(processes=read_metadata_procs) | |
| track_iter = p.imap(get_track_length, track_paths) | |
| if should_print: | |
| track_iter = tqdm(track_iter, total=len(track_paths)) | |
| for out in track_iter: | |
| metadata[instr].append(out) | |
| p.close() | |
| elif self.dataset_type == 3: | |
| import pandas as pd | |
| if type(self.data_path) != list: | |
| data_path = [self.data_path] | |
| metadata = dict() | |
| for i in range(len(self.data_path)): | |
| if self.verbose and should_print: | |
| print("Reading tracks from: {}".format(self.data_path[i])) | |
| df = pd.read_csv(self.data_path[i]) | |
| skipped = 0 | |
| for instr in self.instruments: | |
| part = df[df["instrum"] == instr].copy() | |
| if should_print: | |
| print("Tracks found for {}: {}".format(instr, len(part))) | |
| for instr in self.instruments: | |
| part = df[df["instrum"] == instr].copy() | |
| metadata[instr] = [] | |
| track_paths = list(part["path"].values) | |
| track_paths, metadata[instr] = self.read_from_metadata_cache( | |
| track_paths, instr | |
| ) | |
| pbar = tqdm(track_paths) if should_print else track_paths | |
| for path in pbar: | |
| if not os.path.isfile(path): | |
| if should_print: | |
| print("Cant find track: {}".format(path)) | |
| skipped += 1 | |
| continue | |
| # print(path) | |
| try: | |
| length = sf.info(path).frames | |
| except: | |
| if should_print: | |
| print("Problem with path: {}".format(path)) | |
| skipped += 1 | |
| continue | |
| metadata[instr].append((path, length)) | |
| if skipped > 0 and should_print: | |
| print("Missing tracks: {} from {}".format(skipped, len(df))) | |
| else: | |
| if should_print: | |
| print( | |
| "Unknown dataset type: {}. Must be 1, 2, 3, 4, 5 or 6".format( | |
| self.dataset_type | |
| ) | |
| ) | |
| exit() | |
| # Save metadata | |
| pickle.dump(metadata, open(self.metadata_path, "wb")) | |
| return metadata | |
| def load_source(self, metadata, instr): | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| while True: | |
| if self.dataset_type in [1, 4, 5, 6, 7]: | |
| track_path, track_length = random.choice(metadata) | |
| for extension in self.file_types: | |
| path_to_audio_file = track_path + "/{}.{}".format(instr, extension) | |
| if os.path.isfile(path_to_audio_file): | |
| try: | |
| source = load_chunk( | |
| path_to_audio_file, track_length, self.chunk_size | |
| ) | |
| except Exception as e: | |
| # Sometimes error during FLAC reading, catch it and use zero stem | |
| if should_print: | |
| print( | |
| "Error: {} Path: {}".format(e, path_to_audio_file) | |
| ) | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| break | |
| else: | |
| track_path, track_length = random.choice(metadata[instr]) | |
| try: | |
| source = load_chunk(track_path, track_length, self.chunk_size) | |
| except Exception as e: | |
| # Sometimes error during FLAC reading, catch it and use zero stem | |
| if should_print: | |
| print("Error: {} Path: {}".format(e, track_path)) | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| if np.abs(source).mean() >= self.min_mean_abs: # remove quiet chunks | |
| break | |
| if self.aug: | |
| source = self.augm_data(source, instr) | |
| return torch.tensor(source, dtype=torch.float32) | |
| def load_random_mix(self): | |
| res = [] | |
| for instr in self.instruments: | |
| s1 = self.load_source(self.metadata, instr) | |
| # Mixup augmentation. Multiple mix of same type of stems | |
| if self.aug: | |
| if "mixup" in self.config["augmentations"]: | |
| if self.config["augmentations"].mixup: | |
| mixup = [s1] | |
| for prob in self.config.augmentations.mixup_probs: | |
| if random.uniform(0, 1) < prob: | |
| s2 = self.load_source(self.metadata, instr) | |
| mixup.append(s2) | |
| mixup = torch.stack(mixup, dim=0) | |
| loud_values = np.random.uniform( | |
| low=self.config.augmentations.loudness_min, | |
| high=self.config.augmentations.loudness_max, | |
| size=(len(mixup),), | |
| ) | |
| loud_values = torch.tensor(loud_values, dtype=torch.float32) | |
| mixup *= loud_values[:, None, None] | |
| s1 = mixup.mean(dim=0, dtype=torch.float32) | |
| res.append(s1) | |
| res = torch.stack(res) | |
| return res | |
| def _load_chunk_by_offset(self, track_path, offset): | |
| """Load specific chunk by track path and offset""" | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| res = [] | |
| for instr in self.instruments: | |
| for extension in self.file_types: | |
| path_to_audio_file = track_path + "/{}.{}".format(instr, extension) | |
| if os.path.isfile(path_to_audio_file): | |
| try: | |
| # Get track length from metadata | |
| track_length = None | |
| for path, length in self.metadata: | |
| if path == track_path: | |
| track_length = length | |
| break | |
| if track_length is None: | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| else: | |
| source = load_chunk( | |
| path_to_audio_file, | |
| track_length, | |
| self.chunk_size, | |
| offset=offset, | |
| ) | |
| except Exception as e: | |
| if should_print: | |
| print("Error: {} Path: {}".format(e, path_to_audio_file)) | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| break | |
| else: | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| res.append(source) | |
| res = np.stack(res, axis=0) | |
| if self.aug: | |
| for i, instr in enumerate(self.instruments): | |
| res[i] = self.augm_data(res[i], instr) | |
| return torch.tensor(res, dtype=torch.float32) | |
| def load_aligned_data(self): | |
| track_path, track_length = random.choice(self.metadata) | |
| should_print = not dist.is_initialized() or dist.get_rank() == 0 | |
| attempts = 10 | |
| while attempts: | |
| if track_length >= self.chunk_size: | |
| common_offset = np.random.randint(track_length - self.chunk_size + 1) | |
| else: | |
| common_offset = None | |
| res = [] | |
| silent_chunks = 0 | |
| for i in self.instruments: | |
| found = False | |
| for extension in self.file_types: | |
| path_to_audio_file = f"{track_path}/{i}.{extension}" | |
| if os.path.isfile(path_to_audio_file): | |
| found = True | |
| try: | |
| source = load_chunk( | |
| path_to_audio_file, | |
| track_length, | |
| self.chunk_size, | |
| offset=common_offset, | |
| ) | |
| except Exception as e: | |
| if should_print: | |
| print(f"Error: {e} Path: {path_to_audio_file}") | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| break | |
| if not found: | |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) | |
| res.append(source) | |
| if np.abs(source).mean() < self.min_mean_abs: # remove quiet chunks | |
| silent_chunks += 1 | |
| mix = None | |
| for extension in self.file_types: | |
| path_to_mix_file = track_path + "/mixture.{}".format(extension) | |
| if os.path.isfile(path_to_mix_file): | |
| try: | |
| mix = load_chunk( | |
| path_to_mix_file, | |
| track_length, | |
| self.chunk_size, | |
| offset=common_offset, | |
| ) | |
| except Exception as e: | |
| if should_print: | |
| print( | |
| "Error loading mix: {} Path: {}".format( | |
| e, path_to_mix_file | |
| ) | |
| ) | |
| break | |
| if silent_chunks == 0: | |
| break | |
| attempts -= 1 | |
| if attempts <= 0 and should_print: | |
| print("Attempts max!", track_path) | |
| if common_offset is None: | |
| break | |
| try: | |
| res = np.stack(res, axis=0) | |
| except Exception as e: | |
| print( | |
| "Error during stacking stems: {} Track Length: {} Track path: {}".format( | |
| str(e), track_length, track_path | |
| ) | |
| ) | |
| res = np.zeros( | |
| (len(self.instruments), 2, self.chunk_size), dtype=np.float32 | |
| ) | |
| if mix is None: | |
| mix = res.sum(0) | |
| if self.aug: | |
| for i, instr in enumerate(self.instruments): | |
| res[i] = self.augm_data(res[i], instr) | |
| return torch.tensor(res, dtype=torch.float32), torch.tensor( | |
| mix, dtype=torch.float32 | |
| ) | |
| def augm_data(self, source, instr): | |
| # source.shape = (2, 261120) - first channels, second length | |
| source_shape = source.shape | |
| applied_augs = [] | |
| if "all" in self.config["augmentations"]: | |
| augs = self.config["augmentations"]["all"] | |
| else: | |
| augs = dict() | |
| # We need to add to all augmentations specific augs for stem. And rewrite values if needed | |
| if instr in self.config["augmentations"]: | |
| for el in self.config["augmentations"][instr]: | |
| augs[el] = self.config["augmentations"][instr][el] | |
| # Channel shuffle | |
| if "channel_shuffle" in augs: | |
| if augs["channel_shuffle"] > 0: | |
| if random.uniform(0, 1) < augs["channel_shuffle"]: | |
| source = source[::-1].copy() | |
| applied_augs.append("channel_shuffle") | |
| # Random inverse | |
| if "random_inverse" in augs: | |
| if augs["random_inverse"] > 0: | |
| if random.uniform(0, 1) < augs["random_inverse"]: | |
| source = source[:, ::-1].copy() | |
| applied_augs.append("random_inverse") | |
| # Random polarity (multiply -1) | |
| if "random_polarity" in augs: | |
| if augs["random_polarity"] > 0: | |
| if random.uniform(0, 1) < augs["random_polarity"]: | |
| source = -source.copy() | |
| applied_augs.append("random_polarity") | |
| # Random pitch shift | |
| if "pitch_shift" in augs: | |
| if augs["pitch_shift"] > 0: | |
| if random.uniform(0, 1) < augs["pitch_shift"]: | |
| apply_aug = AU.PitchShift( | |
| min_semitones=augs["pitch_shift_min_semitones"], | |
| max_semitones=augs["pitch_shift_max_semitones"], | |
| p=1.0, | |
| ) | |
| source = apply_aug(samples=source, sample_rate=44100) | |
| applied_augs.append("pitch_shift") | |
| # Random seven band parametric eq | |
| if "seven_band_parametric_eq" in augs: | |
| if augs["seven_band_parametric_eq"] > 0: | |
| if random.uniform(0, 1) < augs["seven_band_parametric_eq"]: | |
| apply_aug = AU.SevenBandParametricEQ( | |
| min_gain_db=augs["seven_band_parametric_eq_min_gain_db"], | |
| max_gain_db=augs["seven_band_parametric_eq_max_gain_db"], | |
| p=1.0, | |
| ) | |
| source = apply_aug(samples=source, sample_rate=44100) | |
| applied_augs.append("seven_band_parametric_eq") | |
| # Random tanh distortion | |
| if "tanh_distortion" in augs: | |
| if augs["tanh_distortion"] > 0: | |
| if random.uniform(0, 1) < augs["tanh_distortion"]: | |
| apply_aug = AU.TanhDistortion( | |
| min_distortion=augs["tanh_distortion_min"], | |
| max_distortion=augs["tanh_distortion_max"], | |
| p=1.0, | |
| ) | |
| source = apply_aug(samples=source, sample_rate=44100) | |
| applied_augs.append("tanh_distortion") | |
| # Random MP3 Compression | |
| if "mp3_compression" in augs: | |
| if augs["mp3_compression"] > 0: | |
| if random.uniform(0, 1) < augs["mp3_compression"]: | |
| apply_aug = AU.Mp3Compression( | |
| min_bitrate=augs["mp3_compression_min_bitrate"], | |
| max_bitrate=augs["mp3_compression_max_bitrate"], | |
| backend=augs["mp3_compression_backend"], | |
| p=1.0, | |
| ) | |
| source = apply_aug(samples=source, sample_rate=44100) | |
| applied_augs.append("mp3_compression") | |
| # Random AddGaussianNoise | |
| if "gaussian_noise" in augs: | |
| if augs["gaussian_noise"] > 0: | |
| if random.uniform(0, 1) < augs["gaussian_noise"]: | |
| apply_aug = AU.AddGaussianNoise( | |
| min_amplitude=augs["gaussian_noise_min_amplitude"], | |
| max_amplitude=augs["gaussian_noise_max_amplitude"], | |
| p=1.0, | |
| ) | |
| source = apply_aug(samples=source, sample_rate=44100) | |
| applied_augs.append("gaussian_noise") | |
| # Random TimeStretch | |
| if "time_stretch" in augs: | |
| if augs["time_stretch"] > 0: | |
| if random.uniform(0, 1) < augs["time_stretch"]: | |
| apply_aug = AU.TimeStretch( | |
| min_rate=augs["time_stretch_min_rate"], | |
| max_rate=augs["time_stretch_max_rate"], | |
| leave_length_unchanged=True, | |
| p=1.0, | |
| ) | |
| source = apply_aug(samples=source, sample_rate=44100) | |
| applied_augs.append("time_stretch") | |
| # Possible fix of shape | |
| if source_shape != source.shape: | |
| source = source[..., : source_shape[-1]] | |
| # Random Reverb | |
| if "pedalboard_reverb" in augs: | |
| if augs["pedalboard_reverb"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_reverb"]: | |
| room_size = random.uniform( | |
| augs["pedalboard_reverb_room_size_min"], | |
| augs["pedalboard_reverb_room_size_max"], | |
| ) | |
| damping = random.uniform( | |
| augs["pedalboard_reverb_damping_min"], | |
| augs["pedalboard_reverb_damping_max"], | |
| ) | |
| wet_level = random.uniform( | |
| augs["pedalboard_reverb_wet_level_min"], | |
| augs["pedalboard_reverb_wet_level_max"], | |
| ) | |
| dry_level = random.uniform( | |
| augs["pedalboard_reverb_dry_level_min"], | |
| augs["pedalboard_reverb_dry_level_max"], | |
| ) | |
| width = random.uniform( | |
| augs["pedalboard_reverb_width_min"], | |
| augs["pedalboard_reverb_width_max"], | |
| ) | |
| board = PB.Pedalboard( | |
| [ | |
| PB.Reverb( | |
| room_size=room_size, # 0.1 - 0.9 | |
| damping=damping, # 0.1 - 0.9 | |
| wet_level=wet_level, # 0.1 - 0.9 | |
| dry_level=dry_level, # 0.1 - 0.9 | |
| width=width, # 0.9 - 1.0 | |
| freeze_mode=0.0, | |
| ) | |
| ] | |
| ) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_reverb") | |
| # Random Chorus | |
| if "pedalboard_chorus" in augs: | |
| if augs["pedalboard_chorus"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_chorus"]: | |
| rate_hz = random.uniform( | |
| augs["pedalboard_chorus_rate_hz_min"], | |
| augs["pedalboard_chorus_rate_hz_max"], | |
| ) | |
| depth = random.uniform( | |
| augs["pedalboard_chorus_depth_min"], | |
| augs["pedalboard_chorus_depth_max"], | |
| ) | |
| centre_delay_ms = random.uniform( | |
| augs["pedalboard_chorus_centre_delay_ms_min"], | |
| augs["pedalboard_chorus_centre_delay_ms_max"], | |
| ) | |
| feedback = random.uniform( | |
| augs["pedalboard_chorus_feedback_min"], | |
| augs["pedalboard_chorus_feedback_max"], | |
| ) | |
| mix = random.uniform( | |
| augs["pedalboard_chorus_mix_min"], | |
| augs["pedalboard_chorus_mix_max"], | |
| ) | |
| board = PB.Pedalboard( | |
| [ | |
| PB.Chorus( | |
| rate_hz=rate_hz, | |
| depth=depth, | |
| centre_delay_ms=centre_delay_ms, | |
| feedback=feedback, | |
| mix=mix, | |
| ) | |
| ] | |
| ) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_chorus") | |
| # Random Phazer | |
| if "pedalboard_phazer" in augs: | |
| if augs["pedalboard_phazer"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_phazer"]: | |
| rate_hz = random.uniform( | |
| augs["pedalboard_phazer_rate_hz_min"], | |
| augs["pedalboard_phazer_rate_hz_max"], | |
| ) | |
| depth = random.uniform( | |
| augs["pedalboard_phazer_depth_min"], | |
| augs["pedalboard_phazer_depth_max"], | |
| ) | |
| centre_frequency_hz = random.uniform( | |
| augs["pedalboard_phazer_centre_frequency_hz_min"], | |
| augs["pedalboard_phazer_centre_frequency_hz_max"], | |
| ) | |
| feedback = random.uniform( | |
| augs["pedalboard_phazer_feedback_min"], | |
| augs["pedalboard_phazer_feedback_max"], | |
| ) | |
| mix = random.uniform( | |
| augs["pedalboard_phazer_mix_min"], | |
| augs["pedalboard_phazer_mix_max"], | |
| ) | |
| board = PB.Pedalboard( | |
| [ | |
| PB.Phaser( | |
| rate_hz=rate_hz, | |
| depth=depth, | |
| centre_frequency_hz=centre_frequency_hz, | |
| feedback=feedback, | |
| mix=mix, | |
| ) | |
| ] | |
| ) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_phazer") | |
| # Random Distortion | |
| if "pedalboard_distortion" in augs: | |
| if augs["pedalboard_distortion"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_distortion"]: | |
| drive_db = random.uniform( | |
| augs["pedalboard_distortion_drive_db_min"], | |
| augs["pedalboard_distortion_drive_db_max"], | |
| ) | |
| board = PB.Pedalboard( | |
| [ | |
| PB.Distortion( | |
| drive_db=drive_db, | |
| ) | |
| ] | |
| ) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_distortion") | |
| # Random PitchShift | |
| if "pedalboard_pitch_shift" in augs: | |
| if augs["pedalboard_pitch_shift"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_pitch_shift"]: | |
| semitones = random.uniform( | |
| augs["pedalboard_pitch_shift_semitones_min"], | |
| augs["pedalboard_pitch_shift_semitones_max"], | |
| ) | |
| board = PB.Pedalboard([PB.PitchShift(semitones=semitones)]) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_pitch_shift") | |
| # Random Resample | |
| if "pedalboard_resample" in augs: | |
| if augs["pedalboard_resample"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_resample"]: | |
| target_sample_rate = random.uniform( | |
| augs["pedalboard_resample_target_sample_rate_min"], | |
| augs["pedalboard_resample_target_sample_rate_max"], | |
| ) | |
| board = PB.Pedalboard( | |
| [PB.Resample(target_sample_rate=target_sample_rate)] | |
| ) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_resample") | |
| # Random Bitcrash | |
| if "pedalboard_bitcrash" in augs: | |
| if augs["pedalboard_bitcrash"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_bitcrash"]: | |
| bit_depth = random.uniform( | |
| augs["pedalboard_bitcrash_bit_depth_min"], | |
| augs["pedalboard_bitcrash_bit_depth_max"], | |
| ) | |
| board = PB.Pedalboard([PB.Bitcrush(bit_depth=bit_depth)]) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_bitcrash") | |
| # Random MP3Compressor | |
| if "pedalboard_mp3_compressor" in augs: | |
| if augs["pedalboard_mp3_compressor"] > 0: | |
| if random.uniform(0, 1) < augs["pedalboard_mp3_compressor"]: | |
| vbr_quality = random.uniform( | |
| augs["pedalboard_mp3_compressor_pedalboard_mp3_compressor_min"], | |
| augs["pedalboard_mp3_compressor_pedalboard_mp3_compressor_max"], | |
| ) | |
| board = PB.Pedalboard([PB.MP3Compressor(vbr_quality=vbr_quality)]) | |
| source = board(source, 44100) | |
| applied_augs.append("pedalboard_mp3_compressor") | |
| # print(applied_augs) | |
| return source | |