import os from typing import * import torch from allennlp.common.from_params import Params, T, pop_and_construct_arg from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN, DEFAULT_OOV_TOKEN from allennlp.models.model import Model from allennlp.modules import TextFieldEmbedder from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import Seq2SeqEncoder from allennlp.modules.span_extractors import SpanExtractor from allennlp.training.metrics import Metric from ..metrics import ExactMatch from ..modules import SpanFinder, SpanTyping from ..utils import num2mask, VIRTUAL_ROOT, Span, tensor2span @Model.register("span") class SpanModel(Model): """ Identify/Find spans; link them as a tree; label them. """ default_predictor = 'span' def __init__( self, vocab: Vocabulary, # Modules word_embedding: TextFieldEmbedder, span_extractor: SpanExtractor, span_finder: SpanFinder, span_typing: SpanTyping, # Config typing_loss_factor: float = 1., max_recursion_depth: int = -1, max_decoding_spans: int = -1, debug: bool = False, # Ontology Constraints ontology_path: Optional[str] = None, # Metrics metrics: Optional[List[Metric]] = None, ) -> None: """ Note for jsonnet file: it doesn't strictly follow the init examples of every module for that we override the from_params method. You can either check the SpanModel.from_params or the example jsonnet file. :param vocab: No need to specify. ## Modules :param word_embedding: Refer to the module doc. :param span_extractor: Refer to the module doc. :param span_finder: Refer to the module doc. :param span_typing: Refer to the module doc. ## Configs :param typing_loss_factor: loss = span_finder_loss + span_typing_loss * typing_loss_factor :param max_recursion_depth: Maximum tree depth for inference. E.g., 1 for shallow event typing, 2 for SRL, -1 (unlimited) for dependency parsing. :param max_decoding_spans: Maximum spans for inference. -1 for unlimited. :param debug: Useless now. """ self._pad_idx = vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'token') self._null_idx = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label') super().__init__(vocab) self.word_embedding = word_embedding self._span_finder = span_finder self._span_extractor = span_extractor self._span_typing = span_typing self.metrics = [ExactMatch(True), ExactMatch(False)] if metrics is not None: self.metrics.extend(metrics) if ontology_path is not None and os.path.exists(ontology_path): self._span_typing.load_ontology(ontology_path, self.vocab) self._max_decoding_spans = max_decoding_spans self._typing_loss_factor = typing_loss_factor self._max_recursion_depth = max_recursion_depth self.debug = debug def forward( self, tokens: Dict[str, Dict[str, torch.Tensor]], span_boundary: Optional[torch.Tensor] = None, span_labels: Optional[torch.Tensor] = None, parent_indices: Optional[torch.Tensor] = None, parent_mask: Optional[torch.Tensor] = None, bio_seqs: Optional[torch.Tensor] = None, raw_inputs: Optional[dict] = None, meta: Optional[dict] = None, **extra ) -> Dict[str, torch.Tensor]: """ For training, provide all blow. For inference, it's enough to only provide words. :param tokens: Indexed input sentence. Shape: [batch, token] :param span_boundary: Start and end indices for every span. Note this includes both parent and non-parent spans. Shape: [batch, span, 2]. For the last dim, [0] is start idx and [1] is end idx. :param span_labels: Indexed label for spans, including parent and non-parent ones. Shape: [batch, span] :param parent_indices: The parent span idx of every span. Shape: [batch, span] :param parent_mask: True if this span is a parent. Shape: [batch, span] :param bio_seqs: Shape [batch, parent, token, 3] :param raw_inputs :param meta: Meta information. Will be copied to the outputs. :return: - loss: training loss - prediction: Predicted spans - meta: Meta info copied from input - inputs: Input sentences and spans (if exist) """ ret = {'inputs': raw_inputs, 'meta': meta or dict()} is_eval = span_labels is not None and not self.training # evaluation on dev set is_test = span_labels is None # test on test set # Shape [batch] num_spans = (span_labels != -1).sum(1) if span_labels is not None else None num_words = tokens['pieces']['mask'].sum(1) # Shape [batch, word, token_dim] token_vec = self.word_embedding(tokens) if span_labels is not None: # Revise the padding value from -1 to 0 span_labels[span_labels == -1] = 0 # Calculate Loss if self.training or is_eval: # Shape [batch, word, token_dim] span_vec = self._span_extractor(token_vec, span_boundary) finder_rst = self._span_finder( token_vec, num2mask(num_words), span_vec, num2mask(num_spans), span_labels, parent_indices, parent_mask, bio_seqs ) typing_rst = self._span_typing(span_vec, parent_indices, span_labels) ret['loss'] = finder_rst['loss'] + typing_rst['loss'] * self._typing_loss_factor # Decoding if is_eval or is_test: pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence \ = self.inference(num_words, token_vec, **extra) prediction = self.post_process_pred( pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence ) for pred, raw_in in zip(prediction, raw_inputs): pred.re_index(raw_in['offsets'], True, True, True) pred.remove_overlapping() ret['prediction'] = prediction if 'spans' in raw_inputs[0]: for pred, raw_in in zip(prediction, raw_inputs): gold = raw_in['spans'] for metric in self.metrics: metric(pred, gold) return ret def inference( self, num_words: torch.Tensor, token_vec: torch.Tensor, **auxiliaries ): n_batch = num_words.shape[0] # The decoding results are preserved in the following tensors starting with `pred` # During inference, we completely ignore the arguments defaulted None in the forward method. # The span indexing space is shift to the decoding span space. (since we do not have gold span now) # boundary indices of every predicted span pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2]) # labels (and corresponding confidence) for predicted spans pred_span_labels = num_words.new_full( [n_batch, self._max_decoding_spans], self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label') ) pred_label_confidence = num_words.new_zeros([n_batch, self._max_decoding_spans]) # label masked as True will be treated as parent in the next round pred_parent_mask = num_words.new_zeros([n_batch, self._max_decoding_spans], dtype=torch.bool) pred_parent_mask[:, 0] = True # parent index (in the span indexing space) for every span pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans]) # what index have we reached for every batch? pred_cursor = num_words.new_ones([n_batch]) # Pass environment variables to handler. Extra variables will be ignored. # So pass the union of variables that are needed by different modules. span_find_handler = self._span_finder.inference_forward_handler( token_vec, num2mask(num_words), self._span_extractor, **auxiliaries ) # Every step here is one layer of the tree. It deals with all the parents for the last layer # so there might be 0 to multiple parents for a batch for a single step. for _ in range(self._max_recursion_depth): cursor_before_find = pred_cursor.clone() span_find_handler( pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor ) # Labels of old spans are re-predicted. It doesn't matter since their results shouldn't change # in theory. span_typing_ret = self._span_typing( self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True ) pred_span_labels = span_typing_ret['prediction'] pred_label_confidence = span_typing_ret['label_confidence'] pred_span_labels[:, 0] = self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label') pred_parent_mask = ( num2mask(cursor_before_find, self._max_decoding_spans) ^ num2mask(pred_cursor, self._max_decoding_spans) ) # Break the inference loop if 1) all batches reach max span limit OR 2) no parent is predicted # at last step OR 3) max recursion limit is reached (for loop condition) if (pred_cursor == self._max_decoding_spans).all() or pred_parent_mask.sum() == 0: break return pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence def one_step_prediction( self, tokens: Dict[str, Dict[str, torch.Tensor]], parent_boundary: torch.Tensor, parent_labels: torch.Tensor, ): """ Single step prediction. Given parent span boundary indices, return the corresponding children spans and their labels. Restriction: Each sentence contain exactly 1 parent. For efficient multi-layer prediction, i.e. given a root, predict the whole tree, refer to the `forward' method. :param tokens: See forward. :param parent_boundary: Pairs of (start_idx, end_idx) for parents. Shape [batch, 2] :param parent_labels: Labels for parents. Shape [batch] Note: If `no_label' is on in span_finder module, this will be ignored. :return: children_boundary: (start_idx, end_idx) for every child span. Padded with (0, 0). Shape [batch, children, 2] children_labels: Label for every child span. Padded with null_idx. Shape [batch, children] num_children: The number of children predicted for parent/batch. Shape [batch] Tips: You can use num2mask method to convert this to bool tensor mask. """ num_words = tokens['pieces']['mask'].sum(1) # Shape [batch, word, token_dim] token_vec = self.word_embedding(tokens) n_batch = token_vec.shape[0] # The following variables assumes the parent is the 0-th span, and we let the model # to extend the span list. pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2]) pred_span_boundary[:, 0] = parent_boundary pred_span_labels = num_words.new_full([n_batch, self._max_decoding_spans], self._null_idx) pred_span_labels[:, 0] = parent_labels pred_parent_mask = num_words.new_zeros(pred_span_labels.shape, dtype=torch.bool) pred_parent_mask[:, 0] = True pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans]) # We start from idx 1 since 0 is the parents. pred_cursor = num_words.new_ones([n_batch]) span_find_handler = self._span_finder.inference_forward_handler( token_vec, num2mask(num_words), self._span_extractor ) span_find_handler( pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor ) typing_out = self._span_typing( self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True ) pred_span_labels = typing_out['prediction'] # Now remove the parent num_children = pred_cursor - 1 max_children = int(num_children.max()) children_boundary = pred_span_boundary[:, 1:max_children + 1] children_labels = pred_span_labels[:, 1:max_children + 1] children_distribution = typing_out['distribution'][:, 1:max_children + 1] return children_boundary, children_labels, num_children, children_distribution def post_process_pred( self, span_boundary, span_labels, parent_indices, num_spans, label_confidence ) -> List[Span]: pred_spans = tensor2span( span_boundary, span_labels, parent_indices, num_spans, label_confidence, self.vocab.get_index_to_token_vocabulary('span_label'), label_ignore=[self._null_idx], ) return pred_spans def get_metrics(self, reset: bool = False) -> Dict[str, float]: ret = dict() if reset: for metric in self.metrics: ret.update(metric.get_metric(reset)) ret.update(self._span_finder.get_metrics(reset)) ret.update(self._span_typing.get_metric(reset)) return ret @classmethod def from_params( cls: Type[T], params: Params, constructor_to_call: Callable[..., T] = None, constructor_to_inspect: Callable[..., T] = None, **extras, ) -> T: """ Specify the dependency between modules. E.g. the input dim of a module might depend on the output dim of another module. """ vocab = extras['vocab'] word_embedding = pop_and_construct_arg('SpanModel', 'word_embedding', TextFieldEmbedder, None, params, **extras) label_dim, token_emb_dim = params.pop('label_dim'), word_embedding.get_output_dim() span_extractor = pop_and_construct_arg( 'SpanModel', 'span_extractor', SpanExtractor, None, params, input_dim=token_emb_dim, **extras ) label_embedding = torch.nn.Embedding(vocab.get_vocab_size('span_label'), label_dim) extras['label_emb'] = label_embedding if params.get('span_finder').get('type') == 'bio': bio_encoder = Seq2SeqEncoder.from_params( params['span_finder'].pop('bio_encoder'), input_size=span_extractor.get_output_dim() + token_emb_dim + label_dim, input_dim=span_extractor.get_output_dim() + token_emb_dim + label_dim, **extras ) extras['span_finder'] = SpanFinder.from_params( params.pop('span_finder'), bio_encoder=bio_encoder, **extras ) else: extras['span_finder'] = pop_and_construct_arg( 'SpanModel', 'span_finder', SpanFinder, None, params, **extras ) extras['span_finder'].label_emb = label_embedding if params.get('span_typing').get('type') == 'mlp': extras['span_typing'] = SpanTyping.from_params( params.pop('span_typing'), input_dim=span_extractor.get_output_dim() * 2 + label_dim, n_category=vocab.get_vocab_size('span_label'), label_to_ignore=[ vocab.get_token_index(lti, 'span_label') for lti in [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN] ], **extras ) else: extras['span_typing'] = pop_and_construct_arg( 'SpanModel', 'span_typing', SpanTyping, None, params, **extras ) extras['span_typing'].label_emb = label_embedding return super().from_params( params, word_embedding=word_embedding, span_extractor=span_extractor, **extras )