Spaces:
Runtime error
Runtime error
from typing import NamedTuple | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import logging | |
import numpy as np | |
import torch as T | |
from torch.nn import functional as F | |
import diac_utils as du | |
_x = [ | |
'a' | |
] | |
# logging.setLevel(logging.INFO) | |
logger = logging.getLogger(__file__) | |
logger.setLevel(logging.INFO) | |
def logln(*texts: str): | |
# logger.info(' '.join(texts)) | |
print(*texts) | |
# Relative improvement: | |
# T.mean((pred_c.argmax('c') == gt) - (pred_m.argmax('c') == gt)) | |
# Coverage Confidence: | |
# pred_c.argmax('c')[pred_c.argmax('c') != pred_m.argmax('c')].mean() | |
class PartialDiacMetrics(NamedTuple): | |
diff_total: float | |
worse_total: float | |
diff_relative: float | |
der_total: float | |
selectivity: float | |
hidden_der: float | |
partial_der: float | |
reader_error: float | |
def load_data(path: str): | |
if path.endswith('.txt'): | |
with open(path, 'r', encoding='utf-8') as fin: | |
return fin.readlines() | |
else: | |
return T.load(path) | |
def parse_data( | |
data, | |
logits: bool = False, | |
side=None, | |
): | |
if logits: | |
ld = data['line_data'] | |
diac_logits = T.tensor(ld[f'diac_logits_{side}']) | |
# diac_pred: T.Tensor = ld['diac_pred'] | |
diac_pred: T.Tensor = diac_logits.argmax(dim=-1) | |
diac_gt : T.Tensor = ld['diac_gt'] | |
# diac_logits = (ld['diac_logits_ctxt'], ld['diac_logits_base']) | |
return diac_pred, diac_gt, diac_logits | |
if isinstance(data, dict): | |
ld = data.get('line_data_fix', data['line_data']) | |
if side is None: | |
diac_pred: T.Tensor = ld['diac_pred'] | |
else: | |
diac_pred: T.Tensor = ld[f'diac_logits_{side}'].argmax(axis=-1) | |
diac_gt : T.Tensor = ld['diac_gt'] | |
return diac_pred, diac_gt | |
elif isinstance(data, list): | |
data_indices = [ | |
du.diac_ids_of_line(du.strip_tatweel(du.normalize_spaces(line))) | |
for line in data | |
] | |
max_len = max(map(len, data_indices)) | |
out = np.full((len(data), max_len), fill_value=du.DIAC_PAD_IDX) | |
for i_line, line_indices in enumerate(data_indices): | |
out[i_line][:len(line_indices)] = line_indices | |
return out, None | |
elif isinstance(data, (T.Tensor, np.ndarray)): | |
return data, None | |
else: | |
raise NotImplementedError | |
def make_mask_hard( | |
pred_c: T.Tensor, | |
pred_m: T.Tensor, | |
): | |
selection = (pred_c != pred_m) | |
return selection | |
def make_mask_logits( | |
pred_c: T.Tensor, | |
pred_m: T.Tensor, | |
threshold: float = 0.1, | |
version: str = '2', | |
) -> T.BoolTensor: | |
logger.warning(f"{version=}, {threshold=}") | |
pred_c = T.softmax(T.tensor(pred_c), dim=-1) | |
pred_m = T.softmax(T.tensor(pred_m), dim=-1) | |
# pred_i = pred_c.argmax(dim=-1) | |
if version == 'hard': | |
selection = pred_c.argmax(-1) != pred_m.argmax(-1) | |
elif version == '0': | |
selection = pred_c.max(dim=-1).values > pred_m.max(dim=-1).values | |
selection = selection & (pred_m.max(dim=-1).values > threshold) | |
elif version == '1': | |
pred_c_conf = pred_c.max(dim=-1).values | |
pred_m_conf = pred_m.max(dim=-1).values | |
selection = (pred_c_conf - pred_m_conf) > threshold | |
elif version == '1.1': | |
pred_c_conf = pred_c.max(dim=-1).values | |
pred_m_conf = pred_m.max(dim=-1).values | |
selection = (pred_c_conf - pred_m_conf).abs() > threshold | |
elif version.startswith('2'): | |
if version == '2': | |
max_c = pred_c.argmax(dim=-1, keepdims=True) | |
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold | |
elif version == '2.1': | |
max_c = pred_m.argmax(dim=-1, keepdims=True) | |
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold | |
elif version == '2.abs': | |
max_c = pred_c.argmax(dim=-1, keepdims=True) | |
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold | |
elif version == '2.1.abs': | |
max_c = pred_m.argmax(dim=-1, keepdims=True) | |
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold | |
elif version == '3': | |
selection = (pred_c - pred_m).max(dim=-1).values > threshold | |
elif version == '4': | |
selection_hard = (pred_c.argmax(-1) != pred_m.argmax(-1)) | |
# selection_logits = (pred_c.max(-1).values - pred_m.max(-1).values) > threshold | |
selection_logits = T.gather(pred_c - pred_m, dim=-1, index=pred_c.argmax(-1, keepdims=True)) > threshold | |
selection = selection_hard & selection_logits.squeeze() | |
# selection = (pred_c != pred_m) | |
return selection.squeeze() | |
def analysis_summary( | |
pred_c : T.LongTensor, | |
pred_m : T.LongTensor, | |
labels : T.LongTensor, | |
padding_mask: T.BoolTensor, | |
*, | |
selection : T.Tensor = None, | |
random: bool = False, | |
logits: tuple = None | |
): | |
#^ pred_c: [b tw tc | ClassId] | |
#^ pred_m: [b tw tc | ClassId] | |
#^ labels: [b tw tc | ClassId] | |
padding_mask = T.tensor(padding_mask) | |
# padding_mask[:, 200:] = False | |
nonpad_mask = ~padding_mask | |
num_chars = nonpad_mask.sum() | |
if logits is not None: | |
logits = tuple(map(T.tensor, logits)) | |
# pred_c = (logits[0] + logits[1]).argmax(-1) | |
pred_c = (T.softmax(logits[0], dim=-1) + T.softmax(logits[1], dim=-1)).argmax(-1) | |
pred_c = T.tensor(pred_c)[nonpad_mask] | |
pred_m = T.tensor(pred_m)[nonpad_mask] | |
labels = T.tensor(labels)[nonpad_mask] | |
#^ : [(b * tw * tc) | ClassId] | |
ctxt_match = (pred_c == labels).float() | |
base_match = (pred_m == labels).float() | |
selection = T.tensor(selection)[nonpad_mask] | |
if random: | |
selection = pred_c.new_empty(pred_c.shape).bernoulli_(p=selection.float().mean()).to(bool) | |
unselected = ~selection | |
assert num_chars > 0 | |
assert selection.sum() > 0 | |
base_accuracy = base_match[unselected].sum() / unselected.sum() | |
ctxt_accuracy = ctxt_match[selection].sum() / selection.sum() | |
correct_total = ctxt_match.sum() / num_chars | |
der_total = 1 - correct_total | |
cmp = (ctxt_match - base_match)[selection] | |
diff = T.sum(cmp) | |
diff_total = diff / num_chars | |
diff_relative = diff / selection.sum() | |
selectivity = selection.sum() / num_chars | |
worse_total = base_match[selection].sum() / num_chars | |
hidden_der = 1.0 - base_accuracy | |
partial_der = 1.0 - ctxt_accuracy | |
reader_error = selectivity * partial_der + (1 - selectivity) * hidden_der | |
return PartialDiacMetrics( | |
diff_total = round(diff_total.item() * 100, 2), | |
worse_total = round(worse_total.item() * 100, 2), | |
diff_relative = round(diff_relative.item() * 100, 2), | |
der_total = round(der_total.item() * 100, 2), | |
selectivity = round(selectivity.item() * 100, 2), | |
hidden_der = round(hidden_der.item() * 100, 2), | |
partial_der = round(partial_der.item() * 100, 2), | |
reader_error = round(reader_error.item() * 100, 2) | |
) | |
def relative_improvement_soft( | |
pred_c : T.Tensor, | |
pred_m : T.Tensor, | |
labels : T.LongTensor, | |
padding_mask: T.Tensor, | |
): | |
#^ pred_c: [b tw tc Classes="15"] | |
#^ pred_m: [b tw tc Classes="15"] | |
padding_mask = T.tensor(padding_mask) | |
nonpad_mask = 1 - padding_mask.float() | |
num_chars = nonpad_mask.sum() | |
pred_c = T.tensor(pred_c)[~padding_mask] | |
pred_m = T.tensor(pred_m)[~padding_mask] | |
#^ : [(b * tw * tc), Classes] | |
labels = T.tensor(labels)[~padding_mask] | |
#^ : [(b * tw * tc) | ClassId] | |
ctxt_match = T.gather(pred_c, dim=1, index=labels) | |
base_match = T.gather(pred_m, dim=1, index=labels) | |
selection = (pred_c.argmax(-1) != pred_m.argmax(-1)) | |
better = T.sum(ctxt_match - base_match) / num_chars | |
selectivity = selection.sum() / num_chars | |
worse = base_match[selection].sum() / num_chars | |
return better, worse, selectivity | |
def relative_improvement_masked_soft( | |
pred_c: T.Tensor, | |
pred_m: T.Tensor, | |
ground_truth: T.LongTensor, | |
padding_mask: T.Tensor, | |
): | |
raise NotImplementedError | |
#^ pred_c: [b tw tc "13"] | |
#^ pred_m: [b tw tc "13"] | |
#^ ground_truth: [b tw tc ClassId] | |
nonpad_mask = 1 - padding_mask | |
selection_mask = pred_c.argmax(3) != pred_m.argmax(3) | |
#^ selection_mask: [b tw tc] | |
probs = F.softmax(pred_c.clone(), dim=-1) | |
probs_gt = T.gather(probs, dim=-1, index=ground_truth.unsqueeze(-1)).squeeze(-1) | |
#^ probs_gt: [b tw tc] | |
result = probs_gt[selection_mask & nonpad_mask].mean() | |
return result | |
def coverage_confidence( | |
pred_c: T.Tensor, | |
pred_m: T.Tensor, | |
padding_mask: T.Tensor, | |
# selection_mask: T.Tensor, | |
): | |
raise NotImplementedError | |
#^ pred_c: [b tw tc "13"] | |
#^ pred_m: [b tw tc "13"] | |
#^ selection_mask: [b tw tc (bool)] | |
pred_c_id = pred_c.argmax(3) | |
pred_m_id = pred_m.argmax(3) | |
selected = pred_c_id[pred_c_id != pred_m_id] | |
nonpad_mask = 1 - padding_mask | |
result = selected.sum() / nonpad_mask.sum() | |
return result | |
def cli(): | |
parser = ArgumentParser('Compare diacritics from base/ctxt systems with partial diac metrics.') | |
parser.add_argument('-m', '--model-output-base', help="Path to tensor.pt dump files of base diacs.") | |
parser.add_argument('-c', '--model-output-ctxt', help="Path to tensor.pt dump files of ctxt diacs.") | |
parser.add_argument('--gt', default=None, help="Path to tensor.pt for gt only.") | |
parser.add_argument('--mode', choices=['hard', 'logits'], default='hard') | |
args = parser.parse_args() | |
model_output_base = parse_data( | |
load_data(args.model_output_base), | |
# logits=args.mode == 'logits', | |
logits=True, | |
side='base', | |
) | |
model_output_ctxt = parse_data( | |
load_data(args.model_output_ctxt), | |
# logits=args.mode == 'logits', | |
logits=True, | |
side='ctxt', | |
) | |
#^ shape: [b, tc] -> ClassId | |
diacs_pred = model_output_base | |
logln(f"{model_output_base[0].shape=} , {model_output_ctxt[0].shape=}") | |
assert len(model_output_base[0]) == len(model_output_ctxt[0]) | |
# for diacs_base, diacs_ctxt in zip( | |
# tqdm(model_output_base, dynamic_cols=True), | |
# model_output_ctxt | |
# ): | |
# diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0)[diacs_ctxt != -1] #< Ignore padding | |
xc = model_output_ctxt | |
xm = model_output_base | |
# if args.mode == 'logits': | |
# elif args.mode == 'hard': | |
# xc = model_output_ctxt | |
# xm = model_output_base | |
# if args.gt is not None: | |
# ground_truth = parse_data(load_data(args.gt))[1] | |
if xm[1] is not None: | |
ground_truth = xm[1] | |
elif xc[1] is not None: | |
ground_truth = xc[1] | |
assert ground_truth is not None | |
if args.mode == 'hard': | |
selection = make_mask_hard(xc[0], xm[0]) | |
elif args.mode == 'logits': | |
selection = make_mask_logits(xc[2], xm[2]) | |
metrics = analysis_summary( | |
xc[0], xm[0], ground_truth, ground_truth == -1, | |
selection=selection, | |
logits=(xc[2], xm[2]) | |
) | |
logln("Actual Totals:", metrics) | |
metrics = analysis_summary( | |
xc[0], xm[0], ground_truth, ground_truth == -1, random=True, | |
selection=selection, | |
logits=(xc[2], xm[2]) | |
) | |
logln("Random Marked Chars:", metrics) | |