philschmid HF staff commited on
Commit
4865f8a
1 Parent(s): 770e1d9

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +5 -13
pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
- from typing import Dict, List, Any
2
  from optimum.onnxruntime import ORTModelForFeatureExtraction
3
- from transformers import pipeline, AutoTokenizer
4
 
5
  def cls_pooling(model_output):
6
  return model_output.last_hidden_state[:,0]
@@ -11,20 +11,12 @@ class PreTrainedPipeline():
11
  self.tokenizer = AutoTokenizer.from_pretrained(path, model_max_length=128)
12
 
13
 
14
- def __call__(self, inputs: Any) -> Dict[str, Any]:
15
- """
16
- Args:
17
- data (:obj:`str`):
18
- a string containing some text
19
- Return:
20
- A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
21
- - "label": A string representing what the label/class is. There can be multiple labels.
22
- - "score": A score between 0 and 1 describing how confident the model is for this label/class.
23
- """
24
  # tokenize the input
25
  encoded_input = self.tokenizer(inputs, padding="longest", truncation=True, return_tensors='pt')
26
  # run the model
27
  model_output = self.model(**encoded_input, return_dict=True)
28
  embeddings = cls_pooling(model_output)
29
 
30
- return {"vectors": [float(vec) for vec in embeddings[0].tolist()]}
 
 
1
+ from typing import Dict,List, Any
2
  from optimum.onnxruntime import ORTModelForFeatureExtraction
3
+ from transformers import AutoTokenizer
4
 
5
  def cls_pooling(model_output):
6
  return model_output.last_hidden_state[:,0]
 
11
  self.tokenizer = AutoTokenizer.from_pretrained(path, model_max_length=128)
12
 
13
 
14
+ def __call__(self, inputs: Any) -> Dict[str, List[float]]:
 
 
 
 
 
 
 
 
 
15
  # tokenize the input
16
  encoded_input = self.tokenizer(inputs, padding="longest", truncation=True, return_tensors='pt')
17
  # run the model
18
  model_output = self.model(**encoded_input, return_dict=True)
19
  embeddings = cls_pooling(model_output)
20
 
21
+ return {"vectors": embeddings[0].tolist()}
22
+