Spaces:
Paused
Paused
from huggingface_hub import PyTorchModelHubMixin | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Union | |
class FFNN(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config: dict) -> None: | |
super().__init__() | |
self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"]) | |
self.hidden_layers = nn.ModuleList() | |
for layer_index in range(1, config["num_layers"]): | |
self.hidden_layers.append( | |
nn.Linear(config["hidden_dim"], config["hidden_dim"]) | |
) | |
self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"]) | |
self.id2label = config["id2label"] | |
self.label2id = config["label2id"] | |
def forward(self, embeddings: torch.Tensor) -> torch.Tensor: | |
z = F.relu(self.input_layer(embeddings)) | |
for hidden_layer in self.hidden_layers: | |
z = F.relu(hidden_layer(z)) | |
output = self.output_layer(z) | |
return F.softmax(output, dim=0) | |
def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]: | |
if len(logits.shape) != 1 and len(logits.shape) != 2: | |
raise ValueError("logits must either be a 1 or 2 dimensional tensor") | |
if len(logits.shape) == 1: | |
logits = [logits] | |
return [logits_row.argmax().item() for logits_row in logits] | |
def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]: | |
if len(logits.shape) != 1 and len(logits.shape) != 2: | |
raise ValueError("logits must either be a 1 or 2 dimensional tensor") | |
if len(logits.shape) == 1: | |
logits = [logits] | |
labels = [] | |
for logits_row in logits: | |
labels.append(self.id2label[str(logits_row.argmax().item())]) | |
return labels | |
def predict( | |
self, embeddings: torch.Tensor, return_ids: bool = False | |
) -> Union[list[str], list[int]]: | |
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2: | |
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor") | |
with torch.no_grad(): | |
logits = self.forward(embeddings) | |
if return_ids: | |
return self.convert_logits_to_top_ids(logits) | |
return self.convert_logits_to_labels(logits) | |
def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]: | |
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2: | |
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor") | |
with torch.no_grad(): | |
logits = self.forward(embeddings) | |
if len(logits.shape) == 1: | |
logits = [logits] | |
labeled_logits_list = [] | |
for logits_row in logits: | |
labeled_logits = {} | |
for id, logit in enumerate(logits_row): | |
labeled_logits[self.id2label[str(id)]] = logit | |
labeled_logits_list.append(labeled_logits) | |
return labeled_logits_list | |