asigalov61 commited on
Commit
498b808
1 Parent(s): d7c4dcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -49
app.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn.functional as F
7
 
8
  import gradio as gr
9
 
10
- import onnxruntime as rt
11
  import tqdm
12
 
13
  from midi_synthesizer import synthesis
@@ -32,58 +32,24 @@ def GenerateMIDI(idrums, iinstr, progress=gr.Progress()):
32
 
33
  start_tokens = [3087, drums, 3075+first_note_instrument_number]
34
 
35
- seq_len = 512
36
- max_seq_len = 2048
37
- temperature = 0.9
38
- verbose=False
39
- return_prime=False
40
-
41
-
42
- out = torch.FloatTensor([start_tokens])
43
-
44
- st = len(start_tokens)
45
-
46
- if verbose:
47
- print("Generating sequence of max length:", seq_len)
48
-
49
- progress(0, desc="Starting...")
50
- step = 0
51
-
52
- for i in progress.tqdm(range(seq_len)):
53
-
54
- try:
55
- x = out[:, -max_seq_len:]
56
 
57
- torch_in = x.tolist()[0]
58
 
59
- logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
60
-
61
- probs = F.softmax(logits / temperature, dim=-1)
62
 
63
- sample = torch.multinomial(probs, 1)
 
 
 
 
64
 
65
- out = torch.cat((out, sample), dim=-1)
66
-
67
- if step % 16 == 0:
68
- print(step, '/', seq_len)
69
-
70
- step += 1
71
-
72
- if step >= seq_len:
73
- break
74
-
75
- except Exception as e:
76
- print('Error', e)
77
- break
78
-
79
- if return_prime:
80
- melody_chords_f = out[:, :]
81
-
82
- else:
83
- melody_chords_f = out[:, st:]
84
-
85
- melody_chords_f = melody_chords_f.tolist()[0]
86
 
 
 
87
  print('=' * 70)
88
  print('Sample INTs', melody_chords_f[:12])
89
  print('=' * 70)
@@ -196,7 +162,31 @@ if __name__ == "__main__":
196
  opt = parser.parse_args()
197
 
198
  print('Loading model...')
199
- session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  print('Done!')
201
 
202
  app = gr.Blocks()
 
7
 
8
  import gradio as gr
9
 
10
+ from x_transformer import *
11
  import tqdm
12
 
13
  from midi_synthesizer import synthesis
 
32
 
33
  start_tokens = [3087, drums, 3075+first_note_instrument_number]
34
 
35
+ print('Selected Improv sequence:')
36
+ print(start_tokens)
37
+ print('=' * 70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ inp = [start_tokens] * number_of_batches_to_generate
40
 
41
+ inp = torch.LongTensor(inp).cpu()
 
 
42
 
43
+ out = model.module.generate(inp,
44
+ number_of_tokens_tp_generate,
45
+ temperature=temperature,
46
+ return_prime=False,
47
+ verbose=True)
48
 
49
+ melody_chords_f = out[0].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ print('=' * 70)
52
+ print('Done!')
53
  print('=' * 70)
54
  print('Sample INTs', melody_chords_f[:12])
55
  print('=' * 70)
 
162
  opt = parser.parse_args()
163
 
164
  print('Loading model...')
165
+
166
+ SEQ_LEN = 2048
167
+
168
+ # instantiate the model
169
+
170
+ model = TransformerWrapper(
171
+ num_tokens = 3088,
172
+ max_seq_len = SEQ_LEN,
173
+ attn_layers = Decoder(dim = 1024, depth = 32, heads = 8)
174
+ )
175
+
176
+ model = AutoregressiveWrapper(model)
177
+
178
+ model = torch.nn.DataParallel(model)
179
+
180
+ model.cpu()
181
+ print('=' * 70)
182
+
183
+ print('Loading model checkpoint...')
184
+
185
+ model.load_state_dict(torch.load(full_path_to_model_checkpoint, map_location='cpu'))
186
+ print('=' * 70)
187
+
188
+ model.eval()
189
+
190
  print('Done!')
191
 
192
  app = gr.Blocks()