ydshieh commited on
Commit
a99072f
1 Parent(s): 66d7526

update model.py and add coco files

Browse files
coco_dataset_script.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import os
4
+
5
+ import datasets
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+
10
+ # TODO: Add BibTeX citation
11
+ # Find for instance the citation on arxiv or on the dataset repo/website
12
+ _CITATION = """\
13
+ @InProceedings{huggingface:dataset,
14
+ title = {A great new dataset},
15
+ author={huggingface, Inc.
16
+ },
17
+ year={2020}
18
+ }
19
+ """
20
+
21
+ # TODO: Add description of the dataset here
22
+ # You can copy an official description
23
+ _DESCRIPTION = """\
24
+ This new dataset is designed to solve this great NLP task and is crafted with a lot of care.
25
+ """
26
+
27
+ # TODO: Add a link to an official homepage for the dataset here
28
+ _HOMEPAGE = ""
29
+
30
+ # TODO: Add the licence for the dataset here if you can find it
31
+ _LICENSE = ""
32
+
33
+ # TODO: Add link to the official dataset URLs here
34
+ # The HuggingFace dataset library don't host the datasets but only point to the original files
35
+ # This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
36
+ _URLs = {
37
+ }
38
+
39
+
40
+ # TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
41
+ class COCODataset(datasets.GeneratorBasedBuilder):
42
+ """TODO: Short description of my dataset."""
43
+
44
+ VERSION = datasets.Version("1.1.0")
45
+
46
+ DEFAULT_CONFIG_NAME = "en"
47
+
48
+ def _info(self):
49
+ # TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
50
+
51
+ features = datasets.Features(
52
+ {
53
+ "id": datasets.Value("int64"),
54
+ "en": datasets.Value("string"),
55
+ "fr": datasets.Value("string"),
56
+ "image_id": datasets.Value("int64"),
57
+ "image_file": datasets.Value("string")
58
+ # These are the features of your dataset like images, labels ...
59
+ }
60
+ )
61
+
62
+ return datasets.DatasetInfo(
63
+ # This is the description that will appear on the datasets page.
64
+ description=_DESCRIPTION,
65
+ # This defines the different columns of the dataset and their types
66
+ features=features, # Here we define them above because they are different between the two configurations
67
+ # If there's a common (input, target) tuple from the features,
68
+ # specify them here. They'll be used if as_supervised=True in
69
+ # builder.as_dataset.
70
+ supervised_keys=None,
71
+ # Homepage of the dataset for documentation
72
+ homepage=_HOMEPAGE,
73
+ # License for the dataset if available
74
+ license=_LICENSE,
75
+ # Citation for the dataset
76
+ citation=_CITATION,
77
+ )
78
+
79
+ def _split_generators(self, dl_manager):
80
+ """Returns SplitGenerators."""
81
+ # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
82
+ # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
83
+
84
+ data_dir = self.config.data_dir
85
+
86
+ return [
87
+ datasets.SplitGenerator(
88
+ name=datasets.Split.TRAIN,
89
+ # These kwargs will be passed to _generate_examples
90
+ gen_kwargs={
91
+ "data_dir": data_dir,
92
+ "split": "train",
93
+ },
94
+ ),
95
+ datasets.SplitGenerator(
96
+ name=datasets.Split.TEST,
97
+ # These kwargs will be passed to _generate_examples
98
+ gen_kwargs={
99
+ "data_dir": data_dir,
100
+ "split": "test"
101
+ },
102
+ ),
103
+ datasets.SplitGenerator(
104
+ name=datasets.Split.VALIDATION,
105
+ # These kwargs will be passed to _generate_examples
106
+ gen_kwargs={
107
+ "data_dir": data_dir,
108
+ "split": "val",
109
+ },
110
+ ),
111
+ ]
112
+
113
+ def _generate_examples(
114
+ self, data_dir, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
115
+ ):
116
+ """ Yields examples as (key, example) tuples. """
117
+ # This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
118
+ # The `key` is here for legacy reason (tfds) and is not important in itself.
119
+
120
+ # /home/33611/caption/
121
+ # train2014
122
+
123
+ if split == 'dev':
124
+ split == 'val'
125
+
126
+ with open(os.path.join(data_dir, f'{split}.json')) as fp:
127
+ examples = json.load(fp)
128
+
129
+ for id_, ex in enumerate(examples):
130
+
131
+ image_id = ex["image_id"]
132
+ fn = f'COCO_{split}2014_{str(image_id).zfill(12)}.jpg'
133
+
134
+ image_file = os.path.join(data_dir, f'{split}2014', fn)
135
+
136
+ yield id_, {
137
+ "id": ex["id"],
138
+ "en": ex["caption"],
139
+ "fr": ex["fr"],
140
+ "image_id": ex["image_id"],
141
+ "image_file": image_file
142
+ }
model.py CHANGED
@@ -3,6 +3,9 @@ import sys, os
3
  current_path = os.path.dirname(os.path.abspath(__file__))
4
  sys.path.append(current_path)
5
 
 
 
 
6
  # Main model - ViTGPT2LM
7
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
 
@@ -24,21 +27,27 @@ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
24
  gpt2_model_name = 'asi/gpt-fr-cased-small'
25
  tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
26
 
27
- max_length = 16
28
- num_beams = 4
29
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
30
 
31
 
32
- def predict(image):
 
 
 
 
 
33
 
34
- image = Image.open(requests.get(url, stream=True).raw)
35
  # batch dim is added automatically
36
  encoder_inputs = feature_extractor(images=image, return_tensors="jax")
37
  pixel_values = encoder_inputs.pixel_values
38
 
 
 
 
39
  # generation
40
- batch = {'pixel_values': pixel_values}
41
- generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
42
 
43
  token_ids = np.array(generation.sequences)[0]
44
  caption = tokenizer.decode(token_ids)
@@ -48,10 +57,33 @@ def predict(image):
48
 
49
  if __name__ == '__main__':
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
53
- image = Image.open(requests.get(url, stream=True).raw)
54
- caption, token_ids = predict(image)
 
 
 
 
55
 
56
  print(f'token_ids: {token_ids}')
57
  print(f'caption: {caption}')
 
3
  current_path = os.path.dirname(os.path.abspath(__file__))
4
  sys.path.append(current_path)
5
 
6
+ # jax
7
+ import jax
8
+
9
  # Main model - ViTGPT2LM
10
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
11
 
 
27
  gpt2_model_name = 'asi/gpt-fr-cased-small'
28
  tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
29
 
30
+ max_length = 64
31
+ num_beams = 16
32
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
33
 
34
 
35
+ @jax.jit
36
+ def predict_fn(pixel_values):
37
+
38
+ return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
39
+
40
+ def predict(image, pxs=None):
41
 
 
42
  # batch dim is added automatically
43
  encoder_inputs = feature_extractor(images=image, return_tensors="jax")
44
  pixel_values = encoder_inputs.pixel_values
45
 
46
+ if pxs is not None:
47
+ pixel_values = pxs
48
+
49
  # generation
50
+ generation = predict_fn(pixel_values)
 
51
 
52
  token_ids = np.array(generation.sequences)[0]
53
  caption = tokenizer.decode(token_ids)
 
57
 
58
  if __name__ == '__main__':
59
 
60
+ from datetime import datetime
61
+
62
+ idx = 11
63
+ url = f'./wit_data_dir/train/images/{idx}.jpg'
64
+ image = Image.open(url)
65
+
66
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
67
+ pv1 = encoder_inputs.pixel_values
68
+ pv2 = np.load(f'./wit_data_dir/train/numpy/{idx}.npy')
69
+ print(np.sum(np.abs(pv1 - pv2)))
70
+
71
+ s = datetime.now()
72
+ caption, token_ids = predict(image, pxs=pv2)
73
+ e = datetime.now()
74
+ e = (e - s).total_seconds()
75
+ print(e)
76
+
77
+ print(f'token_ids: {token_ids}')
78
+ print(f'caption: {caption}')
79
 
80
+ for _ in range(1):
81
+ s = datetime.now()
82
+ caption, token_ids = predict(image, pxs=None)
83
+ e = datetime.now()
84
+ e = (e - s).total_seconds()
85
+ print(e)
86
+ print('-' * 20)
87
 
88
  print(f'token_ids: {token_ids}')
89
  print(f'caption: {caption}')
run_summarization_coco.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import sys, os
22
+
23
+ current_path = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append(current_path)
25
+
26
+ import logging
27
+ import os
28
+ import sys
29
+ import time
30
+ from dataclasses import dataclass, field
31
+ from functools import partial
32
+ from pathlib import Path
33
+ from typing import Callable, Optional
34
+
35
+ import datasets
36
+ import nltk # Here to have a nice missing dependency error message early on
37
+ import numpy as np
38
+ from datasets import Dataset, load_dataset, load_metric
39
+ from tqdm import tqdm
40
+
41
+ import jax
42
+ import jax.numpy as jnp
43
+ import optax
44
+ import transformers
45
+ from filelock import FileLock
46
+ from flax import jax_utils, traverse_util
47
+ from flax.jax_utils import unreplicate
48
+ from flax.training import train_state
49
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
50
+ from transformers import (
51
+ CONFIG_MAPPING,
52
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
53
+ AutoConfig,
54
+ AutoTokenizer,
55
+ FlaxAutoModelForSeq2SeqLM,
56
+ HfArgumentParser,
57
+ TrainingArguments,
58
+ is_tensorboard_available,
59
+ )
60
+ from transformers.file_utils import is_offline_mode
61
+
62
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
63
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
64
+
65
+ logger = logging.getLogger(__name__)
66
+
67
+ try:
68
+ nltk.data.find("tokenizers/punkt")
69
+ except (LookupError, OSError):
70
+ if is_offline_mode():
71
+ raise LookupError(
72
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
73
+ )
74
+ with FileLock(".lock") as lock:
75
+ nltk.download("punkt", quiet=True)
76
+
77
+
78
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
79
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
80
+
81
+
82
+ @dataclass
83
+ class ModelArguments:
84
+ """
85
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
86
+ """
87
+
88
+ model_name_or_path: Optional[str] = field(
89
+ default=None,
90
+ metadata={
91
+ "help": "The model checkpoint for weights initialization."
92
+ "Don't set if you want to train a model from scratch."
93
+ },
94
+ )
95
+ model_type: Optional[str] = field(
96
+ default=None,
97
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
98
+ )
99
+ config_name: Optional[str] = field(
100
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
101
+ )
102
+ tokenizer_name: Optional[str] = field(
103
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
104
+ )
105
+ cache_dir: Optional[str] = field(
106
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
107
+ )
108
+ use_fast_tokenizer: bool = field(
109
+ default=True,
110
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
111
+ )
112
+ dtype: Optional[str] = field(
113
+ default="float32",
114
+ metadata={
115
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
116
+ },
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class DataTrainingArguments:
122
+ """
123
+ Arguments pertaining to what data we are going to input our model for training and eval.
124
+ """
125
+
126
+ dataset_name: Optional[str] = field(
127
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
128
+ )
129
+ dataset_config_name: Optional[str] = field(
130
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
131
+ )
132
+ text_column: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
135
+ )
136
+ summary_column: Optional[str] = field(
137
+ default=None,
138
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
139
+ )
140
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
141
+ validation_file: Optional[str] = field(
142
+ default=None,
143
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
144
+ )
145
+ max_source_length: Optional[int] = field(
146
+ default=1024,
147
+ metadata={
148
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
149
+ "than this will be truncated, sequences shorter will be padded."
150
+ },
151
+ )
152
+ max_target_length: Optional[int] = field(
153
+ default=128,
154
+ metadata={
155
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
156
+ "than this will be truncated, sequences shorter will be padded."
157
+ },
158
+ )
159
+ val_max_target_length: Optional[int] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
163
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
164
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
165
+ "during evaluation."
166
+ },
167
+ )
168
+ max_train_samples: Optional[int] = field(
169
+ default=None,
170
+ metadata={
171
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
172
+ "value if set."
173
+ },
174
+ )
175
+ max_eval_samples: Optional[int] = field(
176
+ default=None,
177
+ metadata={
178
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
179
+ "value if set."
180
+ },
181
+ )
182
+ max_predict_samples: Optional[int] = field(
183
+ default=None,
184
+ metadata={
185
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
186
+ "value if set."
187
+ },
188
+ )
189
+ preprocessing_num_workers: Optional[int] = field(
190
+ default=None,
191
+ metadata={"help": "The number of processes to use for the preprocessing."},
192
+ )
193
+ source_prefix: Optional[str] = field(
194
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
195
+ )
196
+ predict_with_generate: bool = field(
197
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
198
+ )
199
+ num_beams: Optional[int] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
203
+ "which is used during evaluation."
204
+ },
205
+ )
206
+ overwrite_cache: bool = field(
207
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
208
+ )
209
+
210
+ def __post_init__(self):
211
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
212
+ raise ValueError("Need either a dataset name or a training/validation file.")
213
+ else:
214
+ if self.train_file is not None:
215
+ extension = self.train_file.split(".")[-1]
216
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
217
+ if self.validation_file is not None:
218
+ extension = self.validation_file.split(".")[-1]
219
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
220
+ if self.val_max_target_length is None:
221
+ self.val_max_target_length = self.max_target_length
222
+
223
+
224
+ summarization_name_mapping = {
225
+ "amazon_reviews_multi": ("review_body", "review_title"),
226
+ "big_patent": ("description", "abstract"),
227
+ "cnn_dailymail": ("article", "highlights"),
228
+ "orange_sum": ("text", "summary"),
229
+ "pn_summary": ("article", "summary"),
230
+ "psc": ("extract_text", "summary_text"),
231
+ "samsum": ("dialogue", "summary"),
232
+ "thaisum": ("body", "summary"),
233
+ "xglue": ("news_body", "news_title"),
234
+ "xsum": ("document", "summary"),
235
+ "wiki_summary": ("article", "highlights"),
236
+ }
237
+
238
+
239
+ class TrainState(train_state.TrainState):
240
+ dropout_rng: jnp.ndarray
241
+
242
+ def replicate(self):
243
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
244
+
245
+
246
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
247
+ """
248
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
249
+ Shuffle batches if `shuffle` is `True`.
250
+ """
251
+ steps_per_epoch = len(dataset) // batch_size
252
+
253
+ if shuffle:
254
+ batch_idx = jax.random.permutation(rng, len(dataset))
255
+ else:
256
+ batch_idx = jnp.arange(len(dataset))
257
+
258
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
259
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
260
+
261
+ for idx in batch_idx:
262
+ batch = dataset[idx]
263
+ batch = {k: jnp.array(v) for k, v in batch.items()}
264
+
265
+ batch = shard(batch)
266
+
267
+ yield batch
268
+
269
+
270
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
271
+ summary_writer.scalar("train_time", train_time, step)
272
+
273
+ train_metrics = get_metrics(train_metrics)
274
+ for key, vals in train_metrics.items():
275
+ tag = f"train_{key}"
276
+ for i, val in enumerate(vals):
277
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
278
+
279
+ for metric_name, value in eval_metrics.items():
280
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
281
+
282
+
283
+ def create_learning_rate_fn(
284
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
285
+ ) -> Callable[[int], jnp.array]:
286
+ """Returns a linear warmup, linear_decay learning rate function."""
287
+ steps_per_epoch = train_ds_size // train_batch_size
288
+ num_train_steps = steps_per_epoch * num_train_epochs
289
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
290
+ decay_fn = optax.linear_schedule(
291
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
292
+ )
293
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
294
+ return schedule_fn
295
+
296
+
297
+ def main():
298
+ # See all possible arguments in src/transformers/training_args.py
299
+ # or by passing the --help flag to this script.
300
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
301
+
302
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
303
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
304
+ # If we pass only one argument to the script and it's the path to a json file,
305
+ # let's parse it to get our arguments.
306
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
307
+ else:
308
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
309
+
310
+ if (
311
+ os.path.exists(training_args.output_dir)
312
+ and os.listdir(training_args.output_dir)
313
+ and training_args.do_train
314
+ and not training_args.overwrite_output_dir
315
+ ):
316
+ raise ValueError(
317
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
318
+ "Use --overwrite_output_dir to overcome."
319
+ )
320
+
321
+ # Make one log on every process with the configuration for debugging.
322
+ logging.basicConfig(
323
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
324
+ datefmt="%m/%d/%Y %H:%M:%S",
325
+ level=logging.INFO,
326
+ )
327
+ # Setup logging, we only want one process per machine to log things on the screen.
328
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
329
+ if jax.process_index() == 0:
330
+ datasets.utils.logging.set_verbosity_warning()
331
+ transformers.utils.logging.set_verbosity_info()
332
+ else:
333
+ datasets.utils.logging.set_verbosity_error()
334
+ transformers.utils.logging.set_verbosity_error()
335
+
336
+ # Set the verbosity to info of the Transformers logger (on main process only):
337
+ logger.info(f"Training/evaluation parameters {training_args}")
338
+
339
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
340
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
341
+ # (the dataset will be downloaded automatically from the datasets Hub).
342
+ #
343
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
344
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
345
+ #
346
+ if data_args.dataset_name is not None:
347
+ # Downloading and loading a dataset from the hub.
348
+ dataset = load_dataset(
349
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir='/home/33611/caption/'
350
+ )
351
+ else:
352
+ data_files = {}
353
+ if data_args.train_file is not None:
354
+ data_files["train"] = data_args.train_file
355
+ extension = data_args.train_file.split(".")[-1]
356
+ if data_args.validation_file is not None:
357
+ data_files["validation"] = data_args.validation_file
358
+ extension = data_args.validation_file.split(".")[-1]
359
+ if data_args.test_file is not None:
360
+ data_files["test"] = data_args.test_file
361
+ extension = data_args.test_file.split(".")[-1]
362
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
363
+
364
+ vit_name_path = 'google/vit-base-patch16-224-in21k'
365
+ gpt2_name_path = 'asi/gpt-fr-cased-small'
366
+
367
+ gpt2_config = GPT2Config.from_pretrained(gpt2_name_path)
368
+ gpt2_config.add_cross_attention = True
369
+
370
+
371
+ vit_gpt2_name_path = ''
372
+
373
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_name_path)
374
+
375
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_name_path)
376
+
377
+ if not vit_gpt2_name_path:
378
+ assert vit_name_path
379
+ assert gpt2_name_path
380
+ vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
381
+ vit_name_path, gpt2_name_path
382
+ )
383
+ else:
384
+ vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
385
+ vit_gpt2_name_path
386
+ )
387
+
388
+ model = vit_gpt2_model
389
+ model.config.is_encoder_decoder = True
390
+ model.config.decoder_start_token_id = gpt2_config.bos_token_id
391
+ model.config.bos_token_id = gpt2_config.bos_token_id
392
+ model.config.eos_token_id = gpt2_config.eos_token_id
393
+ model.config.pad_token_id = gpt2_config.pad_token_id
394
+
395
+ # Preprocessing the datasets.
396
+ # We need to tokenize inputs and targets.
397
+ if training_args.do_train:
398
+ column_names = dataset["train"].column_names
399
+ elif training_args.do_eval:
400
+ column_names = dataset["validation"].column_names
401
+ elif training_args.do_predict:
402
+ column_names = dataset["test"].column_names
403
+ else:
404
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
405
+ return
406
+
407
+ image_file_column = 'image_file'
408
+ caption_column = 'fr'
409
+
410
+ # Temporarily set max_target_length for training.
411
+ max_target_length = data_args.max_target_length
412
+
413
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
414
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
415
+ # for that dynamically import the `shift_tokens_right` function from the model file
416
+ model_module = __import__(vit_gpt2_model.__module__, fromlist=["shift_tokens_right"])
417
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
418
+
419
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
420
+ def preprocess_function(examples):
421
+
422
+ _pixel_values = []
423
+ for y in examples[image_file_column]:
424
+ with Image.open(y) as image:
425
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
426
+ x = encoder_inputs.pixel_values
427
+ _pixel_values.append(x)
428
+ pixel_values = np.concatenate(_pixel_values)
429
+
430
+ targets = examples[caption_column]
431
+
432
+ # Add eos_token!!
433
+ targets = [x.lower() + ' ' + tokenizer.eos_token for x in targets]
434
+
435
+ model_inputs = {}
436
+ model_inputs['pixel_values'] = pixel_values
437
+
438
+ # Setup the tokenizer for targets
439
+ with tokenizer.as_target_tokenizer():
440
+ labels = tokenizer(
441
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
442
+ )
443
+
444
+ model_inputs["labels"] = labels["input_ids"]
445
+
446
+ #print(labels["input_ids"])
447
+ #print(gpt2_config.pad_token_id)
448
+ #rint(gpt2_config.bos_token_id)
449
+
450
+ decoder_input_ids = shift_tokens_right_fn(
451
+ jnp.array(labels["input_ids"]), gpt2_config.pad_token_id, gpt2_config.bos_token_id
452
+ )
453
+ model_inputs["input_ids"] = np.asarray(decoder_input_ids)
454
+
455
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
456
+ model_inputs["attention_mask"] = labels["attention_mask"]
457
+
458
+ return model_inputs
459
+
460
+ if training_args.do_train:
461
+ if "train" not in dataset:
462
+ raise ValueError("--do_train requires a train dataset")
463
+ train_dataset = dataset["train"]
464
+ if data_args.max_train_samples is not None:
465
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
466
+
467
+ train_dataset = train_dataset.map(
468
+ preprocess_function,
469
+ batched=True,
470
+ num_proc=data_args.preprocessing_num_workers,
471
+ remove_columns=column_names,
472
+ load_from_cache_file=not data_args.overwrite_cache,
473
+ desc="Running tokenizer on train dataset",
474
+ )
475
+
476
+ if training_args.do_eval:
477
+ max_target_length = data_args.val_max_target_length
478
+ if "validation" not in dataset:
479
+ raise ValueError("--do_eval requires a validation dataset")
480
+ eval_dataset = dataset["validation"]
481
+ if data_args.max_eval_samples is not None:
482
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
483
+ eval_dataset = eval_dataset.map(
484
+ preprocess_function,
485
+ batched=True,
486
+ num_proc=data_args.preprocessing_num_workers,
487
+ remove_columns=column_names,
488
+ load_from_cache_file=not data_args.overwrite_cache,
489
+ desc="Running tokenizer on validation dataset",
490
+ )
491
+
492
+ if training_args.do_predict:
493
+ max_target_length = data_args.val_max_target_length
494
+ if "test" not in dataset:
495
+ raise ValueError("--do_predict requires a test dataset")
496
+ predict_dataset = dataset["test"]
497
+ if data_args.max_predict_samples is not None:
498
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
499
+ predict_dataset = predict_dataset.map(
500
+ preprocess_function,
501
+ batched=True,
502
+ num_proc=data_args.preprocessing_num_workers,
503
+ remove_columns=column_names,
504
+ load_from_cache_file=not data_args.overwrite_cache,
505
+ desc="Running tokenizer on prediction dataset",
506
+ )
507
+
508
+ # Metric
509
+ metric = load_metric("rouge")
510
+
511
+ def postprocess_text(preds, labels):
512
+ preds = [pred.strip() for pred in preds]
513
+ labels = [label.strip() for label in labels]
514
+
515
+ # rougeLSum expects newline after each sentence
516
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
517
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
518
+
519
+ return preds, labels
520
+
521
+ def compute_metrics(preds, labels):
522
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
523
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
524
+
525
+ # Some simple post-processing
526
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
527
+
528
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
529
+ # Extract a few results from ROUGE
530
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
531
+
532
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
533
+ result["gen_len"] = np.mean(prediction_lens)
534
+ result = {k: round(v, 4) for k, v in result.items()}
535
+ return result
536
+
537
+ # Enable tensorboard only on the master node
538
+ has_tensorboard = is_tensorboard_available()
539
+ if has_tensorboard and jax.process_index() == 0:
540
+ try:
541
+ from flax.metrics.tensorboard import SummaryWriter
542
+
543
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
544
+ except ImportError as ie:
545
+ has_tensorboard = False
546
+ logger.warning(
547
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
548
+ )
549
+ else:
550
+ logger.warning(
551
+ "Unable to display metrics through TensorBoard because the package is not installed: "
552
+ "Please run pip install tensorboard to enable."
553
+ )
554
+
555
+ # Initialize our training
556
+ rng = jax.random.PRNGKey(training_args.seed)
557
+ rng, dropout_rng = jax.random.split(rng)
558
+
559
+ # Store some constant
560
+ num_epochs = int(training_args.num_train_epochs)
561
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
562
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
563
+ steps_per_epoch = len(train_dataset) // train_batch_size
564
+ total_train_steps = steps_per_epoch * num_epochs
565
+
566
+ # Create learning rate schedule
567
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
568
+ len(train_dataset),
569
+ train_batch_size,
570
+ training_args.num_train_epochs,
571
+ training_args.warmup_steps,
572
+ training_args.learning_rate,
573
+ )
574
+
575
+ # We use Optax's "masking" functionality to not apply weight decay
576
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
577
+ # mask boolean with the same structure as the parameters.
578
+ # The mask is True for parameters that should be decayed.
579
+ # Note that this mask is specifically adapted for FlaxBart.
580
+ # For FlaxT5, one should correct the layer norm parameter naming
581
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
582
+ def decay_mask_fn(params):
583
+ flat_params = traverse_util.flatten_dict(params)
584
+ layer_norm_params = [
585
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
586
+ ]
587
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
588
+ return traverse_util.unflatten_dict(flat_mask)
589
+
590
+ # create adam optimizer
591
+ adamw = optax.adamw(
592
+ learning_rate=linear_decay_lr_schedule_fn,
593
+ b1=training_args.adam_beta1,
594
+ b2=training_args.adam_beta2,
595
+ eps=training_args.adam_epsilon,
596
+ weight_decay=training_args.weight_decay,
597
+ mask=decay_mask_fn,
598
+ )
599
+
600
+ # Setup train state
601
+ state = TrainState.create(apply_fn=vit_gpt2_model.__call__, params=vit_gpt2_model.params, tx=adamw, dropout_rng=dropout_rng)
602
+
603
+ # label smoothed cross entropy
604
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
605
+ """
606
+ The label smoothing implementation is adapted from Flax's official example:
607
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
608
+ """
609
+ vocab_size = logits.shape[-1]
610
+ confidence = 1.0 - label_smoothing_factor
611
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
612
+ normalizing_constant = -(
613
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
614
+ )
615
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
616
+
617
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
618
+ loss = loss - normalizing_constant
619
+
620
+ # ignore padded tokens from loss
621
+ loss = loss * padding_mask
622
+ loss = loss.sum() / padding_mask.sum()
623
+ return loss
624
+
625
+ # Define gradient update step fn
626
+ def train_step(state, batch, label_smoothing_factor=0.0):
627
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
628
+
629
+ def compute_loss(params):
630
+ labels = batch.pop("labels")
631
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
632
+ loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
633
+ return loss
634
+
635
+ grad_fn = jax.value_and_grad(compute_loss)
636
+ loss, grad = grad_fn(state.params)
637
+ grad = jax.lax.pmean(grad, "batch")
638
+
639
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
640
+
641
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
642
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
643
+
644
+ return new_state, metrics
645
+
646
+ # Define eval fn
647
+ def eval_step(params, batch, label_smoothing_factor=0.0):
648
+ labels = batch.pop("labels")
649
+ logits = model(**batch, params=params, train=False)[0]
650
+ loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
651
+
652
+ # summarize metrics
653
+ metrics = {"loss": loss}
654
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
655
+ return metrics
656
+
657
+ # Define generation function
658
+ max_length = (
659
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
660
+ )
661
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
662
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
663
+
664
+ def generate_step(params, batch):
665
+ model.params = params
666
+ # output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
667
+
668
+ #encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
669
+ #output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
670
+
671
+ # encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
672
+ output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
673
+
674
+
675
+ return output_ids.sequences
676
+
677
+ # Create parallel version of the train and eval step
678
+ p_train_step = jax.pmap(
679
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
680
+ )
681
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
682
+ p_generate_step = jax.pmap(generate_step, "batch")
683
+
684
+ # Replicate the train state on each device
685
+ state = state.replicate()
686
+
687
+ logger.info("***** Running training *****")
688
+ logger.info(f" Num examples = {len(train_dataset)}")
689
+ logger.info(f" Num Epochs = {num_epochs}")
690
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
691
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
692
+ logger.info(f" Total optimization steps = {total_train_steps}")
693
+
694
+ train_time = 0
695
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
696
+ for epoch in epochs:
697
+ # ======================== Training ================================
698
+ train_start = time.time()
699
+
700
+ # Create sampling rng
701
+ rng, input_rng = jax.random.split(rng)
702
+ train_metrics = []
703
+
704
+ # Generate an epoch by shuffling sampling indices from the train dataset
705
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
706
+ steps_per_epoch = len(train_dataset) // train_batch_size
707
+ # train
708
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
709
+ batch = next(train_loader)
710
+ state, train_metric = p_train_step(state, batch)
711
+ train_metrics.append(train_metric)
712
+
713
+ train_time += time.time() - train_start
714
+
715
+ train_metric = unreplicate(train_metric)
716
+
717
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
718
+ epochs.write(desc)
719
+ epochs.desc = desc
720
+ logger.info(desc)
721
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
722
+ fp.write(desc + '\n')
723
+
724
+
725
+ # ======================== Evaluating ==============================
726
+ eval_metrics = []
727
+ eval_preds = []
728
+ eval_labels = []
729
+
730
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
731
+ eval_steps = len(eval_dataset) // eval_batch_size
732
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
733
+ # Model forward
734
+ batch = next(eval_loader)
735
+ labels = batch["labels"]
736
+
737
+ metrics = p_eval_step(state.params, batch)
738
+ eval_metrics.append(metrics)
739
+
740
+ # generation
741
+ if data_args.predict_with_generate:
742
+ generated_ids = p_generate_step(state.params, batch)
743
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
744
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
745
+
746
+ # normalize eval metrics
747
+ eval_metrics = get_metrics(eval_metrics)
748
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
749
+
750
+ # compute ROUGE metrics
751
+ rouge_desc = ""
752
+ if data_args.predict_with_generate:
753
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
754
+ eval_metrics.update(rouge_metrics)
755
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
756
+
757
+ # Print metrics and update progress bar
758
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
759
+ epochs.write(desc)
760
+ epochs.desc = desc
761
+ logger.info(desc)
762
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
763
+ fp.write(desc + '\n')
764
+
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
769
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
770
+
771
+ # ======================== Prediction loop ==============================
772
+ if training_args.do_predict:
773
+ logger.info("*** Predict ***")
774
+
775
+ pred_metrics = []
776
+ pred_generations = []
777
+ pred_labels = []
778
+
779
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
780
+ pred_steps = len(predict_dataset) // eval_batch_size
781
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
782
+ # Model forward
783
+ batch = next(pred_loader)
784
+ labels = batch["labels"]
785
+
786
+ metrics = p_eval_step(state.params, batch)
787
+ pred_metrics.append(metrics)
788
+
789
+ # generation
790
+ if data_args.predict_with_generate:
791
+ generated_ids = p_generate_step(state.params, batch)
792
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
793
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
794
+
795
+ # normalize prediction metrics
796
+ pred_metrics = get_metrics(pred_metrics)
797
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
798
+
799
+ # compute ROUGE metrics
800
+ rouge_desc = ""
801
+ if data_args.predict_with_generate:
802
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
803
+ pred_metrics.update(rouge_metrics)
804
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
805
+
806
+ # Print metrics
807
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
808
+ epochs.write(desc)
809
+ epochs.desc = desc
810
+ logger.info(desc)
811
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
812
+ fp.write(desc + '\n')
813
+
814
+
815
+ # save checkpoint after each epoch and push checkpoint to the hub
816
+ if jax.process_index() == 0:
817
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
818
+ model.save_pretrained(
819
+ os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'),
820
+ params=params,
821
+ push_to_hub=training_args.push_to_hub,
822
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
823
+ )
824
+
825
+ if __name__ == "__main__":
826
+ main()
test_coco_dataset_script.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import os
4
+
5
+ import datasets
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ ds = datasets.load_dataset('./coco_dataset_script.py', data_dir='/home/33611/caption/')
10
+ ds = ds['train']
11
+
12
+
13
+ def transform(example):
14
+
15
+ example['pixel_values'] = np.load(example['pixels_file'])
16
+ return example
17
+
18
+
19
+ # ds = ds.map(transform)
20
+
21
+ n = 0
22
+ for x in ds:
23
+ n += 1
24
+ assert os.path.isfile(x['image_file'])
25
+ if n == 10:
26
+ print(x)
27
+ break