Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
5dafbac
1
Parent(s):
df7025d
fixes #3
Browse files- vampnet/beats.py +2 -1
vampnet/beats.py
CHANGED
@@ -9,6 +9,7 @@ from typing import Tuple
|
|
9 |
from typing import Union
|
10 |
|
11 |
import librosa
|
|
|
12 |
import numpy as np
|
13 |
from audiotools import AudioSignal
|
14 |
|
@@ -203,7 +204,7 @@ class WaveBeat(BeatTracker):
|
|
203 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
204 |
from wavebeat.dstcn import dsTCNModel
|
205 |
|
206 |
-
model = dsTCNModel.load_from_checkpoint(ckpt_path)
|
207 |
model.eval()
|
208 |
|
209 |
self.device = device
|
|
|
9 |
from typing import Union
|
10 |
|
11 |
import librosa
|
12 |
+
import torch
|
13 |
import numpy as np
|
14 |
from audiotools import AudioSignal
|
15 |
|
|
|
204 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
205 |
from wavebeat.dstcn import dsTCNModel
|
206 |
|
207 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
|
208 |
model.eval()
|
209 |
|
210 |
self.device = device
|