File size: 5,178 Bytes
ff3c666
 
 
 
 
 
 
efb26ff
cb4d93d
efb26ff
ff3c666
 
38bae21
22f9998
efb26ff
b1a906d
 
 
ff3c666
 
cb4d93d
ff3c666
 
 
 
cb4d93d
ff3c666
 
 
 
cb4d93d
 
ff3c666
 
 
 
cb4d93d
 
ff3c666
 
cb4d93d
ff3c666
cb4d93d
 
 
 
 
 
 
 
ff3c666
 
 
 
 
 
 
cb4d93d
 
 
8478564
 
ff3c666
 
cb4d93d
ff3c666
 
 
cb4d93d
 
 
b1a906d
ca4414c
b1a906d
 
 
ca4414c
b1a906d
 
 
 
ca4414c
b1a906d
ca4414c
 
cb4d93d
 
 
73f624e
 
 
cb4d93d
 
 
 
 
73f624e
 
 
cb4d93d
 
 
 
 
73f624e
 
 
cb4d93d
 
 
 
 
73f624e
 
 
cb4d93d
 
 
 
 
73f624e
cb4d93d
 
73f624e
ff3c666
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import random
from string import ascii_letters
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import mido
import uvicorn
import gradio as gr
from inference.inference import generate_groove

# create a FastAPI app
app = FastAPI()
app.mount("/uploads", StaticFiles(directory="./uploads"), name="uploads")
app.mount("/generated", StaticFiles(directory="./generated"), name="generated")

def display_bpm(bpm):
    return f"BPM: {bpm}"

visualizer_html_template = """
<div>
    <midi-visualizer type="piano-roll" id="myPianoRollVisualizer{id}"
        src="{filepath}">
    </midi-visualizer>
    <midi-player
        src="{filepath}"
        sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/jazz_kit" visualizer="#myPianoRollVisualizer{id}">
    </midi-player>
</div>
"""




def load_midi_file(input_midi, bpm_input):
    if input_midi:
        mid = mido.MidiFile()
        midi_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid"
        filepath = os.path.join("uploads/", midi_filename)
        with open(filepath, "wb") as fd:
            fd.write(input_midi)
        return visualizer_html_template.format(filepath=filepath, id="input"), filepath
    else:
        return None, None

def run_inference(midi_filename: str, count: int=1):
    visualizers = []
    filenames = generate_groove(midi_filename, count)
    for id, filename in enumerate(filenames):
        visualizers.append(visualizer_html_template.format(filepath=filename, id=id))
    return visualizers + filenames


head = """
<script src="https://cdn.jsdelivr.net/combine/npm/tone@14.7.58,npm/@magenta/music@1.23.1/es6/core.js,npm/focus-visible@5,npm/html-midi-player@1.5.0"></script>
"""
block = gr.Blocks(head=head)
with block:
    midi_filepath = gr.State()
    ### MIDI UPLOAD AND PREVIEW ###
    with gr.Group():
        input_midi = gr.File(label="Upload basic drum part", file_types=[".midi", ".mid"], type="binary")
        bpm_input = gr.Number(value=120, label="BPM", interactive=True)
        load_btn = gr.Button("load midi file")
        midi_player = gr.HTML()
        run_event = load_btn.click(load_midi_file, [input_midi, bpm_input], [midi_player, midi_filepath])

        with open("js/midi-player.html") as fd:
            html = fd.read()

    ### GENERATION SETTINGS ###
    with gr.Group():
        gr.Markdown("## Generation Settings")
        with gr.Row():
            generate_genre = gr.Dropdown(["Rock", "Pop", "Reggae", "Jazz", "Metal"], label="Genre", interactive=True)
            generate_complexity = gr.Slider(1, 10, value=5, label="Complexity", info="Choose between 1 and 10", step=float, interactive=True)

        with gr.Row():
            bpm_display = gr.Textbox(label="BPM value", interactive=False)
            bpm_input.change(fn=display_bpm, inputs=bpm_input, outputs=bpm_display)
            generate_length = gr.Dropdown(["1 Bar", "2 Bars", "3 Bars", "4 Bars", "5 Bars"], label="Length", interactive=True)

        with gr.Row():
            generate_button = gr.Button("Generate")


    ### OUTPUT ###
    number_outputs = gr.State(4)
    with gr.Group():
        gr.Markdown("## Output")
        with gr.Tab("1"):
            with gr.Row():
                midi_player_output_1 = gr.HTML()
            # with gr.Row():
            #     bpm_output = gr.Number(value=120, label="BPM", interactive=False)
            #     output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
            download_button_1 = gr.DownloadButton("Download")

        with gr.Tab("2"):
            with gr.Row():
                midi_player_output_2 = gr.HTML()
            # with gr.Row():
            #     bpm_output = gr.Number(value=123, label="BPM", interactive=False)
            #     output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
            download_button_2 = gr.DownloadButton("Download")

        with gr.Tab("3"):
            with gr.Row():
                midi_player_output_3 = gr.HTML()
            # with gr.Row():
            #     bpm_output = gr.Number(value=168, label="BPM", interactive=False)
            #     # output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
            download_button_3 = gr.DownloadButton("Download")

        with gr.Tab("4"):
            with gr.Row():
                midi_player_output_4 = gr.HTML()
            # with gr.Row():
            #     bpm_output = gr.Number(value=222, label="BPM", interactive=False)
            #     output_instrument = gr.Dropdown(["Drums", "Snare", "Kick"], label="Sound", interactive=True)
            download_button_4 = gr.DownloadButton("Download")

        run_event = generate_button.click(run_inference, [midi_filepath, number_outputs],
            [midi_player_output_1, midi_player_output_2, midi_player_output_3, midi_player_output_4, download_button_1, download_button_2, download_button_3, download_button_4])


# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)