risedev commited on
Commit
8e71a92
1 Parent(s): 5448fc1

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -26
handler.py CHANGED
@@ -1,34 +1,27 @@
1
- from typing import Dict, List, Any
2
- from optimum.onnxruntime import ORTModelForSequenceClassification
3
- from transformers import pipeline, AutoTokenizer
 
 
4
 
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
- # load the optimized model
9
- model = ORTModelForSequenceClassification.from_pretrained(path)
10
- tokenizer = AutoTokenizer.from_pretrained(path)
11
- # create inference pipeline
12
- self.pipeline = pipeline("zero-shot-image-classification", model=model, tokenizer=tokenizer)
13
-
14
-
15
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
16
  """
17
- Args:
18
- data (:obj:):
19
- includes the input data and the parameters for the inference.
20
- Return:
21
- A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
22
- - "label": A string representing what the label/class is. There can be multiple labels.
23
- - "score": A score between 0 and 1 describing how confident the model is for this label/class.
24
  """
25
  inputs = data.pop("inputs", data)
26
- parameters = data.pop("parameters", None)
27
 
28
- # pass inputs with all kwargs in data
29
- if parameters is not None:
30
- prediction = self.pipeline(inputs, **parameters)
31
- else:
32
- prediction = self.pipeline(inputs)
33
- # postprocess the prediction
34
- return prediction
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from transformers import pipeline
5
+ import base64
6
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
+ self.pipeline=pipeline("zero-shot-image-classification",model=path)
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
13
  """
14
+ data args:
15
+ images (:obj:`string`)
16
+ candiates (:obj:`list`)
17
+ Return:
18
+ A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
 
 
19
  """
20
  inputs = data.pop("inputs", data)
 
21
 
22
+ # decode base64 image to PIL
23
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
24
+
25
+ # run prediction one image wit provided candiates
26
+ prediction = self.pipeline(images=[image], candidate_labels=inputs["candiates"])
27
+ return prediction[0]