Add runner, fix some bugs
Browse files- src/data_utils.py +5 -4
- src/normalizer.py +12 -36
- src/run.sh +50 -0
- src/run_clm_flax.py +74 -51
- src/run_config.sh +13 -0
- src/run_tokenizer.sh +20 -0
- src/train_tokenizer.py +11 -6
src/data_utils.py
CHANGED
@@ -22,13 +22,14 @@ def filter_by_num_tokens(text, gt=64):
|
|
22 |
def filter_by_num_sents(text, gt=2):
|
23 |
return len(sent_tokenize(text)) > gt
|
24 |
|
25 |
-
|
|
|
26 |
comma = text.split(",")
|
27 |
-
colon = re.findall(r
|
28 |
virgool = text.split("،")
|
29 |
-
length_add = len(comma)+len(colon)+len(virgool)
|
30 |
|
31 |
-
return
|
32 |
|
33 |
|
34 |
def normalizer(text, do_lowercase=False):
|
|
|
22 |
def filter_by_num_sents(text, gt=2):
|
23 |
return len(sent_tokenize(text)) > gt
|
24 |
|
25 |
+
|
26 |
+
def filter_by_adv(text, ratio=50):
|
27 |
comma = text.split(",")
|
28 |
+
colon = re.findall(r"""(?:([^\W]+):([^\W]+))""", text)
|
29 |
virgool = text.split("،")
|
30 |
+
length_add = len(comma) + len(colon) + len(virgool)
|
31 |
|
32 |
+
return length_add < ratio
|
33 |
|
34 |
|
35 |
def normalizer(text, do_lowercase=False):
|
src/normalizer.py
CHANGED
@@ -25,13 +25,15 @@ def multiple_replace(text, chars_to_mapping):
|
|
25 |
pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
|
26 |
return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
36 |
def clean_url(text):
|
37 |
# removing html tags
|
@@ -86,7 +88,7 @@ def normalize(text, zwnj="\u200c", tokenized=False):
|
|
86 |
text = DOUBLE_QUOTE_REGEX.sub('"', text)
|
87 |
text = CURRENCY_REGEX.sub(r" \1 ", text)
|
88 |
text = clean_url(text)
|
89 |
-
text =
|
90 |
text = URL_REGEX.sub(" ", text)
|
91 |
text = EMAIL_REGEX.sub(" ", text)
|
92 |
text = PHONE_REGEX.sub(r" \1 ", text)
|
@@ -128,34 +130,8 @@ def normalize(text, zwnj="\u200c", tokenized=False):
|
|
128 |
if __name__ == '__main__':
|
129 |
import textwrap
|
130 |
|
131 |
-
# input_text = "دارهٔ تحقیقات فدرال در سال ۱۹۰۸ به نام ادارهٔ تحقیقات (BOI یا BI) بنیانگذاری شد. نام این سازمان در سال ۱۹۳۵ به ادارهٔ تحقیقات فدرال تغییر یافت. دفتر مرکزی افبیآی در ساختمان جی. ادگار هوور در شهر واشینگتن، دی.سی. واقع شدهاست."
|
132 |
-
# input_text = "یونان (به یونانی: Ελλάδα, اِلادا)"
|
133 |
-
# input_text = "نسخهٔ"
|
134 |
-
# input_text = "ὑ蕉Ұ제ṅ尘̲改座◦花芝秀黄天자埃澤ಿ ˈazbab اینجا ایران خانهشما است؟!۱۲۳۱۲۳۱۳۱۲ اَلْحُرُوفُ ٱلْعَرَبِیَّة"
|
135 |
-
input_text = """
|
136 |
-
قـــــــــــــــــرار بود با هم کنـــــــــــــار بیایم نه اینکه از کنــــــــــــار هم رد بشیم...!!!
|
137 |
-
اگر روزی دلت لبریز غم بود گذارت بر مزار کهنه ام بود بگو این بی نصیب خفته در خاک یه روزی عاشق و دیوانه ام بود...
|
138 |
-
خبر به دورترین نقطه جهان برسد نخواست او به من خسته ، بی کمان برسد شکنجه بیشتر از این که پیش چشم خودت کسی که سهم تو باشد به دیگران برسد خدا کند ، که نفرین نمی کنم نکند به او که عاشق او بوده ام زیان برسد خدا کند فقط این عشق از سرم برود خدا کند که فقط زود آن زمان برسد...
|
139 |
-
ترسم که شبی از غم ناگه بمیرم در بستر دلسوز با آه بمیرم آن لحظه آخر که اجل گفت بمیر ای کاش تو را بینم و آنگاه بمیرم
|
140 |
-
خوشبختی را دیروز به حراج گذاشتند ولی حیف که من زاده ی امروزم... خدایا جهنم فرداست پس چرا امروز می سوزم ؟!!
|
141 |
-
در کرانه محبت رنگ چشمانت را دیدم در عمق بی کران آسمان رخ تو را دیدم دیدنت آرزوی من شده آن را از من دریغ مکن
|
142 |
-
کوه باشی صخره هایت می شوم... اشک باشی دیدگانت می شوم... رود باشی چشمه سارت می شوم... دوست باشی دوستدارت می شوم...!
|
143 |
-
امتیاز شما به این کانالوبسایت پردیس ایرانیان آموزش دوره های تخصصی ماساژ زیر نظر مربی فنی و حرفه ای عنوان سایت : پردیس ایرانیان لینک سایت : http://pardisiranian.ir/ آی دی اینستاگرام : pardisiranian.info@ لینک اینستاگرام: […]
|
144 |
-
Leave a commentپیج اینستاگرام, کانال برنزی, لینک سایتآموزش ماساژ زیر نظر فنی حرفه ای, آموزش ماساژ فنی حرفه ای, آموزش ماساژ فنی حرفه ای تهران, دوره ماساژ فدراسیون پزشکی ورزشی, فدراسیون بین المللی ماساژ ایران, ماساژ آروماتراپی, ماساژ تخصصی, ماساژ سوئدی, ماساژ لیفتینگ صورت, مربی آموزش ماساژ
|
145 |
-
بایگانی کانالها گزینش ماه اکتبر 2018 سپتامبر 2018 آگوست 2018 جولای 2018 ژوئن 2018 می 2018 آوریل 2018 فوریه 2018
|
146 |
-
کانال ها بر اساس موضوع گزینش دسته پیج اینستاگرام ربات پیام رسان بله سایر کانال ها کانال آموزشی کانال آی گپ کانال اجتماعی-سیاسی-فرهنگی کانال ایتا کانال برنزی کانال بله کانال بیسفون کانال تجاری کانال تفریحی کانال تلگرام کانال خبری کانال رسمی کانال سروش کانال سلام کانال شخصی کانال گپ کانال محلی کانال نقره ای کانال های طلایی کانال ویسپی لینک سایت مطالب سایت
|
147 |
-
کانال تبلیغاتی "باکس آگهی" در پیام رسان بله عنوان کانال : کانال تبلیغاتی "باکس آگهی" در پیام رسان بله نام...
|
148 |
-
عنوان سایت:کسب در آمد از طریق گوشی همراه و اپلیکیشن ۷۰۳۰ لینک سایت : https://7030.ir/r/sayed97 توضیحات کامل کانال:سامانه «هفتاد سی»...
|
149 |
-
دکوراسیون داخلی سالن های زیبایی نام پیام رسان: تلگرام و اینستاگرام لینک کانال تلگرام : https://t.me/designer_ir لینک صفحه اینستاگرام :...
|
150 |
-
نصب -تعمیر و فروش انواع کولرگازی عنوان کانال : نصب -تعمیر و فروش انواع کولرگازی نام پیام رسان: تلگرام لینک...
|
151 |
-
کانال بازیگر ایرانی خانم شبنم قلی خانی عنوان کانال : کانال بازیگر ایرانی خانم شبنم قلی خانی نام پیام رسان:...
|
152 |
-
تارنمای جامع ثبت کانال ;کانال سروش,کانال تلگرام,کانال بیسفون, کانال گپ, کانال ویسپی, کانال بیسفون پلاس, کانال ایتا, کانال آی گپ, کانال سلام, پیج اینستاگرام,کانال بله,گروه سروش,گروه تلگرامی,گروه بیسفون,گروه گپ,گروه آی گپ,گروه ایتا,گروه سلام,گروه بله,تبلیغات کانال تلگرام,تبلیغات کانال بیسفون,تبلیغات کانال گپ,تبلیغات کانال ویسپی,تبلیغات کانال بیسفون,تبلیغات کانال ایتا,تبلیغات کانال آی گپ,تبلیغات کانال سلام,تبلیغات کانال بله,گروه تبلیغاتی پیام رسان سروش,گروه تبلیغاتی تلگرام,گروه تبلیغاتی بیسفون,گروه تبلیغاتی گپ,گروه تبلیغاتی ویسپی,گروه تبلیغاتی آی گپ,گروه تبلیغاتی سلام,گروه تبلیغاتی بله,افزایش مببر سروش,خرید ممبر سروش,حرید ممبر تلگرام,افزایش ممبر تلگرام,خرید ممبر بیسفون,افزایش ممبر بیسفون,خرید ممبر ایتا,افزایش ممبر ایتا,خرید ممبرآی گپ,افزایش ممبر آی گپ,حرید ممبر گپ,افزایش ممبرگپ,خرید ممبرسلام,افزایش ممبر سلام,خرید ممبر بله ,افزایش ممبر بله,درگاه پیام رسان داخلی,فروشگاه پیام رسان سروش,فروشگاه پیام رسان بله,فروشگاه پیام رسان گپ,فروشگاه پیام رسان آِ گپ,تبلیغات پیام رسانهای داخلی
|
153 |
-
آموزش ماساژ زیر نظر فنی حرفه ای آموزش ماساژ فنی حرفه ای آموزش ماساژ فنی حرفه ای تهران آون اتوکلاو ارگانیک تریتا انکوباتور انکوباتور یخچالدار تجهیز کامل آزمایشگاه غذایی تعمیر مبل تور آنتالیا تور ارمنستان تور استانبول تور چین تور گرجستان خرید ظرف شویی خرید کولر گازی دوره ماساژ فدراسیون پزشکی ورزشی ربات پیام رسان بله سانتریفیوژ ژربر سانتریفیوژ یونیورسال فدراسیون بین المللی ماساژ ایران فروشگاه ارگانیک تریتا فروشگاه ارگانیک در پیام رسان بله فروشگاه ارگانیک شیراز فروشگاه معتبر ارگانیک ماساژ آروماتراپی ماساژ تخصصی ماساژ سوئدی ماساژ لیفتینگ صورت مربی آموزش ماساژ هود شیمیایی هود لامینار پی اچ متر دیجیتال کانال آموزشی کانال باکس آگهی کانال بله بانک ملی کانال تریتا کانال تفریحی کانال پت شاپ کانال پیام رسان بله کانال گردشگری کلنی کانتر کوره الکتریکی کولر گازی
|
154 |
-
تمامی حقوق مادی و معنوی سایت متعلق است به:تارنمای جامع ثبت و معرفی رایگان کانال پیام رسانهای ایرانی و خارجی
|
155 |
-
"""
|
156 |
-
|
157 |
# input_text = " «هفتاد سی» "
|
158 |
-
input_text = normalize(input_text)
|
159 |
# input_text = DOUBLE_QUOTE_REGEX.sub('"', input_text)
|
160 |
-
print(textwrap.fill(input_text))
|
161 |
# print(normalize(input_text, tokenized=True))
|
|
|
25 |
pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
|
26 |
return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
|
27 |
|
28 |
+
|
29 |
+
def remove_adv_by_tag_name(text, tag_name):
|
30 |
+
found = text.find(tag_name)
|
31 |
+
|
32 |
+
if found > 0:
|
33 |
+
text = text[:found]
|
34 |
+
|
35 |
+
return text
|
36 |
+
|
37 |
|
38 |
def clean_url(text):
|
39 |
# removing html tags
|
|
|
88 |
text = DOUBLE_QUOTE_REGEX.sub('"', text)
|
89 |
text = CURRENCY_REGEX.sub(r" \1 ", text)
|
90 |
text = clean_url(text)
|
91 |
+
text = remove_adv_by_tag_name(text, tag_name="برچسب ها :")
|
92 |
text = URL_REGEX.sub(" ", text)
|
93 |
text = EMAIL_REGEX.sub(" ", text)
|
94 |
text = PHONE_REGEX.sub(r" \1 ", text)
|
|
|
130 |
if __name__ == '__main__':
|
131 |
import textwrap
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
# input_text = " «هفتاد سی» "
|
134 |
+
# input_text = normalize(input_text)
|
135 |
# input_text = DOUBLE_QUOTE_REGEX.sub('"', input_text)
|
136 |
+
# print(textwrap.fill(input_text))
|
137 |
# print(normalize(input_text, tokenized=True))
|
src/run.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export LC_ALL=C.UTF-8
|
4 |
+
export LANG=C.UTF-8
|
5 |
+
|
6 |
+
#export MODEL_NAME_OR_PATH=t5-base
|
7 |
+
export OUTPUT_DIR=/home/username/code/gpt2-medium-persian
|
8 |
+
export MODEL_TYPE=gpt2
|
9 |
+
export CONFIG_NAME=/home/username/code/gpt2-medium-persian
|
10 |
+
export TOKENIZER_NAME=/home/username/code/gpt2-medium-persian
|
11 |
+
|
12 |
+
#export TRAIN_FILE=/home/username/code/data/...csv
|
13 |
+
#export VALIDATION_FILE=/home/username/code/data/...csv
|
14 |
+
#export TEST_FILE=/home/username/code/data/...csv
|
15 |
+
export DATASET_NAME=oscar
|
16 |
+
export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
|
17 |
+
export MAX_SEQUENCE_LENGTH=1024
|
18 |
+
|
19 |
+
#export MAX_TRAIN_SAMPLE=5000
|
20 |
+
#export MAX_EVAL_SAMPLES=5000
|
21 |
+
|
22 |
+
export PER_DEVICE_TRAIN_BATCH_SIZE=8
|
23 |
+
export PER_DEVICE_EVAL_BATCH_SIZE=8
|
24 |
+
export NUM_TRAIN_EPOCHS=10.0
|
25 |
+
export LEARNING_RATE=1e-3
|
26 |
+
export WARMUP_STEPS=5000
|
27 |
+
export LOGGING_STEPS=500
|
28 |
+
export EVAL_STEPS=2500
|
29 |
+
export SAVE_STEPS=2500
|
30 |
+
|
31 |
+
python src/run_clm.py \
|
32 |
+
--output_dir="$OUTPUT_DIR" \
|
33 |
+
--model_type="$MODEL_TYPE" \
|
34 |
+
--config_name="$CONFIG_NAME" \
|
35 |
+
--tokenizer_name="$TOKENIZER_NAME" \
|
36 |
+
--dataset_name="$DATASET_NAME" \
|
37 |
+
--dataset_config_name="$DATASET_CONFIG_NAME" \
|
38 |
+
--max_seq_length="$MAX_SEQUENCE_LENGTH" \
|
39 |
+
--per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
|
40 |
+
--per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
|
41 |
+
--num_train_epochs=$NUM_TRAIN_EPOCHS \
|
42 |
+
--learning_rate=$LEARNING_RATE \
|
43 |
+
--warmup_steps=$WARMUP_STEPS \
|
44 |
+
--logging_step=$LOGGING_STEPS \
|
45 |
+
--eval_steps=$EVAL_STEPS \
|
46 |
+
--save_steps=$SAVE_STEPS \
|
47 |
+
--do_train \
|
48 |
+
--do_eval \
|
49 |
+
--overwrite_output_dir \
|
50 |
+
--push_to_hub
|
src/run_clm_flax.py
CHANGED
@@ -60,6 +60,7 @@ from data_utils import (
|
|
60 |
filter_by_lang_regex,
|
61 |
filter_by_num_tokens,
|
62 |
filter_by_num_sents,
|
|
|
63 |
normalizer
|
64 |
)
|
65 |
|
@@ -359,8 +360,9 @@ def main():
|
|
359 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
360 |
logger.info("Preprocessing the dataset")
|
361 |
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
|
362 |
-
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=
|
363 |
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
|
|
|
364 |
dataset = dataset.map(normalizer)
|
365 |
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
|
366 |
|
@@ -461,7 +463,8 @@ def main():
|
|
461 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
462 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
463 |
# customize this part to your needs.
|
464 |
-
|
|
|
465 |
# Split by chunks of max_len.
|
466 |
result = {
|
467 |
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
@@ -538,17 +541,24 @@ def main():
|
|
538 |
return traverse_util.unflatten_dict(flat_mask)
|
539 |
|
540 |
# create adam optimizer
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
550 |
# Setup train state
|
551 |
-
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=
|
552 |
|
553 |
def loss_fn(logits, labels):
|
554 |
shift_logits = logits[..., :-1, :]
|
@@ -623,11 +633,12 @@ def main():
|
|
623 |
|
624 |
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
|
625 |
|
626 |
-
if cur_step % training_args.logging_steps and cur_step > 0:
|
627 |
# Save metrics
|
628 |
train_metric = unreplicate(train_metric)
|
629 |
train_time += time.time() - train_start
|
630 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
631 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
632 |
|
633 |
epochs.write(
|
@@ -636,45 +647,57 @@ def main():
|
|
636 |
|
637 |
train_metrics = []
|
638 |
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
|
680 |
if __name__ == "__main__":
|
|
|
60 |
filter_by_lang_regex,
|
61 |
filter_by_num_tokens,
|
62 |
filter_by_num_sents,
|
63 |
+
filter_by_adv,
|
64 |
normalizer
|
65 |
)
|
66 |
|
|
|
360 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
361 |
logger.info("Preprocessing the dataset")
|
362 |
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
|
363 |
+
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
|
364 |
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
|
365 |
+
dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
|
366 |
dataset = dataset.map(normalizer)
|
367 |
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
|
368 |
|
|
|
463 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
464 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
465 |
# customize this part to your needs.
|
466 |
+
if total_length >= block_size:
|
467 |
+
total_length = (total_length // block_size) * block_size
|
468 |
# Split by chunks of max_len.
|
469 |
result = {
|
470 |
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
|
|
541 |
return traverse_util.unflatten_dict(flat_mask)
|
542 |
|
543 |
# create adam optimizer
|
544 |
+
if training_args.adafactor:
|
545 |
+
# We use the default parameters here to initialize adafactor,
|
546 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
547 |
+
optimizer = optax.adafactor(
|
548 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
549 |
+
)
|
550 |
+
else:
|
551 |
+
optimizer = optax.adamw(
|
552 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
553 |
+
b1=training_args.adam_beta1,
|
554 |
+
b2=training_args.adam_beta2,
|
555 |
+
eps=training_args.adam_epsilon,
|
556 |
+
weight_decay=training_args.weight_decay,
|
557 |
+
mask=decay_mask_fn,
|
558 |
+
)
|
559 |
|
560 |
# Setup train state
|
561 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
562 |
|
563 |
def loss_fn(logits, labels):
|
564 |
shift_logits = logits[..., :-1, :]
|
|
|
633 |
|
634 |
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
|
635 |
|
636 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
637 |
# Save metrics
|
638 |
train_metric = unreplicate(train_metric)
|
639 |
train_time += time.time() - train_start
|
640 |
if has_tensorboard and jax.process_index() == 0:
|
641 |
+
logger.info(f"*** Writing training summary after {cur_step} steps ***")
|
642 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
643 |
|
644 |
epochs.write(
|
|
|
647 |
|
648 |
train_metrics = []
|
649 |
|
650 |
+
if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
|
651 |
+
logger.info(f"*** Evaluation after {cur_step} steps ***")
|
652 |
+
|
653 |
+
eval_metrics = []
|
654 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
655 |
+
eval_steps = len(eval_dataset) // eval_batch_size
|
656 |
+
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
657 |
+
# Model forward
|
658 |
+
batch = next(eval_loader)
|
659 |
+
metrics = p_eval_step(state.params, batch)
|
660 |
+
eval_metrics.append(metrics)
|
661 |
+
|
662 |
+
# normalize eval metrics
|
663 |
+
eval_metrics = get_metrics(eval_metrics)
|
664 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
665 |
+
|
666 |
+
try:
|
667 |
+
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
668 |
+
except OverflowError:
|
669 |
+
eval_metrics["perplexity"] = float("inf")
|
670 |
+
|
671 |
+
# Print metrics and update progress bar
|
672 |
+
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
673 |
+
epochs.write(desc)
|
674 |
+
epochs.desc = desc
|
675 |
+
|
676 |
+
# Save metrics
|
677 |
+
if has_tensorboard and jax.process_index() == 0:
|
678 |
+
logger.info(f"*** Writing evaluation summary after {cur_step} steps ***")
|
679 |
+
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
680 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
681 |
+
|
682 |
+
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
683 |
+
logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
|
684 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
685 |
+
if jax.process_index() == 0:
|
686 |
+
params = jax.device_get(unreplicate(state.params))
|
687 |
+
model.save_pretrained(
|
688 |
+
training_args.output_dir,
|
689 |
+
params=params,
|
690 |
+
push_to_hub=training_args.push_to_hub,
|
691 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
692 |
+
)
|
693 |
+
|
694 |
+
if not os.path.exists(os.path.join(training_args.output_dir, "tokenizer.json")):
|
695 |
+
logger.info(f"*** Saving tokenizer ***")
|
696 |
+
tokenizer.save_pretrained(
|
697 |
+
training_args.output_dir,
|
698 |
+
push_to_hub=training_args.push_to_hub,
|
699 |
+
commit_message=f"Saving tokenizer",
|
700 |
+
)
|
701 |
|
702 |
|
703 |
if __name__ == "__main__":
|
src/run_config.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export LC_ALL=C.UTF-8
|
4 |
+
export LANG=C.UTF-8
|
5 |
+
|
6 |
+
export OUTPUT_DIR=./
|
7 |
+
#export OUTPUT_DIR=/home/username/code/gpt2-medium-persian
|
8 |
+
export NAME_OR_PATH=gpt2-medium
|
9 |
+
|
10 |
+
python src/create_config.py \
|
11 |
+
--output_dir="$OUTPUT_DIR" \
|
12 |
+
--name_or_path="$NAME_OR_PATH" \
|
13 |
+
--params='{"vocab_size": 50000}'
|
src/run_tokenizer.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export LC_ALL=C.UTF-8
|
4 |
+
export LANG=C.UTF-8
|
5 |
+
|
6 |
+
export OUTPUT_DIR=/home/username/code/gpt2-medium-persian
|
7 |
+
export DATASET_NAME=oscar
|
8 |
+
export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
|
9 |
+
export VOCAB_SIZE=50000
|
10 |
+
export MIN_FREQUENCY=2
|
11 |
+
export SPECIAL_TOKENS='<s>','<pad>','</s>','<unk>','<mask>','<|endoftext|>','<|startoftext|>','<sep>','<cls>','<nl>','<tab>','<zwnj>','[U1]','[U2]','[U3]','[U4]','[U5]','[U6]','[U7]','[U8]','[U9]','[U10]','[U11]','[U12]','[U13]','[U14]','[U15]','[U16]','[U17]','[U18]','[U19]','[U20]'
|
12 |
+
|
13 |
+
|
14 |
+
python src/train_tokenizer.py \
|
15 |
+
--output_dir="$OUTPUT_DIR" \
|
16 |
+
--dataset_name="$DATASET_NAME" \
|
17 |
+
--dataset_config_name="$DATASET_CONFIG_NAME" \
|
18 |
+
--vocab_size=$VOCAB_SIZE \
|
19 |
+
--min_frequency=$MIN_FREQUENCY \
|
20 |
+
--special_tokens="$SPECIAL_TOKENS"
|
src/train_tokenizer.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3 |
import os
|
4 |
import sys
|
5 |
from dataclasses import dataclass, field
|
6 |
-
from typing import Dict, List, Optional, Tuple
|
7 |
|
8 |
from datasets import load_dataset
|
9 |
from tokenizers import ByteLevelBPETokenizer
|
@@ -15,6 +15,7 @@ from data_utils import (
|
|
15 |
filter_by_lang_regex,
|
16 |
filter_by_num_tokens,
|
17 |
filter_by_num_sents,
|
|
|
18 |
normalizer
|
19 |
)
|
20 |
|
@@ -42,12 +43,12 @@ class TokenizerArguments:
|
|
42 |
default=None,
|
43 |
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
44 |
)
|
45 |
-
special_tokens: Optional[
|
46 |
default=None,
|
47 |
metadata={"help": "The list of special tokens that you want to add in your training."}
|
48 |
)
|
49 |
vocab_size: Optional[int] = field(
|
50 |
-
default=
|
51 |
metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
|
52 |
)
|
53 |
min_frequency: Optional[int] = field(
|
@@ -61,13 +62,16 @@ class TokenizerArguments:
|
|
61 |
|
62 |
def __post_init__(self):
|
63 |
if self.special_tokens is None:
|
64 |
-
|
65 |
"<s>", "<pad>", "</s>", "<unk>", "<mask>",
|
66 |
"<|endoftext|>", "<|startoftext|>",
|
67 |
"<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
|
68 |
]
|
|
|
|
|
|
|
69 |
|
70 |
-
self.special_tokens =
|
71 |
if self.dataset_name is None and self.train_file is None:
|
72 |
raise ValueError("Need either a dataset name or a training file.")
|
73 |
else:
|
@@ -100,7 +104,7 @@ def main():
|
|
100 |
tokenizer_args.dataset_name,
|
101 |
tokenizer_args.dataset_config_name,
|
102 |
cache_dir=tokenizer_args.cache_dir,
|
103 |
-
split="train
|
104 |
)
|
105 |
else:
|
106 |
data_files = {"train": tokenizer_args.train_file}
|
@@ -119,6 +123,7 @@ def main():
|
|
119 |
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
|
120 |
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
|
121 |
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
|
|
|
122 |
dataset = dataset.map(normalizer)
|
123 |
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
|
124 |
|
|
|
3 |
import os
|
4 |
import sys
|
5 |
from dataclasses import dataclass, field
|
6 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
7 |
|
8 |
from datasets import load_dataset
|
9 |
from tokenizers import ByteLevelBPETokenizer
|
|
|
15 |
filter_by_lang_regex,
|
16 |
filter_by_num_tokens,
|
17 |
filter_by_num_sents,
|
18 |
+
filter_by_adv,
|
19 |
normalizer
|
20 |
)
|
21 |
|
|
|
43 |
default=None,
|
44 |
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
45 |
)
|
46 |
+
special_tokens: Optional[str] = field(
|
47 |
default=None,
|
48 |
metadata={"help": "The list of special tokens that you want to add in your training."}
|
49 |
)
|
50 |
vocab_size: Optional[int] = field(
|
51 |
+
default=56000,
|
52 |
metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
|
53 |
)
|
54 |
min_frequency: Optional[int] = field(
|
|
|
62 |
|
63 |
def __post_init__(self):
|
64 |
if self.special_tokens is None:
|
65 |
+
special_tokens = [
|
66 |
"<s>", "<pad>", "</s>", "<unk>", "<mask>",
|
67 |
"<|endoftext|>", "<|startoftext|>",
|
68 |
"<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
|
69 |
]
|
70 |
+
special_tokens += [f"[U{i}]" for i in range(1, 21)]
|
71 |
+
else:
|
72 |
+
special_tokens = list(self.special_tokens.split(","))
|
73 |
|
74 |
+
self.special_tokens = special_tokens
|
75 |
if self.dataset_name is None and self.train_file is None:
|
76 |
raise ValueError("Need either a dataset name or a training file.")
|
77 |
else:
|
|
|
104 |
tokenizer_args.dataset_name,
|
105 |
tokenizer_args.dataset_config_name,
|
106 |
cache_dir=tokenizer_args.cache_dir,
|
107 |
+
split="train"
|
108 |
)
|
109 |
else:
|
110 |
data_files = {"train": tokenizer_args.train_file}
|
|
|
123 |
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
|
124 |
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
|
125 |
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
|
126 |
+
dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
|
127 |
dataset = dataset.map(normalizer)
|
128 |
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
|
129 |
|