1-800-BAD-CODE
commited on
Commit
•
70c8b16
1
Parent(s):
03cc4e4
Create pipeline.py
Browse files- pipeline.py +21 -0
pipeline.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
|
3 |
+
from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX
|
4 |
+
|
5 |
+
|
6 |
+
class PreTrainedPipeline():
|
7 |
+
def __init__(self, path: str):
|
8 |
+
cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX(
|
9 |
+
directory=path,
|
10 |
+
spe_filename="sp.model",
|
11 |
+
model_filename="model.onnx",
|
12 |
+
config_filename="config.yaml",
|
13 |
+
)
|
14 |
+
self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg)
|
15 |
+
|
16 |
+
def __call__(self, data: str) -> List[Dict]:
|
17 |
+
# Use list to generate a batch of size 1
|
18 |
+
pred_texts: List[List[str]] = self._punctuator.infer([data])
|
19 |
+
# Can't figure out how to make the text gen widget print multiple lines; use a '\n' for now.
|
20 |
+
outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}]
|
21 |
+
return outputs
|