asigalov61 commited on
Commit
037c7a6
1 Parent(s): 5173099

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -39
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import argparse
2
  import glob
3
  import json
4
  import os.path
@@ -11,6 +10,7 @@ import torch
11
  import torch.nn.functional as F
12
 
13
  import gradio as gr
 
14
 
15
  from x_transformer import *
16
  import tqdm
@@ -24,7 +24,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
24
 
25
  # =================================================================================================
26
 
27
- @torch.no_grad()
28
  def GenerateMIDI(num_tok, idrums, iinstr):
29
  print('=' * 70)
30
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
@@ -83,6 +83,38 @@ def GenerateMIDI(num_tok, idrums, iinstr):
83
 
84
  yield output, None, None, [create_msg("visualizer_clear", None)]
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  outy = start_tokens
87
 
88
  ctime = 0
@@ -201,42 +233,6 @@ if __name__ == "__main__":
201
  print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
202
  print('=' * 70)
203
 
204
- parser = argparse.ArgumentParser()
205
- parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
206
- parser.add_argument("--port", type=int, default=7860, help="gradio server port")
207
- opt = parser.parse_args()
208
-
209
- print('Loading model...')
210
-
211
- SEQ_LEN = 2048
212
-
213
- # instantiate the model
214
-
215
- model = TransformerWrapper(
216
- num_tokens=3088,
217
- max_seq_len=SEQ_LEN,
218
- attn_layers=Decoder(dim=1024, depth=16, heads=8)
219
- )
220
-
221
- model = AutoregressiveWrapper(model)
222
-
223
- model = torch.nn.DataParallel(model)
224
-
225
- model.cpu()
226
- print('=' * 70)
227
-
228
- print('Loading model checkpoint...')
229
-
230
- model.load_state_dict(
231
- torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
232
- map_location='cpu'))
233
- print('=' * 70)
234
-
235
- model.eval()
236
-
237
- print('Done!')
238
- print('=' * 70)
239
-
240
  load_javascript()
241
  app = gr.Blocks()
242
  with app:
@@ -267,4 +263,4 @@ if __name__ == "__main__":
267
  [output_midi_seq, output_midi, output_audio, js_msg])
268
  interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
269
  cancels=run_event, queue=False)
270
- app.queue(concurrency_count=1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
 
1
  import glob
2
  import json
3
  import os.path
 
10
  import torch.nn.functional as F
11
 
12
  import gradio as gr
13
+ import spaces
14
 
15
  from x_transformer import *
16
  import tqdm
 
24
 
25
  # =================================================================================================
26
 
27
+ @spaces.GPU
28
  def GenerateMIDI(num_tok, idrums, iinstr):
29
  print('=' * 70)
30
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
 
83
 
84
  yield output, None, None, [create_msg("visualizer_clear", None)]
85
 
86
+
87
+ print('Loading model...')
88
+
89
+ SEQ_LEN = 2048
90
+
91
+ # instantiate the model
92
+
93
+ model = TransformerWrapper(
94
+ num_tokens=3088,
95
+ max_seq_len=SEQ_LEN,
96
+ attn_layers=Decoder(dim=1024, depth=16, heads=8)
97
+ )
98
+
99
+ model = AutoregressiveWrapper(model)
100
+
101
+ model = torch.nn.DataParallel(model)
102
+
103
+ model.cpu()
104
+ print('=' * 70)
105
+
106
+ print('Loading model checkpoint...')
107
+
108
+ model.load_state_dict(
109
+ torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
110
+ map_location='cpu'))
111
+ print('=' * 70)
112
+
113
+ model.eval()
114
+
115
+ print('Done!')
116
+ print('=' * 70)
117
+
118
  outy = start_tokens
119
 
120
  ctime = 0
 
233
  print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
234
  print('=' * 70)
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  load_javascript()
237
  app = gr.Blocks()
238
  with app:
 
263
  [output_midi_seq, output_midi, output_audio, js_msg])
264
  interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
265
  cancels=run_event, queue=False)
266
+ app.queue().launch()