import torch from torch.utils.data import Dataset from transformers import AutoFeatureExtractor import os import librosa import numpy as np class DemoDataset(Dataset): def __init__(self, demonstration_paths, query_path, sample_rate=16000): self.sample_rate = sample_rate self.query_path = query_path # Convert to list if single path if isinstance(demonstration_paths, str): self.demonstration_paths = [demonstration_paths] else: self.demonstration_paths = demonstration_paths # Load feature extractor self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") print(f'Number of demonstration audios: {len(self.demonstration_paths)}') print(f'Query audio: {self.query_path}') def load_pad(self, path, max_length=64000): """Load and pad audio file""" X, sr = librosa.load(path, sr=self.sample_rate) X = self.pad(X, max_length) return X def pad(self, x, max_len=64000): """Pad audio to fixed length""" x_len = x.shape[0] if x_len >= max_len: return x[:max_len] pad_length = max_len - x_len return np.concatenate([x, np.zeros(pad_length)], axis=0) def __len__(self): return 1 # Only one query audio def __getitem__(self, idx): # Load query audio query_waveform = self.load_pad(self.query_path) query_waveform = torch.from_numpy(query_waveform).float() if len(query_waveform.shape) == 1: query_waveform = query_waveform.unsqueeze(0) # Extract features for query audio main_features = self.feature_extractor( query_waveform, sampling_rate=self.sample_rate, padding=True, return_attention_mask=True, return_tensors="pt" ) # Process demonstration audios prompt_features = [] for demo_path in self.demonstration_paths: # Load demonstration audio demo_waveform = self.load_pad(demo_path) demo_waveform = torch.from_numpy(demo_waveform).float() if len(demo_waveform.shape) == 1: demo_waveform = demo_waveform.unsqueeze(0) # Extract features prompt_feature = self.feature_extractor( demo_waveform, sampling_rate=self.sample_rate, padding=True, return_attention_mask=True, return_tensors="pt" ) prompt_features.append(prompt_feature) return { 'main_features': main_features, 'prompt_features': prompt_features, 'file_name': os.path.basename(self.query_path), 'file_path': self.query_path } def collate_fn(batch): """ Collate function for dataloader Args: batch: List containing dictionaries with: - main_features: feature extractor output - prompt_features: list of feature extractor outputs - file_name: file name - file_path: file path """ batch_size = len(batch) # Process main features main_features_keys = batch[0]['main_features'].keys() main_features = {} for key in main_features_keys: main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0) # Get number of prompts num_prompts = len(batch[0]['prompt_features']) # Process prompt features prompt_features = [] for i in range(num_prompts): prompt_feature = {} for key in main_features_keys: prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0) prompt_features.append(prompt_feature) # Collect file names and paths file_names = [item['file_name'] for item in batch] file_paths = [item['file_path'] for item in batch] return { 'main_features': main_features, 'prompt_features': prompt_features, 'file_names': file_names, 'file_paths': file_paths } if __name__ == '__main__': # Test the dataset demo_paths = ["examples/demo1.wav", "examples/demo2.wav"] query_path = "examples/query.wav" dataset = DemoDataset(demo_paths, query_path) print(dataset[0])