PardisSzah
commited on
Commit
•
c96e0a1
1
Parent(s):
360b18c
commit files to HF hub
Browse files- PersianTextFormalizerPipeline.py +19 -0
- config.json +9 -0
PersianTextFormalizerPipeline.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline, T5ForConditionalGeneration, AutoTokenizer
|
2 |
+
|
3 |
+
class PersianTextFormalizerPipeline(Pipeline):
|
4 |
+
|
5 |
+
def _sanitize_parameters(self, **kwargs):
|
6 |
+
preprocess_kwargs = {}
|
7 |
+
if "second_text" in kwargs:
|
8 |
+
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
9 |
+
return preprocess_kwargs, {}, {}
|
10 |
+
|
11 |
+
def preprocess(self, text, second_text=None):
|
12 |
+
inputs = self.tokenizer.encode("informal: " + text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
|
13 |
+
return inputs.to(self.device)
|
14 |
+
|
15 |
+
def _forward(self, model_inputs):
|
16 |
+
return self.model.generate(model_inputs, max_length=128, num_beams=4, temperature=0.7)
|
17 |
+
|
18 |
+
def postprocess(self, model_outputs):
|
19 |
+
return self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)
|
config.json
CHANGED
@@ -4,6 +4,15 @@
|
|
4 |
"T5ForConditionalGeneration"
|
5 |
],
|
6 |
"classifier_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
"d_ff": 2048,
|
8 |
"d_kv": 64,
|
9 |
"d_model": 768,
|
|
|
4 |
"T5ForConditionalGeneration"
|
5 |
],
|
6 |
"classifier_dropout": 0.0,
|
7 |
+
"custom_pipelines": {
|
8 |
+
"text2text-PersianTextFormalizer_M": {
|
9 |
+
"impl": "PersianTextFormalizerPipeline.PersianTextFormalizerPipeline",
|
10 |
+
"pt": [
|
11 |
+
"T5ForConditionalGeneration"
|
12 |
+
],
|
13 |
+
"tf": []
|
14 |
+
}
|
15 |
+
},
|
16 |
"d_ff": 2048,
|
17 |
"d_kv": 64,
|
18 |
"d_model": 768,
|