Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
"""Module for visualizing audio data and chorus predictions.""" | |
from typing import List | |
import librosa | |
import librosa.display | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
from chorus_detection.audio.processor import AudioFeature | |
def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None: | |
"""Draw meter grid lines on the plot. | |
Args: | |
ax: The matplotlib axes object to draw on | |
meter_grid_times: Array of times at which to draw the meter lines | |
""" | |
for time in meter_grid_times: | |
ax.axvline(x=time, color='grey', linestyle='--', | |
linewidth=1, alpha=0.6) | |
def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None: | |
"""Plot the audio waveform and overlay the predicted chorus locations. | |
Args: | |
audio_features: An object containing audio features and components | |
binary_predictions: Array of binary predictions indicating chorus locations | |
""" | |
meter_grid_times = librosa.frames_to_time( | |
audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length) | |
fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96) | |
# Display harmonic and percussive components | |
librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr, | |
alpha=0.8, ax=ax, color='deepskyblue') | |
librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr, | |
alpha=0.7, ax=ax, color='plum') | |
plot_meter_lines(ax, meter_grid_times) | |
# Highlight chorus sections | |
for i, prediction in enumerate(binary_predictions): | |
start_time = meter_grid_times[i] | |
end_time = meter_grid_times[i + 1] if i < len( | |
meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr | |
if prediction == 1: | |
ax.axvspan(start_time, end_time, color='green', alpha=0.3, | |
label='Predicted Chorus' if i == 0 else None) | |
# Set plot limits and labels | |
ax.set_xlim([0, len(audio_features.y) / audio_features.sr]) | |
ax.set_ylabel('Amplitude') | |
audio_file_name = os.path.basename(audio_features.audio_path) | |
ax.set_title( | |
f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}') | |
# Add legend | |
chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3) | |
handles, labels = ax.get_legend_handles_labels() | |
handles.append(chorus_patch) | |
labels.append('Chorus') | |
ax.legend(handles=handles, labels=labels) | |
# Set x-tick labels in minutes:seconds format | |
duration = len(audio_features.y) / audio_features.sr | |
xticks = np.arange(0, duration, 10) | |
xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks] | |
ax.set_xticks(xticks) | |
ax.set_xticklabels(xlabels) | |
plt.tight_layout() | |
plt.show(block=False) |