|
from typing import Dict, List, Any |
|
|
|
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path: str): |
|
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX( |
|
directory=path, |
|
spe_filename="spe_unigram_64k_lowercase_47lang.model", |
|
model_filename="punct_cap_seg_47lang.onnx", |
|
config_filename="config.yaml", |
|
) |
|
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg) |
|
|
|
def __call__(self, data: str) -> List[Dict]: |
|
|
|
pred_texts: List[List[str]] = self._punctuator.infer([data]) |
|
|
|
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}] |
|
return outputs |
|
|