|
|
from typing import Dict
|
|
|
from torch import nn
|
|
|
import torch
|
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
|
|
class ClassifierHead(
|
|
|
nn.Module,
|
|
|
PyTorchModelHubMixin,
|
|
|
repo_url="https://huggingface.co/davidgray/health-query-triage",
|
|
|
pipeline_tag="text-classification",
|
|
|
library_name="PyTorch",
|
|
|
tags=["medical", "classification"],
|
|
|
):
|
|
|
def __init__(self, num_classes: int, embedding_dim: int = 768):
|
|
|
super().__init__()
|
|
|
|
|
|
self.linear_elu_stack = nn.Sequential(
|
|
|
nn.Linear(embedding_dim, 512),
|
|
|
nn.ELU(),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.Linear(512, 512),
|
|
|
nn.ELU(),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.Linear(512, num_classes),
|
|
|
)
|
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Calculates logits from the sentence embedding.
|
|
|
|
|
|
Args:
|
|
|
features (Dict[str, torch.Tensor]): Output dictionary from the Sentence Transformer body,
|
|
|
containing 'sentence_embedding'.
|
|
|
Returns:
|
|
|
Dict[str, torch.Tensor]: Dictionary with the 'logits' key.
|
|
|
"""
|
|
|
embeddings = features['sentence_embedding']
|
|
|
logits = self.linear_elu_stack(embeddings)
|
|
|
return {"logits": logits}
|
|
|
|
|
|
def predict(self, embeddings: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Classifies embeddings into integer labels in the range [0, num_classes).
|
|
|
|
|
|
Args:
|
|
|
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Integer labels with shape [num_inputs].
|
|
|
"""
|
|
|
|
|
|
proba = self.predict_proba(embeddings)
|
|
|
return torch.argmax(proba, dim=-1)
|
|
|
|
|
|
def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Classifies embeddings into probabilities for each class (summing to 1).
|
|
|
|
|
|
Args:
|
|
|
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Float probabilities with shape [num_inputs, num_classes].
|
|
|
"""
|
|
|
|
|
|
self.eval()
|
|
|
with torch.no_grad():
|
|
|
logits = self.linear_elu_stack(embeddings)
|
|
|
|
|
|
probabilities = self.softmax(logits)
|
|
|
self.train()
|
|
|
|
|
|
return probabilities
|
|
|
|
|
|
def get_loss_fn(self) -> nn.Module:
|
|
|
"""
|
|
|
Returns an initialized loss function for training.
|
|
|
|
|
|
Returns:
|
|
|
nn.Module: An initialized loss function (e.g., CrossEntropyLoss).
|
|
|
"""
|
|
|
|
|
|
return nn.CrossEntropyLoss()
|
|
|
|