File size: 10,434 Bytes
bd6e54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import torch
import gradio as gr
import mido
from io import BytesIO
# import pyrubberband as pyrb

from webUI.natural_language_guided_4.track_maker import DiffSynth, Track


def get_arrangement_module(gradioWebUI, virtual_instruments_state, midi_files_state):
    # Load configurations
    uNet = gradioWebUI.uNet
    freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
    VAE_scale = gradioWebUI.VAE_scale
    height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels

    timesteps = gradioWebUI.timesteps
    VAE_quantizer = gradioWebUI.VAE_quantizer
    VAE_decoder = gradioWebUI.VAE_decoder
    CLAP = gradioWebUI.CLAP
    CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
    device = gradioWebUI.device
    squared = gradioWebUI.squared
    sample_rate = gradioWebUI.sample_rate
    noise_strategy = gradioWebUI.noise_strategy

    def read_midi(midi, midi_dict):
        mid = mido.MidiFile(file=BytesIO(midi))
        tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]

        midi_info_text = f"Uploaded midi:"
        for i, track in enumerate(tracks):
            midi_info_text += f"\n{len(track.events)} events loaded from Track {i}."

        midis = midi_dict["midis"]
        midis["uploaded_midi"] = mid
        midi_dict["midis"] = midis

        return {midi_info_textbox: gr.Textbox(label="Midi info", lines=10,
                                              placeholder=midi_info_text),
                current_midi_state: "uploaded_midi",
                midi_files_state: midi_dict}

    def make_track(inpaint_steps, current_midi_name, midi_dict, max_notes, noising_strength, attack, before_release, current_instruments,

                   virtual_instruments_dict):

        if noising_strength < 1:
            print(f"Warning: making track with noising_strength = {noising_strength} < 1")
        virtual_instruments = virtual_instruments_dict["virtual_instruments"]
        sample_steps = int(inpaint_steps)

        print(f"current_instruments: {current_instruments}")
        instrument_names = current_instruments
        instruments_configs = {}

        for virtual_instrument_name in instrument_names:
            virtual_instrument = virtual_instruments[virtual_instrument_name]

            latent_representation = torch.tensor(virtual_instrument["latent_representation"], dtype=torch.float32).to(
                device)
            sampler = virtual_instrument["sampler"]

            batchsize = 1

            latent_representation = latent_representation.repeat(batchsize, 1, 1, 1)

            instruments_configs[virtual_instrument_name] = {
                'sample_steps': sample_steps,
                'sampler': sampler,
                'noising_strength': noising_strength,
                'latent_representation': latent_representation,
                'attack': attack,
                'before_release': before_release}

        diffSynth = DiffSynth(instruments_configs, uNet, VAE_quantizer, VAE_decoder, CLAP, CLAP_tokenizer, device)

        midis = midi_dict["midis"]
        mid = midis[current_midi_name]
        full_audio = diffSynth.get_music(mid, instrument_names, max_notes=max_notes)

        return {track_audio: (sample_rate, full_audio)}

    with gr.Tab("Arrangement"):
        default_instrument = "preset_string"
        current_instruments_state = gr.State(value=[default_instrument for _ in range(100)])
        current_midi_state = gr.State(value="Ode_to_Joy_Easy_variation")

        gr.Markdown("Make music with generated sounds!")
        with gr.Row(variant="panel"):
            with gr.Column(scale=3):

                @gr.render(inputs=midi_files_state)
                def check_midis(midi_dict):
                    midis = midi_dict["midis"]
                    midi_names = list(midis.keys())

                    instrument_dropdown = gr.Dropdown(
                        midi_names, label="Select from preset midi files", value="Ode_to_Joy_Easy_variation"
                    )

                    def select_midi(midi_name):
                        # print(f"midi_name: {midi_name}")
                        mid = midis[midi_name]
                        tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
                        midi_info_text = f"Name: {midi_name}"
                        for i, track in enumerate(tracks):
                            midi_info_text += f"\n{len(track.events)} events loaded from Track {i}."

                        return {midi_info_textbox: gr.Textbox(label="Midi info", lines=10,
                                                              placeholder=midi_info_text),
                                current_midi_state: midi_name}

                    instrument_dropdown.select(select_midi, inputs=instrument_dropdown,
                                               outputs=[midi_info_textbox, current_midi_state])

                midi_file = gr.File(label="Upload a midi file", type="binary", scale=1)
                midi_info_textbox = gr.Textbox(label="Midi info", lines=10,
                                               placeholder="Please select/upload a midi on the left.", scale=3,
                                               visible=False)

            with gr.Column(scale=3, ):

                @gr.render(inputs=[current_midi_state, midi_files_state, virtual_instruments_state])
                def render_select_instruments(current_midi_name, midi_dict, virtual_instruments_dict):

                    virtual_instruments = virtual_instruments_dict["virtual_instruments"]
                    instrument_names = list(virtual_instruments.keys())

                    midis = midi_dict["midis"]
                    mid = midis[current_midi_name]
                    tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]

                    dropdowns = []
                    for i, track in enumerate(tracks):
                        dropdowns.append(gr.Dropdown(
                            instrument_names, value=default_instrument, label=f"Track {i}:  {len(track.events)} notes",
                            info=f"Select an instrument to play this track!"
                        ))

                    def select_instruments(*instruments):
                        return instruments

                    for d in dropdowns:
                        d.select(select_instruments, inputs=dropdowns,
                                 outputs=current_instruments_state)


            with gr.Column(scale=3):
                max_notes_slider = gr.Slider(minimum=10.0, maximum=999.0, value=100.0, step=1.0,
                                             label="Maximum number of synthesized notes in each track",
                                             info="Lower this value to prevent Gradio timeouts")
                make_track_button = gr.Button(variant="primary", value="Make track", scale=1)
                track_audio = gr.Audio(type="numpy", label="Play music", interactive=False)

        with gr.Row(variant="panel", visible=False):
            with gr.Tab("Origin sound"):
                inpaint_steps_slider = gr.Slider(minimum=5.0, maximum=999.0, value=20.0, step=1.0,
                                                 label="inpaint_steps")
                noising_strength_slider = gradioWebUI.get_noising_strength_slider(default_noising_strength=1.)
                end_noise_level_ratio_slider = gr.Slider(minimum=0.0, maximum=1., value=0.0, step=0.01,
                                                         label="end_noise_level_ratio")
                attack_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01, label="attack in sec")
                before_release_slider = gr.Slider(minimum=0.0, maximum=1.5, value=0.5, step=0.01,
                                                  label="before_release in sec")
                release_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="release in sec")
                mask_flexivity_slider = gr.Slider(minimum=0.01, maximum=1.00, value=1., step=0.01,
                                                  label="mask_flexivity")
            with gr.Tab("Length adjustment config"):
                use_dynamic_mask_checkbox = gr.Checkbox(label="Use dynamic mask", value=True)
                test_duration_envelope_button = gr.Button(variant="primary", value="Apply envelope", scale=1)
                test_duration_stretch_button = gr.Button(variant="primary", value="Apply stretch", scale=1)
                test_duration_inpaint_button = gr.Button(variant="primary", value="Inpaint different duration", scale=1)
                duration_slider = gradioWebUI.get_duration_slider()
            with gr.Tab("Pitch shift config"):
                pitch_shift_radio = gr.Radio(choices=["librosa", "torchaudio", "rubberband"],
                                             value="librosa")

        with gr.Row(variant="panel", visible=False):
            with gr.Column(scale=2):
                with gr.Row(variant="panel"):
                    source_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy",
                                                              height=600, scale=1)
                    source_sound_phase_image = gr.Image(label="New sound phase", type="numpy",
                                                        height=600, scale=1)

    make_track_button.click(make_track,
                            inputs=[inpaint_steps_slider, current_midi_state, midi_files_state,
                                    max_notes_slider, noising_strength_slider,
                                    attack_slider,
                                    before_release_slider,
                                    current_instruments_state,
                                    virtual_instruments_state],
                            outputs=[track_audio])

    midi_file.change(read_midi,
                     inputs=[midi_file,
                             midi_files_state],
                     outputs=[midi_info_textbox,
                              current_midi_state,
                              midi_files_state])