Fixes and eval configs
Browse files- evaluation/paws.yaml +0 -2
- evaluation/token.yaml +53 -0
- evaluation/xnli.yaml +0 -2
- images/bertin.png +0 -0
- run_mlm_flax_stream.py +7 -6
- utils/download_mc4es_sampled.py +32 -0
evaluation/paws.yaml
CHANGED
@@ -36,8 +36,6 @@ parameters:
|
|
36 |
value: ./outputs
|
37 |
overwrite_output_dir:
|
38 |
value: true
|
39 |
-
resume_from_checkpoint:
|
40 |
-
value: false
|
41 |
max_seq_length:
|
42 |
value: 512
|
43 |
pad_to_max_length:
|
|
|
36 |
value: ./outputs
|
37 |
overwrite_output_dir:
|
38 |
value: true
|
|
|
|
|
39 |
max_seq_length:
|
40 |
value: 512
|
41 |
pad_to_max_length:
|
evaluation/token.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: BERTIN NER and POS es
|
2 |
+
project: bertin-eval
|
3 |
+
enitity: versae
|
4 |
+
program: run_ner.py
|
5 |
+
command:
|
6 |
+
- ${env}
|
7 |
+
- ${interpreter}
|
8 |
+
- ${program}
|
9 |
+
- ${args}
|
10 |
+
method: grid
|
11 |
+
metric:
|
12 |
+
name: eval/accuracy
|
13 |
+
goal: maximize
|
14 |
+
parameters:
|
15 |
+
model_name_or_path:
|
16 |
+
values:
|
17 |
+
- bertin-project/bertin-base-gaussian-exp-512seqlen
|
18 |
+
- bertin-project/bertin-base-random-exp-512seqlen
|
19 |
+
- bertin-project/bertin-base-gaussian
|
20 |
+
- bertin-project/bertin-base-stepwise
|
21 |
+
- bertin-project/bertin-base-random
|
22 |
+
- bertin-project/bertin-roberta-base-spanish
|
23 |
+
- flax-community/bertin-roberta-large-spanish
|
24 |
+
- BSC-TeMU/roberta-base-bne
|
25 |
+
- dccuchile/bert-base-spanish-wwm-cased
|
26 |
+
- bert-base-multilingual-cased
|
27 |
+
num_train_epochs:
|
28 |
+
values: [5]
|
29 |
+
task_name:
|
30 |
+
values:
|
31 |
+
- ner
|
32 |
+
- pos
|
33 |
+
dataset_name:
|
34 |
+
value: conll2002
|
35 |
+
dataset_config_name:
|
36 |
+
value: es
|
37 |
+
output_dir:
|
38 |
+
value: ./outputs
|
39 |
+
overwrite_output_dir:
|
40 |
+
value: true
|
41 |
+
pad_to_max_length:
|
42 |
+
value: true
|
43 |
+
per_device_train_batch_size:
|
44 |
+
value: 16
|
45 |
+
per_device_eval_batch_size:
|
46 |
+
value: 16
|
47 |
+
save_total_limit:
|
48 |
+
value: 1
|
49 |
+
do_train:
|
50 |
+
value: true
|
51 |
+
do_eval:
|
52 |
+
value: true
|
53 |
+
|
evaluation/xnli.yaml
CHANGED
@@ -36,8 +36,6 @@ parameters:
|
|
36 |
value: ./outputs
|
37 |
overwrite_output_dir:
|
38 |
value: true
|
39 |
-
resume_from_checkpoint:
|
40 |
-
value: false
|
41 |
max_seq_length:
|
42 |
value: 512
|
43 |
pad_to_max_length:
|
|
|
36 |
value: ./outputs
|
37 |
overwrite_output_dir:
|
38 |
value: true
|
|
|
|
|
39 |
max_seq_length:
|
40 |
value: 512
|
41 |
pad_to_max_length:
|
images/bertin.png
CHANGED
run_mlm_flax_stream.py
CHANGED
@@ -384,8 +384,8 @@ def to_f32(t):
|
|
384 |
|
385 |
|
386 |
def convert(output_dir, destination_dir="./"):
|
387 |
-
shutil.copyfile(Path(output_dir) / "flax_model.msgpack", destination_dir)
|
388 |
-
shutil.copyfile(Path(output_dir) / "config.json", destination_dir)
|
389 |
# Saving extra files from config.json and tokenizer.json files
|
390 |
tokenizer = AutoTokenizer.from_pretrained(destination_dir)
|
391 |
tokenizer.save_pretrained(destination_dir)
|
@@ -611,8 +611,8 @@ if __name__ == "__main__":
|
|
611 |
|
612 |
# Setup train state
|
613 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
614 |
-
saved_step =
|
615 |
-
if "checkpoint" in model_args.model_name_or_path:
|
616 |
params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
|
617 |
# Create learning rate schedule
|
618 |
warmup_fn = optax.linear_schedule(
|
@@ -714,8 +714,9 @@ if __name__ == "__main__":
|
|
714 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
715 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
716 |
|
|
|
717 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
718 |
-
for step in range(
|
719 |
if step < saved_step:
|
720 |
steps.update(1)
|
721 |
continue
|
@@ -827,5 +828,5 @@ if __name__ == "__main__":
|
|
827 |
training_args.output_dir,
|
828 |
params=params,
|
829 |
push_to_hub=training_args.push_to_hub,
|
830 |
-
commit_message=last_desc,
|
831 |
)
|
|
|
384 |
|
385 |
|
386 |
def convert(output_dir, destination_dir="./"):
|
387 |
+
shutil.copyfile(Path(output_dir) / "flax_model.msgpack", Path(destination_dir) / "flax_model.msgpack")
|
388 |
+
shutil.copyfile(Path(output_dir) / "config.json", Path(destination_dir) / "config.json")
|
389 |
# Saving extra files from config.json and tokenizer.json files
|
390 |
tokenizer = AutoTokenizer.from_pretrained(destination_dir)
|
391 |
tokenizer.save_pretrained(destination_dir)
|
|
|
611 |
|
612 |
# Setup train state
|
613 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
614 |
+
saved_step = -1
|
615 |
+
if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path:
|
616 |
params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
|
617 |
# Create learning rate schedule
|
618 |
warmup_fn = optax.linear_schedule(
|
|
|
714 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
715 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
716 |
|
717 |
+
last_desc = ""
|
718 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
719 |
+
for step in range(num_train_steps):
|
720 |
if step < saved_step:
|
721 |
steps.update(1)
|
722 |
continue
|
|
|
828 |
training_args.output_dir,
|
829 |
params=params,
|
830 |
push_to_hub=training_args.push_to_hub,
|
831 |
+
commit_message=last_desc or "Saving model after training",
|
832 |
)
|
utils/download_mc4es_sampled.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import gzip
|
3 |
+
import json
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
_DATA_URL_TRAIN = "https://huggingface.co/datasets/bertin-project/mc4-es-sampled/resolve/main/mc4-es-train-50M-{config}-shard-{index:04d}-of-{n_shards:04d}.json.gz"
|
10 |
+
|
11 |
+
|
12 |
+
def main(config="stepwise"):
|
13 |
+
data_urls = [
|
14 |
+
_DATA_URL_TRAIN.format(
|
15 |
+
config=config,
|
16 |
+
index=index + 1,
|
17 |
+
n_shards=1024,
|
18 |
+
)
|
19 |
+
for index in range(1024)
|
20 |
+
]
|
21 |
+
with open(f"mc4-es-train-50M-{config}.jsonl", "w") as f:
|
22 |
+
for dara_url in tqdm(data_urls):
|
23 |
+
response = requests.get(dara_url)
|
24 |
+
bio = io.BytesIO(response.content)
|
25 |
+
with gzip.open(bio, "rt", encoding="utf8") as g:
|
26 |
+
for line in g:
|
27 |
+
json_line = json.loads(line.strip())
|
28 |
+
f.write(json.dumps(json_line) + "\n")
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
main(sys.argv[1])
|