soujanyaporia commited on
Commit
3dd4ae8
1 Parent(s): 60a6a17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -22,8 +22,16 @@ from gradio import Markdown
22
 
23
  import spaces
24
 
 
 
 
 
 
 
 
 
25
  class MusicFeaturePredictor:
26
- def __init__(self, path, device="cuda:0", cache_dir=None, local_files_only=False):
27
  self.beats_tokenizer = AutoTokenizer.from_pretrained(
28
  "microsoft/deberta-v3-large",
29
  use_fast=False,
@@ -147,7 +155,7 @@ class Mustango:
147
  def __init__(
148
  self,
149
  name="declare-lab/mustango",
150
- device="cuda:0",
151
  cache_dir=None,
152
  local_files_only=False,
153
  ):
@@ -217,10 +225,11 @@ class Mustango:
217
 
218
 
219
  # Initialize Mustango
220
- if torch.cuda.is_available():
221
- mustango = Mustango()
222
- else:
223
- mustango = Mustango(device="cpu")
 
224
 
225
  # output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False)
226
  @spaces.GPU(duration=120)
 
22
 
23
  import spaces
24
 
25
+ # Automatic device detection
26
+ if torch.cuda.is_available():
27
+ device_type = "cuda"
28
+ device_selection = "cuda:0"
29
+ else:
30
+ device_type = "cpu"
31
+ device_selection = "cpu"
32
+
33
  class MusicFeaturePredictor:
34
+ def __init__(self, path, device=device_selection, cache_dir=None, local_files_only=False):
35
  self.beats_tokenizer = AutoTokenizer.from_pretrained(
36
  "microsoft/deberta-v3-large",
37
  use_fast=False,
 
155
  def __init__(
156
  self,
157
  name="declare-lab/mustango",
158
+ device=device_selection,
159
  cache_dir=None,
160
  local_files_only=False,
161
  ):
 
225
 
226
 
227
  # Initialize Mustango
228
+ mustango = Mustango(device="cpu")
229
+ # if torch.cuda.is_available():
230
+ # mustango = Mustango()
231
+ # else:
232
+ # mustango = Mustango(device="cpu")
233
 
234
  # output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False)
235
  @spaces.GPU(duration=120)