m3hrdadfi commited on
Commit
31bf2aa
1 Parent(s): 8918872

Add runner, fix some bugs

Browse files
src/data_utils.py CHANGED
@@ -22,13 +22,14 @@ def filter_by_num_tokens(text, gt=64):
22
  def filter_by_num_sents(text, gt=2):
23
  return len(sent_tokenize(text)) > gt
24
 
25
- def remove_adds(text,ratio=50):
 
26
  comma = text.split(",")
27
- colon = re.findall(r'(?:([^\W]+):([^\W]+))',text)
28
  virgool = text.split("،")
29
- length_add = len(comma)+len(colon)+len(virgool)
30
 
31
- return True if length_add < ratio else False
32
 
33
 
34
  def normalizer(text, do_lowercase=False):
 
22
  def filter_by_num_sents(text, gt=2):
23
  return len(sent_tokenize(text)) > gt
24
 
25
+
26
+ def filter_by_adv(text, ratio=50):
27
  comma = text.split(",")
28
+ colon = re.findall(r"""(?:([^\W]+):([^\W]+))""", text)
29
  virgool = text.split("،")
30
+ length_add = len(comma) + len(colon) + len(virgool)
31
 
32
+ return length_add < ratio
33
 
34
 
35
  def normalizer(text, do_lowercase=False):
src/normalizer.py CHANGED
@@ -25,13 +25,15 @@ def multiple_replace(text, chars_to_mapping):
25
  pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
26
  return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
27
 
28
- def remove_tags(text):
29
- tag = "برچسب ها :"
30
- try:
31
- text = text[:text.index(tag)]
32
- return text
33
- except:
34
- return text
 
 
35
 
36
  def clean_url(text):
37
  # removing html tags
@@ -86,7 +88,7 @@ def normalize(text, zwnj="\u200c", tokenized=False):
86
  text = DOUBLE_QUOTE_REGEX.sub('"', text)
87
  text = CURRENCY_REGEX.sub(r" \1 ", text)
88
  text = clean_url(text)
89
- text = remove_tags(text)
90
  text = URL_REGEX.sub(" ", text)
91
  text = EMAIL_REGEX.sub(" ", text)
92
  text = PHONE_REGEX.sub(r" \1 ", text)
@@ -128,34 +130,8 @@ def normalize(text, zwnj="\u200c", tokenized=False):
128
  if __name__ == '__main__':
129
  import textwrap
130
 
131
- # input_text = "دارهٔ تحقیقات فدرال در سال ۱۹۰۸ به نام ادارهٔ تحقیقات (BOI یا BI) بنیان‌گذاری شد. نام این سازمان در سال ۱۹۳۵ به ادارهٔ تحقیقات فدرال تغییر یافت. دفتر مرکزی اف‌بی‌آی در ساختمان جی. ادگار هوور در شهر واشینگتن، دی.سی. واقع شده‌است."
132
- # input_text = "یونان (به یونانی: Ελλάδα, اِلادا)"
133
- # input_text = "نسخهٔ"
134
- # input_text = "ὑ蕉Ұ제ṅ尘̲改座◦花芝秀黄天자埃澤ಿ ˈazbab اینجا ایران خانه‌شما است؟!۱۲۳۱۲۳۱۳۱۲ اَلْحُرُوفُ ٱلْعَرَبِیَّة"
135
- input_text = """
136
- قـــــــــــــــــرار بود با هم کنـــــــــــــار بیایم نه اینکه از کنــــــــــــار هم رد بشیم...!!!
137
- اگر روزی دلت لبریز غم بود گذارت بر مزار کهنه ام بود بگو این بی نصیب خفته در خاک یه روزی عاشق و دیوانه ام بود...
138
- خبر به دورترین نقطه جهان برسد نخواست او به من خسته ، بی کمان برسد شکنجه بیشتر از این که پیش چشم خودت کسی که سهم تو باشد به دیگران برسد خدا کند ، که نفرین نمی کنم نکند به او که عاشق او بوده ام زیان برسد خدا کند فقط این عشق از سرم برود خدا کند که فقط زود آن زمان برسد...
139
- ترسم که شبی از غم ناگه بمیرم در بستر دلسوز با آه بمیرم آن لحظه آخر که اجل گفت بمیر ای کاش تو را بینم و آنگاه بمیرم
140
- خوشبختی را دیروز به حراج گذاشتند ولی حیف که من زاده ی امروزم... خدایا جهنم فرداست پس چرا امروز می سوزم ؟!!
141
- در کرانه محبت رنگ چشمانت را دیدم در عمق بی کران آسمان رخ تو را دیدم دیدنت آرزوی من شده آن را از من دریغ مکن
142
- کوه باشی صخره هایت می شوم... اشک باشی دیدگانت می شوم... رود باشی چشمه سارت می شوم... دوست باشی دوستدارت می شوم...!
143
- امتیاز شما به این کانالوبسایت پردیس ایرانیان آموزش دوره های تخصصی ماساژ زیر نظر مربی فنی و حرفه ای عنوان سایت : پردیس ایرانیان لینک سایت : http://pardisiranian.ir/ آی دی اینستاگرام : pardisiranian.info@ لینک اینستاگرام: […]
144
- Leave a commentپیج اینستاگرام, کانال برنزی, لینک سایتآموزش ماساژ زیر نظر فنی حرفه ای, آموزش ماساژ فنی حرفه ای, آموزش ماساژ فنی حرفه ای تهران, دوره ماساژ فدراسیون پزشکی ورزشی, فدراسیون بین المللی ماساژ ایران, ماساژ آروماتراپی, ماساژ تخصصی, ماساژ سوئدی, ماساژ لیفتینگ صورت, مربی آموزش ماساژ
145
- بایگانی کانالها گزینش ماه اکتبر 2018 سپتامبر 2018 آگوست 2018 جولای 2018 ژوئن 2018 می 2018 آوریل 2018 فوریه 2018
146
- کانال ها بر اساس موضوع گزینش دسته پیج اینستاگرام ربات پیام رسان بله سایر کانال ها کانال آموزشی کانال آی گپ کانال اجتماعی-سیاسی-فرهنگی کانال ایتا کانال برنزی کانال بله کانال بیسفون کانال تجاری کانال تفریحی کانال تلگرام کانال خبری کانال رسمی کانال سروش کانال سلام کانال شخصی کانال گپ کانال محلی کانال نقره ای کانال های طلایی کانال ویسپی لینک سایت مطالب سایت
147
- کانال تبلیغاتی "باکس آگهی" در پیام رسان بله عنوان کانال : کانال تبلیغاتی "باکس آگهی" در پیام رسان بله نام...
148
- عنوان سایت:کسب در آمد از طریق گوشی همراه و اپلیکیشن ۷۰۳۰ لینک سایت : https://7030.ir/r/sayed97 توضیحات کامل کانال:سامانه «هفتاد سی»...
149
- دکوراسیون داخلی سالن های زیبایی نام پیام رسان: تلگرام و اینستاگرام لینک کانال تلگرام : https://t.me/designer_ir لینک صفحه اینستاگرام :...
150
- نصب -تعمیر و فروش انواع کولرگازی عنوان کانال : نصب -تعمیر و فروش انواع کولرگازی نام پیام رسان: تلگرام لینک...
151
- کانال بازیگر ایرانی خانم شبنم قلی خانی عنوان کانال : کانال بازیگر ایرانی خانم شبنم قلی خانی نام پیام رسان:...
152
- تارنمای جامع ثبت کانال ;کانال سروش,کانال تلگرام,کانال بیسفون, کانال گپ, کانال ویسپی, کانال بیسفون پلاس, کانال ایتا, کانال آی گپ, کانال سلام, پیج اینستاگرام,کانال بله,گروه سروش,گروه تلگرامی,گروه بیسفون,گروه گپ,گروه آی گپ,گروه ایتا,گروه سلام,گروه بله,تبلیغات کانال تلگرام,تبلیغات کانال بیسفون,تبلیغات کانال گپ,تبلیغات کانال ویسپی,تبلیغات کانال بیسفون,تبلیغات کانال ایتا,تبلیغات کانال آی گپ,تبلیغات کانال سلام,تبلیغات کانال بله,گروه تبلیغاتی پیام رسان سروش,گروه تبلیغاتی تلگرام,گروه تبلیغاتی بیسفون,گروه تبلیغاتی گپ,گروه تبلیغاتی ویسپی,گروه تبلیغاتی آی گپ,گروه تبلیغاتی سلام,گروه تبلیغاتی بله,افزایش مببر سروش,خرید ممبر سروش,حرید ممبر تلگرام,افزایش ممبر تلگرام,خرید ممبر بیسفون,افزایش ممبر بیسفون,خرید ممبر ایتا,افزایش ممبر ایتا,خرید ممبرآی گپ,افزایش ممبر آی گپ,حرید ممبر گپ,افزایش ممبرگپ,خرید ممبرسلام,افزایش ممبر سلام,خرید ممبر بله ,افزایش ممبر بله,درگاه پیام رسان داخلی,فروشگاه پیام رسان سروش,فروشگاه پیام رسان بله,فروشگاه پیام رسان گپ,فروشگاه پیام رسان آِ گپ,تبلیغات پیام رسانهای داخلی
153
- آموزش ماساژ زیر نظر فنی حرفه ای آموزش ماساژ فنی حرفه ای آموزش ماساژ فنی حرفه ای تهران آون اتوکلاو ارگانیک تریتا انکوباتور انکوباتور یخچالدار تجهیز کامل آزمایشگاه غذایی تعمیر مبل تور آنتالیا تور ارمنستان تور استانبول تور چین تور گرجستان خرید ظرف شویی خرید کولر گازی دوره ماساژ فدراسیون پزشکی ورزشی ربات پیام رسان بله سانتریفیوژ ژربر سانتریفیوژ یونیورسال فدراسیون بین المللی ماساژ ایران فروشگاه ارگانیک تریتا فروشگاه ارگانیک در پیام رسان بله فروشگاه ارگانیک شیراز فروشگاه معتبر ارگانیک ماساژ آروماتراپی ماساژ تخصصی ماساژ سوئدی ماساژ لیفتینگ صورت مربی آموزش ماساژ هود شیمیایی هود لامینار پی اچ متر دیجیتال کانال آموزشی کانال باکس آگهی کانال بله بانک ملی کانال تریتا کانال تفریحی کانال پت شاپ کانال پیام رسان بله کانال گردشگری کلنی کانتر کوره الکتریکی کولر گازی
154
- تمامی حقوق مادی و معنوی سایت متعلق است به:تارنمای جامع ثبت و معرفی رایگان کانال پیام رسانهای ایرانی و خارجی
155
- """
156
-
157
  # input_text = " «هفتاد سی» "
158
- input_text = normalize(input_text)
159
  # input_text = DOUBLE_QUOTE_REGEX.sub('"', input_text)
160
- print(textwrap.fill(input_text))
161
  # print(normalize(input_text, tokenized=True))
 
25
  pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
26
  return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
27
 
28
+
29
+ def remove_adv_by_tag_name(text, tag_name):
30
+ found = text.find(tag_name)
31
+
32
+ if found > 0:
33
+ text = text[:found]
34
+
35
+ return text
36
+
37
 
38
  def clean_url(text):
39
  # removing html tags
 
88
  text = DOUBLE_QUOTE_REGEX.sub('"', text)
89
  text = CURRENCY_REGEX.sub(r" \1 ", text)
90
  text = clean_url(text)
91
+ text = remove_adv_by_tag_name(text, tag_name="برچسب ها :")
92
  text = URL_REGEX.sub(" ", text)
93
  text = EMAIL_REGEX.sub(" ", text)
94
  text = PHONE_REGEX.sub(r" \1 ", text)
 
130
  if __name__ == '__main__':
131
  import textwrap
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # input_text = " «هفتاد سی» "
134
+ # input_text = normalize(input_text)
135
  # input_text = DOUBLE_QUOTE_REGEX.sub('"', input_text)
136
+ # print(textwrap.fill(input_text))
137
  # print(normalize(input_text, tokenized=True))
src/run.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ #export MODEL_NAME_OR_PATH=t5-base
7
+ export OUTPUT_DIR=/home/username/code/gpt2-medium-persian
8
+ export MODEL_TYPE=gpt2
9
+ export CONFIG_NAME=/home/username/code/gpt2-medium-persian
10
+ export TOKENIZER_NAME=/home/username/code/gpt2-medium-persian
11
+
12
+ #export TRAIN_FILE=/home/username/code/data/...csv
13
+ #export VALIDATION_FILE=/home/username/code/data/...csv
14
+ #export TEST_FILE=/home/username/code/data/...csv
15
+ export DATASET_NAME=oscar
16
+ export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
17
+ export MAX_SEQUENCE_LENGTH=1024
18
+
19
+ #export MAX_TRAIN_SAMPLE=5000
20
+ #export MAX_EVAL_SAMPLES=5000
21
+
22
+ export PER_DEVICE_TRAIN_BATCH_SIZE=8
23
+ export PER_DEVICE_EVAL_BATCH_SIZE=8
24
+ export NUM_TRAIN_EPOCHS=10.0
25
+ export LEARNING_RATE=1e-3
26
+ export WARMUP_STEPS=5000
27
+ export LOGGING_STEPS=500
28
+ export EVAL_STEPS=2500
29
+ export SAVE_STEPS=2500
30
+
31
+ python src/run_clm.py \
32
+ --output_dir="$OUTPUT_DIR" \
33
+ --model_type="$MODEL_TYPE" \
34
+ --config_name="$CONFIG_NAME" \
35
+ --tokenizer_name="$TOKENIZER_NAME" \
36
+ --dataset_name="$DATASET_NAME" \
37
+ --dataset_config_name="$DATASET_CONFIG_NAME" \
38
+ --max_seq_length="$MAX_SEQUENCE_LENGTH" \
39
+ --per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
40
+ --per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
41
+ --num_train_epochs=$NUM_TRAIN_EPOCHS \
42
+ --learning_rate=$LEARNING_RATE \
43
+ --warmup_steps=$WARMUP_STEPS \
44
+ --logging_step=$LOGGING_STEPS \
45
+ --eval_steps=$EVAL_STEPS \
46
+ --save_steps=$SAVE_STEPS \
47
+ --do_train \
48
+ --do_eval \
49
+ --overwrite_output_dir \
50
+ --push_to_hub
src/run_clm_flax.py CHANGED
@@ -60,6 +60,7 @@ from data_utils import (
60
  filter_by_lang_regex,
61
  filter_by_num_tokens,
62
  filter_by_num_sents,
 
63
  normalizer
64
  )
65
 
@@ -359,8 +360,9 @@ def main():
359
  # https://huggingface.co/docs/datasets/loading_datasets.html.
360
  logger.info("Preprocessing the dataset")
361
  dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
362
- dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=128))
363
  dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
 
364
  dataset = dataset.map(normalizer)
365
  logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
366
 
@@ -461,7 +463,8 @@ def main():
461
  total_length = len(concatenated_examples[list(examples.keys())[0]])
462
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
463
  # customize this part to your needs.
464
- total_length = (total_length // block_size) * block_size
 
465
  # Split by chunks of max_len.
466
  result = {
467
  k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
@@ -538,17 +541,24 @@ def main():
538
  return traverse_util.unflatten_dict(flat_mask)
539
 
540
  # create adam optimizer
541
- adamw = optax.adamw(
542
- learning_rate=linear_decay_lr_schedule_fn,
543
- b1=training_args.adam_beta1,
544
- b2=training_args.adam_beta2,
545
- eps=training_args.adam_epsilon,
546
- weight_decay=training_args.weight_decay,
547
- mask=decay_mask_fn,
548
- )
 
 
 
 
 
 
 
549
 
550
  # Setup train state
551
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
552
 
553
  def loss_fn(logits, labels):
554
  shift_logits = logits[..., :-1, :]
@@ -623,11 +633,12 @@ def main():
623
 
624
  cur_step = epoch * (len(train_dataset) // train_batch_size) + step
625
 
626
- if cur_step % training_args.logging_steps and cur_step > 0:
627
  # Save metrics
628
  train_metric = unreplicate(train_metric)
629
  train_time += time.time() - train_start
630
  if has_tensorboard and jax.process_index() == 0:
 
631
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
632
 
633
  epochs.write(
@@ -636,45 +647,57 @@ def main():
636
 
637
  train_metrics = []
638
 
639
- # ======================== Evaluating ==============================
640
- eval_metrics = []
641
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
642
- eval_steps = len(eval_dataset) // eval_batch_size
643
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
644
- # Model forward
645
- batch = next(eval_loader)
646
- metrics = p_eval_step(state.params, batch)
647
- eval_metrics.append(metrics)
648
-
649
- # normalize eval metrics
650
- eval_metrics = get_metrics(eval_metrics)
651
-
652
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
653
-
654
- try:
655
- eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
656
- except OverflowError:
657
- eval_metrics["perplexity"] = float("inf")
658
-
659
- # Print metrics and update progress bar
660
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
661
- epochs.write(desc)
662
- epochs.desc = desc
663
-
664
- # Save metrics
665
- if has_tensorboard and jax.process_index() == 0:
666
- cur_step = epoch * (len(train_dataset) // train_batch_size)
667
- write_eval_metric(summary_writer, eval_metrics, cur_step)
668
-
669
- # save checkpoint after each epoch and push checkpoint to the hub
670
- if jax.process_index() == 0:
671
- params = jax.device_get(unreplicate(state.params))
672
- model.save_pretrained(
673
- training_args.output_dir,
674
- params=params,
675
- push_to_hub=training_args.push_to_hub,
676
- commit_message=f"Saving weights and logs of epoch {epoch + 1}",
677
- )
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
 
680
  if __name__ == "__main__":
 
60
  filter_by_lang_regex,
61
  filter_by_num_tokens,
62
  filter_by_num_sents,
63
+ filter_by_adv,
64
  normalizer
65
  )
66
 
 
360
  # https://huggingface.co/docs/datasets/loading_datasets.html.
361
  logger.info("Preprocessing the dataset")
362
  dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
363
+ dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
364
  dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
365
+ dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
366
  dataset = dataset.map(normalizer)
367
  logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
368
 
 
463
  total_length = len(concatenated_examples[list(examples.keys())[0]])
464
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
465
  # customize this part to your needs.
466
+ if total_length >= block_size:
467
+ total_length = (total_length // block_size) * block_size
468
  # Split by chunks of max_len.
469
  result = {
470
  k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
 
541
  return traverse_util.unflatten_dict(flat_mask)
542
 
543
  # create adam optimizer
544
+ if training_args.adafactor:
545
+ # We use the default parameters here to initialize adafactor,
546
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
547
+ optimizer = optax.adafactor(
548
+ learning_rate=linear_decay_lr_schedule_fn,
549
+ )
550
+ else:
551
+ optimizer = optax.adamw(
552
+ learning_rate=linear_decay_lr_schedule_fn,
553
+ b1=training_args.adam_beta1,
554
+ b2=training_args.adam_beta2,
555
+ eps=training_args.adam_epsilon,
556
+ weight_decay=training_args.weight_decay,
557
+ mask=decay_mask_fn,
558
+ )
559
 
560
  # Setup train state
561
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
562
 
563
  def loss_fn(logits, labels):
564
  shift_logits = logits[..., :-1, :]
 
633
 
634
  cur_step = epoch * (len(train_dataset) // train_batch_size) + step
635
 
636
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
637
  # Save metrics
638
  train_metric = unreplicate(train_metric)
639
  train_time += time.time() - train_start
640
  if has_tensorboard and jax.process_index() == 0:
641
+ logger.info(f"*** Writing training summary after {cur_step} steps ***")
642
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
643
 
644
  epochs.write(
 
647
 
648
  train_metrics = []
649
 
650
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
651
+ logger.info(f"*** Evaluation after {cur_step} steps ***")
652
+
653
+ eval_metrics = []
654
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
655
+ eval_steps = len(eval_dataset) // eval_batch_size
656
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
657
+ # Model forward
658
+ batch = next(eval_loader)
659
+ metrics = p_eval_step(state.params, batch)
660
+ eval_metrics.append(metrics)
661
+
662
+ # normalize eval metrics
663
+ eval_metrics = get_metrics(eval_metrics)
664
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
665
+
666
+ try:
667
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
668
+ except OverflowError:
669
+ eval_metrics["perplexity"] = float("inf")
670
+
671
+ # Print metrics and update progress bar
672
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
673
+ epochs.write(desc)
674
+ epochs.desc = desc
675
+
676
+ # Save metrics
677
+ if has_tensorboard and jax.process_index() == 0:
678
+ logger.info(f"*** Writing evaluation summary after {cur_step} steps ***")
679
+ # cur_step = epoch * (len(train_dataset) // train_batch_size)
680
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
681
+
682
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
683
+ logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
684
+ # save checkpoint after each epoch and push checkpoint to the hub
685
+ if jax.process_index() == 0:
686
+ params = jax.device_get(unreplicate(state.params))
687
+ model.save_pretrained(
688
+ training_args.output_dir,
689
+ params=params,
690
+ push_to_hub=training_args.push_to_hub,
691
+ commit_message=f"Saving weights and logs of step {cur_step}",
692
+ )
693
+
694
+ if not os.path.exists(os.path.join(training_args.output_dir, "tokenizer.json")):
695
+ logger.info(f"*** Saving tokenizer ***")
696
+ tokenizer.save_pretrained(
697
+ training_args.output_dir,
698
+ push_to_hub=training_args.push_to_hub,
699
+ commit_message=f"Saving tokenizer",
700
+ )
701
 
702
 
703
  if __name__ == "__main__":
src/run_config.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ export OUTPUT_DIR=./
7
+ #export OUTPUT_DIR=/home/username/code/gpt2-medium-persian
8
+ export NAME_OR_PATH=gpt2-medium
9
+
10
+ python src/create_config.py \
11
+ --output_dir="$OUTPUT_DIR" \
12
+ --name_or_path="$NAME_OR_PATH" \
13
+ --params='{"vocab_size": 50000}'
src/run_tokenizer.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ export OUTPUT_DIR=/home/username/code/gpt2-medium-persian
7
+ export DATASET_NAME=oscar
8
+ export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
9
+ export VOCAB_SIZE=50000
10
+ export MIN_FREQUENCY=2
11
+ export SPECIAL_TOKENS='<s>','<pad>','</s>','<unk>','<mask>','<|endoftext|>','<|startoftext|>','<sep>','<cls>','<nl>','<tab>','<zwnj>','[U1]','[U2]','[U3]','[U4]','[U5]','[U6]','[U7]','[U8]','[U9]','[U10]','[U11]','[U12]','[U13]','[U14]','[U15]','[U16]','[U17]','[U18]','[U19]','[U20]'
12
+
13
+
14
+ python src/train_tokenizer.py \
15
+ --output_dir="$OUTPUT_DIR" \
16
+ --dataset_name="$DATASET_NAME" \
17
+ --dataset_config_name="$DATASET_CONFIG_NAME" \
18
+ --vocab_size=$VOCAB_SIZE \
19
+ --min_frequency=$MIN_FREQUENCY \
20
+ --special_tokens="$SPECIAL_TOKENS"
src/train_tokenizer.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  import os
4
  import sys
5
  from dataclasses import dataclass, field
6
- from typing import Dict, List, Optional, Tuple
7
 
8
  from datasets import load_dataset
9
  from tokenizers import ByteLevelBPETokenizer
@@ -15,6 +15,7 @@ from data_utils import (
15
  filter_by_lang_regex,
16
  filter_by_num_tokens,
17
  filter_by_num_sents,
 
18
  normalizer
19
  )
20
 
@@ -42,12 +43,12 @@ class TokenizerArguments:
42
  default=None,
43
  metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
44
  )
45
- special_tokens: Optional[List[str]] = field(
46
  default=None,
47
  metadata={"help": "The list of special tokens that you want to add in your training."}
48
  )
49
  vocab_size: Optional[int] = field(
50
- default=50257,
51
  metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
52
  )
53
  min_frequency: Optional[int] = field(
@@ -61,13 +62,16 @@ class TokenizerArguments:
61
 
62
  def __post_init__(self):
63
  if self.special_tokens is None:
64
- self.special_tokens = [
65
  "<s>", "<pad>", "</s>", "<unk>", "<mask>",
66
  "<|endoftext|>", "<|startoftext|>",
67
  "<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
68
  ]
 
 
 
69
 
70
- self.special_tokens = self.special_tokens + [f"[U{i}]" for i in range(1, 21)]
71
  if self.dataset_name is None and self.train_file is None:
72
  raise ValueError("Need either a dataset name or a training file.")
73
  else:
@@ -100,7 +104,7 @@ def main():
100
  tokenizer_args.dataset_name,
101
  tokenizer_args.dataset_config_name,
102
  cache_dir=tokenizer_args.cache_dir,
103
- split="train[:10%]"
104
  )
105
  else:
106
  data_files = {"train": tokenizer_args.train_file}
@@ -119,6 +123,7 @@ def main():
119
  dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
120
  dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
121
  dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
 
122
  dataset = dataset.map(normalizer)
123
  logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
124
 
 
3
  import os
4
  import sys
5
  from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Tuple, Union, Any
7
 
8
  from datasets import load_dataset
9
  from tokenizers import ByteLevelBPETokenizer
 
15
  filter_by_lang_regex,
16
  filter_by_num_tokens,
17
  filter_by_num_sents,
18
+ filter_by_adv,
19
  normalizer
20
  )
21
 
 
43
  default=None,
44
  metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
45
  )
46
+ special_tokens: Optional[str] = field(
47
  default=None,
48
  metadata={"help": "The list of special tokens that you want to add in your training."}
49
  )
50
  vocab_size: Optional[int] = field(
51
+ default=56000,
52
  metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
53
  )
54
  min_frequency: Optional[int] = field(
 
62
 
63
  def __post_init__(self):
64
  if self.special_tokens is None:
65
+ special_tokens = [
66
  "<s>", "<pad>", "</s>", "<unk>", "<mask>",
67
  "<|endoftext|>", "<|startoftext|>",
68
  "<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
69
  ]
70
+ special_tokens += [f"[U{i}]" for i in range(1, 21)]
71
+ else:
72
+ special_tokens = list(self.special_tokens.split(","))
73
 
74
+ self.special_tokens = special_tokens
75
  if self.dataset_name is None and self.train_file is None:
76
  raise ValueError("Need either a dataset name or a training file.")
77
  else:
 
104
  tokenizer_args.dataset_name,
105
  tokenizer_args.dataset_config_name,
106
  cache_dir=tokenizer_args.cache_dir,
107
+ split="train"
108
  )
109
  else:
110
  data_files = {"train": tokenizer_args.train_file}
 
123
  dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
124
  dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
125
  dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
126
+ dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
127
  dataset = dataset.map(normalizer)
128
  logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
129