HTML-Pruner-Phi-3.8B / tree_gen_utils.py
zstanjj's picture
Upload 13 files
3f61ae8 verified
from collections import defaultdict
from typing import List, Tuple
import numpy as np
from anytree import Node, RenderTree
import bs4
from anytree import PreOrderIter
from anytree.exporter import DotExporter
def nodenamefunc(node):
return f"{node.name}|{node.prob}|{node.input_ids}"
class TokenDotExporter(DotExporter):
def __init__(self, node, **kwargs):
super().__init__(node, **kwargs)
def __iter__(self):
# prepare
indent = " " * self.indent
nodenamefunc = self.nodenamefunc or self._default_nodenamefunc
nodeattrfunc = self.nodeattrfunc or self._default_nodeattrfunc
edgeattrfunc = self.edgeattrfunc or self._default_edgeattrfunc
edgetypefunc = self.edgetypefunc or self._default_edgetypefunc
filter_ = self.filter_ or self._default_filter
return self.__iter(indent, nodenamefunc, nodeattrfunc, edgeattrfunc, edgetypefunc, filter_)
def __iter_nodes(self, indent, nodenamefunc, nodeattrfunc, filter_):
for node in PreOrderIter(self.node, filter_=filter_, stop=self.stop, maxlevel=self.maxlevel):
nodename = nodenamefunc(node)
nodeattr = nodeattrfunc(node)
nodeattr = " {%s}" % nodeattr if nodeattr is not None else ""
yield '%s%s' % (DotExporter.esc(nodename), nodeattr)
def __iter(self, indent, nodenamefunc, nodeattrfunc, edgeattrfunc, edgetypefunc, filter_):
for node in self.__iter_nodes(indent, nodenamefunc, nodeattrfunc, filter_):
yield node
class TokenIdNode(Node):
def __init__(self, name, parent=None, children=None, **kwargs):
super().__init__(name, parent, children, **kwargs)
self.input_ids = kwargs.get('input_ids', [])
self.prob = kwargs.get('prob', np.float32(0.0))
def split_tree(soup: bs4.BeautifulSoup, max_node_words=0) -> List[Tuple[bs4.element.Tag, List[str], bool]]:
word_count = len(soup.get_text().split())
if word_count > max_node_words:
possible_trees = [(soup, [])]
target_trees = [] # [(tag, path, is_leaf)]
# split the entire dom tee into subtrees, until the length of the subtree is less than max_node_words words
# find all possible trees
while True:
if len(possible_trees) == 0:
break
tree = possible_trees.pop(0)
tag_children = defaultdict(int)
bare_word_count = 0
# count child tags
for child in tree[0].contents:
if isinstance(child, bs4.element.Tag):
tag_children[child.name] += 1
_tag_children = {k: 0 for k in tag_children.keys()}
# check if the tree can be split
for child in tree[0].contents:
if isinstance(child, bs4.element.Tag):
# change child tag with duplicate names
if tag_children[child.name] > 1:
new_name = f"{child.name}{_tag_children[child.name]}"
new_tree = (child, tree[1] + [new_name])
_tag_children[child.name] += 1
child.name = new_name
else:
new_tree = (child, tree[1] + [child.name])
word_count = len(child.get_text().split())
# add node with more than max_node_words words, and recursion depth is less than 64
if word_count > max_node_words and len(new_tree[1]) < 64:
possible_trees.append(new_tree)
else:
target_trees.append((new_tree[0], new_tree[1], True))
else:
bare_word_count += len(str(child).split())
# add leaf node
if len(tag_children) == 0:
target_trees.append((tree[0], tree[1], True))
# add node with more than max_node_words bare words
elif bare_word_count > max_node_words:
target_trees.append((tree[0], tree[1], False))
else:
soup_children = [c for c in soup.contents if isinstance(c, bs4.element.Tag)]
if len(soup_children) == 1:
target_trees = [(soup_children[0], [soup_children[0].name], True)]
else:
# add an html tag to wrap all children
new_soup = bs4.BeautifulSoup("", 'html.parser')
new_tag = new_soup.new_tag("html")
new_soup.append(new_tag)
for child in soup_children:
new_tag.append(child)
target_trees = [(new_tag, ["html"], True)]
return target_trees