Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
17f9fb1
1
Parent(s):
7b37b0e
device cpu
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ import json
|
|
10 |
from functools import cache
|
11 |
import pandas as pd
|
12 |
|
13 |
-
|
14 |
|
15 |
@cache
|
16 |
def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
|
@@ -43,10 +43,10 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
|
43 |
expected_duration = 6
|
44 |
threshold = 0.5
|
45 |
sample_len = sample_rate * expected_duration
|
46 |
-
|
47 |
|
48 |
audio_pipeline = get_pipeline(sample_rate)
|
49 |
-
model, labels = get_model(
|
50 |
|
51 |
if sample_len > len(waveform):
|
52 |
raise gr.Error("You must record for at least 6 seconds")
|
@@ -60,7 +60,7 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
|
60 |
waveform = waveform.astype("float32")
|
61 |
waveform = torch.from_numpy(waveform)
|
62 |
spectrogram = audio_pipeline(waveform)
|
63 |
-
spectrogram = spectrogram.unsqueeze(0).to(
|
64 |
|
65 |
with torch.no_grad():
|
66 |
results = model(spectrogram)
|
|
|
10 |
from functools import cache
|
11 |
import pandas as pd
|
12 |
|
13 |
+
DEVICE = "cpu"
|
14 |
|
15 |
@cache
|
16 |
def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
|
|
|
43 |
expected_duration = 6
|
44 |
threshold = 0.5
|
45 |
sample_len = sample_rate * expected_duration
|
46 |
+
|
47 |
|
48 |
audio_pipeline = get_pipeline(sample_rate)
|
49 |
+
model, labels = get_model(DEVICE)
|
50 |
|
51 |
if sample_len > len(waveform):
|
52 |
raise gr.Error("You must record for at least 6 seconds")
|
|
|
60 |
waveform = waveform.astype("float32")
|
61 |
waveform = torch.from_numpy(waveform)
|
62 |
spectrogram = audio_pipeline(waveform)
|
63 |
+
spectrogram = spectrogram.unsqueeze(0).to(DEVICE)
|
64 |
|
65 |
with torch.no_grad():
|
66 |
results = model(spectrogram)
|