asigalov61 commited on
Commit
2a68ddd
1 Parent(s): 1543d57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -22
app.py CHANGED
@@ -2,15 +2,23 @@ import argparse
2
  import glob
3
  import json
4
  import os.path
 
5
  import torch
 
 
6
  import gradio as gr
7
- import matplotlib.pyplot as plt
8
  from x_transformer import *
9
- import TMIDIX
 
10
  from midi_synthesizer import synthesis
 
 
 
11
 
12
  in_space = os.getenv("SYSTEM") == "spaces"
13
 
 
14
  # =================================================================================================
15
 
16
  @torch.no_grad()
@@ -128,6 +136,7 @@ def cancel_run(mid_seq):
128
  yield "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [
129
  create_msg("visualizer_end", None)]
130
 
 
131
  # =================================================================================================
132
 
133
  def load_javascript(dir="javascript"):
@@ -136,8 +145,7 @@ def load_javascript(dir="javascript"):
136
  for path in scripts_list:
137
  with open(path, "r", encoding="utf8") as jsfile:
138
  javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
139
-
140
- template_response_ori = gr.templates.TemplateResponse
141
 
142
  def template_response(*args, **kwargs):
143
  res = template_response_ori(*args, **kwargs)
@@ -149,22 +157,24 @@ def load_javascript(dir="javascript"):
149
  gr.routes.templates.TemplateResponse = template_response
150
 
151
 
152
- class JSMsgReceiver(gr.Interface):
153
 
154
  def __init__(self, **kwargs):
155
- super().__init__(elem_id="msg_receiver", **kwargs)
156
 
157
  def postprocess(self, y):
158
  if y:
159
  y = f"<p>{json.dumps(y)}</p>"
160
  return super().postprocess(y)
161
 
162
- def get_interface_name(self) -> str:
163
- return "interface"
 
164
 
165
  def create_msg(name, data):
166
  return {"name": name, "data": data}
167
 
 
168
  if __name__ == "__main__":
169
  parser = argparse.ArgumentParser()
170
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
@@ -202,10 +212,10 @@ if __name__ == "__main__":
202
  print('Done!')
203
 
204
  load_javascript()
205
- app = gr.Interface()
206
  with app:
207
- gr.markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
208
- gr.markdown(
209
  "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
210
  "Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
211
  "Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
@@ -215,18 +225,20 @@ if __name__ == "__main__":
215
  " for faster execution and endless generation"
216
  )
217
  js_msg = JSMsgReceiver()
218
- input_drums = gr.Checkbox(label="Drums Controls", default=False, type=bool, description="Drums present or not")
219
  input_instrument = gr.Radio(
220
  ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"],
221
- default="Piano", label="Lead Instrument Controls", description="Desired lead instrument")
222
- input_num_tokens = gr.Slider(minimum=16, maximum=512, default=256, label="Number of Tokens", description="Number of tokens to generate")
223
- run_btn = gr.Button("generate", type="primary")
224
  interrupt_btn = gr.Button("interrupt")
225
 
226
- output_midi_seq = gr.Output()
227
- output_midi_visualizer = gr.HTML()
228
- output_audio = gr.Audio(label="output audio", format="mp3")
229
- output_midi = gr.File(label="output midi", type="mid")
230
- run_event = run_btn.click(GenerateMIDI, inputs=[input_num_tokens, input_drums, input_instrument], outputs=[output_midi_seq, output_midi, output_audio, js_msg])
231
- interrupt_btn.click(cancel_run, inputs=[output_midi_seq], outputs=[output_midi, output_audio, js_msg], cancel=run_event, queue=False)
232
- app.run(port=opt.port, share=opt.share, inbrowser=True)
 
 
 
2
  import glob
3
  import json
4
  import os.path
5
+
6
  import torch
7
+ import torch.nn.functional as F
8
+
9
  import gradio as gr
10
+
11
  from x_transformer import *
12
+ import tqdm
13
+
14
  from midi_synthesizer import synthesis
15
+ import TMIDIX
16
+
17
+ import matplotlib.pyplot as plt
18
 
19
  in_space = os.getenv("SYSTEM") == "spaces"
20
 
21
+
22
  # =================================================================================================
23
 
24
  @torch.no_grad()
 
136
  yield "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [
137
  create_msg("visualizer_end", None)]
138
 
139
+
140
  # =================================================================================================
141
 
142
  def load_javascript(dir="javascript"):
 
145
  for path in scripts_list:
146
  with open(path, "r", encoding="utf8") as jsfile:
147
  javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
148
+ template_response_ori = gr.routes.templates.TemplateResponse
 
149
 
150
  def template_response(*args, **kwargs):
151
  res = template_response_ori(*args, **kwargs)
 
157
  gr.routes.templates.TemplateResponse = template_response
158
 
159
 
160
+ class JSMsgReceiver(gr.HTML):
161
 
162
  def __init__(self, **kwargs):
163
+ super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
164
 
165
  def postprocess(self, y):
166
  if y:
167
  y = f"<p>{json.dumps(y)}</p>"
168
  return super().postprocess(y)
169
 
170
+ def get_block_name(self) -> str:
171
+ return "html"
172
+
173
 
174
  def create_msg(name, data):
175
  return {"name": name, "data": data}
176
 
177
+
178
  if __name__ == "__main__":
179
  parser = argparse.ArgumentParser()
180
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
 
212
  print('Done!')
213
 
214
  load_javascript()
215
+ app = gr.Blocks()
216
  with app:
217
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
218
+ gr.Markdown(
219
  "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
220
  "Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
221
  "Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
 
225
  " for faster execution and endless generation"
226
  )
227
  js_msg = JSMsgReceiver()
228
+ input_drums = gr.Checkbox(label="Drums Controls", value=False, info="Drums present or not")
229
  input_instrument = gr.Radio(
230
  ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"],
231
+ value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
232
+ input_num_tokens = gr.Slider(16, 512, value=256, label="Number of Tokens", info="Number of tokens to generate")
233
+ run_btn = gr.Button("generate", variant="primary")
234
  interrupt_btn = gr.Button("interrupt")
235
 
236
+ output_midi_seq = gr.Variable()
237
+ output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
238
+ output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
239
+ output_midi = gr.File(label="output midi", file_types=[".mid"])
240
+ run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument],
241
+ [output_midi_seq, output_midi, output_audio, js_msg])
242
+ interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
243
+ cancels=run_event, queue=False)
244
+ app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)