markuplm-phish / processor.py
pogzyb's picture
Add custom processor
787e5bf verified
raw
history blame
No virus
3.24 kB
from typing import Optional, Union
import transformers
from bs4 import BeautifulSoup
class MarkupLMPhishProcessor(transformers.MarkupLMProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.keep_tags_ctx = [
"html",
"head",
"body",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"p",
"a",
"button",
"span",
"div",
"iframe",
"table",
]
def _preprocess(self, html_string: str):
# Most webpages are huge. BERT's "attention" is limited to 512 tokens.
# In order to give the model more context to work with, we strip extraneous
# tags/content from the page to help with the binary classification task.
soup = BeautifulSoup(html_string, "html.parser")
for tag in soup.find_all(True):
if tag.name in ("style", "script"):
# keep the meaning of the tag, but remove its contents to save space
tag.string = ""
elif tag.name not in self.keep_tags_ctx:
# remove tag, but keep its contents
tag.unwrap()
return str(soup)
def __call__(
self,
html_strings=None,
nodes=None,
xpaths=None,
node_labels=None,
questions=None,
add_special_tokens: bool = True,
padding: Union[bool, str, transformers.utils.generic.PaddingStrategy] = False,
truncation: Union[
bool, str, transformers.tokenization_utils_base.TruncationStrategy
] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Union[str, transformers.utils.generic.TensorType] = None,
**kwargs,
) -> transformers.tokenization_utils_base.BatchEncoding:
# custom html_strings preprocessing
if html_strings is not None:
if isinstance(html_strings, list):
html_strings = [self._preprocess(hs) for hs in html_strings]
elif isinstance(html_strings, str):
html_strings = self._preprocess(html_strings)
# invoke the parent method
return super().__call__(
html_strings,
nodes,
xpaths,
node_labels,
questions,
add_special_tokens,
padding,
truncation,
max_length,
stride,
pad_to_multiple_of,
return_token_type_ids,
return_attention_mask,
return_overflowing_tokens,
return_special_tokens_mask,
return_offsets_mapping,
return_length,
verbose,
return_tensors,
**kwargs,
)