Partial-Arabic-Diacritization / partial_dd_metrics.py
bkhmsi's picture
pdd working now
ebc546a
from typing import NamedTuple
from argparse import ArgumentParser
from tqdm import tqdm
import logging
import numpy as np
import torch as T
from torch.nn import functional as F
import diac_utils as du
_x = [
'a'
]
# logging.setLevel(logging.INFO)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
def logln(*texts: str):
# logger.info(' '.join(texts))
print(*texts)
# Relative improvement:
# T.mean((pred_c.argmax('c') == gt) - (pred_m.argmax('c') == gt))
# Coverage Confidence:
# pred_c.argmax('c')[pred_c.argmax('c') != pred_m.argmax('c')].mean()
class PartialDiacMetrics(NamedTuple):
diff_total: float
worse_total: float
diff_relative: float
der_total: float
selectivity: float
hidden_der: float
partial_der: float
reader_error: float
def load_data(path: str):
if path.endswith('.txt'):
with open(path, 'r', encoding='utf-8') as fin:
return fin.readlines()
else:
return T.load(path)
def parse_data(
data,
logits: bool = False,
side=None,
):
if logits:
ld = data['line_data']
diac_logits = T.tensor(ld[f'diac_logits_{side}'])
# diac_pred: T.Tensor = ld['diac_pred']
diac_pred: T.Tensor = diac_logits.argmax(dim=-1)
diac_gt : T.Tensor = ld['diac_gt']
# diac_logits = (ld['diac_logits_ctxt'], ld['diac_logits_base'])
return diac_pred, diac_gt, diac_logits
if isinstance(data, dict):
ld = data.get('line_data_fix', data['line_data'])
if side is None:
diac_pred: T.Tensor = ld['diac_pred']
else:
diac_pred: T.Tensor = ld[f'diac_logits_{side}'].argmax(axis=-1)
diac_gt : T.Tensor = ld['diac_gt']
return diac_pred, diac_gt
elif isinstance(data, list):
data_indices = [
du.diac_ids_of_line(du.strip_tatweel(du.normalize_spaces(line)))
for line in data
]
max_len = max(map(len, data_indices))
out = np.full((len(data), max_len), fill_value=du.DIAC_PAD_IDX)
for i_line, line_indices in enumerate(data_indices):
out[i_line][:len(line_indices)] = line_indices
return out, None
elif isinstance(data, (T.Tensor, np.ndarray)):
return data, None
else:
raise NotImplementedError
def make_mask_hard(
pred_c: T.Tensor,
pred_m: T.Tensor,
):
selection = (pred_c != pred_m)
return selection
def make_mask_logits(
pred_c: T.Tensor,
pred_m: T.Tensor,
threshold: float = 0.1,
version: str = '2',
) -> T.BoolTensor:
logger.warning(f"{version=}, {threshold=}")
pred_c = T.softmax(T.tensor(pred_c), dim=-1)
pred_m = T.softmax(T.tensor(pred_m), dim=-1)
# pred_i = pred_c.argmax(dim=-1)
if version == 'hard':
selection = pred_c.argmax(-1) != pred_m.argmax(-1)
elif version == '0':
selection = pred_c.max(dim=-1).values > pred_m.max(dim=-1).values
selection = selection & (pred_m.max(dim=-1).values > threshold)
elif version == '1':
pred_c_conf = pred_c.max(dim=-1).values
pred_m_conf = pred_m.max(dim=-1).values
selection = (pred_c_conf - pred_m_conf) > threshold
elif version == '1.1':
pred_c_conf = pred_c.max(dim=-1).values
pred_m_conf = pred_m.max(dim=-1).values
selection = (pred_c_conf - pred_m_conf).abs() > threshold
elif version.startswith('2'):
if version == '2':
max_c = pred_c.argmax(dim=-1, keepdims=True)
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold
elif version == '2.1':
max_c = pred_m.argmax(dim=-1, keepdims=True)
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold
elif version == '2.abs':
max_c = pred_c.argmax(dim=-1, keepdims=True)
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold
elif version == '2.1.abs':
max_c = pred_m.argmax(dim=-1, keepdims=True)
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold
elif version == '3':
selection = (pred_c - pred_m).max(dim=-1).values > threshold
elif version == '4':
selection_hard = (pred_c.argmax(-1) != pred_m.argmax(-1))
# selection_logits = (pred_c.max(-1).values - pred_m.max(-1).values) > threshold
selection_logits = T.gather(pred_c - pred_m, dim=-1, index=pred_c.argmax(-1, keepdims=True)) > threshold
selection = selection_hard & selection_logits.squeeze()
# selection = (pred_c != pred_m)
return selection.squeeze()
def analysis_summary(
pred_c : T.LongTensor,
pred_m : T.LongTensor,
labels : T.LongTensor,
padding_mask: T.BoolTensor,
*,
selection : T.Tensor = None,
random: bool = False,
logits: tuple = None
):
#^ pred_c: [b tw tc | ClassId]
#^ pred_m: [b tw tc | ClassId]
#^ labels: [b tw tc | ClassId]
padding_mask = T.tensor(padding_mask)
# padding_mask[:, 200:] = False
nonpad_mask = ~padding_mask
num_chars = nonpad_mask.sum()
if logits is not None:
logits = tuple(map(T.tensor, logits))
# pred_c = (logits[0] + logits[1]).argmax(-1)
pred_c = (T.softmax(logits[0], dim=-1) + T.softmax(logits[1], dim=-1)).argmax(-1)
pred_c = T.tensor(pred_c)[nonpad_mask]
pred_m = T.tensor(pred_m)[nonpad_mask]
labels = T.tensor(labels)[nonpad_mask]
#^ : [(b * tw * tc) | ClassId]
ctxt_match = (pred_c == labels).float()
base_match = (pred_m == labels).float()
selection = T.tensor(selection)[nonpad_mask]
if random:
selection = pred_c.new_empty(pred_c.shape).bernoulli_(p=selection.float().mean()).to(bool)
unselected = ~selection
assert num_chars > 0
assert selection.sum() > 0
base_accuracy = base_match[unselected].sum() / unselected.sum()
ctxt_accuracy = ctxt_match[selection].sum() / selection.sum()
correct_total = ctxt_match.sum() / num_chars
der_total = 1 - correct_total
cmp = (ctxt_match - base_match)[selection]
diff = T.sum(cmp)
diff_total = diff / num_chars
diff_relative = diff / selection.sum()
selectivity = selection.sum() / num_chars
worse_total = base_match[selection].sum() / num_chars
hidden_der = 1.0 - base_accuracy
partial_der = 1.0 - ctxt_accuracy
reader_error = selectivity * partial_der + (1 - selectivity) * hidden_der
return PartialDiacMetrics(
diff_total = round(diff_total.item() * 100, 2),
worse_total = round(worse_total.item() * 100, 2),
diff_relative = round(diff_relative.item() * 100, 2),
der_total = round(der_total.item() * 100, 2),
selectivity = round(selectivity.item() * 100, 2),
hidden_der = round(hidden_der.item() * 100, 2),
partial_der = round(partial_der.item() * 100, 2),
reader_error = round(reader_error.item() * 100, 2)
)
def relative_improvement_soft(
pred_c : T.Tensor,
pred_m : T.Tensor,
labels : T.LongTensor,
padding_mask: T.Tensor,
):
#^ pred_c: [b tw tc Classes="15"]
#^ pred_m: [b tw tc Classes="15"]
padding_mask = T.tensor(padding_mask)
nonpad_mask = 1 - padding_mask.float()
num_chars = nonpad_mask.sum()
pred_c = T.tensor(pred_c)[~padding_mask]
pred_m = T.tensor(pred_m)[~padding_mask]
#^ : [(b * tw * tc), Classes]
labels = T.tensor(labels)[~padding_mask]
#^ : [(b * tw * tc) | ClassId]
ctxt_match = T.gather(pred_c, dim=1, index=labels)
base_match = T.gather(pred_m, dim=1, index=labels)
selection = (pred_c.argmax(-1) != pred_m.argmax(-1))
better = T.sum(ctxt_match - base_match) / num_chars
selectivity = selection.sum() / num_chars
worse = base_match[selection].sum() / num_chars
return better, worse, selectivity
def relative_improvement_masked_soft(
pred_c: T.Tensor,
pred_m: T.Tensor,
ground_truth: T.LongTensor,
padding_mask: T.Tensor,
):
raise NotImplementedError
#^ pred_c: [b tw tc "13"]
#^ pred_m: [b tw tc "13"]
#^ ground_truth: [b tw tc ClassId]
nonpad_mask = 1 - padding_mask
selection_mask = pred_c.argmax(3) != pred_m.argmax(3)
#^ selection_mask: [b tw tc]
probs = F.softmax(pred_c.clone(), dim=-1)
probs_gt = T.gather(probs, dim=-1, index=ground_truth.unsqueeze(-1)).squeeze(-1)
#^ probs_gt: [b tw tc]
result = probs_gt[selection_mask & nonpad_mask].mean()
return result
def coverage_confidence(
pred_c: T.Tensor,
pred_m: T.Tensor,
padding_mask: T.Tensor,
# selection_mask: T.Tensor,
):
raise NotImplementedError
#^ pred_c: [b tw tc "13"]
#^ pred_m: [b tw tc "13"]
#^ selection_mask: [b tw tc (bool)]
pred_c_id = pred_c.argmax(3)
pred_m_id = pred_m.argmax(3)
selected = pred_c_id[pred_c_id != pred_m_id]
nonpad_mask = 1 - padding_mask
result = selected.sum() / nonpad_mask.sum()
return result
def cli():
parser = ArgumentParser('Compare diacritics from base/ctxt systems with partial diac metrics.')
parser.add_argument('-m', '--model-output-base', help="Path to tensor.pt dump files of base diacs.")
parser.add_argument('-c', '--model-output-ctxt', help="Path to tensor.pt dump files of ctxt diacs.")
parser.add_argument('--gt', default=None, help="Path to tensor.pt for gt only.")
parser.add_argument('--mode', choices=['hard', 'logits'], default='hard')
args = parser.parse_args()
model_output_base = parse_data(
load_data(args.model_output_base),
# logits=args.mode == 'logits',
logits=True,
side='base',
)
model_output_ctxt = parse_data(
load_data(args.model_output_ctxt),
# logits=args.mode == 'logits',
logits=True,
side='ctxt',
)
#^ shape: [b, tc] -> ClassId
diacs_pred = model_output_base
logln(f"{model_output_base[0].shape=} , {model_output_ctxt[0].shape=}")
assert len(model_output_base[0]) == len(model_output_ctxt[0])
# for diacs_base, diacs_ctxt in zip(
# tqdm(model_output_base, dynamic_cols=True),
# model_output_ctxt
# ):
# diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0)[diacs_ctxt != -1] #< Ignore padding
xc = model_output_ctxt
xm = model_output_base
# if args.mode == 'logits':
# elif args.mode == 'hard':
# xc = model_output_ctxt
# xm = model_output_base
# if args.gt is not None:
# ground_truth = parse_data(load_data(args.gt))[1]
if xm[1] is not None:
ground_truth = xm[1]
elif xc[1] is not None:
ground_truth = xc[1]
assert ground_truth is not None
if args.mode == 'hard':
selection = make_mask_hard(xc[0], xm[0])
elif args.mode == 'logits':
selection = make_mask_logits(xc[2], xm[2])
metrics = analysis_summary(
xc[0], xm[0], ground_truth, ground_truth == -1,
selection=selection,
logits=(xc[2], xm[2])
)
logln("Actual Totals:", metrics)
metrics = analysis_summary(
xc[0], xm[0], ground_truth, ground_truth == -1, random=True,
selection=selection,
logits=(xc[2], xm[2])
)
logln("Random Marked Chars:", metrics)