File size: 4,752 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#!/usr/bin/env python3
import argparse
import logging
import sys
from typing import List
from typing import Union
from mir_eval.separation import bss_eval_sources
import numpy as np
from pystoi import stoi
import torch
from typeguard import check_argument_types
from espnet.utils.cli_utils import get_commandline_args
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.fileio.sound_scp import SoundScpReader
from espnet2.utils import config_argparse
def scoring(
output_dir: str,
dtype: str,
log_level: Union[int, str],
key_file: str,
ref_scp: List[str],
inf_scp: List[str],
ref_channel: int,
):
assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
assert len(ref_scp) == len(inf_scp), ref_scp
num_spk = len(ref_scp)
keys = [
line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8")
]
ref_readers = [SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp]
inf_readers = [SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp]
# get sample rate
sample_rate, _ = ref_readers[0][keys[0]]
# check keys
for inf_reader, ref_reader in zip(inf_readers, ref_readers):
assert inf_reader.keys() == ref_reader.keys()
with DatadirWriter(output_dir) as writer:
for key in keys:
ref_audios = [ref_reader[key][1] for ref_reader in ref_readers]
inf_audios = [inf_reader[key][1] for inf_reader in inf_readers]
ref = np.array(ref_audios)
inf = np.array(inf_audios)
if ref.ndim > inf.ndim:
# multi-channel reference and single-channel output
ref = ref[..., ref_channel]
assert ref.shape == inf.shape, (ref.shape, inf.shape)
elif ref.ndim < inf.ndim:
# single-channel reference and multi-channel output
raise ValueError(
"Reference must be multi-channel when the \
network output is multi-channel."
)
elif ref.ndim == inf.ndim == 3:
# multi-channel reference and output
ref = ref[..., ref_channel]
inf = inf[..., ref_channel]
sdr, sir, sar, perm = bss_eval_sources(ref, inf, compute_permutation=True)
for i in range(num_spk):
stoi_score = stoi(ref[i], inf[int(perm[i])], fs_sig=sample_rate)
si_snr_score = -float(
ESPnetEnhancementModel.si_snr_loss(
torch.from_numpy(ref[i][None, ...]),
torch.from_numpy(inf[int(perm[i])][None, ...]),
)
)
writer[f"STOI_spk{i + 1}"][key] = str(stoi_score)
writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score)
writer[f"SDR_spk{i + 1}"][key] = str(sdr[i])
writer[f"SAR_spk{i + 1}"][key] = str(sar[i])
writer[f"SIR_spk{i + 1}"][key] = str(sir[i])
# save permutation assigned script file
writer[f"wav_spk{i + 1}"][key] = inf_readers[perm[i]].data[key]
def get_parser():
parser = config_argparse.ArgumentParser(
description="Frontend inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--ref_scp",
type=str,
required=True,
action="append",
)
group.add_argument(
"--inf_scp",
type=str,
required=True,
action="append",
)
group.add_argument("--key_file", type=str)
group.add_argument("--ref_channel", type=int, default=0)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
scoring(**kwargs)
if __name__ == "__main__":
main()
|