FinDeBERTa / handler.py
ritessshhh's picture
Upload folder using huggingface_hub
3e4ee15 verified
import torch
import numpy as np
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path=""):
# Load the model and tokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Load per-class thresholds
thresholds_path = f"{path}/thresholds.npy"
self.thresholds = np.load(thresholds_path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict[str, Any]): Input data containing 'inputs' key
Returns:
List[Dict[str, Any]]: Predictions with labels and scores
"""
inputs_text = data.pop("inputs", data)
# Tokenize
inputs = self.tokenizer(
inputs_text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
)
# Inference
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[0]
probs = torch.sigmoid(logits).cpu().numpy()
# Apply per-class thresholds
predictions = []
for idx, prob in enumerate(probs):
if prob >= self.thresholds[idx]:
predictions.append({
"label": self.model.config.id2label[idx],
"score": float(prob)
})
# Sort by score descending
predictions = sorted(predictions, key=lambda x: x["score"], reverse=True)
return predictions