Spaces:
Runtime error
Runtime error
# Copyright 2020 IBM | |
# Author: peter.zhong@au1.ibm.com | |
# | |
# This is free software; you can redistribute it and/or modify | |
# it under the terms of the Apache 2.0 License. | |
# | |
# This software is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# Apache 2.0 License for more details. | |
from rapidfuzz.distance import Levenshtein | |
from apted import APTED, Config | |
from apted.helpers import Tree | |
from lxml import etree, html | |
from collections import deque | |
from .parallel import parallel_process | |
from tqdm import tqdm | |
class TableTree(Tree): | |
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): | |
self.tag = tag | |
self.colspan = colspan | |
self.rowspan = rowspan | |
self.content = content | |
self.children = list(children) | |
def bracket(self): | |
"""Show tree using brackets notation""" | |
if self.tag == 'td': | |
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ | |
(self.tag, self.colspan, self.rowspan, self.content) | |
else: | |
result = '"tag": %s' % self.tag | |
for child in self.children: | |
result += child.bracket() | |
return "{{{}}}".format(result) | |
class CustomConfig(Config): | |
def rename(self, node1, node2): | |
"""Compares attributes of trees""" | |
#print(node1.tag) | |
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): | |
return 1. | |
if node1.tag == 'td': | |
if node1.content or node2.content: | |
#print(node1.content, ) | |
return Levenshtein.normalized_distance(node1.content, node2.content) | |
return 0. | |
class CustomConfig_del_short(Config): | |
def rename(self, node1, node2): | |
"""Compares attributes of trees""" | |
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): | |
return 1. | |
if node1.tag == 'td': | |
if node1.content or node2.content: | |
#print('before') | |
#print(node1.content, node2.content) | |
#print('after') | |
node1_content = node1.content | |
node2_content = node2.content | |
if len(node1_content) < 3: | |
node1_content = ['####'] | |
if len(node2_content) < 3: | |
node2_content = ['####'] | |
return Levenshtein.normalized_distance(node1_content, node2_content) | |
return 0. | |
class CustomConfig_del_block(Config): | |
def rename(self, node1, node2): | |
"""Compares attributes of trees""" | |
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): | |
return 1. | |
if node1.tag == 'td': | |
if node1.content or node2.content: | |
node1_content = node1.content | |
node2_content = node2.content | |
while ' ' in node1_content: | |
print(node1_content.index(' ')) | |
node1_content.pop(node1_content.index(' ')) | |
while ' ' in node2_content: | |
print(node2_content.index(' ')) | |
node2_content.pop(node2_content.index(' ')) | |
return Levenshtein.normalized_distance(node1_content, node2_content) | |
return 0. | |
class TEDS(object): | |
''' Tree Edit Distance basead Similarity | |
''' | |
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): | |
assert isinstance(n_jobs, int) and ( | |
n_jobs >= 1), 'n_jobs must be an integer greather than 1' | |
self.structure_only = structure_only | |
self.n_jobs = n_jobs | |
self.ignore_nodes = ignore_nodes | |
self.__tokens__ = [] | |
def tokenize(self, node): | |
''' Tokenizes table cells | |
''' | |
self.__tokens__.append('<%s>' % node.tag) | |
if node.text is not None: | |
self.__tokens__ += list(node.text) | |
for n in node.getchildren(): | |
self.tokenize(n) | |
if node.tag != 'unk': | |
self.__tokens__.append('</%s>' % node.tag) | |
if node.tag != 'td' and node.tail is not None: | |
self.__tokens__ += list(node.tail) | |
def load_html_tree(self, node, parent=None): | |
''' Converts HTML tree to the format required by apted | |
''' | |
global __tokens__ | |
if node.tag == 'td': | |
if self.structure_only: | |
cell = [] | |
else: | |
self.__tokens__ = [] | |
self.tokenize(node) | |
cell = self.__tokens__[1:-1].copy() | |
new_node = TableTree(node.tag, | |
int(node.attrib.get('colspan', '1')), | |
int(node.attrib.get('rowspan', '1')), | |
cell, *deque()) | |
else: | |
new_node = TableTree(node.tag, None, None, None, *deque()) | |
if parent is not None: | |
parent.children.append(new_node) | |
if node.tag != 'td': | |
for n in node.getchildren(): | |
self.load_html_tree(n, new_node) | |
if parent is None: | |
return new_node | |
def evaluate(self, pred, true): | |
''' Computes TEDS score between the prediction and the ground truth of a | |
given sample | |
''' | |
if (not pred) or (not true): | |
return 0.0 | |
parser = html.HTMLParser(remove_comments=True, encoding='utf-8') | |
pred = html.fromstring(pred, parser=parser) | |
true = html.fromstring(true, parser=parser) | |
if pred.xpath('body/table') and true.xpath('body/table'): | |
pred = pred.xpath('body/table')[0] | |
true = true.xpath('body/table')[0] | |
if self.ignore_nodes: | |
etree.strip_tags(pred, *self.ignore_nodes) | |
etree.strip_tags(true, *self.ignore_nodes) | |
n_nodes_pred = len(pred.xpath(".//*")) | |
n_nodes_true = len(true.xpath(".//*")) | |
n_nodes = max(n_nodes_pred, n_nodes_true) | |
tree_pred = self.load_html_tree(pred) | |
tree_true = self.load_html_tree(true) | |
distance = APTED(tree_pred, tree_true, | |
CustomConfig()).compute_edit_distance() | |
return 1.0 - (float(distance) / n_nodes) | |
else: | |
return 0.0 | |
def batch_evaluate(self, pred_json, true_json): | |
''' Computes TEDS score between the prediction and the ground truth of | |
a batch of samples | |
@params pred_json: {'FILENAME': 'HTML CODE', ...} | |
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} | |
@output: {'FILENAME': 'TEDS SCORE', ...} | |
''' | |
samples = true_json.keys() | |
if self.n_jobs == 1: | |
scores = [self.evaluate(pred_json.get( | |
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] | |
else: | |
inputs = [{'pred': pred_json.get( | |
filename, ''), 'true': true_json[filename]['html']} for filename in samples] | |
scores = parallel_process( | |
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) | |
scores = dict(zip(samples, scores)) | |
return scores | |
def batch_evaluate_html(self, pred_htmls, true_htmls): | |
''' Computes TEDS score between the prediction and the ground truth of | |
a batch of samples | |
''' | |
if self.n_jobs == 1: | |
scores = [self.evaluate(pred_html, true_html) for ( | |
pred_html, true_html) in zip(pred_htmls, true_htmls)] | |
else: | |
inputs = [{"pred": pred_html, "true": true_html} for( | |
pred_html, true_html) in zip(pred_htmls, true_htmls)] | |
scores = parallel_process( | |
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) | |
return scores | |
if __name__ == '__main__': | |
import json | |
import pprint | |
with open('sample_pred.json') as fp: | |
pred_json = json.load(fp) | |
with open('sample_gt.json') as fp: | |
true_json = json.load(fp) | |
teds = TEDS(n_jobs=4) | |
scores = teds.batch_evaluate(pred_json, true_json) | |
pp = pprint.PrettyPrinter() | |
pp.pprint(scores) | |