File size: 3,502 Bytes
bd6e54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

import librosa
import mido
import numpy as np
import torch

from tools import read_wav_to_numpy, pad_STFT, encode_stft
from webUI.natural_language_guided_4.gradio_webUI import GradioWebUI
from webUI.natural_language_guided_4.utils import InputBatch2Encode_STFT


def load_presets(gradioWebUI: GradioWebUI):
    # Load configurations
    uNet = gradioWebUI.uNet
    freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
    VAE_scale = gradioWebUI.VAE_scale
    height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels

    timesteps = gradioWebUI.timesteps
    VAE_quantizer = gradioWebUI.VAE_quantizer
    VAE_encoder = gradioWebUI.VAE_encoder
    VAE_decoder = gradioWebUI.VAE_decoder
    CLAP = gradioWebUI.CLAP
    CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
    device = gradioWebUI.device
    squared = gradioWebUI.squared
    sample_rate = gradioWebUI.sample_rate
    noise_strategy = gradioWebUI.noise_strategy

    def add_preset_instruments(virtual_instruments, instrument_name):

        instruments_path = os.path.join("webUI", "presets", "instruments", f"{instrument_name}.wav")
        sample_rate, origin_audio = read_wav_to_numpy(instruments_path)

        D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024)
        padded_D = pad_STFT(D)
        encoded_D = encode_stft(padded_D)

        # Todo: justify batchsize to 1
        origin_spectrogram_batch_tensor = torch.from_numpy(
            np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device)

        # Todo: remove hard-coding
        origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT(
            VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer,
            squared=squared)


        virtual_instrument = {"latent_representation": origin_latent_representations[0].to("cpu").detach().numpy(),
                              "quantized_latent_representation": quantized_origin_latent_representations[0].to(
                                  "cpu").detach().numpy(),
                              "sampler": "ddim",
                              "signal": (sample_rate, origin_audio),
                              "spectrogram_gradio_image": origin_flipped_log_spectrums[0],
                              "phase_gradio_image": origin_flipped_phases[0]}
        virtual_instruments[f"preset_{instrument_name}"] = virtual_instrument
        return virtual_instruments

    virtual_instruments = {}
    preset_instrument_names = ["ax", "electronic_sound", "organ", "synth_lead", "keyboard", "string"]
    for preset_instrument_name in preset_instrument_names:
        virtual_instruments = add_preset_instruments(virtual_instruments, preset_instrument_name)



    def load_midi_files():

        midis_dict = {}
        midi_file_names = ["Ode_to_Joy_Easy_variation", "Air_on_the_G_String", "Canon_in_D"]

        for midi_file_name in midi_file_names:
            midi_path = os.path.join("webUI", "presets", "midis", f"{midi_file_name}.mid")
            mid = mido.MidiFile(midi_path)
            midis_dict[midi_file_name] = mid

        return midis_dict

    midis = load_midi_files()

    return virtual_instruments, midis