asigalov61 commited on
Commit
8453f63
1 Parent(s): de46ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -2
app.py CHANGED
@@ -2,6 +2,9 @@ import argparse
2
  import glob
3
  import os.path
4
 
 
 
 
5
  import gradio as gr
6
  import numpy as np
7
  import onnxruntime as rt
@@ -13,7 +16,54 @@ import TMIDIX
13
 
14
  in_space = os.getenv("SYSTEM") == "spaces"
15
 
16
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def load_javascript(dir="javascript"):
19
  scripts_list = glob.glob(f"{dir}/*.js")
@@ -55,7 +105,7 @@ if __name__ == "__main__":
55
  opt = parser.parse_args()
56
 
57
 
58
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
59
 
60
  session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=providers)
61
 
 
2
  import glob
3
  import os.path
4
 
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
  import gradio as gr
9
  import numpy as np
10
  import onnxruntime as rt
 
16
 
17
  in_space = os.getenv("SYSTEM") == "spaces"
18
 
19
+ providers = ['CPUExecutionProvider']
20
+
21
+ #=================================================================================================
22
+
23
+ def generate(
24
+ start_tokens,
25
+ seq_len,
26
+ max_seq_len = 2048,
27
+ temperature = 0.9,
28
+ verbose=True,
29
+ return_prime=False,
30
+ ):
31
+
32
+ out = torch.LongTensor([start_tokens])
33
+
34
+ st = len(start_tokens)
35
+
36
+ if verbose:
37
+ print("Generating sequence of max length:", seq_len)
38
+
39
+ for s in range(seq_len):
40
+ x = out[:, -max_seq_len:]
41
+
42
+ torch_in = x.tolist()[0]
43
+
44
+ logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
45
+
46
+ filtered_logits = logits
47
+
48
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
49
+
50
+ sample = torch.multinomial(probs, 1)
51
+
52
+ out = torch.cat((out, sample), dim=-1)
53
+
54
+ if verbose:
55
+ if s % 32 == 0:
56
+ print(s, '/', seq_len)
57
+
58
+ if return_prime:
59
+ return out[:, :]
60
+
61
+ else:
62
+ return out[:, st:]
63
+
64
+
65
+ #=================================================================================================
66
+
67
 
68
  def load_javascript(dir="javascript"):
69
  scripts_list = glob.glob(f"{dir}/*.js")
 
105
  opt = parser.parse_args()
106
 
107
 
108
+ providers = ['CPUExecutionProvider']
109
 
110
  session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=providers)
111