File size: 7,359 Bytes
ae1bdf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import librosa
import numpy as np
import torch

from tools import np_power_to_db, decode_stft, depad_STFT


def spectrogram_to_Gradio_image(spc):
    ### input: spc [np.ndarray]
    frequency_resolution, time_resolution = spc.shape[-2], spc.shape[-1]
    spc = np.reshape(spc, (frequency_resolution, time_resolution))

    # Todo:
    magnitude_spectrum = np.abs(spc)
    log_spectrum = np_power_to_db(magnitude_spectrum)
    flipped_log_spectrum = np.flipud(log_spectrum)

    colorful_spc = np.ones((frequency_resolution, time_resolution, 3)) * -80.0
    colorful_spc[:, :, 0] = flipped_log_spectrum
    colorful_spc[:, :, 1] = flipped_log_spectrum
    colorful_spc[:, :, 2] = np.ones((frequency_resolution, time_resolution)) * -60.0
    # Rescale to 0-255 and convert to uint8
    rescaled = (colorful_spc + 80.0) / 80.0
    rescaled = (255.0 * rescaled).astype(np.uint8)
    return rescaled


def phase_to_Gradio_image(phase):
    ### input: spc [np.ndarray]
    frequency_resolution, time_resolution = phase.shape[-2], phase.shape[-1]
    phase = np.reshape(phase, (frequency_resolution, time_resolution))

    # Todo:
    flipped_phase = np.flipud(phase)
    flipped_phase = (flipped_phase + 1.0) / 2.0

    colorful_spc = np.zeros((frequency_resolution, time_resolution, 3))
    colorful_spc[:, :, 0] = flipped_phase
    colorful_spc[:, :, 1] = flipped_phase
    colorful_spc[:, :, 2] = 0.2
    # Rescale to 0-255 and convert to uint8
    rescaled = (255.0 * colorful_spc).astype(np.uint8)
    return rescaled


def latent_representation_to_Gradio_image(latent_representation):
    # input: latent_representation [torch.tensor]
    if not isinstance(latent_representation, np.ndarray):
        latent_representation = latent_representation.to("cpu").detach().numpy()
    image = latent_representation

    def normalize_image(img):
        min_val = img.min()
        max_val = img.max()
        normalized_img = ((img - min_val) / (max_val - min_val) * 255)
        return normalized_img

    image[0, :, :] = normalize_image(image[0, :, :])
    image[1, :, :] = normalize_image(image[1, :, :])
    image[2, :, :] = normalize_image(image[2, :, :])
    image[3, :, :] = normalize_image(image[3, :, :])
    image_transposed = np.transpose(image, (1, 2, 0))
    enlarged_image = np.repeat(image_transposed, 8, axis=0)
    enlarged_image = np.repeat(enlarged_image, 8, axis=1)
    return np.flipud(enlarged_image).astype(np.uint8)


def InputBatch2Encode_STFT(encoder, STFT_batch, resolution=(512, 256), quantizer=None, squared=True):
    """Transform batch of numpy spectrogram's into signals and encodings."""
    # Todo: remove resolution hard-coding
    frequency_resolution, time_resolution = resolution

    device = next(encoder.parameters()).device
    if not (quantizer is None):
        latent_representation_batch = encoder(STFT_batch.to(device))
        quantized_latent_representation_batch, loss, (_, _, _) = quantizer(latent_representation_batch)
    else:
        mu, logvar, latent_representation_batch = encoder(STFT_batch.to(device))
        quantized_latent_representation_batch = None

    STFT_batch = STFT_batch.to("cpu").detach().numpy()

    origin_flipped_log_spectrums, origin_flipped_phases, origin_signals = [], [], []
    for STFT in STFT_batch:

        padded_D_rec = decode_stft(STFT)
        D_rec = depad_STFT(padded_D_rec)
        spc = np.abs(D_rec)
        phase = np.angle(D_rec)

        flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
        flipped_phase = phase_to_Gradio_image(phase)

        # get_audio
        rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)

        origin_flipped_log_spectrums.append(flipped_log_spectrum)
        origin_flipped_phases.append(flipped_phase)
        origin_signals.append(rec_signal)

    return origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, \
        latent_representation_batch, quantized_latent_representation_batch


def encodeBatch2GradioOutput_STFT(decoder, latent_vector_batch, resolution=(512, 256), original_STFT_batch=None):
    """Show a spectrogram."""
    # Todo: remove resolution hard-coding
    frequency_resolution, time_resolution = resolution

    if isinstance(latent_vector_batch, np.ndarray):
        latent_vector_batch = torch.from_numpy(latent_vector_batch).to(next(decoder.parameters()).device)

    reconstruction_batch = decoder(latent_vector_batch).to("cpu").detach().numpy()

    flipped_log_spectrums, flipped_phases, rec_signals = [], [], []
    flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp = [], [], []

    for index, STFT in enumerate(reconstruction_batch):
        padded_D_rec = decode_stft(STFT)
        D_rec = depad_STFT(padded_D_rec)
        spc = np.abs(D_rec)
        phase = np.angle(D_rec)

        flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
        flipped_phase = phase_to_Gradio_image(phase)

        # get_audio
        rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)

        flipped_log_spectrums.append(flipped_log_spectrum)
        flipped_phases.append(flipped_phase)
        rec_signals.append(rec_signal)

        ##########################################

        if original_STFT_batch is not None:
            STFT[0, :, :] = original_STFT_batch[index, 0, :, :]

            padded_D_rec = decode_stft(STFT)
            D_rec = depad_STFT(padded_D_rec)
            spc = np.abs(D_rec)
            phase = np.angle(D_rec)

            flipped_log_spectrum = spectrogram_to_Gradio_image(spc)
            flipped_phase = phase_to_Gradio_image(phase)

            # get_audio
            rec_signal = librosa.istft(D_rec, hop_length=256, win_length=1024)

            flipped_log_spectrums_with_original_amp.append(flipped_log_spectrum)
            flipped_phases_with_original_amp.append(flipped_phase)
            rec_signals_with_original_amp.append(rec_signal)


    return flipped_log_spectrums, flipped_phases, rec_signals, \
        flipped_log_spectrums_with_original_amp, flipped_phases_with_original_amp, rec_signals_with_original_amp



def add_instrument(source_dict, virtual_instruments_dict, virtual_instrument_name, sample_index):

    virtual_instruments = virtual_instruments_dict["virtual_instruments"]
    virtual_instrument = {
                          "latent_representation": source_dict["latent_representations"][sample_index],
                          "quantized_latent_representation": source_dict["quantized_latent_representations"][sample_index],
                          "sampler": source_dict["sampler"],
                          "signal": source_dict["new_sound_rec_signals_gradio"][sample_index],
                          "spectrogram_gradio_image": source_dict["new_sound_spectrogram_gradio_images"][
                              sample_index],
                          "phase_gradio_image": source_dict["new_sound_phase_gradio_images"][
                              sample_index]}
    virtual_instruments[virtual_instrument_name] = virtual_instrument
    virtual_instruments_dict["virtual_instruments"] = virtual_instruments
    return virtual_instruments_dict