File size: 2,536 Bytes
44921ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import re


class WhitespaceTokenSplitter:
    def __init__(self):
        self.whitespace_pattern = re.compile(r"\w+(?:[-_]\w+)*|\S")

    def __call__(self, text):
        for match in self.whitespace_pattern.finditer(text):
            yield match.group(), match.start(), match.end()


tokenizer = WhitespaceTokenSplitter()


def get_char_label_map(ner_spans: list):
    """return a dict with char indices(int) as keys and the label they belong to as values
    example -- {1:'label1', 2: 'label1', 5:'label2', 5:'label2'}
    note: the char indices that do not belong to a span do not exist in the map
    """
    char_label_map = {}
    for span in ner_spans:
        char_label_map = {
            **char_label_map,
            **{
                char_index: span["label"]
                for char_index in range(span["start"], span["end"])
            },
        }
    return char_label_map


def get_tokens(text: str) -> list[str]:
    tokens_with_offsets = list(tokenizer(text))
    return [token for token, start, end in tokens_with_offsets]


def get_token_offsets(text: str) -> list[tuple[int, int]]:
    tokens_with_offsets = list(tokenizer(text))
    return [(start, end) for token, start, end in tokens_with_offsets]


def get_list_of_token_label_tuples(
    tokens: list[str],
    token_spans: list[tuple[int, int]],
    char_label_map: dict[int, str],
) -> list[tuple[str, str]]:
    """
    returns a list of tuples with first element as token and second element as the label
    example - [('a', 'O'), ('cat', 'ANIMAL'), ('sits', 'O')]
    note: the label of a token is decided based on the max chars in the token belonging to a span
    """
    token_labels = []
    for token, offsets in zip(tokens, token_spans):
        if offsets[0] == offsets[1]:
            token_labels.append((token, "O"))
            continue
        char_labels = [
            char_label_map.get(char_index, "O") for char_index in range(*offsets)
        ]
        token_label = max(set(char_labels), key=char_labels.count)
        token_labels.append((token, token_label))
    return token_labels


def get_token_outputs(ner_spans, parent_text):
    char_label_map = get_char_label_map(ner_spans)

    token_offsets = get_token_offsets(parent_text)
    tokens = get_tokens(parent_text)

    return get_list_of_token_label_tuples(tokens, token_offsets, char_label_map)


def get_token_output_labels(ner_spans, parent_text):
    token_output = get_token_outputs(ner_spans, parent_text)
    return [label for token, label in token_output]