# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Compute input examples for VGGish from audio waveform.""" # Modification: Return torch tensors rather than numpy arrays import torch import numpy as np import resampy from . import mel_features from . import vggish_params import soundfile as sf def waveform_to_examples(data, sample_rate, numFrames, fps, return_tensor=True): """Converts audio waveform into an array of examples for VGGish. Args: data: np.array of either one dimension (mono) or two dimensions (multi-channel, with the outer dimension representing channels). Each sample is generally expected to lie in the range [-1.0, +1.0], although this is not required. sample_rate: Sample rate of data. return_tensor: Return data as a Pytorch tensor ready for VGGish Returns: 3-D np.array of shape [num_examples, num_frames, num_bands] which represents a sequence of examples, each of which contains a patch of log mel spectrogram, covering num_frames frames of audio and num_bands mel frequency bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. """ # Convert to mono. if len(data.shape) > 1: data = np.mean(data, axis=1) # Resample to the rate assumed by VGGish. if sample_rate != vggish_params.SAMPLE_RATE: data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) window_length_seconds = vggish_params.STFT_WINDOW_LENGTH_SECONDS * 25. / fps hop_length_seconds = vggish_params.STFT_HOP_LENGTH_SECONDS * 25. / fps # Compute log mel spectrogram features. log_mel = mel_features.log_mel_spectrogram(data, audio_sample_rate=vggish_params.SAMPLE_RATE, log_offset=vggish_params.LOG_OFFSET, window_length_secs=window_length_seconds, hop_length_secs=hop_length_seconds, num_mel_bins=vggish_params.NUM_MEL_BINS, lower_edge_hertz=vggish_params.MEL_MIN_HZ, upper_edge_hertz=vggish_params.MEL_MAX_HZ) maxAudio = int(numFrames * 4) if log_mel.shape[0] < maxAudio: shortage = maxAudio - log_mel.shape[0] log_mel = np.pad(log_mel, ((0, shortage), (0, 0)), 'wrap') log_mel = log_mel[:int(round(numFrames * 4)), :] # Frame features into examples. # features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS # example_window_length = int(round(vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) # example_hop_length = int(round(vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) # log_mel_examples = mel_features.frame(log_mel, # window_length=example_window_length, # hop_length=example_hop_length) if return_tensor: log_mel_examples = torch.tensor(log_mel_examples, requires_grad=True)[:, None, :, :].float() # return log_mel_examples return log_mel def wavfile_to_examples(wav_file, return_tensor=True): """Convenience wrapper around waveform_to_examples() for a common WAV format. Args: wav_file: String path to a file, or a file-like object. The file is assumed to contain WAV audio data with signed 16-bit PCM samples. torch: Return data as a Pytorch tensor ready for VGGish Returns: See waveform_to_examples. """ wav_data, sr = sf.read(wav_file, dtype='int16') assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] return waveform_to_examples(samples, sr, return_tensor)