soujanyaporia commited on
Commit
7de47a5
1 Parent(s): 7416727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -47,7 +47,7 @@ class MusicFeaturePredictor:
47
  self.beats_model.to(device)
48
 
49
  beats_ckpt = f"{path}/beats/microsoft-deberta-v3-large.pt"
50
- beats_weight = torch.load(beats_ckpt, map_location="cpu")
51
  self.beats_model.load_state_dict(beats_weight)
52
 
53
  self.chords_tokenizer = AutoTokenizer.from_pretrained(
@@ -64,7 +64,7 @@ class MusicFeaturePredictor:
64
  self.chords_model.to(device)
65
 
66
  chords_ckpt = f"{path}/chords/flan-t5-large.bin"
67
- chords_weight = torch.load(chords_ckpt, map_location="cpu")
68
  self.chords_model.load_state_dict(chords_weight)
69
 
70
  def generate_beats(self, prompt):
 
47
  self.beats_model.to(device)
48
 
49
  beats_ckpt = f"{path}/beats/microsoft-deberta-v3-large.pt"
50
+ beats_weight = torch.load(beats_ckpt, map_location=device)
51
  self.beats_model.load_state_dict(beats_weight)
52
 
53
  self.chords_tokenizer = AutoTokenizer.from_pretrained(
 
64
  self.chords_model.to(device)
65
 
66
  chords_ckpt = f"{path}/chords/flan-t5-large.bin"
67
+ chords_weight = torch.load(chords_ckpt, map_location=device)
68
  self.chords_model.load_state_dict(chords_weight)
69
 
70
  def generate_beats(self, prompt):