the-jam-machine-app / playground.py
m41w4r3.exe
path fix vol2 wtfff
8e0633a
raw history blame
No virus
4.92 kB
import gradio as gr
from load import LoadModel
from generate import GenerateMidiText
from constants import INSTRUMENT_CLASSES
from decoder import TextDecoder
from utils import get_miditok, index_has_substring
from playback import get_music
from matplotlib import pylab
import sys
import matplotlib
from generation_utils import plot_piano_roll
import numpy as np
matplotlib.use("Agg")
import matplotlib.pyplot as plt
sys.modules["pylab"] = pylab
model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e"
n_bar_generated = 8
# model_repo = "JammyMachina/improved_4bars-mdl"
# n_bar_generated = 4
model, tokenizer = LoadModel(
model_repo, from_huggingface=True, revision=revision
).load_model_and_tokenizer()
genesis = GenerateMidiText(
model,
tokenizer,
)
genesis.set_nb_bars_generated(n_bars=n_bar_generated)
miditok = get_miditok()
decoder = TextDecoder(miditok)
def define_prompt(state, genesis):
if len(state) == 0:
input_prompt = "PIECE_START "
else:
input_prompt = genesis.get_whole_piece_from_bar_dict()
return input_prompt
def generator(
regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1
):
inst = next(
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
{"family_number": "DRUMS"},
)["family_number"]
inst_index = index_has_substring(state, "INST=" + str(inst))
# Regenerate
if regenerate:
state.pop(inst_index)
genesis.delete_one_track(inst_index)
generated_text = (
genesis.get_whole_piece_from_bar_dict()
) # maybe not useful here
inst_index = -1 # reset to last generated
# Generate
if not add_bars:
# NEW TRACK
input_prompt = define_prompt(state, genesis)
generated_text = genesis.generate_one_new_track(
inst, density, temp, input_prompt=input_prompt
)
else:
# NEW BARS
genesis.generate_n_more_bars(add_bar_count) # for all instruments
generated_text = genesis.get_whole_piece_from_bar_dict()
decoder.get_midi(generated_text, "mixed.mid")
mixed_inst_midi, mixed_audio = get_music("mixed.mid")
inst_text = genesis.get_selected_track_as_text(inst_index)
inst_midi_name = f"{instrument}.mid"
decoder.get_midi(inst_text, inst_midi_name)
_, inst_audio = get_music(inst_midi_name)
piano_roll = plot_piano_roll(mixed_inst_midi)
state.append(inst_text)
return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio)
def instrument_row(default_inst):
with gr.Row():
with gr.Column(scale=1, min_width=50):
inst = gr.Dropdown(
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
value=default_inst,
label="Instrument",
)
temp = gr.Number(value=0.7, label="Creativity")
density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density")
with gr.Column(scale=3):
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
with gr.Column(scale=1, min_width=100):
inst_audio = gr.Audio(label="Audio")
regenerate = gr.Checkbox(value=False, label="Regenerate")
# add_bars = gr.Checkbox(value=False, label="Add Bars")
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=generator,
inputs=[
regenerate,
temp,
density,
inst,
state,
],
outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
)
with gr.Blocks(cache_examples=False) as demo:
state = gr.State([])
mixed_audio = gr.Audio(label="Mixed Audio")
piano_roll = gr.Plot(label="Piano Roll")
instrument_row("Drums")
instrument_row("Bass")
instrument_row("Synth Lead")
# instrument_row("Piano")
demo.launch(debug=True)
"""
TODO: DEPLOY
TODO: temp file situation
TODO: clear cache situation
TODO: reset button
TODO: instrument mapping business
TODO: Y lim axis of piano roll
TODO: add a button to save the generated midi
TODO: add improvise button
TODO: making the piano roll fit on the horizontal scale
TODO: set values for temperature as it is done for density
TODO: set the color situation to be dark background
TODO: make regeration default when an intrument has already been track has already been generated
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
TODO: row height to fix
TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
TODO: block regenerate if add bar on
"""