File size: 2,490 Bytes
e70cca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dataclasses import dataclass
from typing import Optional, Tuple

import datasets
import torch
from datasets import load_dataset
from torch import nn
from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.modeling_utils import PreTrainedModel

from .crf import MaskedCRFLoss


@dataclass
class TokenClassifierCRFOutput(TokenClassifierOutput):
    loss: Optional[torch.FloatTensor] = None
    real_path_score: Optional[torch.FloatTensor] = None
    total_score: torch.FloatTensor = None
    best_path_score: torch.FloatTensor = None
    best_path: Optional[torch.LongTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class PretrainedCRFModel(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = AutoModelForTokenClassification.from_pretrained(
            config._name_or_path, config=config
        )
        self.crf_model = MaskedCRFLoss(self.config.num_labels)
        self.post_init()

    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        attention_mask=None,
        labels=None,
        return_best_path=False,
        **kwargs
    ):
        encoder_output = self.encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            **kwargs
        )

        # Convert output to seq length as first dim

        emissions = encoder_output.logits.transpose(1, 0)
        tags = labels.transpose(1, 0)
        mask = tags != -100
        tags = tags.where(mask, 0)  # CRF cant support -100 id

        crf_output = self.crf_model(
            emissions, tags, mask, return_best_path=return_best_path
        )

        # Convert best_path to batch first
        best_path = crf_output.best_path
        if best_path is not None:
            best_path = best_path.transpose(1, 0)

        output = TokenClassifierCRFOutput(
            loss=crf_output.loss,
            real_path_score=crf_output.real_path_score,
            total_score=crf_output.total_score,
            best_path_score=crf_output.best_path_score,
            best_path=best_path,
            hidden_states=encoder_output.hidden_states,
            attentions=encoder_output.attentions,
        )
        return output