|
import os
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
|
from tqdm import tqdm
|
|
import librosa
|
|
import logging
|
|
import argparse
|
|
import json
|
|
import time
|
|
import torchaudio
|
|
from torchvision import transforms
|
|
import pickle
|
|
import random
|
|
|
|
def configure_logging():
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler()
|
|
])
|
|
logging.info("Logging is set up.")
|
|
print("Logging is set up.")
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Spectrogram Dataset Preparation')
|
|
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
|
|
return parser.parse_args()
|
|
|
|
def load_config(config_path):
|
|
logging.info(f"Loading configuration from {config_path}")
|
|
print(f"Loading configuration from {config_path}")
|
|
try:
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
logging.info("Configuration loaded successfully")
|
|
print("Configuration loaded successfully")
|
|
return config
|
|
except Exception as e:
|
|
logging.error(f"Failed to load config file: {e}", exc_info=True)
|
|
print(f"Failed to load config file: {e}")
|
|
raise
|
|
|
|
def validate_audio(y, sr, target_sr=44100, min_duration=0.1):
|
|
logging.debug(f"Validating audio with sr={sr}, target_sr={target_sr}, min_duration={min_duration}")
|
|
print(f"Validating audio with sr={sr}, target_sr={target_sr}, min_duration={min_duration}")
|
|
if sr != target_sr:
|
|
logging.warning(f"Resampling from {sr} to {target_sr}")
|
|
print(f"Resampling from {sr} to {target_sr}")
|
|
y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
|
|
if len(y) < min_duration * target_sr:
|
|
pad_length = int(min_duration * target_sr - len(y))
|
|
y = np.pad(y, (0, pad_length), mode='constant')
|
|
logging.info(f"Audio file padded with {pad_length} samples")
|
|
print(f"Audio file padded with {pad_length} samples")
|
|
return y, target_sr
|
|
|
|
def strip_silence(y, sr, top_db=20, pad_duration=0.1):
|
|
logging.debug(f"Stripping silence with sr={sr}, top_db={top_db}, pad_duration={pad_duration}")
|
|
print(f"Stripping silence with sr={sr}, top_db={top_db}, pad_duration={pad_duration}")
|
|
y_trimmed, _ = librosa.effects.trim(y, top_db=top_db)
|
|
pad_length = int(pad_duration * sr)
|
|
y_padded = np.pad(y_trimmed, pad_length, mode='constant')
|
|
return y_padded
|
|
|
|
def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1):
|
|
try:
|
|
logging.info(f"Loading file: {file_path}")
|
|
print(f"Loading file: {file_path}")
|
|
y, sr = librosa.load(file_path, sr=None)
|
|
logging.debug(f"Loaded file: {file_path} with sr={sr}")
|
|
print(f"Loaded file: {file_path} with sr={sr}")
|
|
y, sr = validate_audio(y, sr, target_sr, min_duration)
|
|
y = strip_silence(y, sr)
|
|
except Exception as e:
|
|
logging.error(f"Error reading {file_path}: {e}", exc_info=True)
|
|
print(f"Error reading {file_path}: {e}")
|
|
return None
|
|
|
|
y = librosa.util.normalize(y)
|
|
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
|
|
S_dB = librosa.power_to_db(S, ref=np.max)
|
|
logging.debug(f"Generated spectrogram for file: {file_path}")
|
|
print(f"Generated spectrogram for file: {file_path}")
|
|
|
|
return S_dB
|
|
|
|
def validate_spectrogram(spectrogram, n_mels=128):
|
|
logging.debug(f"Validating spectrogram with n_mels={n_mels}")
|
|
print(f"Validating spectrogram with n_mels={n_mels}")
|
|
if spectrogram.shape[0] != n_mels:
|
|
raise ValueError(f"Spectrogram has incorrect number of mel bands: {spectrogram.shape[0]}")
|
|
if spectrogram.shape[1] == 0:
|
|
raise ValueError("Spectrogram has zero frames")
|
|
return True
|
|
|
|
def save_spectrogram(spectrogram, save_path):
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
np.save(save_path, spectrogram)
|
|
logging.debug(f"Spectrogram saved at {save_path}")
|
|
print(f"Spectrogram saved at {save_path}")
|
|
|
|
class AddNoise(torch.nn.Module):
|
|
def __init__(self, noise_type='white', snr=10):
|
|
super(AddNoise, self).__init__()
|
|
self.noise_type = noise_type
|
|
self.snr = snr
|
|
|
|
def forward(self, waveform):
|
|
noise = torch.randn_like(waveform)
|
|
signal_power = waveform.norm(p=2)
|
|
noise_power = noise.norm(p=2)
|
|
noise = noise * (signal_power / noise_power) / (10 ** (self.snr / 20))
|
|
return waveform + noise
|
|
|
|
class SpectrogramDataset(Dataset):
|
|
def __init__(self, config, directory, process_new=True):
|
|
logging.info("Initializing SpectrogramDataset...")
|
|
print("Initializing SpectrogramDataset...")
|
|
self.directory = directory
|
|
self.output_directory = config['output_directory']
|
|
self.spectrograms = []
|
|
self.labels = []
|
|
self.label_to_index = {}
|
|
self.process_new = process_new
|
|
self.config = config
|
|
|
|
|
|
self.cache_path = os.path.join(self.output_directory, 'cache_data.npy')
|
|
self.dataset_path = os.path.join(self.output_directory, 'spectrogram_dataset.pkl')
|
|
|
|
if os.path.exists(self.dataset_path):
|
|
self.load_dataset()
|
|
else:
|
|
if os.path.exists(self.cache_path):
|
|
os.remove(self.cache_path)
|
|
logging.info(f"Cache cleared at {self.cache_path}")
|
|
print(f"Cache cleared at {self.cache_path}")
|
|
|
|
self.load_data()
|
|
self.save_dataset()
|
|
|
|
self.transforms = transforms.Compose([
|
|
torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
|
|
torchaudio.transforms.TimeMasking(time_mask_param=30)
|
|
]) if self.config['augment'] else None
|
|
|
|
self.audio_transforms = torch.nn.Sequential(
|
|
AddNoise(snr=self.config['noise_snr']),
|
|
torchaudio.transforms.PitchShift(self.config['sample_rate'], n_steps=self.config['pitch_steps'])
|
|
) if self.config['augment'] else None
|
|
logging.info("SpectrogramDataset initialized successfully")
|
|
print("SpectrogramDataset initialized successfully")
|
|
|
|
def save_dataset(self):
|
|
with open(self.dataset_path, 'wb') as f:
|
|
pickle.dump(self, f)
|
|
logging.info(f"Dataset object saved at {self.dataset_path}")
|
|
print(f"Dataset object saved at {self.dataset_path}")
|
|
|
|
def load_dataset(self):
|
|
with open(self.dataset_path, 'rb') as f:
|
|
obj = pickle.load(f)
|
|
self.__dict__.update(obj.__dict__)
|
|
logging.info(f"Dataset object loaded from {self.dataset_path}")
|
|
print(f"Dataset object loaded from {self.dataset_path}")
|
|
|
|
def process_file(self, file_path):
|
|
logging.debug(f"Processing file: {file_path}")
|
|
print(f"Processing file: {file_path}")
|
|
try:
|
|
label = os.path.basename(os.path.dirname(file_path))
|
|
if label not in self.label_to_index:
|
|
self.label_to_index[label] = len(self.label_to_index)
|
|
relative_path = os.path.relpath(file_path, self.directory)
|
|
spectrogram_path = os.path.join(self.output_directory, os.path.splitext(relative_path)[0] + '_spectrogram.npy')
|
|
if not os.path.exists(spectrogram_path) and self.process_new:
|
|
spectrogram = audio_to_spectrogram(file_path, n_fft=self.config['n_fft'], hop_length=self.config['hop_length'], n_mels=self.config['n_mels'], target_sr=self.config['sample_rate'], min_duration=self.config['min_duration'])
|
|
if spectrogram is not None:
|
|
if spectrogram.shape[1] > self.config['max_frames']:
|
|
spectrogram = spectrogram[:, :self.config['max_frames']]
|
|
try:
|
|
validate_spectrogram(spectrogram, n_mels=self.config['n_mels'])
|
|
save_spectrogram(spectrogram, spectrogram_path)
|
|
logging.debug(f"Spectrogram saved: {spectrogram_path}")
|
|
print(f"Spectrogram saved: {spectrogram_path}")
|
|
except Exception as e:
|
|
logging.error(f"Error validating/saving spectrogram: {e}", exc_info=True)
|
|
print(f"Error validating/saving spectrogram: {e}")
|
|
if os.path.exists(spectrogram_path):
|
|
try:
|
|
spectrogram = np.load(spectrogram_path)
|
|
validate_spectrogram(spectrogram, n_mels=self.config['n_mels'])
|
|
spectrogram_tensor = torch.tensor(spectrogram, dtype=torch.float32)
|
|
self.spectrograms.append(spectrogram_tensor)
|
|
self.labels.append(self.label_to_index[label])
|
|
logging.debug(f"Spectrogram loaded and appended for file: {file_path}")
|
|
print(f"Spectrogram loaded and appended for file: {file_path}")
|
|
except Exception as e:
|
|
logging.error(f"Error loading spectrogram {spectrogram_path}: {e}", exc_info=True)
|
|
print(f"Error loading spectrogram {spectrogram_path}: {e}")
|
|
except Exception as e:
|
|
logging.error(f"Exception in process_file: {e}", exc_info=True)
|
|
print(f"Exception in process_file: {e}")
|
|
|
|
def load_data(self):
|
|
start_time = time.time()
|
|
logging.info("Starting to load and process files...")
|
|
print("Starting to load and process files...")
|
|
files_to_process = [os.path.join(root, file) for root, _, files in os.walk(self.directory) for file in files if file.lower().endswith('.wav')]
|
|
total_files = len(files_to_process)
|
|
logging.info(f"Total files to process: {total_files}")
|
|
print(f"Total files to process: {total_files}")
|
|
|
|
for file_path in tqdm(files_to_process, desc="Processing files"):
|
|
self.process_file(file_path)
|
|
|
|
end_time = time.time()
|
|
logging.info(f"Data loading and processing took {end_time - start_time:.2f} seconds")
|
|
print(f"Data loading and processing took {end_time - start_time:.2f} seconds")
|
|
|
|
self.save_cached_data(self.cache_path)
|
|
|
|
def save_cached_data(self, cache_path):
|
|
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
|
np.save(cache_path, {'spectrograms': self.spectrograms, 'labels': self.labels})
|
|
logging.debug(f"Cached data saved at {cache_path}")
|
|
print(f"Cached data saved at {cache_path}")
|
|
|
|
def __len__(self):
|
|
return len(self.spectrograms)
|
|
|
|
def __getitem__(self, idx):
|
|
spectrogram, label = self.spectrograms[idx], self.labels[idx]
|
|
if self.config['augment']:
|
|
if spectrogram.shape[1] >= 256:
|
|
spectrogram = self.audio_transforms(spectrogram.unsqueeze(0)).squeeze(0)
|
|
spectrogram = self.transforms(spectrogram.unsqueeze(0)).squeeze(0)
|
|
return spectrogram, label
|
|
|
|
def collate_fn(batch):
|
|
spectrograms, labels = zip(*batch)
|
|
labels = torch.tensor(labels, dtype=torch.long)
|
|
max_length = max(s.size(1) for s in spectrograms)
|
|
max_freq = max(s.size(0) for s in spectrograms)
|
|
spectrograms_padded = torch.zeros(len(spectrograms), max_freq, max_length)
|
|
for i, s in enumerate(spectrograms):
|
|
if s.dim() == 3 and s.size(2) == 1:
|
|
s = s.squeeze(2)
|
|
spectrograms_padded[i, :s.size(0), :s.size(1)] = s
|
|
return spectrograms_padded, labels
|
|
|
|
class SmartBatchingSampler(Sampler):
|
|
def __init__(self, data_source, batch_size):
|
|
self.data_source = data_source
|
|
self.batch_size = batch_size
|
|
|
|
def __iter__(self):
|
|
sorted_indices = sorted(range(len(self.data_source)), key=lambda i: self.data_source[i][0].shape[1], reverse=True)
|
|
pooled_indices = [sorted_indices[i:i + self.batch_size] for i in range(0, len(sorted_indices), self.batch_size)]
|
|
random.shuffle(pooled_indices)
|
|
for p in pooled_indices:
|
|
yield from p
|
|
if len(sorted_indices) % self.batch_size != 0:
|
|
yield from sorted_indices[-(len(sorted_indices) % self.batch_size):]
|
|
|
|
def __len__(self):
|
|
return len(self.data_source) // self.batch_size
|
|
|
|
if __name__ == '__main__':
|
|
print("Starting script")
|
|
try:
|
|
args = parse_args()
|
|
print(f"Arguments parsed: {args}")
|
|
config = load_config(args.config)
|
|
print(f"Config loaded: {config}")
|
|
|
|
configure_logging()
|
|
print("Logging configured")
|
|
|
|
logging.info("Script started.")
|
|
dataset = SpectrogramDataset(config, config['directory'], process_new=True)
|
|
dataloader = DataLoader(dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SmartBatchingSampler(dataset, config['batch_size']))
|
|
for batch in dataloader:
|
|
spectrograms, labels = batch
|
|
logging.info(f"Spectrograms batch shape: {spectrograms.shape}")
|
|
logging.info(f"Labels batch shape: {labels.shape}")
|
|
print(f"Spectrograms batch shape: {spectrograms.shape}")
|
|
print(f"Labels batch shape: {labels.shape}")
|
|
break
|
|
|
|
logging.info(f"Total files processed: {len(dataset)}")
|
|
print(f"Total files processed: {len(dataset)}")
|
|
except Exception as e:
|
|
logging.error(f"Exception occurred: {e}", exc_info=True)
|
|
print(f"Exception occurred: {e}")
|
|
finally:
|
|
logging.info("Script ended.")
|
|
print("Script ended") |