philschmid's picture
philschmid HF staff
Update pipeline.py
770e1d9
raw history blame
No virus
1.36 kB
from typing import Dict, List, Any
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import pipeline, AutoTokenizer
def cls_pooling(model_output):
return model_output.last_hidden_state[:,0]
class PreTrainedPipeline():
def __init__(self, path=""):
# load the optimized model
self.model = ORTModelForFeatureExtraction.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path, model_max_length=128)
def __call__(self, inputs: Any) -> Dict[str, Any]:
"""
Args:
data (:obj:`str`):
a string containing some text
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
# tokenize the input
encoded_input = self.tokenizer(inputs, padding="longest", truncation=True, return_tensors='pt')
# run the model
model_output = self.model(**encoded_input, return_dict=True)
embeddings = cls_pooling(model_output)
return {"vectors": [float(vec) for vec in embeddings[0].tolist()]}