asigalov61
commited on
Commit
·
8c8ea80
1
Parent(s):
56ab42f
Update app.py
Browse files
app.py
CHANGED
@@ -84,7 +84,7 @@ def create_msg(name, data):
|
|
84 |
return {"name": name, "data": data}
|
85 |
|
86 |
|
87 |
-
def run(
|
88 |
mid_seq = []
|
89 |
gen_events = int(gen_events)
|
90 |
max_len = gen_events
|
@@ -92,55 +92,32 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
92 |
disable_patch_change = False
|
93 |
disable_channels = None
|
94 |
if tab == 0:
|
95 |
-
|
96 |
-
|
97 |
-
patches = {}
|
98 |
-
for instr in instruments:
|
99 |
-
patches[i] = patch2number[instr]
|
100 |
-
i = (i + 1) if i != 8 else 10
|
101 |
-
if drum_kit != "None":
|
102 |
-
patches[9] = drum_kits2number[drum_kit]
|
103 |
-
for i, (c, p) in enumerate(patches.items()):
|
104 |
-
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
|
105 |
-
mid_seq = mid
|
106 |
-
mid = np.asarray(mid, dtype=np.int64)
|
107 |
-
if len(instruments) > 0:
|
108 |
-
disable_patch_change = True
|
109 |
-
disable_channels = [i for i in range(16) if i not in patches]
|
110 |
elif mid is not None:
|
111 |
-
|
112 |
-
|
113 |
-
mid = mid[:int(midi_events)]
|
114 |
-
max_len += len(mid)
|
115 |
-
for token_seq in mid:
|
116 |
-
mid_seq.append(token_seq.tolist())
|
117 |
init_msgs = [create_msg("visualizer_clear", None)]
|
118 |
for tokens in mid_seq:
|
119 |
-
init_msgs.append(create_msg("visualizer_append",
|
120 |
yield mid_seq, None, None, init_msgs
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
for i, token_seq in enumerate(generator):
|
126 |
-
token_seq = token_seq.tolist()
|
127 |
-
mid_seq.append(token_seq)
|
128 |
-
event = tokenizer.tokens2event(token_seq)
|
129 |
-
yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
|
130 |
-
mid = tokenizer.detokenize(mid_seq)
|
131 |
with open(f"output.mid", 'wb') as f:
|
132 |
-
f.write(MIDI.score2midi(
|
133 |
-
audio = synthesis(MIDI.score2opus(
|
134 |
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
135 |
|
136 |
|
137 |
def cancel_run(mid_seq):
|
138 |
if mid_seq is None:
|
139 |
return None, None
|
140 |
-
|
141 |
with open(f"output.mid", 'wb') as f:
|
142 |
-
f.write(MIDI.score2midi(
|
143 |
-
audio = synthesis(MIDI.score2opus(
|
144 |
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
145 |
|
146 |
|
@@ -174,11 +151,6 @@ class JSMsgReceiver(gr.HTML):
|
|
174 |
def get_block_name(self) -> str:
|
175 |
return "html"
|
176 |
|
177 |
-
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
178 |
-
40: "Blush", 48: "Orchestra"}
|
179 |
-
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
180 |
-
drum_kits2number = {v: k for k, v in number2drum_kits.items()}
|
181 |
-
|
182 |
if __name__ == "__main__":
|
183 |
parser = argparse.ArgumentParser()
|
184 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
@@ -233,9 +205,7 @@ if __name__ == "__main__":
|
|
233 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
234 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
235 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
236 |
-
run_event =
|
237 |
-
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
238 |
-
input_allow_cc],
|
239 |
[output_midi_seq, output_midi, output_audio, js_msg])
|
240 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
241 |
app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
84 |
return {"name": name, "data": data}
|
85 |
|
86 |
|
87 |
+
def run(search_prompt):
|
88 |
mid_seq = []
|
89 |
gen_events = int(gen_events)
|
90 |
max_len = gen_events
|
|
|
92 |
disable_patch_change = False
|
93 |
disable_channels = None
|
94 |
if tab == 0:
|
95 |
+
mid_seq = []
|
96 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
elif mid is not None:
|
98 |
+
mid_seq = MIDI.midi2score(mid)
|
99 |
+
|
|
|
|
|
|
|
|
|
100 |
init_msgs = [create_msg("visualizer_clear", None)]
|
101 |
for tokens in mid_seq:
|
102 |
+
init_msgs.append(create_msg("visualizer_append", tokens))
|
103 |
yield mid_seq, None, None, init_msgs
|
104 |
+
|
105 |
+
for i in range(len(mid_seq)):
|
106 |
+
yield mid_seq, None, None, [create_msg("visualizer_append", mid_seq[i]), create_msg("progress", [i + 1, mid_seq[i+1]])]
|
107 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
with open(f"output.mid", 'wb') as f:
|
109 |
+
f.write(MIDI.score2midi(mid_seq))
|
110 |
+
audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
|
111 |
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
112 |
|
113 |
|
114 |
def cancel_run(mid_seq):
|
115 |
if mid_seq is None:
|
116 |
return None, None
|
117 |
+
|
118 |
with open(f"output.mid", 'wb') as f:
|
119 |
+
f.write(MIDI.score2midi(mid_seq))
|
120 |
+
audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
|
121 |
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
122 |
|
123 |
|
|
|
151 |
def get_block_name(self) -> str:
|
152 |
return "html"
|
153 |
|
|
|
|
|
|
|
|
|
|
|
154 |
if __name__ == "__main__":
|
155 |
parser = argparse.ArgumentParser()
|
156 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
|
|
205 |
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
|
206 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
207 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
208 |
+
run_event = search_btn.click(run, [search_prompt],
|
|
|
|
|
209 |
[output_midi_seq, output_midi, output_audio, js_msg])
|
210 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
211 |
app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|