Spaces:
Sleeping
Sleeping
import os | |
import librosa | |
import mido | |
import numpy as np | |
import torch | |
from tools import read_wav_to_numpy, pad_STFT, encode_stft | |
from webUI.natural_language_guided_4.gradio_webUI import GradioWebUI | |
from webUI.natural_language_guided_4.utils import InputBatch2Encode_STFT | |
def load_presets(gradioWebUI: GradioWebUI): | |
# Load configurations | |
uNet = gradioWebUI.uNet | |
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution | |
VAE_scale = gradioWebUI.VAE_scale | |
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels | |
timesteps = gradioWebUI.timesteps | |
VAE_quantizer = gradioWebUI.VAE_quantizer | |
VAE_encoder = gradioWebUI.VAE_encoder | |
VAE_decoder = gradioWebUI.VAE_decoder | |
CLAP = gradioWebUI.CLAP | |
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer | |
device = gradioWebUI.device | |
squared = gradioWebUI.squared | |
sample_rate = gradioWebUI.sample_rate | |
noise_strategy = gradioWebUI.noise_strategy | |
def add_preset_instruments(virtual_instruments, instrument_name): | |
instruments_path = os.path.join("webUI", "presets", "instruments", f"{instrument_name}.wav") | |
sample_rate, origin_audio = read_wav_to_numpy(instruments_path) | |
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) | |
padded_D = pad_STFT(D) | |
encoded_D = encode_stft(padded_D) | |
# Todo: justify batchsize to 1 | |
origin_spectrogram_batch_tensor = torch.from_numpy( | |
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device) | |
# Todo: remove hard-coding | |
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT( | |
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, | |
squared=squared) | |
virtual_instrument = {"latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(), | |
"quantized_latent_representation": quantized_origin_latent_representations[0].to( | |
"cpu").detach().numpy(), | |
"sampler": "ddim", | |
"signal": (sample_rate, origin_audio), | |
"spectrogram_gradio_image": origin_flipped_log_spectrums[0], | |
"phase_gradio_image": origin_flipped_phases[0]} | |
virtual_instruments[f"preset_{instrument_name}"] = virtual_instrument | |
return virtual_instruments | |
virtual_instruments = {} | |
preset_instrument_names = ["ax", "electronic_sound", "organ", "synth_lead", "keyboard", "string"] | |
for preset_instrument_name in preset_instrument_names: | |
virtual_instruments = add_preset_instruments(virtual_instruments, preset_instrument_name) | |
def load_midi_files(): | |
midis_dict = {} | |
midi_file_names = ["Ode_to_Joy_Easy_variation", "Air_on_the_G_String", "Canon_in_D"] | |
for midi_file_name in midi_file_names: | |
midi_path = os.path.join("webUI", "presets", "midis", f"{midi_file_name}.mid") | |
mid = mido.MidiFile(midi_path) | |
midis_dict[midi_file_name] = mid | |
return midis_dict | |
midis = load_midi_files() | |
return virtual_instruments, midis |