BEST-RQ-2 / audio-embeddings /src /data /yt1b_datamodule.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import os
from functools import partial
from typing import Any, Dict, List, Optional, Union
import lightning as L
import numpy as np
import pandas as pd
import torch
import torchaudio
from torch.utils.data import DataLoader, Dataset
from src.data.audio_utils import DatasetResamplerCropper, collate_audio_batch
class YT1BDataset(Dataset):
"""
Dataset for YT-Temporal-1B data using Parquet metadata files.
Args:
parquet_path (str): Path to the parquet file containing metadata (must include 'file_path', 'video_id', 'duration_sec').
If a 'sample_rate' column exists, it is used to avoid probing files for source sample rate.
min_duration (Optional[float]): Minimum duration in seconds to include a file.
max_duration (Optional[float]): Maximum duration in seconds to include a file.
transform (Optional[callable]): Optional transform to apply to the waveform.
max_length (Optional[int]): Maximum length of the waveform in samples (at target_sample_rate).
target_sample_rate (int): Target sample rate for the waveform. Defaults to 16000.
decode_window_sec (Optional[float]): Optional decode window length in seconds. If None,
defaults to max_length / target_sample_rate (when max_length is set).
"""
def __init__(
self,
parquet_path: str,
min_duration: Optional[float] = None,
max_duration: Optional[float] = 30.0,
transform: Optional[Any] = None,
max_length: Optional[int] = None,
target_sample_rate: int = 16000,
decode_window_sec: Optional[float] = None,
):
print(f"Loading metadata from {parquet_path}...")
self.transform = transform
self.max_length = max_length
self.target_sample_rate = target_sample_rate
self.decode_window_sec = decode_window_sec
# --- Metadata Loading ---
if not os.path.exists(parquet_path):
raise FileNotFoundError(f"Parquet file not found at: {parquet_path}")
# Pyarrow is required for read_parquet
try:
df = pd.read_parquet(parquet_path)
except ImportError:
raise ImportError(
"Please install pyarrow to read parquet files: `uv add pyarrow`"
)
required_cols = {"file_path", "video_id", "duration_sec"}
if not required_cols.issubset(df.columns):
# Check if we have compatible columns or raise error
# Some datasets might use different names, strictly enforcing for now based on user prompt
raise ValueError(
f"Parquet file must contain columns: {required_cols}. Found: {df.columns.tolist()}"
)
if min_duration is not None and min_duration < 0:
raise ValueError(f"min_duration must be >= 0, got {min_duration}")
if max_duration is not None and max_duration < 0:
raise ValueError(f"max_duration must be >= 0, got {max_duration}")
if (
min_duration is not None
and max_duration is not None
and min_duration > max_duration
):
raise ValueError(
"min_duration must be <= max_duration; "
f"got min_duration={min_duration}, max_duration={max_duration}"
)
if min_duration is not None:
df = df[df["duration_sec"] >= min_duration]
if max_duration is not None:
df = df[df["duration_sec"] <= max_duration]
self.ids = df["video_id"].tolist()
self.paths = df["file_path"].tolist()
self.durations_sec = df["duration_sec"].tolist()
if "sample_rate" in df.columns:
sample_rates = pd.to_numeric(df["sample_rate"], errors="coerce").to_numpy(
dtype=np.float64
)
self.source_sample_rates: Optional[list[Optional[int]]] = [
int(sr) if np.isfinite(sr) and sr > 0 else None for sr in sample_rates
]
else:
self.source_sample_rates = None
self.length = len(self.ids)
# --- Resampler ---
# Uses the optimized class that caches resamplers
self.resampler = DatasetResamplerCropper(
target_sr=target_sample_rate, max_length=max_length
)
print(f"Dataset loaded. Length: {self.length:,}")
def __len__(self) -> int:
return self.length
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str, int]]:
audio_path = self.paths[idx]
audio_id = self.ids[idx]
# Load waveform
try:
decode_window_sec = self.decode_window_sec
if decode_window_sec is None and self.max_length is not None:
decode_window_sec = self.max_length / self.target_sample_rate
if decode_window_sec is None:
waveform, sr = torchaudio.load(audio_path)
else:
duration_sec = float(self.durations_sec[idx])
if duration_sec <= 0:
waveform, sr = torchaudio.load(audio_path)
else:
source_sr: Optional[int]
if self.source_sample_rates is not None:
source_sr = self.source_sample_rates[idx]
else:
source_sr = None
if source_sr is None:
_, source_sr = torchaudio.load(
audio_path, frame_offset=0, num_frames=1
)
total_frames = max(1, int(duration_sec * source_sr))
max_decode_frames = max(1, int(decode_window_sec * source_sr))
decode_frames = min(max_decode_frames, total_frames)
if total_frames > decode_frames:
max_start = total_frames - decode_frames
frame_offset = int(np.random.randint(0, max_start + 1))
else:
frame_offset = 0
waveform, sr = torchaudio.load(
audio_path,
frame_offset=frame_offset,
num_frames=decode_frames,
)
except Exception as e:
print(f"Error loading {audio_path}: {e}")
# Return a dummy silent waveform to prevent crash
len_samples = (
self.max_length if self.max_length else self.target_sample_rate
)
return {
"waveform": torch.zeros(1, len_samples),
"audio_name": audio_id,
"index": idx,
"error": True,
}
# Mix down to mono if necessary
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample and crop
waveform = self.resampler(waveform, source_sr=sr)
# Ensure channel dim exists [1, T] if resampler stripped it or returned [T]
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
if self.transform:
waveform = self.transform(waveform)
return {
"waveform": waveform,
"audio_name": audio_id,
"index": idx,
}
class YT1BDataModule(L.LightningDataModule):
"""
LightningDataModule for YT-Temporal-1B.
Args:
data_dir (str): Root directory for data.
train_parquet (str): Filename of training parquet file.
val_parquet (str): Filename of validation parquet file.
test_parquet (str): Filename of test parquet file.
batch_size (int): Batch size for dataloaders.
num_workers (int): Number of workers for dataloaders.
pin_memory (bool): Whether to pin memory in dataloaders.
max_audio_length_sec (Optional[float]): Maximum audio length in seconds.
min_duration_sec (Optional[float]): Minimum audio duration in seconds to filter.
max_duration_sec (Optional[float]): Maximum audio duration in seconds to filter.
target_sample_rate (int): Target sample rate.
collate_mode (str): 'pad' or 'truncate'.
decode_window_sec (Optional[float]): Optional decode window length in seconds. If None,
defaults to max_audio_length_sec.
"""
def __init__(
self,
data_dir: str = "data/YT-Temporal-1B",
train_parquet: str = "train_metadata.parquet",
val_parquet: str = "val_metadata.parquet",
test_parquet: str = "val_metadata.parquet",
batch_size: int = 64,
num_workers: int = 4,
pin_memory: bool = True,
max_audio_length_sec: Optional[float] = 10.0,
min_duration_sec: Optional[float] = None,
max_duration_sec: Optional[float] = 30.0,
target_sample_rate: int = 16000,
collate_mode: str = "pad",
decode_window_sec: Optional[float] = None,
):
super().__init__()
self.save_hyperparameters()
self.data_dir = data_dir
self.train_parquet_path = os.path.join(data_dir, train_parquet)
self.val_parquet_path = os.path.join(data_dir, val_parquet)
self.test_parquet_path = os.path.join(data_dir, test_parquet)
if max_audio_length_sec is not None:
self.max_audio_length = int(max_audio_length_sec * target_sample_rate)
else:
self.max_audio_length = None
self.train_dataset: Optional[YT1BDataset] = None
self.val_dataset: Optional[YT1BDataset] = None
self.test_dataset: Optional[YT1BDataset] = None
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
if os.path.exists(self.train_parquet_path):
self.train_dataset = YT1BDataset(
self.train_parquet_path,
min_duration=self.hparams["min_duration_sec"],
max_duration=self.hparams["max_duration_sec"],
max_length=self.max_audio_length,
target_sample_rate=self.hparams["target_sample_rate"],
decode_window_sec=self.hparams["decode_window_sec"],
)
if os.path.exists(self.val_parquet_path):
self.val_dataset = YT1BDataset(
self.val_parquet_path,
min_duration=self.hparams["min_duration_sec"],
max_duration=self.hparams["max_duration_sec"],
max_length=self.max_audio_length,
target_sample_rate=self.hparams["target_sample_rate"],
decode_window_sec=self.hparams["decode_window_sec"],
)
if stage == "test":
if os.path.exists(self.test_parquet_path):
self.test_dataset = YT1BDataset(
self.test_parquet_path,
min_duration=self.hparams["min_duration_sec"],
max_duration=self.hparams["max_duration_sec"],
max_length=self.max_audio_length,
target_sample_rate=self.hparams["target_sample_rate"],
decode_window_sec=self.hparams["decode_window_sec"],
)
def train_dataloader(self) -> DataLoader:
if not self.train_dataset:
raise RuntimeError(
f"Train dataset not initialized. File not found: {self.train_parquet_path}"
)
return DataLoader(
self.train_dataset,
batch_size=self.hparams["batch_size"],
shuffle=True,
num_workers=self.hparams["num_workers"],
pin_memory=self.hparams["pin_memory"],
persistent_workers=self.hparams["num_workers"] > 0,
collate_fn=partial(self.collate_fn, mode=self.hparams["collate_mode"]),
)
def val_dataloader(self) -> DataLoader:
if not self.val_dataset:
# Often validation sets are missing in large scale pretraining or we use a subset of train
# For now, raise strict error or return empty list (lightning supports empty list for no val)
# Raising error is safer to debug configuration issues.
raise RuntimeError(
f"Val dataset not initialized. File not found: {self.val_parquet_path}"
)
return DataLoader(
self.val_dataset,
batch_size=self.hparams["batch_size"],
shuffle=False,
num_workers=self.hparams["num_workers"],
pin_memory=self.hparams["pin_memory"],
persistent_workers=self.hparams["num_workers"] > 0,
collate_fn=partial(self.collate_fn, mode=self.hparams["collate_mode"]),
)
def test_dataloader(self) -> DataLoader:
if not self.test_dataset:
raise RuntimeError(
f"Test dataset not initialized. File not found: {self.test_parquet_path}"
)
return DataLoader(
self.test_dataset,
batch_size=self.hparams["batch_size"],
shuffle=False,
num_workers=self.hparams["num_workers"],
pin_memory=self.hparams["pin_memory"],
collate_fn=partial(self.collate_fn, mode=self.hparams["collate_mode"]),
)
@staticmethod
def collate_fn(batch: List[Dict[str, Any]], mode: str = "pad") -> Dict[str, Any]:
# Filter out errors
batch = [x for x in batch if not x.get("error", False)]
if len(batch) == 0:
raise RuntimeError("All items in batch failed to load.")
return collate_audio_batch(
batch=batch,
waveform_key="waveform",
mode=mode,
)