email-classifier-bert / handler.py
vertigoq3's picture
Add inference endpoint handler
8c3fc6d verified
"""
Handler para el Inference Endpoint del clasificador de emails
"""
import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
class EndpointHandler:
def __init__(self):
self.model = None
self.tokenizer = None
self.encoder = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_model()
def load_model(self):
"""Cargar el modelo"""
try:
# Cargar modelo y tokenizer
self.model = AutoModelForSequenceClassification.from_pretrained("vertigoq3/email-classifier-bert")
self.tokenizer = AutoTokenizer.from_pretrained("vertigoq3/email-classifier-bert")
# Mover al dispositivo
self.model.to(self.device)
self.model.eval()
# Cargar encoder
encoder_path = hf_hub_download(
repo_id="vertigoq3/email-classifier-bert",
filename="label_encoder.pkl"
)
with open(encoder_path, "rb") as f:
self.encoder = pickle.load(f)
except Exception as e:
print(f"Error al cargar modelo: {e}")
raise
def __call__(self, inputs):
"""Procesar una solicitud de inferencia"""
try:
if isinstance(inputs, str):
text = inputs
elif isinstance(inputs, dict) and "inputs" in inputs:
text = inputs["inputs"]
else:
text = str(inputs)
# Tokenizar
tokenized = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
# Clasificar
with torch.no_grad():
outputs = self.model(**tokenized)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
predicted_class_id = torch.argmax(probabilities, dim=-1).item()
predicted_class = self.encoder.inverse_transform([predicted_class_id])[0]
confidence = float(probabilities[0][predicted_class_id])
return {
"predicted_class": predicted_class,
"confidence": confidence,
"all_probabilities": {
self.encoder.classes_[i]: float(probabilities[0][i])
for i in range(len(self.encoder.classes_))
}
}
except Exception as e:
return {"error": str(e)}
# Crear instancia global
handler = EndpointHandler()