| import torch |
| import torchaudio |
| import torchaudio.transforms as T |
| import torch.nn.functional as F |
| import torchaudio.functional as AF |
|
|
| import numpy as np |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| from pathlib import Path |
| import random |
|
|
| import noisereduce as nr |
| import librosa |
|
|
|
|
| import scipy |
|
|
| import pickle |
| import os |
| from tqdm import tqdm |
|
|
|
|
| class Load: |
| """Loads an audio signal into memory in normalized form""" |
| def __init__(self): |
| pass |
|
|
| def load(self, file_path): |
| signal, sample_rate = torchaudio.load(file_path, channels_first=True, normalize=True) |
| return signal, sample_rate |
|
|
| class StereoToMono: |
| """Applies mapping from stereo to mono""" |
| def __init__(self): |
| pass |
| |
| def stereo_to_mono(self, stereo_signal): |
| mono_signal = stereo_signal.mean(dim=0, keepdim=True) |
| return mono_signal |
| |
| class Resample: |
| """Applies resampling onto a signal""" |
| def __init__(self): |
| self.sr_in = None |
| self.sr_out = None |
|
|
| def resample(self, signal, sr_in, sr_out, debug = True): |
| self.sr_in = sr_in |
| self.sr_out = sr_out |
| if sr_in == sr_out: |
| print('No remsampling needed') if debug else None |
| return signal, sr_out |
| print('Resampling the signal...') |
| resampler = torchaudio.transforms.Resample(orig_freq=self.sr_in, new_freq=self.sr_out) |
| return resampler(signal), sr_out |
|
|
| class NoiseRemover: |
| def __init__(self): |
| self._sr = None |
| self._signal = None |
| self._denoised_signal = None |
| |
| def remove_noise(self, signal, sr): |
| self._sr = sr |
| signal = signal.squeeze(0).numpy() |
| self._signal = signal |
| denoised = nr.reduce_noise(y = signal, sr = sr) |
| self._denoised_signal = torch.tensor(denoised).unsqueeze(0) |
| return self._denoised_signal,sr |
| |
|
|
| class TruncateOrPad: |
| """Dynamically truncates or pads depending on the signal""" |
| def __init__(self, max_duration: int, sr_out: int = 16_000): |
| self.max_duration = max_duration |
| self.sr_out = sr_out |
| self.tot_samples_expected = sr_out * max_duration |
|
|
| def truncate_or_pad(self, signal, debug = True): |
| tot_samples = signal.shape[-1] |
| if tot_samples == self.tot_samples_expected: |
| print('Signal already at max duration') if debug else None |
| return signal |
| elif tot_samples > self.tot_samples_expected: |
| print('Truncating the signal') |
| return self._truncate(signal) |
| else: |
| print('Padding the signal') |
| return self._pad(signal) |
|
|
| def _truncate(self, signal): |
| return signal[..., :self.tot_samples_expected] |
|
|
| def _pad(self, signal): |
| pad_amount = self.tot_samples_expected - signal.shape[-1] |
| return F.pad(signal, (0, pad_amount)) |
|
|
| class FeatureExtractor: |
| """Extracts features: linear, log spectrograms, mel spectrograms""" |
|
|
| def __init__(self, n_fft=1024, hop_length=256, sr=16000, n_mels=80): |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.sr = sr |
| self.n_mels = n_mels |
| self._window = torch.hann_window(n_fft) |
|
|
| def stft_spec(self, signal): |
| return torch.stft( |
| signal, |
| n_fft=self.n_fft, |
| hop_length=self.hop_length, |
| window=self._window.to(device=signal.device, dtype=signal.dtype), |
| center=True, |
| return_complex=True |
| ) |
|
|
| def linear_mag(self, signal): |
| """stft -> abs""" |
| return self.stft_spec(signal).abs() |
|
|
| def linear_power(self, signal): |
| """stft -> abs -> **2""" |
| return self.linear_mag(signal).pow(2) |
|
|
| def mel_scale(self, signal): |
| """Mel spectrogram (power)""" |
| mel_spec = torchaudio.transforms.MelSpectrogram( |
| sample_rate=self.sr, |
| n_fft=self.n_fft, |
| hop_length=self.hop_length, |
| n_mels=self.n_mels, |
| center=True, |
| power=2.0 |
| )(signal) |
| return mel_spec |
|
|
| def log_mag(self, signal, eps=1e-10): |
| return 20 * torch.log10(self.linear_mag(signal) + eps) |
|
|
| def log_power(self, signal, eps=1e-10): |
| return 10 * torch.log10(self.linear_power(signal) + eps) |
|
|
| def log_mel_scale(self, signal): |
| """Log-mel spectrogram for classification""" |
| mel_spec = self.mel_scale(signal) |
| log_mel_spec = torchaudio.transforms.AmplitudeToDB(top_db=80)(mel_spec) |
| return log_mel_spec |
|
|
|
|
| class NormalizeFeatures: |
| @staticmethod |
| def min_max_normalize(mel: torch.Tensor): |
| max_val = mel.max() |
| min_val = mel.min() |
| mel_norm = (mel - min_val) / (max_val - min_val + 1e-8) |
| return mel_norm, min_val, max_val |
|
|
|
|
| class BirdDatasetSaver: |
|
|
| def __init__(self, save_dir): |
| self.save_dir = save_dir |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| def save(self, bird_category: str, audio_file_name: str, log_mel: torch.Tensor, mel_norm: torch.Tensor): |
| category_path = os.path.join(self.save_dir, bird_category) |
| classification_path = os.path.join(category_path, "classification") |
| generation_path = os.path.join(category_path, "generation") |
|
|
| os.makedirs(classification_path, exist_ok=True) |
| os.makedirs(generation_path, exist_ok=True) |
|
|
| stem = Path(audio_file_name).stem |
| torch.save(log_mel, os.path.join(classification_path, f"{stem}_logmel.pt")) |
| torch.save(mel_norm, os.path.join(generation_path, f"{stem}_mel.pt")) |
|
|
|
|
| class PreprocessingPipeline: |
| def __init__(self, save_dir, max_duration=4, sr_out=22050, n_fft=1024, hop_length=256, n_mels=80, debug = False): |
| |
| self.loader = Load() |
| self.stereo2mono = StereoToMono() |
| self.resampler = Resample() |
| self.truncate_pad = TruncateOrPad(max_duration=max_duration, sr_out=sr_out) |
| self.fe = FeatureExtractor(n_fft=n_fft, hop_length=hop_length, sr=sr_out, n_mels=n_mels) |
| self.normer = NormalizeFeatures() |
| self.saver = BirdDatasetSaver(save_dir) |
| self.sr_out = sr_out |
| self.debug = debug |
| def process_file(self, bird_category, audio_file_path): |
| audio_file_name = Path(audio_file_path).name |
|
|
| |
| signal, sr = self.loader.load(audio_file_path) |
| |
| signal = self.stereo2mono.stereo_to_mono(signal) |
| |
| signal, sr = self.resampler.resample(signal, sr, self.sr_out, self.debug) |
| |
| signal = self.truncate_pad.truncate_or_pad(signal, self.debug) |
| |
| log_mel = self.fe.log_mel_scale(signal) |
| mel = self.fe.mel_scale(signal) |
| mel_norm, _, _ = self.normer.min_max_normalize(mel) |
| |
| self.saver.save(bird_category, audio_file_name, log_mel, mel_norm) |
|
|
| def process_dataset(self, root_dir): |
| for bird_category in tqdm(os.listdir(root_dir)): |
| category_path = os.path.join(root_dir, bird_category) |
| if not os.path.isdir(category_path): |
| continue |
| for audio_file in os.listdir(category_path): |
| if not audio_file.endswith(".wav"): |
| continue |
| audio_file_path = os.path.join(category_path, audio_file) |
| self.process_file(bird_category, audio_file_path) |
|
|