VulBERTa-mlm / tokenization_vulberta.py
claudios's picture
Upload folder using huggingface_hub
18c9b3e verified
from typing import List
from tokenizers import NormalizedString, PreTokenizedString
from tokenizers.pre_tokenizers import PreTokenizer
from transformers import PreTrainedTokenizerFast
try:
from clang import cindex
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"VulBERTa Clang tokenizer requires `libclang`. Please install it via `pip install libclang`.",
) from e
class ClangPreTokenizer:
cidx = cindex.Index.create()
def clang_split(
self,
i: int,
normalized_string: NormalizedString,
) -> List[NormalizedString]:
tok = []
tu = self.cidx.parse(
"tmp.c",
args=[""],
unsaved_files=[("tmp.c", str(normalized_string.original))],
options=0,
)
for t in tu.get_tokens(extent=tu.cursor.extent):
spelling = t.spelling.strip()
if spelling == "":
continue
tok.append(NormalizedString(spelling))
return tok
def pre_tokenize(self, pretok: PreTokenizedString):
pretok.split(self.clang_split)
class VulBERTaTokenizer(PreTrainedTokenizerFast):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(
*args,
**kwargs,
)
self._tokenizer.pre_tokenizer = PreTokenizer.custom(ClangPreTokenizer())