bandhav commited on
Commit
b7c0655
1 Parent(s): 7089999

Fixed normalization

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -43,7 +43,7 @@ def waveformer(audio, label_choices):
43
  if fs != 44100:
44
  raise ValueError(fs)
45
  mixture = torch.from_numpy(
46
- mixture).unsqueeze(0).unsqueeze(0).to(torch.float)
47
 
48
  # Construct the query vector
49
  if len(label_choices) == 0:
@@ -53,7 +53,7 @@ def waveformer(audio, label_choices):
53
  query[0, TARGETS.index(t)] = 1.
54
 
55
  with torch.no_grad():
56
- output = model(mixture, query)
57
 
58
  return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
59
 
 
43
  if fs != 44100:
44
  raise ValueError(fs)
45
  mixture = torch.from_numpy(
46
+ mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15)
47
 
48
  # Construct the query vector
49
  if len(label_choices) == 0:
 
53
  query[0, TARGETS.index(t)] = 1.
54
 
55
  with torch.no_grad():
56
+ output = (2.0 ** 15) * model(mixture, query)
57
 
58
  return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
59