yourusername's picture
:beers: cheers
66a6dc0
raw history blame
No virus
2.99 kB
import io
import torch
import PIL.Image
import numpy as np
import scipy.signal
import librosa.display
import matplotlib.pyplot as plt
from torch.functional import Tensor
from torchvision.transforms import ToTensor
def compute_comparison_spectrogram(
x: np.ndarray,
y: np.ndarray,
sample_rate: float = 44100,
n_fft: int = 2048,
hop_length: int = 1024,
) -> Tensor:
X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max)
Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max)
fig, axs = plt.subplots(figsize=(9, 6), nrows=2)
img = librosa.display.specshow(
X_db,
ax=axs[0],
hop_length=hop_length,
x_axis="time",
y_axis="log",
sr=sample_rate,
)
# fig.colorbar(img, ax=axs[0])
img = librosa.display.specshow(
Y_db,
ax=axs[1],
hop_length=hop_length,
x_axis="time",
y_axis="log",
sr=sample_rate,
)
# fig.colorbar(img, ax=axs[1])
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
image = PIL.Image.open(buf)
image = ToTensor()(image)
plt.close("all")
return image
def plot_multi_spectrum(
ys=None,
Hs=None,
legend=[],
title="Spectrum",
filename=None,
sample_rate=44100,
n_fft=1024,
zero_mean=False,
):
if Hs is None:
Hs = []
for y in ys:
X = get_average_spectrum(y, n_fft)
X_sm = smooth_spectrum(X)
Hs.append(X_sm)
bin_width = (sample_rate / 2) / (n_fft // 2)
freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width)
fig, ax1 = plt.subplots()
for idx, H in enumerate(Hs):
H = np.nan_to_num(H)
H = np.clip(H, 0, np.max(H))
H_dB = 20 * np.log10(H + 1e-8)
if zero_mean:
H_dB -= np.mean(H_dB)
if "Target" in legend[idx]:
ax1.plot(freqs, H_dB, linestyle="--", color="k")
else:
ax1.plot(freqs, H_dB)
plt.legend(legend)
ax1.set_xscale("log")
ax1.set_ylim([-80, 0])
ax1.set_xlim([100, 11000])
plt.title(title)
plt.ylabel("Magnitude (dB)")
plt.xlabel("Frequency (Hz)")
plt.grid(c="lightgray", which="both")
if filename is not None:
plt.savefig(f"{filename}.png", dpi=300)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
image = PIL.Image.open(buf)
image = ToTensor()(image)
plt.close("all")
return image
def smooth_spectrum(H):
# apply Savgol filter for smoothed target curve
return scipy.signal.savgol_filter(H, 1025, 2)
def get_average_spectrum(x, n_fft):
X = torch.stft(x, n_fft, return_complex=True, normalized=True)
X = X.abs() # convert to magnitude
X = X.mean(dim=-1).view(-1) # average across frames
return X