soujanyaporia commited on
Commit
7416727
1 Parent(s): c963e98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -176,7 +176,7 @@ class Mustango:
176
  main_config["scheduler_name"],
177
  unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json",
178
  ).to(device)
179
- self.model.device = device
180
 
181
  vae_weights = torch.load(
182
  f"{path}/vae/pytorch_model_vae.bin", map_location=device
@@ -226,11 +226,11 @@ class Mustango:
226
 
227
  # Initialize Mustango
228
  mustango = Mustango(device="cpu")
229
- mustango.vae.to(device_selection)
230
- mustango.stft.to(device_selection)
231
- mustango.model.to(device_selection)
232
- mustango.music_model.beats_model.to(device_selection)
233
- mustango.music_model.chords_model.to(device_selection)
234
  # if torch.cuda.is_available():
235
  # mustango = Mustango()
236
  # else:
 
176
  main_config["scheduler_name"],
177
  unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json",
178
  ).to(device)
179
+ # self.model.device = device
180
 
181
  vae_weights = torch.load(
182
  f"{path}/vae/pytorch_model_vae.bin", map_location=device
 
226
 
227
  # Initialize Mustango
228
  mustango = Mustango(device="cpu")
229
+ mustango.vae.to(device_type)
230
+ mustango.stft.to(device_type)
231
+ mustango.model.to(device_type)
232
+ mustango.music_model.beats_model.to(device_type)
233
+ mustango.music_model.chords_model.to(device_type)
234
  # if torch.cuda.is_available():
235
  # mustango = Mustango()
236
  # else: