Spaces:
Paused
Paused
File size: 3,053 Bytes
0d812a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from huggingface_hub import PyTorchModelHubMixin
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
class FFNN(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: dict) -> None:
super().__init__()
self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"])
self.hidden_layers = nn.ModuleList()
for layer_index in range(1, config["num_layers"]):
self.hidden_layers.append(
nn.Linear(config["hidden_dim"], config["hidden_dim"])
)
self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"])
self.id2label = config["id2label"]
self.label2id = config["label2id"]
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
z = F.relu(self.input_layer(embeddings))
for hidden_layer in self.hidden_layers:
z = F.relu(hidden_layer(z))
output = self.output_layer(z)
return F.softmax(output, dim=0)
def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]:
if len(logits.shape) != 1 and len(logits.shape) != 2:
raise ValueError("logits must either be a 1 or 2 dimensional tensor")
if len(logits.shape) == 1:
logits = [logits]
return [logits_row.argmax().item() for logits_row in logits]
def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]:
if len(logits.shape) != 1 and len(logits.shape) != 2:
raise ValueError("logits must either be a 1 or 2 dimensional tensor")
if len(logits.shape) == 1:
logits = [logits]
labels = []
for logits_row in logits:
labels.append(self.id2label[str(logits_row.argmax().item())])
return labels
def predict(
self, embeddings: torch.Tensor, return_ids: bool = False
) -> Union[list[str], list[int]]:
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
with torch.no_grad():
logits = self.forward(embeddings)
if return_ids:
return self.convert_logits_to_top_ids(logits)
return self.convert_logits_to_labels(logits)
def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]:
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
with torch.no_grad():
logits = self.forward(embeddings)
if len(logits.shape) == 1:
logits = [logits]
labeled_logits_list = []
for logits_row in logits:
labeled_logits = {}
for id, logit in enumerate(logits_row):
labeled_logits[self.id2label[str(id)]] = logit
labeled_logits_list.append(labeled_logits)
return labeled_logits_list
|