File size: 3,748 Bytes
f8b2708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import functools
import seqio
import tensorflow as tf
import t5.data
from datasets import load_dataset, load_from_disk
from t5.data import postprocessors
from t5.data import preprocessors
from t5.evaluation import metrics
from seqio import FunctionDataSource, utils

from ul2_objective import ul2_objective

# values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]

R_DENOISER_TOKEN_PREFIX = '[NLU]'
X_DENOISER_TOKEN_PREFIX = '[NLG]'
S_DENOISER_TOKEN_PREFIX = '[S2S]'

TaskRegistry = seqio.TaskRegistry

vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=vocabulary, add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=vocabulary, add_eos=True)
}


def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            if item[column] is not None:
                yield item[column]


def dataset_fn(split, shuffle_files, seed=None, dataset=None):
    return tf.data.Dataset.from_generator(
        functools.partial(gen_dataset, split, shuffle_files,
                          seed, dataset=dataset),
        output_signature=tf.TensorSpec(
            shape=(), dtype=tf.string, name=dataset_name)
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}


dataset_name = "/researchdisk/lm_training_dataset_full"
dataset_params = {"from_disk_path": dataset_name}

if "from_disk_path" in dataset_params:
    dataset = load_from_disk(dataset_params.get("from_disk_path"))
else:
    dataset = load_dataset(**dataset_params)

dataset_shapes = {"train": dataset["train"].num_rows,
                  "validation": dataset["validation"].num_rows}

TaskRegistry.add(
    "pretrain_finnish_ul2",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(dataset_fn, dataset=dataset),
        splits=("train", "validation"),
        caching_permitted=False,
        num_input_examples=dataset_shapes,
    ),
    preprocessors=[
        functools.partial(
            target_to_key, key_map={
                "inputs": None,
                "targets": None,
            }, target_key="targets"),
        seqio.preprocessors.tokenize,
        functools.partial(
            ul2_objective,
            shard_ds=False,
            use_prefix_lm_task=True,  # use S-denoising
            rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
                0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2],  # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
            mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
            noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
            optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
                X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
            reserved_for_packing=1,  # make room for task prefix token
        ),
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
    metric_fns=[metrics.accuracy]
)