File size: 4,945 Bytes
2ec0615 8f7d113 2ec0615 b9cb210 8e0633a 2ec0615 b9cb210 2ec0615 8f7d113 2ec0615 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
from load import LoadModel
from generate import GenerateMidiText
from constants import INSTRUMENT_CLASSES
from decoder import TextDecoder
from utils import get_miditok, index_has_substring
from playback import get_music
from matplotlib import pylab
import sys
import matplotlib
from generation_utils import plot_piano_roll
import numpy as np
matplotlib.use("Agg")
import matplotlib.pyplot as plt
sys.modules["pylab"] = pylab
model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e"
n_bar_generated = 8
# model_repo = "JammyMachina/improved_4bars-mdl"
# n_bar_generated = 4
model, tokenizer = LoadModel(
model_repo, from_huggingface=True, revision=revision
).load_model_and_tokenizer()
miditok = get_miditok()
decoder = TextDecoder(miditok)
def define_prompt(state, genesis):
if len(state) == 0:
input_prompt = "PIECE_START "
else:
input_prompt = genesis.get_whole_piece_from_bar_dict()
return input_prompt
def generator(
regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1
):
inst = next(
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
{"family_number": "DRUMS"},
)["family_number"]
inst_index = index_has_substring(state, "INST=" + str(inst))
# Regenerate
if regenerate:
state.pop(inst_index)
genesis.delete_one_track(inst_index)
generated_text = (
genesis.get_whole_piece_from_bar_dict()
) # maybe not useful here
inst_index = -1 # reset to last generated
# Generate
if not add_bars:
# NEW TRACK
input_prompt = define_prompt(state, genesis)
generated_text = genesis.generate_one_new_track(
inst, density, temp, input_prompt=input_prompt
)
else:
# NEW BARS
genesis.generate_n_more_bars(add_bar_count) # for all instruments
generated_text = genesis.get_whole_piece_from_bar_dict()
decoder.get_midi(generated_text, "mixed.mid")
mixed_inst_midi, mixed_audio = get_music("mixed.mid")
inst_text = genesis.get_selected_track_as_text(inst_index)
inst_midi_name = f"{instrument}.mid"
decoder.get_midi(inst_text, inst_midi_name)
_, inst_audio = get_music(inst_midi_name)
piano_roll = plot_piano_roll(mixed_inst_midi)
state.append(inst_text)
return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio)
def instrument_row(default_inst):
with gr.Row():
with gr.Column(scale=1, min_width=50):
inst = gr.Dropdown(
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
value=default_inst,
label="Instrument",
)
temp = gr.Number(value=0.7, label="Creativity")
density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density")
with gr.Column(scale=3):
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
with gr.Column(scale=1, min_width=100):
inst_audio = gr.Audio(label="Audio")
regenerate = gr.Checkbox(value=False, label="Regenerate")
# add_bars = gr.Checkbox(value=False, label="Add Bars")
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=generator,
inputs=[
regenerate,
temp,
density,
inst,
state,
],
outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
)
with gr.Blocks(cache_examples=False) as demo:
genesis = GenerateMidiText(
model,
tokenizer,
)
genesis.set_nb_bars_generated(n_bars=n_bar_generated)
state = gr.State([])
mixed_audio = gr.Audio(label="Mixed Audio")
piano_roll = gr.Plot(label="Piano Roll")
instrument_row("Drums")
instrument_row("Bass")
instrument_row("Synth Lead")
# instrument_row("Piano")
demo.launch(debug=True)
"""
TODO: DEPLOY
TODO: temp file situation
TODO: clear cache situation
TODO: reset button
TODO: instrument mapping business
TODO: Y lim axis of piano roll
TODO: add a button to save the generated midi
TODO: add improvise button
TODO: making the piano roll fit on the horizontal scale
TODO: set values for temperature as it is done for density
TODO: set the color situation to be dark background
TODO: make regeration default when an intrument has already been track has already been generated
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
TODO: row height to fix
TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
TODO: block regenerate if add bar on
"""
|