| import torch |
| import torchaudio |
|
|
|
|
| def get_spectrogram_shape(waveform_len, hop_length=1250, center=True): |
| |
| |
| |
| return waveform_len // hop_length + 1 |
|
|
|
|
| def calculate_required_length(current_len, hop_length, patch_time_dim): |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| spec_len = current_len // hop_length + 1 |
| block_size = 2 * patch_time_dim |
|
|
| if spec_len % block_size == 0: |
| target_spec_len = spec_len |
| else: |
| target_spec_len = (spec_len // block_size + 1) * block_size |
|
|
| |
| |
| |
| |
|
|
| return target_spec_len, target_spec_len * hop_length |
|
|
|
|
| def test_shapes(): |
| hop_length = 1250 |
| patch_time_dim = 16 |
|
|
| lengths = [32000, 48000, 320000, 12345] |
|
|
| mel = torchaudio.transforms.MelSpectrogram( |
| sample_rate=32000, |
| n_fft=4096, |
| win_length=4096, |
| hop_length=hop_length, |
| n_mels=128, |
| center=True, |
| ) |
|
|
| print(f"Testing with hop_length={hop_length}, patch_time_dim={patch_time_dim}") |
|
|
| for length in lengths: |
| wave = torch.randn(1, length) |
| spec = mel(wave) |
| spec_len = spec.shape[-1] |
| print(f"Wave: {length}, Spec: {spec_len}") |
|
|
| |
| target_spec_len, target_wave_len = calculate_required_length( |
| length, hop_length, patch_time_dim |
| ) |
|
|
| |
| wave_pad = torch.randn(1, target_wave_len) |
| spec_pad = mel(wave_pad) |
| spec_pad_len = spec_pad.shape[-1] |
|
|
| print(f" Target Spec: {target_spec_len}, Target Wave: {target_wave_len}") |
| print(f" Actual Spec: {spec_pad_len}") |
| print(f" Even patches? {spec_pad_len / patch_time_dim} (Time patches)") |
| print( |
| f" Even time patches condition: {(spec_pad_len // patch_time_dim) % 2 == 0}" |
| ) |
|
|
| if spec_pad_len != target_spec_len: |
| print(" MISMATCH!") |
|
|
|
|
| if __name__ == "__main__": |
| test_shapes() |
|
|