1-800-BAD-CODE commited on
Commit
70c8b16
1 Parent(s): 03cc4e4

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -0
pipeline.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+
3
+ from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX
4
+
5
+
6
+ class PreTrainedPipeline():
7
+ def __init__(self, path: str):
8
+ cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
9
+ directory=path,
10
+ spe_filename="sp.model",
11
+ model_filename="model.onnx",
12
+ config_filename="config.yaml",
13
+ )
14
+ self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)
15
+
16
+ def __call__(self, data: str) -> List[Dict]:
17
+ # Use list to generate a batch of size 1
18
+ pred_texts: List[List[str]] = self._punctuator.infer([data])
19
+ # Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
20
+ outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
21
+ return outputs