mattricesound commited on
Commit
fb059dc
1 Parent(s): 22b77e6

Add GPU support

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -97,7 +97,9 @@ def load_model(version='melody'):
97
  global MODEL
98
  print("Loading model", version)
99
  if MODEL is None or MODEL.name != version:
100
- MODEL = MusicGen.get_pretrained(version)
 
 
101
 
102
 
103
  def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
 
97
  global MODEL
98
  print("Loading model", version)
99
  if MODEL is None or MODEL.name != version:
100
+ # If gpu is not available, we'll use cpu.
101
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ MODEL = MusicGen.get_pretrained(version, device=device)
103
 
104
 
105
  def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):