BEST-RQ-2 / audio-embeddings /scripts /verify_cropping.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import torch
import numpy as np
from src.data.audioset_datamodule import AudioSetDataset
# Mock Dataset inheriting from AudioSetDataset to test logic without H5
class MockAudioSetDataset(AudioSetDataset):
def __init__(self, lengths, max_length=None):
self.lengths = lengths
self.max_length = max_length
self.transform = None
self.valid_indices = list(range(len(lengths)))
self.h5_file = None # Not used
def _open_h5(self):
pass
def __getitem__(self, idx):
# Mock waveform loading
length = self.lengths[idx]
# Create a waveform where values are 0..L-1 so we can check cropping start
waveform = np.arange(length, dtype=np.float32)
# Random Crop logic from AudioSetDataset
if self.max_length is not None and len(waveform) > self.max_length:
max_start = len(waveform) - self.max_length
start = np.random.randint(0, max_start + 1)
waveform = waveform[start : start + self.max_length]
# Mock other returns
target = torch.zeros(527)
audio_name = f"audio_{idx}"
waveform = torch.from_numpy(waveform).unsqueeze(0)
return {
"waveform": waveform,
"target": target,
"audio_name": audio_name,
"index": idx,
}
def test_random_cropping():
max_len = 100
lengths = [50, 100, 150, 200]
dataset = MockAudioSetDataset(lengths, max_length=max_len)
print(f"Testing with max_length={max_len}")
for i in range(len(lengths)):
# Test multiple times to check randomness
starts = []
for _ in range(5):
item = dataset[i]
wave = item["waveform"]
# Check length
if wave.shape[-1] > max_len:
print(
f"FAIL: Index {i} (orig {lengths[i]}) has length {wave.shape[-1]} > {max_len}"
)
# Check content (start index)
start_val = wave[0, 0].item()
starts.append(start_val)
print(f"Index {i} (orig {lengths[i]}): Starts = {starts}")
if lengths[i] > max_len:
# Should be cropped to max_len
if wave.shape[-1] != max_len:
print(
f"FAIL: Index {i} should be cropped to {max_len}, got {wave.shape[-1]}"
)
# Should be random (unless max_start=0)
if (
len(set(starts)) == 1 and lengths[i] > max_len + 5
): # Allow some chance of collision
print(f"WARNING: Index {i} might not be random? Starts: {starts}")
else:
# Should be original length
if wave.shape[-1] != lengths[i]:
print(f"FAIL: Index {i} should be {lengths[i]}, got {wave.shape[-1]}")
print("Test finished.")
if __name__ == "__main__":
test_random_cropping()