from abc import ABC from typing import * import torch from allennlp.common import Registrable from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary from allennlp.training.metrics import CategoricalAccuracy class SpanTyping(Registrable, torch.nn.Module, ABC): """ Models the probability p(child_label | child_span, parent_span, parent_label). """ def __init__( self, n_label: int, label_to_ignore: Optional[List[int]] = None, ): """ :param label_to_ignore: Label indexes in this list will be ignored. Usually this should include NULL, PADDING and UNKNOWN. """ super().__init__() self.label_to_ignore = label_to_ignore or list() self.acc_metric = CategoricalAccuracy() self.onto = torch.ones([n_label, n_label], dtype=torch.bool) self.register_buffer('ontology', self.onto) def load_ontology(self, path: str, vocab: Vocabulary): unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label') for line in open(path).readlines(): entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')] parent, children = entities[0], entities[1:] if parent == unk_id: continue self.onto[parent, :] = False children = list(filter(lambda x: x != unk_id, children)) self.onto[parent, children] = True self.register_buffer('ontology', self.onto) def forward( self, span_vec: torch.Tensor, parent_at_span: torch.Tensor, span_labels: Optional[torch.Tensor], prediction_only: bool = False, ) -> Dict[str, torch.Tensor]: """ Inputs: All features for typing a child span. Output: The loss of typing and predictions. :param span_vec: Shape [batch, span, token_dim] :param parent_at_span: Shape [batch, span] :param span_labels: Shape [batch, span] :param prediction_only: If True, no loss returned & metric will not be updated :return: loss: Loss for label prediction. (absent of pred_only = True) prediction: Predicted labels. """ raise NotImplementedError def get_metric(self, reset): return{ "typing_acc": self.acc_metric.get_metric(reset) * 100 }