File size: 4,945 Bytes
2ec0615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7d113
2ec0615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9cb210
8e0633a
2ec0615
 
b9cb210
2ec0615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7d113
 
 
 
 
2ec0615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
from load import LoadModel
from generate import GenerateMidiText
from constants import INSTRUMENT_CLASSES
from decoder import TextDecoder
from utils import get_miditok, index_has_substring
from playback import get_music
from matplotlib import pylab
import sys
import matplotlib
from generation_utils import plot_piano_roll
import numpy as np

matplotlib.use("Agg")
import matplotlib.pyplot as plt

sys.modules["pylab"] = pylab

model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e"
n_bar_generated = 8
# model_repo = "JammyMachina/improved_4bars-mdl"
# n_bar_generated = 4

model, tokenizer = LoadModel(
    model_repo, from_huggingface=True, revision=revision
).load_model_and_tokenizer()


miditok = get_miditok()
decoder = TextDecoder(miditok)


def define_prompt(state, genesis):
    if len(state) == 0:
        input_prompt = "PIECE_START "
    else:
        input_prompt = genesis.get_whole_piece_from_bar_dict()
    return input_prompt


def generator(
    regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1
):

    inst = next(
        (inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
        {"family_number": "DRUMS"},
    )["family_number"]

    inst_index = index_has_substring(state, "INST=" + str(inst))

    # Regenerate
    if regenerate:
        state.pop(inst_index)
        genesis.delete_one_track(inst_index)
        generated_text = (
            genesis.get_whole_piece_from_bar_dict()
        )  # maybe not useful here
        inst_index = -1  # reset to last generated

    # Generate
    if not add_bars:
        # NEW TRACK
        input_prompt = define_prompt(state, genesis)
        generated_text = genesis.generate_one_new_track(
            inst, density, temp, input_prompt=input_prompt
        )
    else:
        # NEW BARS
        genesis.generate_n_more_bars(add_bar_count)  # for all instruments
        generated_text = genesis.get_whole_piece_from_bar_dict()

    decoder.get_midi(generated_text, "mixed.mid")
    mixed_inst_midi, mixed_audio = get_music("mixed.mid")

    inst_text = genesis.get_selected_track_as_text(inst_index)
    inst_midi_name = f"{instrument}.mid"
    decoder.get_midi(inst_text, inst_midi_name)
    _, inst_audio = get_music(inst_midi_name)
    piano_roll = plot_piano_roll(mixed_inst_midi)
    state.append(inst_text)

    return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio)


def instrument_row(default_inst):

    with gr.Row():
        with gr.Column(scale=1, min_width=50):
            inst = gr.Dropdown(
                [inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
                value=default_inst,
                label="Instrument",
            )
            temp = gr.Number(value=0.7, label="Creativity")
            density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density")

        with gr.Column(scale=3):
            output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
        with gr.Column(scale=1, min_width=100):
            inst_audio = gr.Audio(label="Audio")
            regenerate = gr.Checkbox(value=False, label="Regenerate")
            # add_bars = gr.Checkbox(value=False, label="Add Bars")
            # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
            gen_btn = gr.Button("Generate")
            gen_btn.click(
                fn=generator,
                inputs=[
                    regenerate,
                    temp,
                    density,
                    inst,
                    state,
                ],
                outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
            )


with gr.Blocks(cache_examples=False) as demo:
    genesis = GenerateMidiText(
        model,
        tokenizer,
    )
    genesis.set_nb_bars_generated(n_bars=n_bar_generated)
    state = gr.State([])
    mixed_audio = gr.Audio(label="Mixed Audio")
    piano_roll = gr.Plot(label="Piano Roll")
    instrument_row("Drums")
    instrument_row("Bass")
    instrument_row("Synth Lead")
    # instrument_row("Piano")

demo.launch(debug=True)

"""
TODO: DEPLOY
TODO: temp file situation
TODO: clear cache situation
TODO: reset button
TODO: instrument mapping business
TODO: Y lim axis of piano roll
TODO: add a button to save the generated midi
TODO: add improvise button
TODO: making the piano roll fit on the horizontal scale
TODO: set values for temperature as it is done for density
TODO: set the color situation to be dark background
TODO: make regeration default when an intrument has already been track has already been generated
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
TODO: row height to fix

TODO: reset state of tick boxes after used maybe (regenerate, add bars) ; 
TODO: block regenerate if add bar on
"""