File size: 2,437 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from abc import ABC
from typing import *

import torch
from allennlp.common import Registrable
from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary
from allennlp.training.metrics import CategoricalAccuracy


class SpanTyping(Registrable, torch.nn.Module, ABC):
    """
    Models the probability p(child_label | child_span, parent_span, parent_label).
    """
    def __init__(
            self,
            n_label: int,
            label_to_ignore: Optional[List[int]] = None,
    ):
        """
        :param label_to_ignore: Label indexes in this list will be ignored.
            Usually this should include NULL, PADDING and UNKNOWN.
        """
        super().__init__()
        self.label_to_ignore = label_to_ignore or list()
        self.acc_metric = CategoricalAccuracy()
        self.onto = torch.ones([n_label, n_label], dtype=torch.bool)
        self.register_buffer('ontology', self.onto)

    def load_ontology(self, path: str, vocab: Vocabulary):
        unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
        for line in open(path).readlines():
            entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')]
            parent, children = entities[0], entities[1:]
            if parent == unk_id:
                continue
            self.onto[parent, :] = False
            children = list(filter(lambda x: x != unk_id, children))
            self.onto[parent, children] = True
        self.register_buffer('ontology', self.onto)

    def forward(
            self,
            span_vec: torch.Tensor,
            parent_at_span: torch.Tensor,
            span_labels: Optional[torch.Tensor],
            prediction_only: bool = False,
    ) -> Dict[str, torch.Tensor]:
        """
        Inputs: All features for typing a child span.
        Output: The loss of typing and predictions.
        :param span_vec: Shape [batch, span, token_dim]
        :param parent_at_span: Shape [batch, span]
        :param span_labels: Shape [batch, span]
        :param prediction_only: If True, no loss returned & metric will not be updated
        :return:
            loss: Loss for label prediction. (absent of pred_only = True)
            prediction: Predicted labels.
        """
        raise NotImplementedError

    def get_metric(self, reset):
        return{
            "typing_acc": self.acc_metric.get_metric(reset) * 100
        }