File size: 2,274 Bytes
24d0437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from collections import defaultdict
from typing import Dict


def get_spans_from_bio(bioes_tags, bioes_scores=None):
    # add a dummy "O" to close final prediction
    bioes_tags.append("O")
    # return complex list
    found_spans = []
    # internal variables
    current_tag_weights: Dict[str, float] = defaultdict(lambda: 0.0)
    previous_tag = "O-"
    current_span = []
    current_span_scores = []
    for idx, bioes_tag in enumerate(bioes_tags):

        # non-set tags are OUT tags
        if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_":
            bioes_tag = "O-"

        # anything that is not OUT is IN
        in_span = False if bioes_tag == "O-" else True

        # does this prediction start a new span?
        starts_new_span = False

        # begin and single tags start new spans
        if bioes_tag[0:2] in ["B-", "S-"]:
            starts_new_span = True

        # in IOB format, an I tag starts a span if it follows an O or is a different span
        if bioes_tag[0:2] == "I-" and previous_tag[2:] != bioes_tag[2:]:
            starts_new_span = True

        # single tags that change prediction start new spans
        if bioes_tag[0:2] in ["S-"] and previous_tag[2:] != bioes_tag[2:]:
            starts_new_span = True

        # if an existing span is ended (either by reaching O or starting a new span)
        if (starts_new_span or not in_span) and len(current_span) > 0:
            # determine score and value
            span_score = sum(current_span_scores) / len(current_span_scores)
            span_value = sorted(current_tag_weights.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]

            # append to result list
            found_spans.append((current_span, span_score, span_value))

            # reset for-loop variables for new span
            current_span = []
            current_span_scores = []
            current_tag_weights = defaultdict(lambda: 0.0)

        if in_span:
            current_span.append(idx)
            current_span_scores.append(bioes_scores[idx] if bioes_scores else 1.0)
            weight = 1.1 if starts_new_span else 1.0
            current_tag_weights[bioes_tag[2:]] += weight

        # remember previous tag
        previous_tag = bioes_tag

    return found_spans