from huggingface_hub import PyTorchModelHubMixin import torch import torch.nn as nn import torch.nn.functional as F from typing import Union class FFNN(nn.Module, PyTorchModelHubMixin): def __init__(self, config: dict) -> None: super().__init__() self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"]) self.hidden_layers = nn.ModuleList() for layer_index in range(1, config["num_layers"]): self.hidden_layers.append( nn.Linear(config["hidden_dim"], config["hidden_dim"]) ) self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"]) self.id2label = config["id2label"] self.label2id = config["label2id"] def forward(self, embeddings: torch.Tensor) -> torch.Tensor: z = F.relu(self.input_layer(embeddings)) for hidden_layer in self.hidden_layers: z = F.relu(hidden_layer(z)) output = self.output_layer(z) return F.softmax(output, dim=0) def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]: if len(logits.shape) != 1 and len(logits.shape) != 2: raise ValueError("logits must either be a 1 or 2 dimensional tensor") if len(logits.shape) == 1: logits = [logits] return [logits_row.argmax().item() for logits_row in logits] def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]: if len(logits.shape) != 1 and len(logits.shape) != 2: raise ValueError("logits must either be a 1 or 2 dimensional tensor") if len(logits.shape) == 1: logits = [logits] labels = [] for logits_row in logits: labels.append(self.id2label[str(logits_row.argmax().item())]) return labels def predict( self, embeddings: torch.Tensor, return_ids: bool = False ) -> Union[list[str], list[int]]: if len(embeddings.shape) != 1 and len(embeddings.shape) != 2: raise ValueError("embeddings must either be a 1 or 2 dimensional tensor") with torch.no_grad(): logits = self.forward(embeddings) if return_ids: return self.convert_logits_to_top_ids(logits) return self.convert_logits_to_labels(logits) def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]: if len(embeddings.shape) != 1 and len(embeddings.shape) != 2: raise ValueError("embeddings must either be a 1 or 2 dimensional tensor") with torch.no_grad(): logits = self.forward(embeddings) if len(logits.shape) == 1: logits = [logits] labeled_logits_list = [] for logits_row in logits: labeled_logits = {} for id, logit in enumerate(logits_row): labeled_logits[self.id2label[str(id)]] = logit labeled_logits_list.append(labeled_logits) return labeled_logits_list