|
import logging |
|
from abc import ABC |
|
from typing import * |
|
|
|
import numpy as np |
|
from allennlp.common.util import END_SYMBOL |
|
from allennlp.data.dataset_readers.dataset_reader import DatasetReader |
|
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans |
|
from allennlp.data.fields import * |
|
from allennlp.data.token_indexers import PretrainedTransformerIndexer |
|
from allennlp.data.tokenizers import PretrainedTransformerTokenizer, Token |
|
|
|
from ..utils import Span, BIOSmoothing, apply_bio_smoothing |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@DatasetReader.register('span') |
|
class SpanReader(DatasetReader, ABC): |
|
def __init__( |
|
self, |
|
pretrained_model: str, |
|
max_length: int = 512, |
|
ignore_label: bool = False, |
|
debug: bool = False, |
|
**extras |
|
) -> None: |
|
""" |
|
:param pretrained_model: The name of the pretrained model. E.g. xlm-roberta-large |
|
:param max_length: Sequences longer than this limit will be truncated. |
|
:param ignore_label: If True, label on spans will be anonymized. |
|
:param debug: True to turn on debugging mode. |
|
:param span_proposals: Needed for "enumeration" scheme, but not needed for "BIO". |
|
If True, it will try to enumerate candidate spans in the sentence, which will then be fed into |
|
a binary classifier (EnumSpanFinder). |
|
Note: It might take time to propose spans. And better to use SpacyTokenizer if you want to call |
|
constituency parser or dependency parser. |
|
:param maximum_negative_spans: Necessary for EnumSpanFinder. |
|
:param extras: Args to DatasetReader. |
|
""" |
|
super().__init__(**extras) |
|
self.word_indexer = { |
|
'pieces': PretrainedTransformerIndexer(pretrained_model, namespace='pieces') |
|
} |
|
|
|
self._pretrained_model_name = pretrained_model |
|
self.debug = debug |
|
self.ignore_label = ignore_label |
|
|
|
self._pretrained_tokenizer = PretrainedTransformerTokenizer(pretrained_model) |
|
self.max_length = max_length |
|
self.n_span_removed = 0 |
|
|
|
def retokenize( |
|
self, sentence: List[str], truncate: bool = True |
|
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: |
|
pieces, offsets = self._pretrained_tokenizer.intra_word_tokenize(sentence) |
|
pieces = list(map(str, pieces)) |
|
if truncate: |
|
pieces = pieces[:self.max_length] |
|
pieces[-1] = END_SYMBOL |
|
return pieces, offsets |
|
|
|
def prepare_inputs( |
|
self, |
|
sentence: List[str], |
|
spans: Optional[Union[List[Span], Span]] = None, |
|
truncate: bool = True, |
|
label_type: str = 'string', |
|
) -> Dict[str, Field]: |
|
""" |
|
Prepare inputs and auxiliary variables for span model. |
|
:param sentence: A list of tokens. Do not pass in any special tokens, like BOS or EOS. |
|
Necessary for both training and testing. |
|
:param spans: Optional. For training, spans passed in will be considered as positive examples; the spans |
|
that are automatically proposed and not in the positive set will be considered as negative examples. |
|
Necessary for training. |
|
:param truncate: If True, sequence will be truncated if it's longer than `self.max_training_length` |
|
:param label_type: One of [string, list]. |
|
|
|
:return: Dict of AllenNLP fields. For detailed of explanation of every field, refer to the comments |
|
below. For the shape of every field, check the module doc. |
|
Fields list: |
|
- words |
|
- span_labels |
|
- span_boundary |
|
- parent_indices |
|
- parent_mask |
|
- bio_seqs |
|
- raw_sentence |
|
- raw_spans |
|
- proposed_spans |
|
""" |
|
fields = dict() |
|
|
|
pieces, offsets = self.retokenize(sentence, truncate) |
|
fields['tokens'] = TextField(list(map(Token, pieces)), self.word_indexer) |
|
raw_inputs = {'sentence': sentence, "pieces": pieces, 'offsets': offsets} |
|
fields['raw_inputs'] = MetadataField(raw_inputs) |
|
|
|
if spans is None: |
|
return fields |
|
|
|
vr = spans if isinstance(spans, Span) else Span.virtual_root(spans) |
|
self.n_span_removed = vr.remove_overlapping() |
|
raw_inputs['spans'] = vr |
|
|
|
vr = vr.re_index(offsets) |
|
if truncate: |
|
vr.truncate(self.max_length) |
|
if self.ignore_label: |
|
vr.ignore_labels() |
|
|
|
|
|
|
|
span_boundary = list() |
|
|
|
span_labels = list() |
|
|
|
span_parent_indices = list() |
|
|
|
parent_mask = [False] * vr.n_nodes |
|
|
|
flatten_spans = list(vr.bfs()) |
|
for span_idx, span in enumerate(vr.bfs()): |
|
if span.is_parent: |
|
parent_mask[span_idx] = True |
|
|
|
parent_idx = flatten_spans.index(span.parent) if span.parent else 0 |
|
span_parent_indices.append(parent_idx) |
|
span_boundary.append(span.boundary) |
|
span_labels.append(span.label) |
|
|
|
bio_tag_list: List[List[str]] = list() |
|
bio_configs: List[List[BIOSmoothing]] = list() |
|
|
|
bio_seqs: List[np.ndarray] = list() |
|
|
|
for parent_idx, parent in filter(lambda node: node[1].is_parent, enumerate(flatten_spans)): |
|
bio_tags = ['O'] * len(pieces) |
|
bio_tag_list.append(bio_tags) |
|
bio_smooth: List[BIOSmoothing] = [parent.child_smooth.clone() for _ in pieces] |
|
bio_configs.append(bio_smooth) |
|
for child in parent: |
|
assert all(bio_tags[bio_idx] == 'O' for bio_idx in range(child.start_idx, child.end_idx + 1)) |
|
if child.smooth_weight is not None: |
|
for i in range(child.start_idx, child.end_idx+1): |
|
bio_smooth[i].weight = child.smooth_weight |
|
bio_tags[child.start_idx] = 'B' |
|
for word_idx in range(child.start_idx + 1, child.end_idx + 1): |
|
bio_tags[word_idx] = 'I' |
|
bio_seqs.append(apply_bio_smoothing(bio_smooth, bio_tags)) |
|
|
|
fields['span_boundary'] = ArrayField( |
|
np.array(span_boundary), padding_value=0, dtype=np.int |
|
) |
|
fields['parent_indices'] = ArrayField(np.array(span_parent_indices), 0, np.int) |
|
if label_type == 'string': |
|
fields['span_labels'] = ListField([LabelField(label, 'span_label') for label in span_labels]) |
|
elif label_type == 'list': |
|
fields['span_labels'] = ArrayField(np.array(span_labels)) |
|
else: |
|
raise NotImplementedError |
|
fields['parent_mask'] = ArrayField(np.array(parent_mask), False, np.bool) |
|
fields['bio_seqs'] = ArrayField(np.stack(bio_seqs)) |
|
|
|
self._sanity_check( |
|
flatten_spans, pieces, bio_tag_list, parent_mask, span_boundary, span_labels, span_parent_indices |
|
) |
|
|
|
return fields |
|
|
|
@staticmethod |
|
def _sanity_check( |
|
flatten_spans, words, bio_tag_list, parent_mask, span_boundary, span_labels, parent_indices, verbose=False |
|
): |
|
|
|
assert len(parent_mask) == len(span_boundary) == len(span_labels) == len(parent_indices) |
|
for (parent_idx, parent_span), bio_tags in zip( |
|
filter(lambda x: x[1].is_parent, enumerate(flatten_spans)), bio_tag_list |
|
): |
|
assert parent_mask[parent_idx] |
|
parent_s, parent_e = span_boundary[parent_idx] |
|
if verbose: |
|
print('Parent: ', span_labels[parent_idx], 'Text: ', ' '.join(words[parent_s:parent_e+1])) |
|
print(f'It contains {len(parent_span)} children.') |
|
for child in parent_span: |
|
child_idx = flatten_spans.index(child) |
|
assert parent_indices[child_idx] == flatten_spans.index(parent_span) |
|
if verbose: |
|
child_s, child_e = span_boundary[child_idx] |
|
print(' ', span_labels[child_idx], 'Text', words[child_s:child_e+1]) |
|
|
|
if verbose: |
|
print(f'Child derived from BIO tags:') |
|
for _, (start, end) in bio_tags_to_spans(bio_tags): |
|
print(words[start:end+1]) |
|
|