Spaces:
Running
on
Zero
Running
on
Zero
| import csv | |
| import torch | |
| import numpy as np | |
| import logging | |
| from torch_mir_eval.separation import bss_eval_sources | |
| import fast_bss_eval | |
| from ..losses import ( | |
| PITLossWrapper, | |
| pairwise_neg_sisdr, | |
| pairwise_neg_snr, | |
| singlesrc_neg_sisdr, | |
| PairwiseNegSDR, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class MetricsTracker: | |
| def __init__(self, save_file: str = ""): | |
| self.all_sdrs = [] | |
| self.all_sdrs_i = [] | |
| self.all_sisnrs = [] | |
| self.all_sisnrs_i = [] | |
| csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"] | |
| self.results_csv = open(save_file, "w") | |
| self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns) | |
| self.writer.writeheader() | |
| self.pit_sisnr = PITLossWrapper( | |
| PairwiseNegSDR("sisdr", zero_mean=False), pit_from="pw_mtx" | |
| ) | |
| self.pit_snr = PITLossWrapper( | |
| PairwiseNegSDR("snr", zero_mean=False), pit_from="pw_mtx" | |
| ) | |
| def __call__(self, mix, clean, estimate, key): | |
| # sisnr | |
| sisnr = self.pit_sisnr(estimate.unsqueeze(0), clean.unsqueeze(0)) | |
| mix = torch.stack([mix] * clean.shape[0], dim=0) | |
| sisnr_baseline = self.pit_sisnr(mix.unsqueeze(0), clean.unsqueeze(0)) | |
| sisnr_i = sisnr - sisnr_baseline | |
| # sdr | |
| sdr = -fast_bss_eval.sdr_pit_loss(estimate, clean).mean() | |
| sdr_baseline = -fast_bss_eval.sdr_pit_loss(mix, clean).mean() | |
| sdr_i = sdr - sdr_baseline | |
| # import pdb; pdb.set_trace() | |
| row = { | |
| "snt_id": key, | |
| "sdr": sdr.item(), | |
| "sdr_i": sdr_i.item(), | |
| "si-snr": -sisnr.item(), | |
| "si-snr_i": -sisnr_i.item(), | |
| } | |
| self.writer.writerow(row) | |
| # Metric Accumulation | |
| self.all_sdrs.append(sdr.item()) | |
| self.all_sdrs_i.append(sdr_i.item()) | |
| self.all_sisnrs.append(-sisnr.item()) | |
| self.all_sisnrs_i.append(-sisnr_i.item()) | |
| def update(self, ): | |
| return {"sdr_i": np.array(self.all_sdrs_i).mean(), | |
| "si-snr_i": np.array(self.all_sisnrs_i).mean() | |
| } | |
| def final(self,): | |
| row = { | |
| "snt_id": "avg", | |
| "sdr": np.array(self.all_sdrs).mean(), | |
| "sdr_i": np.array(self.all_sdrs_i).mean(), | |
| "si-snr": np.array(self.all_sisnrs).mean(), | |
| "si-snr_i": np.array(self.all_sisnrs_i).mean(), | |
| } | |
| self.writer.writerow(row) | |
| row = { | |
| "snt_id": "std", | |
| "sdr": np.array(self.all_sdrs).std(), | |
| "sdr_i": np.array(self.all_sdrs_i).std(), | |
| "si-snr": np.array(self.all_sisnrs).std(), | |
| "si-snr_i": np.array(self.all_sisnrs_i).std(), | |
| } | |
| self.writer.writerow(row) | |
| self.results_csv.close() | |