asigalov61 commited on
Commit
c2734a9
1 Parent(s): a71ce75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -35
app.py CHANGED
@@ -2,23 +2,15 @@ import argparse
2
  import glob
3
  import json
4
  import os.path
5
-
6
  import torch
7
- import torch.nn.functional as F
8
-
9
  import gradio as gr
10
-
11
  from x_transformer import *
12
- import tqdm
13
-
14
- from midi_synthesizer import synthesis
15
  import TMIDIX
16
-
17
- import matplotlib.pyplot as plt
18
 
19
  in_space = os.getenv("SYSTEM") == "spaces"
20
 
21
-
22
  # =================================================================================================
23
 
24
  @torch.no_grad()
@@ -136,7 +128,6 @@ def cancel_run(mid_seq):
136
  yield "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [
137
  create_msg("visualizer_end", None)]
138
 
139
-
140
  # =================================================================================================
141
 
142
  def load_javascript(dir="javascript"):
@@ -145,7 +136,8 @@ def load_javascript(dir="javascript"):
145
  for path in scripts_list:
146
  with open(path, "r", encoding="utf8") as jsfile:
147
  javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
148
- template_response_ori = gr.routes.templates.TemplateResponse
 
149
 
150
  def template_response(*args, **kwargs):
151
  res = template_response_ori(*args, **kwargs)
@@ -154,27 +146,25 @@ def load_javascript(dir="javascript"):
154
  res.init_headers()
155
  return res
156
 
157
- gr.routes.templates.TemplateResponse = template_response
158
 
159
 
160
- class JSMsgReceiver(gr.HTML):
161
 
162
  def __init__(self, **kwargs):
163
- super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
164
 
165
  def postprocess(self, y):
166
  if y:
167
  y = f"<p>{json.dumps(y)}</p>"
168
  return super().postprocess(y)
169
 
170
- def get_block_name(self) -> str:
171
- return "html"
172
-
173
 
174
  def create_msg(name, data):
175
  return {"name": name, "data": data}
176
 
177
-
178
  if __name__ == "__main__":
179
  parser = argparse.ArgumentParser()
180
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
@@ -212,10 +202,10 @@ if __name__ == "__main__":
212
  print('Done!')
213
 
214
  load_javascript()
215
- app = gr.Blocks()
216
  with app:
217
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
218
- gr.Markdown(
219
  "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
220
  "Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
221
  "Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
@@ -225,20 +215,18 @@ if __name__ == "__main__":
225
  " for faster execution and endless generation"
226
  )
227
  js_msg = JSMsgReceiver()
228
- input_drums = gr.Checkbox(label="Drums Controls", value=False, info="Drums present or not")
229
  input_instrument = gr.Radio(
230
  ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"],
231
- value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
232
- input_num_tokens = gr.Slider(16, 512, value=256, label="Number of Tokens", info="Number of tokens to generate")
233
- run_btn = gr.Button("generate", variant="primary")
234
  interrupt_btn = gr.Button("interrupt")
235
 
236
- output_midi_seq = gr.Variable()
237
- output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
238
- output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
239
- output_midi = gr.File(label="output midi", file_types=[".mid"])
240
- run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument],
241
- [output_midi_seq, output_midi, output_audio, js_msg])
242
- interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
243
- cancels=run_event, queue=False)
244
- app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
2
  import glob
3
  import json
4
  import os.path
 
5
  import torch
 
 
6
  import gradio as gr
7
+ import matplotlib.pyplot as plt
8
  from x_transformer import *
 
 
 
9
  import TMIDIX
10
+ from midi_synthesizer import synthesis
 
11
 
12
  in_space = os.getenv("SYSTEM") == "spaces"
13
 
 
14
  # =================================================================================================
15
 
16
  @torch.no_grad()
 
128
  yield "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [
129
  create_msg("visualizer_end", None)]
130
 
 
131
  # =================================================================================================
132
 
133
  def load_javascript(dir="javascript"):
 
136
  for path in scripts_list:
137
  with open(path, "r", encoding="utf8") as jsfile:
138
  javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
139
+
140
+ template_response_ori = gr.templates.TemplateResponse
141
 
142
  def template_response(*args, **kwargs):
143
  res = template_response_ori(*args, **kwargs)
 
146
  res.init_headers()
147
  return res
148
 
149
+ gr.templates.TemplateResponse = template_response
150
 
151
 
152
+ class JSMsgReceiver(gr.Interface):
153
 
154
  def __init__(self, **kwargs):
155
+ super().__init__(elem_id="msg_receiver", **kwargs)
156
 
157
  def postprocess(self, y):
158
  if y:
159
  y = f"<p>{json.dumps(y)}</p>"
160
  return super().postprocess(y)
161
 
162
+ def get_interface_name(self) -> str:
163
+ return "interface"
 
164
 
165
  def create_msg(name, data):
166
  return {"name": name, "data": data}
167
 
 
168
  if __name__ == "__main__":
169
  parser = argparse.ArgumentParser()
170
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
 
202
  print('Done!')
203
 
204
  load_javascript()
205
+ app = gr.Interface()
206
  with app:
207
+ gr.markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
208
+ gr.markdown(
209
  "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
210
  "Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
211
  "Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
 
215
  " for faster execution and endless generation"
216
  )
217
  js_msg = JSMsgReceiver()
218
+ input_drums = gr.Checkbox(label="Drums Controls", default=False, type=bool, description="Drums present or not")
219
  input_instrument = gr.Radio(
220
  ["Piano", "Guitar", "Bass", "Violin", "Cello", "Harp", "Trumpet", "Sax", "Flute", "Choir", "Organ"],
221
+ default="Piano", label="Lead Instrument Controls", description="Desired lead instrument")
222
+ input_num_tokens = gr.Slider(minimum=16, maximum=512, default=256, label="Number of Tokens", description="Number of tokens to generate")
223
+ run_btn = gr.Button("generate", type="primary")
224
  interrupt_btn = gr.Button("interrupt")
225
 
226
+ output_midi_seq = gr.Output()
227
+ output_midi_visualizer = gr.HTML()
228
+ output_audio = gr.Audio(label="output audio", format="mp3")
229
+ output_midi = gr.File(label="output midi", type="mid")
230
+ run_event = run_btn.click(GenerateMIDI, inputs=[input_num_tokens, input_drums, input_instrument], outputs=[output_midi_seq, output_midi, output_audio, js_msg])
231
+ interrupt_btn.click(cancel_run, inputs=[output_midi_seq], outputs=[output_midi, output_audio, js_msg], cancel=run_event, queue=False)
232
+ app.run(port=opt.port, share=opt.share, inbrowser=True)