asigalov61 commited on
Commit
f24d883
1 Parent(s): 09f50c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -6,10 +6,9 @@ import torch
6
  import torch.nn.functional as F
7
 
8
  import gradio as gr
9
- import numpy as np
10
  import onnxruntime as rt
11
  import tqdm
12
- import json
13
 
14
  from midi_synthesizer import synthesis
15
  import TMIDIX
@@ -18,9 +17,6 @@ in_space = os.getenv("SYSTEM") == "spaces"
18
 
19
  #=================================================================================================
20
 
21
- def create_msg(name, data):
22
- return {"name": name, "data": data}
23
-
24
  def GenerateMIDI():
25
 
26
  start_tokens = [3087, 3073+1, 3075+1]
@@ -45,12 +41,14 @@ def GenerateMIDI():
45
  try:
46
 
47
  x = out[:, -max_seq_len:]
48
-
49
  torch_in = x.tolist()[0]
50
 
51
  logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
52
 
53
- probs = F.softmax(logits / temperature, dim=-1)
 
 
54
 
55
  sample = torch.multinomial(probs, 1)
56
 
 
6
  import torch.nn.functional as F
7
 
8
  import gradio as gr
9
+
10
  import onnxruntime as rt
11
  import tqdm
 
12
 
13
  from midi_synthesizer import synthesis
14
  import TMIDIX
 
17
 
18
  #=================================================================================================
19
 
 
 
 
20
  def GenerateMIDI():
21
 
22
  start_tokens = [3087, 3073+1, 3075+1]
 
41
  try:
42
 
43
  x = out[:, -max_seq_len:]
44
+
45
  torch_in = x.tolist()[0]
46
 
47
  logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
48
 
49
+ filtered_logits = logits
50
+
51
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
52
 
53
  sample = torch.multinomial(probs, 1)
54