martinjurkovic commited on
Commit
1db1a8f
1 Parent(s): 1472770

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -2
handler.py CHANGED
@@ -14,7 +14,7 @@ class EndpointHandler():
14
  # self.pipe.model.config.forced_decoder_ids = self.pipe.model.processor.get_decoder_prompt_ids(language="Slovenian", task="transcribe")
15
  # self.pipe.model.generation_config.forced_decoder_ids = self.pipe.model.config.forced_decoder_ids
16
 
17
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
  """
19
  data args:
20
  inputs (:obj: `str` | `PIL.Image` | `np.array`)
@@ -23,6 +23,7 @@ class EndpointHandler():
23
  A :obj:`list` | `dict`: will be serialized and returned
24
  """
25
  inputs = data.pop("inputs",data)
 
26
  # print("inputs", inputs)
27
- prediction = self.pipe(inputs, generate_kwargs={"language": "Slovenian", "task": "transcribe"})
28
  return prediction
 
14
  # self.pipe.model.config.forced_decoder_ids = self.pipe.model.processor.get_decoder_prompt_ids(language="Slovenian", task="transcribe")
15
  # self.pipe.model.generation_config.forced_decoder_ids = self.pipe.model.config.forced_decoder_ids
16
 
17
+ def __call__(self, data: Dict[str, Any], **kwargs) -> List[Dict[str, Any]]:
18
  """
19
  data args:
20
  inputs (:obj: `str` | `PIL.Image` | `np.array`)
 
23
  A :obj:`list` | `dict`: will be serialized and returned
24
  """
25
  inputs = data.pop("inputs",data)
26
+ language = kwargs.get("language", "sl")
27
  # print("inputs", inputs)
28
+ prediction = self.pipe(inputs, generate_kwargs={"language": language, "task": "transcribe"})
29
  return prediction