Spaces:
Build error
Build error
File size: 2,437 Bytes
6680682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from abc import ABC
from typing import *
import torch
from allennlp.common import Registrable
from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary
from allennlp.training.metrics import CategoricalAccuracy
class SpanTyping(Registrable, torch.nn.Module, ABC):
"""
Models the probability p(child_label | child_span, parent_span, parent_label).
"""
def __init__(
self,
n_label: int,
label_to_ignore: Optional[List[int]] = None,
):
"""
:param label_to_ignore: Label indexes in this list will be ignored.
Usually this should include NULL, PADDING and UNKNOWN.
"""
super().__init__()
self.label_to_ignore = label_to_ignore or list()
self.acc_metric = CategoricalAccuracy()
self.onto = torch.ones([n_label, n_label], dtype=torch.bool)
self.register_buffer('ontology', self.onto)
def load_ontology(self, path: str, vocab: Vocabulary):
unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
for line in open(path).readlines():
entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')]
parent, children = entities[0], entities[1:]
if parent == unk_id:
continue
self.onto[parent, :] = False
children = list(filter(lambda x: x != unk_id, children))
self.onto[parent, children] = True
self.register_buffer('ontology', self.onto)
def forward(
self,
span_vec: torch.Tensor,
parent_at_span: torch.Tensor,
span_labels: Optional[torch.Tensor],
prediction_only: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Inputs: All features for typing a child span.
Output: The loss of typing and predictions.
:param span_vec: Shape [batch, span, token_dim]
:param parent_at_span: Shape [batch, span]
:param span_labels: Shape [batch, span]
:param prediction_only: If True, no loss returned & metric will not be updated
:return:
loss: Loss for label prediction. (absent of pred_only = True)
prediction: Predicted labels.
"""
raise NotImplementedError
def get_metric(self, reset):
return{
"typing_acc": self.acc_metric.get_metric(reset) * 100
}
|