|
from abc import ABC |
|
from typing import * |
|
|
|
import torch |
|
from allennlp.common import Registrable |
|
from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary |
|
from allennlp.training.metrics import CategoricalAccuracy |
|
|
|
|
|
class SpanTyping(Registrable, torch.nn.Module, ABC): |
|
""" |
|
Models the probability p(child_label | child_span, parent_span, parent_label). |
|
""" |
|
def __init__( |
|
self, |
|
n_label: int, |
|
label_to_ignore: Optional[List[int]] = None, |
|
): |
|
""" |
|
:param label_to_ignore: Label indexes in this list will be ignored. |
|
Usually this should include NULL, PADDING and UNKNOWN. |
|
""" |
|
super().__init__() |
|
self.label_to_ignore = label_to_ignore or list() |
|
self.acc_metric = CategoricalAccuracy() |
|
self.onto = torch.ones([n_label, n_label], dtype=torch.bool) |
|
self.register_buffer('ontology', self.onto) |
|
|
|
def load_ontology(self, path: str, vocab: Vocabulary): |
|
unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label') |
|
for line in open(path).readlines(): |
|
entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')] |
|
parent, children = entities[0], entities[1:] |
|
if parent == unk_id: |
|
continue |
|
self.onto[parent, :] = False |
|
children = list(filter(lambda x: x != unk_id, children)) |
|
self.onto[parent, children] = True |
|
self.register_buffer('ontology', self.onto) |
|
|
|
def forward( |
|
self, |
|
span_vec: torch.Tensor, |
|
parent_at_span: torch.Tensor, |
|
span_labels: Optional[torch.Tensor], |
|
prediction_only: bool = False, |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Inputs: All features for typing a child span. |
|
Output: The loss of typing and predictions. |
|
:param span_vec: Shape [batch, span, token_dim] |
|
:param parent_at_span: Shape [batch, span] |
|
:param span_labels: Shape [batch, span] |
|
:param prediction_only: If True, no loss returned & metric will not be updated |
|
:return: |
|
loss: Loss for label prediction. (absent of pred_only = True) |
|
prediction: Predicted labels. |
|
""" |
|
raise NotImplementedError |
|
|
|
def get_metric(self, reset): |
|
return{ |
|
"typing_acc": self.acc_metric.get_metric(reset) * 100 |
|
} |
|
|