ydshieh commited on
Commit
2c5a28b
1 Parent(s): 16517d8

update to be as a base

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax_reduced.py +128 -64
run_image_captioning_flax_reduced.py CHANGED
@@ -32,8 +32,8 @@ import datasets
32
  import nltk # Here to have a nice missing dependency error message early on
33
  import numpy as np
34
  from datasets import Dataset, load_dataset, load_metric
35
- from tqdm import tqdm
36
  from PIL import Image
 
37
 
38
  import jax
39
  import jax.numpy as jnp
@@ -47,14 +47,14 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
47
  from huggingface_hub import Repository
48
  from transformers import (
49
  CONFIG_MAPPING,
50
- FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
51
  FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
 
52
  AutoConfig,
53
  AutoFeatureExtractor,
54
  AutoTokenizer,
 
55
  HfArgumentParser,
56
  is_tensorboard_available,
57
- FlaxAutoModelForVision2Seq,
58
  )
59
  from transformers.file_utils import get_full_repo_name, is_offline_mode
60
 
@@ -113,8 +113,7 @@ class TrainingArguments:
113
  per_device_eval_batch_size: int = field(
114
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
115
  )
116
- _block_size_doc = \
117
- """
118
  The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
119
  cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
120
  good option if your disk space is large enough to store the whole processed dataset.
@@ -124,10 +123,7 @@ class TrainingArguments:
124
  `batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
125
  dataset is large.
126
  """
127
- block_size: int = field(
128
- default=0,
129
- metadata={"help": _block_size_doc}
130
- )
131
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
132
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
133
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
@@ -197,16 +193,21 @@ class ModelArguments:
197
  },
198
  )
199
  model_type: Optional[str] = field(
200
- default='vision-encoder-decoder',
201
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
202
  )
203
  encoder_model_type: Optional[str] = field(
204
  default=None,
205
- metadata={"help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"}
 
 
206
  )
207
  decoder_model_type: Optional[str] = field(
208
  default=None,
209
- metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(DECODER_MODEL_TYPES)}
 
 
 
210
  )
211
  config_name: Optional[str] = field(
212
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
@@ -218,10 +219,12 @@ class ModelArguments:
218
  default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
219
  )
220
  feature_extractor_name: Optional[str] = field(
221
- default=None, metadata={"help": "Pretrained encoder feature extractor_name or path if not the same as encoder_model_name"}
 
222
  )
223
  tokenizer_name: Optional[str] = field(
224
- default=None, metadata={"help": "Pretrained decoder tokenizer name or path if not the same as decoder_model_name"}
 
225
  )
226
  cache_dir: Optional[str] = field(
227
  default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
@@ -505,7 +508,7 @@ def main():
505
  # Use specified `model_type` (default to `vision-encoder-decoder`)
506
  else:
507
 
508
- if not model_args.model_type in MODEL_TYPES:
509
  raise ValueError(
510
  f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
511
  )
@@ -516,29 +519,41 @@ def main():
516
 
517
  # Use explicit specified encoder config
518
  if model_args.encoder_config_name:
519
- encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
 
 
520
  # Use pretrained encoder model's config
521
  elif model_args.encoder_model_name_or_path:
522
- encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
 
 
523
  # Use specified encoder model type
524
  elif model_args.encoder_model_type:
525
  encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
526
  logger.warning("You are instantiating a new config instance from scratch for the encoder.")
527
  else:
528
- raise ValueError("Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required.")
 
 
529
 
530
  # Use explicit specified decoder config
531
  if model_args.decoder_config_name:
532
- decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
 
 
533
  # Use pretrained decoder model's config
534
  elif model_args.decoder_model_name_or_path:
535
- decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
 
 
536
  # Use specified decoder model type
537
  elif model_args.decoder_model_type:
538
  decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
539
  logger.warning("You are instantiating a new config instance from scratch for the decoder.")
540
  else:
541
- raise ValueError("Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required.")
 
 
542
 
543
  logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
544
  decoder_config.is_decoder = True
@@ -586,7 +601,9 @@ def main():
586
  )
587
  else:
588
  # model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
589
- model = FlaxAutoModelForVision2Seq.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
590
  model_class = model.__class__
591
 
592
  # encoder_class = FlaxAutoModel
@@ -604,10 +621,12 @@ def main():
604
  model_args.encoder_model_name_or_path,
605
  config=config.encoder,
606
  seed=training_args.seed,
607
- dtype=getattr(jnp, model_args.dtype)
608
  )
609
  else:
610
- encoder = encoder_class(config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
611
  logger.warning("You are instantiating a new model instance from scratch for the encoder.")
612
 
613
  if model_args.decoder_model_name_or_path:
@@ -615,10 +634,12 @@ def main():
615
  model_args.decoder_model_name_or_path,
616
  config=config.decoder,
617
  seed=training_args.seed,
618
- dtype=getattr(jnp, model_args.dtype)
619
  )
620
  else:
621
- decoder = decoder_class(config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
622
  logger.warning("You are instantiating a new model instance from scratch for the decoder.")
623
 
624
  model = model_class.from_encoder_decoder_pretrained(
@@ -646,7 +667,8 @@ def main():
646
  feature_extractor = None
647
  if model_args.feature_extractor_name:
648
  feature_extractor = AutoFeatureExtractor.from_pretrained(
649
- model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
 
650
  )
651
  elif model_args.model_name_or_path:
652
  try:
@@ -684,7 +706,9 @@ def main():
684
  if not tokenizer:
685
  if model_args.decoder_model_name_or_path:
686
  tokenizer = AutoTokenizer.from_pretrained(
687
- model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
688
  )
689
  else:
690
  raise ValueError(
@@ -739,9 +763,9 @@ def main():
739
  for image_file in examples[image_column]:
740
  try:
741
  image = Image.open(image_file)
742
- encoder_inputs = feature_extractor(images=image, return_tensors="np")
743
  bools.append(True)
744
- except:
745
  bools.append(False)
746
 
747
  return bools
@@ -752,7 +776,7 @@ def main():
752
 
753
  captions = []
754
  for caption in examples[caption_column]:
755
- captions.append(caption.lower() + ' ' + tokenizer.eos_token)
756
 
757
  targets = captions
758
 
@@ -795,7 +819,7 @@ def main():
795
  img = Image.open(image_file)
796
  images.append(img)
797
  to_keep.append(True)
798
- except:
799
  to_keep.append(False)
800
 
801
  for k, v in examples.items():
@@ -831,9 +855,11 @@ def main():
831
  ),
832
  dtype="float32",
833
  ),
834
- "labels": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
835
- "decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
836
- "decoder_attention_mask": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
 
 
837
  }
838
  )
839
 
@@ -909,7 +935,9 @@ def main():
909
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
910
  # instead here.)
911
  if not run_feat_ext_at_beginning:
912
- predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
 
 
913
  predict_dataset = predict_dataset.map(
914
  function=function_kwarg,
915
  batched=True,
@@ -930,7 +958,9 @@ def main():
930
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
931
 
932
  if training_args.block_size % train_batch_size > 0:
933
- raise ValueError(f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead.")
 
 
934
 
935
  if training_args.do_train:
936
  steps_per_epoch = len(train_dataset) // train_batch_size
@@ -951,13 +981,13 @@ def main():
951
  test_steps = num_test_examples // eval_batch_size
952
 
953
  def blockwise_data_loader(
954
- rng: jax.random.PRNGKey,
955
- ds: Dataset,
956
- block_size: int,
957
- batch_size: int,
958
- shuffle: bool = False,
959
- keep_in_memory: bool = False,
960
- split: str = ""
961
  ):
962
  """
963
  Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
@@ -1165,7 +1195,7 @@ def main():
1165
 
1166
  def generate_step(params, batch):
1167
  model.params = params
1168
- output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
1169
  return output_ids.sequences
1170
 
1171
  # Create parallel version of the train and eval step
@@ -1212,7 +1242,13 @@ def main():
1212
  if training_args.push_to_hub:
1213
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1214
 
1215
- def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, metric_key_prefix: str = "eval", ckpt_dir: str = "", is_prediction=False):
 
 
 
 
 
 
1216
 
1217
  logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
1218
 
@@ -1230,7 +1266,9 @@ def main():
1230
  split="prediction" if is_prediction else "validation",
1231
  )
1232
  steps = len(dataset) // eval_batch_size
1233
- for _ in tqdm(range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False):
 
 
1234
  # Model forward
1235
  batch = next(batches)
1236
  _labels = batch.get("labels", None)
@@ -1260,7 +1298,12 @@ def main():
1260
  if labels:
1261
  rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1262
  metrics.update(rouge_metrics)
1263
- rouge_desc = " ".join([f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
 
 
 
 
 
1264
  for pred, label in zip(decoded_preds, decoded_labels):
1265
  pred = pred.replace("\n", " ")
1266
  label = label.replace("\n", " ")
@@ -1293,28 +1336,37 @@ def main():
1293
 
1294
  # Save metrics (only for the evaluation/prediction being done along with training)
1295
  if has_tensorboard and training_args.do_train:
1296
- write_metric(summary_writer, metrics, train_time=None, step=cur_step, metric_key_prefix=metric_key_prefix)
 
 
1297
 
1298
  # save final metrics in json
1299
- metrics = {f"{metric_key_prefix}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()}
 
 
 
1300
  _path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
1301
  with open(_path, "w") as f:
1302
  json.dump(metrics, f, indent=4, sort_keys=True)
1303
 
1304
  # Update report
1305
- with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1306
- fp.write(desc + '\n')
1307
 
1308
  # Save generations
1309
  if generations:
1310
- with open(os.path.join(training_args.output_dir, ckpt_dir, f'{metric_key_prefix}_generation.json'), 'w', encoding='UTF-8') as fp:
 
 
 
 
1311
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1312
 
1313
  def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
1314
- evaluation_loop(rng, dataset, metric_key_prefix='eval', ckpt_dir=ckpt_dir)
1315
 
1316
  def predict(rng: jax.random.PRNGKey, dataset: Dataset):
1317
- evaluation_loop(rng, dataset, metric_key_prefix='test', is_prediction=True)
1318
 
1319
  input_rng = None
1320
 
@@ -1340,7 +1392,7 @@ def main():
1340
  batch_size=train_batch_size,
1341
  keep_in_memory=True,
1342
  shuffle=True,
1343
- split="train"
1344
  )
1345
 
1346
  # train
@@ -1364,16 +1416,26 @@ def main():
1364
 
1365
  logger.info(desc)
1366
 
1367
- with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1368
- fp.write(desc + '\n')
1369
 
1370
  # Save metrics
1371
  if has_tensorboard and jax.process_index() == 0:
1372
- write_metric(summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train")
 
 
 
 
 
 
1373
 
1374
  # ======================== Evaluating (inside an epoch) ==============================
1375
 
1376
- if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
 
 
 
 
1377
  ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
1378
  commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
1379
  evaluate(input_rng, eval_dataset, ckpt_dir)
@@ -1386,12 +1448,14 @@ def main():
1386
 
1387
  logger.info(desc)
1388
 
1389
- with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1390
- fp.write(desc + '\n')
1391
 
1392
  # Save metrics
1393
  if has_tensorboard and jax.process_index() == 0:
1394
- write_metric(summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train")
 
 
1395
 
1396
  # ======================== Evaluating (after each epoch) ==============================
1397
 
 
32
  import nltk # Here to have a nice missing dependency error message early on
33
  import numpy as np
34
  from datasets import Dataset, load_dataset, load_metric
 
35
  from PIL import Image
36
+ from tqdm import tqdm
37
 
38
  import jax
39
  import jax.numpy as jnp
 
47
  from huggingface_hub import Repository
48
  from transformers import (
49
  CONFIG_MAPPING,
 
50
  FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
51
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
52
  AutoConfig,
53
  AutoFeatureExtractor,
54
  AutoTokenizer,
55
+ FlaxAutoModelForVision2Seq,
56
  HfArgumentParser,
57
  is_tensorboard_available,
 
58
  )
59
  from transformers.file_utils import get_full_repo_name, is_offline_mode
60
 
 
113
  per_device_eval_batch_size: int = field(
114
  default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
115
  )
116
+ _block_size_doc = """
 
117
  The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
118
  cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
119
  good option if your disk space is large enough to store the whole processed dataset.
 
123
  `batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
124
  dataset is large.
125
  """
126
+ block_size: int = field(default=0, metadata={"help": _block_size_doc})
 
 
 
127
  learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
128
  weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
129
  adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
 
193
  },
194
  )
195
  model_type: Optional[str] = field(
196
+ default="vision-encoder-decoder",
197
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
198
  )
199
  encoder_model_type: Optional[str] = field(
200
  default=None,
201
+ metadata={
202
+ "help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"
203
+ },
204
  )
205
  decoder_model_type: Optional[str] = field(
206
  default=None,
207
+ metadata={
208
+ "help": "If training from scratch, pass a decoder model type from the list: "
209
+ + ", ".join(DECODER_MODEL_TYPES)
210
+ },
211
  )
212
  config_name: Optional[str] = field(
213
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
219
  default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
220
  )
221
  feature_extractor_name: Optional[str] = field(
222
+ default=None,
223
+ metadata={"help": "Pretrained encoder feature extractor_name or path if not the same as encoder_model_name"},
224
  )
225
  tokenizer_name: Optional[str] = field(
226
+ default=None,
227
+ metadata={"help": "Pretrained decoder tokenizer name or path if not the same as decoder_model_name"},
228
  )
229
  cache_dir: Optional[str] = field(
230
  default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
508
  # Use specified `model_type` (default to `vision-encoder-decoder`)
509
  else:
510
 
511
+ if model_args.model_type not in MODEL_TYPES:
512
  raise ValueError(
513
  f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
514
  )
 
519
 
520
  # Use explicit specified encoder config
521
  if model_args.encoder_config_name:
522
+ encoder_config = AutoConfig.from_pretrained(
523
+ model_args.encoder_config_name, cache_dir=encoder_cache_dir
524
+ )
525
  # Use pretrained encoder model's config
526
  elif model_args.encoder_model_name_or_path:
527
+ encoder_config = AutoConfig.from_pretrained(
528
+ model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
529
+ )
530
  # Use specified encoder model type
531
  elif model_args.encoder_model_type:
532
  encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
533
  logger.warning("You are instantiating a new config instance from scratch for the encoder.")
534
  else:
535
+ raise ValueError(
536
+ "Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required."
537
+ )
538
 
539
  # Use explicit specified decoder config
540
  if model_args.decoder_config_name:
541
+ decoder_config = AutoConfig.from_pretrained(
542
+ model_args.decoder_config_name, cache_dir=decoder_cache_dir
543
+ )
544
  # Use pretrained decoder model's config
545
  elif model_args.decoder_model_name_or_path:
546
+ decoder_config = AutoConfig.from_pretrained(
547
+ model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
548
+ )
549
  # Use specified decoder model type
550
  elif model_args.decoder_model_type:
551
  decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
552
  logger.warning("You are instantiating a new config instance from scratch for the decoder.")
553
  else:
554
+ raise ValueError(
555
+ "Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required."
556
+ )
557
 
558
  logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
559
  decoder_config.is_decoder = True
 
601
  )
602
  else:
603
  # model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
604
+ model = FlaxAutoModelForVision2Seq.from_config(
605
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
606
+ )
607
  model_class = model.__class__
608
 
609
  # encoder_class = FlaxAutoModel
 
621
  model_args.encoder_model_name_or_path,
622
  config=config.encoder,
623
  seed=training_args.seed,
624
+ dtype=getattr(jnp, model_args.dtype),
625
  )
626
  else:
627
+ encoder = encoder_class(
628
+ config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
629
+ )
630
  logger.warning("You are instantiating a new model instance from scratch for the encoder.")
631
 
632
  if model_args.decoder_model_name_or_path:
 
634
  model_args.decoder_model_name_or_path,
635
  config=config.decoder,
636
  seed=training_args.seed,
637
+ dtype=getattr(jnp, model_args.dtype),
638
  )
639
  else:
640
+ decoder = decoder_class(
641
+ config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
642
+ )
643
  logger.warning("You are instantiating a new model instance from scratch for the decoder.")
644
 
645
  model = model_class.from_encoder_decoder_pretrained(
 
667
  feature_extractor = None
668
  if model_args.feature_extractor_name:
669
  feature_extractor = AutoFeatureExtractor.from_pretrained(
670
+ model_args.feature_extractor_name,
671
+ cache_dir=model_args.cache_dir,
672
  )
673
  elif model_args.model_name_or_path:
674
  try:
 
706
  if not tokenizer:
707
  if model_args.decoder_model_name_or_path:
708
  tokenizer = AutoTokenizer.from_pretrained(
709
+ model_args.decoder_model_name_or_path,
710
+ cache_dir=model_args.cache_dir,
711
+ use_fast=model_args.use_fast_tokenizer,
712
  )
713
  else:
714
  raise ValueError(
 
763
  for image_file in examples[image_column]:
764
  try:
765
  image = Image.open(image_file)
766
+ feature_extractor(images=image, return_tensors="np")
767
  bools.append(True)
768
+ except Exception:
769
  bools.append(False)
770
 
771
  return bools
 
776
 
777
  captions = []
778
  for caption in examples[caption_column]:
779
+ captions.append(caption.lower() + " " + tokenizer.eos_token)
780
 
781
  targets = captions
782
 
 
819
  img = Image.open(image_file)
820
  images.append(img)
821
  to_keep.append(True)
822
+ except Exception:
823
  to_keep.append(False)
824
 
825
  for k, v in examples.items():
 
855
  ),
856
  dtype="float32",
857
  ),
858
+ "labels": datasets.Sequence(feature=datasets.Value(dtype="int32", id=None), length=-1, id=None),
859
+ "decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype="int32", id=None), length=-1, id=None),
860
+ "decoder_attention_mask": datasets.Sequence(
861
+ feature=datasets.Value(dtype="int32", id=None), length=-1, id=None
862
+ ),
863
  }
864
  )
865
 
 
935
  # (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
936
  # instead here.)
937
  if not run_feat_ext_at_beginning:
938
+ predict_dataset = predict_dataset.filter(
939
+ filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers
940
+ )
941
  predict_dataset = predict_dataset.map(
942
  function=function_kwarg,
943
  batched=True,
 
958
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
959
 
960
  if training_args.block_size % train_batch_size > 0:
961
+ raise ValueError(
962
+ f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead."
963
+ )
964
 
965
  if training_args.do_train:
966
  steps_per_epoch = len(train_dataset) // train_batch_size
 
981
  test_steps = num_test_examples // eval_batch_size
982
 
983
  def blockwise_data_loader(
984
+ rng: jax.random.PRNGKey,
985
+ ds: Dataset,
986
+ block_size: int,
987
+ batch_size: int,
988
+ shuffle: bool = False,
989
+ keep_in_memory: bool = False,
990
+ split: str = "",
991
  ):
992
  """
993
  Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
 
1195
 
1196
  def generate_step(params, batch):
1197
  model.params = params
1198
+ output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
1199
  return output_ids.sequences
1200
 
1201
  # Create parallel version of the train and eval step
 
1242
  if training_args.push_to_hub:
1243
  repo.push_to_hub(commit_message=commit_msg, blocking=False)
1244
 
1245
+ def evaluation_loop(
1246
+ rng: jax.random.PRNGKey,
1247
+ dataset: Dataset,
1248
+ metric_key_prefix: str = "eval",
1249
+ ckpt_dir: str = "",
1250
+ is_prediction=False,
1251
+ ):
1252
 
1253
  logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
1254
 
 
1266
  split="prediction" if is_prediction else "validation",
1267
  )
1268
  steps = len(dataset) // eval_batch_size
1269
+ for _ in tqdm(
1270
+ range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False
1271
+ ):
1272
  # Model forward
1273
  batch = next(batches)
1274
  _labels = batch.get("labels", None)
 
1298
  if labels:
1299
  rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1300
  metrics.update(rouge_metrics)
1301
+ rouge_desc = " ".join(
1302
+ [
1303
+ f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |"
1304
+ for key, value in rouge_metrics.items()
1305
+ ]
1306
+ )
1307
  for pred, label in zip(decoded_preds, decoded_labels):
1308
  pred = pred.replace("\n", " ")
1309
  label = label.replace("\n", " ")
 
1336
 
1337
  # Save metrics (only for the evaluation/prediction being done along with training)
1338
  if has_tensorboard and training_args.do_train:
1339
+ write_metric(
1340
+ summary_writer, metrics, train_time=None, step=cur_step, metric_key_prefix=metric_key_prefix
1341
+ )
1342
 
1343
  # save final metrics in json
1344
+ metrics = {
1345
+ f"{metric_key_prefix}_{metric_name}": round(value.item(), 6)
1346
+ for metric_name, value in metrics.items()
1347
+ }
1348
  _path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
1349
  with open(_path, "w") as f:
1350
  json.dump(metrics, f, indent=4, sort_keys=True)
1351
 
1352
  # Update report
1353
+ with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
1354
+ fp.write(desc + "\n")
1355
 
1356
  # Save generations
1357
  if generations:
1358
+ with open(
1359
+ os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_generation.json"),
1360
+ "w",
1361
+ encoding="UTF-8",
1362
+ ) as fp:
1363
  json.dump(generations, fp, ensure_ascii=False, indent=4)
1364
 
1365
  def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
1366
+ evaluation_loop(rng, dataset, metric_key_prefix="eval", ckpt_dir=ckpt_dir)
1367
 
1368
  def predict(rng: jax.random.PRNGKey, dataset: Dataset):
1369
+ evaluation_loop(rng, dataset, metric_key_prefix="test", is_prediction=True)
1370
 
1371
  input_rng = None
1372
 
 
1392
  batch_size=train_batch_size,
1393
  keep_in_memory=True,
1394
  shuffle=True,
1395
+ split="train",
1396
  )
1397
 
1398
  # train
 
1416
 
1417
  logger.info(desc)
1418
 
1419
+ with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
1420
+ fp.write(desc + "\n")
1421
 
1422
  # Save metrics
1423
  if has_tensorboard and jax.process_index() == 0:
1424
+ write_metric(
1425
+ summary_writer,
1426
+ train_metrics,
1427
+ train_time=train_time,
1428
+ step=cur_step,
1429
+ metric_key_prefix="train",
1430
+ )
1431
 
1432
  # ======================== Evaluating (inside an epoch) ==============================
1433
 
1434
+ if (
1435
+ training_args.do_eval
1436
+ and (training_args.eval_steps is not None and training_args.eval_steps > 0)
1437
+ and cur_step % training_args.eval_steps == 0
1438
+ ):
1439
  ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
1440
  commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
1441
  evaluate(input_rng, eval_dataset, ckpt_dir)
 
1448
 
1449
  logger.info(desc)
1450
 
1451
+ with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
1452
+ fp.write(desc + "\n")
1453
 
1454
  # Save metrics
1455
  if has_tensorboard and jax.process_index() == 0:
1456
+ write_metric(
1457
+ summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train"
1458
+ )
1459
 
1460
  # ======================== Evaluating (after each epoch) ==============================
1461