bandhav commited on
Commit
7089999
1 Parent(s): e6a6383
Files changed (2) hide show
  1. app.py +3 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -42,7 +42,8 @@ def waveformer(audio, label_choices):
42
  fs, mixture = audio
43
  if fs != 44100:
44
  raise ValueError(fs)
45
- mixture = torch.from_numpy(mixture).unsqueeze(0)
 
46
 
47
  # Construct the query vector
48
  if len(label_choices) == 0:
@@ -54,7 +55,7 @@ def waveformer(audio, label_choices):
54
  with torch.no_grad():
55
  output = model(mixture, query)
56
 
57
- return fs, output.squeeze(0).numpy()
58
 
59
 
60
  label_checkbox = gr.CheckboxGroup(choices=TARGETS)
 
42
  fs, mixture = audio
43
  if fs != 44100:
44
  raise ValueError(fs)
45
+ mixture = torch.from_numpy(
46
+ mixture).unsqueeze(0).unsqueeze(0).to(torch.float)
47
 
48
  # Construct the query vector
49
  if len(label_choices) == 0:
 
55
  with torch.no_grad():
56
  output = model(mixture, query)
57
 
58
+ return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
59
 
60
 
61
  label_checkbox = gr.CheckboxGroup(choices=TARGETS)
requirements.txt CHANGED
@@ -6,4 +6,5 @@ soundfile
6
  numpy
7
  speechbrain
8
  wget
 
9
 
 
6
  numpy
7
  speechbrain
8
  wget
9
+ torchmetrics
10