import gradio as gr from load import LoadModel from generate import GenerateMidiText from constants import INSTRUMENT_CLASSES from decoder import TextDecoder from utils import get_miditok, index_has_substring from playback import get_music from matplotlib import pylab import sys import matplotlib from generation_utils import plot_piano_roll import numpy as np matplotlib.use("Agg") import matplotlib.pyplot as plt sys.modules["pylab"] = pylab model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53" revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e" n_bar_generated = 8 # model_repo = "JammyMachina/improved_4bars-mdl" # n_bar_generated = 4 model, tokenizer = LoadModel( model_repo, from_huggingface=True, revision=revision ).load_model_and_tokenizer() genesis = GenerateMidiText( model, tokenizer, ) genesis.set_nb_bars_generated(n_bars=n_bar_generated) miditok = get_miditok() decoder = TextDecoder(miditok) def define_prompt(state, genesis): if len(state) == 0: input_prompt = "PIECE_START " else: input_prompt = genesis.get_whole_piece_from_bar_dict() return input_prompt def generator( regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1 ): inst = next( (inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument), {"family_number": "DRUMS"}, )["family_number"] inst_index = index_has_substring(state, "INST=" + str(inst)) # Regenerate if regenerate: state.pop(inst_index) genesis.delete_one_track(inst_index) generated_text = ( genesis.get_whole_piece_from_bar_dict() ) # maybe not useful here inst_index = -1 # reset to last generated # Generate if not add_bars: # NEW TRACK input_prompt = define_prompt(state, genesis) generated_text = genesis.generate_one_new_track( inst, density, temp, input_prompt=input_prompt ) else: # NEW BARS genesis.generate_n_more_bars(add_bar_count) # for all instruments generated_text = genesis.get_whole_piece_from_bar_dict() decoder.get_midi(generated_text, "mixed.mid") mixed_inst_midi, mixed_audio = get_music("mixed.mid") inst_text = genesis.get_selected_track_as_text(inst_index) inst_midi_name = f"{instrument}.mid" decoder.get_midi(inst_text, inst_midi_name) _, inst_audio = get_music(inst_midi_name) piano_roll = plot_piano_roll(mixed_inst_midi) state.append(inst_text) return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio) def instrument_row(default_inst): with gr.Row(): with gr.Column(scale=1, min_width=50): inst = gr.Dropdown( [inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"], value=default_inst, label="Instrument", ) temp = gr.Number(value=0.7, label="Creativity") density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density") with gr.Column(scale=3): output_txt = gr.Textbox(label="output", lines=10, max_lines=10) with gr.Column(scale=1, min_width=100): inst_audio = gr.Audio(label="Audio") regenerate = gr.Checkbox(value=False, label="Regenerate") # add_bars = gr.Checkbox(value=False, label="Add Bars") # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars") gen_btn = gr.Button("Generate") gen_btn.click( fn=generator, inputs=[ regenerate, temp, density, inst, state, ], outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio], ) with gr.Blocks(cache_examples=False) as demo: state = gr.State([]) mixed_audio = gr.Audio(label="Mixed Audio") piano_roll = gr.Plot(label="Piano Roll") instrument_row("Drums") instrument_row("Bass") instrument_row("Synth Lead") # instrument_row("Piano") demo.launch(debug=True) """ TODO: DEPLOY TODO: temp file situation TODO: clear cache situation TODO: reset button TODO: instrument mapping business TODO: Y lim axis of piano roll TODO: add a button to save the generated midi TODO: add improvise button TODO: making the piano roll fit on the horizontal scale TODO: set values for temperature as it is done for density TODO: set the color situation to be dark background TODO: make regeration default when an intrument has already been track has already been generated TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments TODO: row height to fix TODO: reset state of tick boxes after used maybe (regenerate, add bars) ; TODO: block regenerate if add bar on """