File size: 1,156 Bytes
d7d7e6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Dict, List, Any
import mii
import json

class EndpointHandler():
    def __init__(self, path=""):
        self.deploy_name = "bert"
        mii_config = {"dtype": "fp16"}
        mii.deploy(task='text-classification',
                   model="philschmid/finbert-tone-endpoint-ds",
                   deployment_name=self.deploy_name,
                   mii_config=mii_config)
        # create handler for server
        self.pipeline = mii.mii_query_handle(self.deploy_name)


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        # pass inputs with all kwargs in data
        if parameters is not None:
            prediction = self.pipeline.query({"query": inputs}, **parameters)
        else:
            prediction = self.pipeline.query({"query": inputs})
        # postprocess the prediction
        return prediction.response