bpiyush's picture
Upload folder using huggingface_hub
c5f65a4 verified
raw
history blame
18.9 kB
"""Audio loading utils."""
import os
import numpy as np
import torch
import torchaudio
import decord
import librosa
import einops
import PIL
import matplotlib.pyplot as plt
# Add serif font
plt.rcParams['font.family'] = 'serif'
from PIL import Image, ImageOps
import librosa.display
import shared.utils as su
def read_info(path):
"""
Reads the info of the given audio file.
Args:
path (str): path to the audio file
"""
import ffmpeg
probe = ffmpeg.probe(path)
audio_info = next(
(s for s in probe['streams'] if s['codec_type'] == 'audio'),
None,
)
video_info = next(
(s for s in probe['streams'] if s['codec_type'] == 'video'),
None,
)
return dict(video=video_info, audio=audio_info)
def load_audio_clips(
audio_path,
clips,
sr,
clip_len,
backend='decord',
load_entire=False,
cut_to_clip_len=True,
):
"""
Loads audio clips from the given audio file.
Args:
audio_path (str): path to the audio file
clips (np.ndarray): sized [T, 2], where T is the number of clips
and each row is a pair of start and end times of the clip
sr (int): sample rate
clip_len (float): length of the audio clip in seconds
backend (str): backend to use for loading audio clips
load_entire (bool): whether to load the entire audio file
cut_to_clip_len (bool): whether to cut the audio clip to clip_len
"""
if backend == 'torchaudio':
audio_info = read_info(audio_path)["audio"]
true_sr = int(audio_info["sample_rate"])
true_nf = audio_info["duration_ts"]
audio_duration = true_nf / true_sr
# metadata = torchaudio.info(audio_path)
# true_sr = metadata.sample_rate
# true_nf = metadata.num_frames
elif backend == "decord":
# duration = librosa.get_duration(filename=audio_path)
ar = decord.AudioReader(audio_path, sample_rate=sr, mono=True)
# Mono=False gives NaNs in inputs.
# This (https://gist.github.com/nateraw/fcc2bdb9c8738224957c8617c3360445) might
# be a related issue. Ignoring for now. Need to use torchaudio for now.
true_nf = ar.shape[1]
audio_duration = ar.shape[1] / sr
else:
raise ValueError(f"Unknown backend: {backend}")
if load_entire:
# Load the entire audio as a single clip and return
if backend == 'torchaudio':
y, _ = torchaudio.load(audio_path)
if y.shape[0] > 1:
# Convert to a single channel
y = y.mean(dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(true_sr, sr)
y = resampler(y)
audio = y
elif backend == "decord":
audio = ar.get_batch(np.arange(true_nf)).asnumpy()
audio = torch.from_numpy(audio)
return [audio]
else:
# Clip the clips to avoid going out of bounds
clips = np.clip(clips, 0, audio_duration)
audio_clips = []
for st, et in clips:
if backend == 'torchaudio':
# Load audio within the given time range
sf = max(int(true_sr * st), 0)
ef = min(int(true_sr * et), true_nf)
nf = ef - sf
y, _ = torchaudio.load(audio_path, frame_offset=sf, num_frames=nf)
# Stereo to mono
if y.shape[0] > 1:
# Convert to a single channel
y = y.mean(dim=0, keepdim=True)
# Resample to the given sample rate
resampler = torchaudio.transforms.Resample(true_sr, sr)
y = resampler(y)
audio = y
elif backend == "decord":
# Load audio within the given time range
sf = max(int(st * sr), 0)
ef = min(int(et * sr), true_nf)
audio = ar.get_batch(np.arange(sf, ef)).asnumpy()
audio = torch.from_numpy(audio)
# No need to convert to mono since we are using mono=True
# No need to resample since we are using sample_rate=sr
else:
raise ValueError(f"Unknown backend: {backend}")
# Pad the clip to clip_len
nf_reqd = int(clip_len * sr)
nf_curr = audio.shape[1]
npad_side = max(0, nf_reqd - nf_curr)
if nf_curr < nf_reqd:
audio = torch.nn.functional.pad(audio, (0, npad_side))
elif (nf_curr > nf_reqd) and cut_to_clip_len:
audio = audio[:, :nf_reqd]
audio_clips.append(audio)
return audio_clips
def show_audio_clips_waveform(
audio_clips, clips, title=None, show=True, figsize=(10, 2),
):
"""
Visualizes the given audio clips.
Args:
audio_clips (list): list of audio clips
sr (int): sample rate
title (str): title of the plot
show (bool): whether to show the clips
figsize (tuple): figure size
"""
clip_centers = (clips[:, 0] + clips[:, 1]) / 2
clip_durations = clips[:, 1] - clips[:, 0]
fig, ax = plt.subplots(1, len(audio_clips), figsize=figsize)
if len(audio_clips) == 1:
ax = [ax]
for i, audio in enumerate(audio_clips):
timestamps = np.linspace(
clip_centers[i] - clip_durations[i],
clip_centers[i] + clip_durations[i],
audio.shape[-1],
)
ax[i].plot(timestamps, audio.squeeze().numpy(), alpha=0.5)
ax[i].set_title(f'$t=$ {clip_centers[i]:.2f}')
ax[i].grid(alpha=0.4)
plt.tight_layout()
if show:
plt.show()
else:
plt.savefig('audio_clips_waveform.png')
# TODO: preprocess audio clips (e.g., wav-to-spectrogram, etc.)
# Note that this is different from transforms applied as augmentation
# during training. This is more like a preprocessing step that is applied
# to the entire audio before sampling the clips.
import torchaudio.functional as TAF
import torchaudio.transforms as TAT
def load_audio(path, sr=16000, **kwargs):
y, true_sr = torchaudio.load(path, **kwargs)
y = y.mean(dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(true_sr, sr)
y = resampler(y)
return y, sr
def load_audio_librosa(path, sr=16000, **kwargs):
y, true_sr = librosa.load(path, sr=sr, **kwargs)
y = torch.from_numpy(y).unsqueeze(0)
return y, sr
def librosa_harmonic_spectrogram_db(
y, sr=16000, n_fft=512, hop_length=256, margin=16., n_mels=64,
):
if isinstance(y, torch.Tensor):
y = y.numpy()
if len(y.shape) == 2:
y = y.mean(axis=0)
# center=True outputs 1 more frame than center=False
# Currently, using just center=False
D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, center=False)
DH, DP = librosa.decompose.hpss(D, margin=margin)
amplitude_h = np.sqrt(2) * np.abs(DH)
if n_mels is None:
# Usual dB spectrogram
SH = librosa.amplitude_to_db(amplitude_h, ref=np.max)
else:
# Mel-scaled dB spectrogram
S = librosa.amplitude_to_db(amplitude_h)
SH = librosa.feature.melspectrogram(S=S, n_mels=n_mels, sr=sr)
return SH
def show_logmelspectrogram(
S,
sr,
n_fft=512,
hop_length=256,
figsize=(10, 3),
ax=None,
show=True,
title="LogMelSpectrogram",
xlabel="Time (s)",
ylabel="Mel bins (Hz)",
return_as_image=False,
):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
librosa.display.specshow(
S,
sr=sr,
hop_length=hop_length,
n_fft=n_fft,
y_axis='mel',
x_axis='time',
ax=ax,
auto_aspect=True,
)
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if return_as_image:
fig.canvas.draw()
image = PIL.Image.frombytes(
'RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb(),
)
plt.close(fig)
return image
if show:
plt.show()
def show_logspectrogram(
S, sr, n_fft=512, hop_length=256, figsize=(10, 3), ax=None, show=True,
):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
librosa.display.specshow(
S,
sr=sr,
hop_length=hop_length,
n_fft=n_fft,
y_axis='linear',
x_axis='time',
ax=ax,
)
ax.set_title("LogSpectrogram")
if show:
plt.show()
def audio_clips_wav_to_spec(
audio_clips, n_fft=512, hop_length=256, margin=16., n_mels=None,
):
"""
Converts the given audio clips to spectrograms.
Args:
audio_clips (list): list of audio clips
n_fft (int): number of FFT points
hop_length (int): hop length
margin (float): margin for harmonic-percussive source separation
n_mels (int): number of mel bands (optional, if None, then dB spectrogram is returned)
"""
audio_specs = []
for audio in audio_clips:
spec = librosa_harmonic_spectrogram_db(
audio,
n_fft=n_fft,
hop_length=hop_length,
margin=margin,
n_mels=n_mels,
)
spec = torch.from_numpy(spec).unsqueeze(0)
audio_specs.append(spec)
return audio_specs
def show_audio_clips_spec(
audio_specs,
clips,
sr,
n_fft=512,
hop_length=256,
margin=16.,
cmap='magma',
n_mels=None,
show=True,
):
"""
Visualizes the given audio clips.
Args:
audio_specs (list): list of audio spectrograms
clips (np.ndarray): sized [T, 2], where T is the number of clips
and each row is a pair of start and end times of the clip
show (bool): whether to show the clips
"""
clip_centers = (clips[:, 0] + clips[:, 1]) / 2
clip_durations = clips[:, 1] - clips[:, 0]
fig, ax = plt.subplots(1, len(audio_specs), figsize=(10, 4))
if len(audio_specs) == 1:
ax = [ax]
for i, spec in enumerate(audio_specs):
clip_start = clips[i][0]
# ax[i].imshow(spec, aspect='auto', origin='lower')
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
if len(spec.shape) == 3:
spec = spec[0]
args = dict(
data=spec,
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
ax=ax[i],
x_axis="time",
cmap=cmap,
)
if n_mels is None:
args.update(dict(y_axis="linear"))
else:
args.update(dict(y_axis="mel"))
librosa.display.specshow(**args)
# Get xticks and replace them by xticks + clip_start
xticks = ax[i].get_xticks()
xticks = xticks + clip_start
ax[i].set_xticklabels([f'{x:.1f}' for x in xticks])
ax[i].set_title(f'$t=$ {clip_centers[i]:.2f}')
plt.tight_layout()
if show:
plt.show()
else:
plt.savefig('audio_clips_spec.png')
def basic_pipeline_audio_clips(
audio_clips,
spec_args=None,
audio_transform=None,
stack=True,
):
wave_transform = audio_transform.get('wave', None)
spec_transform = audio_transform.get('spec', None)
# Apply transforms to raw waveforms
if wave_transform is not None:
audio_clips = wave_transform(audio_clips)
if spec_args is not None:
# Convert waveforms to spectrograms
audio_clips = audio_clips_wav_to_spec(audio_clips, **spec_args)
# Apply transforms to spectrograms
if spec_transform is not None:
audio_clips = spec_transform(audio_clips)
if stack:
audio_clips = torch.stack(audio_clips)
return audio_clips
def load_and_process_audio(
audio_path,
clips,
cut_to_clip_len=True,
load_entire=False,
audio_transform=None,
aload_args=dict(),
apipe_args=dict(),
):
"""Loads and preprocess audio."""
# [C1] Load video clips: List[torch.Tensor]
audio_clips = load_audio_clips(
audio_path=audio_path,
clips=clips,
load_entire=load_entire,
cut_to_clip_len=cut_to_clip_len,
**aload_args,
)
# [C2] Pipeline: [Preprocessing -> Transform]
audio_clips = basic_pipeline_audio_clips(
audio_clips=audio_clips,
audio_transform=audio_transform,
**apipe_args,
)
return audio_clips
def crop_height(image, height):
"""Crops image from the top and bottom to the desired height."""
width, curr_height = image.size
if curr_height < height:
raise ValueError(f"Height of the image is less than {height}")
top = (curr_height - height) // 2
bottom = top + height
return image.crop((0, top, width, bottom))
def pad_to_height(image, height):
"""Pads image with black strips at the top and bottom."""
width, curr_height = image.size
if curr_height > height:
raise ValueError(f"Height of the image is already greater than {height}")
top = (height - curr_height) // 2
bottom = height - curr_height - top
return ImageOps.expand(image, (0, top, 0, bottom), fill="black")
def crop_width(image, width):
"""Crops image from the left and right to the desired width."""
curr_width, height = image.size
if curr_width < width:
raise ValueError(f"Width of the image is less than {width}")
left = (curr_width - width) // 2
right = left + width
return image.crop((left, 0, right, height))
def crop_or_pad_height(image, height):
"""Crops or pads image to the desired height."""
width, curr_height = image.size
if curr_height < height:
return pad_to_height(image, height)
elif curr_height > height:
return crop_height(image, height)
return image
def crop_or_pad_width(image, width):
"""Crops or pads image to the desired width."""
curr_width, height = image.size
if curr_width < width:
return pad_to_width(image, width)
elif curr_width > width:
return crop_width(image, width)
return image
def pad_to_width(image, width):
"""Pads image with black strips at the left and right."""
curr_width, height = image.size
if curr_width > width:
raise ValueError(f"Width of the image is already greater than {width}")
left = (width - curr_width) // 2
right = width - curr_width - left
return ImageOps.expand(image, (left, 0, right, 0), fill="black")
def crop_or_pad_to_size(image, size=(270, 480)):
"""Crops or pads image to the desired size."""
image = crop_or_pad_height(image, size[1])
image = crop_or_pad_width(image, size[0])
return image
if __name__ == "__main__":
import decord
import sound_of_water.data.audio_transforms as at
# Testing on a sample file
file_path = "media_assets/ayNzH0uygFw_9.0_21.0.mp4"
assert os.path.exists(file_path), f"File not found: {file_path}"
# Define audio transforms
cfg_transform = {
"audio": {
"wave": [
{
"name": "AddNoise",
"args": {
"noise_level": 0.001
},
"augmentation": True,
},
{
"name": "ChangeVolume",
"args": {
"volume_factor": [0.8, 1.2]
},
"augmentation": True,
},
{
"name": "Wav2Vec2WaveformProcessor",
"args": {
"model_name": "facebook/wav2vec2-base-960h",
"sr": 16000
}
}
],
"spec": None,
}
}
audio_transform = at.define_audio_transforms(
cfg_transform, augment=False,
)
# Define audio load arguments
aload_args = {
"sr": 16000,
"clip_len": None,
"backend": "decord",
}
# Define audio pipeline arguments
apipe_args = {
"spec_args": None,
"stack": True,
}
# Run the pipeline (this is used to pass to the model)
audio = load_and_process_audio(
audio_path=file_path,
clips=None,
load_entire=True,
cut_to_clip_len=False,
audio_transform=audio_transform,
aload_args=aload_args,
apipe_args=apipe_args,
)[0]
# This will be used to visualise
visualise_args = {
"sr": 16000,
"n_fft": 400,
"hop_length": 320,
"n_mels": 64,
"margin": 16.,
"C": 340 * 100.,
"audio_output_fps": 49.,
}
y = load_audio_clips(
audio_path=file_path,
clips=None,
load_entire=True,
cut_to_clip_len=False,
**aload_args,
)[0]
S = librosa_harmonic_spectrogram_db(
y,
sr=visualise_args["sr"],
n_fft=visualise_args["n_fft"],
hop_length=visualise_args["hop_length"],
n_mels=visualise_args['n_mels'],
)
# Load video frame
vr = decord.VideoReader(file_path, num_threads=1)
frame = PIL.Image.fromarray(vr[0].asnumpy())
"""
# Cut to desired width
new_width, new_height = 270, 480
width, height = frame.size
if width > new_width:
# Crop the width
left = (width - new_width) // 2
right = left + new_width
frame = frame.crop((left, 0, right, height))
else:
# Resize along width to have the desired width
frame = su.visualize.resize_width(frame, new_width)
assert frame.size[0] == new_width, \
f"Width mismatch: {frame.size[0]} != {new_width}"
# Now pad/crop to desired height
if height > new_height:
# Crop the height
top = (height - new_height) // 2
bottom = top + new_height
frame = frame.crop((0, top, new_width, bottom))
else:
# Pad the height
frame = pad_to_height(frame, new_height)
assert frame.size[1] == new_height, \
f"Height mismatch: {frame.size[1]} != {new_height}"
"""
frame = crop_or_pad_to_size(frame)
# frame.save("1.png")
# Visualise
fig, axes = plt.subplots(
1, 2, figsize=(13, 4), width_ratios=[0.25, 0.75],
)
ax = axes[0]
ax.imshow(frame, aspect="auto")
ax.set_title("Example frame")
ax.set_xticks([])
ax.set_yticks([])
ax = axes[1]
show_logmelspectrogram(
S=S,
ax=ax,
show=False,
sr=visualise_args["sr"],
n_fft=visualise_args["n_fft"],
hop_length=visualise_args["hop_length"],
)
plt.savefig("./media_assets/audio_visualisation.png", bbox_inches="tight")
plt.close()