Text Classification
PyTorch
Safetensors
English
eurovoc
Inference Endpoints
eurovoc_eu / handler.py
scampion's picture
Update handler.py
988762a
raw
history blame
2.82 kB
from typing import Dict, List, Any
import numpy as np
import pickle
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer
import torch
from eurovoc import EurovocTagger
BERT_MODEL_NAME = "EuropeanParliament/EUBERT"
MAX_LEN = 512
TEXT_MAX_LEN = MAX_LEN * 50
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
class EndpointHandler:
mlb = MultiLabelBinarizer()
def __init__(self, path=""):
self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = EurovocTagger.from_pretrained(path,
bert_model_name=BERT_MODEL_NAME,
n_classes=len(self.mlb.classes_),
map_location=self.device)
self.model.eval()
self.model.freeze()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
text = data.pop("inputs", data)
topk = data.pop("topk", 5)
threshold = data.pop("threshold", 0.16)
debug = data.pop("debug", False)
prediction = self.get_prediction(text)
results = [{"label": label, "score": float(score)} for label, score in
zip(self.mlb.classes_, prediction[0].tolist())]
results = sorted(results, key=lambda x: x["score"], reverse=True)
results = [r for r in results if r["score"] > threshold]
results = results[:topk]
if debug:
return {"results": results, "values": prediction, "input": text}
else:
return {"results": results}
def get_prediction(self, text):
# split text into chunks of MAX_LEN and get average prediction for each chunk
chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
predictions = [self._get_prediction(chunk) for chunk in chunks]
predictions = np.array(predictions).mean(axis=0)
return predictions
def _get_prediction(self, text):
item = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=MAX_LEN,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt')
_, prediction = self.model(item["input_ids"], item["attention_mask"])
prediction = prediction.cpu().detach().numpy()
print(text, prediction)
return prediction