ydshieh commited on
Commit
5cb8b84
•
1 Parent(s): 790ecb1

clean repo

Browse files
.gitattributes CHANGED
@@ -21,3 +21,5 @@ wit_data_dir/test/test.tsv filter=lfs diff=lfs merge=lfs -text
21
  train.json filter=lfs diff=lfs merge=lfs -text
22
  val.json filter=lfs diff=lfs merge=lfs -text
23
  outputs/ckpt_5/flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
21
  train.json filter=lfs diff=lfs merge=lfs -text
22
  val.json filter=lfs diff=lfs merge=lfs -text
23
  outputs/ckpt_5/flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
24
+ C:/Users/33611/Desktop/hub/vit-gpt2/coco_data/train.json filter=lfs diff=lfs merge=lfs -text
25
+ C:/Users/33611/Desktop/hub/vit-gpt2/coco_data/val.json filter=lfs diff=lfs merge=lfs -text
train.json → coco_data/train.json RENAMED
File without changes
val.json → coco_data/val.json RENAMED
File without changes
outputs-wit/.gitattributes DELETED
@@ -1 +0,0 @@
1
- events.out.tfevents.1626423408.t1v-n-cab111a8-w-0.820839.3.v2 filter=lfs diff=lfs merge=lfs -text
 
outputs-wit/ckpt_7/.gitattributes DELETED
@@ -1,2 +0,0 @@
1
- flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
2
- config.json filter=lfs diff=lfs merge=lfs -text
 
 
outputs-wit/ckpt_7/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:162da9f3eab4459a15a6ae25e50ddaa379938891fb788ae8a395c7de0977e81d
3
- size 4229
 
 
 
outputs-wit/ckpt_7/flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:64476c6a5afee2acfefaa36f1eb14a97d13999e787cdfb2fc361e7df4acbf562
3
- size 1012706583
 
 
 
outputs-wit/events.out.tfevents.1626423408.t1v-n-cab111a8-w-0.820839.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc8d3f9c17dc59b0c63eb82d776844477760da16269693db8d13985d95516206
3
- size 476565
 
 
 
outputs-wit/summary.txt DELETED
@@ -1,21 +0,0 @@
1
- Epoch... (1/20 | Loss: 2.377821683883667, Learning Rate: 9.501103704678826e-06)
2
- Epoch... (1/20 | Eval Loss: 2.2912707328796387 | Eval rouge1: 17.1532 | Eval rouge2: 2.1991 | Eval rougeL: 12.1665 | Eval rougeLsum: 13.7971 | Eval gen_len: 58.3398 |)
3
- Predict Loss: 2.3000617027282715 | Predict rouge1: 17.3228 | Predict rouge2: 2.1974 | Predict rougeL: 12.2257 | Predict rougeLsum: 13.863 | Predict gen_len: 58.0887 |)
4
- Epoch... (2/20 | Loss: 2.318763017654419, Learning Rate: 9.001103535410948e-06)
5
- Epoch... (2/20 | Eval Loss: 2.2495603561401367 | Eval rouge1: 13.6938 | Eval rouge2: 0.963 | Eval rougeL: 10.0782 | Eval rougeLsum: 10.8405 | Eval gen_len: 58.3382 |)
6
- Predict Loss: 2.2592482566833496 | Predict rouge1: 13.7749 | Predict rouge2: 0.9371 | Predict rougeL: 10.0138 | Predict rougeLsum: 10.8695 | Predict gen_len: 58.1836 |)
7
- Epoch... (3/20 | Loss: 2.3419060707092285, Learning Rate: 8.501104275637772e-06)
8
- Epoch... (3/20 | Eval Loss: 2.22269344329834 | Eval rouge1: 12.0579 | Eval rouge2: 0.7251 | Eval rougeL: 9.092 | Eval rougeLsum: 9.3802 | Eval gen_len: 60.7578 |)
9
- Predict Loss: 2.233069896697998 | Predict rouge1: 12.5721 | Predict rouge2: 0.8881 | Predict rougeL: 9.4823 | Predict rougeLsum: 9.7638 | Predict gen_len: 60.5006 |)
10
- Epoch... (4/20 | Loss: 2.2800769805908203, Learning Rate: 8.001104106369894e-06)
11
- Epoch... (4/20 | Eval Loss: 2.2039794921875 | Eval rouge1: 14.2541 | Eval rouge2: 0.7585 | Eval rougeL: 10.3604 | Eval rougeLsum: 11.1679 | Eval gen_len: 60.3655 |)
12
- Predict Loss: 2.214798927307129 | Predict rouge1: 14.4009 | Predict rouge2: 0.8344 | Predict rougeL: 10.3895 | Predict rougeLsum: 11.2357 | Predict gen_len: 60.2483 |)
13
- Epoch... (5/20 | Loss: 2.220062494277954, Learning Rate: 7.501103482354665e-06)
14
- Epoch... (5/20 | Eval Loss: 2.1913952827453613 | Eval rouge1: 14.1698 | Eval rouge2: 0.8184 | Eval rougeL: 10.2918 | Eval rougeLsum: 11.245 | Eval gen_len: 60.1311 |)
15
- Predict Loss: 2.202223300933838 | Predict rouge1: 14.4567 | Predict rouge2: 0.9169 | Predict rougeL: 10.5117 | Predict rougeLsum: 11.3823 | Predict gen_len: 59.875 |)
16
- Epoch... (6/20 | Loss: 2.105752944946289, Learning Rate: 7.001103767834138e-06)
17
- Epoch... (6/20 | Eval Loss: 2.1800718307495117 | Eval rouge1: 14.6613 | Eval rouge2: 0.924 | Eval rougeL: 10.5021 | Eval rougeLsum: 11.672 | Eval gen_len: 61.7065 |)
18
- Predict Loss: 2.1911959648132324 | Predict rouge1: 14.972 | Predict rouge2: 0.9993 | Predict rougeL: 10.7166 | Predict rougeLsum: 11.843 | Predict gen_len: 61.8092 |)
19
- Epoch... (7/20 | Loss: 2.1191587448120117, Learning Rate: 6.50110359856626e-06)
20
- Epoch... (7/20 | Eval Loss: 2.1725244522094727 | Eval rouge1: 12.9676 | Eval rouge2: 1.1282 | Eval rougeL: 9.5649 | Eval rougeLsum: 10.702 | Eval gen_len: 59.4275 |)
21
- Predict Loss: 2.1837007999420166 | Predict rouge1: 13.161 | Predict rouge2: 1.1852 | Predict rougeL: 9.7344 | Predict rougeLsum: 10.9045 | Predict gen_len: 59.3945 |)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_summarization_coco.py → run_image_caption.py RENAMED
File without changes
run_summarization.py DELETED
@@ -1,832 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the library models for summarization.
18
- """
19
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
-
21
- import sys, os
22
-
23
- current_path = os.path.dirname(os.path.abspath(__file__))
24
- sys.path.append(current_path)
25
-
26
- import logging
27
- import os
28
- import sys
29
- import time
30
- from dataclasses import dataclass, field
31
- from functools import partial
32
- from pathlib import Path
33
- from typing import Callable, Optional
34
-
35
- import datasets
36
- import nltk # Here to have a nice missing dependency error message early on
37
- import numpy as np
38
- from datasets import Dataset, load_dataset, load_metric
39
- from tqdm import tqdm
40
-
41
- import jax
42
- import jax.numpy as jnp
43
- import optax
44
- import transformers
45
- from filelock import FileLock
46
- from flax import jax_utils, traverse_util
47
- from flax.jax_utils import unreplicate
48
- from flax.training import train_state
49
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
50
- from transformers import (
51
- CONFIG_MAPPING,
52
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
53
- AutoConfig,
54
- AutoTokenizer,
55
- FlaxAutoModelForSeq2SeqLM,
56
- HfArgumentParser,
57
- TrainingArguments,
58
- is_tensorboard_available,
59
- )
60
- from transformers.file_utils import is_offline_mode
61
-
62
- from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
63
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
64
-
65
- logger = logging.getLogger(__name__)
66
-
67
- try:
68
- nltk.data.find("tokenizers/punkt")
69
- except (LookupError, OSError):
70
- if is_offline_mode():
71
- raise LookupError(
72
- "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
73
- )
74
- with FileLock(".lock") as lock:
75
- nltk.download("punkt", quiet=True)
76
-
77
-
78
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
79
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
80
-
81
-
82
- @dataclass
83
- class ModelArguments:
84
- """
85
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
86
- """
87
-
88
- model_name_or_path: Optional[str] = field(
89
- default=None,
90
- metadata={
91
- "help": "The model checkpoint for weights initialization."
92
- "Don't set if you want to train a model from scratch."
93
- },
94
- )
95
- model_type: Optional[str] = field(
96
- default=None,
97
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
98
- )
99
- config_name: Optional[str] = field(
100
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
101
- )
102
- tokenizer_name: Optional[str] = field(
103
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
104
- )
105
- cache_dir: Optional[str] = field(
106
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
107
- )
108
- use_fast_tokenizer: bool = field(
109
- default=True,
110
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
111
- )
112
- dtype: Optional[str] = field(
113
- default="float32",
114
- metadata={
115
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
116
- },
117
- )
118
-
119
-
120
- @dataclass
121
- class DataTrainingArguments:
122
- """
123
- Arguments pertaining to what data we are going to input our model for training and eval.
124
- """
125
-
126
- dataset_name: Optional[str] = field(
127
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
128
- )
129
- dataset_config_name: Optional[str] = field(
130
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
131
- )
132
- text_column: Optional[str] = field(
133
- default=None,
134
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
135
- )
136
- summary_column: Optional[str] = field(
137
- default=None,
138
- metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
139
- )
140
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
141
- validation_file: Optional[str] = field(
142
- default=None,
143
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
144
- )
145
- max_source_length: Optional[int] = field(
146
- default=1024,
147
- metadata={
148
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
149
- "than this will be truncated, sequences shorter will be padded."
150
- },
151
- )
152
- max_target_length: Optional[int] = field(
153
- default=128,
154
- metadata={
155
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
156
- "than this will be truncated, sequences shorter will be padded."
157
- },
158
- )
159
- val_max_target_length: Optional[int] = field(
160
- default=None,
161
- metadata={
162
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
163
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
164
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
165
- "during evaluation."
166
- },
167
- )
168
- max_train_samples: Optional[int] = field(
169
- default=None,
170
- metadata={
171
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
172
- "value if set."
173
- },
174
- )
175
- max_eval_samples: Optional[int] = field(
176
- default=None,
177
- metadata={
178
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
179
- "value if set."
180
- },
181
- )
182
- max_predict_samples: Optional[int] = field(
183
- default=None,
184
- metadata={
185
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
186
- "value if set."
187
- },
188
- )
189
- preprocessing_num_workers: Optional[int] = field(
190
- default=None,
191
- metadata={"help": "The number of processes to use for the preprocessing."},
192
- )
193
- source_prefix: Optional[str] = field(
194
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
195
- )
196
- predict_with_generate: bool = field(
197
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
198
- )
199
- num_beams: Optional[int] = field(
200
- default=None,
201
- metadata={
202
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
203
- "which is used during evaluation."
204
- },
205
- )
206
- overwrite_cache: bool = field(
207
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
208
- )
209
-
210
- def __post_init__(self):
211
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
212
- raise ValueError("Need either a dataset name or a training/validation file.")
213
- else:
214
- if self.train_file is not None:
215
- extension = self.train_file.split(".")[-1]
216
- assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
217
- if self.validation_file is not None:
218
- extension = self.validation_file.split(".")[-1]
219
- assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
220
- if self.val_max_target_length is None:
221
- self.val_max_target_length = self.max_target_length
222
-
223
-
224
- summarization_name_mapping = {
225
- "amazon_reviews_multi": ("review_body", "review_title"),
226
- "big_patent": ("description", "abstract"),
227
- "cnn_dailymail": ("article", "highlights"),
228
- "orange_sum": ("text", "summary"),
229
- "pn_summary": ("article", "summary"),
230
- "psc": ("extract_text", "summary_text"),
231
- "samsum": ("dialogue", "summary"),
232
- "thaisum": ("body", "summary"),
233
- "xglue": ("news_body", "news_title"),
234
- "xsum": ("document", "summary"),
235
- "wiki_summary": ("article", "highlights"),
236
- }
237
-
238
-
239
- class TrainState(train_state.TrainState):
240
- dropout_rng: jnp.ndarray
241
-
242
- def replicate(self):
243
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
244
-
245
-
246
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
247
- """
248
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
249
- Shuffle batches if `shuffle` is `True`.
250
- """
251
- steps_per_epoch = len(dataset) // batch_size
252
-
253
- if shuffle:
254
- batch_idx = jax.random.permutation(rng, len(dataset))
255
- else:
256
- batch_idx = jnp.arange(len(dataset))
257
-
258
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
259
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
260
-
261
- for idx in batch_idx:
262
- batch = dataset[idx]
263
- batch = {k: jnp.array(v) for k, v in batch.items()}
264
-
265
- batch = shard(batch)
266
-
267
- yield batch
268
-
269
-
270
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
271
- summary_writer.scalar("train_time", train_time, step)
272
-
273
- train_metrics = get_metrics(train_metrics)
274
- for key, vals in train_metrics.items():
275
- tag = f"train_{key}"
276
- for i, val in enumerate(vals):
277
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
278
-
279
- for metric_name, value in eval_metrics.items():
280
- summary_writer.scalar(f"eval_{metric_name}", value, step)
281
-
282
-
283
- def create_learning_rate_fn(
284
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
285
- ) -> Callable[[int], jnp.array]:
286
- """Returns a linear warmup, linear_decay learning rate function."""
287
- steps_per_epoch = train_ds_size // train_batch_size
288
- num_train_steps = steps_per_epoch * num_train_epochs
289
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
290
- decay_fn = optax.linear_schedule(
291
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
292
- )
293
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
294
- return schedule_fn
295
-
296
-
297
- def main():
298
- # See all possible arguments in src/transformers/training_args.py
299
- # or by passing the --help flag to this script.
300
- # We now keep distinct sets of args, for a cleaner separation of concerns.
301
-
302
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
303
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
304
- # If we pass only one argument to the script and it's the path to a json file,
305
- # let's parse it to get our arguments.
306
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
307
- else:
308
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
309
-
310
- if (
311
- os.path.exists(training_args.output_dir)
312
- and os.listdir(training_args.output_dir)
313
- and training_args.do_train
314
- and not training_args.overwrite_output_dir
315
- ):
316
- raise ValueError(
317
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
318
- "Use --overwrite_output_dir to overcome."
319
- )
320
-
321
- # Make one log on every process with the configuration for debugging.
322
- logging.basicConfig(
323
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
324
- datefmt="%m/%d/%Y %H:%M:%S",
325
- level=logging.INFO,
326
- )
327
- # Setup logging, we only want one process per machine to log things on the screen.
328
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
329
- if jax.process_index() == 0:
330
- datasets.utils.logging.set_verbosity_warning()
331
- transformers.utils.logging.set_verbosity_info()
332
- else:
333
- datasets.utils.logging.set_verbosity_error()
334
- transformers.utils.logging.set_verbosity_error()
335
-
336
- # Set the verbosity to info of the Transformers logger (on main process only):
337
- logger.info(f"Training/evaluation parameters {training_args}")
338
-
339
- # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
340
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
341
- # (the dataset will be downloaded automatically from the datasets Hub).
342
- #
343
- # For CSV/JSON files this script will use the first column for the full texts and the second column for the
344
- # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
345
- #
346
- if data_args.dataset_name is not None:
347
- # Downloading and loading a dataset from the hub.
348
- dataset = load_dataset(
349
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir='./wit_data_dir/'
350
- )
351
- else:
352
- data_files = {}
353
- if data_args.train_file is not None:
354
- data_files["train"] = data_args.train_file
355
- extension = data_args.train_file.split(".")[-1]
356
- if data_args.validation_file is not None:
357
- data_files["validation"] = data_args.validation_file
358
- extension = data_args.validation_file.split(".")[-1]
359
- if data_args.test_file is not None:
360
- data_files["test"] = data_args.test_file
361
- extension = data_args.test_file.split(".")[-1]
362
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
363
-
364
- vit_name_path = 'google/vit-base-patch16-224-in21k'
365
- gpt2_name_path = 'asi/gpt-fr-cased-small'
366
-
367
- gpt2_config = GPT2Config.from_pretrained(gpt2_name_path)
368
- gpt2_config.add_cross_attention = True
369
-
370
-
371
- vit_gpt2_name_path = ''
372
-
373
- feature_extractor = ViTFeatureExtractor.from_pretrained(vit_name_path)
374
-
375
- tokenizer = GPT2Tokenizer.from_pretrained(gpt2_name_path)
376
-
377
- if not vit_gpt2_name_path:
378
- assert vit_name_path
379
- assert gpt2_name_path
380
- vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
381
- vit_name_path, gpt2_name_path
382
- )
383
- else:
384
- vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
385
- vit_gpt2_name_path
386
- )
387
-
388
- model = vit_gpt2_model
389
- model.config.is_encoder_decoder = True
390
- model.config.decoder_start_token_id = gpt2_config.bos_token_id
391
- model.config.bos_token_id = gpt2_config.bos_token_id
392
- model.config.eos_token_id = gpt2_config.eos_token_id
393
- model.config.pad_token_id = gpt2_config.pad_token_id
394
-
395
- # Preprocessing the datasets.
396
- # We need to tokenize inputs and targets.
397
- if training_args.do_train:
398
- column_names = dataset["train"].column_names
399
- elif training_args.do_eval:
400
- column_names = dataset["validation"].column_names
401
- elif training_args.do_predict:
402
- column_names = dataset["test"].column_names
403
- else:
404
- logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
405
- return
406
-
407
- image_file_column = 'image_file'
408
- caption_column = 'caption'
409
- pixels_file_column = 'pixels_file'
410
-
411
- # Temporarily set max_target_length for training.
412
- max_target_length = data_args.max_target_length
413
-
414
- # In Flax, for seq2seq models we need to pass `decoder_input_ids`
415
- # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
416
- # for that dynamically import the `shift_tokens_right` function from the model file
417
- model_module = __import__(vit_gpt2_model.__module__, fromlist=["shift_tokens_right"])
418
- shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
419
-
420
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
421
- def preprocess_function(examples):
422
-
423
- pixels_file = examples[pixels_file_column]
424
- if not pixels_file:
425
- assert examples[image_file_column]
426
- _pixel_values = []
427
- for y in examples[image_file_column]:
428
- with Image.open(y) as image:
429
- encoder_inputs = feature_extractor(images=image, return_tensors="np")
430
- x = encoder_inputs.pixel_values
431
- _pixel_values.append(x)
432
- pixel_values = np.concatenate(_pixel_values)
433
- else:
434
- pixel_values = np.concatenate([np.load(x) for x in pixels_file])
435
-
436
- targets = examples[caption_column]
437
-
438
- # Add eos_token!!
439
- targets = [x + ' ' + tokenizer.eos_token for x in targets]
440
-
441
- model_inputs = {}
442
- model_inputs['pixel_values'] = pixel_values
443
-
444
- # Setup the tokenizer for targets
445
- with tokenizer.as_target_tokenizer():
446
- labels = tokenizer(
447
- targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
448
- )
449
-
450
- model_inputs["labels"] = labels["input_ids"]
451
-
452
- #print(labels["input_ids"])
453
- #print(gpt2_config.pad_token_id)
454
- #rint(gpt2_config.bos_token_id)
455
-
456
- decoder_input_ids = shift_tokens_right_fn(
457
- jnp.array(labels["input_ids"]), gpt2_config.pad_token_id, gpt2_config.bos_token_id
458
- )
459
- model_inputs["input_ids"] = np.asarray(decoder_input_ids)
460
-
461
- # We need decoder_attention_mask so we can ignore pad tokens from loss
462
- model_inputs["attention_mask"] = labels["attention_mask"]
463
-
464
- return model_inputs
465
-
466
- if training_args.do_train:
467
- if "train" not in dataset:
468
- raise ValueError("--do_train requires a train dataset")
469
- train_dataset = dataset["train"]
470
- if data_args.max_train_samples is not None:
471
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
472
-
473
- train_dataset = train_dataset.map(
474
- preprocess_function,
475
- batched=True,
476
- num_proc=data_args.preprocessing_num_workers,
477
- remove_columns=column_names,
478
- load_from_cache_file=not data_args.overwrite_cache,
479
- desc="Running tokenizer on train dataset",
480
- )
481
-
482
- if training_args.do_eval:
483
- max_target_length = data_args.val_max_target_length
484
- if "validation" not in dataset:
485
- raise ValueError("--do_eval requires a validation dataset")
486
- eval_dataset = dataset["validation"]
487
- if data_args.max_eval_samples is not None:
488
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
489
- eval_dataset = eval_dataset.map(
490
- preprocess_function,
491
- batched=True,
492
- num_proc=data_args.preprocessing_num_workers,
493
- remove_columns=column_names,
494
- load_from_cache_file=not data_args.overwrite_cache,
495
- desc="Running tokenizer on validation dataset",
496
- )
497
-
498
- if training_args.do_predict:
499
- max_target_length = data_args.val_max_target_length
500
- if "test" not in dataset:
501
- raise ValueError("--do_predict requires a test dataset")
502
- predict_dataset = dataset["test"]
503
- if data_args.max_predict_samples is not None:
504
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
505
- predict_dataset = predict_dataset.map(
506
- preprocess_function,
507
- batched=True,
508
- num_proc=data_args.preprocessing_num_workers,
509
- remove_columns=column_names,
510
- load_from_cache_file=not data_args.overwrite_cache,
511
- desc="Running tokenizer on prediction dataset",
512
- )
513
-
514
- # Metric
515
- metric = load_metric("rouge")
516
-
517
- def postprocess_text(preds, labels):
518
- preds = [pred.strip() for pred in preds]
519
- labels = [label.strip() for label in labels]
520
-
521
- # rougeLSum expects newline after each sentence
522
- preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
523
- labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
524
-
525
- return preds, labels
526
-
527
- def compute_metrics(preds, labels):
528
- decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
529
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
530
-
531
- # Some simple post-processing
532
- decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
533
-
534
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
535
- # Extract a few results from ROUGE
536
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
537
-
538
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
539
- result["gen_len"] = np.mean(prediction_lens)
540
- result = {k: round(v, 4) for k, v in result.items()}
541
- return result
542
-
543
- # Enable tensorboard only on the master node
544
- has_tensorboard = is_tensorboard_available()
545
- if has_tensorboard and jax.process_index() == 0:
546
- try:
547
- from flax.metrics.tensorboard import SummaryWriter
548
-
549
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
550
- except ImportError as ie:
551
- has_tensorboard = False
552
- logger.warning(
553
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
554
- )
555
- else:
556
- logger.warning(
557
- "Unable to display metrics through TensorBoard because the package is not installed: "
558
- "Please run pip install tensorboard to enable."
559
- )
560
-
561
- # Initialize our training
562
- rng = jax.random.PRNGKey(training_args.seed)
563
- rng, dropout_rng = jax.random.split(rng)
564
-
565
- # Store some constant
566
- num_epochs = int(training_args.num_train_epochs)
567
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
568
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
569
- steps_per_epoch = len(train_dataset) // train_batch_size
570
- total_train_steps = steps_per_epoch * num_epochs
571
-
572
- # Create learning rate schedule
573
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
574
- len(train_dataset),
575
- train_batch_size,
576
- training_args.num_train_epochs,
577
- training_args.warmup_steps,
578
- training_args.learning_rate,
579
- )
580
-
581
- # We use Optax's "masking" functionality to not apply weight decay
582
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
583
- # mask boolean with the same structure as the parameters.
584
- # The mask is True for parameters that should be decayed.
585
- # Note that this mask is specifically adapted for FlaxBart.
586
- # For FlaxT5, one should correct the layer norm parameter naming
587
- # accordingly - see `run_t5_mlm_flax.py` e.g.
588
- def decay_mask_fn(params):
589
- flat_params = traverse_util.flatten_dict(params)
590
- layer_norm_params = [
591
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
592
- ]
593
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
594
- return traverse_util.unflatten_dict(flat_mask)
595
-
596
- # create adam optimizer
597
- adamw = optax.adamw(
598
- learning_rate=linear_decay_lr_schedule_fn,
599
- b1=training_args.adam_beta1,
600
- b2=training_args.adam_beta2,
601
- eps=training_args.adam_epsilon,
602
- weight_decay=training_args.weight_decay,
603
- mask=decay_mask_fn,
604
- )
605
-
606
- # Setup train state
607
- state = TrainState.create(apply_fn=vit_gpt2_model.__call__, params=vit_gpt2_model.params, tx=adamw, dropout_rng=dropout_rng)
608
-
609
- # label smoothed cross entropy
610
- def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
611
- """
612
- The label smoothing implementation is adapted from Flax's official example:
613
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
614
- """
615
- vocab_size = logits.shape[-1]
616
- confidence = 1.0 - label_smoothing_factor
617
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
618
- normalizing_constant = -(
619
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
620
- )
621
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
622
-
623
- loss = optax.softmax_cross_entropy(logits, soft_labels)
624
- loss = loss - normalizing_constant
625
-
626
- # ignore padded tokens from loss
627
- loss = loss * padding_mask
628
- loss = loss.sum() / padding_mask.sum()
629
- return loss
630
-
631
- # Define gradient update step fn
632
- def train_step(state, batch, label_smoothing_factor=0.0):
633
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
634
-
635
- def compute_loss(params):
636
- labels = batch.pop("labels")
637
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
638
- loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
639
- return loss
640
-
641
- grad_fn = jax.value_and_grad(compute_loss)
642
- loss, grad = grad_fn(state.params)
643
- grad = jax.lax.pmean(grad, "batch")
644
-
645
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
646
-
647
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
648
- metrics = jax.lax.pmean(metrics, axis_name="batch")
649
-
650
- return new_state, metrics
651
-
652
- # Define eval fn
653
- def eval_step(params, batch, label_smoothing_factor=0.0):
654
- labels = batch.pop("labels")
655
- logits = model(**batch, params=params, train=False)[0]
656
- loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
657
-
658
- # summarize metrics
659
- metrics = {"loss": loss}
660
- metrics = jax.lax.pmean(metrics, axis_name="batch")
661
- return metrics
662
-
663
- # Define generation function
664
- max_length = (
665
- data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
666
- )
667
- num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
668
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
669
-
670
- def generate_step(params, batch):
671
- model.params = params
672
- # output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
673
-
674
- #encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
675
- #output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
676
-
677
- # encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
678
- output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
679
-
680
-
681
- return output_ids.sequences
682
-
683
- # Create parallel version of the train and eval step
684
- p_train_step = jax.pmap(
685
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
686
- )
687
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
688
- p_generate_step = jax.pmap(generate_step, "batch")
689
-
690
- # Replicate the train state on each device
691
- state = state.replicate()
692
-
693
- logger.info("***** Running training *****")
694
- logger.info(f" Num examples = {len(train_dataset)}")
695
- logger.info(f" Num Epochs = {num_epochs}")
696
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
697
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
698
- logger.info(f" Total optimization steps = {total_train_steps}")
699
-
700
- train_time = 0
701
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
702
- for epoch in epochs:
703
- # ======================== Training ================================
704
- train_start = time.time()
705
-
706
- # Create sampling rng
707
- rng, input_rng = jax.random.split(rng)
708
- train_metrics = []
709
-
710
- # Generate an epoch by shuffling sampling indices from the train dataset
711
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
712
- steps_per_epoch = len(train_dataset) // train_batch_size
713
- # train
714
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
715
- batch = next(train_loader)
716
- state, train_metric = p_train_step(state, batch)
717
- train_metrics.append(train_metric)
718
-
719
- train_time += time.time() - train_start
720
-
721
- train_metric = unreplicate(train_metric)
722
-
723
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
724
- epochs.write(desc)
725
- epochs.desc = desc
726
- logger.info(desc)
727
- with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
728
- fp.write(desc + '\n')
729
-
730
-
731
- # ======================== Evaluating ==============================
732
- eval_metrics = []
733
- eval_preds = []
734
- eval_labels = []
735
-
736
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
737
- eval_steps = len(eval_dataset) // eval_batch_size
738
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
739
- # Model forward
740
- batch = next(eval_loader)
741
- labels = batch["labels"]
742
-
743
- metrics = p_eval_step(state.params, batch)
744
- eval_metrics.append(metrics)
745
-
746
- # generation
747
- if data_args.predict_with_generate:
748
- generated_ids = p_generate_step(state.params, batch)
749
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
750
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
751
-
752
- # normalize eval metrics
753
- eval_metrics = get_metrics(eval_metrics)
754
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
755
-
756
- # compute ROUGE metrics
757
- rouge_desc = ""
758
- if data_args.predict_with_generate:
759
- rouge_metrics = compute_metrics(eval_preds, eval_labels)
760
- eval_metrics.update(rouge_metrics)
761
- rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
762
-
763
- # Print metrics and update progress bar
764
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
765
- epochs.write(desc)
766
- epochs.desc = desc
767
- logger.info(desc)
768
- with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
769
- fp.write(desc + '\n')
770
-
771
-
772
- # Save metrics
773
- if has_tensorboard and jax.process_index() == 0:
774
- cur_step = epoch * (len(train_dataset) // train_batch_size)
775
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
776
-
777
- # ======================== Prediction loop ==============================
778
- if training_args.do_predict:
779
- logger.info("*** Predict ***")
780
-
781
- pred_metrics = []
782
- pred_generations = []
783
- pred_labels = []
784
-
785
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
786
- pred_steps = len(predict_dataset) // eval_batch_size
787
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
788
- # Model forward
789
- batch = next(pred_loader)
790
- labels = batch["labels"]
791
-
792
- metrics = p_eval_step(state.params, batch)
793
- pred_metrics.append(metrics)
794
-
795
- # generation
796
- if data_args.predict_with_generate:
797
- generated_ids = p_generate_step(state.params, batch)
798
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
799
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
800
-
801
- # normalize prediction metrics
802
- pred_metrics = get_metrics(pred_metrics)
803
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
804
-
805
- # compute ROUGE metrics
806
- rouge_desc = ""
807
- if data_args.predict_with_generate:
808
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
809
- pred_metrics.update(rouge_metrics)
810
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
811
-
812
- # Print metrics
813
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
814
- epochs.write(desc)
815
- epochs.desc = desc
816
- logger.info(desc)
817
- with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
818
- fp.write(desc + '\n')
819
-
820
-
821
- # save checkpoint after each epoch and push checkpoint to the hub
822
- if jax.process_index() == 0:
823
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
824
- model.save_pretrained(
825
- os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'),
826
- params=params,
827
- push_to_hub=training_args.push_to_hub,
828
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
829
- )
830
-
831
- if __name__ == "__main__":
832
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_wit_dataset_script.py DELETED
@@ -1,23 +0,0 @@
1
- import csv
2
- import json
3
- import os
4
-
5
- import datasets
6
- import pandas as pd
7
- import numpy as np
8
-
9
- ds = datasets.load_dataset('./wit_dataset_script.py', data_dir='./wit_data_dir/')
10
- test_ds = ds['test']
11
-
12
-
13
- def transform(example):
14
-
15
- example['pixel_values'] = np.load(example['pixels_file'])
16
- return example
17
-
18
-
19
- test_ds = test_ds.map(transform)
20
-
21
- for x in test_ds:
22
- print(x)
23
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wit_data_dir/dev/dev.tsv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef1ecdcd132885a8f29c8707fad649431c6ff3d9bbd295d56b8520e7046c0eb7
3
- size 1418232
 
 
 
wit_data_dir/test/test.tsv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f0517292749005808b1d1d75343c76b8b16c3ed74fde030f7af8b611ad7b4d5d
3
- size 1406997
 
 
 
wit_data_dir/train/train.tsv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:267de5cc6965e95f44795e78019ad9d0dfc648bee83a9fdb9cf9a92e8e4ac9d3
3
- size 45417661
 
 
 
wit_dataset_script.py DELETED
@@ -1,145 +0,0 @@
1
- import csv
2
- import json
3
- import os
4
-
5
- import datasets
6
- import pandas as pd
7
- import numpy as np
8
-
9
-
10
- # TODO: Add BibTeX citation
11
- # Find for instance the citation on arxiv or on the dataset repo/website
12
- _CITATION = """\
13
- @InProceedings{huggingface:dataset,
14
- title = {A great new dataset},
15
- author={huggingface, Inc.
16
- },
17
- year={2020}
18
- }
19
- """
20
-
21
- # TODO: Add description of the dataset here
22
- # You can copy an official description
23
- _DESCRIPTION = """\
24
- This new dataset is designed to solve this great NLP task and is crafted with a lot of care.
25
- """
26
-
27
- # TODO: Add a link to an official homepage for the dataset here
28
- _HOMEPAGE = ""
29
-
30
- # TODO: Add the licence for the dataset here if you can find it
31
- _LICENSE = ""
32
-
33
- # TODO: Add link to the official dataset URLs here
34
- # The HuggingFace dataset library don't host the datasets but only point to the original files
35
- # This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
36
- _URLs = {
37
- }
38
-
39
-
40
- # TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
41
- class WITDataset(datasets.GeneratorBasedBuilder):
42
- """TODO: Short description of my dataset."""
43
-
44
- VERSION = datasets.Version("1.1.0")
45
-
46
- DEFAULT_CONFIG_NAME = "en"
47
-
48
- def _info(self):
49
- # TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
50
-
51
- features = datasets.Features(
52
- {
53
- "id": datasets.Value("int64"),
54
- "lang": datasets.Value("string"),
55
- "caption": datasets.Value("string"),
56
- "context": datasets.Value("string"),
57
- "image_url": datasets.Value("string"),
58
- "page_url": datasets.Value("string"),
59
- "image_file": datasets.Value("string"),
60
- "pixels_file": datasets.Value("string")
61
- # These are the features of your dataset like images, labels ...
62
- }
63
- )
64
-
65
- return datasets.DatasetInfo(
66
- # This is the description that will appear on the datasets page.
67
- description=_DESCRIPTION,
68
- # This defines the different columns of the dataset and their types
69
- features=features, # Here we define them above because they are different between the two configurations
70
- # If there's a common (input, target) tuple from the features,
71
- # specify them here. They'll be used if as_supervised=True in
72
- # builder.as_dataset.
73
- supervised_keys=None,
74
- # Homepage of the dataset for documentation
75
- homepage=_HOMEPAGE,
76
- # License for the dataset if available
77
- license=_LICENSE,
78
- # Citation for the dataset
79
- citation=_CITATION,
80
- )
81
-
82
- def _split_generators(self, dl_manager):
83
- """Returns SplitGenerators."""
84
- # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
85
- # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
86
-
87
- data_dir = self.config.data_dir
88
-
89
- return [
90
- datasets.SplitGenerator(
91
- name=datasets.Split.TRAIN,
92
- # These kwargs will be passed to _generate_examples
93
- gen_kwargs={
94
- "data_dir": os.path.join(data_dir, "train"),
95
- "split": "train",
96
- },
97
- ),
98
- datasets.SplitGenerator(
99
- name=datasets.Split.TEST,
100
- # These kwargs will be passed to _generate_examples
101
- gen_kwargs={
102
- "data_dir": os.path.join(data_dir, "test"),
103
- "split": "test"
104
- },
105
- ),
106
- datasets.SplitGenerator(
107
- name=datasets.Split.VALIDATION,
108
- # These kwargs will be passed to _generate_examples
109
- gen_kwargs={
110
- "data_dir": os.path.join(data_dir, "dev"),
111
- "split": "dev",
112
- },
113
- ),
114
- ]
115
-
116
- def _generate_examples(
117
- self, data_dir, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
118
- ):
119
- """ Yields examples as (key, example) tuples. """
120
- # This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
121
- # The `key` is here for legacy reason (tfds) and is not important in itself.
122
-
123
- df = pd.read_csv(os.path.join(data_dir, f'{split}.tsv'), sep='\t')
124
-
125
- for id_, row in df.iterrows():
126
-
127
- _id = row[0]
128
-
129
- # null caption and context
130
- if type(row[4]) != str or type(row[5]) != str:
131
- continue
132
-
133
- image_file = os.path.join(data_dir, 'images', f'{_id}.jpg')
134
- pixels_file = os.path.join(data_dir, 'numpy', f'{_id}.npy')
135
-
136
- yield id_, {
137
- "id": row[0],
138
- "lang": row[1],
139
- "caption": row[4],
140
- "context": row[5],
141
- "image_url": row[2],
142
- "page_url": row[3],
143
- "image_file": image_file,
144
- "pixels_file": pixels_file
145
- }