flash_gen_bert_para / handler.py
cguynup's picture
Update handler.py
59fca60
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
class EndpointHandler():
def __init__(self, path=""):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForSequenceClassification.from_pretrained(path)
self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
def __call__(self, data):
inputs = data.pop("inputs", data)
def iterator():
for i in inputs:
yield i
labels = []
for out in self.pipeline(iterator(), padding=True, truncation=True, max_length=253):
labels.append(int(out["label"][-1]))
return {
"pairs": inputs,
"evaluations": labels
}