CS-4700-Demo / FFNN.py
tjl223's picture
add required files for demo
0d812a0
from huggingface_hub import PyTorchModelHubMixin
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
class FFNN(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: dict) -> None:
super().__init__()
self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"])
self.hidden_layers = nn.ModuleList()
for layer_index in range(1, config["num_layers"]):
self.hidden_layers.append(
nn.Linear(config["hidden_dim"], config["hidden_dim"])
)
self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"])
self.id2label = config["id2label"]
self.label2id = config["label2id"]
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
z = F.relu(self.input_layer(embeddings))
for hidden_layer in self.hidden_layers:
z = F.relu(hidden_layer(z))
output = self.output_layer(z)
return F.softmax(output, dim=0)
def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]:
if len(logits.shape) != 1 and len(logits.shape) != 2:
raise ValueError("logits must either be a 1 or 2 dimensional tensor")
if len(logits.shape) == 1:
logits = [logits]
return [logits_row.argmax().item() for logits_row in logits]
def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]:
if len(logits.shape) != 1 and len(logits.shape) != 2:
raise ValueError("logits must either be a 1 or 2 dimensional tensor")
if len(logits.shape) == 1:
logits = [logits]
labels = []
for logits_row in logits:
labels.append(self.id2label[str(logits_row.argmax().item())])
return labels
def predict(
self, embeddings: torch.Tensor, return_ids: bool = False
) -> Union[list[str], list[int]]:
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
with torch.no_grad():
logits = self.forward(embeddings)
if return_ids:
return self.convert_logits_to_top_ids(logits)
return self.convert_logits_to_labels(logits)
def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]:
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
with torch.no_grad():
logits = self.forward(embeddings)
if len(logits.shape) == 1:
logits = [logits]
labeled_logits_list = []
for logits_row in logits:
labeled_logits = {}
for id, logit in enumerate(logits_row):
labeled_logits[self.id2label[str(id)]] = logit
labeled_logits_list.append(labeled_logits)
return labeled_logits_list