File size: 944 Bytes
9ed74a2
 
 
 
8d88f31
7f3abba
 
d237660
9ed74a2
 
 
d237660
9ed74a2
 
d237660
9ed74a2
 
 
d237660
9ed74a2
 
2eb1a84
9ed74a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from typing import Dict, List, Any
from transformers import pipeline

class EndpointHandler():
    def __init__(self, path=""):
        self.pipeline = pipeline(task= "automatic-speech-recognition", model=path)
        self.pipeline.model.config.forced_decoder_ids = self.pipeline.tokenizer.get_decoder_prompt_ids(language="Chinese", task="transcribe")
        self.pipeline.model.generation_config.forced_decoder_ids = self.pipeline.model.config.forced_decoder_ids  # just to be sure!

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        
        # run normal prediction
        inputs = data.pop("inputs",data)
        prediction = self.pipeline(inputs, return_timestamps=False)
        return prediction