Hugo Flores Garcia commited on
Commit
c1b9ba0
1 Parent(s): e3c7f46

use items instead of tensors

Browse files
Files changed (1) hide show
  1. scripts/exp/eval.py +6 -5
scripts/exp/eval.py CHANGED
@@ -59,12 +59,13 @@ def eval(
59
 
60
  pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
61
  for baseline_file, cond_file in pbar:
 
62
  assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
63
  pbar.set_description(baseline_file.stem)
64
 
65
  # load the files
66
- baseline_sig = AudioSignal(baseline_file)
67
- cond_sig = AudioSignal(cond_file)
68
 
69
  # compute the metrics
70
  try:
@@ -72,9 +73,9 @@ def eval(
72
  except:
73
  vsq = 0.0
74
  metrics.append({
75
- "sisdr": sisdr_loss(baseline_sig, cond_sig),
76
- "stft": stft_loss(baseline_sig, cond_sig),
77
- "mel": mel_loss(baseline_sig, cond_sig),
78
  "frechet": frechet_score,
79
  "visqol": vsq,
80
  "condition": condition,
 
59
 
60
  pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
61
  for baseline_file, cond_file in pbar:
62
+ # make sure the files match (same name)
63
  assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
64
  pbar.set_description(baseline_file.stem)
65
 
66
  # load the files
67
+ baseline_sig = AudioSignal(str(baseline_file))
68
+ cond_sig = AudioSignal(str(cond_file))
69
 
70
  # compute the metrics
71
  try:
 
73
  except:
74
  vsq = 0.0
75
  metrics.append({
76
+ "sisdr": sisdr_loss(baseline_sig, cond_sig).item(),
77
+ "stft": stft_loss(baseline_sig, cond_sig).item(),
78
+ "mel": mel_loss(baseline_sig, cond_sig).item(),
79
  "frechet": frechet_score,
80
  "visqol": vsq,
81
  "condition": condition,