pere commited on
Commit
d3986eb
1 Parent(s): b032b6f

First attempt

Browse files
Files changed (6) hide show
  1. README.md +16 -0
  2. base_wmt_infer.gin +23 -0
  3. finetune_mt5_sentencefix.gin +41 -0
  4. interference.sh +16 -0
  5. tasks.py +65 -0
  6. train.sh +12 -0
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc
3
+ ---
4
+ # Multi-Lingual DeUnCaser - Base mT5 Version
5
+ The output from Automated Speak Recognition software is usually uncased and without any punctation. This does not make a very readable text.
6
+
7
+ The DeUnCaser is a sequence-to-sequence model that is reversing this process. It adds punctation, and capitalises the correct words. In some languages this means adding capital letters at start of sentences and on all proper nouns, in other languages, like German, it means capitalising the first letter of all nouns. It will also make attempts at adding hyphens and parentheses if this is making the meaning clearer.
8
+
9
+ It is using based on the multi-lingual T5 model. It is finetuned for 100,000 steps. The finetuning scripts is based on 100,000 training examples from each of the 44 languages with Latin alphabet that is both part of OSCAR and the mT5 training set: Afrikaans, Albanian, Basque, Catalan, Cebuano, Czech, Danish, Dutch, English, Esperanto, Estonian, Finnish, French, Galician, German, Haitian Creole, Hungarian, Icelandic, Indonesian, Irish, Italian, Kurdish, Latin, Latvian, Lithuanian, Luxembourgish, Malagasy, Malay, Maltese, Norwegian Bokmål, Norwegian Nynorsk, Polish, Portuguese, Romanian, Slovak, Spanish, Sundanese, Swahili, Swedish, Turkish, Uzbek, Vietnamese, Welsh, West Frisian.
10
+
11
+ A Notebook for creating the training corpus is available [here](https://colab.research.google.com/drive/1bkH94z-0wIQP8Pz0qXFndhoQsokU-78x?usp=sharing).
12
+
13
+
14
+
15
+
16
+
base_wmt_infer.gin ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import tasks
3
+
4
+ import __main__ as infer_script
5
+ from t5.data import mixtures
6
+ from t5x import partitioning
7
+ from t5x import utils
8
+
9
+ include "t5x/examples/t5/mt5/base.gin"
10
+ include "t5x/configs/runs/infer.gin"
11
+
12
+ DROPOUT_RATE = 0.0 # unused but needs to be specified
13
+ MIXTURE_OR_TASK_NAME = "sentencefix"
14
+ TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
15
+
16
+ infer_script.infer:
17
+ partitioner = @partitioning.ModelBasedPjitPartitioner()
18
+
19
+ partitioning.ModelBasedPjitPartitioner.num_partitions = 1
20
+
21
+ utils.DatasetConfig:
22
+ split = "test"
23
+ batch_size = 32
finetune_mt5_sentencefix.gin ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import tasks
3
+
4
+ import __main__ as train_script
5
+ from t5.data import mixtures
6
+ from t5x import models
7
+ from t5x import partitioning
8
+ from t5x import utils
9
+
10
+ include "t5x/examples/t5/mt5/base.gin"
11
+ include "t5x/configs/runs/finetune.gin"
12
+
13
+ MIXTURE_OR_TASK_NAME = "sentencefix"
14
+ TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
15
+ TRAIN_STEPS = 1_050_000 # 1000000 pre-trained steps + 20000 fine-tuning steps.
16
+ USE_CACHED_TASKS = False
17
+ DROPOUT_RATE = 0.0
18
+ RANDOM_SEED = 0
19
+
20
+ # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained
21
+ # using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be
22
+ # set to `pretraining batch_size` * `target_token_length`. For T5 and T5.1.1:
23
+ # `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.
24
+ LOSS_NORMALIZING_FACTOR = 234496
25
+ INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_base/checkpoint_1000000"
26
+
27
+ train_script.train:
28
+ eval_period = 500
29
+ partitioner = @partitioning.ModelBasedPjitPartitioner()
30
+
31
+ # `num_decodes` is equivalent to a beam size in a beam search decoding.
32
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
33
+
34
+ partitioning.ModelBasedPjitPartitioner.num_partitions = 2
35
+
36
+
37
+ #from t5.models import mesh_transformer
38
+ #import t5.models
39
+ #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
40
+ #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
41
+
interference.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INFER_OUTPUT_DIR="output" # directory to write infer output
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
3
+ TFDS_DATA_DIR="gs://nb-t5x/corpus_multi_sentencefix_mt5/"
4
+ CHECKPOINT_PATH="gs://nb-t5x/corpus_multi_sentencefix_mt5/checkpoint_1100000"
5
+ PROJECT_DIR=${HOME}"/mymodel/multi_sentencefix_mt5"
6
+ export PYTHONPATH=${PROJECT_DIR}
7
+
8
+ python3 ${T5X_DIR}/t5x/infer.py \
9
+ --gin_search_paths=${PROJECT_DIR} \
10
+ --gin_file="base_wmt_infer.gin" \
11
+ --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
12
+ --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
13
+ --tfds_data_dir=${TFDS_DATA_DIR}
14
+
15
+
16
+
tasks.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /home/perk/mymodel/sentencefix/tasks.py
2
+
3
+ import functools
4
+ import seqio
5
+ import tensorflow_datasets as tfds
6
+ from t5.evaluation import metrics
7
+ from t5.data import preprocessors
8
+ import t5
9
+ import tensorflow.compat.v1 as tf
10
+
11
+ tsv_path = {
12
+ "train": "gs://nb-t5x/corpus/train/train.tsv",
13
+ "validation": "gs://nb-t5x/corpus/eval/eval.tsv",
14
+ "test": "gs://nb-t5x/corpus/test/test.tsv"
15
+ }
16
+
17
+
18
+ vocabulary = t5.data.ByteVocabulary()
19
+
20
+ DEFAULT_OUTPUT_FEATURES = {
21
+ "inputs":
22
+ seqio.Feature(
23
+ vocabulary=vocabulary, add_eos=True),
24
+ "targets":
25
+ seqio.Feature(
26
+ vocabulary=vocabulary, add_eos=True)
27
+ }
28
+
29
+ def sentencefix_preprocessor(ds):
30
+ def normalize_text(text):
31
+ """Lowercase and remove quotes from a TensorFlow string."""
32
+ text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
33
+ return text
34
+
35
+ def to_inputs_and_targets(ex):
36
+ """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}."""
37
+ return {
38
+ "inputs":
39
+ tf.strings.join(
40
+ [normalize_text(ex["source"])]),
41
+ "targets":
42
+ tf.strings.join(
43
+ [normalize_text(ex["target"])]),
44
+ }
45
+ return ds.map(to_inputs_and_targets,
46
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
47
+
48
+
49
+ seqio.TaskRegistry.add(
50
+ "sentencefix",
51
+ source=seqio.TextLineDataSource(
52
+ split_to_filepattern=tsv_path,
53
+ #num_input_examples=num_nq_examples
54
+ ),
55
+ preprocessors=[
56
+ functools.partial(
57
+ t5.data.preprocessors.parse_tsv,
58
+ field_names=["source", "target"]),
59
+ sentencefix_preprocessor,
60
+ seqio.preprocessors.tokenize_and_append_eos,
61
+ ],
62
+ #metric_fns=[t5.evaluation.metrics.bleu],
63
+ output_features=DEFAULT_OUTPUT_FEATURES,
64
+ )
65
+
train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/mymodel/multi-sentencefix-mt5"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ TFDS_DATA_DIR="gs://nb-t5x/corpus_multi_sentencefix_mt5"
4
+ MODEL_DIR="gs://nb-t5x/model_multi_sentencefix_mt5"
5
+ export PYTHONPATH=${PROJECT_DIR}
6
+
7
+ python3 ${T5X_DIR}/t5x/train.py \
8
+ --gin_search_paths=${PROJECT_DIR} \
9
+ --gin_file="finetune_mt5_sentencefix.gin" \
10
+ --gin.MODEL_DIR="'${MODEL_DIR}'" \
11
+ --tfds_data_dir=${TFDS_DATA_DIR}
12
+