File size: 3,053 Bytes
0d812a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from huggingface_hub import PyTorchModelHubMixin

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Union


class FFNN(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: dict) -> None:
        super().__init__()

        self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"])

        self.hidden_layers = nn.ModuleList()
        for layer_index in range(1, config["num_layers"]):
            self.hidden_layers.append(
                nn.Linear(config["hidden_dim"], config["hidden_dim"])
            )

        self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"])

        self.id2label = config["id2label"]
        self.label2id = config["label2id"]

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        z = F.relu(self.input_layer(embeddings))
        for hidden_layer in self.hidden_layers:
            z = F.relu(hidden_layer(z))
        output = self.output_layer(z)
        return F.softmax(output, dim=0)

    def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]:
        if len(logits.shape) != 1 and len(logits.shape) != 2:
            raise ValueError("logits must either be a 1 or 2 dimensional tensor")

        if len(logits.shape) == 1:
            logits = [logits]

        return [logits_row.argmax().item() for logits_row in logits]

    def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]:
        if len(logits.shape) != 1 and len(logits.shape) != 2:
            raise ValueError("logits must either be a 1 or 2 dimensional tensor")

        if len(logits.shape) == 1:
            logits = [logits]

        labels = []
        for logits_row in logits:
            labels.append(self.id2label[str(logits_row.argmax().item())])

        return labels

    def predict(
        self, embeddings: torch.Tensor, return_ids: bool = False
    ) -> Union[list[str], list[int]]:
        if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
            raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")

        with torch.no_grad():
            logits = self.forward(embeddings)

            if return_ids:
                return self.convert_logits_to_top_ids(logits)

            return self.convert_logits_to_labels(logits)

    def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]:
        if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
            raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")

        with torch.no_grad():
            logits = self.forward(embeddings)

            if len(logits.shape) == 1:
                logits = [logits]

            labeled_logits_list = []

            for logits_row in logits:
                labeled_logits = {}
                for id, logit in enumerate(logits_row):
                    labeled_logits[self.id2label[str(id)]] = logit

                labeled_logits_list.append(labeled_logits)

            return labeled_logits_list