Hugo Flores Garcia commited on
Commit
5dafbac
1 Parent(s): df7025d
Files changed (1) hide show
  1. vampnet/beats.py +2 -1
vampnet/beats.py CHANGED
@@ -9,6 +9,7 @@ from typing import Tuple
9
  from typing import Union
10
 
11
  import librosa
 
12
  import numpy as np
13
  from audiotools import AudioSignal
14
 
@@ -203,7 +204,7 @@ class WaveBeat(BeatTracker):
203
  def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
204
  from wavebeat.dstcn import dsTCNModel
205
 
206
- model = dsTCNModel.load_from_checkpoint(ckpt_path)
207
  model.eval()
208
 
209
  self.device = device
 
9
  from typing import Union
10
 
11
  import librosa
12
+ import torch
13
  import numpy as np
14
  from audiotools import AudioSignal
15
 
 
204
  def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
205
  from wavebeat.dstcn import dsTCNModel
206
 
207
+ model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
208
  model.eval()
209
 
210
  self.device = device