File size: 5,178 Bytes
ff3c666 efb26ff cb4d93d efb26ff ff3c666 38bae21 22f9998 efb26ff b1a906d ff3c666 cb4d93d ff3c666 cb4d93d ff3c666 cb4d93d ff3c666 cb4d93d ff3c666 cb4d93d ff3c666 cb4d93d ff3c666 cb4d93d 8478564 ff3c666 cb4d93d ff3c666 cb4d93d b1a906d ca4414c b1a906d ca4414c b1a906d ca4414c b1a906d ca4414c cb4d93d 73f624e cb4d93d 73f624e cb4d93d 73f624e cb4d93d 73f624e cb4d93d 73f624e cb4d93d 73f624e ff3c666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import random
from string import ascii_letters
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import mido
import uvicorn
import gradio as gr
from inference.inference import generate_groove
# create a FastAPI app
app = FastAPI()
app.mount("/uploads", StaticFiles(directory="./uploads"), name="uploads")
app.mount("/generated", StaticFiles(directory="./generated"), name="generated")
def display_bpm(bpm):
return f"BPM: {bpm}"
visualizer_html_template = """
<div>
<midi-visualizer type="piano-roll" id="myPianoRollVisualizer{id}"
src="{filepath}">
</midi-visualizer>
<midi-player
src="{filepath}"
sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/jazz_kit" visualizer="#myPianoRollVisualizer{id}">
</midi-player>
</div>
"""
def load_midi_file(input_midi, bpm_input):
if input_midi:
mid = mido.MidiFile()
midi_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid"
filepath = os.path.join("uploads/", midi_filename)
with open(filepath, "wb") as fd:
fd.write(input_midi)
return visualizer_html_template.format(filepath=filepath, id="input"), filepath
else:
return None, None
def run_inference(midi_filename: str, count: int=1):
visualizers = []
filenames = generate_groove(midi_filename, count)
for id, filename in enumerate(filenames):
visualizers.append(visualizer_html_template.format(filepath=filename, id=id))
return visualizers + filenames
head = """
<script src="https://cdn.jsdelivr.net/combine/npm/tone@14.7.58,npm/@magenta/music@1.23.1/es6/core.js,npm/focus-visible@5,npm/html-midi-player@1.5.0"></script>
"""
block = gr.Blocks(head=head)
with block:
midi_filepath = gr.State()
### MIDI UPLOAD AND PREVIEW ###
with gr.Group():
input_midi = gr.File(label="Upload basic drum part", file_types=[".midi", ".mid"], type="binary")
bpm_input = gr.Number(value=120, label="BPM", interactive=True)
load_btn = gr.Button("load midi file")
midi_player = gr.HTML()
run_event = load_btn.click(load_midi_file, [input_midi, bpm_input], [midi_player, midi_filepath])
with open("js/midi-player.html") as fd:
html = fd.read()
### GENERATION SETTINGS ###
with gr.Group():
gr.Markdown("## Generation Settings")
with gr.Row():
generate_genre = gr.Dropdown(["Rock", "Pop", "Reggae", "Jazz", "Metal"], label="Genre", interactive=True)
generate_complexity = gr.Slider(1, 10, value=5, label="Complexity", info="Choose between 1 and 10", step=float, interactive=True)
with gr.Row():
bpm_display = gr.Textbox(label="BPM value", interactive=False)
bpm_input.change(fn=display_bpm, inputs=bpm_input, outputs=bpm_display)
generate_length = gr.Dropdown(["1 Bar", "2 Bars", "3 Bars", "4 Bars", "5 Bars"], label="Length", interactive=True)
with gr.Row():
generate_button = gr.Button("Generate")
### OUTPUT ###
number_outputs = gr.State(4)
with gr.Group():
gr.Markdown("## Output")
with gr.Tab("1"):
with gr.Row():
midi_player_output_1 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=120, label="BPM", interactive=False)
# output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_1 = gr.DownloadButton("Download")
with gr.Tab("2"):
with gr.Row():
midi_player_output_2 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=123, label="BPM", interactive=False)
# output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_2 = gr.DownloadButton("Download")
with gr.Tab("3"):
with gr.Row():
midi_player_output_3 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=168, label="BPM", interactive=False)
# # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_3 = gr.DownloadButton("Download")
with gr.Tab("4"):
with gr.Row():
midi_player_output_4 = gr.HTML()
# with gr.Row():
# bpm_output = gr.Number(value=222, label="BPM", interactive=False)
# output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
download_button_4 = gr.DownloadButton("Download")
run_event = generate_button.click(run_inference, [midi_filepath, number_outputs],
[midi_player_output_1, midi_player_output_2, midi_player_output_3, midi_player_output_4, download_button_1, download_button_2, download_button_3, download_button_4])
# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |