balanced lm task
Browse files- norwegian_lm_large.gin +1 -1
- tasks.py +29 -0
norwegian_lm_large.gin
CHANGED
@@ -13,7 +13,7 @@ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
|
13 |
TRAIN_STEPS = 1_600_000
|
14 |
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
15 |
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_large/checkpoint_1500000"
|
16 |
-
PjitPartitioner.num_partitions = 1
|
17 |
utils.SaveCheckpointConfig.period = 5000
|
18 |
utils.SaveCheckpointConfig.keep = 3
|
19 |
|
|
|
13 |
TRAIN_STEPS = 1_600_000
|
14 |
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
15 |
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_large/checkpoint_1500000"
|
16 |
+
#PjitPartitioner.num_partitions = 1
|
17 |
utils.SaveCheckpointConfig.period = 5000
|
18 |
utils.SaveCheckpointConfig.keep = 3
|
19 |
|
tasks.py
CHANGED
@@ -273,3 +273,32 @@ TaskRegistry.add(
|
|
273 |
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
274 |
metric_fns=[]
|
275 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
274 |
metric_fns=[]
|
275 |
)
|
276 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
277 |
+
dataset_name = 'NbAiLab/balanced_bokmaal_nynorsk'
|
278 |
+
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
279 |
+
dataset_shapes = None
|
280 |
+
TaskRegistry.add(
|
281 |
+
"balanced_bokmaal_nynorsk_prefix_lm_stream",
|
282 |
+
source=seqio.FunctionDataSource(
|
283 |
+
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
284 |
+
splits=("train", "validation"),
|
285 |
+
caching_permitted=False,
|
286 |
+
num_input_examples=dataset_shapes,
|
287 |
+
),
|
288 |
+
preprocessors=[
|
289 |
+
functools.partial(
|
290 |
+
target_to_key, key_map={
|
291 |
+
"inputs": None,
|
292 |
+
"targets": None,
|
293 |
+
}, target_key="targets"),
|
294 |
+
seqio.preprocessors.tokenize,
|
295 |
+
# seqio.CacheDatasetPlaceholder(),
|
296 |
+
preprocessors.prefix_lm,
|
297 |
+
seqio.preprocessors.append_eos_after_trim,
|
298 |
+
],
|
299 |
+
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
300 |
+
metric_fns=[]
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
|