File size: 2,797 Bytes
f0c06da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel
from transformers import PreTrainedModel

from .configuration_leaf import LeafConfig
from .mappings import idx_to_ef, idx_to_classname


class LeafModel(PreTrainedModel):
    """
    LEAF model for text classification.
    """
    config_class = LeafConfig

    def __init__(self, config: LeafConfig):
        super().__init__(config)
        self._base_model = AutoModel.from_pretrained(config.model_name)
        self._device = "cuda" if torch.cuda.is_available() else "cpu"

        hidden_dim = self._base_model.config.hidden_size
        self.head = ClassificationHead(hidden_dim=hidden_dim, num_classes=2097,
                                       idx_to_ef=idx_to_ef, idx_to_classname=idx_to_classname,
                                       device=self._device)

    def forward(self, input_ids, attention_mask, **kwargs) -> dict:
        if "classes" not in kwargs:
            kwargs["classes"] = None
        outputs = self._base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        attention_mask = attention_mask.unsqueeze(-1)
        masked_outputs = outputs * attention_mask.type_as(outputs)
        nom = masked_outputs.sum(dim=1)
        denom = attention_mask.sum(dim=1)
        denom = denom.masked_fill(denom == 0, 1)
        return self.head(nom / denom, **kwargs)


class ClassificationHead(nn.Module):
    """
    Model head to predict a categorical target variable.
    """

    def __init__(self, hidden_dim: int, num_classes: int, idx_to_ef: dict, idx_to_classname: Optional[dict],
                 device: str):
        super().__init__()
        self.linear = nn.Linear(in_features=hidden_dim, out_features=num_classes)
        self.loss = nn.CrossEntropyLoss()

        # Turn dict into lookup table
        self.idx_to_ef = torch.Tensor([idx_to_ef[k] for k in sorted(idx_to_ef.keys())]).to(device)
        self.idx_to_ef.requires_grad = False
        self.idx_to_classname = idx_to_classname

    def __call__(self, activations: torch.Tensor, classes: Optional[torch.Tensor], **kwargs) -> dict:
        return_dict = {}
        logits = self.linear(activations)
        return_dict["logits"] = logits
        if classes:
            loss = self.loss(logits, classes)
            return_dict["loss"] = loss
        _, predicted_classes = torch.max(F.softmax(logits, dim=1), dim=1)
        return_dict["class_idx"] = predicted_classes
        return_dict["ef_score"] = self.idx_to_ef[predicted_classes]
        if self.idx_to_classname:
            return_dict["class"] = [self.idx_to_classname[str(c)] for c in
                                    predicted_classes.cpu().numpy()]
        return return_dict