from typing import Optional, Union import transformers from bs4 import BeautifulSoup class MarkupLMPhishProcessor(transformers.MarkupLMProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.keep_tags_ctx = [ "html", "head", "body", "h1", "h2", "h3", "h4", "h5", "h6", "p", "a", "button", "span", "div", "iframe", "table", ] def _preprocess(self, html_string: str): # Most webpages are huge. BERT's "attention" is limited to 512 tokens. # In order to give the model more context to work with, we strip extraneous # tags/content from the page to help with the binary classification task. soup = BeautifulSoup(html_string, "html.parser") for tag in soup.find_all(True): if tag.name in ("style", "script"): # keep the meaning of the tag, but remove its contents to save space tag.string = "" elif tag.name not in self.keep_tags_ctx: # remove tag, but keep its contents tag.unwrap() return str(soup) def __call__( self, html_strings=None, nodes=None, xpaths=None, node_labels=None, questions=None, add_special_tokens: bool = True, padding: Union[bool, str, transformers.utils.generic.PaddingStrategy] = False, truncation: Union[ bool, str, transformers.tokenization_utils_base.TruncationStrategy ] = None, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, return_tensors: Union[str, transformers.utils.generic.TensorType] = None, **kwargs, ) -> transformers.tokenization_utils_base.BatchEncoding: # custom html_strings preprocessing if html_strings is not None: if isinstance(html_strings, list): html_strings = [self._preprocess(hs) for hs in html_strings] elif isinstance(html_strings, str): html_strings = self._preprocess(html_strings) # invoke the parent method return super().__call__( html_strings, nodes, xpaths, node_labels, questions, add_special_tokens, padding, truncation, max_length, stride, pad_to_multiple_of, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, return_tensors, **kwargs, )