waidhoferj commited on
Commit
17f9fb1
1 Parent(s): 7b37b0e

device cpu

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,7 +10,7 @@ import json
10
  from functools import cache
11
  import pandas as pd
12
 
13
-
14
 
15
  @cache
16
  def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
@@ -43,10 +43,10 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
43
  expected_duration = 6
44
  threshold = 0.5
45
  sample_len = sample_rate * expected_duration
46
- device = "mps"
47
 
48
  audio_pipeline = get_pipeline(sample_rate)
49
- model, labels = get_model(device)
50
 
51
  if sample_len > len(waveform):
52
  raise gr.Error("You must record for at least 6 seconds")
@@ -60,7 +60,7 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
60
  waveform = waveform.astype("float32")
61
  waveform = torch.from_numpy(waveform)
62
  spectrogram = audio_pipeline(waveform)
63
- spectrogram = spectrogram.unsqueeze(0).to(device)
64
 
65
  with torch.no_grad():
66
  results = model(spectrogram)
 
10
  from functools import cache
11
  import pandas as pd
12
 
13
+ DEVICE = "cpu"
14
 
15
  @cache
16
  def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
 
43
  expected_duration = 6
44
  threshold = 0.5
45
  sample_len = sample_rate * expected_duration
46
+
47
 
48
  audio_pipeline = get_pipeline(sample_rate)
49
+ model, labels = get_model(DEVICE)
50
 
51
  if sample_len > len(waveform):
52
  raise gr.Error("You must record for at least 6 seconds")
 
60
  waveform = waveform.astype("float32")
61
  waveform = torch.from_numpy(waveform)
62
  spectrogram = audio_pipeline(waveform)
63
+ spectrogram = spectrogram.unsqueeze(0).to(DEVICE)
64
 
65
  with torch.no_grad():
66
  results = model(spectrogram)