from typing import NamedTuple from argparse import ArgumentParser from tqdm import tqdm import logging import numpy as np import torch as T from torch.nn import functional as F import diac_utils as du _x = [ 'a' ] # logging.setLevel(logging.INFO) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) def logln(*texts: str): # logger.info(' '.join(texts)) print(*texts) # Relative improvement: # T.mean((pred_c.argmax('c') == gt) - (pred_m.argmax('c') == gt)) # Coverage Confidence: # pred_c.argmax('c')[pred_c.argmax('c') != pred_m.argmax('c')].mean() class PartialDiacMetrics(NamedTuple): diff_total: float worse_total: float diff_relative: float der_total: float selectivity: float hidden_der: float partial_der: float reader_error: float def load_data(path: str): if path.endswith('.txt'): with open(path, 'r', encoding='utf-8') as fin: return fin.readlines() else: return T.load(path) def parse_data( data, logits: bool = False, side=None, ): if logits: ld = data['line_data'] diac_logits = T.tensor(ld[f'diac_logits_{side}']) # diac_pred: T.Tensor = ld['diac_pred'] diac_pred: T.Tensor = diac_logits.argmax(dim=-1) diac_gt : T.Tensor = ld['diac_gt'] # diac_logits = (ld['diac_logits_ctxt'], ld['diac_logits_base']) return diac_pred, diac_gt, diac_logits if isinstance(data, dict): ld = data.get('line_data_fix', data['line_data']) if side is None: diac_pred: T.Tensor = ld['diac_pred'] else: diac_pred: T.Tensor = ld[f'diac_logits_{side}'].argmax(axis=-1) diac_gt : T.Tensor = ld['diac_gt'] return diac_pred, diac_gt elif isinstance(data, list): data_indices = [ du.diac_ids_of_line(du.strip_tatweel(du.normalize_spaces(line))) for line in data ] max_len = max(map(len, data_indices)) out = np.full((len(data), max_len), fill_value=du.DIAC_PAD_IDX) for i_line, line_indices in enumerate(data_indices): out[i_line][:len(line_indices)] = line_indices return out, None elif isinstance(data, (T.Tensor, np.ndarray)): return data, None else: raise NotImplementedError def make_mask_hard( pred_c: T.Tensor, pred_m: T.Tensor, ): selection = (pred_c != pred_m) return selection def make_mask_logits( pred_c: T.Tensor, pred_m: T.Tensor, threshold: float = 0.1, version: str = '2', ) -> T.BoolTensor: logger.warning(f"{version=}, {threshold=}") pred_c = T.softmax(T.tensor(pred_c), dim=-1) pred_m = T.softmax(T.tensor(pred_m), dim=-1) # pred_i = pred_c.argmax(dim=-1) if version == 'hard': selection = pred_c.argmax(-1) != pred_m.argmax(-1) elif version == '0': selection = pred_c.max(dim=-1).values > pred_m.max(dim=-1).values selection = selection & (pred_m.max(dim=-1).values > threshold) elif version == '1': pred_c_conf = pred_c.max(dim=-1).values pred_m_conf = pred_m.max(dim=-1).values selection = (pred_c_conf - pred_m_conf) > threshold elif version == '1.1': pred_c_conf = pred_c.max(dim=-1).values pred_m_conf = pred_m.max(dim=-1).values selection = (pred_c_conf - pred_m_conf).abs() > threshold elif version.startswith('2'): if version == '2': max_c = pred_c.argmax(dim=-1, keepdims=True) selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold elif version == '2.1': max_c = pred_m.argmax(dim=-1, keepdims=True) selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold elif version == '2.abs': max_c = pred_c.argmax(dim=-1, keepdims=True) selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold elif version == '2.1.abs': max_c = pred_m.argmax(dim=-1, keepdims=True) selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold elif version == '3': selection = (pred_c - pred_m).max(dim=-1).values > threshold elif version == '4': selection_hard = (pred_c.argmax(-1) != pred_m.argmax(-1)) # selection_logits = (pred_c.max(-1).values - pred_m.max(-1).values) > threshold selection_logits = T.gather(pred_c - pred_m, dim=-1, index=pred_c.argmax(-1, keepdims=True)) > threshold selection = selection_hard & selection_logits.squeeze() # selection = (pred_c != pred_m) return selection.squeeze() def analysis_summary( pred_c : T.LongTensor, pred_m : T.LongTensor, labels : T.LongTensor, padding_mask: T.BoolTensor, *, selection : T.Tensor = None, random: bool = False, logits: tuple = None ): #^ pred_c: [b tw tc | ClassId] #^ pred_m: [b tw tc | ClassId] #^ labels: [b tw tc | ClassId] padding_mask = T.tensor(padding_mask) # padding_mask[:, 200:] = False nonpad_mask = ~padding_mask num_chars = nonpad_mask.sum() if logits is not None: logits = tuple(map(T.tensor, logits)) # pred_c = (logits[0] + logits[1]).argmax(-1) pred_c = (T.softmax(logits[0], dim=-1) + T.softmax(logits[1], dim=-1)).argmax(-1) pred_c = T.tensor(pred_c)[nonpad_mask] pred_m = T.tensor(pred_m)[nonpad_mask] labels = T.tensor(labels)[nonpad_mask] #^ : [(b * tw * tc) | ClassId] ctxt_match = (pred_c == labels).float() base_match = (pred_m == labels).float() selection = T.tensor(selection)[nonpad_mask] if random: selection = pred_c.new_empty(pred_c.shape).bernoulli_(p=selection.float().mean()).to(bool) unselected = ~selection assert num_chars > 0 assert selection.sum() > 0 base_accuracy = base_match[unselected].sum() / unselected.sum() ctxt_accuracy = ctxt_match[selection].sum() / selection.sum() correct_total = ctxt_match.sum() / num_chars der_total = 1 - correct_total cmp = (ctxt_match - base_match)[selection] diff = T.sum(cmp) diff_total = diff / num_chars diff_relative = diff / selection.sum() selectivity = selection.sum() / num_chars worse_total = base_match[selection].sum() / num_chars hidden_der = 1.0 - base_accuracy partial_der = 1.0 - ctxt_accuracy reader_error = selectivity * partial_der + (1 - selectivity) * hidden_der return PartialDiacMetrics( diff_total = round(diff_total.item() * 100, 2), worse_total = round(worse_total.item() * 100, 2), diff_relative = round(diff_relative.item() * 100, 2), der_total = round(der_total.item() * 100, 2), selectivity = round(selectivity.item() * 100, 2), hidden_der = round(hidden_der.item() * 100, 2), partial_der = round(partial_der.item() * 100, 2), reader_error = round(reader_error.item() * 100, 2) ) def relative_improvement_soft( pred_c : T.Tensor, pred_m : T.Tensor, labels : T.LongTensor, padding_mask: T.Tensor, ): #^ pred_c: [b tw tc Classes="15"] #^ pred_m: [b tw tc Classes="15"] padding_mask = T.tensor(padding_mask) nonpad_mask = 1 - padding_mask.float() num_chars = nonpad_mask.sum() pred_c = T.tensor(pred_c)[~padding_mask] pred_m = T.tensor(pred_m)[~padding_mask] #^ : [(b * tw * tc), Classes] labels = T.tensor(labels)[~padding_mask] #^ : [(b * tw * tc) | ClassId] ctxt_match = T.gather(pred_c, dim=1, index=labels) base_match = T.gather(pred_m, dim=1, index=labels) selection = (pred_c.argmax(-1) != pred_m.argmax(-1)) better = T.sum(ctxt_match - base_match) / num_chars selectivity = selection.sum() / num_chars worse = base_match[selection].sum() / num_chars return better, worse, selectivity def relative_improvement_masked_soft( pred_c: T.Tensor, pred_m: T.Tensor, ground_truth: T.LongTensor, padding_mask: T.Tensor, ): raise NotImplementedError #^ pred_c: [b tw tc "13"] #^ pred_m: [b tw tc "13"] #^ ground_truth: [b tw tc ClassId] nonpad_mask = 1 - padding_mask selection_mask = pred_c.argmax(3) != pred_m.argmax(3) #^ selection_mask: [b tw tc] probs = F.softmax(pred_c.clone(), dim=-1) probs_gt = T.gather(probs, dim=-1, index=ground_truth.unsqueeze(-1)).squeeze(-1) #^ probs_gt: [b tw tc] result = probs_gt[selection_mask & nonpad_mask].mean() return result def coverage_confidence( pred_c: T.Tensor, pred_m: T.Tensor, padding_mask: T.Tensor, # selection_mask: T.Tensor, ): raise NotImplementedError #^ pred_c: [b tw tc "13"] #^ pred_m: [b tw tc "13"] #^ selection_mask: [b tw tc (bool)] pred_c_id = pred_c.argmax(3) pred_m_id = pred_m.argmax(3) selected = pred_c_id[pred_c_id != pred_m_id] nonpad_mask = 1 - padding_mask result = selected.sum() / nonpad_mask.sum() return result def cli(): parser = ArgumentParser('Compare diacritics from base/ctxt systems with partial diac metrics.') parser.add_argument('-m', '--model-output-base', help="Path to tensor.pt dump files of base diacs.") parser.add_argument('-c', '--model-output-ctxt', help="Path to tensor.pt dump files of ctxt diacs.") parser.add_argument('--gt', default=None, help="Path to tensor.pt for gt only.") parser.add_argument('--mode', choices=['hard', 'logits'], default='hard') args = parser.parse_args() model_output_base = parse_data( load_data(args.model_output_base), # logits=args.mode == 'logits', logits=True, side='base', ) model_output_ctxt = parse_data( load_data(args.model_output_ctxt), # logits=args.mode == 'logits', logits=True, side='ctxt', ) #^ shape: [b, tc] -> ClassId diacs_pred = model_output_base logln(f"{model_output_base[0].shape=} , {model_output_ctxt[0].shape=}") assert len(model_output_base[0]) == len(model_output_ctxt[0]) # for diacs_base, diacs_ctxt in zip( # tqdm(model_output_base, dynamic_cols=True), # model_output_ctxt # ): # diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0)[diacs_ctxt != -1] #< Ignore padding xc = model_output_ctxt xm = model_output_base # if args.mode == 'logits': # elif args.mode == 'hard': # xc = model_output_ctxt # xm = model_output_base # if args.gt is not None: # ground_truth = parse_data(load_data(args.gt))[1] if xm[1] is not None: ground_truth = xm[1] elif xc[1] is not None: ground_truth = xc[1] assert ground_truth is not None if args.mode == 'hard': selection = make_mask_hard(xc[0], xm[0]) elif args.mode == 'logits': selection = make_mask_logits(xc[2], xm[2]) metrics = analysis_summary( xc[0], xm[0], ground_truth, ground_truth == -1, selection=selection, logits=(xc[2], xm[2]) ) logln("Actual Totals:", metrics) metrics = analysis_summary( xc[0], xm[0], ground_truth, ground_truth == -1, random=True, selection=selection, logits=(xc[2], xm[2]) ) logln("Random Marked Chars:", metrics)