pere commited on
Commit
b73aadd
1 Parent(s): f431aa4

balanced lm task

Browse files
Files changed (2) hide show
  1. norwegian_lm_large.gin +1 -1
  2. tasks.py +29 -0
norwegian_lm_large.gin CHANGED
@@ -13,7 +13,7 @@ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
13
  TRAIN_STEPS = 1_600_000
14
  DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
15
  INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_large/checkpoint_1500000"
16
- PjitPartitioner.num_partitions = 1
17
  utils.SaveCheckpointConfig.period = 5000
18
  utils.SaveCheckpointConfig.keep = 3
19
 
 
13
  TRAIN_STEPS = 1_600_000
14
  DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
15
  INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_large/checkpoint_1500000"
16
+ #PjitPartitioner.num_partitions = 1
17
  utils.SaveCheckpointConfig.period = 5000
18
  utils.SaveCheckpointConfig.keep = 3
19
 
tasks.py CHANGED
@@ -273,3 +273,32 @@ TaskRegistry.add(
273
  output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
274
  metric_fns=[]
275
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
274
  metric_fns=[]
275
  )
276
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
277
+ dataset_name = 'NbAiLab/balanced_bokmaal_nynorsk'
278
+ dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
279
+ dataset_shapes = None
280
+ TaskRegistry.add(
281
+ "balanced_bokmaal_nynorsk_prefix_lm_stream",
282
+ source=seqio.FunctionDataSource(
283
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
284
+ splits=("train", "validation"),
285
+ caching_permitted=False,
286
+ num_input_examples=dataset_shapes,
287
+ ),
288
+ preprocessors=[
289
+ functools.partial(
290
+ target_to_key, key_map={
291
+ "inputs": None,
292
+ "targets": None,
293
+ }, target_key="targets"),
294
+ seqio.preprocessors.tokenize,
295
+ # seqio.CacheDatasetPlaceholder(),
296
+ preprocessors.prefix_lm,
297
+ seqio.preprocessors.append_eos_after_trim,
298
+ ],
299
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
300
+ metric_fns=[]
301
+ )
302
+
303
+
304
+