from typing import Dict, List, Any from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX class PreTrainedPipeline(): def __init__(self, path: str): cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX( directory=path, spe_filename="spe_32k_lc_en.model", model_filename="punct_cap_seg_en.onnx", config_filename="config.yaml", ) self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg) def __call__(self, data: str) -> List[Dict]: # Use list to generate a batch of size 1 pred_texts: List[List[str]] = self._punctuator.infer([data]) # Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now. outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}] return outputs