ydshieh
commited on
Commit
•
2c5a28b
1
Parent(s):
16517d8
update to be as a base
Browse files- run_image_captioning_flax_reduced.py +128 -64
run_image_captioning_flax_reduced.py
CHANGED
@@ -32,8 +32,8 @@ import datasets
|
|
32 |
import nltk # Here to have a nice missing dependency error message early on
|
33 |
import numpy as np
|
34 |
from datasets import Dataset, load_dataset, load_metric
|
35 |
-
from tqdm import tqdm
|
36 |
from PIL import Image
|
|
|
37 |
|
38 |
import jax
|
39 |
import jax.numpy as jnp
|
@@ -47,14 +47,14 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
|
|
47 |
from huggingface_hub import Repository
|
48 |
from transformers import (
|
49 |
CONFIG_MAPPING,
|
50 |
-
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
51 |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
|
52 |
AutoConfig,
|
53 |
AutoFeatureExtractor,
|
54 |
AutoTokenizer,
|
|
|
55 |
HfArgumentParser,
|
56 |
is_tensorboard_available,
|
57 |
-
FlaxAutoModelForVision2Seq,
|
58 |
)
|
59 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
60 |
|
@@ -113,8 +113,7 @@ class TrainingArguments:
|
|
113 |
per_device_eval_batch_size: int = field(
|
114 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
115 |
)
|
116 |
-
_block_size_doc =
|
117 |
-
"""
|
118 |
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
|
119 |
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
|
120 |
good option if your disk space is large enough to store the whole processed dataset.
|
@@ -124,10 +123,7 @@ class TrainingArguments:
|
|
124 |
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
|
125 |
dataset is large.
|
126 |
"""
|
127 |
-
block_size: int = field(
|
128 |
-
default=0,
|
129 |
-
metadata={"help": _block_size_doc}
|
130 |
-
)
|
131 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
132 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
133 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
@@ -197,16 +193,21 @@ class ModelArguments:
|
|
197 |
},
|
198 |
)
|
199 |
model_type: Optional[str] = field(
|
200 |
-
default=
|
201 |
-
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
|
202 |
)
|
203 |
encoder_model_type: Optional[str] = field(
|
204 |
default=None,
|
205 |
-
metadata={
|
|
|
|
|
206 |
)
|
207 |
decoder_model_type: Optional[str] = field(
|
208 |
default=None,
|
209 |
-
metadata={
|
|
|
|
|
|
|
210 |
)
|
211 |
config_name: Optional[str] = field(
|
212 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
@@ -218,10 +219,12 @@ class ModelArguments:
|
|
218 |
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
|
219 |
)
|
220 |
feature_extractor_name: Optional[str] = field(
|
221 |
-
default=None,
|
|
|
222 |
)
|
223 |
tokenizer_name: Optional[str] = field(
|
224 |
-
default=None,
|
|
|
225 |
)
|
226 |
cache_dir: Optional[str] = field(
|
227 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
@@ -505,7 +508,7 @@ def main():
|
|
505 |
# Use specified `model_type` (default to `vision-encoder-decoder`)
|
506 |
else:
|
507 |
|
508 |
-
if
|
509 |
raise ValueError(
|
510 |
f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
|
511 |
)
|
@@ -516,29 +519,41 @@ def main():
|
|
516 |
|
517 |
# Use explicit specified encoder config
|
518 |
if model_args.encoder_config_name:
|
519 |
-
encoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
520 |
# Use pretrained encoder model's config
|
521 |
elif model_args.encoder_model_name_or_path:
|
522 |
-
encoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
523 |
# Use specified encoder model type
|
524 |
elif model_args.encoder_model_type:
|
525 |
encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
|
526 |
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
527 |
else:
|
528 |
-
raise ValueError(
|
|
|
|
|
529 |
|
530 |
# Use explicit specified decoder config
|
531 |
if model_args.decoder_config_name:
|
532 |
-
decoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
533 |
# Use pretrained decoder model's config
|
534 |
elif model_args.decoder_model_name_or_path:
|
535 |
-
decoder_config = AutoConfig.from_pretrained(
|
|
|
|
|
536 |
# Use specified decoder model type
|
537 |
elif model_args.decoder_model_type:
|
538 |
decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
|
539 |
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
540 |
else:
|
541 |
-
raise ValueError(
|
|
|
|
|
542 |
|
543 |
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
544 |
decoder_config.is_decoder = True
|
@@ -586,7 +601,9 @@ def main():
|
|
586 |
)
|
587 |
else:
|
588 |
# model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
|
589 |
-
model = FlaxAutoModelForVision2Seq.from_config(
|
|
|
|
|
590 |
model_class = model.__class__
|
591 |
|
592 |
# encoder_class = FlaxAutoModel
|
@@ -604,10 +621,12 @@ def main():
|
|
604 |
model_args.encoder_model_name_or_path,
|
605 |
config=config.encoder,
|
606 |
seed=training_args.seed,
|
607 |
-
dtype=getattr(jnp, model_args.dtype)
|
608 |
)
|
609 |
else:
|
610 |
-
encoder = encoder_class(
|
|
|
|
|
611 |
logger.warning("You are instantiating a new model instance from scratch for the encoder.")
|
612 |
|
613 |
if model_args.decoder_model_name_or_path:
|
@@ -615,10 +634,12 @@ def main():
|
|
615 |
model_args.decoder_model_name_or_path,
|
616 |
config=config.decoder,
|
617 |
seed=training_args.seed,
|
618 |
-
dtype=getattr(jnp, model_args.dtype)
|
619 |
)
|
620 |
else:
|
621 |
-
decoder = decoder_class(
|
|
|
|
|
622 |
logger.warning("You are instantiating a new model instance from scratch for the decoder.")
|
623 |
|
624 |
model = model_class.from_encoder_decoder_pretrained(
|
@@ -646,7 +667,8 @@ def main():
|
|
646 |
feature_extractor = None
|
647 |
if model_args.feature_extractor_name:
|
648 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
649 |
-
model_args.feature_extractor_name,
|
|
|
650 |
)
|
651 |
elif model_args.model_name_or_path:
|
652 |
try:
|
@@ -684,7 +706,9 @@ def main():
|
|
684 |
if not tokenizer:
|
685 |
if model_args.decoder_model_name_or_path:
|
686 |
tokenizer = AutoTokenizer.from_pretrained(
|
687 |
-
model_args.decoder_model_name_or_path,
|
|
|
|
|
688 |
)
|
689 |
else:
|
690 |
raise ValueError(
|
@@ -739,9 +763,9 @@ def main():
|
|
739 |
for image_file in examples[image_column]:
|
740 |
try:
|
741 |
image = Image.open(image_file)
|
742 |
-
|
743 |
bools.append(True)
|
744 |
-
except:
|
745 |
bools.append(False)
|
746 |
|
747 |
return bools
|
@@ -752,7 +776,7 @@ def main():
|
|
752 |
|
753 |
captions = []
|
754 |
for caption in examples[caption_column]:
|
755 |
-
|
756 |
|
757 |
targets = captions
|
758 |
|
@@ -795,7 +819,7 @@ def main():
|
|
795 |
img = Image.open(image_file)
|
796 |
images.append(img)
|
797 |
to_keep.append(True)
|
798 |
-
except:
|
799 |
to_keep.append(False)
|
800 |
|
801 |
for k, v in examples.items():
|
@@ -831,9 +855,11 @@ def main():
|
|
831 |
),
|
832 |
dtype="float32",
|
833 |
),
|
834 |
-
"labels": datasets.Sequence(feature=datasets.Value(dtype=
|
835 |
-
"decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype=
|
836 |
-
"decoder_attention_mask": datasets.Sequence(
|
|
|
|
|
837 |
}
|
838 |
)
|
839 |
|
@@ -909,7 +935,9 @@ def main():
|
|
909 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
910 |
# instead here.)
|
911 |
if not run_feat_ext_at_beginning:
|
912 |
-
predict_dataset = predict_dataset.filter(
|
|
|
|
|
913 |
predict_dataset = predict_dataset.map(
|
914 |
function=function_kwarg,
|
915 |
batched=True,
|
@@ -930,7 +958,9 @@ def main():
|
|
930 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
931 |
|
932 |
if training_args.block_size % train_batch_size > 0:
|
933 |
-
raise ValueError(
|
|
|
|
|
934 |
|
935 |
if training_args.do_train:
|
936 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
@@ -951,13 +981,13 @@ def main():
|
|
951 |
test_steps = num_test_examples // eval_batch_size
|
952 |
|
953 |
def blockwise_data_loader(
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
):
|
962 |
"""
|
963 |
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
|
@@ -1165,7 +1195,7 @@ def main():
|
|
1165 |
|
1166 |
def generate_step(params, batch):
|
1167 |
model.params = params
|
1168 |
-
output_ids = model.generate(batch[
|
1169 |
return output_ids.sequences
|
1170 |
|
1171 |
# Create parallel version of the train and eval step
|
@@ -1212,7 +1242,13 @@ def main():
|
|
1212 |
if training_args.push_to_hub:
|
1213 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1214 |
|
1215 |
-
def evaluation_loop(
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
1218 |
|
@@ -1230,7 +1266,9 @@ def main():
|
|
1230 |
split="prediction" if is_prediction else "validation",
|
1231 |
)
|
1232 |
steps = len(dataset) // eval_batch_size
|
1233 |
-
for _ in tqdm(
|
|
|
|
|
1234 |
# Model forward
|
1235 |
batch = next(batches)
|
1236 |
_labels = batch.get("labels", None)
|
@@ -1260,7 +1298,12 @@ def main():
|
|
1260 |
if labels:
|
1261 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
1262 |
metrics.update(rouge_metrics)
|
1263 |
-
rouge_desc = " ".join(
|
|
|
|
|
|
|
|
|
|
|
1264 |
for pred, label in zip(decoded_preds, decoded_labels):
|
1265 |
pred = pred.replace("\n", " ")
|
1266 |
label = label.replace("\n", " ")
|
@@ -1293,28 +1336,37 @@ def main():
|
|
1293 |
|
1294 |
# Save metrics (only for the evaluation/prediction being done along with training)
|
1295 |
if has_tensorboard and training_args.do_train:
|
1296 |
-
write_metric(
|
|
|
|
|
1297 |
|
1298 |
# save final metrics in json
|
1299 |
-
metrics = {
|
|
|
|
|
|
|
1300 |
_path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
|
1301 |
with open(_path, "w") as f:
|
1302 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
1303 |
|
1304 |
# Update report
|
1305 |
-
with open(os.path.join(training_args.output_dir,
|
1306 |
-
fp.write(desc +
|
1307 |
|
1308 |
# Save generations
|
1309 |
if generations:
|
1310 |
-
with open(
|
|
|
|
|
|
|
|
|
1311 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
1312 |
|
1313 |
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
|
1314 |
-
evaluation_loop(rng, dataset, metric_key_prefix=
|
1315 |
|
1316 |
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
1317 |
-
evaluation_loop(rng, dataset, metric_key_prefix=
|
1318 |
|
1319 |
input_rng = None
|
1320 |
|
@@ -1340,7 +1392,7 @@ def main():
|
|
1340 |
batch_size=train_batch_size,
|
1341 |
keep_in_memory=True,
|
1342 |
shuffle=True,
|
1343 |
-
split="train"
|
1344 |
)
|
1345 |
|
1346 |
# train
|
@@ -1364,16 +1416,26 @@ def main():
|
|
1364 |
|
1365 |
logger.info(desc)
|
1366 |
|
1367 |
-
with open(os.path.join(training_args.output_dir,
|
1368 |
-
fp.write(desc +
|
1369 |
|
1370 |
# Save metrics
|
1371 |
if has_tensorboard and jax.process_index() == 0:
|
1372 |
-
write_metric(
|
|
|
|
|
|
|
|
|
|
|
|
|
1373 |
|
1374 |
# ======================== Evaluating (inside an epoch) ==============================
|
1375 |
|
1376 |
-
if
|
|
|
|
|
|
|
|
|
1377 |
ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
|
1378 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
|
1379 |
evaluate(input_rng, eval_dataset, ckpt_dir)
|
@@ -1386,12 +1448,14 @@ def main():
|
|
1386 |
|
1387 |
logger.info(desc)
|
1388 |
|
1389 |
-
with open(os.path.join(training_args.output_dir,
|
1390 |
-
fp.write(desc +
|
1391 |
|
1392 |
# Save metrics
|
1393 |
if has_tensorboard and jax.process_index() == 0:
|
1394 |
-
write_metric(
|
|
|
|
|
1395 |
|
1396 |
# ======================== Evaluating (after each epoch) ==============================
|
1397 |
|
|
|
32 |
import nltk # Here to have a nice missing dependency error message early on
|
33 |
import numpy as np
|
34 |
from datasets import Dataset, load_dataset, load_metric
|
|
|
35 |
from PIL import Image
|
36 |
+
from tqdm import tqdm
|
37 |
|
38 |
import jax
|
39 |
import jax.numpy as jnp
|
|
|
47 |
from huggingface_hub import Repository
|
48 |
from transformers import (
|
49 |
CONFIG_MAPPING,
|
|
|
50 |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
51 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
52 |
AutoConfig,
|
53 |
AutoFeatureExtractor,
|
54 |
AutoTokenizer,
|
55 |
+
FlaxAutoModelForVision2Seq,
|
56 |
HfArgumentParser,
|
57 |
is_tensorboard_available,
|
|
|
58 |
)
|
59 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
60 |
|
|
|
113 |
per_device_eval_batch_size: int = field(
|
114 |
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
115 |
)
|
116 |
+
_block_size_doc = """
|
|
|
117 |
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
|
118 |
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
|
119 |
good option if your disk space is large enough to store the whole processed dataset.
|
|
|
123 |
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
|
124 |
dataset is large.
|
125 |
"""
|
126 |
+
block_size: int = field(default=0, metadata={"help": _block_size_doc})
|
|
|
|
|
|
|
127 |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
128 |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
129 |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
|
193 |
},
|
194 |
)
|
195 |
model_type: Optional[str] = field(
|
196 |
+
default="vision-encoder-decoder",
|
197 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
198 |
)
|
199 |
encoder_model_type: Optional[str] = field(
|
200 |
default=None,
|
201 |
+
metadata={
|
202 |
+
"help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"
|
203 |
+
},
|
204 |
)
|
205 |
decoder_model_type: Optional[str] = field(
|
206 |
default=None,
|
207 |
+
metadata={
|
208 |
+
"help": "If training from scratch, pass a decoder model type from the list: "
|
209 |
+
+ ", ".join(DECODER_MODEL_TYPES)
|
210 |
+
},
|
211 |
)
|
212 |
config_name: Optional[str] = field(
|
213 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
|
219 |
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
|
220 |
)
|
221 |
feature_extractor_name: Optional[str] = field(
|
222 |
+
default=None,
|
223 |
+
metadata={"help": "Pretrained encoder feature extractor_name or path if not the same as encoder_model_name"},
|
224 |
)
|
225 |
tokenizer_name: Optional[str] = field(
|
226 |
+
default=None,
|
227 |
+
metadata={"help": "Pretrained decoder tokenizer name or path if not the same as decoder_model_name"},
|
228 |
)
|
229 |
cache_dir: Optional[str] = field(
|
230 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
|
508 |
# Use specified `model_type` (default to `vision-encoder-decoder`)
|
509 |
else:
|
510 |
|
511 |
+
if model_args.model_type not in MODEL_TYPES:
|
512 |
raise ValueError(
|
513 |
f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
|
514 |
)
|
|
|
519 |
|
520 |
# Use explicit specified encoder config
|
521 |
if model_args.encoder_config_name:
|
522 |
+
encoder_config = AutoConfig.from_pretrained(
|
523 |
+
model_args.encoder_config_name, cache_dir=encoder_cache_dir
|
524 |
+
)
|
525 |
# Use pretrained encoder model's config
|
526 |
elif model_args.encoder_model_name_or_path:
|
527 |
+
encoder_config = AutoConfig.from_pretrained(
|
528 |
+
model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir
|
529 |
+
)
|
530 |
# Use specified encoder model type
|
531 |
elif model_args.encoder_model_type:
|
532 |
encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
|
533 |
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
534 |
else:
|
535 |
+
raise ValueError(
|
536 |
+
"Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required."
|
537 |
+
)
|
538 |
|
539 |
# Use explicit specified decoder config
|
540 |
if model_args.decoder_config_name:
|
541 |
+
decoder_config = AutoConfig.from_pretrained(
|
542 |
+
model_args.decoder_config_name, cache_dir=decoder_cache_dir
|
543 |
+
)
|
544 |
# Use pretrained decoder model's config
|
545 |
elif model_args.decoder_model_name_or_path:
|
546 |
+
decoder_config = AutoConfig.from_pretrained(
|
547 |
+
model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir
|
548 |
+
)
|
549 |
# Use specified decoder model type
|
550 |
elif model_args.decoder_model_type:
|
551 |
decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
|
552 |
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
553 |
else:
|
554 |
+
raise ValueError(
|
555 |
+
"Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required."
|
556 |
+
)
|
557 |
|
558 |
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
559 |
decoder_config.is_decoder = True
|
|
|
601 |
)
|
602 |
else:
|
603 |
# model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
|
604 |
+
model = FlaxAutoModelForVision2Seq.from_config(
|
605 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
606 |
+
)
|
607 |
model_class = model.__class__
|
608 |
|
609 |
# encoder_class = FlaxAutoModel
|
|
|
621 |
model_args.encoder_model_name_or_path,
|
622 |
config=config.encoder,
|
623 |
seed=training_args.seed,
|
624 |
+
dtype=getattr(jnp, model_args.dtype),
|
625 |
)
|
626 |
else:
|
627 |
+
encoder = encoder_class(
|
628 |
+
config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
629 |
+
)
|
630 |
logger.warning("You are instantiating a new model instance from scratch for the encoder.")
|
631 |
|
632 |
if model_args.decoder_model_name_or_path:
|
|
|
634 |
model_args.decoder_model_name_or_path,
|
635 |
config=config.decoder,
|
636 |
seed=training_args.seed,
|
637 |
+
dtype=getattr(jnp, model_args.dtype),
|
638 |
)
|
639 |
else:
|
640 |
+
decoder = decoder_class(
|
641 |
+
config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
642 |
+
)
|
643 |
logger.warning("You are instantiating a new model instance from scratch for the decoder.")
|
644 |
|
645 |
model = model_class.from_encoder_decoder_pretrained(
|
|
|
667 |
feature_extractor = None
|
668 |
if model_args.feature_extractor_name:
|
669 |
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
670 |
+
model_args.feature_extractor_name,
|
671 |
+
cache_dir=model_args.cache_dir,
|
672 |
)
|
673 |
elif model_args.model_name_or_path:
|
674 |
try:
|
|
|
706 |
if not tokenizer:
|
707 |
if model_args.decoder_model_name_or_path:
|
708 |
tokenizer = AutoTokenizer.from_pretrained(
|
709 |
+
model_args.decoder_model_name_or_path,
|
710 |
+
cache_dir=model_args.cache_dir,
|
711 |
+
use_fast=model_args.use_fast_tokenizer,
|
712 |
)
|
713 |
else:
|
714 |
raise ValueError(
|
|
|
763 |
for image_file in examples[image_column]:
|
764 |
try:
|
765 |
image = Image.open(image_file)
|
766 |
+
feature_extractor(images=image, return_tensors="np")
|
767 |
bools.append(True)
|
768 |
+
except Exception:
|
769 |
bools.append(False)
|
770 |
|
771 |
return bools
|
|
|
776 |
|
777 |
captions = []
|
778 |
for caption in examples[caption_column]:
|
779 |
+
captions.append(caption.lower() + " " + tokenizer.eos_token)
|
780 |
|
781 |
targets = captions
|
782 |
|
|
|
819 |
img = Image.open(image_file)
|
820 |
images.append(img)
|
821 |
to_keep.append(True)
|
822 |
+
except Exception:
|
823 |
to_keep.append(False)
|
824 |
|
825 |
for k, v in examples.items():
|
|
|
855 |
),
|
856 |
dtype="float32",
|
857 |
),
|
858 |
+
"labels": datasets.Sequence(feature=datasets.Value(dtype="int32", id=None), length=-1, id=None),
|
859 |
+
"decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype="int32", id=None), length=-1, id=None),
|
860 |
+
"decoder_attention_mask": datasets.Sequence(
|
861 |
+
feature=datasets.Value(dtype="int32", id=None), length=-1, id=None
|
862 |
+
),
|
863 |
}
|
864 |
)
|
865 |
|
|
|
935 |
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
936 |
# instead here.)
|
937 |
if not run_feat_ext_at_beginning:
|
938 |
+
predict_dataset = predict_dataset.filter(
|
939 |
+
filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers
|
940 |
+
)
|
941 |
predict_dataset = predict_dataset.map(
|
942 |
function=function_kwarg,
|
943 |
batched=True,
|
|
|
958 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
959 |
|
960 |
if training_args.block_size % train_batch_size > 0:
|
961 |
+
raise ValueError(
|
962 |
+
f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead."
|
963 |
+
)
|
964 |
|
965 |
if training_args.do_train:
|
966 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
|
981 |
test_steps = num_test_examples // eval_batch_size
|
982 |
|
983 |
def blockwise_data_loader(
|
984 |
+
rng: jax.random.PRNGKey,
|
985 |
+
ds: Dataset,
|
986 |
+
block_size: int,
|
987 |
+
batch_size: int,
|
988 |
+
shuffle: bool = False,
|
989 |
+
keep_in_memory: bool = False,
|
990 |
+
split: str = "",
|
991 |
):
|
992 |
"""
|
993 |
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
|
|
|
1195 |
|
1196 |
def generate_step(params, batch):
|
1197 |
model.params = params
|
1198 |
+
output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
|
1199 |
return output_ids.sequences
|
1200 |
|
1201 |
# Create parallel version of the train and eval step
|
|
|
1242 |
if training_args.push_to_hub:
|
1243 |
repo.push_to_hub(commit_message=commit_msg, blocking=False)
|
1244 |
|
1245 |
+
def evaluation_loop(
|
1246 |
+
rng: jax.random.PRNGKey,
|
1247 |
+
dataset: Dataset,
|
1248 |
+
metric_key_prefix: str = "eval",
|
1249 |
+
ckpt_dir: str = "",
|
1250 |
+
is_prediction=False,
|
1251 |
+
):
|
1252 |
|
1253 |
logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
|
1254 |
|
|
|
1266 |
split="prediction" if is_prediction else "validation",
|
1267 |
)
|
1268 |
steps = len(dataset) // eval_batch_size
|
1269 |
+
for _ in tqdm(
|
1270 |
+
range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False
|
1271 |
+
):
|
1272 |
# Model forward
|
1273 |
batch = next(batches)
|
1274 |
_labels = batch.get("labels", None)
|
|
|
1298 |
if labels:
|
1299 |
rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
|
1300 |
metrics.update(rouge_metrics)
|
1301 |
+
rouge_desc = " ".join(
|
1302 |
+
[
|
1303 |
+
f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |"
|
1304 |
+
for key, value in rouge_metrics.items()
|
1305 |
+
]
|
1306 |
+
)
|
1307 |
for pred, label in zip(decoded_preds, decoded_labels):
|
1308 |
pred = pred.replace("\n", " ")
|
1309 |
label = label.replace("\n", " ")
|
|
|
1336 |
|
1337 |
# Save metrics (only for the evaluation/prediction being done along with training)
|
1338 |
if has_tensorboard and training_args.do_train:
|
1339 |
+
write_metric(
|
1340 |
+
summary_writer, metrics, train_time=None, step=cur_step, metric_key_prefix=metric_key_prefix
|
1341 |
+
)
|
1342 |
|
1343 |
# save final metrics in json
|
1344 |
+
metrics = {
|
1345 |
+
f"{metric_key_prefix}_{metric_name}": round(value.item(), 6)
|
1346 |
+
for metric_name, value in metrics.items()
|
1347 |
+
}
|
1348 |
_path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
|
1349 |
with open(_path, "w") as f:
|
1350 |
json.dump(metrics, f, indent=4, sort_keys=True)
|
1351 |
|
1352 |
# Update report
|
1353 |
+
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
1354 |
+
fp.write(desc + "\n")
|
1355 |
|
1356 |
# Save generations
|
1357 |
if generations:
|
1358 |
+
with open(
|
1359 |
+
os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_generation.json"),
|
1360 |
+
"w",
|
1361 |
+
encoding="UTF-8",
|
1362 |
+
) as fp:
|
1363 |
json.dump(generations, fp, ensure_ascii=False, indent=4)
|
1364 |
|
1365 |
def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
|
1366 |
+
evaluation_loop(rng, dataset, metric_key_prefix="eval", ckpt_dir=ckpt_dir)
|
1367 |
|
1368 |
def predict(rng: jax.random.PRNGKey, dataset: Dataset):
|
1369 |
+
evaluation_loop(rng, dataset, metric_key_prefix="test", is_prediction=True)
|
1370 |
|
1371 |
input_rng = None
|
1372 |
|
|
|
1392 |
batch_size=train_batch_size,
|
1393 |
keep_in_memory=True,
|
1394 |
shuffle=True,
|
1395 |
+
split="train",
|
1396 |
)
|
1397 |
|
1398 |
# train
|
|
|
1416 |
|
1417 |
logger.info(desc)
|
1418 |
|
1419 |
+
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
1420 |
+
fp.write(desc + "\n")
|
1421 |
|
1422 |
# Save metrics
|
1423 |
if has_tensorboard and jax.process_index() == 0:
|
1424 |
+
write_metric(
|
1425 |
+
summary_writer,
|
1426 |
+
train_metrics,
|
1427 |
+
train_time=train_time,
|
1428 |
+
step=cur_step,
|
1429 |
+
metric_key_prefix="train",
|
1430 |
+
)
|
1431 |
|
1432 |
# ======================== Evaluating (inside an epoch) ==============================
|
1433 |
|
1434 |
+
if (
|
1435 |
+
training_args.do_eval
|
1436 |
+
and (training_args.eval_steps is not None and training_args.eval_steps > 0)
|
1437 |
+
and cur_step % training_args.eval_steps == 0
|
1438 |
+
):
|
1439 |
ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
|
1440 |
commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
|
1441 |
evaluate(input_rng, eval_dataset, ckpt_dir)
|
|
|
1448 |
|
1449 |
logger.info(desc)
|
1450 |
|
1451 |
+
with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
|
1452 |
+
fp.write(desc + "\n")
|
1453 |
|
1454 |
# Save metrics
|
1455 |
if has_tensorboard and jax.process_index() == 0:
|
1456 |
+
write_metric(
|
1457 |
+
summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train"
|
1458 |
+
)
|
1459 |
|
1460 |
# ======================== Evaluating (after each epoch) ==============================
|
1461 |
|