diff --git a/.gitattributes b/.gitattributes index 6d34772f5ca361021038b404fb913ec8dc0b1a5a..a19a1377366c871891f180ff0dc5448919e0328b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -25,3 +25,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zstandard filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +coco_dataset/data/* filter=lfs diff=lfs merge=lfs -text +coco_dataset/data/**/ filter=lfs diff=lfs merge=lfs -text +coco_dataset/data/**/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/ filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/*/ filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/*" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/*" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/*" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**/*" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**/**" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**/**/* filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**/**/*" filter=lfs diff=lfs merge=lfs -text +image_caption_dataset/coco_dataset/**/**/**/**/**/**/**" filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/* filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/*" filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**" filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**/* filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**/*" filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**/**" filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**/**/* filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**/**/*" filter=lfs diff=lfs merge=lfs -text +coco_dataset/dummy_data/**/**/**" filter=lfs diff=lfs merge=lfs -text + diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5793fe023435c70af4220533c9eadb0d3919eb77 --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +## Example + +The model is by no means a state-of-the-art model, but nevertheless +produces reasonable image captioning results. It was mainly fine-tuned +as a proof-of-concept for the 🤗 FlaxVisionEncoderDecoder Framework. + +The model can be used as follows: + +```python + +import requests +from PIL import Image + +from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel + +loc = "ydshieh/flax-vit-gpt2-coco-en" + +feature_extractor = ViTFeatureExtractor.from_pretrained(loc) +tokenizer = AutoTokenizer.from_pretrained(loc) +model = FlaxVisionEncoderDecoderModel.from_pretrained(loc) + +# We will verify our results on an image of cute cats +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +with Image.open(requests.get(url, stream=True).raw) as img: + pixel_values = feature_extractor(images=img, return_tensors="np").pixel_values + +def generate_step(pixel_values): + + output_ids = model.generate(pixel_values, max_length=16, num_beams=4).sequences + preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + preds = [pred.strip() for pred in preds] + + return preds + +preds = generate_step(pixel_values) +print(preds) + +# should produce +# ['a cat laying on top of a couch next to another cat'] + +``` \ No newline at end of file diff --git a/coco_dataset/coco_dataset.py b/coco_dataset/coco_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6d98d4801bf26d51627e41039b169687e0820b --- /dev/null +++ b/coco_dataset/coco_dataset.py @@ -0,0 +1,191 @@ +import csv +import json +import os + +import datasets +import pandas as pd +import numpy as np + + +class ImageCaptionBuilderConfig(datasets.BuilderConfig): + + def __init__(self, name, splits, **kwargs): + + super().__init__(name, **kwargs) + + self.splits = splits + + +# TODO: Add BibTeX citation +# Find for instance the citation on arxiv or on the dataset repo/website +_CITATION = """\ +@InProceedings{None, + title = {COCO dataset}, + author={...}, + year={...} +} +""" + +# TODO: Add description of the dataset here +# You can copy an official description +_DESCRIPTION = """\ + +""" + +# TODO: Add a link to an official homepage for the dataset here +_HOMEPAGE = "" + +# TODO: Add the licence for the dataset here if you can find it +_LICENSE = "" + +# TODO: Add link to the official dataset URLs here +# The HuggingFace dataset library don't host the datasets but only point to the original files +# This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method) +_URLs = {} + + +# TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case +class ImageCaptionDataset(datasets.GeneratorBasedBuilder): + """TODO: Short description of my dataset.""" + + VERSION = datasets.Version("0.0.0") + + BUILDER_CONFIG_CLASS = ImageCaptionBuilderConfig + BUILDER_CONFIGS = [ + ImageCaptionBuilderConfig(name='2017', splits=['train', 'valid', 'test']), + ] + DEFAULT_CONFIG_NAME = "2017" + + def _info(self): + # TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset + + feature_dict = { + "image_id": datasets.Value("int64"), + "caption_id": datasets.Value("int64"), + "caption": datasets.Value("string"), + "height": datasets.Value("int64"), + "width": datasets.Value("int64"), + "file_name": datasets.Value("string"), + "coco_url": datasets.Value("string"), + "image_path": datasets.Value("string"), + } + + features = datasets.Features(feature_dict) + + return datasets.DatasetInfo( + # This is the description that will appear on the datasets page. + description=_DESCRIPTION, + # This defines the different columns of the dataset and their types + features=features, # Here we define them above because they are different between the two configurations + # If there's a common (input, target) tuple from the features, + # specify them here. They'll be used if as_supervised=True in + # builder.as_dataset. + supervised_keys=None, + # Homepage of the dataset for documentation + homepage=_HOMEPAGE, + # License for the dataset if available + license=_LICENSE, + # Citation for the dataset + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration + # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name + + data_dir = self.config.data_dir + + splits = [] + for split in self.config.splits: + if split == 'train': + dataset = datasets.SplitGenerator( + name=datasets.Split.TRAIN, + # These kwargs will be passed to _generate_examples + gen_kwargs={ + "json_path": os.path.join(data_dir, f"captions_train{self.config.name}.json"), + "image_dir": os.path.join(data_dir, f'train{self.config.name}'), + "split": "train", + } + ) + elif split in ['val', 'valid', 'validation', 'dev']: + dataset = datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + # These kwargs will be passed to _generate_examples + gen_kwargs={ + "json_path": os.path.join(data_dir, f"captions_val{self.config.name}.json"), + "image_dir": os.path.join(data_dir, f'val{self.config.name}'), + "split": "valid", + }, + ) + elif split == 'test': + dataset = datasets.SplitGenerator( + name=datasets.Split.TEST, + # These kwargs will be passed to _generate_examples + gen_kwargs={ + "json_path": os.path.join(data_dir, f'image_info_test{self.config.name}.json'), + "image_dir": os.path.join(data_dir, f'test{self.config.name}'), + "split": "test", + }, + ) + else: + continue + + splits.append(dataset) + + return splits + + def _generate_examples( + # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` + self, json_path, image_dir, split + ): + """ Yields examples as (key, example) tuples. """ + # This method handles input defined in _split_generators to yield (key, example) tuples from the dataset. + # The `key` is here for legacy reason (tfds) and is not important in itself. + + _features = ["image_id", "caption_id", "caption", "height", "width", "file_name", "coco_url", "image_path", "id"] + features = list(_features) + + if split in "valid": + split = "val" + + with open(json_path, 'r', encoding='UTF-8') as fp: + data = json.load(fp) + + # list of dict + images = data["images"] + entries = images + + # build a dict of image_id -> image info dict + d = {image["id"]: image for image in images} + + # list of dict + if split in ["train", "val"]: + annotations = data["annotations"] + + # build a dict of image_id -> + for annotation in annotations: + _id = annotation["id"] + image_info = d[annotation["image_id"]] + annotation.update(image_info) + annotation["id"] = _id + + entries = annotations + + for id_, entry in enumerate(entries): + + entry = {k: v for k, v in entry.items() if k in features} + + if split == "test": + entry["image_id"] = entry["id"] + entry["id"] = -1 + entry["caption"] = -1 + + entry["caption_id"] = entry.pop("id") + entry["image_path"] = os.path.join(image_dir, entry["file_name"]) + + entry = {k: entry[k] for k in _features if k in entry} + + print(entry) + + yield str((entry["image_id"], entry["caption_id"])), entry diff --git a/coco_dataset/dummy_data/annotations_trainval2017.zip b/coco_dataset/dummy_data/annotations_trainval2017.zip new file mode 100644 index 0000000000000000000000000000000000000000..6a89f7420500b3568a5a6eec8ba83d15706426bc --- /dev/null +++ b/coco_dataset/dummy_data/annotations_trainval2017.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:803f8cc8c2bd46633be6866aa0230d25916e0c2de6cad85bf11573fbe352efc6 +size 7458 diff --git a/coco_dataset/dummy_data/captions_train2017.json b/coco_dataset/dummy_data/captions_train2017.json new file mode 100644 index 0000000000000000000000000000000000000000..0352243262429b0fb062718361af0b6efc8a5f6a --- /dev/null +++ b/coco_dataset/dummy_data/captions_train2017.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3c77d344bd683c0e9b235892a2bc6db096bcd470feb80a5928b98696739a34d +size 20225 diff --git a/coco_dataset/dummy_data/captions_val2017.json b/coco_dataset/dummy_data/captions_val2017.json new file mode 100644 index 0000000000000000000000000000000000000000..786fa33407330df382581f6883a34859d2d50c5e --- /dev/null +++ b/coco_dataset/dummy_data/captions_val2017.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d081d092d0c7f7d052ce5759f366c0ddfe743418ea4d26ccad6a9d91ae0e4f51 +size 20148 diff --git a/coco_dataset/dummy_data/image_info_test-dev2017.json b/coco_dataset/dummy_data/image_info_test-dev2017.json new file mode 100644 index 0000000000000000000000000000000000000000..1e2cd33849a00e3a0354637aa7b75471da805666 --- /dev/null +++ b/coco_dataset/dummy_data/image_info_test-dev2017.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47b81764a5e1b843363a4ee847c12caa19f2f80fef0c43fb512110c575f27b61 +size 14535 diff --git a/coco_dataset/dummy_data/image_info_test2017.json b/coco_dataset/dummy_data/image_info_test2017.json new file mode 100644 index 0000000000000000000000000000000000000000..c8de8da0877e86d64eb9655d4c4c478ec2831e29 --- /dev/null +++ b/coco_dataset/dummy_data/image_info_test2017.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eba76001d86881501288d8600812a43dcda2c3dd2f4bc46b84966ef45412ee80 +size 15439 diff --git a/coco_dataset/dummy_data/image_info_test2017.zip b/coco_dataset/dummy_data/image_info_test2017.zip new file mode 100644 index 0000000000000000000000000000000000000000..03e4638446ebe78b18c37ca553c3d082a4156d05 --- /dev/null +++ b/coco_dataset/dummy_data/image_info_test2017.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e30ca36eba1cb0a631bf78201c92f508be9e105777d0c8b763edbc8b8517d5a2 +size 3498 diff --git a/coco_dataset/dummy_data/test2017.zip b/coco_dataset/dummy_data/test2017.zip new file mode 100644 index 0000000000000000000000000000000000000000..53e3a5765873debc3a8fa39a22ef68269e9c1029 --- /dev/null +++ b/coco_dataset/dummy_data/test2017.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abdc12551acb154ed65c38f493ce930ad22fffd74c1b1e1f811a46212eb28e91 +size 3023535 diff --git a/coco_dataset/dummy_data/test2017/000000000001.jpg b/coco_dataset/dummy_data/test2017/000000000001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e6434c6a0dd65a582adee0816b102d22c530d463 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91d74bd2a576101fdc78a448f8f8c52816dda6ec2e37e0bffcb9e30e950647ae +size 159192 diff --git a/coco_dataset/dummy_data/test2017/000000000016.jpg b/coco_dataset/dummy_data/test2017/000000000016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..46220df94f9513ed4b3dac64faae31e322a5408c --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000016.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e77b3610a4d60decaf6399090e772f45f6bb69d13902e46f1aa67e6ffa44c05d +size 230884 diff --git a/coco_dataset/dummy_data/test2017/000000000019.jpg b/coco_dataset/dummy_data/test2017/000000000019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ded45b4859638234671483d8990292c865ff6153 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000019.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd645634cc290ad2cba0d0544aea02347edb8510dbc69586d367afc68bafd5ed +size 284532 diff --git a/coco_dataset/dummy_data/test2017/000000000057.jpg b/coco_dataset/dummy_data/test2017/000000000057.jpg new file mode 100644 index 0000000000000000000000000000000000000000..44af5c9449c3be1114eb698dc011b5ec562ce707 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000057.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c00c990da26ba8b195d8e1ffaa1085cad53be95ec8de83c5d4941c1dd0912f2 +size 189943 diff --git a/coco_dataset/dummy_data/test2017/000000000063.jpg b/coco_dataset/dummy_data/test2017/000000000063.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fe04bd1c5fe428bcfe606d79b4c7e1c0daf78053 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000063.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e6eb85f9bce5a7622123bec2c2b62d182b35f162b1d04f7dc4597ab055ac6db +size 216259 diff --git a/coco_dataset/dummy_data/test2017/000000000069.jpg b/coco_dataset/dummy_data/test2017/000000000069.jpg new file mode 100644 index 0000000000000000000000000000000000000000..422984fc518e47c95606b3bfa437a51021183e20 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000069.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47193896385f8d284f3e20dd457253d2e51d4bd77e7e3531e4731a5f131a4493 +size 112853 diff --git a/coco_dataset/dummy_data/test2017/000000000080.jpg b/coco_dataset/dummy_data/test2017/000000000080.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c3d504ecf85668c63d77ce486ef3f5f60e74e20a --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000080.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77393198302233c5df1f3b5bfe9e7fcd20af37402db4e29692d608aa0090b678 +size 113415 diff --git a/coco_dataset/dummy_data/test2017/000000000090.jpg b/coco_dataset/dummy_data/test2017/000000000090.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9f78f62c92c171dbc4a027c13b68991b465c1c0f --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000090.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ff72dc5f69998b7777802f2d6c6dc261c67cf9f5ade8e0f94e030edd422467b +size 254598 diff --git a/coco_dataset/dummy_data/test2017/000000000106.jpg b/coco_dataset/dummy_data/test2017/000000000106.jpg new file mode 100644 index 0000000000000000000000000000000000000000..75af18f208ce7320070b56490b3549f80b3757fb --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000106.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4797648ce66c98cd4d12945b314120aa85b73f32db6a33f5bd07d7a290e872bb +size 247602 diff --git a/coco_dataset/dummy_data/test2017/000000000108.jpg b/coco_dataset/dummy_data/test2017/000000000108.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ff12338e33b9d75f4099b9eb930345db9a83d39 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000108.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dac65dacd78f1dbd389025524ca74055f3441419ec694e54633f699ef4ae1ab +size 272953 diff --git a/coco_dataset/dummy_data/test2017/000000000128.jpg b/coco_dataset/dummy_data/test2017/000000000128.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c824ce67664121ed41d7925857397d4152ce1267 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000128.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4741971ce3b5f5d3cf46b91ecce5cde63bcad143045c9739ecd39fee8bcd1832 +size 212272 diff --git a/coco_dataset/dummy_data/test2017/000000000155.jpg b/coco_dataset/dummy_data/test2017/000000000155.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20747f26b59134e5f74a0f030b9904804e2b0ae9 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000155.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:413ead1673ab92a827f3225de59b3100f591bb64f7886b31b692a712592ebe49 +size 79729 diff --git a/coco_dataset/dummy_data/test2017/000000000161.jpg b/coco_dataset/dummy_data/test2017/000000000161.jpg new file mode 100644 index 0000000000000000000000000000000000000000..609bd92b22647d7f3ddb1cf201d238593a96da43 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000161.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22eaba0e35f9e452c8e208aabb570dcb5f92fe16d0d30cc28e44a22d613370d6 +size 146892 diff --git a/coco_dataset/dummy_data/test2017/000000000171.jpg b/coco_dataset/dummy_data/test2017/000000000171.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0385bd89ae3ec8e2db39f4844891539e11321e0 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000171.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51f983839ec43b66572f88b187c8da76902e2da62967cf1b7fc8f8444bf544f6 +size 61337 diff --git a/coco_dataset/dummy_data/test2017/000000000178.jpg b/coco_dataset/dummy_data/test2017/000000000178.jpg new file mode 100644 index 0000000000000000000000000000000000000000..56ac19edf74e6dd6d19c7d7755103e49fb3f0868 --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000178.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47a086149db68c389925d97da1c56ca05e5375a3c29f418a61c5a155a5e6d9a2 +size 172866 diff --git a/coco_dataset/dummy_data/test2017/000000000180.jpg b/coco_dataset/dummy_data/test2017/000000000180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c51367d3bec2f57550609948af505669db1e2e0b --- /dev/null +++ b/coco_dataset/dummy_data/test2017/000000000180.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01a19909f472d5fd46e3604c164166b422e51995f06496b790a2e80576c90262 +size 285666 diff --git a/coco_dataset/dummy_data/train2017.zip b/coco_dataset/dummy_data/train2017.zip new file mode 100644 index 0000000000000000000000000000000000000000..4a86cf00a77de9410aeb4603b8ac5d43d07a4715 --- /dev/null +++ b/coco_dataset/dummy_data/train2017.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:370a953a7907709727f53c989d8282a1c0e47fb6fc0900cebf1c67ba362c6e73 +size 3590323 diff --git a/coco_dataset/dummy_data/train2017/000000000009.jpg b/coco_dataset/dummy_data/train2017/000000000009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c65ce6293c9d3963de3e63b9110234e3b027acf9 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000009.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35cdfe8259aca40d564baf33ee749d82ce852446bd9574f0c47551d8bfffda99 +size 224297 diff --git a/coco_dataset/dummy_data/train2017/000000000025.jpg b/coco_dataset/dummy_data/train2017/000000000025.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e1f50b2d4109efb4dc89c9c534abd2fe220993c5 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000025.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8f12a26d8803701cabac80494b080f998e5ed9bafaf61a2825ce6212c85487a +size 196370 diff --git a/coco_dataset/dummy_data/train2017/000000000030.jpg b/coco_dataset/dummy_data/train2017/000000000030.jpg new file mode 100644 index 0000000000000000000000000000000000000000..18b4f2d6c291f753a0d1cc69ee64ae18cd3c7555 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000030.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0444b10826d376ad9075805061405f6071a62b80eda29c5f284ed77b093d5b1d +size 71463 diff --git a/coco_dataset/dummy_data/train2017/000000000034.jpg b/coco_dataset/dummy_data/train2017/000000000034.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b2ac9b895eee1c2180810246d6e72a43fc9e94e --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000034.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c46871034fa901ae795a8bb916ba7f2f728507cab9e511cced0986bd083d193 +size 406018 diff --git a/coco_dataset/dummy_data/train2017/000000000036.jpg b/coco_dataset/dummy_data/train2017/000000000036.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9018aca6cb016e9f8622afd493faf2178de3041c --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000036.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b04d9d0a1ea8b930e11f293a832bdfe8d43892fdb96f1038219196d41a86b95 +size 260207 diff --git a/coco_dataset/dummy_data/train2017/000000000042.jpg b/coco_dataset/dummy_data/train2017/000000000042.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf3660610f0046ac45892e13bb6ff422a6a16b73 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000042.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9da8dbb3f415b0549a2c7ef8930e245af324191cc4c4465a7d75f05d96b0781d +size 213308 diff --git a/coco_dataset/dummy_data/train2017/000000000049.jpg b/coco_dataset/dummy_data/train2017/000000000049.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a61767d21e0e6c09d45ca4fd2e6cc4a3169e585a --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000049.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a29047d4bd83d26ae5370c03447bfa1ccd7005362ee180510e47af61746273a +size 124619 diff --git a/coco_dataset/dummy_data/train2017/000000000061.jpg b/coco_dataset/dummy_data/train2017/000000000061.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f22b08c280a1f114140e80f18453d0f1cca7330c --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000061.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03919a304272b466ac9415ecc8eebf7e6ec6d5a370e027ea8ed8c348e73d4024 +size 400343 diff --git a/coco_dataset/dummy_data/train2017/000000000064.jpg b/coco_dataset/dummy_data/train2017/000000000064.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d47059a181bafc49e6196e5a2cc5f6c0e6ad93e0 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000064.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efb62bc395400df8dfe6ab5ea3577913ca029753628f37af4bcca8144fd67af5 +size 220869 diff --git a/coco_dataset/dummy_data/train2017/000000000071.jpg b/coco_dataset/dummy_data/train2017/000000000071.jpg new file mode 100644 index 0000000000000000000000000000000000000000..38e97afdd147910e9fe6f424bf95066255a13466 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000071.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9e8680a3bdd1021d8cce453d91255df472e9c7cf79bf6312b80b5df8d7b665c +size 214185 diff --git a/coco_dataset/dummy_data/train2017/000000000072.jpg b/coco_dataset/dummy_data/train2017/000000000072.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b843084cff6a8f0eac7254deb856ef37a9ce2a2 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000072.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34985f0e3d9e71652c47bce8813b6a1925f431c0925a094b37208a97df670a7b +size 239116 diff --git a/coco_dataset/dummy_data/train2017/000000000073.jpg b/coco_dataset/dummy_data/train2017/000000000073.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1570979f0727257b6a8b64c9004f1cccd6378bda --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000073.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a202bf8a3327508e135aa7a2ba01014f67cd68253d2eed33dd7a5e1e24523aef +size 383651 diff --git a/coco_dataset/dummy_data/train2017/000000000074.jpg b/coco_dataset/dummy_data/train2017/000000000074.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab09fb5437f5b7514a4fd1b5779cd61676e7a0f6 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000074.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de6e1fbb6b8569bb85d2e82c9aa6cc528ac87fba11135c1f3691789a27fc13a5 +size 176151 diff --git a/coco_dataset/dummy_data/train2017/000000000077.jpg b/coco_dataset/dummy_data/train2017/000000000077.jpg new file mode 100644 index 0000000000000000000000000000000000000000..868906bac10d9d6e76f4e1565484b439674867d5 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000077.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c45ca68cf1c5af1efb6dc04df82e8821580e97862761914e665c7b965850b14d +size 159213 diff --git a/coco_dataset/dummy_data/train2017/000000000078.jpg b/coco_dataset/dummy_data/train2017/000000000078.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0b83076ac2bbbf826be90fed3d750cabb019a4e5 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000078.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbf237b294405624b5d662fe7f45a4c1b77bc275ade45f0737f5a6f6b1b6bb19 +size 209700 diff --git a/coco_dataset/dummy_data/train2017/000000000081.jpg b/coco_dataset/dummy_data/train2017/000000000081.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c244e2c615c897cf5052dff9224cca90ac261241 --- /dev/null +++ b/coco_dataset/dummy_data/train2017/000000000081.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a57ac5980eff9d319871e17b5fe17363b5788861957d665cfe9e9c43fc9c27a +size 113261 diff --git a/coco_dataset/dummy_data/val2017.zip b/coco_dataset/dummy_data/val2017.zip new file mode 100644 index 0000000000000000000000000000000000000000..ba2cf905d5b854183f1dd8a5d662c79d29cc3713 --- /dev/null +++ b/coco_dataset/dummy_data/val2017.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3625319b0adc788f918dea4a32a3c551323fd883888e8a03668cb47dd823b94 +size 2556555 diff --git a/coco_dataset/dummy_data/val2017/000000000139.jpg b/coco_dataset/dummy_data/val2017/000000000139.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fcc0ee640f200f4557b4b77fd2fd010cf1d1bc2c --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000139.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffe0f0cec3b2e27aab1967229cdf0a0d7751dcdd5800322f0b8ac0dffb3b8a8d +size 161811 diff --git a/coco_dataset/dummy_data/val2017/000000000285.jpg b/coco_dataset/dummy_data/val2017/000000000285.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d6fac3296dfd30328248b229bcea70f1d6691119 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000285.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3a2974ce3686332609124c70e3e6a2e3aca43fccf1cd1bd7c5c03820977f57d +size 335861 diff --git a/coco_dataset/dummy_data/val2017/000000000632.jpg b/coco_dataset/dummy_data/val2017/000000000632.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8a13832dd7305067b5d33cfa7ce6d695c3972347 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000632.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4cd7f45ac1ce27eaafb254b23af7c0b18a064be08870ceaaf03b2147f2ce550 +size 155667 diff --git a/coco_dataset/dummy_data/val2017/000000000724.jpg b/coco_dataset/dummy_data/val2017/000000000724.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5c43ddad3428ab43630dea7797807454e64f6d80 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000724.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c0e559c75d3969c8e3e297b61f61063f78045c9d4802b526ba616361f3823fd +size 130107 diff --git a/coco_dataset/dummy_data/val2017/000000000776.jpg b/coco_dataset/dummy_data/val2017/000000000776.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5138e0a56d766e85e4a40578a2908e1731154c3e --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000776.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dd31e9059c491992be2f562624eb4093e17aee08b4f7baf5ff9ea24543b0a33 +size 176410 diff --git a/coco_dataset/dummy_data/val2017/000000000785.jpg b/coco_dataset/dummy_data/val2017/000000000785.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d1d462ab49b9ec885e390c8c84a5c5e4527ef141 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000785.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83981537a7baeafbeb9c8cb67b3484dc26433f574b3685d021fa537e277e4726 +size 133674 diff --git a/coco_dataset/dummy_data/val2017/000000000802.jpg b/coco_dataset/dummy_data/val2017/000000000802.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf2e6cd173050412836b19118275aba16ce6736e --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000802.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5b79e7fa716f85ca86f46e5f518da9b6c5e26414925624d76e6476861aec495 +size 62406 diff --git a/coco_dataset/dummy_data/val2017/000000000872.jpg b/coco_dataset/dummy_data/val2017/000000000872.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2c7e4a2a048db3163173011ccb93c49b1de483f --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000872.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2aa138ee3a59b057a7ba6fc5a6a18e62af531aa7dab78a7bfd33c1cd7e55eb6 +size 317669 diff --git a/coco_dataset/dummy_data/val2017/000000000885.jpg b/coco_dataset/dummy_data/val2017/000000000885.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ded3f060579f826dae2b9c7a71ee62453679689 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000000885.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c67b783f78b18ddb5dd20aaba4d4aacaf98e5277051688a96da096afa9cdf1f +size 111441 diff --git a/coco_dataset/dummy_data/val2017/000000001000.jpg b/coco_dataset/dummy_data/val2017/000000001000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4d2aa4b1ac4940208a8e42cd490ab5bb32429109 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001000.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24bb77a31928404e45a0454b06f6a0bd54a8db103590c9d7917288a1e0269f05 +size 321136 diff --git a/coco_dataset/dummy_data/val2017/000000001268.jpg b/coco_dataset/dummy_data/val2017/000000001268.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c6ddf5be5f14536c46129b3444e430531163a4ea --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001268.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e85a2e71ee512fe748a3a3be9bb40b8591e8c501c0f3896efcb5be5999ec236 +size 180751 diff --git a/coco_dataset/dummy_data/val2017/000000001296.jpg b/coco_dataset/dummy_data/val2017/000000001296.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d28461589ff55ec552d4833770fc547fa66f3c1e --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001296.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34a98da7c11bc8811e6eec445145bb16f3d5f4338945ba116edcc1fe554d181b +size 204829 diff --git a/coco_dataset/dummy_data/val2017/000000001353.jpg b/coco_dataset/dummy_data/val2017/000000001353.jpg new file mode 100644 index 0000000000000000000000000000000000000000..be7f8c3ae472f6123ec2e932d6be67a4c3db8f73 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001353.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cccef217a1e5eec2acc74d7c08a105da3ef6fdb8514e221cdea9b0b373b60ea +size 83033 diff --git a/coco_dataset/dummy_data/val2017/000000001425.jpg b/coco_dataset/dummy_data/val2017/000000001425.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05e68673ae997394ecd2b92f7bc99c384577a182 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001425.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71a37c25cdf944cf0a88b66d00a66cf899a88bd55c7ac80fcf1c76329d617e6c +size 111309 diff --git a/coco_dataset/dummy_data/val2017/000000001490.jpg b/coco_dataset/dummy_data/val2017/000000001490.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb7baf9b0be777cd2c15673f7a479e8554616bac --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001490.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a044804a0b0186a1267acae34975932f14ede96c4a2b78fcee74e6c2fbd8414d +size 101098 diff --git a/coco_dataset/dummy_data/val2017/000000001503.jpg b/coco_dataset/dummy_data/val2017/000000001503.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5b25450971ea28d4f5b83bb62a4cdefd5cf3453 --- /dev/null +++ b/coco_dataset/dummy_data/val2017/000000001503.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06c721f0a62bcb9fcb2d2d0cc0b9edd63092084a8773a18c4b5a7bd6fb4f095d +size 15259 diff --git a/coco_dataset/usage_example.py b/coco_dataset/usage_example.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8013fc457ffd1918eea8228ebd3558a2cfd0ce --- /dev/null +++ b/coco_dataset/usage_example.py @@ -0,0 +1,32 @@ +from datasets import load_dataset +from PIL import Image +import os + + +dataset_name = "coco_dataset.py" +dataset_config_name = "2017" +cache_dir = None +keep_in_memory = False +data_dir = "./dummy_data/" + +dataset = load_dataset( + dataset_name, dataset_config_name, cache_dir=cache_dir, keep_in_memory=keep_in_memory, data_dir=data_dir +) + +for example in dataset["train"]: + print(example) + with Image.open(os.path.join(example['image_path'])) as image: + image.show() + break + +for _idx, example in enumerate(dataset["validation"]): + print(example) + with Image.open(os.path.join(example['image_path'])) as image: + image.show() + break + +for _idx, example in enumerate(dataset["test"]): + print(example) + with Image.open(os.path.join(example['image_path'])) as image: + image.show() + break diff --git a/create_dummy_pretrained_models.py b/create_dummy_pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..adbb5d6c77ea52bd8c3cc7fa671806e9d01b9b98 --- /dev/null +++ b/create_dummy_pretrained_models.py @@ -0,0 +1,93 @@ +from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer + +print(1) + +hidden_size = 8 +num_hidden_layers = 2 +num_attention_heads = 2 +intermediate_size = 16 + +n_embd = 8 +n_layer = 2 +n_head = 2 +n_inner = 16 + +encoder_config = ViTConfig( + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, +) +decoder_config = GPT2Config( + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, +) + + +print(2) + +encoder = FlaxViTModel(encoder_config) +decoder = FlaxGPT2Model(decoder_config) +encoder.save_pretrained("./encoder-decoder/encoder") +decoder.save_pretrained("./encoder-decoder/decoder") + +print(3) + +enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( + "./encoder-decoder/encoder", + "./encoder-decoder/decoder", +) +enocder_decoder.save_pretrained("./encoder-decoder") +enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder") + +print(4) + +config = enocder_decoder.config + +decoder_start_token_id = getattr(config, "decoder_start_token_id", None) +if not decoder_start_token_id and getattr(config, "decoder", None): + decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None) +bos_token_id = getattr(config, "bos_token_id", None) +if not bos_token_id and getattr(config, "decoder", None): + bos_token_id = getattr(config.decoder, "bos_token_id", None) +eos_token_id = getattr(config, "eos_token_id", None) +if not eos_token_id and getattr(config, "decoder", None): + eos_token_id = getattr(config.decoder, "eos_token_id", None) +pad_token_id = getattr(config, "pad_token_id", None) +if not pad_token_id and getattr(config, "decoder", None): + pad_token_id = getattr(config.decoder, "pad_token_id", None) + +if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id +if pad_token_id is None: + pad_token_id = eos_token_id + +config.decoder_start_token_id = decoder_start_token_id +config.bos_token_id = bos_token_id +config.eos_token_id = eos_token_id +config.pad_token_id = pad_token_id + +if getattr(config, "decoder", None): + config.decoder.decoder_start_token_id = decoder_start_token_id + config.decoder.bos_token_id = bos_token_id + config.decoder.eos_token_id = eos_token_id + config.decoder.pad_token_id = pad_token_id + +fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) + +fe.save_pretrained("./encoder-decoder/encoder") +tokenizer.save_pretrained("./encoder-decoder/decoder") + +targets = ['i love dog', 'you cat is very cute'] + +# Setup the tokenizer for targets +with tokenizer.as_target_tokenizer(): + labels = tokenizer( + targets, max_length=8, padding="max_length", truncation=True, return_tensors="np" + ) + + print(labels) diff --git a/run_image_captioning_flax.py b/run_image_captioning_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..c81a7522e70f3b8001bebe6d4cc9c5688204bfd0 --- /dev/null +++ b/run_image_captioning_flax.py @@ -0,0 +1,1224 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for summarization. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. +import json +import logging +import os +import sys +import time +from dataclasses import dataclass, field +import datetime +from functools import partial +from pathlib import Path +from typing import Callable, Optional + +import datasets +import nltk # Here to have a nice missing dependency error message early on +import numpy as np +from datasets import Dataset, load_dataset, load_metric +from tqdm import tqdm +from PIL import Image + +import jax +import jax.numpy as jnp +import optax +import transformers +from filelock import FileLock +from flax import jax_utils, traverse_util +from flax.jax_utils import unreplicate +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from huggingface_hub import Repository +from transformers import ( + CONFIG_MAPPING, + FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoTokenizer, + HfArgumentParser, + TrainingArguments, + is_tensorboard_available, + FlaxAutoModelForVision2Seq, +) +from transformers.file_utils import get_full_repo_name, is_offline_mode + + +logger = logging.getLogger(__name__) + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: np.ndarray, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +@dataclass +class CustomTrainingArguments(TrainingArguments): + + do_predict_during_training: bool = field(default=None, metadata={"help": "???"}) + do_predict_after_evaluation: bool = field(default=None, metadata={"help": "???"}) + block_size: int = field(default=None, metadata={"help": "???"}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + encoder_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The encoder model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + decoder_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The decoder model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default='vision-encoder-decoder', + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + encoder_model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + decoder_model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + encoder_config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as encoder_model_name"} + ) + decoder_config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as decoder_model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained feature extractor_name name or path if not the same as encoder_model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as decoder_model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + data_dir: Optional[str] = field( + default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."} + ) + image_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full image file paths (for image captioning)."}, + ) + caption_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the image captions (for image captioning)."}, + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input predict data file to do prediction on (a text file)."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the `max_length` param of `model.generate`, which is used " + "during evaluation." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, " + "which is used during evaluation." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +image_captioning_name_mapping = { + "image_caption_dataset.py": ("image_file", "caption"), +} + + +class TrainState(train_state.TrainState): + dropout_rng: jnp.ndarray + + def replicate(self): + return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) + + +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): + """ + Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. + Shuffle batches if `shuffle` is `True`. + """ + steps_per_epoch = len(dataset) // batch_size + + if shuffle: + batch_idx = jax.random.permutation(rng, len(dataset)) + else: + batch_idx = jnp.arange(len(dataset)) + + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + + for idx in batch_idx: + batch = dataset[idx] + batch = {k: jnp.array(v) for k, v in batch.items()} + + batch = shard(batch) + + yield batch + + +def write_metric(summary_writer, mode, metrics, step, train_time=None): + + if train_time: + summary_writer.scalar("train_time", train_time, step) + + if mode == "train": + metrics = get_metrics(metrics) + for key, vals in metrics.items(): + tag = f"{mode}_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + elif mode in ["valid", "pred"]: + for metric_name, value in metrics.items(): + summary_writer.scalar(f"{mode}_{metric_name}", value, step) + + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) + if jax.process_index() == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files this script will use the first column for the full texts and the second column for the + # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). + # + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, keep_in_memory=False, data_dir=data_args.data_dir, + cache_dir="./dataset_cache/" + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + # TODO: Check + dataset = load_dataset(extension, data_files=data_files, cache_dir="./dataset_cache/", data_dir=data_args.data_dir, ) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + + encoder_cache_dir, decoder_cache_dir = None, None + if model_args.cache_dir: + encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder") + decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder") + + # Use explicit specified config + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) + # Use pretrained model's config + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + # Use specified `model_type` (default to `vision-encoder-decoder`) + else: + + if not model_args.model_type in MODEL_TYPES: + raise ValueError( + f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}." + ) + config_class = CONFIG_MAPPING[model_args.model_type] + + # Deal with encoder-decoder models that require specifying encoder/decoder + if hasattr(config_class, "from_encoder_decoder_configs"): + + # Use explicit specified encoder config + if model_args.encoder_config_name: + encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir) + # Use pretrained encoder model's config + elif model_args.encoder_model_name_or_path: + encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir) + # Use specified encoder model type + elif model_args.encoder_model_type: + encoder_config = AutoConfig.for_model(model_args.encoder_model_type) + logger.warning("You are instantiating a new config instance from scratch for the encoder.") + else: + raise ValueError("Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required.") + + # Use explicit specified decoder config + if model_args.decoder_config_name: + decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir) + # Use pretrained decoder model's config + elif model_args.decoder_model_name_or_path: + decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir) + # Use specified decoder model type + elif model_args.decoder_model_type: + decoder_config = AutoConfig.for_model(model_args.decoder_model_type) + logger.warning("You are instantiating a new config instance from scratch for the decoder.") + else: + raise ValueError("Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required.") + + logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config) + # For self-contained model + else: + config = config_class() + logger.warning("You are instantiating a new config instance from scratch.") + + decoder_start_token_id = getattr(config, "decoder_start_token_id", None) + if not decoder_start_token_id and getattr(config, "decoder", None): + decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None) + bos_token_id = getattr(config, "bos_token_id", None) + if not bos_token_id and getattr(config, "decoder", None): + bos_token_id = getattr(config.decoder, "bos_token_id", None) + eos_token_id = getattr(config, "eos_token_id", None) + if not eos_token_id and getattr(config, "decoder", None): + eos_token_id = getattr(config.decoder, "eos_token_id", None) + pad_token_id = getattr(config, "pad_token_id", None) + if not pad_token_id and getattr(config, "decoder", None): + pad_token_id = getattr(config.decoder, "pad_token_id", None) + + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + if pad_token_id is None: + pad_token_id = eos_token_id + + if getattr(config, "decoder", None): + config.decoder.decoder_start_token_id = decoder_start_token_id + config.decoder.bos_token_id = bos_token_id + config.decoder.eos_token_id = eos_token_id + config.decoder.pad_token_id = pad_token_id + + # Set `encoder-decoder` (top-level) specific config + config.decoder_start_token_id = decoder_start_token_id + config.bos_token_id = bos_token_id + config.eos_token_id = eos_token_id + config.pad_token_id = pad_token_id + + if model_args.model_name_or_path: + model = FlaxAutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + # model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__] + model = FlaxAutoModelForVision2Seq.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + model_class = model.__class__ + + # encoder_class = FlaxAutoModel + # decoder_class = FlaxAutoModelForCausalLM + module = model.module.bind(model.params) + encoder_class_name = type(module.encoder).__name__.replace("Module", "Model") + decoder_class_name = type(module.decoder).__name__.replace("Module", "Model") + encoder_class = getattr(transformers, encoder_class_name, None) + decoder_class = getattr(transformers, decoder_class_name, None) + + if hasattr(model_class, "from_encoder_decoder_pretrained"): + + if model_args.encoder_model_name_or_path: + encoder = encoder_class.from_pretrained( + model_args.encoder_model_name_or_path, + config=config.encoder, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype) + ) + else: + encoder = encoder_class(config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + logger.warning("You are instantiating a new model instance from scratch for the encoder.") + + if model_args.decoder_model_name_or_path: + decoder = decoder_class.from_pretrained( + model_args.decoder_model_name_or_path, + config=config.decoder, + seed=training_args.seed, + dtype=getattr(jnp, model_args.dtype) + ) + else: + decoder = decoder_class(config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + logger.warning("You are instantiating a new model instance from scratch for the decoder.") + + model = model_class.from_encoder_decoder_pretrained( + model_args.encoder_model_name_or_path, + model_args.decoder_model_name_or_path, + encoder_model=encoder, + decoder_model=decoder, + encoder_config=config.encoder, + decoder_config=config.decoder, + encoder_seed=training_args.seed, + decoder_seed=training_args.seed, + encoder_dtype=getattr(jnp, model_args.dtype), + decoder_dtype=getattr(jnp, model_args.dtype), + ) + + # Set `encoder-decoder` (top-level) specific config + model.config.decoder_start_token_id = decoder_start_token_id + model.config.bos_token_id = bos_token_id + model.config.eos_token_id = eos_token_id + model.config.pad_token_id = pad_token_id + + else: + logger.warning("You are instantiating a new model instance from scratch.") + + feature_extractor = None + if model_args.feature_extractor_name: + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name, cache_dir=model_args.cache_dir, + ) + elif model_args.model_name_or_path: + try: + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir + ) + except ValueError as e: + logger.warning(e) + # Check encoder + if not feature_extractor: + if model_args.encoder_model_name_or_path: + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir + ) + else: + raise ValueError( + "You are instantiating a new feature extractor from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --feature_extractor_name." + ) + + tokenizer = None + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + elif model_args.model_name_or_path: + try: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + except ValueError as e: + logger.warning(e) + # Check decoder + if not tokenizer: + if model_args.decoder_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = dataset["train"].column_names + elif training_args.do_eval: + column_names = dataset["validation"].column_names + elif training_args.do_predict: + column_names = dataset["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # Get the column names for input/target. + dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None) + if data_args.image_column is None: + assert dataset_columns is not None + image_column = dataset_columns[0] + else: + image_column = data_args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if data_args.caption_column is None: + assert dataset_columns is not None + caption_column = dataset_columns[1] + else: + caption_column = data_args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # In Flax, for seq2seq models we need to pass `decoder_input_ids` + # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here + # for that dynamically import the `shift_tokens_right` function from the model file + model_module = __import__(model.__module__, fromlist=["shift_tokens_right"]) + shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right) + + def filter_fn(examples): + + bools = [] + for image_file in examples[image_column]: + with Image.open(image_file) as image: + try: + feature_extractor(images=image, return_tensors="np") + bools.append(True) + except: + bools.append(False) + + return bools + + # Setting padding="max_length" as we need fixed length inputs for jitted functions + def preprocess_function(examples, max_target_length): + + pixel_values = [] + captions = [] + for image_file, caption in zip(examples[image_column], examples[caption_column]): + with Image.open(image_file) as image: + try: + encoder_inputs = feature_extractor(images=image, return_tensors="np") + except: + continue + pixel_values.append(encoder_inputs.pixel_values) + captions.append(caption.lower() + ' ' + tokenizer.eos_token) + + pixel_values = np.concatenate(pixel_values) + targets = captions + + model_inputs = {} + model_inputs['pixel_values'] = pixel_values + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer( + targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np" + ) + + model_inputs["labels"] = labels["input_ids"] + decoder_input_ids = shift_tokens_right_fn( + jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id + ) + model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) + + # We need decoder_attention_mask so we can ignore pad tokens from loss + model_inputs["decoder_attention_mask"] = labels["attention_mask"] + + return model_inputs + + features = datasets.Features( + { + "pixel_values": datasets.Array3D( + shape=( + getattr(config.encoder, "num_channels", 3), + config.encoder.image_size, + config.encoder.image_size, + ), + dtype="float32", + ), + "labels": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None), + "decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None), + "decoder_attention_mask": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None), + } + ) + + if training_args.do_train: + if "train" not in dataset: + raise ValueError("--do_train requires a train dataset") + train_dataset = dataset["train"] + # remove problematic examples + train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers) + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in dataset: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = dataset["validation"] + # remove problematic examples + eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers) + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + + if training_args.do_predict: + if "test" not in dataset: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = dataset["test"] + # remove problematic examples + predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers) + if data_args.max_predict_samples is not None: + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) + + # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a + # data loader separately (in a sequential order). + block_size = training_args.block_size + + # Store some constant + + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + + if training_args.do_train: + steps_per_epoch = len(train_dataset) // train_batch_size + num_train_examples_per_epoch = steps_per_epoch * train_batch_size + num_epochs = int(training_args.num_train_epochs) + total_train_steps = steps_per_epoch * num_epochs + else: + num_train_examples_per_epoch = 0 + + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + + if training_args.do_eval: + num_eval_examples = len(eval_dataset) + eval_steps = num_eval_examples // eval_batch_size + int(num_eval_examples % eval_batch_size > 0) + + if training_args.do_predict: + num_test_examples = len(predict_dataset) + test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0) + + def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, split=""): + + _block_size = block_size + + if not block_size: + block_size = len(ds) + + steps_per_split = block_size // batch_size + num_examples = len(ds) + steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch) + num_splits = steps // steps_per_split + int(steps % steps_per_split > 0) + + if drop_last_batch: + num_examples = steps * batch_size + + if shuffle: + indices = jax.random.permutation(input_rng, len(ds)) + else: + indices = jnp.arange(len(ds)) + + # Temporarily set max_target_length for training or evaluation/prediction. + max_target_length = data_args.max_target_length + if split in ["valid", "test"]: + max_target_length = data_args.val_max_target_length + + for idx in range(num_splits): + + start_idx = block_size * idx + end_idx = block_size * (idx + 1) + + selected_indices = indices[start_idx:end_idx] + + _ds = ds.select(selected_indices) + + names = { + "train": "train", + "valid": "validation", + "test": "prediction", + } + + # load image files from disk + feature processing: multiprocessing + keep the results in memory + # (save the results to disk + load them in training step: > 40x much slower - possible issue for images) + _ds =_ds.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + features=features, + keep_in_memory=_block_size, + desc=f"Running tokenizer on {names[split]} dataset".replace(" ", " "), + fn_kwargs={"max_target_length": max_target_length}, + ) + _ds = _ds.with_format("numpy") + + # No need to shuffle here + loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False) + + for batch in loader: + yield batch + + # Metric + metric = load_metric("rouge") + + def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + def compute_metrics(preds, labels): + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Some simple post-processing + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + + result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) + # Extract a few results from ROUGE + result = {key: value.mid.fmeasure * 100 for key, value in result.items()} + + prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] + result["gen_len"] = np.mean(prediction_lens) + result = {k: round(v, 6) for k, v in result.items()} + + return result, decoded_preds, decoded_labels + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + num_train_examples_per_epoch, + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxBart. + # For FlaxT5, one should correct the layer norm parameter naming + # accordingly - see `run_t5_mlm_flax.py` e.g. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + layer_norm_params = [ + (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] + ] + flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} + return traverse_util.unflatten_dict(flat_mask) + + # create adam optimizer + adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Setup train state + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) + + # label smoothed cross entropy + def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): + """ + The label smoothing implementation is adapted from Flax's official example: + https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 + """ + vocab_size = logits.shape[-1] + confidence = 1.0 - label_smoothing_factor + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + ) + soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) + + loss = optax.softmax_cross_entropy(logits, soft_labels) + loss = loss - normalizing_constant + + # ignore padded tokens from loss + loss = loss * padding_mask + loss = loss.sum() / padding_mask.sum() + return loss + + # Define gradient update step fn + def train_step(state, batch, label_smoothing_factor=0.0): + dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + return loss + + grad_fn = jax.value_and_grad(compute_loss) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + + new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch, label_smoothing_factor=0.0): + labels = batch.pop("labels") + logits = model(**batch, params=params, train=False)[0] + loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) + + # summarize metrics + metrics = {"loss": loss} + metrics = jax.lax.pmean(metrics, axis_name="batch") + return metrics + + # Define generation function + max_length = ( + data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams + gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + + def generate_step(params, batch): + model.params = params + output_ids = model.generate(batch['pixel_values'], **gen_kwargs) + return output_ids.sequences + + # Create parallel version of the train and eval step + p_train_step = jax.pmap( + partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,) + ) + p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch") + p_generate_step = jax.pmap(generate_step, "batch") + + # Replicate the train state on each device + state = state.replicate() + + if training_args.do_train: + logger.info("***** Running training *****") + logger.info(f" Num train examples = {len(train_dataset)}") + logger.info(f" Num train examples per epoch = {num_train_examples_per_epoch}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous train batch size per device = {training_args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") + logger.info(f" Optimization steps per epoch = {steps_per_epoch}") + logger.info(f" Total optimization steps = {total_train_steps}") + if training_args.do_eval: + logger.info(f" Num evaluation examples = {num_eval_examples}") + logger.info(f" Instantaneous evaluation batch size per device = {training_args.per_device_eval_batch_size}") + logger.info(f" Total evaluation batch size (w. parallel & distributed) = {eval_batch_size}") + logger.info(f" Evaluation steps = {eval_steps}") + if training_args.do_predict: + logger.info(f" Num test examples = {num_test_examples}") + logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}") + logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}") + logger.info(f" Total train batch size (w. parallel & distributed) = {eval_batch_size}") + logger.info(f" Test steps = {test_steps}") + + # create output directory + if not os.path.isdir(os.path.join(training_args.output_dir)): + os.makedirs(os.path.join(training_args.output_dir), exist_ok=True) + + def save_results(epoch, step): + + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + dir_name = f'ckpt_epoch_{epoch + 1}_step_{step}' + model.save_pretrained(os.path.join(training_args.output_dir, dir_name), params=params) + tokenizer.save_pretrained(os.path.join(training_args.output_dir, dir_name)) + if training_args.push_to_hub: + commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}" + repo.push_to_hub(commit_message=commit_msg, blocking=False) + + def run_eval_or_test(rng, dataset, name, is_inside_training=True): + + if name not in ["valid", "test"]: + raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {name} instead.") + + logger.info(f"*** {'Predict' if name == 'test' else 'Evaluate'} ***") + + metrics = [] + preds = [] + labels = [] + + batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, shuffle=False, split=name) + steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0) + for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False): + # Model forward + batch = next(batches) + _labels = batch.get("labels", None) + if name == "valid" and _labels is None: + raise ValueError("Validation dataset requires `labels`") + + if _labels is not None: + _metrics = p_eval_step(state.params, batch) + metrics.append(_metrics) + + # generation + if data_args.predict_with_generate: + generated_ids = p_generate_step(state.params, batch) + preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) + if _labels is not None: + labels.extend(jax.device_get(_labels.reshape(-1, _labels.shape[-1]))) + + if metrics: + # normalize metrics + metrics = get_metrics(metrics) + metrics = jax.tree_map(jnp.mean, metrics) + + # compute ROUGE metrics + generations = [] + rouge_desc = "" + if data_args.predict_with_generate: + if labels: + rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels) + metrics.update(rouge_metrics) + rouge_desc = " ".join([f"{'Predict' if name == 'test' else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()]) + for pred, label in zip(decoded_preds, decoded_labels): + pred = pred.replace("\n", " ") + label = label.replace("\n", " ") + generations.append({"label": label, "pred": pred}) + else: + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + # Some simple post-processing + decoded_preds = [pred.strip() for pred in decoded_preds] + # rougeLSum expects newline after each sentence + decoded_preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in decoded_preds] + for pred in decoded_preds: + pred = pred.replace("\n", " ") + generations.append({"pred": pred}) + + if metrics: + # Print metrics and update progress bar + desc = f"{'Predict' if name == 'test' else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})" + if is_inside_training: + desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc + epochs.write(desc) + epochs.desc = desc + logger.info(desc) + + if jax.process_index() == 0: + + ckpt_dir = "" + if is_inside_training: + ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}' + if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)): + os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True) + + if metrics: + + # save final metrics in json + metrics = {f"{name}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()} + path = os.path.join(training_args.output_dir, ckpt_dir, f"{name}_results.json") + with open(path, "w") as f: + json.dump(metrics, f, indent=4, sort_keys=True) + + # Update report + with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp: + fp.write(desc + '\n') + + # Save metrics + if has_tensorboard and is_inside_training: + write_metric(summary_writer, name, metrics, cur_step) + + # Save generations + if generations: + with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{name}.json'), 'w', encoding='UTF-8') as fp: + json.dump(generations, fp, ensure_ascii=False, indent=4) + + input_rng = None + + if training_args.do_train: + + cur_step = 0 + train_time = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + + for epoch in epochs: + + # ======================== Training ================================ + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + train_metrics = [] + + train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train") + + # train + for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)): + + cur_step += 1 + batch = next(train_batches) + batch_start = time.time() + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + train_time += time.time() - batch_start + + if cur_step % training_args.logging_steps == 0 or (training_args.eval_steps is not None and cur_step % training_args.eval_steps == 0) or cur_step % steps_per_epoch == 0: + + time_per_step = train_time / cur_step + + _train_metric = unreplicate(train_metric) + desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})" + epochs.write(desc) + epochs.desc = desc + logger.info(desc) + with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp: + fp.write(desc + '\n') + + # Save metrics + if has_tensorboard and jax.process_index() == 0: + write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time) + + # ======================== Evaluating ============================== + + if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps) or cur_step % steps_per_epoch == 0): + run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True) + + # ======================== Prediction loop ============================== + + # run prediction after evaluation if specified, otherwise only after each epoch + if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation: + run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True) + + # ======================== Save ============================== + + save_results(epoch + 1, cur_step) + + # run prediction after each epoch (if not done during training) + if training_args.do_predict and training_args.do_predict_during_training and not training_args.do_predict_after_evaluation: + run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True) + save_results(epoch + 1, cur_step) + + # Create sampling rng + if input_rng is None: + rng, input_rng = jax.random.split(rng) + + # run prediction after each epoch (if not done during training) + if training_args.do_predict and not (training_args.do_train and training_args.do_predict_during_training): + run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=False) + + +if __name__ == "__main__": + + main() diff --git a/usage_example.txt b/usage_example.txt new file mode 100644 index 0000000000000000000000000000000000000000..cdc0750c729d2b2382febe85e55d091c304d2912 --- /dev/null +++ b/usage_example.txt @@ -0,0 +1,3 @@ +python3 run_image_captioning_flax.py --dataset_name ./coco_dataset/coco_dataset.py --dataset_config_name=2017 --data_dir ./coco_dataset/dummy_data/ --caption_column caption --image_column image_path --max_target_length 8 --num_beams 2 --output_dir debug --do_train --per_device_train_batch_size 2 --do_eval --per_device_eval_batch_size 4 --do_predict --max_train_samples 16 --max_eval_samples 16 --max_predict_samples 16 --preprocessing_num_workers 2 --num_train_epochs 3 --learning_rate 3e-5 --do_predict_after_evaluation --predict_with_generate --logging_steps 2 --block_size 8 --encoder_model_name_or_path ./encoder-decoder/encoder/ --decoder_model_name_or_path ./encoder-decoder/decoder/ + +python3 run_image_captioning_flax.py --dataset_name ./image_caption_dataset/image_caption_dataset.py --dataset_config_name=coco_2017 --data_dir ./image_caption_dataset/coco_dataset/ --caption_column en --image_column image_path --max_target_length 8 --num_beams 2 --output_dir debug --do_train --per_device_train_batch_size 2 --do_eval --per_device_eval_batch_size 4 --do_predict --max_train_samples 16 --max_eval_samples 16 --max_predict_samples 16 --preprocessing_num_workers 2 --num_train_epochs 3 --learning_rate 3e-5 --do_predict_after_evaluation --predict_with_generate --logging_steps 2 --block_size 8 --encoder_model_name_or_path ./encoder-decoder/encoder/ --decoder_model_name_or_path ./encoder-decoder/decoder