|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script can be used to visualize the errors made by a (duplex) TN system. |
|
More specifically, after running the evaluation script `duplex_text_normalization_test.py`, |
|
a log file containing info about the errors will be generated. The location of this file |
|
is determined by the argument `inference.errors_log_fp`. After that, we can use this |
|
script to generate a HTML visualization. |
|
|
|
USAGE Example: |
|
# python analyze_errors.py \ |
|
--errors_log_fp=PATH_TO_ERRORS_LOG_FILE_PATH \ |
|
--visualization_fp=PATH_TO_VISUALIZATION_FILE_PATH |
|
""" |
|
|
|
from argparse import ArgumentParser |
|
from typing import List |
|
|
|
from nemo.collections.nlp.data.text_normalization import constants |
|
|
|
|
|
|
|
def lcs(X, Y): |
|
""" Function for finding the longest common subsequence between two lists. |
|
In this script, this function is particular used for aligning between the |
|
ground-truth output string and the predicted string (for visualization purpose). |
|
Args: |
|
X: a list |
|
Y: a list |
|
|
|
Returns: a list which is the longest common subsequence between X and Y |
|
""" |
|
m, n = len(X), len(Y) |
|
L = [[0 for x in range(n + 1)] for x in range(m + 1)] |
|
|
|
|
|
|
|
for i in range(m + 1): |
|
for j in range(n + 1): |
|
if i == 0 or j == 0: |
|
L[i][j] = 0 |
|
elif X[i - 1] == Y[j - 1]: |
|
L[i][j] = L[i - 1][j - 1] + 1 |
|
else: |
|
L[i][j] = max(L[i - 1][j], L[i][j - 1]) |
|
|
|
|
|
index = L[m][n] |
|
|
|
|
|
lcs = [''] * (index + 1) |
|
lcs[index] = '' |
|
|
|
|
|
|
|
i = m |
|
j = n |
|
while i > 0 and j > 0: |
|
|
|
|
|
|
|
if X[i - 1] == Y[j - 1]: |
|
lcs[index - 1] = X[i - 1] |
|
i -= 1 |
|
j -= 1 |
|
index -= 1 |
|
|
|
|
|
|
|
elif L[i - 1][j] > L[i][j - 1]: |
|
i -= 1 |
|
else: |
|
j -= 1 |
|
|
|
return lcs[:-1] |
|
|
|
|
|
|
|
class ErrorCase: |
|
""" |
|
This class represents an error case |
|
|
|
Args: |
|
_input: Original input string |
|
target: Ground-truth target string |
|
pred: Predicted string |
|
mode: A string indicates the mode (i.e., constants.ITN_MODE or constants.TN_MODE) |
|
""" |
|
|
|
def __init__(self, _input: str, target: str, pred: str, classes: str, mode: str): |
|
self._input = _input |
|
self.target = target |
|
self.pred = pred |
|
self.mode = mode |
|
self.classes = classes |
|
|
|
|
|
self.target_tokens = self.target.split(' ') |
|
self.pred_tokens = self.pred.split(' ') |
|
|
|
|
|
lcs_tokens = lcs(self.target_tokens, self.pred_tokens) |
|
target_tokens_hightlight = [False] * len(self.target_tokens) |
|
pred_tokens_hightlight = [False] * len(self.pred_tokens) |
|
target_idx, pred_idx = 0, 0 |
|
for token in lcs_tokens: |
|
while self.target_tokens[target_idx] != token: |
|
target_idx += 1 |
|
while self.pred_tokens[pred_idx] != token: |
|
pred_idx += 1 |
|
target_tokens_hightlight[target_idx] = True |
|
pred_tokens_hightlight[pred_idx] = True |
|
target_idx += 1 |
|
pred_idx += 1 |
|
|
|
|
|
self.target_spans = self.get_spans(target_tokens_hightlight) |
|
self.pred_spans = self.get_spans(pred_tokens_hightlight) |
|
|
|
|
|
unhighlighted_target_spans = [] |
|
for ix, t in enumerate(self.target_spans): |
|
if not t[-1]: |
|
unhighlighted_target_spans.append((ix, t)) |
|
|
|
unhighlighted_pred_spans = [] |
|
for ix, t in enumerate(self.pred_spans): |
|
if not t[-1]: |
|
unhighlighted_pred_spans.append((ix, t)) |
|
|
|
@classmethod |
|
def from_lines(cls, lines: List[str], mode: str): |
|
""" |
|
This method returns an instance of ErrorCase from raw string lines. |
|
|
|
Args: |
|
lines: A list of raw string lines for the error case. |
|
mode: A string indicates the mode (i.e., constants.ITN_MODE or constants.TN_MODE) |
|
|
|
Returns: an instance of ErrorCase. |
|
""" |
|
for line in lines: |
|
if line.startswith('Original Input'): |
|
_input = line[line.find(':') + 1 :].strip() |
|
elif line.startswith('Predicted Str'): |
|
pred = line[line.find(':') + 1 :].strip() |
|
elif line.startswith('Ground-Truth'): |
|
target = line[line.find(':') + 1 :].strip() |
|
elif line.startswith('Ground Classes'): |
|
classes = line[line.find(':') + 1 :].strip() |
|
return cls(_input, target, pred, classes, mode) |
|
|
|
def get_html(self): |
|
""" |
|
This method returns a HTML string representing this error case instance. |
|
Returns: a string contains the HTML representing this error case instance. |
|
""" |
|
html_str = '' |
|
|
|
input_form = 'Written' if self.mode == constants.TN_MODE else 'Spoken' |
|
padding_multiplier = 1 if self.mode == constants.TN_MODE else 2 |
|
padding_spaces = ''.join([' '] * padding_multiplier) |
|
input_str = f'<b>[Input ({input_form})]{padding_spaces}</b>: {self._input}</br>\n' |
|
html_str += input_str + ' ' |
|
|
|
target_html = self.get_spans_html(self.target_spans, self.target_tokens) |
|
target_form = 'Spoken' if self.mode == constants.TN_MODE else 'Written' |
|
target_str = f'<b>[Target ({target_form})]</b>: {target_html}</br>\n' |
|
html_str += target_str + ' ' |
|
|
|
pred_html = self.get_spans_html(self.pred_spans, self.pred_tokens) |
|
padding_multiplier = 10 if self.mode == constants.TN_MODE else 11 |
|
padding_spaces = ''.join([' '] * padding_multiplier) |
|
pred_str = f'<b>[Prediction]{padding_spaces}</b>: {pred_html}</br>\n' |
|
html_str += pred_str + ' ' |
|
|
|
padding_multiplier = 15 if self.mode == constants.TN_MODE else 16 |
|
padding_spaces = ''.join([' '] * padding_multiplier) |
|
class_str = f'<b>[Classes]{padding_spaces}</b>: {self.classes}</br>\n' |
|
html_str += class_str + ' ' |
|
|
|
html_str += '</br>\n' |
|
return html_str |
|
|
|
def get_spans(self, tokens_hightlight): |
|
""" |
|
This method extracts the list of spans. |
|
|
|
Args: |
|
tokens_hightlight: A list of boolean values where each value indicates whether a token needs to be hightlighted. |
|
|
|
Returns: |
|
spans: A list of spans. Each span is represented by a tuple of 3 elements: (1) Start Index (2) End Index (3) A boolean value indicating whether the span needs to be hightlighted. |
|
""" |
|
spans, nb_tokens = [], len(tokens_hightlight) |
|
cur_start_idx, cur_bool_val = 0, tokens_hightlight[0] |
|
for idx in range(nb_tokens): |
|
if idx == nb_tokens - 1: |
|
if tokens_hightlight[idx] != cur_bool_val: |
|
spans.append((cur_start_idx, nb_tokens - 2, cur_bool_val)) |
|
spans.append((nb_tokens - 1, nb_tokens - 1, tokens_hightlight[idx])) |
|
else: |
|
spans.append((cur_start_idx, nb_tokens - 1, cur_bool_val)) |
|
else: |
|
if tokens_hightlight[idx] != cur_bool_val: |
|
spans.append((cur_start_idx, idx - 1, cur_bool_val)) |
|
cur_start_idx, cur_bool_val = idx, tokens_hightlight[idx] |
|
return spans |
|
|
|
def get_spans_html(self, spans, tokens): |
|
""" |
|
This method generates a HTML string for a string sequence from its spans. |
|
|
|
Args: |
|
spans: A list of contiguous spans in a sequence. Each span is represented by a tuple of 3 elements: (1) Start Index (2) End Index (3) A boolean value indicating whether the span needs to be hightlighted. |
|
tokens: All tokens in the sequence |
|
Returns: |
|
html_str: A HTML string for the string sequence. |
|
""" |
|
html_str = '' |
|
for start, end, type in spans: |
|
color = 'red' if type else 'black' |
|
span_tokens = tokens[start : end + 1] |
|
span_str = '<span style="color:{}">{}</span> '.format(color, ' '.join(span_tokens)) |
|
html_str += span_str |
|
return html_str |
|
|
|
|
|
|
|
def analyze(errors_log_fp: str, visualization_fp: str): |
|
""" |
|
This method generates a HTML visualization of the error cases logged in a log file. |
|
|
|
Args: |
|
errors_log_fp: Path to the error log file |
|
visualization_fp: Path to the output visualization file |
|
|
|
""" |
|
|
|
with open(errors_log_fp, 'r', encoding='utf-8') as f: |
|
lines = f.readlines() |
|
|
|
|
|
tn_error_cases, itn_error_cases = [], [] |
|
for ix in range(0, len(lines), 8): |
|
mode_line = lines[ix] |
|
info_lines = lines[ix + 1 : ix + 7] |
|
|
|
if mode_line.startswith('Forward Problem'): |
|
mode = constants.TN_MODE |
|
tn_error_cases.append(ErrorCase.from_lines(info_lines, mode)) |
|
elif mode_line.startswith('Backward Problem'): |
|
mode = constants.ITN_MODE |
|
itn_error_cases.append(ErrorCase.from_lines(info_lines, mode)) |
|
|
|
|
|
print('---- Text Normalization ----') |
|
print('Number of TN errors: {}'.format(len(tn_error_cases))) |
|
|
|
print('---- Inverse Text Normalization ---- ') |
|
print('Number of ITN errors: {}'.format(len(itn_error_cases))) |
|
|
|
|
|
with open(visualization_fp, 'w+', encoding='utf-8') as f: |
|
|
|
f.write('Appendix</br>') |
|
f.write('<a href="#tn_section">Text Normalization Analysis.</a></br>') |
|
f.write('<a href="#itn_section">Inverse Text Normalization Analysis.</a>') |
|
|
|
|
|
f.write('<h2 id="tn_section">Text Normalization</h2>\n') |
|
for errorcase in tn_error_cases: |
|
f.write(errorcase.get_html()) |
|
|
|
|
|
f.write('<h2 id="itn_section">Inverse Text Normalization</h2>\n') |
|
for errorcase in itn_error_cases: |
|
f.write(errorcase.get_html()) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('--errors_log_fp', help='Path to the error log file', required=True) |
|
parser.add_argument('--visualization_fp', help='Path to the output visualization file', required=True) |
|
args = parser.parse_args() |
|
|
|
analyze(args.errors_log_fp, args.visualization_fp) |
|
|