versae commited on
Commit
a38611e
1 Parent(s): 1dc4fb8

Fixes and eval configs

Browse files
evaluation/paws.yaml CHANGED
@@ -36,8 +36,6 @@ parameters:
36
  value: ./outputs
37
  overwrite_output_dir:
38
  value: true
39
- resume_from_checkpoint:
40
- value: false
41
  max_seq_length:
42
  value: 512
43
  pad_to_max_length:
36
  value: ./outputs
37
  overwrite_output_dir:
38
  value: true
 
 
39
  max_seq_length:
40
  value: 512
41
  pad_to_max_length:
evaluation/token.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: BERTIN NER and POS es
2
+ project: bertin-eval
3
+ enitity: versae
4
+ program: run_ner.py
5
+ command:
6
+ - ${env}
7
+ - ${interpreter}
8
+ - ${program}
9
+ - ${args}
10
+ method: grid
11
+ metric:
12
+ name: eval/accuracy
13
+ goal: maximize
14
+ parameters:
15
+ model_name_or_path:
16
+ values:
17
+ - bertin-project/bertin-base-gaussian-exp-512seqlen
18
+ - bertin-project/bertin-base-random-exp-512seqlen
19
+ - bertin-project/bertin-base-gaussian
20
+ - bertin-project/bertin-base-stepwise
21
+ - bertin-project/bertin-base-random
22
+ - bertin-project/bertin-roberta-base-spanish
23
+ - flax-community/bertin-roberta-large-spanish
24
+ - BSC-TeMU/roberta-base-bne
25
+ - dccuchile/bert-base-spanish-wwm-cased
26
+ - bert-base-multilingual-cased
27
+ num_train_epochs:
28
+ values: [5]
29
+ task_name:
30
+ values:
31
+ - ner
32
+ - pos
33
+ dataset_name:
34
+ value: conll2002
35
+ dataset_config_name:
36
+ value: es
37
+ output_dir:
38
+ value: ./outputs
39
+ overwrite_output_dir:
40
+ value: true
41
+ pad_to_max_length:
42
+ value: true
43
+ per_device_train_batch_size:
44
+ value: 16
45
+ per_device_eval_batch_size:
46
+ value: 16
47
+ save_total_limit:
48
+ value: 1
49
+ do_train:
50
+ value: true
51
+ do_eval:
52
+ value: true
53
+
evaluation/xnli.yaml CHANGED
@@ -36,8 +36,6 @@ parameters:
36
  value: ./outputs
37
  overwrite_output_dir:
38
  value: true
39
- resume_from_checkpoint:
40
- value: false
41
  max_seq_length:
42
  value: 512
43
  pad_to_max_length:
36
  value: ./outputs
37
  overwrite_output_dir:
38
  value: true
 
 
39
  max_seq_length:
40
  value: 512
41
  pad_to_max_length:
images/bertin.png CHANGED
run_mlm_flax_stream.py CHANGED
@@ -384,8 +384,8 @@ def to_f32(t):
384
 
385
 
386
  def convert(output_dir, destination_dir="./"):
387
- shutil.copyfile(Path(output_dir) / "flax_model.msgpack", destination_dir)
388
- shutil.copyfile(Path(output_dir) / "config.json", destination_dir)
389
  # Saving extra files from config.json and tokenizer.json files
390
  tokenizer = AutoTokenizer.from_pretrained(destination_dir)
391
  tokenizer.save_pretrained(destination_dir)
@@ -611,8 +611,8 @@ if __name__ == "__main__":
611
 
612
  # Setup train state
613
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
614
- saved_step = 0
615
- if "checkpoint" in model_args.model_name_or_path:
616
  params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
617
  # Create learning rate schedule
618
  warmup_fn = optax.linear_schedule(
@@ -714,8 +714,9 @@ if __name__ == "__main__":
714
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
715
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
716
 
 
717
  steps = tqdm(range(num_train_steps), desc="Training...", position=0)
718
- for step in range(saved_step, num_train_steps):
719
  if step < saved_step:
720
  steps.update(1)
721
  continue
@@ -827,5 +828,5 @@ if __name__ == "__main__":
827
  training_args.output_dir,
828
  params=params,
829
  push_to_hub=training_args.push_to_hub,
830
- commit_message=last_desc,
831
  )
384
 
385
 
386
  def convert(output_dir, destination_dir="./"):
387
+ shutil.copyfile(Path(output_dir) / "flax_model.msgpack", Path(destination_dir) / "flax_model.msgpack")
388
+ shutil.copyfile(Path(output_dir) / "config.json", Path(destination_dir) / "config.json")
389
  # Saving extra files from config.json and tokenizer.json files
390
  tokenizer = AutoTokenizer.from_pretrained(destination_dir)
391
  tokenizer.save_pretrained(destination_dir)
611
 
612
  # Setup train state
613
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
614
+ saved_step = -1
615
+ if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path:
616
  params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
617
  # Create learning rate schedule
618
  warmup_fn = optax.linear_schedule(
714
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
715
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
716
 
717
+ last_desc = ""
718
  steps = tqdm(range(num_train_steps), desc="Training...", position=0)
719
+ for step in range(num_train_steps):
720
  if step < saved_step:
721
  steps.update(1)
722
  continue
828
  training_args.output_dir,
829
  params=params,
830
  push_to_hub=training_args.push_to_hub,
831
+ commit_message=last_desc or "Saving model after training",
832
  )
utils/download_mc4es_sampled.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import gzip
3
+ import json
4
+ import sys
5
+
6
+ import requests
7
+ from tqdm import tqdm
8
+
9
+ _DATA_URL_TRAIN = "https://huggingface.co/datasets/bertin-project/mc4-es-sampled/resolve/main/mc4-es-train-50M-{config}-shard-{index:04d}-of-{n_shards:04d}.json.gz"
10
+
11
+
12
+ def main(config="stepwise"):
13
+ data_urls = [
14
+ _DATA_URL_TRAIN.format(
15
+ config=config,
16
+ index=index + 1,
17
+ n_shards=1024,
18
+ )
19
+ for index in range(1024)
20
+ ]
21
+ with open(f"mc4-es-train-50M-{config}.jsonl", "w") as f:
22
+ for dara_url in tqdm(data_urls):
23
+ response = requests.get(dara_url)
24
+ bio = io.BytesIO(response.content)
25
+ with gzip.open(bio, "rt", encoding="utf8") as g:
26
+ for line in g:
27
+ json_line = json.loads(line.strip())
28
+ f.write(json.dumps(json_line) + "\n")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main(sys.argv[1])