m3hrdadfi commited on
Commit
69dd1b0
1 Parent(s): c6f6fcb

Initialize

Browse files
config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPTNeoForCausalLM"
5
+ ],
6
+ "attention_dropout": 0,
7
+ "attention_layers": [
8
+ "global",
9
+ "local",
10
+ "global",
11
+ "local",
12
+ "global",
13
+ "local",
14
+ "global",
15
+ "local",
16
+ "global",
17
+ "local",
18
+ "global",
19
+ "local",
20
+ "global",
21
+ "local",
22
+ "global",
23
+ "local",
24
+ "global",
25
+ "local",
26
+ "global",
27
+ "local",
28
+ "global",
29
+ "local",
30
+ "global",
31
+ "local"
32
+ ],
33
+ "attention_types": [
34
+ [
35
+ [
36
+ "global",
37
+ "local"
38
+ ],
39
+ 12
40
+ ]
41
+ ],
42
+ "bos_token_id": 5,
43
+ "embed_dropout": 0,
44
+ "eos_token_id": 5,
45
+ "gradient_checkpointing": false,
46
+ "hidden_size": 2048,
47
+ "initializer_range": 0.02,
48
+ "intermediate_size": null,
49
+ "layer_norm_epsilon": 1e-05,
50
+ "max_position_embeddings": 2048,
51
+ "model_type": "gpt_neo",
52
+ "num_heads": 16,
53
+ "num_layers": 24,
54
+ "pad_token_id": 5,
55
+ "resid_dropout": 0,
56
+ "summary_activation": null,
57
+ "summary_first_dropout": 0.1,
58
+ "summary_proj_to_labels": true,
59
+ "summary_type": "cls_index",
60
+ "summary_use_proj": true,
61
+ "task_specific_params": {
62
+ "text-generation": {
63
+ "do_sample": true,
64
+ "max_length": 50,
65
+ "temperature": 0.9
66
+ }
67
+ },
68
+ "tokenizer_class": "GPT2Tokenizer",
69
+ "transformers_version": "4.9.0.dev0",
70
+ "use_cache": true,
71
+ "vocab_size": 50000,
72
+ "window_size": 256
73
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
notes/.keep ADDED
File without changes
src/convert_flax_to_pytorch.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from transformers import AutoTokenizer
8
+ from transformers import FlaxGPT2LMHeadModel
9
+ from transformers import GPT2LMHeadModel
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained("../")
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+
14
+ model_fx = FlaxGPT2LMHeadModel.from_pretrained("../")
15
+
16
+ # def to_f32(t):
17
+ # return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
18
+
19
+ # model_fx.params = to_f32(model_fx.params)
20
+ # model_fx.save_pretrained("./fx")
21
+
22
+ model_pt = GPT2LMHeadModel.from_pretrained("../", from_flax=True)
23
+ model_pt.save_pretrained("./pt")
24
+
25
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
26
+ input_ids_pt = torch.tensor(input_ids)
27
+
28
+ logits_pt = model_pt(input_ids_pt).logits
29
+ print(logits_pt)
30
+ logits_fx = model_fx(input_ids).logits
31
+ print(logits_fx)
src/convert_flax_to_tf.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from transformers import AutoTokenizer
8
+ from transformers import GPT2LMHeadModel
9
+ from transformers import TFGPT2LMHeadModel
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained("../")
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+
14
+ model_pt = GPT2LMHeadModel.from_pretrained("./pt")
15
+ model_tf = TFGPT2LMHeadModel.from_pretrained("./pt", from_pt=True)
16
+ model_tf.save_pretrained("./tf")
17
+
18
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
19
+ input_ids_pt = torch.tensor(input_ids)
20
+
21
+ logits_pt = model_pt(input_ids_pt).logits
22
+ print(logits_pt)
23
+ logits_tf = model_tf(input_ids).logits
24
+ print(logits_tf)
src/create_config.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ from transformers import (
9
+ HfArgumentParser,
10
+ AutoConfig
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class ConfigArguments:
18
+ """
19
+ Arguments to which config we are going to set up.
20
+ """
21
+ output_dir: str = field(
22
+ default=".",
23
+ metadata={"help": "The output directory where the config will be written."},
24
+ )
25
+ name_or_path: Optional[str] = field(
26
+ default=None,
27
+ metadata={
28
+ "help": "The model checkpoint for weights initialization."
29
+ "Don't set if you want to train a model from scratch."
30
+ },
31
+ )
32
+ params: Optional[str] = field(
33
+ default=None,
34
+ metadata={"help": "Custom configuration for the specific `name_or_path`"}
35
+ )
36
+
37
+ def __post_init__(self):
38
+ if self.params:
39
+ try:
40
+ self.params = ast.literal_eval(self.params)
41
+ except Exception as e:
42
+ print(f"Your custom parameters do not acceptable due to {e}")
43
+
44
+
45
+ def main():
46
+ parser = HfArgumentParser([ConfigArguments])
47
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
48
+ # If we pass only one argument to the script and it's the path to a json file,
49
+ # let's parse it to get our arguments.
50
+ config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
51
+ else:
52
+ config_args = parser.parse_args_into_dataclasses()[0]
53
+
54
+ # Setup logging
55
+ logging.basicConfig(
56
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ handlers=[logging.StreamHandler(sys.stdout)],
59
+ )
60
+ logger.setLevel(logging.INFO)
61
+
62
+ logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}")
63
+
64
+ if config_args.params and isinstance(config_args.params, dict):
65
+ config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params)
66
+ else:
67
+ config = AutoConfig.from_pretrained(config_args.name_or_path)
68
+
69
+ logger.info(f"Your configuration saved here {config_args.output_dir}/config.json")
70
+ config.save_pretrained(config_args.output_dir)
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
src/create_dataset.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ import pandas as pd
7
+ from sklearn.model_selection import train_test_split
8
+ from tqdm import tqdm
9
+ from typing import Dict, List, Optional, Tuple
10
+ from datasets import load_dataset
11
+ from transformers import (
12
+ HfArgumentParser,
13
+ )
14
+ from data_utils import (
15
+ filter_by_lang_regex,
16
+ filter_by_num_tokens,
17
+ filter_by_num_sents,
18
+ filter_by_adv,
19
+ normalizer
20
+ )
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class DataArguments:
28
+ """
29
+ Arguments to which dataset we are going to set up.
30
+ """
31
+ output_dir: str = field(
32
+ default=".",
33
+ metadata={"help": "The output directory where the config will be written."},
34
+ )
35
+ dataset_name: str = field(
36
+ default=None,
37
+ metadata={"help": "The name of the dataset to use (via the datasets library)."}
38
+ )
39
+ dataset_config_name: Optional[str] = field(
40
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
41
+ )
42
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
43
+ cache_dir: Optional[str] = field(
44
+ default=None,
45
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
46
+ )
47
+ def main():
48
+ parser = HfArgumentParser([DataArguments])
49
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
50
+ # If we pass only one argument to the script and it's the path to a json file,
51
+ # let's parse it to get our arguments.
52
+ data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
53
+ else:
54
+ data_args = parser.parse_args_into_dataclasses()[0]
55
+ # Setup logging
56
+ logging.basicConfig(
57
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
58
+ datefmt="%m/%d/%Y %H:%M:%S",
59
+ handlers=[logging.StreamHandler(sys.stdout)],
60
+ )
61
+ logger.setLevel(logging.INFO)
62
+ logger.info(f"Preparing the dataset")
63
+ if data_args.dataset_name is not None:
64
+ dataset = load_dataset(
65
+ data_args.dataset_name,
66
+ data_args.dataset_config_name,
67
+ cache_dir=data_args.cache_dir,
68
+ split="train"
69
+ )
70
+ else:
71
+ data_files = {"train": data_args.train_file}
72
+ extension = data_args.train_file.split(".")[-1]
73
+ if extension == "txt":
74
+ extension = "text"
75
+
76
+ dataset = load_dataset(
77
+ extension,
78
+ data_files=data_files,
79
+ delimiter="\t",
80
+ cache_dir=data_args.cache_dir,
81
+ )
82
+
83
+ logger.info(f"dataset: {dataset}")
84
+
85
+ def data_preparation(item_dict):
86
+ if "text" not in item_dict:
87
+ return None
88
+
89
+ text = item_dict["text"]
90
+
91
+ status = filter_by_lang_regex(text, ratio=0.75)
92
+ if not status:
93
+ return None
94
+
95
+ status = filter_by_num_tokens(text, gt=64)
96
+ if not status:
97
+ return None
98
+
99
+ status = filter_by_num_sents(text, gt=2)
100
+ if not status:
101
+ return None
102
+
103
+ status = filter_by_adv(text, ratio=50)
104
+ if not status:
105
+ return None
106
+
107
+ text = normalizer(text)
108
+ return {"text": text}
109
+
110
+ data_dict = []
111
+ for item in tqdm(dataset, position=0, total=len(dataset)):
112
+ item = data_preparation(item)
113
+
114
+ if item:
115
+ data_dict.append(item)
116
+
117
+ data_df = pd.DataFrame(data_dict)
118
+
119
+ logger.info(f"Preparation - [before] consists of {len(dataset)} records!")
120
+ logger.info(f"Preparation - [after] consists of {len(data_df)} records!")
121
+
122
+ train, test = train_test_split(data_df, test_size=0.01, random_state=101)
123
+
124
+ train = train.reset_index(drop=True)
125
+ test = test.reset_index(drop=True)
126
+
127
+ logger.info(f"Preparation of [train] set consists of {len(train)} records!")
128
+ logger.info(f"Preparation of [test] set consists of {len(test)} records!")
129
+
130
+ os.makedirs(data_args.output_dir, exist_ok=True)
131
+ train.to_csv(os.path.join(data_args.output_dir, "train.csv"), sep="\t", encoding="utf-8", index=False)
132
+ test.to_csv(os.path.join(data_args.output_dir, "test.csv"), sep="\t", encoding="utf-8", index=False)
133
+ logger.info(f"Data saved here {data_args.output_dir}")
134
+
135
+ if __name__ == '__main__':
136
+ main()
src/data_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hazm import word_tokenize
2
+ from hazm import sent_tokenize
3
+ import re
4
+ import six
5
+
6
+ from normalizer import normalize
7
+
8
+ persian_regex = "0-9۰۱۲۳۴۵۶۷۸۹ءآئابتثجحخدذرزسشصضطظعغفقلمنهوپچژکگیە\u200c"
9
+
10
+
11
+ def filter_by_lang_regex(text, ratio=0.7, regex="0-9۰۱۲۳۴۵۶۷۸۹ءآئابتثجحخدذرزسشصضطظعغفقلمنهوپچژکگیە\u200c"):
12
+ candidate_text = re.sub(r"[^" + regex + "]+", " ", six.ensure_str(text)).replace(" ", "")
13
+ text = text.replace(" ", "")
14
+
15
+ return (len(candidate_text) / len(text)) > ratio
16
+
17
+
18
+ def filter_by_num_tokens(text, gt=64):
19
+ return len(word_tokenize(text)) > gt
20
+
21
+
22
+ def filter_by_num_sents(text, gt=2):
23
+ return len(sent_tokenize(text)) > gt
24
+
25
+
26
+ def filter_by_adv(text, ratio=50):
27
+ comma = text.split(",")
28
+ colon = re.findall(r"""(?:([^\W]+):([^\W]+))""", text)
29
+ virgool = text.split("،")
30
+ length_add = len(comma) + len(colon) + len(virgool)
31
+
32
+ return length_add < ratio
33
+
34
+
35
+ def normalizer(text, do_lowercase=False):
36
+ text = normalize(text)
37
+
38
+ if do_lowercase:
39
+ text = text.lower()
40
+
41
+ return text
42
+
src/dictionary.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ characters = {
2
+ "ك": "ک",
3
+ "دِ": "د",
4
+ "بِ": "ب",
5
+ "زِ": "ز",
6
+ "ذِ": "ذ",
7
+ "شِ": "ش",
8
+ "سِ": "س",
9
+ "ى": "ی",
10
+ "ي": "ی",
11
+ "ؤ": "و",
12
+ "ے": "ی",
13
+ "ۀ": "ه",
14
+ "ﭘ": "پ",
15
+ "ﮐ": "ک",
16
+ "ﯽ": "ی",
17
+ "ﺎ": "ا",
18
+ "ﺑ": "ب",
19
+ "ﺘ": "ت",
20
+ "ﺧ": "خ",
21
+ "ﺩ": "د",
22
+ "ﺱ": "س",
23
+ "ﻀ": "ض",
24
+ "ﻌ": "ع",
25
+ "ﻟ": "ل",
26
+ "ﻡ": "م",
27
+ "ﻢ": "م",
28
+ "ﻪ": "ه",
29
+ "ﻮ": "و",
30
+ "ﺍ": "ا",
31
+ "ة": "ه",
32
+ "ﯾ": "ی",
33
+ "ﯿ": "ی",
34
+ "ﺒ": "ب",
35
+ "ﺖ": "ت",
36
+ "ﺪ": "د",
37
+ "ﺮ": "ر",
38
+ "ﺴ": "س",
39
+ "ﺷ": "ش",
40
+ "ﺸ": "ش",
41
+ "ﻋ": "ع",
42
+ "ﻤ": "م",
43
+ "ﻥ": "ن",
44
+ "ﻧ": "ن",
45
+ "ﻭ": "و",
46
+ "ﺭ": "ر",
47
+ "ﮔ": "گ",
48
+ "إ": "ا",
49
+ "ٕ": " ",
50
+ "ھ": "ه",
51
+ "...": ".",
52
+ "…": ".",
53
+ "-": " - ",
54
+ "هٔ": "ه",
55
+ "ﻯ": "ی",
56
+ "ﻛ": "ک",
57
+ "ﭼ": "چ",
58
+ "ﺓ": "ه",
59
+ "ﻴ": "ی",
60
+ "ﻊ": "ع",
61
+ "ﮬ": "ه",
62
+ "ﺟ": "ج",
63
+ "ﺳ": "س",
64
+ "ﻦ": "ن",
65
+ "ﺬ": "ذ",
66
+ "ﺋ": "ئ",
67
+ "ﷲ": "لله",
68
+ "ﺞ": "ج",
69
+ "ﺙ": "ث",
70
+ "ﻗ": "ق",
71
+ "ﮪ": "ه",
72
+ "ﺰ": "ز",
73
+ "ﯼ": "ی",
74
+ "ٺ": "ت",
75
+ "ﺻ": "ص",
76
+ "ﻂ": "ط",
77
+ "ﻣ": "م",
78
+ "ﻈ": "ظ",
79
+ "ﺐ": "ب",
80
+ "ﻍ": "غ",
81
+ "ݸ": "و",
82
+ "ﻨ": "ن",
83
+ "ﻝ": "ل",
84
+ "ﻩ": "ه",
85
+ "ﻲ": "ی",
86
+ "ﻐ": "غ",
87
+ "ﺲ": "س",
88
+ "ﺁ": "آ",
89
+ "ڔ": "ر",
90
+ "ﺫ": "ذ",
91
+ "ﭻ": "چ",
92
+ "ﺠ": "ج",
93
+ "ﯙ": "و",
94
+ "ﮏ": "ک",
95
+ "ﺣ": "ح",
96
+ "ﺝ": "ج",
97
+ "ﺼ": "ص",
98
+ "ﻳ": "ی",
99
+ "ﻘ": "ق",
100
+ "ﺨ": "خ",
101
+ "ﻔ": "ف",
102
+ "ﻎ": "غ",
103
+ "ئ": "ی",
104
+ "ﻓ": "ف",
105
+ "ﻕ": "ق",
106
+ "ﮋ": "ژ",
107
+ "ﺗ": "ت",
108
+ "ﻁ": "ط",
109
+ "ﺯ": "ز",
110
+ "ﮕ": "گ",
111
+ "ﺌ": "ئ",
112
+ "ﺵ": "ش",
113
+ "ۮ": "د",
114
+ "ﻫ": "ه",
115
+ "ﻬ": "ه",
116
+ "ﻏ": "غ",
117
+ "ﻰ": "ی",
118
+ "﷼": "ریال",
119
+ "ﺿ": "ض",
120
+ "ﺛ": "ث",
121
+ "ݐ": "پ",
122
+ "ﺏ": "ب",
123
+ "ﭙ": "پ",
124
+ "ﭽ": "چ",
125
+ "ﺜ": "ث",
126
+ "ﻃ": "ط",
127
+ "ۂ": "ه",
128
+ "ﻑ": "ف",
129
+ "ﺕ": "ت",
130
+ "ﻞ": "ل",
131
+ }
132
+
133
+ special_tokens = {}
134
+
135
+ words_map = {
136
+ "Leave a comment": "",
137
+ "[…]": "",
138
+ "[.]": "",
139
+ }
src/normalizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hazm
2
+ import re
3
+ import string
4
+
5
+ from regexes.currency import CURRENCY_REGEX
6
+ from regexes.email import EMAIL_REGEX
7
+ from regexes.latin import LATIN_REGEX
8
+ from regexes.latin import LATIN_REGEX, LATIN_WITH_SPECIAL_REGEX
9
+ from regexes.number import NUMBERS_REGEX
10
+ from regexes.phone import PHONE_REGEX
11
+ from regexes.quote import DOUBLE_QUOTE_REGEX, SINGLE_QUOTE_REGEX
12
+ from regexes.url import URL_REGEX
13
+ from regexes.persian import PERSIAN_REGEX
14
+ from regexes.punk import PUNK_REGEX
15
+ import dictionary
16
+
17
+ allowed_char = string.ascii_letters + string.digits + ':/@_-. '
18
+
19
+
20
+ def make_trans(list_a, list_b):
21
+ return dict((ord(a), b) for a, b in zip(list_a, list_b))
22
+
23
+
24
+ def multiple_replace(text, chars_to_mapping):
25
+ pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
26
+ return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
27
+
28
+
29
+ def remove_adv_by_tag_name(text, tag_name):
30
+ found = text.find(tag_name)
31
+
32
+ if found > 0:
33
+ text = text[:found]
34
+
35
+ return text
36
+
37
+
38
+ def clean_url(text):
39
+ # removing html tags
40
+ text = re.sub('<.*?>', '', text)
41
+
42
+ # removing normal(without space urls)
43
+ text = re.sub(r'(?:(?:http|https):\/\/)?([-a-zA-Z0-9.]{2,256}\.[a-z]{2,4})\b(?:\/[-a-zA-Z0-9@:%_\+.~#?&//=]*)?', "",
44
+ text)
45
+
46
+ # removing urls that contains space
47
+ result = ''
48
+ for char in text:
49
+ if char in allowed_char:
50
+ result += char
51
+ result = result.replace(' ', '')
52
+ result = result.split(':')
53
+ for phrase in result:
54
+ p = phrase
55
+ if '//' in p:
56
+ if ('https :' + p) in text:
57
+ text = text.replace('https :' + p, '')
58
+ elif ('http :' + p) in text:
59
+ text = text.replace('http :' + p, '')
60
+ elif '@' in p:
61
+ if p in text:
62
+ text = text.replace(p, '')
63
+
64
+ return text
65
+
66
+
67
+ ar2fa_digits = make_trans("٠١٢٣٤٥٦٧٨٩٪", "۰۱۲۳۴۵۶۷۸۹٪")
68
+ fa2en_digits = make_trans("۰۱۲۳۴۵۶۷۸۹٪", "0123456789%")
69
+ normalizer = hazm.Normalizer(persian_numbers=True, punctuation_spacing=False)
70
+
71
+
72
+ def normalize(text, zwnj="\u200c", tokenized=False):
73
+ text = text.replace("\n", " ").replace("\t", " ")
74
+ text = re.sub(r"\u200c+", "\u200c", text)
75
+ text = text.replace('ـ', '')
76
+ text = normalizer.normalize(text)
77
+
78
+ if len(dictionary.characters) > 0:
79
+ text = multiple_replace(text, dictionary.characters)
80
+
81
+ if len(dictionary.words_map) > 0:
82
+ text = multiple_replace(text, dictionary.words_map)
83
+
84
+ text = text.translate(ar2fa_digits)
85
+ text = text.translate(fa2en_digits)
86
+
87
+ text = SINGLE_QUOTE_REGEX.sub("'", text)
88
+ text = DOUBLE_QUOTE_REGEX.sub('"', text)
89
+ text = CURRENCY_REGEX.sub(r" \1 ", text)
90
+ text = clean_url(text)
91
+ text = remove_adv_by_tag_name(text, tag_name="برچسب ها :")
92
+ text = URL_REGEX.sub(" ", text)
93
+ text = EMAIL_REGEX.sub(" ", text)
94
+ text = PHONE_REGEX.sub(r" \1 ", text)
95
+ text = NUMBERS_REGEX.sub(r" \1 ", text)
96
+ text = LATIN_REGEX.sub(r" \1 ", text)
97
+ # text = PUNK_REGEX.sub(r" \1 ", text) # must be remained the same!
98
+
99
+ # Allow only english and persian characters
100
+ text = re.sub(PERSIAN_REGEX, " ", text)
101
+
102
+ text = text.replace(f" {zwnj} ", f"{zwnj}")
103
+ text = text.replace(f"{zwnj} ", f"{zwnj}")
104
+ text = text.replace(f" {zwnj}", f"{zwnj}")
105
+
106
+ if len(dictionary.special_tokens) > 0:
107
+ text = multiple_replace(text, dictionary.special_tokens)
108
+
109
+ tokens = []
110
+ for token in text.split():
111
+ token = token.strip()
112
+ if token:
113
+ if token.startswith(zwnj) and token.endswith(zwnj):
114
+ token = token[1:-1]
115
+ if token.startswith(zwnj):
116
+ token = token[1:]
117
+ elif token.endswith(zwnj):
118
+ token = token[:-1]
119
+ else:
120
+ token = token
121
+
122
+ tokens.append(token)
123
+
124
+ if tokenized:
125
+ return tokens
126
+
127
+ return " ".join(tokens)
128
+
129
+
130
+
131
+ if __name__ == '__main__':
132
+ import textwrap
133
+
134
+ # input_text = " «هفتاد سی» "
135
+ # input_text = normalize(input_text)
136
+ # input_text = DOUBLE_QUOTE_REGEX.sub('"', input_text)
137
+ # print(textwrap.fill(input_text))
138
+ # print(normalize(input_text, tokenized=True))
src/regexes/__init__.py ADDED
File without changes
src/regexes/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (155 Bytes). View file
 
src/regexes/__pycache__/currency.cpython-38.pyc ADDED
Binary file (678 Bytes). View file
 
src/regexes/__pycache__/email.cpython-38.pyc ADDED
Binary file (469 Bytes). View file
 
src/regexes/__pycache__/latin.cpython-38.pyc ADDED
Binary file (369 Bytes). View file
 
src/regexes/__pycache__/number.cpython-38.pyc ADDED
Binary file (335 Bytes). View file
 
src/regexes/__pycache__/persian.cpython-38.pyc ADDED
Binary file (536 Bytes). View file
 
src/regexes/__pycache__/phone.cpython-38.pyc ADDED
Binary file (365 Bytes). View file
 
src/regexes/__pycache__/punk.cpython-38.pyc ADDED
Binary file (296 Bytes). View file
 
src/regexes/__pycache__/quote.cpython-38.pyc ADDED
Binary file (576 Bytes). View file
 
src/regexes/__pycache__/url.cpython-38.pyc ADDED
Binary file (764 Bytes). View file
 
src/regexes/currency.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ CURRENCIES = {
4
+ "$": "USD",
5
+ "zł": "PLN",
6
+ "£": "GBP",
7
+ "¥": "JPY",
8
+ "฿": "THB",
9
+ "₡": "CRC",
10
+ "₦": "NGN",
11
+ "₩": "KRW",
12
+ "₪": "ILS",
13
+ "₫": "VND",
14
+ "€": "EUR",
15
+ "₱": "PHP",
16
+ "₲": "PYG",
17
+ "₴": "UAH",
18
+ "₹": "INR",
19
+ "﷼": "IRR",
20
+ }
21
+ CURRENCY_REGEX = re.compile(
22
+ "({})+".format("|".join(re.escape(c) for c in CURRENCIES.keys()))
23
+ )
src/regexes/email.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ EMAIL_REGEX = re.compile(
4
+ r"(?:^|(?<=[^\w@.)]))([\w+-](\.(?!\.))?)*?[\w+-](@|[(<{\[]at[)>}\]])(?:(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)(?:\.(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)*(?:\.(?:[a-z\\u00a1-\\uffff]{2,}))",
5
+ flags=re.IGNORECASE | re.UNICODE,
6
+ )
src/regexes/latin.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ LATIN_WITH_SPECIAL_REGEX = re.compile(
4
+ r"(\b(?!URL|EMAIL|PHONE|NUMBER|CUR|LATIN\b)[0-9a-zA-Z]+)"
5
+ )
6
+
7
+ LATIN_REGEX = re.compile(
8
+ r"([0-9a-zA-Z]+)"
9
+ )
10
+
11
+ LATIN_SPACES_REGEX = re.compile(
12
+ r"([0-9a-zA-Z])"
13
+ )
src/regexes/number.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import re
2
+
3
+ NUMBERS_REGEX = re.compile(
4
+ r"(?:^|(?<=[^\w,.]))[+–-]?(([1-9]\d{0,2}(,\d{3})+(\.\d*)?)|([1-9]\d{0,2}([ .]\d{3})+(,\d*)?)|(\d*?[.,]\d+)|\d+)(?:$|(?=\b))"
5
+ )
src/regexes/persian.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ PERSIAN_ALPHA = "ءآئابتثجحخدذرزسشصضطظعغفقلمنهوپچژکگیە" # noqa: E501
5
+ PERSIAN_DIGIT = "۰۱۲۳۴۵۶۷۸۹"
6
+
7
+
8
+ ZWNJ = "\u200c"
9
+ PUNK = '\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\=\>\?\@\[\]\^\_\`\{\|\}\~\«\»\؟\:\×\٬\٫\﷼\٪\،'
10
+
11
+ PERSIAN = (
12
+ "a-zA-Z0-9" +
13
+ PERSIAN_ALPHA +
14
+ PERSIAN_DIGIT +
15
+ ZWNJ +
16
+ PUNK
17
+ )
18
+
19
+ PERSIAN_REGEX = r"[^" + PERSIAN + "+]"
src/regexes/phone.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ PHONE_REGEX = re.compile(
5
+ r"((?:^|(?<=[^\w)]))(((\+?[01])|(\+\d{2}))[ .-]?)?(\(?\d{3,4}\)?/?[ .-]?)?(\d{3}[ .-]?\d{4})(\s?(?:ext\.?|[#x-])\s?\d{2,6})?(?:$|(?=\W)))|\+?\d{4,5}[ .-/]\d{6,9}"
6
+ )
src/regexes/punk.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import re
2
+
3
+ PUNK_REGEX = re.compile(
4
+ r"([\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\=\?\@\[\\\]\^\_\`\{\|\}\~\«\»\⸮\؟\،\٬\٫\؛])"
5
+ )
src/regexes/quote.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ strange_double_quotes = [
5
+ "«",
6
+ "‹",
7
+ "»",
8
+ "›",
9
+ "„",
10
+ "“",
11
+ "‟",
12
+ "”",
13
+ "❝",
14
+ "❞",
15
+ "❮",
16
+ "❯",
17
+ "〝",
18
+ "〞",
19
+ "〟",
20
+ """,
21
+ ]
22
+ strange_single_quotes = ["‘", "‛", "’", "❛", "❜", "`", "´", "‘", "’"]
23
+
24
+ DOUBLE_QUOTE_REGEX = re.compile("|".join(strange_double_quotes))
25
+ SINGLE_QUOTE_REGEX = re.compile("|".join(strange_single_quotes))
src/regexes/url.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ URL_REGEX = re.compile(
4
+ r"(?:^|(?<![\w\/\.]))"
5
+ # protocol identifier
6
+ # r"(?:(?:https?|ftp)://)" <-- alt?
7
+ r"(?:(?:https?:\/\/|ftp:\/\/|www\d{0,3}\.))"
8
+ # user:pass authentication
9
+ r"(?:\S+(?::\S*)?@)?" r"(?:"
10
+ # IP address exclusion
11
+ # private & local networks
12
+ r"(?!(?:10|127)(?:\.\d{1,3}){3})"
13
+ r"(?!(?:169\.254|192\.168)(?:\.\d{1,3}){2})"
14
+ r"(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})"
15
+ # IP address dotted notation octets
16
+ # excludes loopback network 0.0.0.0
17
+ # excludes reserved space >= 224.0.0.0
18
+ # excludes network & broadcast addresses
19
+ # (first & last IP address of each class)
20
+ r"(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])"
21
+ r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}"
22
+ r"(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))"
23
+ r"|"
24
+ # host name
25
+ r"(?:(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)"
26
+ # domain name
27
+ r"(?:\.(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)*"
28
+ # TLD identifier
29
+ r"(?:\.(?:[a-z\\u00a1-\\uffff]{2,}))" r"|" r"(?:(localhost))" r")"
30
+ # port number
31
+ r"(?::\d{2,5})?"
32
+ # resource path
33
+ r"(?:\/[^\)\]\}\s]*)?",
34
+ # r"(?:$|(?![\w?!+&\/\)]))",
35
+ # @jfilter: I removed the line above from the regex because I don't understand what it is used for, maybe it was useful?
36
+ # But I made sure that it does not include ), ] and } in the URL.
37
+ flags=re.UNICODE | re.IGNORECASE,
38
+ )
src/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets >= 1.1.3
2
+ jax>=0.2.8
3
+ jaxlib>=0.1.59
4
+ flax>=0.3.4
5
+ optax>=0.0.8
6
+ hazm
7
+ tensorboard
src/run.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ # export MODEL_NAME_OR_PATH=/home/m3hrdadfi/code/gpt-neo-1.3B-persian
7
+ export OUTPUT_DIR=/home/m3hrdadfi/code/gpt-neo-1.3B-persian
8
+ export MODEL_TYPE=gpt_neo
9
+ export CONFIG_NAME=/home/m3hrdadfi/code/gpt-neo-1.3B-persian
10
+ export TOKENIZER_NAME=/home/m3hrdadfi/code/gpt-neo-1.3B-persian
11
+
12
+ export TRAIN_FILE=/home/m3hrdadfi/data/train-fixed.csv
13
+ export VALIDATION_FILE=/home/m3hrdadfi/data/test-fixed.csv
14
+ export TEST_FILE=/home/m3hrdadfi/code/data/test-fixed.csv
15
+ # export DATASET_NAME=oscar
16
+ # export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
17
+ export MAX_SEQUENCE_LENGTH=512
18
+
19
+ #export MAX_TRAIN_SAMPLE=5000
20
+ #export MAX_EVAL_SAMPLES=5000
21
+
22
+ export PER_DEVICE_TRAIN_BATCH_SIZE=16
23
+ export PER_DEVICE_EVAL_BATCH_SIZE=16
24
+ export NUM_TRAIN_EPOCHS=5.0
25
+ export LEARNING_RATE=1e-3
26
+ export WARMUP_STEPS=5000
27
+ export LOGGING_STEPS=500
28
+ export EVAL_STEPS=2500
29
+ export SAVE_STEPS=2500
30
+
31
+ # python src/run_clm_flax.py \
32
+ # --output_dir="$OUTPUT_DIR" \
33
+ # --model_name_or_path="$MODEL_NAME_OR_PATH" \
34
+ # --train_file="$TRAIN_FILE" \
35
+ # --validation_file="$VALIDATION_FILE" \
36
+ # --block_size=$MAX_SEQUENCE_LENGTH \
37
+ # --per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
38
+ # --per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
39
+ # --num_train_epochs=$NUM_TRAIN_EPOCHS \
40
+ # --learning_rate=$LEARNING_RATE \
41
+ # --warmup_steps=$WARMUP_STEPS \
42
+ # --logging_step=$LOGGING_STEPS \
43
+ # --eval_steps=$EVAL_STEPS \
44
+ # --save_steps=$SAVE_STEPS \
45
+ # --do_train \
46
+ # --do_eval \
47
+ # --overwrite_output_dir \
48
+ # --push_to_hub
49
+
50
+ python src/run_clm_flax.py \
51
+ --output_dir="$OUTPUT_DIR" \
52
+ --model_type="$MODEL_TYPE" \
53
+ --config_name="$CONFIG_NAME" \
54
+ --tokenizer_name="$TOKENIZER_NAME" \
55
+ --dataset_name="$DATASET_NAME" \
56
+ --dataset_config_name="$DATASET_CONFIG_NAME" \
57
+ --block_size=$MAX_SEQUENCE_LENGTH \
58
+ --per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
59
+ --per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
60
+ --num_train_epochs=$NUM_TRAIN_EPOCHS \
61
+ --learning_rate=$LEARNING_RATE \
62
+ --warmup_steps=$WARMUP_STEPS \
63
+ --logging_step=$LOGGING_STEPS \
64
+ --eval_steps=$EVAL_STEPS \
65
+ --save_steps=$SAVE_STEPS \
66
+ --do_train \
67
+ --do_eval \
68
+ --overwrite_output_dir \
69
+ --push_to_hub
src/run_clm_flax.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=causal-lm
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Callable, Optional
32
+
33
+ import datasets
34
+ from datasets import Dataset, load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ from jax import lax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ import transformers
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import checkpoints
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
+ from transformers import (
48
+ CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
50
+ AutoConfig,
51
+ AutoTokenizer,
52
+ FlaxAutoModelForCausalLM,
53
+ HfArgumentParser,
54
+ TrainingArguments,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.testing_utils import CaptureLogger
58
+
59
+ from data_utils import (
60
+ filter_by_lang_regex,
61
+ filter_by_num_tokens,
62
+ filter_by_num_sents,
63
+ filter_by_adv,
64
+ normalizer
65
+ )
66
+
67
+ print(jax.devices())
68
+
69
+ logger = logging.getLogger(__name__)
70
+
71
+ # Cache the result
72
+ has_tensorboard = is_tensorboard_available()
73
+ if has_tensorboard:
74
+ try:
75
+ from flax.metrics.tensorboard import SummaryWriter
76
+ except ImportError as ie:
77
+ has_tensorboard = False
78
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
79
+
80
+ else:
81
+ print(
82
+ "Unable to display metrics through TensorBoard because the package is not installed: "
83
+ "Please run pip install tensorboard to enable."
84
+ )
85
+
86
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
87
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
88
+
89
+
90
+ @dataclass
91
+ class ModelArguments:
92
+ """
93
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
94
+ """
95
+
96
+ model_name_or_path: Optional[str] = field(
97
+ default=None,
98
+ metadata={
99
+ "help": "The model checkpoint for weights initialization."
100
+ "Don't set if you want to train a model from scratch."
101
+ },
102
+ )
103
+ model_type: Optional[str] = field(
104
+ default=None,
105
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
106
+ )
107
+ config_name: Optional[str] = field(
108
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
109
+ )
110
+ tokenizer_name: Optional[str] = field(
111
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
112
+ )
113
+ cache_dir: Optional[str] = field(
114
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
115
+ )
116
+ use_fast_tokenizer: bool = field(
117
+ default=True,
118
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
119
+ )
120
+ dtype: Optional[str] = field(
121
+ default="float32",
122
+ metadata={
123
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
124
+ },
125
+ )
126
+
127
+
128
+ @dataclass
129
+ class DataTrainingArguments:
130
+ """
131
+ Arguments pertaining to what data we are going to input our model for training and eval.
132
+ """
133
+
134
+ dataset_name: Optional[str] = field(
135
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
136
+ )
137
+ dataset_config_name: Optional[str] = field(
138
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
139
+ )
140
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
141
+ validation_file: Optional[str] = field(
142
+ default=None,
143
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
144
+ )
145
+ max_train_samples: Optional[int] = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
149
+ "value if set."
150
+ },
151
+ )
152
+ max_eval_samples: Optional[int] = field(
153
+ default=None,
154
+ metadata={
155
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
156
+ "value if set."
157
+ },
158
+ )
159
+ overwrite_cache: bool = field(
160
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
161
+ )
162
+ validation_split_percentage: Optional[int] = field(
163
+ default=1,
164
+ metadata={
165
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
166
+ },
167
+ )
168
+ block_size: Optional[int] = field(
169
+ default=None,
170
+ metadata={
171
+ "help": "Optional input sequence length after tokenization. "
172
+ "The training dataset will be truncated in block of this size for training. "
173
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
174
+ },
175
+ )
176
+ overwrite_cache: bool = field(
177
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
178
+ )
179
+ preprocessing_num_workers: Optional[int] = field(
180
+ default=None,
181
+ metadata={"help": "The number of processes to use for the preprocessing."},
182
+ )
183
+
184
+ def __post_init__(self):
185
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
186
+ raise ValueError("Need either a dataset name or a training/validation file.")
187
+ else:
188
+ if self.train_file is not None:
189
+ extension = self.train_file.split(".")[-1]
190
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
191
+ if self.validation_file is not None:
192
+ extension = self.validation_file.split(".")[-1]
193
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
194
+
195
+
196
+ class TrainState(train_state.TrainState):
197
+ dropout_rng: jnp.ndarray
198
+
199
+ def replicate(self):
200
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
201
+
202
+
203
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
204
+ """
205
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
206
+ Shuffle batches if `shuffle` is `True`.
207
+ """
208
+ steps_per_epoch = len(dataset) // batch_size
209
+
210
+ if shuffle:
211
+ batch_idx = jax.random.permutation(rng, len(dataset))
212
+ else:
213
+ batch_idx = jnp.arange(len(dataset))
214
+
215
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
216
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
217
+
218
+ for idx in batch_idx:
219
+ batch = dataset[idx]
220
+ batch = {k: jnp.array(v) for k, v in batch.items()}
221
+
222
+ batch = shard(batch)
223
+
224
+ yield batch
225
+
226
+
227
+ # def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
228
+ # summary_writer.scalar("train_time", train_time, step)
229
+ #
230
+ # train_metrics = get_metrics(train_metrics)
231
+ # for key, vals in train_metrics.items():
232
+ # tag = f"train_{key}"
233
+ # for i, val in enumerate(vals):
234
+ # summary_writer.scalar(tag, val, step - len(vals) + i + 1)
235
+ #
236
+ # for metric_name, value in eval_metrics.items():
237
+ # summary_writer.scalar(f"eval_{metric_name}", value, step)
238
+
239
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
240
+ summary_writer.scalar("train_time", train_time, step)
241
+
242
+ train_metrics = get_metrics(train_metrics)
243
+ for key, vals in train_metrics.items():
244
+ tag = f"train_{key}"
245
+ for i, val in enumerate(vals):
246
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
247
+
248
+
249
+ def write_eval_metric(summary_writer, eval_metrics, step):
250
+ for metric_name, value in eval_metrics.items():
251
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
252
+
253
+
254
+ def create_learning_rate_fn(
255
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
256
+ ) -> Callable[[int], jnp.array]:
257
+ """Returns a linear warmup, linear_decay learning rate function."""
258
+ steps_per_epoch = train_ds_size // train_batch_size
259
+ num_train_steps = steps_per_epoch * num_train_epochs
260
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
261
+ decay_fn = optax.linear_schedule(
262
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
263
+ )
264
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
265
+ return schedule_fn
266
+
267
+
268
+ def main():
269
+ # See all possible arguments in src/transformers/training_args.py
270
+ # or by passing the --help flag to this script.
271
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
272
+
273
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
274
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
275
+ # If we pass only one argument to the script and it's the path to a json file,
276
+ # let's parse it to get our arguments.
277
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
278
+ else:
279
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
280
+
281
+ if (
282
+ os.path.exists(training_args.output_dir)
283
+ and os.listdir(training_args.output_dir)
284
+ and training_args.do_train
285
+ and not training_args.overwrite_output_dir
286
+ ):
287
+ raise ValueError(
288
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
289
+ "Use --overwrite_output_dir to overcome."
290
+ )
291
+
292
+ # Make one log on every process with the configuration for debugging.
293
+ logging.basicConfig(
294
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
295
+ datefmt="%m/%d/%Y %H:%M:%S",
296
+ level=logging.INFO,
297
+ )
298
+ # Setup logging, we only want one process per machine to log things on the screen.
299
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
300
+ if jax.process_index() == 0:
301
+ datasets.utils.logging.set_verbosity_warning()
302
+ transformers.utils.logging.set_verbosity_info()
303
+ else:
304
+ datasets.utils.logging.set_verbosity_error()
305
+ transformers.utils.logging.set_verbosity_error()
306
+
307
+ # Set the verbosity to info of the Transformers logger (on main process only):
308
+ logger.info(f"Training/evaluation parameters {training_args}")
309
+
310
+ checkpoints_dir = os.path.join(training_args.output_dir, "checkpoints")
311
+ os.makedirs(checkpoints_dir, exist_ok=True)
312
+
313
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
314
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
315
+ # (the dataset will be downloaded automatically from the datasets Hub).
316
+ #
317
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
318
+ # 'text' is found. You can easily tweak this behavior (see below).
319
+ #
320
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
321
+ # download the dataset.
322
+ if data_args.dataset_name is not None:
323
+ # Downloading and loading a dataset from the hub.
324
+ raw_dataset = load_dataset(
325
+ data_args.dataset_name,
326
+ data_args.dataset_config_name,
327
+ cache_dir=model_args.cache_dir,
328
+ keep_in_memory=False
329
+ )
330
+
331
+ if "validation" not in raw_dataset.keys():
332
+ raw_dataset["validation"] = load_dataset(
333
+ data_args.dataset_name,
334
+ data_args.dataset_config_name,
335
+ split=f"train[:{data_args.validation_split_percentage}%]",
336
+ cache_dir=model_args.cache_dir,
337
+ )
338
+ raw_dataset["train"] = load_dataset(
339
+ data_args.dataset_name,
340
+ data_args.dataset_config_name,
341
+ split=f"train[{data_args.validation_split_percentage}%:]",
342
+ cache_dir=model_args.cache_dir,
343
+ )
344
+ else:
345
+ data_files = {}
346
+ if data_args.train_file is not None:
347
+ data_files["train"] = data_args.train_file
348
+ if data_args.validation_file is not None:
349
+ data_files["validation"] = data_args.validation_file
350
+ extension = data_args.train_file.split(".")[-1]
351
+ if extension == "txt":
352
+ extension = "text"
353
+
354
+ raw_dataset = load_dataset(
355
+ extension,
356
+ data_files=data_files,
357
+ delimiter="\t",
358
+ cache_dir=model_args.cache_dir
359
+ )
360
+
361
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
362
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
363
+ # logger.info("Preprocessing the dataset")
364
+ # dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
365
+ # dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
366
+ # dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
367
+ # dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
368
+ # dataset = dataset.map(normalizer)
369
+ # logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
370
+ dataset = raw_dataset
371
+ logger.info(f"dataset: {dataset}")
372
+
373
+ # Load pretrained model and tokenizer
374
+
375
+ # Distributed training:
376
+ # The .from_pretrained methods guarantee that only one local process can concurrently
377
+ # download model & vocab.
378
+ if model_args.config_name:
379
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
380
+ elif model_args.model_name_or_path:
381
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
382
+ else:
383
+ config = CONFIG_MAPPING[model_args.model_type]()
384
+ logger.warning("You are instantiating a new config instance from scratch.")
385
+
386
+ if model_args.tokenizer_name:
387
+ tokenizer = AutoTokenizer.from_pretrained(
388
+ model_args.tokenizer_name,
389
+ cache_dir=model_args.cache_dir,
390
+ use_fast=model_args.use_fast_tokenizer
391
+ )
392
+ elif model_args.model_name_or_path:
393
+ tokenizer = AutoTokenizer.from_pretrained(
394
+ model_args.model_name_or_path,
395
+ cache_dir=model_args.cache_dir,
396
+ use_fast=model_args.use_fast_tokenizer
397
+ )
398
+ else:
399
+ raise ValueError(
400
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
401
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
402
+ )
403
+
404
+ if model_args.model_name_or_path:
405
+ model = FlaxAutoModelForCausalLM.from_pretrained(
406
+ model_args.model_name_or_path,
407
+ config=config,
408
+ seed=training_args.seed,
409
+ dtype=getattr(jnp, model_args.dtype)
410
+ )
411
+ else:
412
+ model = FlaxAutoModelForCausalLM.from_config(
413
+ config,
414
+ seed=training_args.seed,
415
+ dtype=getattr(jnp, model_args.dtype)
416
+ )
417
+
418
+ # Preprocessing the datasets.
419
+ # First we tokenize all the texts.
420
+ if training_args.do_train:
421
+ column_names = dataset["train"].column_names
422
+ else:
423
+ column_names = dataset["validation"].column_names
424
+ text_column_name = "text" if "text" in column_names else column_names[0]
425
+ logger.info(f"text_column_name: {text_column_name}")
426
+
427
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
428
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
429
+
430
+ def tokenize_function(examples):
431
+ with CaptureLogger(tok_logger) as cl:
432
+ output = tokenizer(examples[text_column_name])
433
+ # clm input could be much much longer than block_size
434
+ if "Token indices sequence length is longer than the" in cl.out:
435
+ tok_logger.warning(
436
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
437
+ )
438
+ return output
439
+
440
+ tokenized_datasets = dataset.map(
441
+ tokenize_function,
442
+ batched=True,
443
+ num_proc=data_args.preprocessing_num_workers,
444
+ remove_columns=column_names,
445
+ load_from_cache_file=not data_args.overwrite_cache,
446
+ )
447
+
448
+ if data_args.block_size is None:
449
+ block_size = tokenizer.model_max_length
450
+ if block_size > config.max_position_embeddings:
451
+ logger.warning(
452
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
453
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
454
+ )
455
+ block_size = 1024
456
+ else:
457
+ if data_args.block_size > tokenizer.model_max_length:
458
+ logger.warning(
459
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
460
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
461
+ )
462
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
463
+
464
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
465
+ def group_texts(examples):
466
+ # Concatenate all texts.
467
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
468
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
469
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
470
+ # customize this part to your needs.
471
+ if total_length >= block_size:
472
+ total_length = (total_length // block_size) * block_size
473
+ # Split by chunks of max_len.
474
+ result = {
475
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
476
+ for k, t in concatenated_examples.items()
477
+ }
478
+ result["labels"] = result["input_ids"].copy()
479
+ return result
480
+
481
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
482
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
483
+ # to preprocess.
484
+ #
485
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
486
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
487
+
488
+ lm_datasets = tokenized_datasets.map(
489
+ group_texts,
490
+ batched=True,
491
+ num_proc=data_args.preprocessing_num_workers,
492
+ load_from_cache_file=not data_args.overwrite_cache,
493
+ )
494
+
495
+ if training_args.do_train:
496
+ if "train" not in tokenized_datasets:
497
+ raise ValueError("--do_train requires a train dataset")
498
+ train_dataset = lm_datasets["train"]
499
+ if data_args.max_train_samples is not None:
500
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
501
+
502
+ if training_args.do_eval:
503
+ if "validation" not in tokenized_datasets:
504
+ raise ValueError("--do_eval requires a validation dataset")
505
+ eval_dataset = lm_datasets["validation"]
506
+ if data_args.max_eval_samples is not None:
507
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
508
+
509
+ # Enable tensorboard only on the master node
510
+ if has_tensorboard and jax.process_index() == 0:
511
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
512
+
513
+ # Initialize our training
514
+ rng = jax.random.PRNGKey(training_args.seed)
515
+ rng, dropout_rng = jax.random.split(rng)
516
+
517
+ # Store some constant
518
+ num_epochs = int(training_args.num_train_epochs)
519
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
520
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
521
+ steps_per_epoch = len(train_dataset) // train_batch_size
522
+ total_train_steps = steps_per_epoch * num_epochs
523
+
524
+ # Create learning rate schedule
525
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
526
+ len(train_dataset),
527
+ train_batch_size,
528
+ training_args.num_train_epochs,
529
+ training_args.warmup_steps,
530
+ training_args.learning_rate,
531
+ )
532
+
533
+ # We use Optax's "masking" functionality to not apply weight decay
534
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
535
+ # mask boolean with the same structure as the parameters.
536
+ # The mask is True for parameters that should be decayed.
537
+ # Note that this mask is specifically adapted for FlaxGPT2.
538
+ # For other models, one should correct the layer norm parameter naming
539
+ # accordingly.
540
+ def decay_mask_fn(params):
541
+ flat_params = traverse_util.flatten_dict(params)
542
+ flat_mask = {
543
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
544
+ for path in flat_params
545
+ }
546
+ return traverse_util.unflatten_dict(flat_mask)
547
+
548
+ # create adam optimizer
549
+ if training_args.adafactor:
550
+ # We use the default parameters here to initialize adafactor,
551
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
552
+ optimizer = optax.adafactor(
553
+ learning_rate=linear_decay_lr_schedule_fn,
554
+ )
555
+ else:
556
+ optimizer = optax.adamw(
557
+ learning_rate=linear_decay_lr_schedule_fn,
558
+ b1=training_args.adam_beta1,
559
+ b2=training_args.adam_beta2,
560
+ eps=training_args.adam_epsilon,
561
+ weight_decay=training_args.weight_decay,
562
+ mask=decay_mask_fn,
563
+ )
564
+
565
+ # Setup train state
566
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
567
+
568
+ def loss_fn(logits, labels):
569
+ shift_logits = logits[..., :-1, :]
570
+ shift_labels = labels[..., 1:]
571
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
572
+ return loss.mean()
573
+
574
+ # Define gradient update step fn
575
+ def train_step(state, batch):
576
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
577
+
578
+ def compute_loss(params):
579
+ labels = batch.pop("labels")
580
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
581
+ loss = loss_fn(logits, labels)
582
+ return loss
583
+
584
+ grad_fn = jax.value_and_grad(compute_loss)
585
+ loss, grad = grad_fn(state.params)
586
+ grad = jax.lax.pmean(grad, "batch")
587
+
588
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
589
+
590
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
591
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
592
+
593
+ return new_state, metrics
594
+
595
+ # Define eval fn
596
+ def eval_step(params, batch):
597
+ labels = batch.pop("labels")
598
+ logits = model(**batch, params=params, train=False)[0]
599
+ loss = loss_fn(logits, labels)
600
+
601
+ # summarize metrics
602
+ metrics = {"loss": loss}
603
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
604
+ return metrics
605
+
606
+ # Create parallel version of the train and eval step
607
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
608
+ p_eval_step = jax.pmap(eval_step, "batch")
609
+
610
+ # Replicate the train state on each device
611
+ state = state.replicate()
612
+
613
+ logger.info("***** Running training *****")
614
+ logger.info(f" Num examples = {len(train_dataset)}")
615
+ logger.info(f" Num Epochs = {num_epochs}")
616
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
617
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
618
+ logger.info(f" Total optimization steps = {total_train_steps}")
619
+
620
+ train_time = 0
621
+ train_metrics = []
622
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
623
+ for epoch in epochs:
624
+ # ======================== Training ================================
625
+ train_start = time.time()
626
+
627
+ # Create sampling rng
628
+ rng, input_rng = jax.random.split(rng)
629
+
630
+ # Generate an epoch by shuffling sampling indices from the train dataset
631
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
632
+ steps_per_epoch = len(train_dataset) // train_batch_size
633
+ # train
634
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
635
+ batch = next(train_loader)
636
+ state, train_metric = p_train_step(state, batch)
637
+ train_metrics.append(train_metric)
638
+
639
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
640
+
641
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
642
+ # Save metrics
643
+ train_metric = unreplicate(train_metric)
644
+ train_time += time.time() - train_start
645
+ if has_tensorboard and jax.process_index() == 0:
646
+ logger.info(f"*** Writing training summary after {cur_step} steps ***")
647
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
648
+
649
+ epochs.write(
650
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
651
+ )
652
+
653
+ train_metrics = []
654
+
655
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
656
+ logger.info(f"*** Evaluation after {cur_step} steps ***")
657
+
658
+ eval_metrics = []
659
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
660
+ eval_steps = len(eval_dataset) // eval_batch_size
661
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
662
+ # Model forward
663
+ batch = next(eval_loader)
664
+ metrics = p_eval_step(state.params, batch)
665
+ eval_metrics.append(metrics)
666
+
667
+ # normalize eval metrics
668
+ eval_metrics = get_metrics(eval_metrics)
669
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
670
+
671
+ try:
672
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
673
+ except OverflowError:
674
+ eval_metrics["perplexity"] = float("inf")
675
+
676
+ # Print metrics and update progress bar
677
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
678
+ epochs.write(desc)
679
+ epochs.desc = desc
680
+
681
+ # Save metrics
682
+ if has_tensorboard and jax.process_index() == 0:
683
+ logger.info(f"*** Writing evaluation summary after {cur_step} steps ***")
684
+ # cur_step = epoch * (len(train_dataset) // train_batch_size)
685
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
686
+
687
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
688
+ logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
689
+ # save checkpoint after each epoch and push checkpoint to the hub
690
+ if jax.process_index() == 0:
691
+ params = jax.device_get(unreplicate(state.params))
692
+ model.save_pretrained(
693
+ training_args.output_dir,
694
+ params=params,
695
+ push_to_hub=training_args.push_to_hub,
696
+ commit_message=f"Saving weights and logs of step {cur_step}",
697
+ )
698
+
699
+ if not os.path.exists(os.path.join(training_args.output_dir, "tokenizer.json")):
700
+ logger.info(f"*** Saving tokenizer ***")
701
+ tokenizer.save_pretrained(
702
+ training_args.output_dir,
703
+ push_to_hub=training_args.push_to_hub,
704
+ commit_message=f"Saving tokenizer",
705
+ )
706
+
707
+
708
+ if __name__ == "__main__":
709
+ main()
src/run_clm_flax_with_ckpts.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=causal-lm
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Callable, Optional
32
+
33
+ import datasets
34
+ from datasets import Dataset, load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ from jax import lax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ import transformers
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import checkpoints
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
+ from transformers import (
48
+ CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
50
+ AutoConfig,
51
+ AutoTokenizer,
52
+ FlaxAutoModelForCausalLM,
53
+ HfArgumentParser,
54
+ TrainingArguments,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.testing_utils import CaptureLogger
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ # Cache the result
62
+ has_tensorboard = is_tensorboard_available()
63
+ if has_tensorboard:
64
+ try:
65
+ from flax.metrics.tensorboard import SummaryWriter
66
+ except ImportError as ie:
67
+ has_tensorboard = False
68
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
69
+
70
+ else:
71
+ print(
72
+ "Unable to display metrics through TensorBoard because the package is not installed: "
73
+ "Please run pip install tensorboard to enable."
74
+ )
75
+
76
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
77
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
78
+
79
+
80
+ @dataclass
81
+ class ModelArguments:
82
+ """
83
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
84
+ """
85
+
86
+ model_name_or_path: Optional[str] = field(
87
+ default=None,
88
+ metadata={
89
+ "help": "The model checkpoint for weights initialization."
90
+ "Don't set if you want to train a model from scratch."
91
+ },
92
+ )
93
+ model_type: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
96
+ )
97
+ config_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
99
+ )
100
+ tokenizer_name: Optional[str] = field(
101
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
102
+ )
103
+ cache_dir: Optional[str] = field(
104
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
105
+ )
106
+ use_fast_tokenizer: bool = field(
107
+ default=True,
108
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
109
+ )
110
+ dtype: Optional[str] = field(
111
+ default="float32",
112
+ metadata={
113
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
114
+ },
115
+ )
116
+
117
+
118
+ @dataclass
119
+ class DataTrainingArguments:
120
+ """
121
+ Arguments pertaining to what data we are going to input our model for training and eval.
122
+ """
123
+
124
+ dataset_name: Optional[str] = field(
125
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
126
+ )
127
+ dataset_config_name: Optional[str] = field(
128
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
129
+ )
130
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
131
+ validation_file: Optional[str] = field(
132
+ default=None,
133
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
134
+ )
135
+ max_train_samples: Optional[int] = field(
136
+ default=None,
137
+ metadata={
138
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
139
+ "value if set."
140
+ },
141
+ )
142
+ max_eval_samples: Optional[int] = field(
143
+ default=None,
144
+ metadata={
145
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
146
+ "value if set."
147
+ },
148
+ )
149
+ overwrite_cache: bool = field(
150
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
151
+ )
152
+ validation_split_percentage: Optional[int] = field(
153
+ default=5,
154
+ metadata={
155
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
156
+ },
157
+ )
158
+ block_size: Optional[int] = field(
159
+ default=None,
160
+ metadata={
161
+ "help": "Optional input sequence length after tokenization. "
162
+ "The training dataset will be truncated in block of this size for training. "
163
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
164
+ },
165
+ )
166
+ overwrite_cache: bool = field(
167
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
168
+ )
169
+ preprocessing_num_workers: Optional[int] = field(
170
+ default=None,
171
+ metadata={"help": "The number of processes to use for the preprocessing."},
172
+ )
173
+
174
+ def __post_init__(self):
175
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
176
+ raise ValueError("Need either a dataset name or a training/validation file.")
177
+ else:
178
+ if self.train_file is not None:
179
+ extension = self.train_file.split(".")[-1]
180
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
181
+ if self.validation_file is not None:
182
+ extension = self.validation_file.split(".")[-1]
183
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
184
+
185
+
186
+ class TrainState(train_state.TrainState):
187
+ dropout_rng: jnp.ndarray
188
+
189
+ def replicate(self):
190
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
191
+
192
+
193
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
194
+ """
195
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
196
+ Shuffle batches if `shuffle` is `True`.
197
+ """
198
+ steps_per_epoch = len(dataset) // batch_size
199
+
200
+ if shuffle:
201
+ batch_idx = jax.random.permutation(rng, len(dataset))
202
+ else:
203
+ batch_idx = jnp.arange(len(dataset))
204
+
205
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
206
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
207
+
208
+ for idx in batch_idx:
209
+ batch = dataset[idx]
210
+ batch = {k: jnp.array(v) for k, v in batch.items()}
211
+
212
+ batch = shard(batch)
213
+
214
+ yield batch
215
+
216
+
217
+ # def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
218
+ # summary_writer.scalar("train_time", train_time, step)
219
+ #
220
+ # train_metrics = get_metrics(train_metrics)
221
+ # for key, vals in train_metrics.items():
222
+ # tag = f"train_{key}"
223
+ # for i, val in enumerate(vals):
224
+ # summary_writer.scalar(tag, val, step - len(vals) + i + 1)
225
+ #
226
+ # for metric_name, value in eval_metrics.items():
227
+ # summary_writer.scalar(f"eval_{metric_name}", value, step)
228
+
229
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
230
+ summary_writer.scalar("train_time", train_time, step)
231
+
232
+ train_metrics = get_metrics(train_metrics)
233
+ for key, vals in train_metrics.items():
234
+ tag = f"train_{key}"
235
+ for i, val in enumerate(vals):
236
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
237
+
238
+
239
+ def write_eval_metric(summary_writer, eval_metrics, step):
240
+ for metric_name, value in eval_metrics.items():
241
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
242
+
243
+
244
+ def create_learning_rate_fn(
245
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
246
+ ) -> Callable[[int], jnp.array]:
247
+ """Returns a linear warmup, linear_decay learning rate function."""
248
+ steps_per_epoch = train_ds_size // train_batch_size
249
+ num_train_steps = steps_per_epoch * num_train_epochs
250
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
251
+ decay_fn = optax.linear_schedule(
252
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
253
+ )
254
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
255
+ return schedule_fn
256
+
257
+
258
+ def restore_checkpoint(state, workdir):
259
+ return checkpoints.restore_checkpoint(workdir, state)
260
+
261
+
262
+ def save_checkpoint(state, workdir):
263
+ if jax.process_index() == 0:
264
+ # get train state from the first replica
265
+ state = jax.device_get(jax.tree_map(lambda x: x[0], state))
266
+ step = int(state.step)
267
+ checkpoints.save_checkpoint(workdir, state, step, keep=3)
268
+
269
+
270
+ def main():
271
+ # See all possible arguments in src/transformers/training_args.py
272
+ # or by passing the --help flag to this script.
273
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
274
+
275
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
276
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
277
+ # If we pass only one argument to the script and it's the path to a json file,
278
+ # let's parse it to get our arguments.
279
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
280
+ else:
281
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
282
+
283
+ if (
284
+ os.path.exists(training_args.output_dir)
285
+ and os.listdir(training_args.output_dir)
286
+ and training_args.do_train
287
+ and not training_args.overwrite_output_dir
288
+ ):
289
+ raise ValueError(
290
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
291
+ "Use --overwrite_output_dir to overcome."
292
+ )
293
+
294
+ # Make one log on every process with the configuration for debugging.
295
+ logging.basicConfig(
296
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
297
+ datefmt="%m/%d/%Y %H:%M:%S",
298
+ level=logging.INFO,
299
+ )
300
+ # Setup logging, we only want one process per machine to log things on the screen.
301
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
302
+ if jax.process_index() == 0:
303
+ datasets.utils.logging.set_verbosity_warning()
304
+ transformers.utils.logging.set_verbosity_info()
305
+ else:
306
+ datasets.utils.logging.set_verbosity_error()
307
+ transformers.utils.logging.set_verbosity_error()
308
+
309
+ # Set the verbosity to info of the Transformers logger (on main process only):
310
+ logger.info(f"Training/evaluation parameters {training_args}")
311
+
312
+ checkpoints_dir = os.path.join(training_args.output_dir, "checkpoints")
313
+ os.makedirs(checkpoints_dir, exist_ok=True)
314
+
315
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
316
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
317
+ # (the dataset will be downloaded automatically from the datasets Hub).
318
+ #
319
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
320
+ # 'text' is found. You can easily tweak this behavior (see below).
321
+ #
322
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
323
+ # download the dataset.
324
+ if data_args.dataset_name is not None:
325
+ # Downloading and loading a dataset from the hub.
326
+ dataset = load_dataset(
327
+ data_args.dataset_name,
328
+ data_args.dataset_config_name,
329
+ cache_dir=model_args.cache_dir,
330
+ keep_in_memory=False
331
+ )
332
+
333
+ if "validation" not in dataset.keys():
334
+ dataset["validation"] = load_dataset(
335
+ data_args.dataset_name,
336
+ data_args.dataset_config_name,
337
+ split=f"train[:{data_args.validation_split_percentage}%]",
338
+ cache_dir=model_args.cache_dir,
339
+ )
340
+ dataset["train"] = load_dataset(
341
+ data_args.dataset_name,
342
+ data_args.dataset_config_name,
343
+ split=f"train[{data_args.validation_split_percentage}%:]",
344
+ cache_dir=model_args.cache_dir,
345
+ )
346
+ else:
347
+ data_files = {}
348
+ if data_args.train_file is not None:
349
+ data_files["train"] = data_args.train_file
350
+ if data_args.validation_file is not None:
351
+ data_files["validation"] = data_args.validation_file
352
+ extension = data_args.train_file.split(".")[-1]
353
+ if extension == "txt":
354
+ extension = "text"
355
+
356
+ dataset = load_dataset(
357
+ extension,
358
+ data_files=data_files,
359
+ delimiter="\t",
360
+ cache_dir=model_args.cache_dir
361
+ )
362
+
363
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
364
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
365
+
366
+ # Load pretrained model and tokenizer
367
+
368
+ # Distributed training:
369
+ # The .from_pretrained methods guarantee that only one local process can concurrently
370
+ # download model & vocab.
371
+ if model_args.config_name:
372
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
373
+ elif model_args.model_name_or_path:
374
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
375
+ else:
376
+ config = CONFIG_MAPPING[model_args.model_type]()
377
+ logger.warning("You are instantiating a new config instance from scratch.")
378
+
379
+ if model_args.tokenizer_name:
380
+ tokenizer = AutoTokenizer.from_pretrained(
381
+ model_args.tokenizer_name,
382
+ cache_dir=model_args.cache_dir,
383
+ use_fast=model_args.use_fast_tokenizer
384
+ )
385
+ elif model_args.model_name_or_path:
386
+ tokenizer = AutoTokenizer.from_pretrained(
387
+ model_args.model_name_or_path,
388
+ cache_dir=model_args.cache_dir,
389
+ use_fast=model_args.use_fast_tokenizer
390
+ )
391
+ else:
392
+ raise ValueError(
393
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
394
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
395
+ )
396
+
397
+ if model_args.model_name_or_path:
398
+ model = FlaxAutoModelForCausalLM.from_pretrained(
399
+ model_args.model_name_or_path,
400
+ config=config,
401
+ seed=training_args.seed,
402
+ dtype=getattr(jnp, model_args.dtype)
403
+ )
404
+ else:
405
+ model = FlaxAutoModelForCausalLM.from_config(
406
+ config,
407
+ seed=training_args.seed,
408
+ dtype=getattr(jnp, model_args.dtype)
409
+ )
410
+
411
+ # Preprocessing the datasets.
412
+ # First we tokenize all the texts.
413
+ if training_args.do_train:
414
+ column_names = dataset["train"].column_names
415
+ else:
416
+ column_names = dataset["validation"].column_names
417
+ text_column_name = "text" if "text" in column_names else column_names[0]
418
+
419
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
420
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
421
+
422
+ def tokenize_function(examples):
423
+ with CaptureLogger(tok_logger) as cl:
424
+ output = tokenizer(examples[text_column_name])
425
+ # clm input could be much much longer than block_size
426
+ if "Token indices sequence length is longer than the" in cl.out:
427
+ tok_logger.warning(
428
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
429
+ )
430
+ return output
431
+
432
+ tokenized_datasets = dataset.map(
433
+ tokenize_function,
434
+ batched=True,
435
+ num_proc=data_args.preprocessing_num_workers,
436
+ remove_columns=column_names,
437
+ load_from_cache_file=not data_args.overwrite_cache,
438
+ )
439
+
440
+ if data_args.block_size is None:
441
+ block_size = tokenizer.model_max_length
442
+ if block_size > config.max_position_embeddings:
443
+ logger.warning(
444
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
445
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
446
+ )
447
+ block_size = 1024
448
+ else:
449
+ if data_args.block_size > tokenizer.model_max_length:
450
+ logger.warning(
451
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
452
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
453
+ )
454
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
455
+
456
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
457
+ def group_texts(examples):
458
+ # Concatenate all texts.
459
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
460
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
461
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
462
+ # customize this part to your needs.
463
+ total_length = (total_length // block_size) * block_size
464
+ # Split by chunks of max_len.
465
+ result = {
466
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
467
+ for k, t in concatenated_examples.items()
468
+ }
469
+ result["labels"] = result["input_ids"].copy()
470
+ return result
471
+
472
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
473
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
474
+ # to preprocess.
475
+ #
476
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
477
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
478
+
479
+ lm_datasets = tokenized_datasets.map(
480
+ group_texts,
481
+ batched=True,
482
+ num_proc=data_args.preprocessing_num_workers,
483
+ load_from_cache_file=not data_args.overwrite_cache,
484
+ )
485
+
486
+ if training_args.do_train:
487
+ if "train" not in tokenized_datasets:
488
+ raise ValueError("--do_train requires a train dataset")
489
+ train_dataset = lm_datasets["train"]
490
+ if data_args.max_train_samples is not None:
491
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
492
+
493
+ if training_args.do_eval:
494
+ if "validation" not in tokenized_datasets:
495
+ raise ValueError("--do_eval requires a validation dataset")
496
+ eval_dataset = lm_datasets["validation"]
497
+ if data_args.max_eval_samples is not None:
498
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
499
+
500
+ # Enable tensorboard only on the master node
501
+ if has_tensorboard and jax.process_index() == 0:
502
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
503
+
504
+ # Initialize our training
505
+ rng = jax.random.PRNGKey(training_args.seed)
506
+ rng, dropout_rng = jax.random.split(rng)
507
+
508
+ # Store some constant
509
+ num_epochs = int(training_args.num_train_epochs)
510
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
511
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
512
+ steps_per_epoch = len(train_dataset) // train_batch_size
513
+
514
+ # total_train_steps = steps_per_epoch * num_epochs
515
+ if training_args.max_steps == -1:
516
+ total_train_steps = steps_per_epoch * num_epochs
517
+ else:
518
+ total_train_steps = training_args.max_steps
519
+
520
+ # Create learning rate schedule
521
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
522
+ len(train_dataset),
523
+ train_batch_size,
524
+ training_args.num_train_epochs,
525
+ training_args.warmup_steps,
526
+ training_args.learning_rate,
527
+ )
528
+
529
+ # We use Optax's "masking" functionality to not apply weight decay
530
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
531
+ # mask boolean with the same structure as the parameters.
532
+ # The mask is True for parameters that should be decayed.
533
+ # Note that this mask is specifically adapted for FlaxGPT2.
534
+ # For other models, one should correct the layer norm parameter naming
535
+ # accordingly.
536
+ def decay_mask_fn(params):
537
+ flat_params = traverse_util.flatten_dict(params)
538
+ flat_mask = {
539
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
540
+ for path in flat_params
541
+ }
542
+ return traverse_util.unflatten_dict(flat_mask)
543
+
544
+ # create adam optimizer
545
+ adamw = optax.adamw(
546
+ learning_rate=linear_decay_lr_schedule_fn,
547
+ b1=training_args.adam_beta1,
548
+ b2=training_args.adam_beta2,
549
+ eps=training_args.adam_epsilon,
550
+ weight_decay=training_args.weight_decay,
551
+ mask=decay_mask_fn,
552
+ )
553
+
554
+ # Setup train state
555
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
556
+
557
+ # Restore states
558
+ state = restore_checkpoint(state, checkpoints_dir)
559
+ step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint
560
+ epoch_offset = int(num_epochs - ((total_train_steps - step_offset) / steps_per_epoch))
561
+
562
+ def loss_fn(logits, labels):
563
+ shift_logits = logits[..., :-1, :]
564
+ shift_labels = labels[..., 1:]
565
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
566
+ return loss.mean()
567
+
568
+ # Define gradient update step fn
569
+ def train_step(state, batch):
570
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
571
+
572
+ def compute_loss(params):
573
+ labels = batch.pop("labels")
574
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
575
+ loss = loss_fn(logits, labels)
576
+ return loss
577
+
578
+ grad_fn = jax.value_and_grad(compute_loss)
579
+ loss, grad = grad_fn(state.params)
580
+ grad = jax.lax.pmean(grad, "batch")
581
+
582
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
583
+
584
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
585
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
586
+
587
+ return new_state, metrics
588
+
589
+ # Define eval fn
590
+ def eval_step(params, batch):
591
+ labels = batch.pop("labels")
592
+ logits = model(**batch, params=params, train=False)[0]
593
+ loss = loss_fn(logits, labels)
594
+
595
+ # summarize metrics
596
+ metrics = {"loss": loss}
597
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
598
+ return metrics
599
+
600
+ # Create parallel version of the train and eval step
601
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
602
+ p_eval_step = jax.pmap(eval_step, "batch")
603
+
604
+ # Replicate the train state on each device
605
+ state = state.replicate()
606
+
607
+ logger.info("***** Running training *****")
608
+ logger.info(f" Num examples = {len(train_dataset)}")
609
+ logger.info(f" Num Epochs = {num_epochs}")
610
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
611
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
612
+ logger.info(f" Total optimization steps = {total_train_steps}")
613
+
614
+ if step_offset > 0:
615
+ logger.info(" Continuing training from checkpoint")
616
+ logger.info(f" Continuing training from epoch {epoch_offset}")
617
+ logger.info(f" Continuing training from global step {step_offset}")
618
+
619
+ train_time = 0
620
+ train_metrics = []
621
+ epochs = tqdm(range(epoch_offset, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
622
+ for epoch in epochs:
623
+ # ======================== Training ================================
624
+ train_start = time.time()
625
+
626
+ # Create sampling rng
627
+ rng, input_rng = jax.random.split(rng)
628
+
629
+ # Generate an epoch by shuffling sampling indices from the train dataset
630
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
631
+ steps_per_epoch = len(train_dataset) // train_batch_size
632
+ num_steps = abs(step_offset - (steps_per_epoch * (epoch + 1)))
633
+
634
+ # train
635
+ for step in tqdm(range(num_steps), desc="Training...", position=1, leave=False):
636
+ batch = next(train_loader)
637
+ state, train_metric = p_train_step(state, batch)
638
+ train_metrics.append(train_metric)
639
+
640
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
641
+
642
+ if cur_step % training_args.logging_steps and cur_step > 0:
643
+ # Save metrics
644
+ train_metric = unreplicate(train_metric)
645
+ train_time += time.time() - train_start
646
+ if has_tensorboard and jax.process_index() == 0:
647
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
648
+
649
+ epochs.write(
650
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
651
+ )
652
+
653
+ train_metrics = []
654
+
655
+ if cur_step % training_args.save_steps and cur_step > 0:
656
+ save_checkpoint(state, checkpoints_dir)
657
+
658
+ # ======================== Evaluating ==============================
659
+ eval_metrics = []
660
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
661
+ eval_steps = len(eval_dataset) // eval_batch_size
662
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
663
+ # Model forward
664
+ batch = next(eval_loader)
665
+ metrics = p_eval_step(state.params, batch)
666
+ eval_metrics.append(metrics)
667
+
668
+ # normalize eval metrics
669
+ eval_metrics = get_metrics(eval_metrics)
670
+
671
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
672
+
673
+ try:
674
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
675
+ except OverflowError:
676
+ eval_metrics["perplexity"] = float("inf")
677
+
678
+ # Print metrics and update progress bar
679
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
680
+ epochs.write(desc)
681
+ epochs.desc = desc
682
+
683
+ # Save metrics
684
+ if has_tensorboard and jax.process_index() == 0:
685
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
686
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
687
+
688
+ # save checkpoint after each epoch and push checkpoint to the hub
689
+ if jax.process_index() == 0:
690
+ params = jax.device_get(unreplicate(state.params))
691
+ model.save_pretrained(
692
+ training_args.output_dir,
693
+ params=params,
694
+ push_to_hub=training_args.push_to_hub,
695
+ commit_message=f"Saving weights and logs of epoch {epoch + 1}",
696
+ )
697
+
698
+
699
+ if __name__ == "__main__":
700
+ main()
src/run_config.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ # export OUTPUT_DIR=./
7
+ export OUTPUT_DIR=/home/m3hrdadfi/code/gpt-neo-1.3B-persian
8
+ export NAME_OR_PATH=EleutherAI/gpt-neo-1.3B
9
+
10
+ python src/create_config.py \
11
+ --output_dir="$OUTPUT_DIR" \
12
+ --name_or_path="$NAME_OR_PATH" \
13
+ --params='{"vocab_size": 50000,"bos_token_id": 5, "eos_token_id": 5, "pad_token_id": 5}'
src/run_dataset.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ export OUTPUT_DIR=/home/m3hrdadfi/data/
7
+ export DATASET_NAME=oscar
8
+ export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
9
+
10
+ python src/create_dataset.py \
11
+ --output_dir="$OUTPUT_DIR" \
12
+ --dataset_name="$DATASET_NAME" \
13
+ --dataset_config_name="$DATASET_CONFIG_NAME"
src/run_tokenizer.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ export OUTPUT_DIR=/home/saied/code/gpt2-medium-persian
7
+ export DATASET_NAME=oscar
8
+ export DATASET_CONFIG_NAME=unshuffled_deduplicated_fa
9
+ export VOCAB_SIZE=50000
10
+ export MIN_FREQUENCY=2
11
+ export SPECIAL_TOKENS='<s>','<pad>','</s>','<unk>','<mask>','<|endoftext|>','<|startoftext|>','<sep>','<cls>','<nl>','<tab>','<zwnj>','[U1]','[U2]','[U3]','[U4]','[U5]','[U6]','[U7]','[U8]','[U9]','[U10]','[U11]','[U12]','[U13]','[U14]','[U15]','[U16]','[U17]','[U18]','[U19]','[U20]'
12
+
13
+
14
+ python src/train_tokenizer.py \
15
+ --output_dir="$OUTPUT_DIR" \
16
+ --dataset_name="$DATASET_NAME" \
17
+ --dataset_config_name="$DATASET_CONFIG_NAME" \
18
+ --vocab_size=$VOCAB_SIZE \
19
+ --min_frequency=$MIN_FREQUENCY \
20
+ --special_tokens="$SPECIAL_TOKENS"
src/train_tokenizer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Tuple, Union, Any
7
+
8
+ from datasets import load_dataset
9
+ from tokenizers import ByteLevelBPETokenizer
10
+ from transformers import (
11
+ HfArgumentParser,
12
+ )
13
+
14
+ from data_utils import (
15
+ filter_by_lang_regex,
16
+ filter_by_num_tokens,
17
+ filter_by_num_sents,
18
+ filter_by_adv,
19
+ normalizer
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class TokenizerArguments:
27
+ """
28
+ Arguments to which tokenizer we are going to set up.
29
+ """
30
+
31
+ output_dir: str = field(
32
+ default=".",
33
+ metadata={"help": "The output directory where the config will be written."},
34
+ )
35
+ dataset_name: Optional[str] = field(
36
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
37
+ )
38
+ dataset_config_name: Optional[str] = field(
39
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
40
+ )
41
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
42
+ cache_dir: Optional[str] = field(
43
+ default=None,
44
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
45
+ )
46
+ special_tokens: Optional[str] = field(
47
+ default=None,
48
+ metadata={"help": "The list of special tokens that you want to add in your training."}
49
+ )
50
+ vocab_size: Optional[int] = field(
51
+ default=56000,
52
+ metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
53
+ )
54
+ min_frequency: Optional[int] = field(
55
+ default=2,
56
+ metadata={"help": "The minimum frequency a pair should have in order to be merged"}
57
+ )
58
+ show_progress: Optional[bool] = field(
59
+ default=True,
60
+ metadata={"help": "Whether to show progress bars while training"}
61
+ )
62
+
63
+ def __post_init__(self):
64
+ if self.special_tokens is None:
65
+ special_tokens = [
66
+ "<s>", "<pad>", "</s>", "<unk>", "<mask>",
67
+ "<|endoftext|>", "<|startoftext|>",
68
+ "<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
69
+ ]
70
+ special_tokens += [f"[U{i}]" for i in range(1, 21)]
71
+ else:
72
+ special_tokens = list(self.special_tokens.split(","))
73
+
74
+ self.special_tokens = special_tokens
75
+ if self.dataset_name is None and self.train_file is None:
76
+ raise ValueError("Need either a dataset name or a training file.")
77
+ else:
78
+ if self.train_file is not None:
79
+ extension = self.train_file.split(".")[-1]
80
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
81
+
82
+
83
+ def main():
84
+ parser = HfArgumentParser([TokenizerArguments])
85
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
86
+ # If we pass only one argument to the script and it's the path to a json file,
87
+ # let's parse it to get our arguments.
88
+ tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
89
+ else:
90
+ tokenizer_args = parser.parse_args_into_dataclasses()[0]
91
+
92
+ # Setup logging
93
+ logging.basicConfig(
94
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
95
+ datefmt="%m/%d/%Y %H:%M:%S",
96
+ handlers=[logging.StreamHandler(sys.stdout)],
97
+ )
98
+ logger.setLevel(logging.INFO)
99
+
100
+ logger.info(f"Training tokenizer")
101
+
102
+ if tokenizer_args.dataset_name is not None:
103
+ raw_dataset = load_dataset(
104
+ tokenizer_args.dataset_name,
105
+ tokenizer_args.dataset_config_name,
106
+ cache_dir=tokenizer_args.cache_dir,
107
+ split="train"
108
+ )
109
+ else:
110
+ data_files = {"train": tokenizer_args.train_file}
111
+ extension = tokenizer_args.train_file.split(".")[-1]
112
+ if extension == "txt":
113
+ extension = "text"
114
+
115
+ raw_dataset = load_dataset(
116
+ extension,
117
+ data_files=data_files,
118
+ delimiter="\t",
119
+ cache_dir=tokenizer_args.cache_dir,
120
+ )
121
+
122
+ logger.info("Preprocessing the dataset")
123
+ dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
124
+ dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
125
+ dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
126
+ dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
127
+ dataset = dataset.map(normalizer)
128
+ logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
129
+
130
+ tokenizer = ByteLevelBPETokenizer()
131
+
132
+ def batch_iterative(batch_size=1000):
133
+ for i in range(0, len(dataset), batch_size):
134
+ yield dataset[i: i + batch_size]["text"]
135
+
136
+ tokenizer.train_from_iterator(
137
+ batch_iterative(),
138
+ vocab_size=tokenizer_args.vocab_size,
139
+ special_tokens=tokenizer_args.special_tokens,
140
+ min_frequency=tokenizer_args.min_frequency,
141
+ show_progress=tokenizer_args.show_progress,
142
+ )
143
+
144
+ logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}")
145
+ os.makedirs(tokenizer_args.output_dir, exist_ok=True)
146
+ tokenizer.save_model(tokenizer_args.output_dir)
147
+ tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff