WeixuanYuan's picture
Upload 70 files
bd6e54b verified
import os
import librosa
import mido
import numpy as np
import torch
from tools import read_wav_to_numpy, pad_STFT, encode_stft
from webUI.natural_language_guided_4.gradio_webUI import GradioWebUI
from webUI.natural_language_guided_4.utils import InputBatch2Encode_STFT
def load_presets(gradioWebUI: GradioWebUI):
# Load configurations
uNet = gradioWebUI.uNet
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
VAE_scale = gradioWebUI.VAE_scale
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
timesteps = gradioWebUI.timesteps
VAE_quantizer = gradioWebUI.VAE_quantizer
VAE_encoder = gradioWebUI.VAE_encoder
VAE_decoder = gradioWebUI.VAE_decoder
CLAP = gradioWebUI.CLAP
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
device = gradioWebUI.device
squared = gradioWebUI.squared
sample_rate = gradioWebUI.sample_rate
noise_strategy = gradioWebUI.noise_strategy
def add_preset_instruments(virtual_instruments, instrument_name):
instruments_path = os.path.join("webUI", "presets", "instruments", f"{instrument_name}.wav")
sample_rate, origin_audio = read_wav_to_numpy(instruments_path)
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
padded_D = pad_STFT(D)
encoded_D = encode_stft(padded_D)
# Todo: justify batchsize to 1
origin_spectrogram_batch_tensor = torch.from_numpy(
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)
# Todo: remove hard-coding
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer,
squared=squared)
virtual_instrument = {"latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(),
"quantized_latent_representation": quantized_origin_latent_representations[0].to(
"cpu").detach().numpy(),
"sampler": "ddim",
"signal": (sample_rate, origin_audio),
"spectrogram_gradio_image": origin_flipped_log_spectrums[0],
"phase_gradio_image": origin_flipped_phases[0]}
virtual_instruments[f"preset_{instrument_name}"] = virtual_instrument
return virtual_instruments
virtual_instruments = {}
preset_instrument_names = ["ax", "electronic_sound", "organ", "synth_lead", "keyboard", "string"]
for preset_instrument_name in preset_instrument_names:
virtual_instruments = add_preset_instruments(virtual_instruments, preset_instrument_name)
def load_midi_files():
midis_dict = {}
midi_file_names = ["Ode_to_Joy_Easy_variation", "Air_on_the_G_String", "Canon_in_D"]
for midi_file_name in midi_file_names:
midi_path = os.path.join("webUI", "presets", "midis", f"{midi_file_name}.mid")
mid = mido.MidiFile(midi_path)
midis_dict[midi_file_name] = mid
return midis_dict
midis = load_midi_files()
return virtual_instruments, midis