patchbanks commited on
Commit
3f08fef
·
verified ·
1 Parent(s): b0cd41b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -27,7 +27,7 @@ max_new_tokens = 768
27
 
28
  seed = random.randint(1, 100000)
29
  torch.manual_seed(seed)
30
- device = 'cpu' if torch.cuda.is_available() else 'cpu'
31
  dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
32
  compile = False
33
  exec(open('configurator.py').read())
@@ -81,6 +81,7 @@ def clear_midi(dir):
81
  clear_midi(temp_dir)
82
 
83
 
 
84
  def generate_midi(temperature, top_k):
85
  start_ids = encode(start)
86
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
 
27
 
28
  seed = random.randint(1, 100000)
29
  torch.manual_seed(seed)
30
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
  dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
32
  compile = False
33
  exec(open('configurator.py').read())
 
81
  clear_midi(temp_dir)
82
 
83
 
84
+ @spaces.GPU(duration=15)
85
  def generate_midi(temperature, top_k):
86
  start_ids = encode(start)
87
  x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])