amank commited on
Commit
7839b8e
1 Parent(s): 139e10d

Made change to cleaning code, modified number of warmpu step, getting eval samples from validation split

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. .vscode/launch.json +3 -2
  3. run_mlm_flax_stream.py +33 -11
  4. run_stream.sh +2 -1
  5. utils.py +30 -8
.gitignore CHANGED
@@ -1 +1,3 @@
1
  __pycache__
 
 
 
1
  __pycache__
2
+ events.out.tfevents*
3
+ *xplane.pb
.vscode/launch.json CHANGED
@@ -17,8 +17,8 @@
17
  "--dataset_name","mc4",
18
  "--dataset_config_name","hi",
19
  "--max_seq_length","256",
20
- "--per_device_train_batch_size","128",
21
- "--per_device_eval_batch_size","128",
22
  "--learning_rate","3e-4",
23
  "--warmup_steps","1000",
24
  "--overwrite_output_dir",
@@ -26,6 +26,7 @@
26
  "--adam_beta2","0.98",
27
  "--num_train_steps","10000",
28
  "--num_eval_samples","5000",
 
29
  "--logging_steps","250",
30
  "--eval_steps","1000"
31
  ],
 
17
  "--dataset_name","mc4",
18
  "--dataset_config_name","hi",
19
  "--max_seq_length","256",
20
+ "--per_device_train_batch_size","16",
21
+ "--per_device_eval_batch_size","16",
22
  "--learning_rate","3e-4",
23
  "--warmup_steps","1000",
24
  "--overwrite_output_dir",
 
26
  "--adam_beta2","0.98",
27
  "--num_train_steps","10000",
28
  "--num_eval_samples","5000",
29
+ "--preprocessing_num_workers", "90",
30
  "--logging_steps","250",
31
  "--eval_steps","1000"
32
  ],
run_mlm_flax_stream.py CHANGED
@@ -31,7 +31,7 @@ from dataclasses import dataclass, field
31
  from pathlib import Path
32
  from typing import Dict, List, Optional, Tuple
33
 
34
- from utils import keep_devnagri
35
 
36
  import datasets
37
  import numpy as np
@@ -60,6 +60,7 @@ from transformers import (
60
  )
61
 
62
 
 
63
  # if datasets.__version__ <= "1.8.0":
64
  # raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
65
 
@@ -320,7 +321,6 @@ if __name__ == "__main__":
320
  # See all possible arguments in src/transformers/training_args.py
321
  # or by passing the --help flag to this script.
322
  # We now keep distinct sets of args, for a cleaner separation of concerns.
323
-
324
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
325
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
326
  # If we pass only one argument to the script and it's the path to a json file,
@@ -375,6 +375,13 @@ if __name__ == "__main__":
375
  streaming=True,
376
  split="train",
377
  )
 
 
 
 
 
 
 
378
 
379
  if model_args.config_name:
380
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
@@ -404,17 +411,26 @@ if __name__ == "__main__":
404
  def tokenize_function(examples):
405
  return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
406
 
407
- cleaned_dataset = dataset.map(
408
- keep_devnagri,
409
- batched=False,
 
 
 
410
  )
411
  tokenized_datasets = cleaned_dataset.map(
412
  tokenize_function,
413
- batched=True,
 
 
 
 
 
 
 
 
 
414
  )
415
-
416
- shuffle_seed = training_args.seed
417
- tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
418
 
419
  has_tensorboard = is_tensorboard_available()
420
  if has_tensorboard and jax.process_index() == 0:
@@ -428,6 +444,10 @@ if __name__ == "__main__":
428
 
429
  summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
430
 
 
 
 
 
431
  # Data collator
432
  # This one will take care of randomly masking the tokens.
433
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
@@ -446,6 +466,7 @@ if __name__ == "__main__":
446
  )
447
  if jax.device_count() < 8:
448
  print('Number of device as per jax device count is {}. Press Enter to continue'.format(jax.device_count()))
 
449
 
450
  # Store some constant
451
  num_epochs = int(training_args.num_train_epochs)
@@ -556,9 +577,10 @@ if __name__ == "__main__":
556
  eval_metrics = []
557
 
558
  training_iter = iter(tokenized_datasets)
 
559
 
560
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
561
- doc_count, eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
562
 
563
  steps = tqdm(range(num_train_steps), desc="Training...", position=0)
564
  docs_progress_bar = tqdm(range(dataset_doc_count * num_epochs), desc="Docs Processed...", position=0)
@@ -575,7 +597,7 @@ if __name__ == "__main__":
575
 
576
  training_iter = iter(tokenized_datasets)
577
 
578
- _, eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
579
  doc_count, samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
580
 
581
 
 
31
  from pathlib import Path
32
  from typing import Dict, List, Optional, Tuple
33
 
34
+ from utils import keep_devnagri_hf_doc
35
 
36
  import datasets
37
  import numpy as np
 
60
  )
61
 
62
 
63
+
64
  # if datasets.__version__ <= "1.8.0":
65
  # raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
66
 
 
321
  # See all possible arguments in src/transformers/training_args.py
322
  # or by passing the --help flag to this script.
323
  # We now keep distinct sets of args, for a cleaner separation of concerns.
 
324
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
325
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
326
  # If we pass only one argument to the script and it's the path to a json file,
 
375
  streaming=True,
376
  split="train",
377
  )
378
+ validation_dataset = load_dataset(
379
+ data_args.dataset_name,
380
+ data_args.dataset_config_name,
381
+ cache_dir=model_args.cache_dir,
382
+ streaming=True,
383
+ split="validation",
384
+ )
385
 
386
  if model_args.config_name:
387
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
 
411
  def tokenize_function(examples):
412
  return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
413
 
414
+ shuffle_seed = training_args.seed
415
+ shuffled_dataset = dataset.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
416
+
417
+ cleaned_dataset = shuffled_dataset.map(
418
+ keep_devnagri_hf_doc,
419
+ batched=True
420
  )
421
  tokenized_datasets = cleaned_dataset.map(
422
  tokenize_function,
423
+ batched=True
424
+ )
425
+
426
+ cleaned_validation_dataset = dataset.map(
427
+ keep_devnagri_hf_doc,
428
+ batched=True
429
+ )
430
+ tokenized_validation_datasets = cleaned_validation_dataset.map(
431
+ tokenize_function,
432
+ batched=True
433
  )
 
 
 
434
 
435
  has_tensorboard = is_tensorboard_available()
436
  if has_tensorboard and jax.process_index() == 0:
 
444
 
445
  summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
446
 
447
+ # code for manual tpu profiling
448
+ import jax.profiler
449
+ server = jax.profiler.start_server(9999)
450
+
451
  # Data collator
452
  # This one will take care of randomly masking the tokens.
453
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
 
466
  )
467
  if jax.device_count() < 8:
468
  print('Number of device as per jax device count is {}. Press Enter to continue'.format(jax.device_count()))
469
+ input()
470
 
471
  # Store some constant
472
  num_epochs = int(training_args.num_train_epochs)
 
577
  eval_metrics = []
578
 
579
  training_iter = iter(tokenized_datasets)
580
+ validation_iter = iter(tokenized_validation_datasets)
581
 
582
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
583
+ _, eval_samples = advance_iter_and_group_samples(validation_iter, data_args.num_eval_samples, max_seq_length)
584
 
585
  steps = tqdm(range(num_train_steps), desc="Training...", position=0)
586
  docs_progress_bar = tqdm(range(dataset_doc_count * num_epochs), desc="Docs Processed...", position=0)
 
597
 
598
  training_iter = iter(tokenized_datasets)
599
 
600
+ _, eval_samples = advance_iter_and_group_samples(validation_iter, data_args.num_eval_samples, max_seq_length)
601
  doc_count, samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
602
 
603
 
run_stream.sh CHANGED
@@ -10,11 +10,12 @@ python3 -c "import jax; print(jax.devices())"
10
  --per_device_train_batch_size="128" \
11
  --per_device_eval_batch_size="128" \
12
  --learning_rate="3e-4" \
13
- --warmup_steps="1000" \
14
  --overwrite_output_dir \
15
  --adam_beta1="0.9" \
16
  --adam_beta2="0.98" \
17
  --num_train_steps="10000" \
18
  --num_eval_samples="5000" \
 
19
  --logging_steps="250" \
20
  --eval_steps="1000"
 
10
  --per_device_train_batch_size="128" \
11
  --per_device_eval_batch_size="128" \
12
  --learning_rate="3e-4" \
13
+ --warmup_steps="10000" \
14
  --overwrite_output_dir \
15
  --adam_beta1="0.9" \
16
  --adam_beta2="0.98" \
17
  --num_train_steps="10000" \
18
  --num_eval_samples="5000" \
19
+ --preprocessing_num_workers="90" \
20
  --logging_steps="250" \
21
  --eval_steps="1000"
utils.py CHANGED
@@ -1,7 +1,7 @@
1
  import regex as re
2
  import string
3
 
4
- def keep_devnagri(document:str):
5
  """
6
  Remove all non Devnagri characters from the text.
7
  Code adapted from https://huggingface.co/flax-community/roberta-base-mr/blob/64d2c745f264f09c3d5b678a718746b2613887db/mr_clean_text.py
@@ -9,7 +9,6 @@ def keep_devnagri(document:str):
9
  @param text: str Text to be cleaned
10
  @return: Union[str, bool]
11
  """
12
- text = document['text']
13
  pattern = r'[\p{Devanagari}0-9।\s\.\!]+'
14
 
15
  # regex pattern for all puntuation symbols
@@ -24,11 +23,34 @@ def keep_devnagri(document:str):
24
  # identify if the clean text only consists of punctuation
25
  is_just_punctuation = len(re.sub(punctuation_regex, "", cleaned)) == 0
26
 
27
- # to handle the tokenizer as empty string may cause issues
28
- # also this only happens for 5 out of 10000 docs, should not
29
- # affect the results
30
- if is_just_punctuation:
31
- document['text'] = " "
 
 
32
  else:
33
- document['text'] = cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return document
 
1
  import regex as re
2
  import string
3
 
4
+ def keep_devnagri(text:str):
5
  """
6
  Remove all non Devnagri characters from the text.
7
  Code adapted from https://huggingface.co/flax-community/roberta-base-mr/blob/64d2c745f264f09c3d5b678a718746b2613887db/mr_clean_text.py
 
9
  @param text: str Text to be cleaned
10
  @return: Union[str, bool]
11
  """
 
12
  pattern = r'[\p{Devanagari}0-9।\s\.\!]+'
13
 
14
  # regex pattern for all puntuation symbols
 
23
  # identify if the clean text only consists of punctuation
24
  is_just_punctuation = len(re.sub(punctuation_regex, "", cleaned)) == 0
25
 
26
+ return cleaned, is_just_punctuation
27
+
28
+ def keep_devnagri_hf_doc(document):
29
+ if isinstance(document['text'], str):
30
+ batched = False
31
+ elif isinstance(document['text'], list):
32
+ batched = True
33
  else:
34
+ raise TypeError("Document must be a dictionary or list.")
35
+
36
+ def get_clean_text(text):
37
+ cleaned_text, is_just_punctuation = keep_devnagri(text)
38
+ # to handle the tokenizer as empty string may cause issues
39
+ # also this only happens for 5 out of 10000 docs, should not
40
+ # affect the results
41
+ cleaned_text = cleaned_text if not is_just_punctuation else " "
42
+ return cleaned_text
43
+
44
+ if batched:
45
+ text_ls = document['text']
46
+ cleaned_text_ls = []
47
+ for text in text_ls:
48
+ cleaned_text = get_clean_text(text)
49
+ cleaned_text_ls.append(cleaned_text)
50
+ document['text'] = cleaned_text_ls
51
+ else:
52
+ text = document['text']
53
+ cleaned_text = get_clean_text(text)
54
+ document['text'] = cleaned_text
55
+
56
  return document