import re class WhitespaceTokenSplitter: def __init__(self): self.whitespace_pattern = re.compile(r"\w+(?:[-_]\w+)*|\S") def __call__(self, text): for match in self.whitespace_pattern.finditer(text): yield match.group(), match.start(), match.end() tokenizer = WhitespaceTokenSplitter() def get_char_label_map(ner_spans: list): """return a dict with char indices(int) as keys and the label they belong to as values example -- {1:'label1', 2: 'label1', 5:'label2', 5:'label2'} note: the char indices that do not belong to a span do not exist in the map """ char_label_map = {} for span in ner_spans: char_label_map = { **char_label_map, **{ char_index: span["label"] for char_index in range(span["start"], span["end"]) }, } return char_label_map def get_tokens(text: str) -> list[str]: tokens_with_offsets = list(tokenizer(text)) return [token for token, start, end in tokens_with_offsets] def get_token_offsets(text: str) -> list[tuple[int, int]]: tokens_with_offsets = list(tokenizer(text)) return [(start, end) for token, start, end in tokens_with_offsets] def get_list_of_token_label_tuples( tokens: list[str], token_spans: list[tuple[int, int]], char_label_map: dict[int, str], ) -> list[tuple[str, str]]: """ returns a list of tuples with first element as token and second element as the label example - [('a', 'O'), ('cat', 'ANIMAL'), ('sits', 'O')] note: the label of a token is decided based on the max chars in the token belonging to a span """ token_labels = [] for token, offsets in zip(tokens, token_spans): if offsets[0] == offsets[1]: token_labels.append((token, "O")) continue char_labels = [ char_label_map.get(char_index, "O") for char_index in range(*offsets) ] token_label = max(set(char_labels), key=char_labels.count) token_labels.append((token, token_label)) return token_labels def get_token_outputs(ner_spans, parent_text): char_label_map = get_char_label_map(ner_spans) token_offsets = get_token_offsets(parent_text) tokens = get_tokens(parent_text) return get_list_of_token_label_tuples(tokens, token_offsets, char_label_map) def get_token_output_labels(ner_spans, parent_text): token_output = get_token_outputs(ner_spans, parent_text) return [label for token, label in token_output]