|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""audio_test.py""" |
|
import unittest |
|
import os |
|
import numpy as np |
|
import wave |
|
import tempfile |
|
from utils.audio import load_audio_file |
|
from utils.audio import get_audio_file_info |
|
from utils.audio import slice_padded_array |
|
from utils.audio import slice_padded_array_for_subbatch |
|
from utils.audio import write_wav_file |
|
|
|
|
|
class TestLoadAudioFile(unittest.TestCase): |
|
|
|
def create_temp_wav_file(self, duration: float, fs: int = 16000) -> str: |
|
n_samples = int(duration * fs) |
|
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
|
temp_filename = temp_file.name |
|
|
|
data = np.random.randint(-2**15, 2**15, n_samples, dtype=np.int16) |
|
|
|
with wave.open(temp_filename, 'wb') as f: |
|
f.setnchannels(1) |
|
f.setsampwidth(2) |
|
f.setframerate(fs) |
|
f.writeframes(data.tobytes()) |
|
|
|
return temp_filename |
|
|
|
def test_load_audio_file(self): |
|
duration = 3.0 |
|
fs = 16000 |
|
temp_filename = self.create_temp_wav_file(duration, fs) |
|
|
|
|
|
audio_data = load_audio_file(temp_filename, dtype=np.int16) |
|
file_fs, n_frames, n_channels = get_audio_file_info(temp_filename) |
|
|
|
self.assertEqual(len(audio_data), n_frames) |
|
self.assertEqual(file_fs, fs) |
|
self.assertEqual(n_channels, 1) |
|
|
|
|
|
seg_start_sec = 1.0 |
|
seg_length_sec = 1.0 |
|
audio_data = load_audio_file(temp_filename, seg_start_sec, seg_length_sec, dtype=np.int16) |
|
|
|
self.assertEqual(len(audio_data), int(seg_length_sec * fs)) |
|
|
|
|
|
with self.assertRaises(NotImplementedError): |
|
load_audio_file("unsupported.xyz") |
|
|
|
|
|
class TestSliceArray(unittest.TestCase): |
|
|
|
def setUp(self): |
|
self.x = np.random.randint(0, 10, size=(1, 10000)) |
|
|
|
def test_without_padding(self): |
|
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=False) |
|
self.assertEqual(sliced_x.shape, (199, 100)) |
|
|
|
def test_with_padding(self): |
|
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) |
|
self.assertEqual(sliced_x.shape, (199, 100)) |
|
|
|
def test_content(self): |
|
sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) |
|
for i in range(sliced_x.shape[0] - 1): |
|
np.testing.assert_array_equal(sliced_x[i, :], self.x[:, i * 50:i * 50 + 100].flatten()) |
|
|
|
last_slice = sliced_x[-1, :] |
|
last_slice_no_padding = self.x[:, -100:].flatten() |
|
np.testing.assert_array_equal(last_slice[:len(last_slice_no_padding)], last_slice_no_padding) |
|
|
|
|
|
class TestSlicePadForSubbatch(unittest.TestCase): |
|
|
|
def test_slice_padded_array_for_subbatch(self): |
|
input_array = np.random.randn(6, 10) |
|
slice_length = 4 |
|
slice_hop = 2 |
|
pad = True |
|
sub_batch_size = 4 |
|
|
|
expected_output_shape = (4, 4) |
|
|
|
|
|
result = slice_padded_array_for_subbatch(input_array, slice_length, slice_hop, pad, sub_batch_size) |
|
|
|
|
|
self.assertEqual(result.shape, expected_output_shape) |
|
|
|
|
|
self.assertEqual(result.shape[0] % sub_batch_size, 0) |
|
|
|
|
|
class TestWriteWavFile(unittest.TestCase): |
|
|
|
def test_write_wav_file_z(self): |
|
|
|
samplerate = 16000 |
|
duration = 1 |
|
t = np.linspace(0, duration, int(samplerate * duration), endpoint=False) |
|
x = np.sin(2 * np.pi * 440 * t) |
|
|
|
|
|
filename = "extras/test.wav" |
|
write_wav_file(filename, x, samplerate) |
|
|
|
|
|
with wave.open(filename, "rb") as wav_file: |
|
|
|
self.assertEqual(wav_file.getnchannels(), 1) |
|
self.assertEqual(wav_file.getsampwidth(), 2) |
|
self.assertEqual(wav_file.getframerate(), samplerate) |
|
self.assertEqual(wav_file.getnframes(), len(x)) |
|
|
|
|
|
data = wav_file.readframes(len(x)) |
|
|
|
|
|
x_read = np.frombuffer(data, dtype=np.int16) / 32767.0 |
|
|
|
|
|
np.testing.assert_allclose(x_read, x, atol=1e-4) |
|
|
|
|
|
os.remove(filename) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|