amuse / mel_module.py
alppo's picture
add generator module
e599c74
from typing import Optional
from config import config
import numpy as np
import librosa
from PIL import Image
import soundfile as sf
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='librosa')
class Mel:
def __init__(
self,
file_path: str = None,
spectrogram: Optional[np.ndarray] = None,
image: Image.Image = None,
x_res: int = config.image_size,
y_res: int = config.image_size,
sample_rate: int = config.sample_rate,
n_fft: int = 2048,
hop_length: int = 882,
top_db: int = 80,
n_iter: int = 32,
):
self.hop_length = hop_length
self.sr = sample_rate
self.n_fft = n_fft
self.top_db = top_db
self.n_iter = n_iter
self.x_res = x_res
self.y_res = y_res
self.n_mels = self.y_res
self.slice_size = self.x_res * self.hop_length - 1
self.file_path = file_path
self.spectrogram = spectrogram
self.image = image
if file_path is not None and not isinstance(file_path, str):
raise ValueError("file_path must be a string")
if spectrogram is not None and not isinstance(spectrogram, np.ndarray):
raise ValueError("spectrogram must be an ndarray")
if image is not None and not isinstance(image, Image.Image):
raise ValueError("image must be a PIL Image")
if file_path is not None:
self.load_file()
elif image is not None:
self.load_spectrogram()
elif spectrogram is not None:
self.load_image()
else:
print("Both file path and image are None!")
def load_file(self):
try:
# Load audio
if ".wav" in self.file_path:
audio, _ = librosa.load(self.file_path, mono=True, sr=self.sr)
# Pad audio if necessary
if len(audio) < self.x_res * self.hop_length:
audio = np.concatenate([audio, np.zeros((self.x_res * self.hop_length - len(audio),))])
# Compute mel spectrogram
S = librosa.feature.melspectrogram(
y=audio, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels, fmax=self.sr//2
)
log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
log_S = log_S[:self.y_res, :self.x_res] # Ensure the spectrogram is of the desired size
self.spectrogram = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
self.image = Image.fromarray(self.spectrogram)
except Exception as e:
print(f"Error loading {self.file_path}: {e}")
def load_spectrogram(self):
self.spectrogram = np.array(self.image)
def load_image(self):
self.spectrogram = self.spectrogram.astype("uint8")
self.image = Image.fromarray(self.spectrogram)
def get_spectrogram(self):
return self.spectrogram
def get_image(self):
return self.image
def get_audio(self):
log_S = self.spectrogram.astype("float") * self.top_db / 255 - self.top_db
S = librosa.db_to_power(log_S)
audio = librosa.feature.inverse.mel_to_audio(
S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
)
return Audio(audio, rate=self.sr)
def save_audio(self):
audio = self.get_audio()
sf.write(config.generated_track_path, audio.data, audio.rate)
print(f"Audio saved to {config.generated_track_path}")
def plot_spectrogram(self):
plt.figure(figsize=(10, 4))
plt.imshow(self.spectrogram, aspect='auto', origin='lower', cmap='viridis')
plt.colorbar(label='Magnitude')
plt.title('Mel Spectrogram')
plt.xlabel('Time (frames)')
plt.ylabel('Frequency (Mel bins)')
plt.tight_layout()
plt.show()