| import torch |
| from torch.utils.data import Dataset |
| import json |
| import torchaudio |
| import os |
| from typing import Optional, Dict, Any, List, Tuple |
| import pandas as pd |
| import warnings |
| import random |
| from pathlib import Path |
| from collections import defaultdict |
|
|
|
|
|
|
|
|
| class TEARSDataset(Dataset): |
| """ |
| TEARS dataset class that loads audio and associated metadata/responses. |
| |
| Args: |
| json_path (str): Path to the JSON file containing TEARS data |
| tears_root (str): Root directory containing TEARS audio files |
| sample_rate (int, optional): Target sample rate for audio. Defaults to 16000. |
| duration (float, optional): Target duration in seconds. Defaults to 3.0. |
| normalize_audio (bool, optional): Whether to normalize audio. Defaults to True. |
| |
| Returns: |
| Dict containing: |
| - audio_tensor: torch.Tensor of shape (1, num_samples) |
| - speaker_id: str, speaker identifier |
| - metadata: dict containing speaker metadata |
| - prompt: str, randomly selected prompt |
| - response: str, corresponding response |
| - filepath: str, path to audio file |
| """ |
| def __init__( |
| self, |
| json_path: str, |
| tears_root: str, |
| sample_rate: int = 16000, |
| duration: float = 3.0, |
| normalize_audio: bool = True, |
| augment: bool = True |
| ): |
| super().__init__() |
| |
| |
| with open(json_path, 'r') as f: |
| self.data = json.load(f) |
| |
| self.tears_root = Path(tears_root) |
| self.sample_rate = sample_rate |
| self.duration = duration |
| self.normalize_audio = normalize_audio |
| self.target_samples = int(duration * sample_rate) |
| self.augment = augment |
|
|
| def __len__(self) -> int: |
| return len(self.data) |
| |
| def augment_audio(self, waveform, sample_rate): |
| |
| augmentation_choices = ['time_stretch', 'pitch_shift', 'add_noise', 'spec_aug'] |
| random.shuffle(augmentation_choices) |
|
|
| for aug in augmentation_choices[:random.randint(1, len(augmentation_choices))]: |
| if aug == 'time_stretch': |
| rate = random.uniform(0.8, 1.25) |
| effect = [['speed', str(rate)], ['rate', str(16000)]] |
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor( |
| waveform, 16000, effects=effect |
| ) |
|
|
| elif aug == 'pitch_shift': |
| n_steps = random.randint(-4, 4) |
| effect = [['pitch', str(n)] for n in [n_steps*100 for n in [random.choice([-2, -1, 1, 2])]]] |
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect) |
|
|
| elif aug == 'add_noise': |
| noise = torch.randn_like(waveform) * random.uniform(0.001, 0.015) |
| waveform = waveform + noise |
|
|
| elif aug == 'frequency_mask': |
| freq_mask = T.FrequencyMasking(freq_mask_param=random.randint(15, 30)) |
| waveform = freq_mask(waveform) |
|
|
| elif aug == 'time_mask': |
| time_mask = T.TimeMasking(time_mask_param=random.randint(20, 80)) |
| waveform = time_mask(waveform) |
|
|
| elif aug == 'reverb': |
| effect = [['reverb', '-w', str(random.randint(10, 50))]] |
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect) |
|
|
| elif aug == 'pitch_shift': |
| steps = random.randint(-2, 2) |
| effect = [['pitch', str(steps * 100)], ['rate', '16000']] |
| waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, 16000, effect) |
|
|
| return waveform |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| |
| sample = self.data[idx] |
| |
| |
| audio_path = str(self.tears_root / sample['audio_path']) |
| |
| |
| try: |
| audio, sr = torchaudio.load(audio_path) |
| |
| |
| if sr != self.sample_rate: |
| audio = torchaudio.transforms.Resample(sr, self.sample_rate)(audio) |
| |
| if self.augment: |
| audio = self.augment_audio(audio, self.sample_rate) |
| |
| |
| if self.normalize_audio: |
| mean = torch.mean(audio) |
| std = torch.std(audio) |
| audio = (audio - mean) / (std + 1e-8) |
| |
| |
| num_samples = audio.shape[1] |
| |
| if num_samples >= self.target_samples: |
| |
| start_sample = random.randint(0, num_samples - self.target_samples) |
| audio = audio[:, start_sample:start_sample + self.target_samples] |
| else: |
| |
| pad_size = self.target_samples - num_samples |
| audio = torch.nn.functional.pad(audio, (0, pad_size)) |
| |
| except Exception as e: |
| warnings.warn(f"Error loading audio file {audio_path}: {str(e)}") |
| |
| audio = torch.zeros(1, self.target_samples) |
| |
| |
| prompts = sample.get('prompts', []) |
| responses = sample.get('responses', []) |
| |
| if prompts and responses and len(prompts) == len(responses): |
| rand_idx = random.randint(0, len(prompts) - 1) |
| prompt = prompts[rand_idx] |
| response = responses[rand_idx].replace("\n", " ").strip() |
| else: |
| prompt = None |
| response = None |
| |
| return { |
| 'audio_tensor': audio, |
| 'sid': sample['speaker']['id'], |
| 'metadata': sample['speaker'], |
| 'prompt': prompt, |
| 'answer': response, |
| 'filename': str(audio_path) |
| } |
|
|
|
|
|
|
| @staticmethod |
| def redistribute_speakers( |
| json_paths: Dict[str, str], |
| split_ratios: Dict[str, float], |
| seed: int = 42 |
| ) -> Dict[str, List[Dict]]: |
| """ |
| Redistribute speakers across splits according to given ratios. |
| |
| Args: |
| json_paths: Dict mapping split names to json file paths |
| split_ratios: Dict mapping split names to desired ratios (should sum to 1) |
| seed: Random seed for reproducibility |
| |
| Returns: |
| Dict mapping split names to lists of samples |
| """ |
| random.seed(seed) |
| |
| |
| speaker_samples = defaultdict(list) |
| for split, path in json_paths.items(): |
| with open(path, 'r') as f: |
| data = json.load(f) |
| for sample in data: |
| speaker_samples[sample['speaker']['id']].append(sample) |
| |
| |
| all_speakers = list(speaker_samples.keys()) |
| random.shuffle(all_speakers) |
| |
| |
| total_speakers = len(all_speakers) |
| split_speakers = { |
| split: int(ratio * total_speakers) |
| for split, ratio in split_ratios.items() |
| } |
| |
| |
| remainder = total_speakers - sum(split_speakers.values()) |
| if remainder > 0: |
| |
| split_speakers[list(split_speakers.keys())[0]] += remainder |
| |
| |
| new_splits = defaultdict(list) |
| current_idx = 0 |
| |
| for split, num_speakers in split_speakers.items(): |
| split_speaker_ids = all_speakers[current_idx:current_idx + num_speakers] |
| for speaker_id in split_speaker_ids: |
| new_splits[split].extend(speaker_samples[speaker_id]) |
| current_idx += num_speakers |
| |
| return new_splits |
|
|
| @staticmethod |
| def save_splits(splits: Dict[str, List[Dict]], output_dir: str): |
| """Save redistributed splits to JSON files.""" |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| for split_name, samples in splits.items(): |
| output_path = output_dir / f"tears_dataset_{split_name}_with_responses.json" |
| with open(output_path, 'w') as f: |
| json.dump(samples, f, indent=2) |
|
|
|
|