Spaces:
Runtime error
Runtime error
from collections import defaultdict | |
from typing import Dict | |
def get_spans_from_bio(bioes_tags, bioes_scores=None): | |
# add a dummy "O" to close final prediction | |
bioes_tags.append("O") | |
# return complex list | |
found_spans = [] | |
# internal variables | |
current_tag_weights: Dict[str, float] = defaultdict(lambda: 0.0) | |
previous_tag = "O-" | |
current_span = [] | |
current_span_scores = [] | |
for idx, bioes_tag in enumerate(bioes_tags): | |
# non-set tags are OUT tags | |
if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_": | |
bioes_tag = "O-" | |
# anything that is not OUT is IN | |
in_span = False if bioes_tag == "O-" else True | |
# does this prediction start a new span? | |
starts_new_span = False | |
# begin and single tags start new spans | |
if bioes_tag[0:2] in ["B-", "S-"]: | |
starts_new_span = True | |
# in IOB format, an I tag starts a span if it follows an O or is a different span | |
if bioes_tag[0:2] == "I-" and previous_tag[2:] != bioes_tag[2:]: | |
starts_new_span = True | |
# single tags that change prediction start new spans | |
if bioes_tag[0:2] in ["S-"] and previous_tag[2:] != bioes_tag[2:]: | |
starts_new_span = True | |
# if an existing span is ended (either by reaching O or starting a new span) | |
if (starts_new_span or not in_span) and len(current_span) > 0: | |
# determine score and value | |
span_score = sum(current_span_scores) / len(current_span_scores) | |
span_value = sorted(current_tag_weights.items(), key=lambda k_v: k_v[1], reverse=True)[0][0] | |
# append to result list | |
found_spans.append((current_span, span_score, span_value)) | |
# reset for-loop variables for new span | |
current_span = [] | |
current_span_scores = [] | |
current_tag_weights = defaultdict(lambda: 0.0) | |
if in_span: | |
current_span.append(idx) | |
current_span_scores.append(bioes_scores[idx] if bioes_scores else 1.0) | |
weight = 1.1 if starts_new_span else 1.0 | |
current_tag_weights[bioes_tag[2:]] += weight | |
# remember previous tag | |
previous_tag = bioes_tag | |
return found_spans |