# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import os.path as op import re from tabulate import tabulate from collections import Counter def comp_purity(p_xy, axis): max_p = p_xy.max(axis=axis) marg_p = p_xy.sum(axis=axis) indv_pur = max_p / marg_p aggr_pur = max_p.sum() return indv_pur, aggr_pur def comp_entropy(p): return (-p * np.log(p + 1e-8)).sum() def comp_norm_mutual_info(p_xy): p_x = p_xy.sum(axis=1, keepdims=True) p_y = p_xy.sum(axis=0, keepdims=True) pmi = np.log(p_xy / np.matmul(p_x, p_y) + 1e-8) mi = (p_xy * pmi).sum() h_x = comp_entropy(p_x) h_y = comp_entropy(p_y) return mi, mi / h_x, mi / h_y, h_x, h_y def pad(labs, n): if n == 0: return np.array(labs) return np.concatenate([[labs[0]] * n, labs, [labs[-1]] * n]) def comp_avg_seg_dur(labs_list): n_frms = 0 n_segs = 0 for labs in labs_list: labs = np.array(labs) edges = np.zeros(len(labs)).astype(bool) edges[0] = True edges[1:] = labs[1:] != labs[:-1] n_frms += len(edges) n_segs += edges.astype(int).sum() return n_frms / n_segs def comp_joint_prob(uid2refs, uid2hyps): """ Args: pad: padding for spliced-feature derived labels """ cnts = Counter() skipped = [] abs_frmdiff = 0 for uid in uid2refs: if uid not in uid2hyps: skipped.append(uid) continue refs = uid2refs[uid] hyps = uid2hyps[uid] abs_frmdiff += abs(len(refs) - len(hyps)) min_len = min(len(refs), len(hyps)) refs = refs[:min_len] hyps = hyps[:min_len] cnts.update(zip(refs, hyps)) tot = sum(cnts.values()) ref_set = sorted({ref for ref, _ in cnts.keys()}) hyp_set = sorted({hyp for _, hyp in cnts.keys()}) ref2pid = dict(zip(ref_set, range(len(ref_set)))) hyp2lid = dict(zip(hyp_set, range(len(hyp_set)))) # print(hyp_set) p_xy = np.zeros((len(ref2pid), len(hyp2lid)), dtype=float) for (ref, hyp), cnt in cnts.items(): p_xy[ref2pid[ref], hyp2lid[hyp]] = cnt p_xy /= p_xy.sum() return p_xy, ref2pid, hyp2lid, tot, abs_frmdiff, skipped def read_phn(tsv_path, rm_stress=True): uid2phns = {} with open(tsv_path) as f: for line in f: uid, phns = line.rstrip().split("\t") phns = phns.split(",") if rm_stress: phns = [re.sub("[0-9]", "", phn) for phn in phns] uid2phns[uid] = phns return uid2phns def read_lab(tsv_path, lab_path, pad_len=0, upsample=1): """ tsv is needed to retrieve the uids for the labels """ with open(tsv_path) as f: f.readline() uids = [op.splitext(op.basename(line.rstrip().split()[0]))[0] for line in f] with open(lab_path) as f: labs_list = [pad(line.rstrip().split(), pad_len).repeat(upsample) for line in f] assert len(uids) == len(labs_list) return dict(zip(uids, labs_list)) def main_lab_lab( tsv_dir, lab_dir, lab_name, lab_sets, ref_dir, ref_name, pad_len=0, upsample=1, verbose=False, ): # assume tsv_dir is the same for both the reference and the hypotheses tsv_dir = lab_dir if tsv_dir is None else tsv_dir uid2refs = {} for s in lab_sets: uid2refs.update(read_lab(f"{tsv_dir}/{s}.tsv", f"{ref_dir}/{s}.{ref_name}")) uid2hyps = {} for s in lab_sets: uid2hyps.update( read_lab( f"{tsv_dir}/{s}.tsv", f"{lab_dir}/{s}.{lab_name}", pad_len, upsample ) ) _main(uid2refs, uid2hyps, verbose) def main_phn_lab( tsv_dir, lab_dir, lab_name, lab_sets, phn_dir, phn_sets, pad_len=0, upsample=1, verbose=False, ): uid2refs = {} for s in phn_sets: uid2refs.update(read_phn(f"{phn_dir}/{s}.tsv")) uid2hyps = {} tsv_dir = lab_dir if tsv_dir is None else tsv_dir for s in lab_sets: uid2hyps.update( read_lab( f"{tsv_dir}/{s}.tsv", f"{lab_dir}/{s}.{lab_name}", pad_len, upsample ) ) _main(uid2refs, uid2hyps, verbose) def _main(uid2refs, uid2hyps, verbose): (p_xy, ref2pid, hyp2lid, tot, frmdiff, skipped) = comp_joint_prob( uid2refs, uid2hyps ) ref_pur_by_hyp, ref_pur = comp_purity(p_xy, axis=0) hyp_pur_by_ref, hyp_pur = comp_purity(p_xy, axis=1) (mi, mi_norm_by_ref, mi_norm_by_hyp, h_ref, h_hyp) = comp_norm_mutual_info(p_xy) outputs = { "ref pur": ref_pur, "hyp pur": hyp_pur, "H(ref)": h_ref, "H(hyp)": h_hyp, "MI": mi, "MI/H(ref)": mi_norm_by_ref, "ref segL": comp_avg_seg_dur(uid2refs.values()), "hyp segL": comp_avg_seg_dur(uid2hyps.values()), "p_xy shape": p_xy.shape, "frm tot": tot, "frm diff": frmdiff, "utt tot": len(uid2refs), "utt miss": len(skipped), } print(tabulate([outputs.values()], outputs.keys(), floatfmt=".4f")) if __name__ == "__main__": """ compute quality of labels with respect to phone or another labels if set """ import argparse parser = argparse.ArgumentParser() parser.add_argument("tsv_dir") parser.add_argument("lab_dir") parser.add_argument("lab_name") parser.add_argument("--lab_sets", default=["valid"], type=str, nargs="+") parser.add_argument( "--phn_dir", default="/checkpoint/wnhsu/data/librispeech/960h/fa/raw_phn/phone_frame_align_v1", ) parser.add_argument( "--phn_sets", default=["dev-clean", "dev-other"], type=str, nargs="+" ) parser.add_argument("--pad_len", default=0, type=int, help="padding for hypotheses") parser.add_argument( "--upsample", default=1, type=int, help="upsample factor for hypotheses" ) parser.add_argument("--ref_lab_dir", default="") parser.add_argument("--ref_lab_name", default="") parser.add_argument("--verbose", action="store_true") args = parser.parse_args() if args.ref_lab_dir and args.ref_lab_name: main_lab_lab( args.tsv_dir, args.lab_dir, args.lab_name, args.lab_sets, args.ref_lab_dir, args.ref_lab_name, args.pad_len, args.upsample, args.verbose, ) else: main_phn_lab( args.tsv_dir, args.lab_dir, args.lab_name, args.lab_sets, args.phn_dir, args.phn_sets, args.pad_len, args.upsample, args.verbose, )