boris commited on
Commit
f234ccf
·
2 Parent(s): e5a52b9 26651dd

Merge branch 'main' of https://github.com/borisdayma/dalle-mini into add-custom-model

Browse files
Files changed (45) hide show
  1. .github/workflows/check_size.yml +17 -0
  2. .github/workflows/style.yml +20 -0
  3. .github/workflows/sync_to_hub.yml +20 -0
  4. .github/workflows/sync_to_hub_debug.yml +17 -0
  5. .gitignore +4 -0
  6. CITATION.cff +44 -0
  7. LICENSE +201 -0
  8. Makefile +5 -0
  9. README.md +144 -30
  10. app/gradio/app_gradio.py +179 -0
  11. app/gradio/requirements.txt +4 -0
  12. app/streamlit/app.py +117 -0
  13. app/streamlit/img/loading.gif +0 -0
  14. dalle_mini/data.py +261 -0
  15. dalle_mini/dataset.py +0 -122
  16. dalle_mini/model.py +64 -0
  17. dalle_mini/text.py +258 -0
  18. dalle_mini/vqgan_jax/__init__.py +0 -0
  19. dalle_mini/vqgan_jax/configuration_vqgan.py +0 -40
  20. dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +0 -109
  21. dalle_mini/vqgan_jax/modeling_flax_vqgan.py +0 -609
  22. data/CC12M_downloader.py +0 -91
  23. data/CC3M_downloader.py +0 -62
  24. demo/CustomBARTv4b_model-generate.ipynb +0 -566
  25. demo/demo_notebook.ipynb +0 -583
  26. encoding/vqgan-jax-encoding-with-captions.ipynb +0 -363
  27. encoding/vqgan-jax-encoding-yfcc100m.ipynb +0 -1136
  28. encoding/vqgan-jax-encoding.ipynb +0 -0
  29. environment.yaml +0 -10
  30. img/logo.png +0 -0
  31. model/data-pipeline.ipynb +0 -385
  32. pyproject.toml +2 -0
  33. requirements.txt +0 -9
  34. seq2seq/do_big_run.sh +0 -16
  35. seq2seq/do_small_run.sh +0 -16
  36. seq2seq/requirements.txt +0 -8
  37. seq2seq/run_seq2seq_flax.py +0 -897
  38. setup.cfg +27 -0
  39. setup.py +4 -0
  40. tools/dataset/encode_dataset.ipynb +371 -0
  41. tools/inference/inference_pipeline.ipynb +0 -0
  42. tools/inference/log_inference_samples.ipynb +434 -0
  43. tools/inference/samples.txt +124 -0
  44. {seq2seq → tools/train}/sweep.yaml +34 -23
  45. tools/train/train.py +857 -0
.github/workflows/check_size.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Check file size
2
+
3
+ on:
4
+ pull_request:
5
+ branches: [main]
6
+
7
+ # to run this workflow manually from the Actions tab
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ sync-to-hub:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - name: Check large files
15
+ uses: ActionsDesk/lfs-warning@v2.0
16
+ with:
17
+ filesizelimit: 10485760 # = 10MB, so we can sync to HF spaces
.github/workflows/style.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ lint:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - uses: psf/black@stable
15
+ - uses: actions/setup-python@v2
16
+ with:
17
+ python-version: 3.9
18
+ - name: Install requirements
19
+ run: pip install ".[dev]"
20
+ - uses: jamescurtin/isort-action@master
.github/workflows/sync_to_hub.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+
7
+ # to run this workflow manually from the Actions tab
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ sync-to-hub:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - uses: actions/checkout@v2
15
+ with:
16
+ fetch-depth: 0
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://boris:$HF_TOKEN@huggingface.co/spaces/flax-community/dalle-mini main
.github/workflows/sync_to_hub_debug.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to debug app
2
+
3
+ on:
4
+ # to run this workflow manually from the Actions tab
5
+ workflow_dispatch:
6
+
7
+ jobs:
8
+ sync-to-hub-debug:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v2
12
+ with:
13
+ fetch-depth: 0
14
+ - name: Push to hub
15
+ env:
16
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
17
+ run: git push --force https://boris:$HF_TOKEN@huggingface.co/spaces/flax-community/dalle-mini-debug +HEAD:main
.gitignore CHANGED
@@ -1 +1,5 @@
1
  __pycache__
 
 
 
 
 
1
  __pycache__
2
+ .ipynb_checkpoints
3
+ .streamlit
4
+ wandb/
5
+ *.egg-info/
CITATION.cff ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YAML 1.2
2
+ ---
3
+ abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
4
+ authors:
5
+ -
6
+ family-names: Dayma
7
+ given-names: Boris
8
+ -
9
+ family-names: Patil
10
+ given-names: Suraj
11
+ -
12
+ family-names: Cuenca
13
+ given-names: Pedro
14
+ -
15
+ family-names: Saifullah
16
+ given-names: Khalid
17
+ -
18
+ family-names: Abraham
19
+ given-names: Tanishq
20
+ -
21
+ family-names: "Lê Khắc"
22
+ given-names: "Phúc"
23
+ -
24
+ family-names: Melas
25
+ given-names: Luke
26
+ -
27
+ family-names: Ghosh
28
+ given-names: Ritobrata
29
+ cff-version: "1.1.0"
30
+ date-released: 2021-07-29
31
+ identifiers:
32
+ keywords:
33
+ - dalle
34
+ - "text-to-image generation"
35
+ - transformer
36
+ - "zero-shot"
37
+ - JAX
38
+ license: "Apache-2.0"
39
+ doi: 10.5281/zenodo.5146400
40
+ message: "If you use this project, please cite it using these metadata."
41
+ repository-code: "https://github.com/borisdayma/dalle-mini"
42
+ title: "DALL·E Mini"
43
+ version: "v0.1-alpha"
44
+ ...
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 The DALL·E mini Authors
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: style
2
+
3
+ style:
4
+ black .
5
+ isort .
README.md CHANGED
@@ -1,42 +1,156 @@
1
- ## DALL-E Mini - Generate image from text
 
 
 
 
 
 
 
 
2
 
3
- ## Tentative Strategy of training (proposed by Luke and Suraj)
4
 
5
- ### Data:
6
- * [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m) Dataset (already loaded and preprocessed in TPU VM by Luke).
7
- * [YFCC100M Subset](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md)
8
- * [Coneptual Captions 3M](https://github.com/google-research-datasets/conceptual-captions)
9
 
10
- ### Architecture:
11
- * Use the Taming Transformers VQ-GAN (with 16384 tokens)
12
- * Use a seq2seq (language encoder --> image decoder) model with a pretrained non-autoregressive encoder (e.g. BERT) and an autoregressive decoder (like GPT).
13
 
14
- ### Remaining Architecture Questions:
15
- * Whether to freeze the text encoder?
16
- * Whether to finetune the VQ-GAN?
17
- * Which text encoder to use (e.g. BERT, RoBERTa, etc.)?
18
- * Hyperparameter choices for the decoder (e.g. positional embedding, initialization, etc.)
19
 
20
- ## TODO
21
 
22
- * experiment with flax/jax and setup of the TPU instance that we should get shortly
23
- * work on dataset loading - [see suggested datasets](https://discuss.huggingface.co/t/dall-e-mini-version/7324/4)
24
- * Optionally create the OpenAI YFCC100M subset (see [this post](https://discuss.huggingface.co/t/dall-e-mini-version/7324/30?u=boris))
25
- * work on text/image encoding
26
- * concatenate inputs (not sure if we need fixed length for text or use a special token separating text & image)
27
- * adapt training script
28
- * create inference function
29
- * integrate CLIP for better results (only if we have the time)
30
- * work on a demo (streamlit or colab or maybe just HF widget)
31
- * document (set up repo on model hub per instructions, start on README writeup…)
32
- * help with coordinating activities & progress
33
 
 
34
 
35
- ## Dependencies Installation
36
- You should create a new python virtual environment and install the project dependencies inside the virtual env. You need to use the `-f` (`--find-links`) option for `pip` to be able to find the appropriate `libtpu` required for the TPU hardware:
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ```
39
- $ pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ```
41
 
42
- If you use `conda`, you can create the virtual env and install everything using: `conda env update -f environments.yaml`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DALL·E mini
3
+ emoji: 🥑
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: streamlit
7
+ app_file: app/streamlit/app.py
8
+ pinned: True
9
+ ---
10
 
11
+ # DALL·E Mini
12
 
13
+ [![Join us on Discord](https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white)](https://discord.gg/xBPBXfcFHd)
 
 
 
14
 
15
+ _Generate images from a text prompt_
 
 
16
 
17
+ <img src="img/logo.png" width="200">
 
 
 
 
18
 
19
+ Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
20
 
21
+ You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).
 
 
 
 
 
 
 
 
 
 
22
 
23
+ ## How does it work?
24
 
25
+ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
 
26
 
27
+ ## Development
28
+
29
+ ### Dependencies Installation
30
+
31
+ For inference only, use `pip install git+https://github.com/borisdayma/dalle-mini.git`.
32
+
33
+ For development, clone the repo and use `pip install -e ".[dev]"`.
34
+
35
+ ### Training of VQGAN
36
+
37
+ The VQGAN was trained using [taming-transformers](https://github.com/CompVis/taming-transformers).
38
+
39
+ We recommend using the latest version available.
40
+
41
+ ### Conversion of VQGAN to JAX
42
+
43
+ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
44
+
45
+ ### Training of Seq2Seq
46
+
47
+ Use [`tools/train/train.py`](tools/train/train.py).
48
+
49
+ You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
50
+
51
+ ### Inference Pipeline
52
+
53
+ To generate sample predictions and understand the inference pipeline step by step, refer to [`tools/inference/inference_pipeline.ipynb`](tools/inference/inference_pipeline.ipynb).
54
+
55
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb)
56
+
57
+ ## FAQ
58
+
59
+ ### Where to find the latest models?
60
+
61
+ Trained models are on 🤗 Model Hub:
62
+
63
+ - [VQGAN-f16-16384](https://huggingface.co/flax-community/vqgan_f16_16384) for encoding/decoding images
64
+ - [DALL·E mini](https://huggingface.co/flax-community/dalle-mini) for generating images from a text prompt
65
+
66
+ ### Where does the logo come from?
67
+
68
+ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
69
+
70
+ ## Authors & Contributors
71
+
72
+ ### Main Authors
73
+
74
+ - [Boris Dayma](https://github.com/borisdayma)
75
+ - [Suraj Patil](https://github.com/patil-suraj)
76
+ - [Pedro Cuenca](https://github.com/pcuenca)
77
+
78
+ ### Other Members of dalle-mini team during FLAX/JAX community week
79
+
80
+ - [Khalid Saifullah](https://github.com/khalidsaifullaah)
81
+ - [Tanishq Abraham](https://github.com/tmabraham)
82
+ - [Phúc Lê Khắc](https://github.com/lkhphuc)
83
+ - [Luke Melas](https://github.com/lukemelas)
84
+ - [Ritobrata Ghosh](https://github.com/ghosh-r)
85
+
86
+ ### Contributing
87
+
88
+ Join the community on the [DALLE-Pytorch Discord](https://discord.gg/xBPBXfcFHd).
89
+ Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model on cool prompts!
90
+
91
+ ## Acknowledgements
92
+
93
+ - 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
94
+ - Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
95
+ - [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
96
+
97
+ ## Citing DALL·E mini
98
+
99
+ If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
100
+
101
+ ```
102
+ @misc{Dayma_DALL·E_Mini_2021,
103
+ author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
104
+ doi = {10.5281/zenodo.5146400},
105
+ month = {7},
106
+ title = {DALL·E Mini},
107
+ url = {https://github.com/borisdayma/dalle-mini},
108
+ year = {2021}
109
+ }
110
  ```
111
+
112
+ ## References
113
+
114
+ ```
115
+ @misc{ramesh2021zeroshot,
116
+ title={Zero-Shot Text-to-Image Generation},
117
+ author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
118
+ year={2021},
119
+ eprint={2102.12092},
120
+ archivePrefix={arXiv},
121
+ primaryClass={cs.CV}
122
+ }
123
+ ```
124
+
125
+ ```
126
+ @misc{esser2021taming,
127
+ title={Taming Transformers for High-Resolution Image Synthesis},
128
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
129
+ year={2021},
130
+ eprint={2012.09841},
131
+ archivePrefix={arXiv},
132
+ primaryClass={cs.CV}
133
+ }
134
  ```
135
 
136
+ ```
137
+ @misc{lewis2019bart,
138
+ title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
139
+ author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
140
+ year={2019},
141
+ eprint={1910.13461},
142
+ archivePrefix={arXiv},
143
+ primaryClass={cs.CL}
144
+ }
145
+ ```
146
+
147
+ ```
148
+ @misc{radford2021learning,
149
+ title={Learning Transferable Visual Models From Natural Language Supervision},
150
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
151
+ year={2021},
152
+ eprint={2103.00020},
153
+ archivePrefix={arXiv},
154
+ primaryClass={cs.CV}
155
+ }
156
+ ```
app/gradio/app_gradio.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # Uncomment to run on cpu
5
+ # import os
6
+ # os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
+
8
+ import random
9
+
10
+ import gradio as gr
11
+ import jax
12
+ import numpy as np
13
+ from flax.jax_utils import replicate
14
+ from flax.training.common_utils import shard
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ # ## CLIP Scoring
18
+ from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
19
+ from vqgan_jax.modeling_flax_vqgan import VQModel
20
+
21
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
22
+
23
+ DALLE_REPO = "flax-community/dalle-mini"
24
+ DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
25
+
26
+ VQGAN_REPO = "flax-community/vqgan_f16_16384"
27
+ VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
28
+
29
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
30
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
31
+ DALLE_REPO, revision=DALLE_COMMIT_ID
32
+ )
33
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
34
+
35
+
36
+ def captioned_strip(images, caption=None, rows=1):
37
+ increased_h = 0 if caption is None else 48
38
+ w, h = images[0].size[0], images[0].size[1]
39
+ img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
40
+ for i, img_ in enumerate(images):
41
+ img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
42
+
43
+ if caption is not None:
44
+ draw = ImageDraw.Draw(img)
45
+ font = ImageFont.truetype(
46
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
47
+ )
48
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
49
+ return img
50
+
51
+
52
+ def custom_to_pil(x):
53
+ x = np.clip(x, 0.0, 1.0)
54
+ x = (255 * x).astype(np.uint8)
55
+ x = Image.fromarray(x)
56
+ if not x.mode == "RGB":
57
+ x = x.convert("RGB")
58
+ return x
59
+
60
+
61
+ def generate(input, rng, params):
62
+ return model.generate(
63
+ **input,
64
+ max_length=257,
65
+ num_beams=1,
66
+ do_sample=True,
67
+ prng_key=rng,
68
+ eos_token_id=50000,
69
+ pad_token_id=50000,
70
+ params=params,
71
+ )
72
+
73
+
74
+ def get_images(indices, params):
75
+ return vqgan.decode_code(indices, params=params)
76
+
77
+
78
+ p_generate = jax.pmap(generate, "batch")
79
+ p_get_images = jax.pmap(get_images, "batch")
80
+
81
+ bart_params = replicate(model.params)
82
+ vqgan_params = replicate(vqgan.params)
83
+
84
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
85
+ print("Initialize FlaxCLIPModel")
86
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
87
+ print("Initialize CLIPProcessor")
88
+
89
+
90
+ def hallucinate(prompt, num_images=64):
91
+ prompt = [prompt] * jax.device_count()
92
+ inputs = tokenizer(
93
+ prompt,
94
+ return_tensors="jax",
95
+ padding="max_length",
96
+ truncation=True,
97
+ max_length=128,
98
+ ).data
99
+ inputs = shard(inputs)
100
+
101
+ all_images = []
102
+ for i in range(num_images // jax.device_count()):
103
+ key = random.randint(0, 1e7)
104
+ rng = jax.random.PRNGKey(key)
105
+ rngs = jax.random.split(rng, jax.local_device_count())
106
+ indices = p_generate(inputs, rngs, bart_params).sequences
107
+ indices = indices[:, :, 1:]
108
+
109
+ images = p_get_images(indices, vqgan_params)
110
+ images = np.squeeze(np.asarray(images), 1)
111
+ for image in images:
112
+ all_images.append(custom_to_pil(image))
113
+ return all_images
114
+
115
+
116
+ def clip_top_k(prompt, images, k=8):
117
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
118
+ outputs = clip(**inputs)
119
+ logits = outputs.logits_per_text
120
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
121
+ return [images[score] for score in scores]
122
+
123
+
124
+ def compose_predictions(images, caption=None):
125
+ increased_h = 0 if caption is None else 48
126
+ w, h = images[0].size[0], images[0].size[1]
127
+ img = Image.new("RGB", (len(images) * w, h + increased_h))
128
+ for i, img_ in enumerate(images):
129
+ img.paste(img_, (i * w, increased_h))
130
+
131
+ if caption is not None:
132
+ draw = ImageDraw.Draw(img)
133
+ font = ImageFont.truetype(
134
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
135
+ )
136
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
137
+ return img
138
+
139
+
140
+ def top_k_predictions(prompt, num_candidates=32, k=8):
141
+ images = hallucinate(prompt, num_images=num_candidates)
142
+ images = clip_top_k(prompt, images, k=k)
143
+ return images
144
+
145
+
146
+ def run_inference(prompt, num_images=32, num_preds=8):
147
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
148
+ predictions = captioned_strip(images)
149
+ output_title = f"""
150
+ <b>{prompt}</b>
151
+ """
152
+ return (output_title, predictions)
153
+
154
+
155
+ outputs = [
156
+ gr.outputs.HTML(label=""), # To be used as title
157
+ gr.outputs.Image(label=""),
158
+ ]
159
+
160
+ description = """
161
+ DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
162
+ """
163
+ gr.Interface(
164
+ run_inference,
165
+ inputs=[gr.inputs.Textbox(label="What do you want to see?")],
166
+ outputs=outputs,
167
+ title="DALL·E mini",
168
+ description=description,
169
+ article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
170
+ layout="vertical",
171
+ theme="huggingface",
172
+ examples=[
173
+ ["an armchair in the shape of an avocado"],
174
+ ["snowy mountains by the sea"],
175
+ ],
176
+ allow_flagging=False,
177
+ live=False,
178
+ # server_port=8999
179
+ ).launch(share=True)
app/gradio/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Requirements for huggingface spaces
2
+ gradio>=2.2.3
3
+ flax
4
+ transformers
app/streamlit/app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import base64
5
+ from io import BytesIO
6
+
7
+ import requests
8
+ import streamlit as st
9
+ from PIL import Image
10
+
11
+
12
+ class ServiceError(Exception):
13
+ def __init__(self, status_code):
14
+ self.status_code = status_code
15
+
16
+
17
+ def get_images_from_backend(prompt, backend_url):
18
+ r = requests.post(backend_url, json={"prompt": prompt})
19
+ if r.status_code == 200:
20
+ images = r.json()["images"]
21
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
22
+ return images
23
+ else:
24
+ raise ServiceError(r.status_code)
25
+
26
+
27
+ st.sidebar.markdown(
28
+ """
29
+ <style>
30
+ .aligncenter {
31
+ text-align: center;
32
+ }
33
+ </style>
34
+ <p class="aligncenter">
35
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
36
+ </p>
37
+ """,
38
+ unsafe_allow_html=True,
39
+ )
40
+ st.sidebar.markdown(
41
+ """
42
+ ___
43
+ <p style='text-align: center'>
44
+ DALL·E mini is an AI model that generates images from any prompt you give!
45
+ </p>
46
+
47
+ <p style='text-align: center'>
48
+ Created by Boris Dayma et al. 2021
49
+ <br/>
50
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
51
+ </p>
52
+ """,
53
+ unsafe_allow_html=True,
54
+ )
55
+
56
+ st.header("DALL·E mini")
57
+ st.subheader("Generate images from text")
58
+
59
+ prompt = st.text_input("What do you want to see?")
60
+
61
+ DEBUG = False
62
+ if prompt != "":
63
+ container = st.empty()
64
+ container.markdown(
65
+ f"""
66
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
67
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
68
+ <div class="stAlert">
69
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
70
+ <div class="st-b7">
71
+ <div class="css-whx05o e13vu3m50">
72
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
73
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
74
+ Generating predictions for: <b>{prompt}</b>
75
+ </div>
76
+ </div>
77
+ </div>
78
+ </div>
79
+ </div>
80
+ </div>
81
+ <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
82
+ """,
83
+ unsafe_allow_html=True,
84
+ )
85
+
86
+ try:
87
+ backend_url = st.secrets["BACKEND_SERVER"]
88
+ print(f"Getting selections: {prompt}")
89
+ selected = get_images_from_backend(prompt, backend_url)
90
+
91
+ margin = 0.1 # for better position of zoom in arrow
92
+ n_columns = 3
93
+ cols = st.columns([1] + [margin, 1] * (n_columns - 1))
94
+ for i, img in enumerate(selected):
95
+ cols[(i % n_columns) * 2].image(img)
96
+ container.markdown(f"**{prompt}**")
97
+
98
+ st.button("Again!", key="again_button")
99
+
100
+ except ServiceError as error:
101
+ container.text(f"Service unavailable, status: {error.status_code}")
102
+ except KeyError:
103
+ if DEBUG:
104
+ container.markdown(
105
+ """
106
+ **Error: BACKEND_SERVER unset**
107
+
108
+ Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
109
+ ```
110
+ BACKEND_SERVER="<server url>"
111
+ ```
112
+ """
113
+ )
114
+ else:
115
+ container.markdown(
116
+ "Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)."
117
+ )
app/streamlit/img/loading.gif ADDED
dalle_mini/data.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from functools import partial
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ from datasets import Dataset, load_dataset
8
+ from flax.training.common_utils import shard
9
+
10
+ from .text import TextNormalizer
11
+
12
+
13
+ @dataclass
14
+ class Dataset:
15
+ dataset_repo_or_path: str
16
+ train_file: str = None
17
+ validation_file: str = None
18
+ dataset_type: str = "dataset"
19
+ streaming: bool = True
20
+ use_auth_token: bool = False
21
+ text_column: str = "caption"
22
+ encoding_column: str = "encoding"
23
+ max_source_length: int = 128
24
+ max_train_samples: int = None
25
+ max_eval_samples: int = None
26
+ preprocessing_num_workers: int = None
27
+ overwrite_cache: bool = False
28
+ do_train: bool = False
29
+ do_eval: bool = True
30
+ seed_dataset: int = None
31
+ train_dataset: Dataset = field(init=False)
32
+ eval_dataset: Dataset = field(init=False)
33
+ rng_dataset: jnp.ndarray = field(init=False)
34
+
35
+ def __post_init__(self):
36
+ # define data_files
37
+ if self.train_file is not None or self.validation_file is not None:
38
+ data_files = {
39
+ "train": self.train_file,
40
+ "validation": self.validation_file,
41
+ }
42
+ else:
43
+ data_files = None
44
+
45
+ # load dataset
46
+ dataset = load_dataset(
47
+ self.dataset_repo_or_path,
48
+ data_files=data_files,
49
+ streaming=self.streaming,
50
+ use_auth_token=self.use_auth_token,
51
+ )
52
+ if self.do_train:
53
+ if "train" not in dataset:
54
+ raise ValueError("Training requires a training dataset")
55
+ self.train_dataset = dataset["train"]
56
+ if self.max_train_samples is not None:
57
+ self.train_dataset = (
58
+ self.train_dataset.take(self.max_train_samples)
59
+ if self.streaming
60
+ else self.train_dataset.select(range(self.max_train_samples))
61
+ )
62
+ if self.do_eval:
63
+ if "validation" not in dataset:
64
+ raise ValueError("Evaluating requires a validation dataset")
65
+ self.eval_dataset = dataset["validation"]
66
+ if self.max_eval_samples is not None:
67
+ self.eval_dataset = (
68
+ self.eval_dataset.take(self.max_eval_samples)
69
+ if self.streaming
70
+ else self.eval_dataset.select(range(self.max_eval_samples))
71
+ )
72
+
73
+ def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
74
+ if self.streaming:
75
+ # we need to shuffle early in streaming mode
76
+ if hasattr(self, "train_dataset"):
77
+ self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset)
78
+ else:
79
+ # prepare rng for later shuffling
80
+ if self.seed_dataset is None:
81
+ self.seed_dataset = np.random.get_state()[1][0]
82
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
83
+
84
+ # normalize text
85
+ if normalize_text:
86
+ text_normalizer = TextNormalizer()
87
+ partial_normalize_function = partial(
88
+ normalize_function,
89
+ text_column=self.text_column,
90
+ text_normalizer=text_normalizer,
91
+ )
92
+ for ds in ["train_dataset", "eval_dataset"]:
93
+ if hasattr(self, ds):
94
+ setattr(
95
+ self,
96
+ ds,
97
+ (
98
+ getattr(self, ds).map(partial_normalize_function)
99
+ if self.streaming
100
+ else getattr(self, ds).map(
101
+ partial_normalize_function,
102
+ num_proc=self.preprocessing_num_workers,
103
+ load_from_cache_file=not self.overwrite_cache,
104
+ desc="Normalizing datasets",
105
+ )
106
+ ),
107
+ )
108
+
109
+ # preprocess
110
+ partial_preprocess_function = partial(
111
+ preprocess_function,
112
+ tokenizer=tokenizer,
113
+ text_column=self.text_column,
114
+ encoding_column=self.encoding_column,
115
+ max_source_length=self.max_source_length,
116
+ decoder_start_token_id=decoder_start_token_id,
117
+ )
118
+ for ds in ["train_dataset", "eval_dataset"]:
119
+ if hasattr(self, ds):
120
+ setattr(
121
+ self,
122
+ ds,
123
+ (
124
+ getattr(self, ds).map(
125
+ partial_preprocess_function,
126
+ batched=True,
127
+ )
128
+ if self.streaming
129
+ else getattr(self, ds).map(
130
+ partial_preprocess_function,
131
+ batched=True,
132
+ remove_columns=getattr(ds, "column_names"),
133
+ num_proc=self.preprocessing_num_workers,
134
+ load_from_cache_file=not self.overwrite_cache,
135
+ desc="Preprocessing datasets",
136
+ )
137
+ ),
138
+ )
139
+
140
+ def dataloader(self, split, batch_size, epoch=None):
141
+ def _dataloader_datasets_non_streaming(
142
+ dataset: Dataset,
143
+ batch_size: int,
144
+ rng: jax.random.PRNGKey = None,
145
+ ):
146
+ """
147
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
148
+ Shuffle batches if `shuffle` is `True`.
149
+ """
150
+ steps_per_epoch = len(dataset) // batch_size
151
+
152
+ if rng is not None:
153
+ batch_idx = jax.random.permutation(rng, len(dataset))
154
+ else:
155
+ batch_idx = jnp.arange(len(dataset))
156
+
157
+ batch_idx = batch_idx[
158
+ : steps_per_epoch * batch_size
159
+ ] # Skip incomplete batch.
160
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
161
+
162
+ for idx in batch_idx:
163
+ batch = dataset[idx]
164
+ batch = {k: jnp.array(v) for k, v in batch.items()}
165
+ batch = shard(batch)
166
+ yield batch
167
+
168
+ def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
169
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
170
+ batch = {k: [] for k in keys}
171
+ for item in dataset:
172
+ for k, v in item.items():
173
+ batch[k].append(v)
174
+ if len(batch[keys[0]]) == batch_size:
175
+ batch = {k: jnp.array(v) for k, v in batch.items()}
176
+ batch = shard(batch)
177
+ yield batch
178
+ batch = {k: [] for k in keys}
179
+
180
+ if split == "train":
181
+ ds = self.train_dataset
182
+ elif split == "eval":
183
+ ds = self.eval_dataset
184
+ else:
185
+ raise ValueError(f'split must be "train" or "eval", got {split}')
186
+
187
+ if self.streaming:
188
+ if split == "train":
189
+ ds.set_epoch(epoch)
190
+ return _dataloader_datasets_streaming(ds, batch_size)
191
+ else:
192
+ if split == "train":
193
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
194
+ return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
195
+
196
+ @property
197
+ def length(self):
198
+ len_train_dataset, len_eval_dataset = None, None
199
+ if self.streaming:
200
+ # we don't know the length, let's just assume max_samples if defined
201
+ if self.max_train_samples is not None:
202
+ len_train_dataset = self.max_train_samples
203
+ if self.max_eval_samples is not None:
204
+ len_eval_dataset = self.max_eval_samples
205
+ else:
206
+ len_train_dataset = (
207
+ len(self.train_dataset) if hasattr(self, "train_dataset") else None
208
+ )
209
+ len_eval_dataset = (
210
+ len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
211
+ )
212
+ return len_train_dataset, len_eval_dataset
213
+
214
+
215
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
216
+ """
217
+ Shift input ids one token to the right.
218
+ """
219
+ shifted_input_ids = np.zeros(input_ids.shape)
220
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
221
+ shifted_input_ids[:, 0] = decoder_start_token_id
222
+ return shifted_input_ids
223
+
224
+
225
+ def normalize_function(example, text_column, text_normalizer):
226
+ example[text_column] = text_normalizer(example[text_column])
227
+ return example
228
+
229
+
230
+ def preprocess_function(
231
+ examples,
232
+ tokenizer,
233
+ text_column,
234
+ encoding_column,
235
+ max_source_length,
236
+ decoder_start_token_id,
237
+ ):
238
+ inputs = examples[text_column]
239
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
240
+ model_inputs = tokenizer(
241
+ inputs,
242
+ max_length=max_source_length,
243
+ padding="max_length",
244
+ truncation=True,
245
+ return_tensors="np",
246
+ )
247
+
248
+ # set up targets
249
+ # Note: labels correspond to our target indices
250
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
251
+ labels = examples[encoding_column]
252
+ labels = np.asarray(labels)
253
+
254
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
255
+ model_inputs["labels"] = labels
256
+
257
+ # In our case, this prepends the bos token and removes the last one
258
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
259
+ model_inputs["decoder_input_ids"] = decoder_input_ids
260
+
261
+ return model_inputs
dalle_mini/dataset.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- An image-caption dataset dataloader.
3
- Luke Melas-Kyriazi, 2021
4
- """
5
- import warnings
6
- from typing import Optional, Callable
7
- from pathlib import Path
8
- import numpy as np
9
- import torch
10
- import pandas as pd
11
- from torch.utils.data import Dataset
12
- from torchvision.datasets.folder import default_loader
13
- from PIL import ImageFile
14
- from PIL.Image import DecompressionBombWarning
15
- ImageFile.LOAD_TRUNCATED_IMAGES = True
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
- warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
-
19
-
20
- class CaptionDataset(Dataset):
21
- """
22
- A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
- returns the raw text rather than tokens. This is done on purpose, because
24
- it's easy to tokenize a batch of text after loading it from this dataset.
25
- """
26
-
27
- def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
- image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
- include_captions: bool = True):
30
- """
31
- :param images_root: folder where images are stored
32
- :param captions_path: path to csv that maps image filenames to captions
33
- :param image_transform: image transform pipeline
34
- :param text_transform: image transform pipeline
35
- :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
- :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
- """
38
-
39
- # Base path for images
40
- self.images_root = Path(images_root)
41
-
42
- # Load captions as DataFrame
43
- self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
- self.captions['image_file'] = self.captions['image_file'].astype(str)
45
-
46
- # PyTorch transformation pipeline for the image (normalizing, etc.)
47
- self.text_transform = text_transform
48
- self.image_transform = image_transform
49
- self.image_transform_type = image_transform_type.lower()
50
- assert self.image_transform_type in ['torchvision', 'albumentations']
51
-
52
- # Total number of datapoints
53
- self.size = len(self.captions)
54
-
55
- # Return image+captions or just images
56
- self.include_captions = include_captions
57
-
58
- def verify_that_all_images_exist(self):
59
- for image_file in self.captions['image_file']:
60
- p = self.images_root / image_file
61
- if not p.is_file():
62
- print(f'file does not exist: {p}')
63
-
64
- def _get_raw_image(self, i):
65
- image_file = self.captions.iloc[i]['image_file']
66
- image_path = self.images_root / image_file
67
- image = default_loader(image_path)
68
- return image
69
-
70
- def _get_raw_text(self, i):
71
- return self.captions.iloc[i]['caption']
72
-
73
- def __getitem__(self, i):
74
- image = self._get_raw_image(i)
75
- caption = self._get_raw_text(i)
76
- if self.image_transform is not None:
77
- if self.image_transform_type == 'torchvision':
78
- image = self.image_transform(image)
79
- elif self.image_transform_type == 'albumentations':
80
- image = self.image_transform(image=np.array(image))['image']
81
- else:
82
- raise NotImplementedError(f"{self.image_transform_type=}")
83
- return {'image': image, 'text': caption} if self.include_captions else image
84
-
85
- def __len__(self):
86
- return self.size
87
-
88
-
89
- if __name__ == "__main__":
90
- import albumentations as A
91
- from albumentations.pytorch import ToTensorV2
92
- from transformers import AutoTokenizer
93
-
94
- # Paths
95
- images_root = './images'
96
- captions_path = './images-list-clean.tsv'
97
-
98
- # Create transforms
99
- tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
- def tokenize(text):
101
- return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
- image_transform = A.Compose([
103
- A.Resize(256, 256), A.CenterCrop(256, 256),
104
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
-
106
- # Create dataset
107
- dataset = CaptionDataset(
108
- images_root=images_root,
109
- captions_path=captions_path,
110
- image_transform=image_transform,
111
- text_transform=tokenize,
112
- image_transform_type='albumentations')
113
-
114
- # Create dataloader
115
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
- batch = next(iter(dataloader))
117
- print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
-
119
- # # (Optional) Check that all the images exist
120
- # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
- # dataset.verify_that_all_images_exist()
122
- # print('Done')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax.linen as nn
2
+ import jax
3
+ from transformers import BartConfig
4
+ from transformers.models.bart.modeling_flax_bart import (
5
+ FlaxBartDecoder,
6
+ FlaxBartEncoder,
7
+ FlaxBartForConditionalGeneration,
8
+ FlaxBartForConditionalGenerationModule,
9
+ FlaxBartModule,
10
+ )
11
+
12
+
13
+ class CustomFlaxBartModule(FlaxBartModule):
14
+ def setup(self):
15
+ # we keep shared to easily load pre-trained weights
16
+ self.shared = nn.Embed(
17
+ self.config.vocab_size,
18
+ self.config.d_model,
19
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
20
+ )
21
+ # a separate embedding is used for the decoder
22
+ self.decoder_embed = nn.Embed(
23
+ self.config.image_vocab_size + 1,
24
+ self.config.d_model,
25
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
26
+ )
27
+ self.encoder = FlaxBartEncoder(
28
+ self.config, dtype=self.dtype, embed_tokens=self.shared
29
+ )
30
+
31
+ # the decoder has a different config
32
+ # TODO: should not be needed once we have custom config/module
33
+ decoder_config = BartConfig(self.config.to_dict())
34
+ decoder_config.max_position_embeddings = (
35
+ self.config.image_length + 1 # image tokens + BOS
36
+ )
37
+ decoder_config.vocab_size = self.config.image_vocab_size + 1
38
+ self.decoder = FlaxBartDecoder(
39
+ decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
40
+ )
41
+
42
+
43
+ class CustomFlaxBartForConditionalGenerationModule(
44
+ FlaxBartForConditionalGenerationModule
45
+ ):
46
+ def setup(self):
47
+ # set default config
48
+ self.config.normalize_text = getattr(self.config, "normalize_text", False)
49
+ self.config.image_length = getattr(self.config, "image_length", 256)
50
+ self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
51
+
52
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
53
+ self.lm_head = nn.Dense(
54
+ self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
55
+ use_bias=False,
56
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
57
+ )
58
+ self.final_logits_bias = self.param(
59
+ "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
60
+ )
61
+
62
+
63
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
64
+ module_class = CustomFlaxBartForConditionalGenerationModule
dalle_mini/text.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
+ from pathlib import Path
10
+
11
+ import ftfy
12
+ from huggingface_hub import hf_hub_download
13
+ from unidecode import unidecode
14
+
15
+ # based on wiki word occurence
16
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
17
+ temp_token = "xtokx" # avoid repeating chars
18
+
19
+
20
+ class HashtagProcessor:
21
+ # Adapted from wordninja library
22
+ # We use our wikipedia word count + a good heuristic to make it work
23
+ def __init__(self):
24
+ wiki_word_frequency = hf_hub_download(
25
+ "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
26
+ )
27
+ self._word_cost = (
28
+ l.split()[0] for l in Path(wiki_word_frequency).read_text().splitlines()
29
+ )
30
+ self._word_cost = {
31
+ str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
32
+ }
33
+ self._max_word = max(len(x) for x in self._word_cost.keys())
34
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
35
+
36
+ def __call__(self, s):
37
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
38
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
39
+ return " ".join([item for sublist in l for item in sublist])
40
+
41
+ def _split(self, s):
42
+ # Find the best match for the i first characters, assuming cost has
43
+ # been built for the i-1 first characters.
44
+ # Returns a pair (match_cost, match_length).
45
+ def best_match(i):
46
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
47
+ return min(
48
+ (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
49
+ for k, c in candidates
50
+ )
51
+
52
+ # Build the cost array
53
+ cost = [0]
54
+ for i in range(1, len(s) + 1):
55
+ c, k = best_match(i)
56
+ cost.append(c)
57
+
58
+ # Backtrack to recover the minimal-cost string.
59
+ out = []
60
+ i = len(s)
61
+ while i > 0:
62
+ c, k = best_match(i)
63
+ assert c == cost[i]
64
+ newToken = True
65
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
66
+ if len(out) > 0:
67
+ # re-attach split 's and split digits
68
+ if out[-1] == "'s" or (
69
+ s[i - 1].isdigit() and out[-1][0].isdigit()
70
+ ): # digit followed by digit
71
+ out[-1] = (
72
+ s[i - k : i] + out[-1]
73
+ ) # combine current token with previous token
74
+ newToken = False
75
+
76
+ if newToken:
77
+ out.append(s[i - k : i])
78
+
79
+ i -= k
80
+
81
+ return reversed(out)
82
+
83
+
84
+ def replace_person_token(t):
85
+ "Used for CC12M"
86
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
87
+ while "<person>" in t:
88
+ t = t.replace(
89
+ "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
90
+ )
91
+ return t
92
+
93
+
94
+ def fix_html(t):
95
+ # from OpenAI CLIP
96
+ return html.unescape(html.unescape(t))
97
+
98
+
99
+ def replace_punctuation_with_commas(t):
100
+ return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
101
+
102
+
103
+ def simplify_quotes(t):
104
+ return re.sub("""['"`]""", ' " ', t)
105
+
106
+
107
+ def merge_quotes(t):
108
+ return re.sub('(\s*"+\s*)+', ' " ', t)
109
+
110
+
111
+ def remove_comma_numbers(t):
112
+ def _f(t):
113
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
114
+
115
+ return _f(_f(t))
116
+
117
+
118
+ def pre_process_dot_numbers(t):
119
+ return re.sub("(\w)\.(\w)", fr"\1{temp_token}dot{temp_token}\2", t)
120
+
121
+
122
+ def post_process_dot_numbers(t):
123
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
124
+
125
+
126
+ def pre_process_quotes(t):
127
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
128
+ return re.sub(
129
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
130
+ )
131
+
132
+
133
+ def post_process_quotes(t):
134
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
135
+
136
+
137
+ def pre_process_dates(t):
138
+ return re.sub("(\d)/(\d)", fr"\1{temp_token}slash{temp_token}\2", t)
139
+
140
+
141
+ def post_process_dates(t):
142
+ return re.sub(f"{temp_token}slash{temp_token}", "/", t)
143
+
144
+
145
+ def merge_commas(t):
146
+ return re.sub("(\s*,+\s*)+", ", ", t)
147
+
148
+
149
+ def add_space_after_commas(t):
150
+ return re.sub(",", ", ", t)
151
+
152
+
153
+ def handle_special_chars(t):
154
+ "Handle special characters"
155
+ # replace "-" with a space when between words without space
156
+ t = re.sub("(\w)-(\w)", r"\1 \2", t)
157
+ # always add space around some characters
158
+ return re.sub("([%&\/$*])", r" \1 ", t)
159
+
160
+
161
+ def expand_hashtags(t, hashtag_processor):
162
+ "Remove # and try to split words"
163
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
164
+
165
+
166
+ _re_ignore_chars = r"[_#\\]"
167
+
168
+
169
+ def ignore_chars(t):
170
+ "Ignore useless characters"
171
+ return re.sub(_re_ignore_chars, " ", t)
172
+
173
+
174
+ def remove_extra_spaces(t):
175
+ "Remove extra spaces (including \t and \n)"
176
+ return re.sub("\s+", " ", t)
177
+
178
+
179
+ def remove_repeating_chars(t):
180
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
181
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
182
+
183
+
184
+ def remove_urls(t):
185
+ return re.sub(r"http\S+", "", t)
186
+
187
+
188
+ def remove_html_tags(t):
189
+ return re.sub("<[^<]+?>", "", t)
190
+
191
+
192
+ def remove_first_last_commas(t):
193
+ t = t.strip()
194
+ t = t[:-1] if t and t[-1] == "," else t
195
+ t = t[1:] if t and t[0] == "," else t
196
+ return t.strip()
197
+
198
+
199
+ def remove_wiki_ref(t):
200
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
201
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
202
+
203
+
204
+ class TextNormalizer:
205
+ "Normalize text"
206
+
207
+ def __init__(self):
208
+ self._hashtag_processor = HashtagProcessor()
209
+
210
+ def __call__(self, t):
211
+ # fix some characters
212
+ t = ftfy.fix_text(t)
213
+ # fix html
214
+ t = fix_html(t)
215
+ # decode and simplify text: see unidecode library
216
+ t = unidecode(t)
217
+ # lower case
218
+ t = t.lower()
219
+ # replace <PERSON> (for CC12M)
220
+ t = replace_person_token(t)
221
+ # remove wiki reference (for WIT)
222
+ t = remove_wiki_ref(t)
223
+ # remove html tags
224
+ t = remove_html_tags(t)
225
+ # remove urls
226
+ t = remove_urls(t)
227
+ # remove commas in numbers
228
+ t = remove_comma_numbers(t)
229
+ # handle dots in numbers and quotes - Part 1
230
+ t = pre_process_dot_numbers(t)
231
+ t = pre_process_quotes(t)
232
+ t = pre_process_dates(t)
233
+ # handle special characters
234
+ t = handle_special_chars(t)
235
+ # handle hashtags
236
+ t = expand_hashtags(t, self._hashtag_processor)
237
+ # ignore useless characters
238
+ t = ignore_chars(t)
239
+ # simplify quotes
240
+ t = simplify_quotes(t)
241
+ # all punctuation becomes commas
242
+ t = replace_punctuation_with_commas(t)
243
+ # handle dots in numbers and quotes - Part 2
244
+ t = post_process_dot_numbers(t)
245
+ t = post_process_quotes(t)
246
+ t = post_process_dates(t)
247
+ # handle repeating characters
248
+ t = remove_repeating_chars(t)
249
+ # merge quotes
250
+ t = merge_quotes(t)
251
+ # merge commas
252
+ t = merge_commas(t)
253
+ # remove multiple spaces
254
+ t = remove_extra_spaces(t)
255
+ # remove first and last comma
256
+ t = remove_first_last_commas(t)
257
+ # always start with a space
258
+ return f" {t}"
dalle_mini/vqgan_jax/__init__.py DELETED
File without changes
dalle_mini/vqgan_jax/configuration_vqgan.py DELETED
@@ -1,40 +0,0 @@
1
- from typing import Tuple
2
-
3
- from transformers import PretrainedConfig
4
-
5
-
6
- class VQGANConfig(PretrainedConfig):
7
- def __init__(
8
- self,
9
- ch: int = 128,
10
- out_ch: int = 3,
11
- in_channels: int = 3,
12
- num_res_blocks: int = 2,
13
- resolution: int = 256,
14
- z_channels: int = 256,
15
- ch_mult: Tuple = (1, 1, 2, 2, 4),
16
- attn_resolutions: int = (16,),
17
- n_embed: int = 1024,
18
- embed_dim: int = 256,
19
- dropout: float = 0.0,
20
- double_z: bool = False,
21
- resamp_with_conv: bool = True,
22
- give_pre_end: bool = False,
23
- **kwargs,
24
- ):
25
- super().__init__(**kwargs)
26
- self.ch = ch
27
- self.out_ch = out_ch
28
- self.in_channels = in_channels
29
- self.num_res_blocks = num_res_blocks
30
- self.resolution = resolution
31
- self.z_channels = z_channels
32
- self.ch_mult = list(ch_mult)
33
- self.attn_resolutions = list(attn_resolutions)
34
- self.n_embed = n_embed
35
- self.embed_dim = embed_dim
36
- self.dropout = dropout
37
- self.double_z = double_z
38
- self.resamp_with_conv = resamp_with_conv
39
- self.give_pre_end = give_pre_end
40
- self.num_resolutions = len(ch_mult)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/vqgan_jax/convert_pt_model_to_jax.py DELETED
@@ -1,109 +0,0 @@
1
- import re
2
-
3
- import jax.numpy as jnp
4
- from flax.traverse_util import flatten_dict, unflatten_dict
5
-
6
- import torch
7
-
8
- from modeling_flax_vqgan import VQModel
9
- from configuration_vqgan import VQGANConfig
10
-
11
-
12
- regex = r"\w+[.]\d+"
13
-
14
-
15
- def rename_key(key):
16
- pats = re.findall(regex, key)
17
- for pat in pats:
18
- key = key.replace(pat, "_".join(pat.split(".")))
19
- return key
20
-
21
-
22
- # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
- # convert pytorch tensor to numpy
25
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
-
27
- random_flax_state_dict = flatten_dict(flax_model.params)
28
- flax_state_dict = {}
29
-
30
- remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
- flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
- )
33
- add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
- flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
- )
36
-
37
- # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
- for pt_key, pt_tensor in pt_state_dict.items():
39
- pt_tuple_key = tuple(pt_key.split("."))
40
-
41
- has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
- require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
-
44
- if remove_base_model_prefix and has_base_model_prefix:
45
- pt_tuple_key = pt_tuple_key[1:]
46
- elif add_base_model_prefix and require_base_model_prefix:
47
- pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
-
49
- # Correctly rename weight parameters
50
- if (
51
- "norm" in pt_key
52
- and (pt_tuple_key[-1] == "bias")
53
- and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
- ):
55
- pt_tensor = pt_tensor[None, None, None, :]
56
- elif (
57
- "norm" in pt_key
58
- and (pt_tuple_key[-1] == "bias")
59
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
- ):
61
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
- pt_tensor = pt_tensor[None, None, None, :]
63
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
- pt_tensor = pt_tensor[None, None, None, :]
66
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
- elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
- # conv layer
70
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
- elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
- # linear layer
74
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
- pt_tensor = pt_tensor.T
76
- elif pt_tuple_key[-1] == "gamma":
77
- pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
- elif pt_tuple_key[-1] == "beta":
79
- pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
-
81
- if pt_tuple_key in random_flax_state_dict:
82
- if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
- raise ValueError(
84
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
- f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
- )
87
-
88
- # also add unexpected weight so that warning is thrown
89
- flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
-
91
- return unflatten_dict(flax_state_dict)
92
-
93
-
94
- def convert_model(config_path, pt_state_dict_path, save_path):
95
- config = VQGANConfig.from_pretrained(config_path)
96
- model = VQModel(config)
97
-
98
- state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
- keys = list(state_dict.keys())
100
- for key in keys:
101
- if key.startswith("loss"):
102
- state_dict.pop(key)
103
- continue
104
- renamed_key = rename_key(key)
105
- state_dict[renamed_key] = state_dict.pop(key)
106
-
107
- state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
- model.params = unflatten_dict(state)
109
- model.save_pretrained(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/vqgan_jax/modeling_flax_vqgan.py DELETED
@@ -1,609 +0,0 @@
1
- # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
-
3
- from functools import partial
4
- from typing import Tuple
5
- import math
6
-
7
- import jax
8
- import jax.numpy as jnp
9
- import numpy as np
10
- import flax.linen as nn
11
- from flax.core.frozen_dict import FrozenDict
12
-
13
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
-
15
- from .configuration_vqgan import VQGANConfig
16
-
17
-
18
- class Upsample(nn.Module):
19
- in_channels: int
20
- with_conv: bool
21
- dtype: jnp.dtype = jnp.float32
22
-
23
- def setup(self):
24
- if self.with_conv:
25
- self.conv = nn.Conv(
26
- self.in_channels,
27
- kernel_size=(3, 3),
28
- strides=(1, 1),
29
- padding=((1, 1), (1, 1)),
30
- dtype=self.dtype,
31
- )
32
-
33
- def __call__(self, hidden_states):
34
- batch, height, width, channels = hidden_states.shape
35
- hidden_states = jax.image.resize(
36
- hidden_states,
37
- shape=(batch, height * 2, width * 2, channels),
38
- method="nearest",
39
- )
40
- if self.with_conv:
41
- hidden_states = self.conv(hidden_states)
42
- return hidden_states
43
-
44
-
45
- class Downsample(nn.Module):
46
- in_channels: int
47
- with_conv: bool
48
- dtype: jnp.dtype = jnp.float32
49
-
50
- def setup(self):
51
- if self.with_conv:
52
- self.conv = nn.Conv(
53
- self.in_channels,
54
- kernel_size=(3, 3),
55
- strides=(2, 2),
56
- padding="VALID",
57
- dtype=self.dtype,
58
- )
59
-
60
- def __call__(self, hidden_states):
61
- if self.with_conv:
62
- pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
- hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
- hidden_states = self.conv(hidden_states)
65
- else:
66
- hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
- return hidden_states
68
-
69
-
70
- class ResnetBlock(nn.Module):
71
- in_channels: int
72
- out_channels: int = None
73
- use_conv_shortcut: bool = False
74
- temb_channels: int = 512
75
- dropout_prob: float = 0.0
76
- dtype: jnp.dtype = jnp.float32
77
-
78
- def setup(self):
79
- self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
-
81
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
- self.conv1 = nn.Conv(
83
- self.out_channels_,
84
- kernel_size=(3, 3),
85
- strides=(1, 1),
86
- padding=((1, 1), (1, 1)),
87
- dtype=self.dtype,
88
- )
89
-
90
- if self.temb_channels:
91
- self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
-
93
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
- self.dropout = nn.Dropout(self.dropout_prob)
95
- self.conv2 = nn.Conv(
96
- self.out_channels_,
97
- kernel_size=(3, 3),
98
- strides=(1, 1),
99
- padding=((1, 1), (1, 1)),
100
- dtype=self.dtype,
101
- )
102
-
103
- if self.in_channels != self.out_channels_:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = nn.Conv(
106
- self.out_channels_,
107
- kernel_size=(3, 3),
108
- strides=(1, 1),
109
- padding=((1, 1), (1, 1)),
110
- dtype=self.dtype,
111
- )
112
- else:
113
- self.nin_shortcut = nn.Conv(
114
- self.out_channels_,
115
- kernel_size=(1, 1),
116
- strides=(1, 1),
117
- padding="VALID",
118
- dtype=self.dtype,
119
- )
120
-
121
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
- residual = hidden_states
123
- hidden_states = self.norm1(hidden_states)
124
- hidden_states = nn.swish(hidden_states)
125
- hidden_states = self.conv1(hidden_states)
126
-
127
- if temb is not None:
128
- hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
-
130
- hidden_states = self.norm2(hidden_states)
131
- hidden_states = nn.swish(hidden_states)
132
- hidden_states = self.dropout(hidden_states, deterministic)
133
- hidden_states = self.conv2(hidden_states)
134
-
135
- if self.in_channels != self.out_channels_:
136
- if self.use_conv_shortcut:
137
- residual = self.conv_shortcut(residual)
138
- else:
139
- residual = self.nin_shortcut(residual)
140
-
141
- return hidden_states + residual
142
-
143
-
144
- class AttnBlock(nn.Module):
145
- in_channels: int
146
- dtype: jnp.dtype = jnp.float32
147
-
148
- def setup(self):
149
- conv = partial(
150
- nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
- )
152
-
153
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
- self.q, self.k, self.v = conv(), conv(), conv()
155
- self.proj_out = conv()
156
-
157
- def __call__(self, hidden_states):
158
- residual = hidden_states
159
- hidden_states = self.norm(hidden_states)
160
-
161
- query = self.q(hidden_states)
162
- key = self.k(hidden_states)
163
- value = self.v(hidden_states)
164
-
165
- # compute attentions
166
- batch, height, width, channels = query.shape
167
- query = query.reshape((batch, height * width, channels))
168
- key = key.reshape((batch, height * width, channels))
169
- attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
- attn_weights = attn_weights * (int(channels) ** -0.5)
171
- attn_weights = nn.softmax(attn_weights, axis=2)
172
-
173
- ## attend to values
174
- value = value.reshape((batch, height * width, channels))
175
- hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
- hidden_states = hidden_states.reshape((batch, height, width, channels))
177
-
178
- hidden_states = self.proj_out(hidden_states)
179
- hidden_states = hidden_states + residual
180
- return hidden_states
181
-
182
-
183
- class UpsamplingBlock(nn.Module):
184
- config: VQGANConfig
185
- curr_res: int
186
- block_idx: int
187
- dtype: jnp.dtype = jnp.float32
188
-
189
- def setup(self):
190
- if self.block_idx == self.config.num_resolutions - 1:
191
- block_in = self.config.ch * self.config.ch_mult[-1]
192
- else:
193
- block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
-
195
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
- self.temb_ch = 0
197
-
198
- res_blocks = []
199
- attn_blocks = []
200
- for _ in range(self.config.num_res_blocks + 1):
201
- res_blocks.append(
202
- ResnetBlock(
203
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
- )
205
- )
206
- block_in = block_out
207
- if self.curr_res in self.config.attn_resolutions:
208
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
-
210
- self.block = res_blocks
211
- self.attn = attn_blocks
212
-
213
- self.upsample = None
214
- if self.block_idx != 0:
215
- self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
-
217
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
- for res_block in self.block:
219
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
- for attn_block in self.attn:
221
- hidden_states = attn_block(hidden_states)
222
-
223
- if self.upsample is not None:
224
- hidden_states = self.upsample(hidden_states)
225
-
226
- return hidden_states
227
-
228
-
229
- class DownsamplingBlock(nn.Module):
230
- config: VQGANConfig
231
- curr_res: int
232
- block_idx: int
233
- dtype: jnp.dtype = jnp.float32
234
-
235
- def setup(self):
236
- in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
- block_in = self.config.ch * in_ch_mult[self.block_idx]
238
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
- self.temb_ch = 0
240
-
241
- res_blocks = []
242
- attn_blocks = []
243
- for _ in range(self.config.num_res_blocks):
244
- res_blocks.append(
245
- ResnetBlock(
246
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
- )
248
- )
249
- block_in = block_out
250
- if self.curr_res in self.config.attn_resolutions:
251
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
-
253
- self.block = res_blocks
254
- self.attn = attn_blocks
255
-
256
- self.downsample = None
257
- if self.block_idx != self.config.num_resolutions - 1:
258
- self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
-
260
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
- for res_block in self.block:
262
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
- for attn_block in self.attn:
264
- hidden_states = attn_block(hidden_states)
265
-
266
- if self.downsample is not None:
267
- hidden_states = self.downsample(hidden_states)
268
-
269
- return hidden_states
270
-
271
-
272
- class MidBlock(nn.Module):
273
- in_channels: int
274
- temb_channels: int
275
- dropout: float
276
- dtype: jnp.dtype = jnp.float32
277
-
278
- def setup(self):
279
- self.block_1 = ResnetBlock(
280
- self.in_channels,
281
- self.in_channels,
282
- temb_channels=self.temb_channels,
283
- dropout_prob=self.dropout,
284
- dtype=self.dtype,
285
- )
286
- self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
- self.block_2 = ResnetBlock(
288
- self.in_channels,
289
- self.in_channels,
290
- temb_channels=self.temb_channels,
291
- dropout_prob=self.dropout,
292
- dtype=self.dtype,
293
- )
294
-
295
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
- hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
- hidden_states = self.attn_1(hidden_states)
298
- hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
- return hidden_states
300
-
301
-
302
- class Encoder(nn.Module):
303
- config: VQGANConfig
304
- dtype: jnp.dtype = jnp.float32
305
-
306
- def setup(self):
307
- self.temb_ch = 0
308
-
309
- # downsampling
310
- self.conv_in = nn.Conv(
311
- self.config.ch,
312
- kernel_size=(3, 3),
313
- strides=(1, 1),
314
- padding=((1, 1), (1, 1)),
315
- dtype=self.dtype,
316
- )
317
-
318
- curr_res = self.config.resolution
319
- downsample_blocks = []
320
- for i_level in range(self.config.num_resolutions):
321
- downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
-
323
- if i_level != self.config.num_resolutions - 1:
324
- curr_res = curr_res // 2
325
- self.down = downsample_blocks
326
-
327
- # middle
328
- mid_channels = self.config.ch * self.config.ch_mult[-1]
329
- self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
-
331
- # end
332
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
- self.conv_out = nn.Conv(
334
- 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
- kernel_size=(3, 3),
336
- strides=(1, 1),
337
- padding=((1, 1), (1, 1)),
338
- dtype=self.dtype,
339
- )
340
-
341
- def __call__(self, pixel_values, deterministic: bool = True):
342
- # timestep embedding
343
- temb = None
344
-
345
- # downsampling
346
- hidden_states = self.conv_in(pixel_values)
347
- for block in self.down:
348
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
-
350
- # middle
351
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
-
353
- # end
354
- hidden_states = self.norm_out(hidden_states)
355
- hidden_states = nn.swish(hidden_states)
356
- hidden_states = self.conv_out(hidden_states)
357
-
358
- return hidden_states
359
-
360
-
361
- class Decoder(nn.Module):
362
- config: VQGANConfig
363
- dtype: jnp.dtype = jnp.float32
364
-
365
- def setup(self):
366
- self.temb_ch = 0
367
-
368
- # compute in_ch_mult, block_in and curr_res at lowest res
369
- block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
- curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
- self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
-
374
- # z to block_in
375
- self.conv_in = nn.Conv(
376
- block_in,
377
- kernel_size=(3, 3),
378
- strides=(1, 1),
379
- padding=((1, 1), (1, 1)),
380
- dtype=self.dtype,
381
- )
382
-
383
- # middle
384
- self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
-
386
- # upsampling
387
- upsample_blocks = []
388
- for i_level in reversed(range(self.config.num_resolutions)):
389
- upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
- if i_level != 0:
391
- curr_res = curr_res * 2
392
- self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
-
394
- # end
395
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
- self.conv_out = nn.Conv(
397
- self.config.out_ch,
398
- kernel_size=(3, 3),
399
- strides=(1, 1),
400
- padding=((1, 1), (1, 1)),
401
- dtype=self.dtype,
402
- )
403
-
404
- def __call__(self, hidden_states, deterministic: bool = True):
405
- # timestep embedding
406
- temb = None
407
-
408
- # z to block_in
409
- hidden_states = self.conv_in(hidden_states)
410
-
411
- # middle
412
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
-
414
- # upsampling
415
- for block in reversed(self.up):
416
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
-
418
- # end
419
- if self.config.give_pre_end:
420
- return hidden_states
421
-
422
- hidden_states = self.norm_out(hidden_states)
423
- hidden_states = nn.swish(hidden_states)
424
- hidden_states = self.conv_out(hidden_states)
425
-
426
- return hidden_states
427
-
428
-
429
- class VectorQuantizer(nn.Module):
430
- """
431
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
- ____________________________________________
433
- Discretization bottleneck part of the VQ-VAE.
434
- Inputs:
435
- - n_e : number of embeddings
436
- - e_dim : dimension of embedding
437
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
- _____________________________________________
439
- """
440
-
441
- config: VQGANConfig
442
- dtype: jnp.dtype = jnp.float32
443
-
444
- def setup(self):
445
- self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
-
447
- def __call__(self, hidden_states):
448
- """
449
- Inputs the output of the encoder network z and maps it to a discrete
450
- one-hot vector that is the index of the closest embedding vector e_j
451
- z (continuous) -> z_q (discrete)
452
- z.shape = (batch, channel, height, width)
453
- quantization pipeline:
454
- 1. get encoder input (B,C,H,W)
455
- 2. flatten input to (B*H*W,C)
456
- """
457
- # flatten
458
- hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
-
460
- # dummy op to init the weights, so we can access them below
461
- self.embedding(jnp.ones((1, 1), dtype="i4"))
462
-
463
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
- emb_weights = self.variables["params"]["embedding"]["embedding"]
465
- distance = (
466
- jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
- + jnp.sum(emb_weights ** 2, axis=1)
468
- - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
- )
470
-
471
- # get quantized latent vectors
472
- min_encoding_indices = jnp.argmin(distance, axis=1)
473
- z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
-
475
- # reshape to (batch, num_tokens)
476
- min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
-
478
- # compute the codebook_loss (q_loss) outside the model
479
- # here we return the embeddings and indices
480
- return z_q, min_encoding_indices
481
-
482
- def get_codebook_entry(self, indices, shape=None):
483
- # indices are expected to be of shape (batch, num_tokens)
484
- # get quantized latent vectors
485
- batch, num_tokens = indices.shape
486
- z_q = self.embedding(indices)
487
- z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
- return z_q
489
-
490
-
491
- class VQModule(nn.Module):
492
- config: VQGANConfig
493
- dtype: jnp.dtype = jnp.float32
494
-
495
- def setup(self):
496
- self.encoder = Encoder(self.config, dtype=self.dtype)
497
- self.decoder = Decoder(self.config, dtype=self.dtype)
498
- self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
- self.quant_conv = nn.Conv(
500
- self.config.embed_dim,
501
- kernel_size=(1, 1),
502
- strides=(1, 1),
503
- padding="VALID",
504
- dtype=self.dtype,
505
- )
506
- self.post_quant_conv = nn.Conv(
507
- self.config.z_channels,
508
- kernel_size=(1, 1),
509
- strides=(1, 1),
510
- padding="VALID",
511
- dtype=self.dtype,
512
- )
513
-
514
- def encode(self, pixel_values, deterministic: bool = True):
515
- hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
- hidden_states = self.quant_conv(hidden_states)
517
- quant_states, indices = self.quantize(hidden_states)
518
- return quant_states, indices
519
-
520
- def decode(self, hidden_states, deterministic: bool = True):
521
- hidden_states = self.post_quant_conv(hidden_states)
522
- hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
- return hidden_states
524
-
525
- def decode_code(self, code_b):
526
- hidden_states = self.quantize.get_codebook_entry(code_b)
527
- hidden_states = self.decode(hidden_states)
528
- return hidden_states
529
-
530
- def __call__(self, pixel_values, deterministic: bool = True):
531
- quant_states, indices = self.encode(pixel_values, deterministic)
532
- hidden_states = self.decode(quant_states, deterministic)
533
- return hidden_states, indices
534
-
535
-
536
- class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
- """
538
- An abstract class to handle weights initialization and a simple interface
539
- for downloading and loading pretrained models.
540
- """
541
-
542
- config_class = VQGANConfig
543
- base_model_prefix = "model"
544
- module_class: nn.Module = None
545
-
546
- def __init__(
547
- self,
548
- config: VQGANConfig,
549
- input_shape: Tuple = (1, 256, 256, 3),
550
- seed: int = 0,
551
- dtype: jnp.dtype = jnp.float32,
552
- **kwargs,
553
- ):
554
- module = self.module_class(config=config, dtype=dtype, **kwargs)
555
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
-
557
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
- # init input tensors
559
- pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
- params_rng, dropout_rng = jax.random.split(rng)
561
- rngs = {"params": params_rng, "dropout": dropout_rng}
562
-
563
- return self.module.init(rngs, pixel_values)["params"]
564
-
565
- def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
- # Handle any PRNG if needed
567
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
-
569
- return self.module.apply(
570
- {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
- )
572
-
573
- def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
- # Handle any PRNG if needed
575
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
-
577
- return self.module.apply(
578
- {"params": params or self.params},
579
- jnp.array(hidden_states),
580
- not train,
581
- rngs=rngs,
582
- method=self.module.decode,
583
- )
584
-
585
- def decode_code(self, indices, params: dict = None):
586
- return self.module.apply(
587
- {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
- )
589
-
590
- def __call__(
591
- self,
592
- pixel_values,
593
- params: dict = None,
594
- dropout_rng: jax.random.PRNGKey = None,
595
- train: bool = False,
596
- ):
597
- # Handle any PRNG if needed
598
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
-
600
- return self.module.apply(
601
- {"params": params or self.params},
602
- jnp.array(pixel_values),
603
- not train,
604
- rngs=rngs,
605
- )
606
-
607
-
608
- class VQModel(VQGANPreTrainedModel):
609
- module_class = VQModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/CC12M_downloader.py DELETED
@@ -1,91 +0,0 @@
1
- # Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
2
-
3
- #%%
4
- import sys
5
- import os
6
- from datetime import datetime
7
- import pandas as pd
8
- import contexttimer
9
- from urllib.request import urlopen
10
- import requests
11
- from PIL import Image
12
- import torch
13
- from torchvision.transforms import functional as TF
14
- from multiprocessing import Pool
15
- from tqdm import tqdm
16
- import logging
17
-
18
- # Setup
19
- logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
20
- requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
21
-
22
-
23
- # # For downloading SVG images (I can't get this to work)
24
- # from io import BytesIO
25
- # import cairosvg
26
-
27
- #%%
28
- # Load data
29
- print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
30
- with contexttimer.Timer(prefix="Loading from tsv"):
31
- df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
32
-
33
- url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
34
- print(f'Loaded {len(url_to_idx_map)} urls')
35
-
36
- #%%
37
- df.head()
38
-
39
- #%%
40
-
41
- # Note: it seems that there are no SVG images
42
- df.sample(10000)[1].str.contains('.svg').sum()
43
-
44
- #%%
45
- # Resize function
46
- def resize(img):
47
- max_size_of_short_side = 512
48
- if min(img.size) > max_size_of_short_side:
49
- img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
50
- return img
51
-
52
- base_dir = os.path.join(os.getcwd(), 'images')
53
-
54
- def process(item):
55
- url, image_id = item
56
- try:
57
- base_url = os.path.basename(url) # extract base url
58
- stem, ext = os.path.splitext(base_url) # split into stem and extension
59
- filename = f'{image_id:08d}---{stem}.jpg' # create filename
60
- filepath = os.path.join(base_dir, filename) # concat to get filepath
61
- if not os.path.isfile(filepath):
62
- # if filepath.endswith('.svg'):
63
- # raise NotImplementedError()
64
- # image_bytes = BytesIO() # create a bytestream
65
- # cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
66
- # else:
67
- req = requests.get(url, stream=True, timeout=1, verify=False).raw
68
- image = Image.open(req).convert('RGB')
69
- if min(image.size) > 512:
70
- image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
71
- # image = resize(image) # resize PIL image
72
- image.save(filepath) # save PIL image
73
- except Exception as e:
74
- logging.info(" ".join(repr(e).splitlines()))
75
- logging.error(url)
76
-
77
- #%%
78
- #for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
79
- # process(item)
80
- # if i > 100:
81
- # break
82
-
83
- # Use multiprocessing for speed
84
- list_of_items = list(url_to_idx_map.items())
85
- print(len(list_of_items))
86
- list_of_items = list_of_items[10_000_000:]
87
- print(len(list_of_items))
88
- with Pool(128) as p:
89
- r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
90
- print('DONE')
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/CC3M_downloader.py DELETED
@@ -1,62 +0,0 @@
1
- '''
2
- This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
3
- Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
4
- Find them here- [https://github.com/google-research-datasets/conceptual-captions]
5
- '''
6
-
7
- import sys
8
- import os
9
- from datetime import datetime
10
- import pandas as pd
11
- import contexttimer
12
- from urllib.request import urlopen
13
- import requests
14
- from PIL import Image
15
- import torch
16
- from torchvision.transforms import functional as TF
17
- from multiprocessing import Pool
18
- from tqdm import tqdm
19
- import logging
20
- import sys
21
-
22
- # Setup
23
- logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
24
- requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
25
-
26
- if len(sys.argv) != 3:
27
- print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
28
- exit(1)
29
-
30
- # Load data
31
- print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
32
- with contexttimer.Timer(prefix="Loading from tsv"):
33
- df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
34
-
35
- url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
36
- print(f'Loaded {len(url_to_idx_map)} urls')
37
-
38
- base_dir = os.path.join(os.getcwd(), sys.argv[2])
39
-
40
- def process(item):
41
- url, image_id = item
42
- try:
43
- base_url = os.path.basename(url) # extract base url
44
- stem, ext = os.path.splitext(base_url) # split into stem and extension
45
- filename = f'{image_id:08d}---{stem}.jpg' # create filename
46
- filepath = os.path.join(base_dir, filename) # concat to get filepath
47
- if not os.path.isfile(filepath):
48
- req = requests.get(url, stream=True, timeout=1, verify=False).raw
49
- image = Image.open(req).convert('RGB')
50
- if min(image.size) > 512:
51
- image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
52
- image.save(filepath) # save PIL image
53
- except Exception as e:
54
- logging.info(" ".join(repr(e).splitlines()))
55
- logging.error(url)
56
-
57
- list_of_items = list(url_to_idx_map.items())
58
- print(len(list_of_items))
59
-
60
- with Pool(128) as p:
61
- r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
62
- print('DONE')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/CustomBARTv4b_model-generate.ipynb DELETED
@@ -1,566 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "name": "CustomBARTv4b-model-generate.ipynb",
7
- "provenance": [],
8
- "collapsed_sections": [],
9
- "machine_shape": "hm"
10
- },
11
- "kernelspec": {
12
- "name": "python3",
13
- "display_name": "Python 3"
14
- },
15
- "language_info": {
16
- "name": "python"
17
- },
18
- "accelerator": "TPU"
19
- },
20
- "cells": [
21
- {
22
- "cell_type": "markdown",
23
- "metadata": {
24
- "id": "ewer-Q-0w2xA"
25
- },
26
- "source": [
27
- "# Installation"
28
- ]
29
- },
30
- {
31
- "cell_type": "code",
32
- "metadata": {
33
- "colab": {
34
- "base_uri": "https://localhost:8080/"
35
- },
36
- "id": "NpsF9ipLLl2s",
37
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
38
- },
39
- "source": [
40
- "!pip install git+https://github.com/huggingface/transformers/\n",
41
- "!pip install git+https://github.com/google/flax"
42
- ],
43
- "execution_count": 1,
44
- "outputs": [
45
- {
46
- "output_type": "stream",
47
- "text": [
48
- "Collecting git+https://github.com/huggingface/transformers/\n",
49
- " Cloning https://github.com/huggingface/transformers/ to /tmp/pip-req-build-oxejx1op\n",
50
- " Running command git clone -q https://github.com/huggingface/transformers/ /tmp/pip-req-build-oxejx1op\n",
51
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
52
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
53
- " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
54
- "Requirement already satisfied (use --upgrade to upgrade): transformers==4.9.0.dev0 from git+https://github.com/huggingface/transformers/ in /usr/local/lib/python3.7/dist-packages\n",
55
- "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (1.19.5)\n",
56
- "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (20.9)\n",
57
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (5.4.1)\n",
58
- "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.45)\n",
59
- "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.6.0)\n",
60
- "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (4.41.1)\n",
61
- "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (3.0.12)\n",
62
- "Requirement already satisfied: huggingface-hub==0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.0.12)\n",
63
- "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (0.10.3)\n",
64
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2019.12.20)\n",
65
- "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.9.0.dev0) (2.23.0)\n",
66
- "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.9.0.dev0) (2.4.7)\n",
67
- "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.15.0)\n",
68
- "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (1.0.1)\n",
69
- "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.9.0.dev0) (7.1.2)\n",
70
- "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.7.4.3)\n",
71
- "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.9.0.dev0) (3.4.1)\n",
72
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2021.5.30)\n",
73
- "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (3.0.4)\n",
74
- "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (1.24.3)\n",
75
- "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.9.0.dev0) (2.10)\n",
76
- "Building wheels for collected packages: transformers\n",
77
- " Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
78
- " Created wheel for transformers: filename=transformers-4.9.0.dev0-cp37-none-any.whl size=2582229 sha256=249c593273ccca3027c6427d2c6fd749a89f21d722d628d97eb438a2cf3185a8\n",
79
- " Stored in directory: /tmp/pip-ephem-wheel-cache-l2rqt1b7/wheels/61/69/33/974fccec4d0ab5feee9fe83bd93e680d269a805be9ede5ec60\n",
80
- "Successfully built transformers\n",
81
- "Collecting git+https://github.com/google/flax\n",
82
- " Cloning https://github.com/google/flax to /tmp/pip-req-build-rt9g1_wx\n",
83
- " Running command git clone -q https://github.com/google/flax /tmp/pip-req-build-rt9g1_wx\n",
84
- "Requirement already satisfied (use --upgrade to upgrade): flax==0.3.4 from git+https://github.com/google/flax in /usr/local/lib/python3.7/dist-packages\n",
85
- "Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.19.5)\n",
86
- "Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.2.13)\n",
87
- "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (3.2.2)\n",
88
- "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (1.0.2)\n",
89
- "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax==0.3.4) (0.0.9)\n",
90
- "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (3.3.0)\n",
91
- "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->flax==0.3.4) (0.12.0)\n",
92
- "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.8.1)\n",
93
- "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (0.10.0)\n",
94
- "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (2.4.7)\n",
95
- "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.4) (1.3.1)\n",
96
- "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.0.8)\n",
97
- "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax==0.3.4) (0.1.66+cuda110)\n",
98
- "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->flax==0.3.4) (1.15.0)\n",
99
- "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.1.6)\n",
100
- "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax==0.3.4) (0.11.1)\n",
101
- "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.12)\n",
102
- "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax==0.3.4) (1.4.1)\n",
103
- "Building wheels for collected packages: flax\n",
104
- " Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
105
- " Created wheel for flax: filename=flax-0.3.4-cp37-none-any.whl size=184692 sha256=503b27995f372afe33631e71572d5edc1fffd4d2e0a4cd206d291ad6b0e4c299\n",
106
- " Stored in directory: /tmp/pip-ephem-wheel-cache-g1pzxnv6/wheels/3d/26/f4/0ea6051d7352289d9e4f8178348452b35a9a97bde6035405a5\n",
107
- "Successfully built flax\n"
108
- ],
109
- "name": "stdout"
110
- }
111
- ]
112
- },
113
- {
114
- "cell_type": "code",
115
- "metadata": {
116
- "id": "M1wVkrpjU6zO"
117
- },
118
- "source": [
119
- "%load_ext autoreload\n",
120
- "%autoreload 2"
121
- ],
122
- "execution_count": 2,
123
- "outputs": []
124
- },
125
- {
126
- "cell_type": "markdown",
127
- "metadata": {
128
- "id": "t47CH1H_IOT8"
129
- },
130
- "source": [
131
- "# Custom BART Model"
132
- ]
133
- },
134
- {
135
- "cell_type": "code",
136
- "metadata": {
137
- "id": "9jQnM6S2vCpn"
138
- },
139
- "source": [
140
- "# TODO: set those args in a config file\n",
141
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
142
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
143
- "BOS_TOKEN_ID = 16384\n",
144
- "BASE_MODEL = 'facebook/bart-large'"
145
- ],
146
- "execution_count": 3,
147
- "outputs": []
148
- },
149
- {
150
- "cell_type": "code",
151
- "metadata": {
152
- "id": "_eEaJVxAKpV5"
153
- },
154
- "source": [
155
- "import jax\n",
156
- "import flax.linen as nn\n",
157
- "\n",
158
- "from transformers.models.bart.modeling_flax_bart import *\n",
159
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
160
- "\n",
161
- "class CustomFlaxBartModule(FlaxBartModule):\n",
162
- " def setup(self):\n",
163
- " # we keep shared to easily load pre-trained weights\n",
164
- " self.shared = nn.Embed(\n",
165
- " self.config.vocab_size,\n",
166
- " self.config.d_model,\n",
167
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
168
- " dtype=self.dtype,\n",
169
- " )\n",
170
- " # a separate embedding is used for the decoder\n",
171
- " self.decoder_embed = nn.Embed(\n",
172
- " OUTPUT_VOCAB_SIZE,\n",
173
- " self.config.d_model,\n",
174
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
175
- " dtype=self.dtype,\n",
176
- " )\n",
177
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
178
- "\n",
179
- " # the decoder has a different config\n",
180
- " decoder_config = BartConfig(self.config.to_dict())\n",
181
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
182
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
183
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
184
- "\n",
185
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
186
- " def setup(self):\n",
187
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
188
- " self.lm_head = nn.Dense(\n",
189
- " OUTPUT_VOCAB_SIZE,\n",
190
- " use_bias=False,\n",
191
- " dtype=self.dtype,\n",
192
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
193
- " )\n",
194
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
195
- "\n",
196
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
197
- " module_class = CustomFlaxBartForConditionalGenerationModule"
198
- ],
199
- "execution_count": 4,
200
- "outputs": []
201
- },
202
- {
203
- "cell_type": "code",
204
- "metadata": {
205
- "id": "S7CP9Td9m2ge",
206
- "colab": {
207
- "base_uri": "https://localhost:8080/"
208
- },
209
- "outputId": "5638ef68-9c40-46f7-90ba-a4d05b61360d"
210
- },
211
- "source": [
212
- "# load pre-trained model for encoder weights\n",
213
- "base_model = FlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)"
214
- ],
215
- "execution_count": 5,
216
- "outputs": [
217
- {
218
- "output_type": "stream",
219
- "text": [
220
- "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
221
- ],
222
- "name": "stderr"
223
- }
224
- ]
225
- },
226
- {
227
- "cell_type": "code",
228
- "metadata": {
229
- "id": "6lmynR-poceH"
230
- },
231
- "source": [
232
- "# set up our new model config\n",
233
- "config = BartConfig.from_pretrained(BASE_MODEL)\n",
234
- "config.tie_word_embeddings = False\n",
235
- "config.decoder_start_token_id = BOS_TOKEN_ID\n",
236
- "config.bos_token_id = BOS_TOKEN_ID # should not be used\n",
237
- "config.pos_token_id = BOS_TOKEN_ID # should not be used\n",
238
- "#config.eos_token_id = None # prevents generation from stopping until we reach max_length"
239
- ],
240
- "execution_count": 6,
241
- "outputs": []
242
- },
243
- {
244
- "cell_type": "code",
245
- "metadata": {
246
- "id": "_6-XKK40oEfP"
247
- },
248
- "source": [
249
- "# create our model and initialize it randomly\n",
250
- "model = CustomFlaxBartForConditionalGeneration(config)"
251
- ],
252
- "execution_count": 7,
253
- "outputs": []
254
- },
255
- {
256
- "cell_type": "code",
257
- "metadata": {
258
- "id": "-r_hZestr-NR"
259
- },
260
- "source": [
261
- "# use pretrained weights\n",
262
- "model.params['model']['encoder'] = base_model.params['model']['encoder']\n",
263
- "model.params['model']['shared'] = base_model.params['model']['shared']"
264
- ],
265
- "execution_count": 8,
266
- "outputs": []
267
- },
268
- {
269
- "cell_type": "code",
270
- "metadata": {
271
- "id": "5NEX8f62sVjx"
272
- },
273
- "source": [
274
- "# no need for base_model anymore\n",
275
- "del base_model"
276
- ],
277
- "execution_count": 9,
278
- "outputs": []
279
- },
280
- {
281
- "cell_type": "code",
282
- "metadata": {
283
- "colab": {
284
- "base_uri": "https://localhost:8080/"
285
- },
286
- "id": "Jz032w73nHEf",
287
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
288
- },
289
- "source": [
290
- "# we verify that the shape has not been modified\n",
291
- "model.params['final_logits_bias'].shape"
292
- ],
293
- "execution_count": 10,
294
- "outputs": [
295
- {
296
- "output_type": "execute_result",
297
- "data": {
298
- "text/plain": [
299
- "(1, 16385)"
300
- ]
301
- },
302
- "metadata": {
303
- "tags": []
304
- },
305
- "execution_count": 10
306
- }
307
- ]
308
- },
309
- {
310
- "cell_type": "markdown",
311
- "metadata": {
312
- "id": "zLl24Ez5t7x1"
313
- },
314
- "source": [
315
- "## Inference"
316
- ]
317
- },
318
- {
319
- "cell_type": "code",
320
- "metadata": {
321
- "id": "XLLA2NK3uDQr"
322
- },
323
- "source": [
324
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
325
- ],
326
- "execution_count": 11,
327
- "outputs": []
328
- },
329
- {
330
- "cell_type": "code",
331
- "metadata": {
332
- "colab": {
333
- "base_uri": "https://localhost:8080/"
334
- },
335
- "id": "Ntow53I_t81D",
336
- "outputId": "59289cdd-1429-4720-cc87-88810c4b99ac"
337
- },
338
- "source": [
339
- "text = \"My friends are cool but they eat too many carbs.\"\n",
340
- "inputs = tokenizer(text, max_length=1024, return_tensors='jax')\n",
341
- "encoder_outputs = model.encode(**inputs)"
342
- ],
343
- "execution_count": 12,
344
- "outputs": [
345
- {
346
- "output_type": "stream",
347
- "text": [
348
- "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
349
- ],
350
- "name": "stderr"
351
- }
352
- ]
353
- },
354
- {
355
- "cell_type": "code",
356
- "metadata": {
357
- "colab": {
358
- "base_uri": "https://localhost:8080/"
359
- },
360
- "id": "vcRNJnJ_uJOJ",
361
- "outputId": "025afd54-7908-4a9c-fb59-e40bd3458711"
362
- },
363
- "source": [
364
- "decoder_start_token_id = model.config.decoder_start_token_id\n",
365
- "decoder_start_token_id"
366
- ],
367
- "execution_count": 13,
368
- "outputs": [
369
- {
370
- "output_type": "execute_result",
371
- "data": {
372
- "text/plain": [
373
- "16384"
374
- ]
375
- },
376
- "metadata": {
377
- "tags": []
378
- },
379
- "execution_count": 13
380
- }
381
- ]
382
- },
383
- {
384
- "cell_type": "code",
385
- "metadata": {
386
- "id": "6QWmEwL_uMld"
387
- },
388
- "source": [
389
- "decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype=\"i4\") * decoder_start_token_id\n",
390
- "outputs = model.decode(decoder_input_ids, encoder_outputs)"
391
- ],
392
- "execution_count": 14,
393
- "outputs": []
394
- },
395
- {
396
- "cell_type": "code",
397
- "metadata": {
398
- "colab": {
399
- "base_uri": "https://localhost:8080/"
400
- },
401
- "id": "c_ys3yWBothF",
402
- "outputId": "40d4d584-e0a8-44cb-bbea-0ffa38d50a53"
403
- },
404
- "source": [
405
- "outputs"
406
- ],
407
- "execution_count": 15,
408
- "outputs": [
409
- {
410
- "output_type": "execute_result",
411
- "data": {
412
- "text/plain": [
413
- "FlaxCausalLMOutputWithCrossAttentions([('logits',\n",
414
- " DeviceArray([[[ 0.5263986 , -2.0947676 , -0.18830685, ..., 0.7599884 ,\n",
415
- " 0.6746795 , -1.0411576 ]]], dtype=float32))])"
416
- ]
417
- },
418
- "metadata": {
419
- "tags": []
420
- },
421
- "execution_count": 15
422
- }
423
- ]
424
- },
425
- {
426
- "cell_type": "code",
427
- "metadata": {
428
- "colab": {
429
- "base_uri": "https://localhost:8080/"
430
- },
431
- "id": "O6s0wtB_uTC_",
432
- "outputId": "bc0e9e80-e346-4e99-d28e-3f658eda1f66"
433
- },
434
- "source": [
435
- "outputs.logits.shape"
436
- ],
437
- "execution_count": 16,
438
- "outputs": [
439
- {
440
- "output_type": "execute_result",
441
- "data": {
442
- "text/plain": [
443
- "(1, 1, 16385)"
444
- ]
445
- },
446
- "metadata": {
447
- "tags": []
448
- },
449
- "execution_count": 16
450
- }
451
- ]
452
- },
453
- {
454
- "cell_type": "code",
455
- "metadata": {
456
- "colab": {
457
- "base_uri": "https://localhost:8080/"
458
- },
459
- "id": "ELzemGP3uBzy",
460
- "outputId": "dc12f98a-1ccf-450d-ba2a-9c29d7d14885"
461
- },
462
- "source": [
463
- "outputs.logits.argmax(axis=-1)"
464
- ],
465
- "execution_count": 17,
466
- "outputs": [
467
- {
468
- "output_type": "execute_result",
469
- "data": {
470
- "text/plain": [
471
- "DeviceArray([[12459]], dtype=int32)"
472
- ]
473
- },
474
- "metadata": {
475
- "tags": []
476
- },
477
- "execution_count": 17
478
- }
479
- ]
480
- },
481
- {
482
- "cell_type": "code",
483
- "metadata": {
484
- "colab": {
485
- "base_uri": "https://localhost:8080/"
486
- },
487
- "id": "fQjikkGEunpx",
488
- "outputId": "3dba0209-ad4e-4069-be38-6c599c677ef1"
489
- },
490
- "source": [
491
- "model.config.bos_token_id, model.config.eos_token_id, model.config.pad_token_id"
492
- ],
493
- "execution_count": 18,
494
- "outputs": [
495
- {
496
- "output_type": "execute_result",
497
- "data": {
498
- "text/plain": [
499
- "(16384, 2, 1)"
500
- ]
501
- },
502
- "metadata": {
503
- "tags": []
504
- },
505
- "execution_count": 18
506
- }
507
- ]
508
- },
509
- {
510
- "cell_type": "code",
511
- "metadata": {
512
- "id": "P32mJJSbrU1F"
513
- },
514
- "source": [
515
- "input_ids_test = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')"
516
- ],
517
- "execution_count": 19,
518
- "outputs": []
519
- },
520
- {
521
- "cell_type": "code",
522
- "metadata": {
523
- "id": "C7cHbIHruELT"
524
- },
525
- "source": [
526
- "greedy_output = model.generate(input_ids_test, max_length=50)"
527
- ],
528
- "execution_count": 20,
529
- "outputs": []
530
- },
531
- {
532
- "cell_type": "code",
533
- "metadata": {
534
- "colab": {
535
- "base_uri": "https://localhost:8080/"
536
- },
537
- "id": "jYugh9cOuwc9",
538
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
539
- },
540
- "source": [
541
- "greedy_output[0]"
542
- ],
543
- "execution_count": 21,
544
- "outputs": [
545
- {
546
- "output_type": "execute_result",
547
- "data": {
548
- "text/plain": [
549
- "DeviceArray([[16384, 0, 3570, 13405, 10186, 2392, 16362, 1869,\n",
550
- " 15772, 13546, 15772, 13546, 9348, 14791, 15772, 15772,\n",
551
- " 15772, 11272, 15772, 13546, 15772, 15772, 13546, 15772,\n",
552
- " 13546, 15772, 6642, 15772, 10776, 6431, 15772, 14567,\n",
553
- " 13406, 15772, 14567, 6235, 15772, 4909, 16160, 568,\n",
554
- " 4664, 6650, 8952, 9089, 15772, 5952, 7375, 10843,\n",
555
- " 8952, 2]], dtype=int32)"
556
- ]
557
- },
558
- "metadata": {
559
- "tags": []
560
- },
561
- "execution_count": 21
562
- }
563
- ]
564
- }
565
- ]
566
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/demo_notebook.ipynb DELETED
@@ -1,583 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "ewer-Q-0w2xA"
7
- },
8
- "source": [
9
- "# Installation"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "colab": {
17
- "base_uri": "https://localhost:8080/"
18
- },
19
- "id": "NpsF9ipLLl2s",
20
- "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32"
21
- },
22
- "outputs": [],
23
- "source": [
24
- "#!pip install git+https://github.com/huggingface/transformers/\n",
25
- "#!pip install git+https://github.com/google/flax"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": 1,
31
- "metadata": {
32
- "id": "M1wVkrpjU6zO"
33
- },
34
- "outputs": [],
35
- "source": [
36
- "%load_ext autoreload\n",
37
- "%autoreload 2"
38
- ]
39
- },
40
- {
41
- "cell_type": "code",
42
- "execution_count": 2,
43
- "metadata": {},
44
- "outputs": [
45
- {
46
- "name": "stdout",
47
- "output_type": "stream",
48
- "text": [
49
- "/home/tmabraham/vqgan-jax\n"
50
- ]
51
- }
52
- ],
53
- "source": [
54
- "%cd ../../vqgan-jax"
55
- ]
56
- },
57
- {
58
- "cell_type": "markdown",
59
- "metadata": {
60
- "id": "t47CH1H_IOT8"
61
- },
62
- "source": [
63
- "# Custom BART Model"
64
- ]
65
- },
66
- {
67
- "cell_type": "code",
68
- "execution_count": 3,
69
- "metadata": {
70
- "id": "9jQnM6S2vCpn"
71
- },
72
- "outputs": [],
73
- "source": [
74
- "# TODO: set those args in a config file\n",
75
- "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n",
76
- "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n",
77
- "BOS_TOKEN_ID = 16384\n",
78
- "BASE_MODEL = 'facebook/bart-large'"
79
- ]
80
- },
81
- {
82
- "cell_type": "code",
83
- "execution_count": 4,
84
- "metadata": {
85
- "id": "_eEaJVxAKpV5"
86
- },
87
- "outputs": [],
88
- "source": [
89
- "import jax\n",
90
- "import flax.linen as nn\n",
91
- "\n",
92
- "from transformers.models.bart.modeling_flax_bart import *\n",
93
- "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n",
94
- "\n",
95
- "class CustomFlaxBartModule(FlaxBartModule):\n",
96
- " def setup(self):\n",
97
- " # we keep shared to easily load pre-trained weights\n",
98
- " self.shared = nn.Embed(\n",
99
- " self.config.vocab_size,\n",
100
- " self.config.d_model,\n",
101
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
102
- " dtype=self.dtype,\n",
103
- " )\n",
104
- " # a separate embedding is used for the decoder\n",
105
- " self.decoder_embed = nn.Embed(\n",
106
- " OUTPUT_VOCAB_SIZE,\n",
107
- " self.config.d_model,\n",
108
- " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
109
- " dtype=self.dtype,\n",
110
- " )\n",
111
- " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n",
112
- "\n",
113
- " # the decoder has a different config\n",
114
- " decoder_config = BartConfig(self.config.to_dict())\n",
115
- " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n",
116
- " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n",
117
- " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n",
118
- "\n",
119
- "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n",
120
- " def setup(self):\n",
121
- " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n",
122
- " self.lm_head = nn.Dense(\n",
123
- " OUTPUT_VOCAB_SIZE,\n",
124
- " use_bias=False,\n",
125
- " dtype=self.dtype,\n",
126
- " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n",
127
- " )\n",
128
- " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n",
129
- "\n",
130
- "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n",
131
- " module_class = CustomFlaxBartForConditionalGenerationModule"
132
- ]
133
- },
134
- {
135
- "cell_type": "code",
136
- "execution_count": 5,
137
- "metadata": {
138
- "scrolled": true
139
- },
140
- "outputs": [
141
- {
142
- "name": "stderr",
143
- "output_type": "stream",
144
- "text": [
145
- "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mtmabraham\u001b[0m (use `wandb login --relogin` to force relogin)\n"
146
- ]
147
- },
148
- {
149
- "data": {
150
- "text/html": [
151
- "\n",
152
- " Tracking run with wandb version 0.10.33<br/>\n",
153
- " Syncing run <strong style=\"color:#cdcd00\">rare-night-7</strong> to <a href=\"https://wandb.ai\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
154
- " Project page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax</a><br/>\n",
155
- " Run page: <a href=\"https://wandb.ai/tmabraham/vqgan-jax/runs/qzxavce8\" target=\"_blank\">https://wandb.ai/tmabraham/vqgan-jax/runs/qzxavce8</a><br/>\n",
156
- " Run data is saved locally in <code>/home/tmabraham/vqgan-jax/wandb/run-20210715_075019-qzxavce8</code><br/><br/>\n",
157
- " "
158
- ],
159
- "text/plain": [
160
- "<IPython.core.display.HTML object>"
161
- ]
162
- },
163
- "metadata": {},
164
- "output_type": "display_data"
165
- },
166
- {
167
- "name": "stderr",
168
- "output_type": "stream",
169
- "text": [
170
- "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-1ef8yxby:latest, 1674.97MB. 2 files... Done. 0:0:0\n"
171
- ]
172
- }
173
- ],
174
- "source": [
175
- "import wandb\n",
176
- "run = wandb.init()\n",
177
- "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:latest', type='bart_model')\n",
178
- "artifact_dir = artifact.download()"
179
- ]
180
- },
181
- {
182
- "cell_type": "code",
183
- "execution_count": 6,
184
- "metadata": {
185
- "id": "_6-XKK40oEfP",
186
- "scrolled": true
187
- },
188
- "outputs": [
189
- {
190
- "name": "stderr",
191
- "output_type": "stream",
192
- "text": [
193
- "/home/tmabraham/dalle-mini/src/transformers/src/transformers/models/bart/configuration_bart.py:180: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n",
194
- " warnings.warn(\n",
195
- "INFO:absl:Starting the local TPU driver.\n",
196
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
197
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
198
- ]
199
- }
200
- ],
201
- "source": [
202
- "# create our model and initialize it randomly\n",
203
- "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)"
204
- ]
205
- },
206
- {
207
- "cell_type": "code",
208
- "execution_count": 7,
209
- "metadata": {},
210
- "outputs": [],
211
- "source": [
212
- "model.config.forced_bos_token_id = None"
213
- ]
214
- },
215
- {
216
- "cell_type": "code",
217
- "execution_count": 8,
218
- "metadata": {
219
- "colab": {
220
- "base_uri": "https://localhost:8080/"
221
- },
222
- "id": "Jz032w73nHEf",
223
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49"
224
- },
225
- "outputs": [
226
- {
227
- "data": {
228
- "text/plain": [
229
- "(1, 16385)"
230
- ]
231
- },
232
- "execution_count": 8,
233
- "metadata": {},
234
- "output_type": "execute_result"
235
- }
236
- ],
237
- "source": [
238
- "# we verify that the shape has not been modified\n",
239
- "model.params['final_logits_bias'].shape"
240
- ]
241
- },
242
- {
243
- "cell_type": "markdown",
244
- "metadata": {
245
- "id": "zLl24Ez5t7x1"
246
- },
247
- "source": [
248
- "## Inference"
249
- ]
250
- },
251
- {
252
- "cell_type": "code",
253
- "execution_count": 9,
254
- "metadata": {
255
- "id": "XLLA2NK3uDQr"
256
- },
257
- "outputs": [],
258
- "source": [
259
- "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)"
260
- ]
261
- },
262
- {
263
- "cell_type": "code",
264
- "execution_count": 10,
265
- "metadata": {},
266
- "outputs": [],
267
- "source": [
268
- "input_text = ['I enjoy walking with my cute dog']*8"
269
- ]
270
- },
271
- {
272
- "cell_type": "code",
273
- "execution_count": 11,
274
- "metadata": {
275
- "id": "P32mJJSbrU1F"
276
- },
277
- "outputs": [],
278
- "source": [
279
- "input_ids_test = tokenizer(input_text, return_tensors='jax')"
280
- ]
281
- },
282
- {
283
- "cell_type": "code",
284
- "execution_count": 12,
285
- "metadata": {},
286
- "outputs": [
287
- {
288
- "data": {
289
- "text/plain": [
290
- "{'input_ids': DeviceArray([[ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
291
- " 2],\n",
292
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
293
- " 2],\n",
294
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
295
- " 2],\n",
296
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
297
- " 2],\n",
298
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
299
- " 2],\n",
300
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
301
- " 2],\n",
302
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
303
- " 2],\n",
304
- " [ 0, 100, 2254, 3051, 19, 127, 11962, 2335,\n",
305
- " 2]], dtype=int32), 'attention_mask': DeviceArray([[1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
306
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
307
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
308
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
309
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
310
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
311
- " [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
312
- " [1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}"
313
- ]
314
- },
315
- "execution_count": 12,
316
- "metadata": {},
317
- "output_type": "execute_result"
318
- }
319
- ],
320
- "source": [
321
- "input_ids_test"
322
- ]
323
- },
324
- {
325
- "cell_type": "code",
326
- "execution_count": 13,
327
- "metadata": {
328
- "id": "C7cHbIHruELT"
329
- },
330
- "outputs": [],
331
- "source": [
332
- "greedy_output = model.generate(input_ids_test['input_ids'], max_length=257)"
333
- ]
334
- },
335
- {
336
- "cell_type": "code",
337
- "execution_count": 14,
338
- "metadata": {},
339
- "outputs": [
340
- {
341
- "data": {
342
- "text/plain": [
343
- "(8, 257)"
344
- ]
345
- },
346
- "execution_count": 14,
347
- "metadata": {},
348
- "output_type": "execute_result"
349
- }
350
- ],
351
- "source": [
352
- "greedy_output[0].shape"
353
- ]
354
- },
355
- {
356
- "cell_type": "code",
357
- "execution_count": 15,
358
- "metadata": {
359
- "colab": {
360
- "base_uri": "https://localhost:8080/"
361
- },
362
- "id": "jYugh9cOuwc9",
363
- "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537"
364
- },
365
- "outputs": [
366
- {
367
- "data": {
368
- "text/plain": [
369
- "DeviceArray([[16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
370
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
371
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
372
- " ...,\n",
373
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
374
- " [16384, 10042, 10042, ..., 10042, 10042, 9570],\n",
375
- " [16384, 10042, 10042, ..., 10042, 10042, 9570]], dtype=int32)"
376
- ]
377
- },
378
- "execution_count": 15,
379
- "metadata": {},
380
- "output_type": "execute_result"
381
- }
382
- ],
383
- "source": [
384
- "greedy_output[0]"
385
- ]
386
- },
387
- {
388
- "cell_type": "code",
389
- "execution_count": 16,
390
- "metadata": {},
391
- "outputs": [
392
- {
393
- "data": {
394
- "text/plain": [
395
- "DeviceArray([16384, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
396
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
397
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
398
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
399
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
400
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
401
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
402
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
403
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
404
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
405
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
406
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
407
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
408
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
409
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
410
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
411
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
412
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
413
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
414
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
415
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
416
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
417
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
418
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
419
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
420
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
421
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
422
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
423
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
424
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
425
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
426
- " 10042, 10042, 10042, 10042, 10042, 10042, 10042, 10042,\n",
427
- " 9570], dtype=int32)"
428
- ]
429
- },
430
- "execution_count": 16,
431
- "metadata": {},
432
- "output_type": "execute_result"
433
- }
434
- ],
435
- "source": [
436
- "greedy_output[0][0]"
437
- ]
438
- },
439
- {
440
- "cell_type": "markdown",
441
- "metadata": {},
442
- "source": [
443
- "# VGAN Jax"
444
- ]
445
- },
446
- {
447
- "cell_type": "code",
448
- "execution_count": 17,
449
- "metadata": {},
450
- "outputs": [],
451
- "source": [
452
- "import io\n",
453
- "\n",
454
- "import requests\n",
455
- "from PIL import Image\n",
456
- "import numpy as np\n",
457
- "\n",
458
- "import torch\n",
459
- "import torchvision.transforms as T\n",
460
- "import torchvision.transforms.functional as TF\n",
461
- "from torchvision.transforms import InterpolationMode"
462
- ]
463
- },
464
- {
465
- "cell_type": "code",
466
- "execution_count": 18,
467
- "metadata": {},
468
- "outputs": [],
469
- "source": [
470
- "from modeling_flax_vqgan import VQModel"
471
- ]
472
- },
473
- {
474
- "cell_type": "code",
475
- "execution_count": 19,
476
- "metadata": {},
477
- "outputs": [],
478
- "source": [
479
- "def custom_to_pil(x):\n",
480
- " x = np.clip(x, 0., 1.)\n",
481
- " x = (255*x).astype(np.uint8)\n",
482
- " x = Image.fromarray(x)\n",
483
- " if not x.mode == \"RGB\":\n",
484
- " x = x.convert(\"RGB\")\n",
485
- " return x"
486
- ]
487
- },
488
- {
489
- "cell_type": "code",
490
- "execution_count": 20,
491
- "metadata": {
492
- "colab": {
493
- "base_uri": "https://localhost:8080/"
494
- },
495
- "id": "Jz032w73nHEf",
496
- "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49",
497
- "scrolled": true
498
- },
499
- "outputs": [
500
- {
501
- "name": "stdout",
502
- "output_type": "stream",
503
- "text": [
504
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
505
- ]
506
- }
507
- ],
508
- "source": [
509
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
510
- ]
511
- },
512
- {
513
- "cell_type": "code",
514
- "execution_count": 21,
515
- "metadata": {},
516
- "outputs": [],
517
- "source": [
518
- "def get_images(indices, model):\n",
519
- " indices = indices[:, 1:]\n",
520
- " print(indices.shape)\n",
521
- " img = model.decode_code(indices)\n",
522
- " return img"
523
- ]
524
- },
525
- {
526
- "cell_type": "code",
527
- "execution_count": 22,
528
- "metadata": {},
529
- "outputs": [
530
- {
531
- "name": "stdout",
532
- "output_type": "stream",
533
- "text": [
534
- "(1, 256)\n",
535
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
536
- ]
537
- },
538
- {
539
- "data": {
540
- "image/png": "\n",
541
- "text/plain": [
542
- "<PIL.Image.Image image mode=RGB size=256x256 at 0x7FA20677A400>"
543
- ]
544
- },
545
- "execution_count": 22,
546
- "metadata": {},
547
- "output_type": "execute_result"
548
- }
549
- ],
550
- "source": [
551
- "custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))"
552
- ]
553
- }
554
- ],
555
- "metadata": {
556
- "accelerator": "TPU",
557
- "colab": {
558
- "collapsed_sections": [],
559
- "machine_shape": "hm",
560
- "name": "CustomBARTv4b-model-generate.ipynb",
561
- "provenance": []
562
- },
563
- "kernelspec": {
564
- "display_name": "Python 3",
565
- "language": "python",
566
- "name": "python3"
567
- },
568
- "language_info": {
569
- "codemirror_mode": {
570
- "name": "ipython",
571
- "version": 3
572
- },
573
- "file_extension": ".py",
574
- "mimetype": "text/x-python",
575
- "name": "python",
576
- "nbconvert_exporter": "python",
577
- "pygments_lexer": "ipython3",
578
- "version": "3.8.8"
579
- }
580
- },
581
- "nbformat": 4,
582
- "nbformat_minor": 1
583
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoding/vqgan-jax-encoding-with-captions.ipynb DELETED
@@ -1,363 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-with-captions"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "875c82b3",
14
- "metadata": {},
15
- "source": [
16
- "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
17
- "\n",
18
- "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 1,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "\n",
41
- "import jax\n",
42
- "from jax import pmap"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "511c3b9e",
48
- "metadata": {},
49
- "source": [
50
- "## VQGAN-JAX model"
51
- ]
52
- },
53
- {
54
- "cell_type": "markdown",
55
- "id": "bb408f6c",
56
- "metadata": {},
57
- "source": [
58
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 2,
64
- "id": "2ca50dc7",
65
- "metadata": {},
66
- "outputs": [],
67
- "source": [
68
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
69
- ]
70
- },
71
- {
72
- "cell_type": "markdown",
73
- "id": "7b60da9a",
74
- "metadata": {},
75
- "source": [
76
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
77
- ]
78
- },
79
- {
80
- "cell_type": "code",
81
- "execution_count": 3,
82
- "id": "29ce8b15",
83
- "metadata": {},
84
- "outputs": [
85
- {
86
- "data": {
87
- "application/vnd.jupyter.widget-view+json": {
88
- "model_id": "db406bdfc5d5428eaeae1631a04989dd",
89
- "version_major": 2,
90
- "version_minor": 0
91
- },
92
- "text/plain": [
93
- "Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
94
- ]
95
- },
96
- "metadata": {},
97
- "output_type": "display_data"
98
- },
99
- {
100
- "data": {
101
- "application/vnd.jupyter.widget-view+json": {
102
- "model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
103
- "version_major": 2,
104
- "version_minor": 0
105
- },
106
- "text/plain": [
107
- "Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
108
- ]
109
- },
110
- "metadata": {},
111
- "output_type": "display_data"
112
- },
113
- {
114
- "name": "stderr",
115
- "output_type": "stream",
116
- "text": [
117
- "INFO:absl:Starting the local TPU driver.\n",
118
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
119
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
120
- ]
121
- },
122
- {
123
- "name": "stdout",
124
- "output_type": "stream",
125
- "text": [
126
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
127
- ]
128
- }
129
- ],
130
- "source": [
131
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
132
- ]
133
- },
134
- {
135
- "cell_type": "markdown",
136
- "id": "c7c4c1e6",
137
- "metadata": {},
138
- "source": [
139
- "## Dataset"
140
- ]
141
- },
142
- {
143
- "cell_type": "markdown",
144
- "id": "7014a7ce",
145
- "metadata": {},
146
- "source": [
147
- "We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
148
- ]
149
- },
150
- {
151
- "cell_type": "code",
152
- "execution_count": 4,
153
- "id": "85832702",
154
- "metadata": {},
155
- "outputs": [],
156
- "source": [
157
- "from dalle_mini.dataset import *"
158
- ]
159
- },
160
- {
161
- "cell_type": "code",
162
- "execution_count": 5,
163
- "id": "81b19eca",
164
- "metadata": {},
165
- "outputs": [],
166
- "source": [
167
- "cc12m_images = '/data/CC12M/images'\n",
168
- "cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
169
- "# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
170
- "cc12m_output = '/data/CC12M/images-encoded.tsv'"
171
- ]
172
- },
173
- {
174
- "cell_type": "code",
175
- "execution_count": 6,
176
- "id": "fecc9a00",
177
- "metadata": {},
178
- "outputs": [],
179
- "source": [
180
- "image_size = 256\n",
181
- "def image_transform(image):\n",
182
- " s = min(image.size)\n",
183
- " r = image_size / s\n",
184
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
185
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
186
- " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
187
- " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
188
- " image = image.permute(0, 2, 3, 1).numpy()\n",
189
- " return image"
190
- ]
191
- },
192
- {
193
- "cell_type": "code",
194
- "execution_count": 7,
195
- "id": "4ce2211f",
196
- "metadata": {},
197
- "outputs": [],
198
- "source": [
199
- "dataset = CaptionDataset(\n",
200
- " images_root=cc12m_images,\n",
201
- " captions_path=cc12m_list,\n",
202
- " image_transform=image_transform,\n",
203
- " image_transform_type='torchvision',\n",
204
- " include_captions=False\n",
205
- ")"
206
- ]
207
- },
208
- {
209
- "cell_type": "code",
210
- "execution_count": 8,
211
- "id": "cc922704",
212
- "metadata": {},
213
- "outputs": [
214
- {
215
- "data": {
216
- "text/plain": [
217
- "8592141"
218
- ]
219
- },
220
- "execution_count": 8,
221
- "metadata": {},
222
- "output_type": "execute_result"
223
- }
224
- ],
225
- "source": [
226
- "len(dataset)"
227
- ]
228
- },
229
- {
230
- "cell_type": "markdown",
231
- "id": "62ad01c3",
232
- "metadata": {},
233
- "source": [
234
- "## Encoding"
235
- ]
236
- },
237
- {
238
- "cell_type": "code",
239
- "execution_count": 9,
240
- "id": "88f36d0b",
241
- "metadata": {},
242
- "outputs": [],
243
- "source": [
244
- "def encode(model, batch):\n",
245
- "# print(\"jitting encode function\")\n",
246
- " _, indices = model.encode(batch)\n",
247
- " return indices"
248
- ]
249
- },
250
- {
251
- "cell_type": "code",
252
- "execution_count": 10,
253
- "id": "1f35f0cb",
254
- "metadata": {},
255
- "outputs": [],
256
- "source": [
257
- "def superbatch_generator(dataloader, num_tpus):\n",
258
- " iter_loader = iter(dataloader)\n",
259
- " for batch in iter_loader:\n",
260
- " superbatch = [batch.squeeze(1)]\n",
261
- " try:\n",
262
- " for b in range(num_tpus-1):\n",
263
- " batch = next(iter_loader)\n",
264
- " if batch is None:\n",
265
- " break\n",
266
- " # Skip incomplete last batch\n",
267
- " if batch.shape[0] == dataloader.batch_size:\n",
268
- " superbatch.append(batch.squeeze(1))\n",
269
- " except StopIteration:\n",
270
- " pass\n",
271
- " superbatch = torch.stack(superbatch, axis=0)\n",
272
- " yield superbatch"
273
- ]
274
- },
275
- {
276
- "cell_type": "code",
277
- "execution_count": 11,
278
- "id": "2210705b",
279
- "metadata": {},
280
- "outputs": [],
281
- "source": [
282
- "import os\n",
283
- "\n",
284
- "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
285
- " if os.path.isfile(output_tsv):\n",
286
- " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
287
- " return\n",
288
- " \n",
289
- " num_tpus = 8 \n",
290
- " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
291
- " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
292
- " \n",
293
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
294
- "\n",
295
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
296
- " # We keep the file open to prevent excessive file seeks.\n",
297
- " with open(output_tsv, \"w\") as file:\n",
298
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
299
- " for n in tqdm(range(iterations)):\n",
300
- " superbatch = next(superbatches)\n",
301
- " encoded = p_encoder(superbatch.numpy())\n",
302
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
303
- "\n",
304
- " # Extract fields from the dataset internal `captions` property, and save to disk\n",
305
- " start_index = n * batch_size * num_tpus\n",
306
- " end_index = (n+1) * batch_size * num_tpus\n",
307
- " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
308
- " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
309
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
310
- " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
311
- " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
312
- " "
313
- ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "id": "7704863d",
319
- "metadata": {},
320
- "outputs": [
321
- {
322
- "name": "stderr",
323
- "output_type": "stream",
324
- "text": [
325
- " 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
326
- ]
327
- }
328
- ],
329
- "source": [
330
- "encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
331
- ]
332
- },
333
- {
334
- "cell_type": "markdown",
335
- "id": "8953dd84",
336
- "metadata": {},
337
- "source": [
338
- "----"
339
- ]
340
- }
341
- ],
342
- "metadata": {
343
- "kernelspec": {
344
- "display_name": "Python 3 (ipykernel)",
345
- "language": "python",
346
- "name": "python3"
347
- },
348
- "language_info": {
349
- "codemirror_mode": {
350
- "name": "ipython",
351
- "version": 3
352
- },
353
- "file_extension": ".py",
354
- "mimetype": "text/x-python",
355
- "name": "python",
356
- "nbconvert_exporter": "python",
357
- "pygments_lexer": "ipython3",
358
- "version": "3.8.10"
359
- }
360
- },
361
- "nbformat": 4,
362
- "nbformat_minor": 5
363
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoding/vqgan-jax-encoding-yfcc100m.ipynb DELETED
@@ -1,1136 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-yfcc100m"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "ba7b31e6",
14
- "metadata": {},
15
- "source": [
16
- "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
17
- "\n",
18
- "This dataset was prepared by @borisdayma in Json lines format."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 92,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "from torchvision.datasets.folder import default_loader\n",
41
- "import os\n",
42
- "\n",
43
- "import jax\n",
44
- "from jax import pmap"
45
- ]
46
- },
47
- {
48
- "cell_type": "markdown",
49
- "id": "511c3b9e",
50
- "metadata": {},
51
- "source": [
52
- "## VQGAN-JAX model"
53
- ]
54
- },
55
- {
56
- "cell_type": "markdown",
57
- "id": "bb408f6c",
58
- "metadata": {},
59
- "source": [
60
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
61
- ]
62
- },
63
- {
64
- "cell_type": "code",
65
- "execution_count": 93,
66
- "id": "2ca50dc7",
67
- "metadata": {},
68
- "outputs": [],
69
- "source": [
70
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
71
- ]
72
- },
73
- {
74
- "cell_type": "markdown",
75
- "id": "7b60da9a",
76
- "metadata": {},
77
- "source": [
78
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
79
- ]
80
- },
81
- {
82
- "cell_type": "code",
83
- "execution_count": 167,
84
- "id": "29ce8b15",
85
- "metadata": {},
86
- "outputs": [
87
- {
88
- "name": "stdout",
89
- "output_type": "stream",
90
- "text": [
91
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
92
- ]
93
- }
94
- ],
95
- "source": [
96
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
97
- ]
98
- },
99
- {
100
- "cell_type": "markdown",
101
- "id": "c7c4c1e6",
102
- "metadata": {},
103
- "source": [
104
- "## Dataset"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": 94,
110
- "id": "33861477",
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "import pandas as pd\n",
115
- "from pathlib import Path"
116
- ]
117
- },
118
- {
119
- "cell_type": "code",
120
- "execution_count": 134,
121
- "id": "81b19eca",
122
- "metadata": {},
123
- "outputs": [],
124
- "source": [
125
- "yfcc100m = Path('/home/khali/TPU-Test/YFCC100M_OpenAI_subset')\n",
126
- "# Images are 'sharded' from the following directory\n",
127
- "yfcc100m_images = yfcc100m/'data'/'data'/'images'\n",
128
- "yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'\n",
129
- "yfcc100m_output = yfcc100m/'metadata_encoded.tsv'"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "1c58bb4a",
135
- "metadata": {},
136
- "source": [
137
- "### Cleanup"
138
- ]
139
- },
140
- {
141
- "cell_type": "markdown",
142
- "id": "1a14ae3d",
143
- "metadata": {},
144
- "source": [
145
- "We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas."
146
- ]
147
- },
148
- {
149
- "cell_type": "code",
150
- "execution_count": 96,
151
- "id": "7811648c",
152
- "metadata": {},
153
- "outputs": [],
154
- "source": [
155
- "import datasets\n",
156
- "from datasets import Dataset, load_dataset"
157
- ]
158
- },
159
- {
160
- "cell_type": "code",
161
- "execution_count": 10,
162
- "id": "4811a230",
163
- "metadata": {},
164
- "outputs": [
165
- {
166
- "name": "stderr",
167
- "output_type": "stream",
168
- "text": [
169
- "tcmalloc: large alloc 1254047744 bytes == 0xb2b08000 @ 0x7f9e78632680 0x7f9e78653824 0x585b92 0x504d56 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
170
- "tcmalloc: large alloc 1254047744 bytes == 0xfd74e000 @ 0x7f9e78632680 0x7f9e78653824 0x590214 0x586f90 0x56e1f3 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
171
- "tcmalloc: large alloc 5016190976 bytes == 0x148b42000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
172
- "tcmalloc: large alloc 5019099136 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
173
- "tcmalloc: large alloc 5019811840 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
174
- "tcmalloc: large alloc 5024571392 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
175
- "tcmalloc: large alloc 5021097984 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
176
- "tcmalloc: large alloc 5022818304 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
177
- "tcmalloc: large alloc 5020794880 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
178
- "tcmalloc: large alloc 5019451392 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
179
- "tcmalloc: large alloc 5020565504 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
180
- "tcmalloc: large alloc 5012561920 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
181
- "tcmalloc: large alloc 5021835264 bytes == 0x5f6cba000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
182
- "tcmalloc: large alloc 5017436160 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n"
183
- ]
184
- }
185
- ],
186
- "source": [
187
- "# The metadata is too bog to load into memory at once, so chopping it into chunks\n",
188
- "chunk_size=1000000\n",
189
- "batch_no=1\n",
190
- "for chunk in pd.read_json(yfcc100m_metadata, orient=\"records\", lines=True,chunksize=chunk_size):\n",
191
- " chunk.to_csv('./chunks/chunk'+str(batch_no)+'.tsv', sep=\"\\t\", index=False)\n",
192
- " batch_no+=1"
193
- ]
194
- },
195
- {
196
- "cell_type": "code",
197
- "execution_count": 25,
198
- "id": "46b2f083",
199
- "metadata": {},
200
- "outputs": [
201
- {
202
- "data": {
203
- "text/html": [
204
- "<div>\n",
205
- "<style scoped>\n",
206
- " .dataframe tbody tr th:only-of-type {\n",
207
- " vertical-align: middle;\n",
208
- " }\n",
209
- "\n",
210
- " .dataframe tbody tr th {\n",
211
- " vertical-align: top;\n",
212
- " }\n",
213
- "\n",
214
- " .dataframe thead th {\n",
215
- " text-align: right;\n",
216
- " }\n",
217
- "</style>\n",
218
- "<table border=\"1\" class=\"dataframe\">\n",
219
- " <thead>\n",
220
- " <tr style=\"text-align: right;\">\n",
221
- " <th></th>\n",
222
- " <th>photoid</th>\n",
223
- " <th>uid</th>\n",
224
- " <th>unickname</th>\n",
225
- " <th>datetaken</th>\n",
226
- " <th>dateuploaded</th>\n",
227
- " <th>capturedevice</th>\n",
228
- " <th>title</th>\n",
229
- " <th>description</th>\n",
230
- " <th>usertags</th>\n",
231
- " <th>machinetags</th>\n",
232
- " <th>...</th>\n",
233
- " <th>licenseurl</th>\n",
234
- " <th>serverid</th>\n",
235
- " <th>farmid</th>\n",
236
- " <th>secret</th>\n",
237
- " <th>secretoriginal</th>\n",
238
- " <th>ext</th>\n",
239
- " <th>marker</th>\n",
240
- " <th>key</th>\n",
241
- " <th>title_clean</th>\n",
242
- " <th>description_clean</th>\n",
243
- " </tr>\n",
244
- " </thead>\n",
245
- " <tbody>\n",
246
- " <tr>\n",
247
- " <th>0</th>\n",
248
- " <td>137943</td>\n",
249
- " <td>48600072071@N01</td>\n",
250
- " <td>doctor+paradox</td>\n",
251
- " <td>2004-08-01 18:13:06.0</td>\n",
252
- " <td>1091409186</td>\n",
253
- " <td>NaN</td>\n",
254
- " <td>A+Picture+Share%21</td>\n",
255
- " <td>Antenna</td>\n",
256
- " <td>cameraphone,cayugaheights,green,hydrant,ithaca...</td>\n",
257
- " <td>NaN</td>\n",
258
- " <td>...</td>\n",
259
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
260
- " <td>1</td>\n",
261
- " <td>1</td>\n",
262
- " <td>1650c7cdc6</td>\n",
263
- " <td>1650c7cdc6</td>\n",
264
- " <td>jpg</td>\n",
265
- " <td>0</td>\n",
266
- " <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
267
- " <td>A Picture Share!</td>\n",
268
- " <td>Antenna</td>\n",
269
- " </tr>\n",
270
- " <tr>\n",
271
- " <th>1</th>\n",
272
- " <td>1246361</td>\n",
273
- " <td>44124324682@N01</td>\n",
274
- " <td>mharrsch</td>\n",
275
- " <td>2004-11-03 23:04:02.0</td>\n",
276
- " <td>1099523042</td>\n",
277
- " <td>NaN</td>\n",
278
- " <td>An+ornate+Roman+urn</td>\n",
279
- " <td>Photographed+at+the+%3Ca+href%3D%22http%3A%2F%...</td>\n",
280
- " <td>ancient,baltimore,burial,death,empire,funeral,...</td>\n",
281
- " <td>NaN</td>\n",
282
- " <td>...</td>\n",
283
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
284
- " <td>1</td>\n",
285
- " <td>1</td>\n",
286
- " <td>cf37054610</td>\n",
287
- " <td>cf37054610</td>\n",
288
- " <td>jpg</td>\n",
289
- " <td>0</td>\n",
290
- " <td>d29f01b149167d683f9ddde464bb3db</td>\n",
291
- " <td>An ornate Roman urn</td>\n",
292
- " <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
293
- " </tr>\n",
294
- " <tr>\n",
295
- " <th>2</th>\n",
296
- " <td>1251599</td>\n",
297
- " <td>51035803024@N01</td>\n",
298
- " <td>bmitd67</td>\n",
299
- " <td>2004-10-30 17:09:32.0</td>\n",
300
- " <td>1099538888</td>\n",
301
- " <td>Canon+PowerShot+S30</td>\n",
302
- " <td>Jai+%26+Tara+on+the+Cumberland</td>\n",
303
- " <td>Another+trip+for+the+happy+couple.</td>\n",
304
- " <td>blue+heron,cumberland+river,jai,tara,tennessee</td>\n",
305
- " <td>NaN</td>\n",
306
- " <td>...</td>\n",
307
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
308
- " <td>1</td>\n",
309
- " <td>1</td>\n",
310
- " <td>4a4234e32c</td>\n",
311
- " <td>4a4234e32c</td>\n",
312
- " <td>jpg</td>\n",
313
- " <td>0</td>\n",
314
- " <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
315
- " <td>Jai &amp; Tara on the Cumberland</td>\n",
316
- " <td>Another trip for the happy couple.</td>\n",
317
- " </tr>\n",
318
- " <tr>\n",
319
- " <th>3</th>\n",
320
- " <td>2348587</td>\n",
321
- " <td>73621375@N00</td>\n",
322
- " <td>Thom+Watson</td>\n",
323
- " <td>2004-12-18 21:08:09.0</td>\n",
324
- " <td>1103497228</td>\n",
325
- " <td>SONY+DSC-W1</td>\n",
326
- " <td>Castle+gate+-+%22lite-brited%22</td>\n",
327
- " <td>Taken+at+the+Miracle+of+Lights+display+in+Cent...</td>\n",
328
- " <td>bullrunpark,castle,centreville,christmas,decor...</td>\n",
329
- " <td>NaN</td>\n",
330
- " <td>...</td>\n",
331
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
332
- " <td>2</td>\n",
333
- " <td>1</td>\n",
334
- " <td>7162c974c3</td>\n",
335
- " <td>7162c974c3</td>\n",
336
- " <td>jpg</td>\n",
337
- " <td>0</td>\n",
338
- " <td>d29ce96395848478b1e8396e44899</td>\n",
339
- " <td>Castle gate - \"lite-brited\"</td>\n",
340
- " <td>Taken at the Miracle of Lights display in Cent...</td>\n",
341
- " </tr>\n",
342
- " <tr>\n",
343
- " <th>4</th>\n",
344
- " <td>3516047</td>\n",
345
- " <td>48600072071@N01</td>\n",
346
- " <td>doctor+paradox</td>\n",
347
- " <td>2005-01-18 16:44:18.0</td>\n",
348
- " <td>1106084658</td>\n",
349
- " <td>NaN</td>\n",
350
- " <td>A+Picture+Share%21</td>\n",
351
- " <td>Tabular</td>\n",
352
- " <td>cameraphone,moblog,unfound</td>\n",
353
- " <td>NaN</td>\n",
354
- " <td>...</td>\n",
355
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
356
- " <td>3</td>\n",
357
- " <td>1</td>\n",
358
- " <td>663e0d8b3d</td>\n",
359
- " <td>663e0d8b3d</td>\n",
360
- " <td>jpg</td>\n",
361
- " <td>0</td>\n",
362
- " <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
363
- " <td>A Picture Share!</td>\n",
364
- " <td>Tabular</td>\n",
365
- " </tr>\n",
366
- " <tr>\n",
367
- " <th>...</th>\n",
368
- " <td>...</td>\n",
369
- " <td>...</td>\n",
370
- " <td>...</td>\n",
371
- " <td>...</td>\n",
372
- " <td>...</td>\n",
373
- " <td>...</td>\n",
374
- " <td>...</td>\n",
375
- " <td>...</td>\n",
376
- " <td>...</td>\n",
377
- " <td>...</td>\n",
378
- " <td>...</td>\n",
379
- " <td>...</td>\n",
380
- " <td>...</td>\n",
381
- " <td>...</td>\n",
382
- " <td>...</td>\n",
383
- " <td>...</td>\n",
384
- " <td>...</td>\n",
385
- " <td>...</td>\n",
386
- " <td>...</td>\n",
387
- " <td>...</td>\n",
388
- " <td>...</td>\n",
389
- " </tr>\n",
390
- " <tr>\n",
391
- " <th>999995</th>\n",
392
- " <td>4648651054</td>\n",
393
- " <td>24511045@N04</td>\n",
394
- " <td>mtfrazier</td>\n",
395
- " <td>2010-05-02 15:47:45.0</td>\n",
396
- " <td>1275083371</td>\n",
397
- " <td>Canon+EOS+50D</td>\n",
398
- " <td>U.S.+Navy+Blue+Angels%3A+2010</td>\n",
399
- " <td>2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri</td>\n",
400
- " <td>NaN</td>\n",
401
- " <td>NaN</td>\n",
402
- " <td>...</td>\n",
403
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
404
- " <td>4072</td>\n",
405
- " <td>5</td>\n",
406
- " <td>2d12d73fb0</td>\n",
407
- " <td>dd5856ea42</td>\n",
408
- " <td>jpg</td>\n",
409
- " <td>0</td>\n",
410
- " <td>60fa2911cb81eb25b356e9fee978aef</td>\n",
411
- " <td>U.S. Navy Blue Angels: 2010</td>\n",
412
- " <td>2 May 2010 Sunday St. Joseph, Missouri</td>\n",
413
- " </tr>\n",
414
- " <tr>\n",
415
- " <th>999996</th>\n",
416
- " <td>4652130996</td>\n",
417
- " <td>21963865@N04</td>\n",
418
- " <td>GRAB1.0</td>\n",
419
- " <td>2010-05-29 19:23:10.0</td>\n",
420
- " <td>1275200833</td>\n",
421
- " <td>SONY+DSLR-A230</td>\n",
422
- " <td>Attempts+on+Her+Life</td>\n",
423
- " <td>BAPA+1+production+of+Martin+Crimp%27s+Attempts...</td>\n",
424
- " <td>NaN</td>\n",
425
- " <td>NaN</td>\n",
426
- " <td>...</td>\n",
427
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
428
- " <td>4003</td>\n",
429
- " <td>5</td>\n",
430
- " <td>8889121579</td>\n",
431
- " <td>2f46599456</td>\n",
432
- " <td>jpg</td>\n",
433
- " <td>0</td>\n",
434
- " <td>60f5ef5ce4c2d24566226abebd67d4</td>\n",
435
- " <td>Attempts on Her Life</td>\n",
436
- " <td>BAPA 1 production of Martin Crimp's Attempts o...</td>\n",
437
- " </tr>\n",
438
- " <tr>\n",
439
- " <th>999997</th>\n",
440
- " <td>4652568339</td>\n",
441
- " <td>64025277@N00</td>\n",
442
- " <td>1Sock</td>\n",
443
- " <td>2010-05-13 15:38:37.0</td>\n",
444
- " <td>1275234267</td>\n",
445
- " <td>Canon+EOS+DIGITAL+REBEL+XT</td>\n",
446
- " <td>Carlsbad+Caverns+3</td>\n",
447
- " <td>%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%...</td>\n",
448
- " <td>carlsbad,carlsbad+caverns,cave,faa,new+mexico,...</td>\n",
449
- " <td>NaN</td>\n",
450
- " <td>...</td>\n",
451
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
452
- " <td>4010</td>\n",
453
- " <td>5</td>\n",
454
- " <td>0a1808a69e</td>\n",
455
- " <td>cf6d348e3d</td>\n",
456
- " <td>jpg</td>\n",
457
- " <td>0</td>\n",
458
- " <td>60f029482d1d1028fda5281daf498f</td>\n",
459
- " <td>Carlsbad Caverns 3</td>\n",
460
- " <td>♥♥♥♥♥♥♥ Interested in purchasing this photogra...</td>\n",
461
- " </tr>\n",
462
- " <tr>\n",
463
- " <th>999998</th>\n",
464
- " <td>4653110895</td>\n",
465
- " <td>20483509@N00</td>\n",
466
- " <td>subberculture</td>\n",
467
- " <td>2010-05-30 15:37:05.0</td>\n",
468
- " <td>1275245596</td>\n",
469
- " <td>Canon+DIGITAL+IXUS+40</td>\n",
470
- " <td>Want</td>\n",
471
- " <td>Isn%27t+that+gorgeous%3F</td>\n",
472
- " <td>2010,edinburgh+museum,may,phonebox,wood</td>\n",
473
- " <td>NaN</td>\n",
474
- " <td>...</td>\n",
475
- " <td>http://creativecommons.org/licenses/by-sa/2.0/</td>\n",
476
- " <td>4066</td>\n",
477
- " <td>5</td>\n",
478
- " <td>77c3b3a254</td>\n",
479
- " <td>c4697e1511</td>\n",
480
- " <td>jpg</td>\n",
481
- " <td>0</td>\n",
482
- " <td>60f72775f433cf8de3efaeb431866153</td>\n",
483
- " <td>Want</td>\n",
484
- " <td>Isn't that gorgeous?</td>\n",
485
- " </tr>\n",
486
- " <tr>\n",
487
- " <th>999999</th>\n",
488
- " <td>4655503987</td>\n",
489
- " <td>8457193@N07</td>\n",
490
- " <td>zackojones</td>\n",
491
- " <td>2010-05-30 15:34:58.0</td>\n",
492
- " <td>1275310230</td>\n",
493
- " <td>Canon+EOS+7D</td>\n",
494
- " <td>Summertime</td>\n",
495
- " <td>You+gotta+love+it%21</td>\n",
496
- " <td>georgia,savannah,united+states,us</td>\n",
497
- " <td>NaN</td>\n",
498
- " <td>...</td>\n",
499
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
500
- " <td>4043</td>\n",
501
- " <td>5</td>\n",
502
- " <td>caff543bfe</td>\n",
503
- " <td>f60952ac4d</td>\n",
504
- " <td>jpg</td>\n",
505
- " <td>0</td>\n",
506
- " <td>60f687e11b913bce461e9525d8047e0</td>\n",
507
- " <td>Summertime</td>\n",
508
- " <td>You gotta love it!</td>\n",
509
- " </tr>\n",
510
- " </tbody>\n",
511
- "</table>\n",
512
- "<p>1000000 rows × 26 columns</p>\n",
513
- "</div>"
514
- ],
515
- "text/plain": [
516
- " photoid uid unickname datetaken \\\n",
517
- "0 137943 48600072071@N01 doctor+paradox 2004-08-01 18:13:06.0 \n",
518
- "1 1246361 44124324682@N01 mharrsch 2004-11-03 23:04:02.0 \n",
519
- "2 1251599 51035803024@N01 bmitd67 2004-10-30 17:09:32.0 \n",
520
- "3 2348587 73621375@N00 Thom+Watson 2004-12-18 21:08:09.0 \n",
521
- "4 3516047 48600072071@N01 doctor+paradox 2005-01-18 16:44:18.0 \n",
522
- "... ... ... ... ... \n",
523
- "999995 4648651054 24511045@N04 mtfrazier 2010-05-02 15:47:45.0 \n",
524
- "999996 4652130996 21963865@N04 GRAB1.0 2010-05-29 19:23:10.0 \n",
525
- "999997 4652568339 64025277@N00 1Sock 2010-05-13 15:38:37.0 \n",
526
- "999998 4653110895 20483509@N00 subberculture 2010-05-30 15:37:05.0 \n",
527
- "999999 4655503987 8457193@N07 zackojones 2010-05-30 15:34:58.0 \n",
528
- "\n",
529
- " dateuploaded capturedevice \\\n",
530
- "0 1091409186 NaN \n",
531
- "1 1099523042 NaN \n",
532
- "2 1099538888 Canon+PowerShot+S30 \n",
533
- "3 1103497228 SONY+DSC-W1 \n",
534
- "4 1106084658 NaN \n",
535
- "... ... ... \n",
536
- "999995 1275083371 Canon+EOS+50D \n",
537
- "999996 1275200833 SONY+DSLR-A230 \n",
538
- "999997 1275234267 Canon+EOS+DIGITAL+REBEL+XT \n",
539
- "999998 1275245596 Canon+DIGITAL+IXUS+40 \n",
540
- "999999 1275310230 Canon+EOS+7D \n",
541
- "\n",
542
- " title \\\n",
543
- "0 A+Picture+Share%21 \n",
544
- "1 An+ornate+Roman+urn \n",
545
- "2 Jai+%26+Tara+on+the+Cumberland \n",
546
- "3 Castle+gate+-+%22lite-brited%22 \n",
547
- "4 A+Picture+Share%21 \n",
548
- "... ... \n",
549
- "999995 U.S.+Navy+Blue+Angels%3A+2010 \n",
550
- "999996 Attempts+on+Her+Life \n",
551
- "999997 Carlsbad+Caverns+3 \n",
552
- "999998 Want \n",
553
- "999999 Summertime \n",
554
- "\n",
555
- " description \\\n",
556
- "0 Antenna \n",
557
- "1 Photographed+at+the+%3Ca+href%3D%22http%3A%2F%... \n",
558
- "2 Another+trip+for+the+happy+couple. \n",
559
- "3 Taken+at+the+Miracle+of+Lights+display+in+Cent... \n",
560
- "4 Tabular \n",
561
- "... ... \n",
562
- "999995 2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri \n",
563
- "999996 BAPA+1+production+of+Martin+Crimp%27s+Attempts... \n",
564
- "999997 %E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%... \n",
565
- "999998 Isn%27t+that+gorgeous%3F \n",
566
- "999999 You+gotta+love+it%21 \n",
567
- "\n",
568
- " usertags machinetags ... \\\n",
569
- "0 cameraphone,cayugaheights,green,hydrant,ithaca... NaN ... \n",
570
- "1 ancient,baltimore,burial,death,empire,funeral,... NaN ... \n",
571
- "2 blue+heron,cumberland+river,jai,tara,tennessee NaN ... \n",
572
- "3 bullrunpark,castle,centreville,christmas,decor... NaN ... \n",
573
- "4 cameraphone,moblog,unfound NaN ... \n",
574
- "... ... ... ... \n",
575
- "999995 NaN NaN ... \n",
576
- "999996 NaN NaN ... \n",
577
- "999997 carlsbad,carlsbad+caverns,cave,faa,new+mexico,... NaN ... \n",
578
- "999998 2010,edinburgh+museum,may,phonebox,wood NaN ... \n",
579
- "999999 georgia,savannah,united+states,us NaN ... \n",
580
- "\n",
581
- " licenseurl serverid farmid \\\n",
582
- "0 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
583
- "1 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
584
- "2 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
585
- "3 http://creativecommons.org/licenses/by-nc-sa/2.0/ 2 1 \n",
586
- "4 http://creativecommons.org/licenses/by-nc-sa/2.0/ 3 1 \n",
587
- "... ... ... ... \n",
588
- "999995 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4072 5 \n",
589
- "999996 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4003 5 \n",
590
- "999997 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4010 5 \n",
591
- "999998 http://creativecommons.org/licenses/by-sa/2.0/ 4066 5 \n",
592
- "999999 http://creativecommons.org/licenses/by-nc-sa/2.0/ 4043 5 \n",
593
- "\n",
594
- " secret secretoriginal ext marker \\\n",
595
- "0 1650c7cdc6 1650c7cdc6 jpg 0 \n",
596
- "1 cf37054610 cf37054610 jpg 0 \n",
597
- "2 4a4234e32c 4a4234e32c jpg 0 \n",
598
- "3 7162c974c3 7162c974c3 jpg 0 \n",
599
- "4 663e0d8b3d 663e0d8b3d jpg 0 \n",
600
- "... ... ... ... ... \n",
601
- "999995 2d12d73fb0 dd5856ea42 jpg 0 \n",
602
- "999996 8889121579 2f46599456 jpg 0 \n",
603
- "999997 0a1808a69e cf6d348e3d jpg 0 \n",
604
- "999998 77c3b3a254 c4697e1511 jpg 0 \n",
605
- "999999 caff543bfe f60952ac4d jpg 0 \n",
606
- "\n",
607
- " key title_clean \\\n",
608
- "0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
609
- "1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
610
- "2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
611
- "3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
612
- "4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
613
- "... ... ... \n",
614
- "999995 60fa2911cb81eb25b356e9fee978aef U.S. Navy Blue Angels: 2010 \n",
615
- "999996 60f5ef5ce4c2d24566226abebd67d4 Attempts on Her Life \n",
616
- "999997 60f029482d1d1028fda5281daf498f Carlsbad Caverns 3 \n",
617
- "999998 60f72775f433cf8de3efaeb431866153 Want \n",
618
- "999999 60f687e11b913bce461e9525d8047e0 Summertime \n",
619
- "\n",
620
- " description_clean \n",
621
- "0 Antenna \n",
622
- "1 Photographed at the Walters Art Museum, Baltim... \n",
623
- "2 Another trip for the happy couple. \n",
624
- "3 Taken at the Miracle of Lights display in Cent... \n",
625
- "4 Tabular \n",
626
- "... ... \n",
627
- "999995 2 May 2010 Sunday St. Joseph, Missouri \n",
628
- "999996 BAPA 1 production of Martin Crimp's Attempts o... \n",
629
- "999997 ♥♥♥♥♥♥♥ Interested in purchasing this photogra... \n",
630
- "999998 Isn't that gorgeous? \n",
631
- "999999 You gotta love it! \n",
632
- "\n",
633
- "[1000000 rows x 26 columns]"
634
- ]
635
- },
636
- "execution_count": 25,
637
- "metadata": {},
638
- "output_type": "execute_result"
639
- }
640
- ],
641
- "source": [
642
- "# looking up at a chunk\n",
643
- "pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")"
644
- ]
645
- },
646
- {
647
- "cell_type": "code",
648
- "execution_count": 98,
649
- "id": "c51c5597",
650
- "metadata": {},
651
- "outputs": [
652
- {
653
- "data": {
654
- "text/html": [
655
- "<div>\n",
656
- "<style scoped>\n",
657
- " .dataframe tbody tr th:only-of-type {\n",
658
- " vertical-align: middle;\n",
659
- " }\n",
660
- "\n",
661
- " .dataframe tbody tr th {\n",
662
- " vertical-align: top;\n",
663
- " }\n",
664
- "\n",
665
- " .dataframe thead th {\n",
666
- " text-align: right;\n",
667
- " }\n",
668
- "</style>\n",
669
- "<table border=\"1\" class=\"dataframe\">\n",
670
- " <thead>\n",
671
- " <tr style=\"text-align: right;\">\n",
672
- " <th></th>\n",
673
- " <th>key</th>\n",
674
- " <th>title_clean</th>\n",
675
- " <th>description_clean</th>\n",
676
- " <th>ext</th>\n",
677
- " </tr>\n",
678
- " </thead>\n",
679
- " <tbody>\n",
680
- " <tr>\n",
681
- " <th>0</th>\n",
682
- " <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
683
- " <td>A Picture Share!</td>\n",
684
- " <td>Antenna</td>\n",
685
- " <td>jpg</td>\n",
686
- " </tr>\n",
687
- " <tr>\n",
688
- " <th>1</th>\n",
689
- " <td>d29f01b149167d683f9ddde464bb3db</td>\n",
690
- " <td>An ornate Roman urn</td>\n",
691
- " <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
692
- " <td>jpg</td>\n",
693
- " </tr>\n",
694
- " <tr>\n",
695
- " <th>2</th>\n",
696
- " <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
697
- " <td>Jai &amp; Tara on the Cumberland</td>\n",
698
- " <td>Another trip for the happy couple.</td>\n",
699
- " <td>jpg</td>\n",
700
- " </tr>\n",
701
- " <tr>\n",
702
- " <th>3</th>\n",
703
- " <td>d29ce96395848478b1e8396e44899</td>\n",
704
- " <td>Castle gate - \"lite-brited\"</td>\n",
705
- " <td>Taken at the Miracle of Lights display in Cent...</td>\n",
706
- " <td>jpg</td>\n",
707
- " </tr>\n",
708
- " <tr>\n",
709
- " <th>4</th>\n",
710
- " <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
711
- " <td>A Picture Share!</td>\n",
712
- " <td>Tabular</td>\n",
713
- " <td>jpg</td>\n",
714
- " </tr>\n",
715
- " </tbody>\n",
716
- "</table>\n",
717
- "</div>"
718
- ],
719
- "text/plain": [
720
- " key title_clean \\\n",
721
- "0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
722
- "1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
723
- "2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
724
- "3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
725
- "4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
726
- "\n",
727
- " description_clean ext \n",
728
- "0 Antenna jpg \n",
729
- "1 Photographed at the Walters Art Museum, Baltim... jpg \n",
730
- "2 Another trip for the happy couple. jpg \n",
731
- "3 Taken at the Miracle of Lights display in Cent... jpg \n",
732
- "4 Tabular jpg "
733
- ]
734
- },
735
- "execution_count": 98,
736
- "metadata": {},
737
- "output_type": "execute_result"
738
- }
739
- ],
740
- "source": [
741
- "# Looking at a chunk with only the relevant columns that we need\n",
742
- "df = pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
743
- "df.head()"
744
- ]
745
- },
746
- {
747
- "cell_type": "markdown",
748
- "id": "cc1668f8",
749
- "metadata": {},
750
- "source": [
751
- "### Grabbing each chunks from the folder, cleaning it up, only taking the entries which image exist and appending it to the global df"
752
- ]
753
- },
754
- {
755
- "cell_type": "code",
756
- "execution_count": null,
757
- "id": "abbcccf3",
758
- "metadata": {},
759
- "outputs": [],
760
- "source": [
761
- "# the function that helps us to decide whether an image with certain id exists in storage, we only take the ones that we have the images for\n",
762
- "def image_exists(item):\n",
763
- " name, _, _, ext, _ = item\n",
764
- " root=str(yfcc100m_images)\n",
765
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".\"+ext)\n",
766
- " if image_path.exists():\n",
767
- " return True\n",
768
- " else:\n",
769
- " return None"
770
- ]
771
- },
772
- {
773
- "cell_type": "code",
774
- "execution_count": 86,
775
- "id": "44fa86ab",
776
- "metadata": {},
777
- "outputs": [],
778
- "source": [
779
- "# This cell does it all, grabs each chunk, cleans it up based on image existing condition, etc.\n",
780
- "global_df = pd.DataFrame()\n",
781
- "chunks_dir = \"./chunks\"\n",
782
- "for filename in os.listdir(chunks_dir):\n",
783
- " df = pd.read_csv(f\"./chunks/{str(filename)}\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
784
- " df['caption'] = df[\"title_clean\"]+\". \"+df['description_clean']\n",
785
- " df['is_exist'] = df.apply(image_exists, axis=1)\n",
786
- " df = df.dropna()[[\"key\", \"caption\"]]\n",
787
- " df.columns = ['image_file', 'caption']\n",
788
- " global_df = global_df.append(df, ignore_index=True)"
789
- ]
790
- },
791
- {
792
- "cell_type": "code",
793
- "execution_count": 89,
794
- "id": "45024fdc",
795
- "metadata": {},
796
- "outputs": [],
797
- "source": [
798
- "# saving the tsv to disk\n",
799
- "global_df.to_csv('./chunks/YFCC_subset_clean.tsv', sep=\"\\t\", index=False)"
800
- ]
801
- },
802
- {
803
- "cell_type": "code",
804
- "execution_count": 101,
805
- "id": "dca4eb73",
806
- "metadata": {},
807
- "outputs": [],
808
- "source": [
809
- "# loading the tsv from disk (for explicitness, also my electricity was gone, glad it happened after I saved to the disk :( )\n",
810
- "\n",
811
- "dataset = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")"
812
- ]
813
- },
814
- {
815
- "cell_type": "code",
816
- "execution_count": 153,
817
- "id": "a511264a",
818
- "metadata": {},
819
- "outputs": [],
820
- "source": [
821
- "\"\"\"\n",
822
- "Luke Melas-Kyriazi's dataset.py's modified version for YFCC\n",
823
- "\"\"\"\n",
824
- "import warnings\n",
825
- "from typing import Optional, Callable\n",
826
- "from pathlib import Path\n",
827
- "import numpy as np\n",
828
- "import torch\n",
829
- "import pandas as pd\n",
830
- "from torch.utils.data import Dataset\n",
831
- "from torchvision.datasets.folder import default_loader\n",
832
- "from PIL import ImageFile\n",
833
- "from PIL.Image import DecompressionBombWarning\n",
834
- "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
835
- "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
836
- "warnings.filterwarnings(\"ignore\", category=DecompressionBombWarning)\n",
837
- "\n",
838
- "\n",
839
- "class CaptionDataset(Dataset):\n",
840
- " \"\"\"\n",
841
- " A PyTorch Dataset class for (image, texts) tasks. Note that this dataset \n",
842
- " returns the raw text rather than tokens. This is done on purpose, because\n",
843
- " it's easy to tokenize a batch of text after loading it from this dataset.\n",
844
- " \"\"\"\n",
845
- "\n",
846
- " def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, \n",
847
- " image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',\n",
848
- " include_captions: bool = True):\n",
849
- " \"\"\"\n",
850
- " :param images_root: folder where images are stored\n",
851
- " :param captions_path: path to csv that maps image filenames to captions\n",
852
- " :param image_transform: image transform pipeline\n",
853
- " :param text_transform: image transform pipeline\n",
854
- " :param image_transform_type: image transform type, either `torchvision` or `albumentations`\n",
855
- " :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.\n",
856
- " \"\"\"\n",
857
- "\n",
858
- " # Base path for images\n",
859
- " self.images_root = Path(images_root)\n",
860
- "\n",
861
- " # Load captions as DataFrame\n",
862
- " self.captions = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")\n",
863
- " self.captions['image_file'] = self.captions['image_file'].astype(str)\n",
864
- "\n",
865
- " # PyTorch transformation pipeline for the image (normalizing, etc.)\n",
866
- " self.text_transform = text_transform\n",
867
- " self.image_transform = image_transform\n",
868
- " self.image_transform_type = image_transform_type.lower()\n",
869
- " assert self.image_transform_type in ['torchvision', 'albumentations']\n",
870
- "\n",
871
- " # Total number of datapoints\n",
872
- " self.size = len(self.captions)\n",
873
- "\n",
874
- " # Return image+captions or just images\n",
875
- " self.include_captions = include_captions\n",
876
- " \n",
877
- " def image_exists(item):\n",
878
- " name, caption = item\n",
879
- " root=str(self.images_root)\n",
880
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
881
- "\n",
882
- " return image_path.exists()\n",
883
- "\n",
884
- " def verify_that_all_images_exist(self):\n",
885
- " for image_file in self.captions['image_file']:\n",
886
- " if not image_exists:\n",
887
- " print(f'file does not exist: {p}')\n",
888
- "\n",
889
- " def _get_raw_image(self, i):\n",
890
- " name = self.captions.iloc[i]['image_file']\n",
891
- " image_path = (Path(self.images_root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
892
- " image = default_loader(image_path)\n",
893
- " return image\n",
894
- "\n",
895
- " def _get_raw_text(self, i):\n",
896
- " return self.captions.iloc[i]['caption']\n",
897
- "\n",
898
- " def __getitem__(self, i):\n",
899
- " image = self._get_raw_image(i)\n",
900
- " caption = self._get_raw_text(i)\n",
901
- " if self.image_transform is not None:\n",
902
- " if self.image_transform_type == 'torchvision':\n",
903
- " image = self.image_transform(image)\n",
904
- " elif self.image_transform_type == 'albumentations':\n",
905
- " image = self.image_transform(image=np.array(image))['image']\n",
906
- " else:\n",
907
- " raise NotImplementedError(f\"{self.image_transform_type=}\")\n",
908
- " return {'image': image, 'text': caption} if self.include_captions else image\n",
909
- "\n",
910
- " def __len__(self):\n",
911
- " return self.size\n",
912
- "\n",
913
- "\n",
914
- "if __name__ == \"__main__\":\n",
915
- " import albumentations as A\n",
916
- " from albumentations.pytorch import ToTensorV2\n",
917
- " from transformers import AutoTokenizer\n",
918
- " \n",
919
- "\n",
920
- " images_root = \"/home/khali/TPU-Test/YFCC100M_OpenAI_subset/data/data/images\"\n",
921
- " captions_path = './YFCC_subset_clean.tsv'\n",
922
- " image_size = 256\n",
923
- " \n",
924
- " # Create transforms\n",
925
- " def image_transform(image):\n",
926
- " s = min(image.size)\n",
927
- " r = image_size / s\n",
928
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
929
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
930
- " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
931
- " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
932
- " image = image.permute(0, 2, 3, 1).numpy()\n",
933
- " return image\n",
934
- " \n",
935
- " # Create dataset\n",
936
- " dataset = CaptionDataset(\n",
937
- " images_root=images_root,\n",
938
- " captions_path=captions_path,\n",
939
- " image_transform=image_transform,\n",
940
- " image_transform_type='torchvision',\n",
941
- " include_captions=False\n",
942
- " )"
943
- ]
944
- },
945
- {
946
- "cell_type": "code",
947
- "execution_count": 155,
948
- "id": "cc922704",
949
- "metadata": {},
950
- "outputs": [
951
- {
952
- "data": {
953
- "text/plain": [
954
- "2483316"
955
- ]
956
- },
957
- "execution_count": 155,
958
- "metadata": {},
959
- "output_type": "execute_result"
960
- }
961
- ],
962
- "source": [
963
- "len(dataset)"
964
- ]
965
- },
966
- {
967
- "cell_type": "code",
968
- "execution_count": 156,
969
- "id": "6e47ba46",
970
- "metadata": {},
971
- "outputs": [],
972
- "source": [
973
- "dataloader = DataLoader(dataset, batch_size=32, num_workers=4)"
974
- ]
975
- },
976
- {
977
- "cell_type": "code",
978
- "execution_count": 1,
979
- "id": "c8a130eb",
980
- "metadata": {},
981
- "outputs": [],
982
- "source": [
983
- "# looking at a batch\n",
984
- "next(iter(dataloader))"
985
- ]
986
- },
987
- {
988
- "cell_type": "code",
989
- "execution_count": null,
990
- "id": "c192fd44",
991
- "metadata": {},
992
- "outputs": [],
993
- "source": [
994
- "# import matplotlib.pyplot as plt\n",
995
- "# for tensor_image, _ in dataloader:\n",
996
- "# print(tensor_image)\n",
997
- "# plt.imshow(tensor_image.permute(1, 2, 0))\n",
998
- "# break"
999
- ]
1000
- },
1001
- {
1002
- "cell_type": "markdown",
1003
- "id": "62ad01c3",
1004
- "metadata": {},
1005
- "source": [
1006
- "## Encoding"
1007
- ]
1008
- },
1009
- {
1010
- "cell_type": "code",
1011
- "execution_count": 158,
1012
- "id": "88f36d0b",
1013
- "metadata": {},
1014
- "outputs": [],
1015
- "source": [
1016
- "def encode(model, batch):\n",
1017
- "# print(\"jitting encode function\")\n",
1018
- " _, indices = model.encode(batch)\n",
1019
- " return indices"
1020
- ]
1021
- },
1022
- {
1023
- "cell_type": "code",
1024
- "execution_count": 160,
1025
- "id": "1f35f0cb",
1026
- "metadata": {},
1027
- "outputs": [],
1028
- "source": [
1029
- "def superbatch_generator(dataloader, num_tpus):\n",
1030
- " iter_loader = iter(dataloader)\n",
1031
- " for batch in iter_loader:\n",
1032
- " superbatch = [batch.squeeze(1)]\n",
1033
- " try:\n",
1034
- " for b in range(num_tpus-1):\n",
1035
- " batch = next(iter_loader)\n",
1036
- " if batch is None:\n",
1037
- " break\n",
1038
- " # Skip incomplete last batch\n",
1039
- " if batch.shape[0] == dataloader.batch_size:\n",
1040
- " superbatch.append(batch.squeeze(1))\n",
1041
- " except StopIteration:\n",
1042
- " pass\n",
1043
- " superbatch = torch.stack(superbatch, axis=0)\n",
1044
- " yield superbatch"
1045
- ]
1046
- },
1047
- {
1048
- "cell_type": "code",
1049
- "execution_count": 170,
1050
- "id": "2210705b",
1051
- "metadata": {},
1052
- "outputs": [],
1053
- "source": [
1054
- "import os\n",
1055
- "\n",
1056
- "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
1057
- " if os.path.isfile(output_tsv):\n",
1058
- " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
1059
- " return\n",
1060
- " \n",
1061
- " num_tpus = 8 \n",
1062
- " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
1063
- " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
1064
- " \n",
1065
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
1066
- "\n",
1067
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
1068
- " # We keep the file open to prevent excessive file seeks.\n",
1069
- " with open(output_tsv, \"w\") as file:\n",
1070
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
1071
- " for n in tqdm(range(iterations)):\n",
1072
- " superbatch = next(superbatches)\n",
1073
- " encoded = p_encoder(superbatch.numpy())\n",
1074
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
1075
- "\n",
1076
- " # Extract fields from the dataset internal `captions` property, and save to disk\n",
1077
- " start_index = n * batch_size * num_tpus\n",
1078
- " end_index = (n+1) * batch_size * num_tpus\n",
1079
- " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
1080
- " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
1081
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
1082
- " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
1083
- " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)"
1084
- ]
1085
- },
1086
- {
1087
- "cell_type": "code",
1088
- "execution_count": 171,
1089
- "id": "7704863d",
1090
- "metadata": {},
1091
- "outputs": [
1092
- {
1093
- "name": "stderr",
1094
- "output_type": "stream",
1095
- "text": [
1096
- "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4850/4850 [2:27:51<00:00, 1.83s/it]\n"
1097
- ]
1098
- }
1099
- ],
1100
- "source": [
1101
- "encode_captioned_dataset(dataset, yfcc100m_output, batch_size=64, num_workers=16)"
1102
- ]
1103
- },
1104
- {
1105
- "cell_type": "markdown",
1106
- "id": "8953dd84",
1107
- "metadata": {},
1108
- "source": [
1109
- "----"
1110
- ]
1111
- }
1112
- ],
1113
- "metadata": {
1114
- "kernelspec": {
1115
- "name": "python3",
1116
- "display_name": "Python 3.9.0 64-bit ('Python39')"
1117
- },
1118
- "language_info": {
1119
- "codemirror_mode": {
1120
- "name": "ipython",
1121
- "version": 3
1122
- },
1123
- "file_extension": ".py",
1124
- "mimetype": "text/x-python",
1125
- "name": "python",
1126
- "nbconvert_exporter": "python",
1127
- "pygments_lexer": "ipython3",
1128
- "version": "3.9.0"
1129
- },
1130
- "interpreter": {
1131
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1132
- }
1133
- },
1134
- "nbformat": 4,
1135
- "nbformat_minor": 5
1136
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoding/vqgan-jax-encoding.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
environment.yaml DELETED
@@ -1,10 +0,0 @@
1
- name: dalle
2
- channels:
3
- - defaults
4
- dependencies:
5
- - python=3.9.5
6
- - pip=21.1.3
7
- - ipython=7.22.0
8
- - cudatoolkit
9
- - pip:
10
- - -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
img/logo.png ADDED
model/data-pipeline.ipynb DELETED
@@ -1,385 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "bf8fb38a",
6
- "metadata": {},
7
- "source": [
8
- "# Data Pipeline"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": 1,
14
- "id": "9b83dcb9",
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "from dataclasses import dataclass, field\n",
19
- "from pathlib import Path\n",
20
- "\n",
21
- "import datasets\n",
22
- "from datasets import Dataset, load_dataset\n",
23
- "import numpy as np\n",
24
- "\n",
25
- "from transformers import BartTokenizer\n",
26
- "\n",
27
- "from tqdm import tqdm\n",
28
- "\n",
29
- "import jax\n",
30
- "import jax.numpy as jnp\n",
31
- "\n",
32
- "from flax.training.common_utils import shard"
33
- ]
34
- },
35
- {
36
- "cell_type": "markdown",
37
- "id": "a661a89e",
38
- "metadata": {},
39
- "source": [
40
- "File containing image paths, captions and VQGAN-encoded indices."
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 2,
46
- "id": "0e84e889",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
51
- ]
52
- },
53
- {
54
- "cell_type": "markdown",
55
- "id": "7fdc640b",
56
- "metadata": {},
57
- "source": [
58
- "TODO: generate train/test splits if necessary."
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 3,
64
- "id": "cc6789b4",
65
- "metadata": {},
66
- "outputs": [
67
- {
68
- "name": "stderr",
69
- "output_type": "stream",
70
- "text": [
71
- "Using custom data configuration default-91833df78e844785\n",
72
- "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
73
- ]
74
- }
75
- ],
76
- "source": [
77
- "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": 4,
83
- "id": "f3ed4919",
84
- "metadata": {},
85
- "outputs": [
86
- {
87
- "data": {
88
- "text/plain": [
89
- "DatasetDict({\n",
90
- " train: Dataset({\n",
91
- " features: ['image_file', 'caption', 'encoding'],\n",
92
- " num_rows: 9999\n",
93
- " })\n",
94
- "})"
95
- ]
96
- },
97
- "execution_count": 4,
98
- "metadata": {},
99
- "output_type": "execute_result"
100
- }
101
- ],
102
- "source": [
103
- "dataset"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 5,
109
- "id": "a70c7354",
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "data": {
114
- "text/plain": [
115
- "Dataset({\n",
116
- " features: ['image_file', 'caption', 'encoding'],\n",
117
- " num_rows: 9999\n",
118
- "})"
119
- ]
120
- },
121
- "execution_count": 5,
122
- "metadata": {},
123
- "output_type": "execute_result"
124
- }
125
- ],
126
- "source": [
127
- "dataset = dataset[\"train\"]\n",
128
- "dataset"
129
- ]
130
- },
131
- {
132
- "cell_type": "markdown",
133
- "id": "a73454cf",
134
- "metadata": {},
135
- "source": [
136
- "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
137
- ]
138
- },
139
- {
140
- "cell_type": "markdown",
141
- "id": "7c0fa992",
142
- "metadata": {},
143
- "source": [
144
- "## Preprocessing"
145
- ]
146
- },
147
- {
148
- "cell_type": "markdown",
149
- "id": "a0e36582",
150
- "metadata": {},
151
- "source": [
152
- "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
153
- ]
154
- },
155
- {
156
- "cell_type": "code",
157
- "execution_count": 6,
158
- "id": "d46f6ac5",
159
- "metadata": {},
160
- "outputs": [],
161
- "source": [
162
- "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
- "max_length = 256 # Read from data_args.max_source_length\n",
164
- "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
165
- "image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
166
- ]
167
- },
168
- {
169
- "cell_type": "code",
170
- "execution_count": 7,
171
- "id": "4cac6643",
172
- "metadata": {},
173
- "outputs": [],
174
- "source": [
175
- "def preprocess_function(examples):\n",
176
- " inputs = examples[\"caption\"]\n",
177
- "# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
178
- " model_inputs = tokenizer(\n",
179
- " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
180
- " )\n",
181
- "\n",
182
- " model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
183
- "\n",
184
- " return model_inputs"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": 8,
190
- "id": "e6a4cb91",
191
- "metadata": {},
192
- "outputs": [],
193
- "source": [
194
- "num_workers = 48 # We have 96 processors in the TPU\n",
195
- "column_names = dataset.column_names\n",
196
- "input_dataset = dataset.map(preprocess_function,\n",
197
- " remove_columns=column_names,\n",
198
- " batched=True,\n",
199
- " num_proc=48\n",
200
- ")"
201
- ]
202
- },
203
- {
204
- "cell_type": "code",
205
- "execution_count": 9,
206
- "id": "a9b1b467",
207
- "metadata": {},
208
- "outputs": [],
209
- "source": [
210
- "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
211
- " \"\"\"\n",
212
- " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
213
- " Shuffle batches if `shuffle` is `True`.\n",
214
- " \"\"\"\n",
215
- " steps_per_epoch = len(dataset) // batch_size\n",
216
- "\n",
217
- " if shuffle:\n",
218
- " batch_idx = jax.random.permutation(rng, len(dataset))\n",
219
- " else:\n",
220
- " batch_idx = jnp.arange(len(dataset))\n",
221
- "\n",
222
- " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
223
- " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
224
- "\n",
225
- " for idx in batch_idx:\n",
226
- " batch = dataset[idx] \n",
227
- " batch = {k: jnp.array(v) for k, v in batch.items()}\n",
228
- " batch = shard(batch)\n",
229
- " yield batch"
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": 10,
235
- "id": "0a628505",
236
- "metadata": {},
237
- "outputs": [
238
- {
239
- "name": "stderr",
240
- "output_type": "stream",
241
- "text": [
242
- "INFO:absl:Starting the local TPU driver.\n",
243
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
244
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
245
- ]
246
- }
247
- ],
248
- "source": [
249
- "rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
250
- "batch_size = 64 # Per device\n",
251
- "super_batch_size = batch_size * jax.device_count()"
252
- ]
253
- },
254
- {
255
- "cell_type": "code",
256
- "execution_count": 11,
257
- "id": "b3a5ce7d",
258
- "metadata": {},
259
- "outputs": [],
260
- "source": [
261
- "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
262
- ]
263
- },
264
- {
265
- "cell_type": "code",
266
- "execution_count": 12,
267
- "id": "67aa8f9c",
268
- "metadata": {},
269
- "outputs": [],
270
- "source": [
271
- "superbatch = next(iter(loader))"
272
- ]
273
- },
274
- {
275
- "cell_type": "code",
276
- "execution_count": 13,
277
- "id": "7cd99402",
278
- "metadata": {},
279
- "outputs": [
280
- {
281
- "data": {
282
- "text/plain": [
283
- "dict_keys(['attention_mask', 'input_ids', 'labels'])"
284
- ]
285
- },
286
- "execution_count": 13,
287
- "metadata": {},
288
- "output_type": "execute_result"
289
- }
290
- ],
291
- "source": [
292
- "superbatch.keys()"
293
- ]
294
- },
295
- {
296
- "cell_type": "code",
297
- "execution_count": 14,
298
- "id": "652a4a9e",
299
- "metadata": {},
300
- "outputs": [
301
- {
302
- "data": {
303
- "text/plain": [
304
- "8"
305
- ]
306
- },
307
- "execution_count": 14,
308
- "metadata": {},
309
- "output_type": "execute_result"
310
- }
311
- ],
312
- "source": [
313
- "len(superbatch[\"labels\"])"
314
- ]
315
- },
316
- {
317
- "cell_type": "code",
318
- "execution_count": 15,
319
- "id": "de7de4e8",
320
- "metadata": {},
321
- "outputs": [
322
- {
323
- "data": {
324
- "text/plain": [
325
- "(8, 64, 257)"
326
- ]
327
- },
328
- "execution_count": 15,
329
- "metadata": {},
330
- "output_type": "execute_result"
331
- }
332
- ],
333
- "source": [
334
- "superbatch[\"labels\"].shape"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "id": "6800153b",
340
- "metadata": {},
341
- "source": [
342
- "Any image sequence should begin with `image_bos`:"
343
- ]
344
- },
345
- {
346
- "cell_type": "code",
347
- "execution_count": 16,
348
- "id": "cfe23a71",
349
- "metadata": {},
350
- "outputs": [],
351
- "source": [
352
- "assert superbatch[\"labels\"][1][5][0].item() == image_bos"
353
- ]
354
- },
355
- {
356
- "cell_type": "code",
357
- "execution_count": null,
358
- "id": "0fb899b4",
359
- "metadata": {},
360
- "outputs": [],
361
- "source": []
362
- }
363
- ],
364
- "metadata": {
365
- "kernelspec": {
366
- "display_name": "Python 3 (ipykernel)",
367
- "language": "python",
368
- "name": "python3"
369
- },
370
- "language_info": {
371
- "codemirror_mode": {
372
- "name": "ipython",
373
- "version": 3
374
- },
375
- "file_extension": ".py",
376
- "mimetype": "text/x-python",
377
- "name": "python",
378
- "nbconvert_exporter": "python",
379
- "pygments_lexer": "ipython3",
380
- "version": "3.8.10"
381
- }
382
- },
383
- "nbformat": 4,
384
- "nbformat_minor": 5
385
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.isort]
2
+ profile = "black"
requirements.txt DELETED
@@ -1,9 +0,0 @@
1
- # Note: install with the following command:
2
- # pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
- # Otherwise it won't find the appropriate libtpu_nightly
4
- requests
5
- jax[tpu]>=0.2.16
6
- -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
- -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
- flax
9
- jupyter
 
 
 
 
 
 
 
 
 
 
seq2seq/do_big_run.sh DELETED
@@ -1,16 +0,0 @@
1
- python run_seq2seq_flax.py \
2
- --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \ # ignored for now in our script
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \ # ignored for now in our script
5
- --output_dir output \
6
- --per_device_train_batch_size 56 \
7
- --per_device_eval_batch_size 56 \
8
- --preprocessing_num_workers 80 \
9
- --warmup_steps 125 \
10
- --gradient_accumulation_steps 8 \
11
- --do_train \
12
- --do_eval \
13
- --adafactor \
14
- --num_train_epochs 10 \
15
- --log_model \
16
- --learning_rate 0.001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
seq2seq/do_small_run.sh DELETED
@@ -1,16 +0,0 @@
1
- python run_seq2seq_flax.py \
2
- --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \ # ignored for now in our script
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \ # ignored for now in our script
5
- --output_dir output \
6
- --per_device_train_batch_size 56 \
7
- --per_device_eval_batch_size 56 \
8
- --preprocessing_num_workers 80 \
9
- --warmup_steps 125 \
10
- --gradient_accumulation_steps 8 \
11
- --do_train \
12
- --do_eval \
13
- --adafactor \
14
- --num_train_epochs 1 \
15
- --max_train_samples 20000 \
16
- --learning_rate 0.003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
seq2seq/requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- datasets >= 1.1.3
2
- jax>=0.2.8
3
- jaxlib>=0.1.59
4
- flax>=0.3.4
5
- optax>=0.0.8
6
- tensorboard
7
- nltk
8
- wandb
 
 
 
 
 
 
 
 
 
seq2seq/run_seq2seq_flax.py DELETED
@@ -1,897 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the library models for seq2seq, text to image.
18
- Script adapted from run_summarization_flax.py
19
- """
20
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
-
22
- import os
23
- # set a common huggingface cache folder (used with datasets and transformers) and wandb cache folder (used with artifacts)
24
- os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
25
- os.environ['WANDB_CACHE_DIR'] = '/data/wandb/' # required before importing wandb
26
-
27
- import logging as pylogging # To avoid collision with transformers.utils.logging
28
- import sys
29
- import time
30
- from dataclasses import dataclass, field
31
- from functools import partial
32
- from pathlib import Path
33
- from typing import Callable, Optional
34
-
35
- import datasets
36
- import nltk # Here to have a nice missing dependency error message early on
37
- import numpy as np
38
- from datasets import Dataset, load_dataset, load_metric
39
- from tqdm import tqdm
40
-
41
- import jax
42
- import jax.numpy as jnp
43
- import optax
44
- import transformers
45
- from filelock import FileLock
46
- from flax import jax_utils, traverse_util
47
- import flax.linen as nn
48
- from flax.jax_utils import unreplicate
49
- from flax.training import train_state
50
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
51
- from transformers import (
52
- CONFIG_MAPPING,
53
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
54
- AutoConfig,
55
- AutoTokenizer,
56
- FlaxAutoModelForSeq2SeqLM,
57
- FlaxBartForConditionalGeneration,
58
- HfArgumentParser,
59
- TrainingArguments,
60
- )
61
- from transformers.models.bart.modeling_flax_bart import *
62
- from transformers.file_utils import is_offline_mode
63
-
64
- import wandb
65
-
66
- logger = pylogging.getLogger(__name__)
67
-
68
- try:
69
- nltk.data.find("tokenizers/punkt")
70
- except (LookupError, OSError):
71
- if is_offline_mode():
72
- raise LookupError(
73
- "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
74
- )
75
- with FileLock(".lock") as lock:
76
- nltk.download("punkt", quiet=True)
77
-
78
-
79
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
80
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
81
-
82
-
83
- # Model hyperparameters, for convenience
84
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
85
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
86
- BOS_TOKEN_ID = 16384
87
- BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
88
-
89
-
90
- @dataclass
91
- class ModelArguments:
92
- """
93
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
94
- """
95
-
96
- model_name_or_path: Optional[str] = field(
97
- default=BASE_MODEL,
98
- metadata={
99
- "help": "The model checkpoint for weights initialization."
100
- "Don't set if you want to train a model from scratch."
101
- },
102
- )
103
- model_type: Optional[str] = field(
104
- default=None,
105
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
106
- )
107
- config_name: Optional[str] = field(
108
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
109
- )
110
- tokenizer_name: Optional[str] = field(
111
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
112
- )
113
- cache_dir: Optional[str] = field(
114
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
115
- )
116
- use_fast_tokenizer: bool = field(
117
- default=True,
118
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
119
- )
120
- dtype: Optional[str] = field(
121
- default="float32",
122
- metadata={
123
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
124
- },
125
- )
126
-
127
-
128
- @dataclass
129
- class DataTrainingArguments:
130
- """
131
- Arguments pertaining to what data we are going to input our model for training and eval.
132
- """
133
-
134
- dataset_name: Optional[str] = field(
135
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
136
- )
137
- dataset_config_name: Optional[str] = field(
138
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
139
- )
140
- text_column: Optional[str] = field(
141
- default='caption',
142
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
143
- )
144
- encoding_column: Optional[str] = field(
145
- default='encoding',
146
- metadata={"help": "The name of the column in the datasets containing the image encodings."},
147
- )
148
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
149
- validation_file: Optional[str] = field(
150
- default=None,
151
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
152
- )
153
- test_file: Optional[str] = field(
154
- default=None,
155
- metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
156
- )
157
- max_source_length: Optional[int] = field(
158
- default=128,
159
- metadata={
160
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
161
- "than this will be truncated, sequences shorter will be padded."
162
- },
163
- )
164
- no_decay: bool = field(
165
- default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
166
- )
167
- max_target_length: Optional[int] = field(
168
- default=OUTPUT_LENGTH,
169
- metadata={
170
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
171
- "than this will be truncated, sequences shorter will be padded."
172
- },
173
- )
174
- val_max_target_length: Optional[int] = field(
175
- default=OUTPUT_LENGTH,
176
- metadata={
177
- "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
178
- "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
179
- "This argument is also used to override the `max_length` param of `model.generate`, which is used "
180
- "during evaluation."
181
- },
182
- )
183
- max_train_samples: Optional[int] = field(
184
- default=None,
185
- metadata={
186
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
187
- "value if set."
188
- },
189
- )
190
- max_eval_samples: Optional[int] = field(
191
- default=None,
192
- metadata={
193
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
194
- "value if set."
195
- },
196
- )
197
- max_predict_samples: Optional[int] = field(
198
- default=None,
199
- metadata={
200
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
201
- "value if set."
202
- },
203
- )
204
- preprocessing_num_workers: Optional[int] = field(
205
- default=80, # ensure we have the same datasets cached data and avoid using too much space
206
- metadata={"help": "The number of processes to use for the preprocessing."},
207
- )
208
- source_prefix: Optional[str] = field(
209
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
210
- )
211
- predict_with_generate: bool = field(
212
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
213
- )
214
- num_beams: Optional[int] = field(
215
- default=None,
216
- metadata={
217
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
218
- "which is used during evaluation."
219
- },
220
- )
221
- overwrite_cache: bool = field(
222
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
223
- )
224
- log_interval: Optional[int] = field(
225
- default=40,
226
- metadata={
227
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
228
- "value if set."
229
- },
230
- )
231
- log_model: bool = field(
232
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
233
- )
234
- save_model_steps: Optional[int] = field(
235
- default=3000, # about once every hour in our experiments
236
- metadata={
237
- "help": "For logging the model more frequently. Used only when `log_model` is set."
238
- },
239
- )
240
-
241
- def __post_init__(self):
242
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
243
- raise ValueError("Need either a dataset name or a training/validation file.")
244
- else:
245
- if self.train_file is not None:
246
- extension = self.train_file.split(".")[-1]
247
- assert extension in ["tsv", "csv", "json"], "`train_file` should be a tsv, csv or json file."
248
- if self.validation_file is not None:
249
- extension = self.validation_file.split(".")[-1]
250
- assert extension in ["tsv", "csv", "json"], "`validation_file` should be a tsv, csv or json file."
251
- if self.val_max_target_length is None:
252
- self.val_max_target_length = self.max_target_length
253
-
254
-
255
- class TrainState(train_state.TrainState):
256
- dropout_rng: jnp.ndarray
257
- grad_accum: jnp.ndarray
258
- optimizer_step: int
259
-
260
- def replicate(self):
261
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
262
-
263
-
264
- class CustomFlaxBartModule(FlaxBartModule):
265
- def setup(self):
266
- # we keep shared to easily load pre-trained weights
267
- self.shared = nn.Embed(
268
- self.config.vocab_size,
269
- self.config.d_model,
270
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
271
- dtype=self.dtype,
272
- )
273
- # a separate embedding is used for the decoder
274
- self.decoder_embed = nn.Embed(
275
- OUTPUT_VOCAB_SIZE,
276
- self.config.d_model,
277
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
278
- dtype=self.dtype,
279
- )
280
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
281
-
282
- # the decoder has a different config
283
- decoder_config = BartConfig(self.config.to_dict())
284
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
285
- decoder_config.min_length = OUTPUT_LENGTH
286
- decoder_config.max_length = OUTPUT_LENGTH
287
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
288
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
289
-
290
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
291
- def setup(self):
292
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
293
- self.lm_head = nn.Dense(
294
- OUTPUT_VOCAB_SIZE,
295
- use_bias=False,
296
- dtype=self.dtype,
297
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
298
- )
299
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
300
-
301
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
302
- module_class = CustomFlaxBartForConditionalGenerationModule
303
-
304
-
305
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
306
- """
307
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
308
- Shuffle batches if `shuffle` is `True`.
309
- """
310
- steps_per_epoch = len(dataset) // batch_size
311
-
312
- if shuffle:
313
- batch_idx = jax.random.permutation(rng, len(dataset))
314
- else:
315
- batch_idx = jnp.arange(len(dataset))
316
-
317
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
318
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
319
-
320
- for idx in batch_idx:
321
- batch = dataset[idx]
322
- batch = {k: jnp.array(v) for k, v in batch.items()}
323
-
324
- batch = shard(batch)
325
-
326
- yield batch
327
-
328
-
329
- def create_learning_rate_fn(
330
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
331
- ) -> Callable[[int], jnp.array]:
332
- """Returns a linear warmup, linear_decay learning rate function."""
333
- steps_per_epoch = train_ds_size // train_batch_size
334
- num_train_steps = steps_per_epoch * num_train_epochs
335
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
336
- if no_decay:
337
- return warmup_fn
338
- decay_fn = optax.linear_schedule(
339
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
340
- )
341
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
342
- return schedule_fn
343
-
344
-
345
- def wandb_log(metrics, step=None, prefix=None):
346
- if jax.process_index() == 0:
347
- log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
348
- if step is not None:
349
- log_metrics['train/step'] = step
350
- wandb.log(log_metrics)
351
-
352
-
353
- def main():
354
- # See all possible arguments in src/transformers/training_args.py
355
- # or by passing the --help flag to this script.
356
- # We now keep distinct sets of args, for a cleaner separation of concerns.
357
-
358
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
359
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
360
- # If we pass only one argument to the script and it's the path to a json file,
361
- # let's parse it to get our arguments.
362
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
363
- else:
364
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
365
-
366
- logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
367
- training_args.eval_steps = 400
368
-
369
- if (
370
- os.path.exists(training_args.output_dir)
371
- and os.listdir(training_args.output_dir)
372
- and training_args.do_train
373
- and not training_args.overwrite_output_dir
374
- ):
375
- raise ValueError(
376
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
377
- "Use --overwrite_output_dir to overcome."
378
- )
379
-
380
- # Set up wandb run
381
- wandb.init(
382
- entity='wandb',
383
- project='hf-flax-dalle-mini',
384
- job_type='Seq2SeqVQGAN',
385
- config=parser.parse_args()
386
- )
387
-
388
- # set default x-axis as 'train/step'
389
- wandb.define_metric('train/step')
390
- wandb.define_metric('*', step_metric='train/step')
391
-
392
- # Make one log on every process with the configuration for debugging.
393
- pylogging.basicConfig(
394
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
395
- datefmt="%m/%d/%Y %H:%M:%S",
396
- level=pylogging.INFO,
397
- )
398
- # Setup logging, we only want one process per machine to log things on the screen.
399
- logger.setLevel(pylogging.INFO if jax.process_index() == 0 else pylogging.ERROR)
400
- if jax.process_index() == 0:
401
- datasets.utils.logging.set_verbosity_warning()
402
- transformers.utils.logging.set_verbosity_info()
403
- else:
404
- datasets.utils.logging.set_verbosity_error()
405
- transformers.utils.logging.set_verbosity_error()
406
-
407
- # Set the verbosity to info of the Transformers logger (on main process only):
408
- logger.info(f"Training/evaluation parameters {training_args}")
409
-
410
- # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
411
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
412
- # (the dataset will be downloaded automatically from the datasets Hub).
413
- #
414
- data_files = {}
415
- logger.warning(f"Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
416
- if data_args.train_file is not None:
417
- data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv"]
418
- if data_args.validation_file is not None:
419
- data_files["validation"] = ["/data/CC3M/validation-encoded.tsv"]
420
- if data_args.test_file is not None:
421
- data_files["test"] = data_args.test_file
422
- dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
423
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
424
- # https://huggingface.co/docs/datasets/loading_datasets.html.
425
-
426
- # Load pretrained model and tokenizer
427
- base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
428
- model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
429
- )
430
- tokenizer = AutoTokenizer.from_pretrained(
431
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
432
- )
433
-
434
- # Set up our new model config
435
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
436
- config.tie_word_embeddings = False
437
- config.decoder_start_token_id = BOS_TOKEN_ID
438
- config.bos_token_id = BOS_TOKEN_ID # should not be used
439
- config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
440
- config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
441
- config.forced_bos_token_id = None # we don't need this token
442
- config.forced_eos_token_id = None # we don't need this token
443
- #config.min_length = data_args.max_target_length # Set only in decoder?
444
- #config.max_length = data_args.max_target_length # Set only in decoder?
445
-
446
- print(f"TPUs: {jax.device_count()}")
447
- assert jax.device_count() == 8, "TPUs in use, please check running processes"
448
-
449
- # Create a custom model and initialize it randomly
450
- model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
451
-
452
- # Use pre-trained weights for encoder
453
- model.params['model']['encoder'] = base_model.params['model']['encoder']
454
- model.params['model']['shared'] = base_model.params['model']['shared']
455
- del base_model
456
-
457
- prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
458
-
459
- # Preprocessing the datasets.
460
- # We need to tokenize inputs and targets.
461
- if training_args.do_train:
462
- column_names = dataset["train"].column_names
463
- elif training_args.do_eval:
464
- column_names = dataset["validation"].column_names
465
- elif training_args.do_predict:
466
- column_names = dataset["test"].column_names
467
- else:
468
- logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
469
- return
470
-
471
- # Get the column names for input/target.
472
- text_column = data_args.text_column
473
- encoding_column = data_args.encoding_column
474
-
475
- # Temporarily set max_target_length for training.
476
- max_target_length = data_args.max_target_length
477
-
478
- def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
479
- """
480
- Shift input ids one token to the right.
481
- """
482
- shifted_input_ids = np.zeros(input_ids.shape)
483
- shifted_input_ids[:, 1:] = input_ids[:, :-1]
484
- shifted_input_ids[:, 0] = decoder_start_token_id
485
- return shifted_input_ids
486
-
487
- def preprocess_function(examples):
488
- inputs = examples[text_column]
489
- inputs = [prefix + inp for inp in inputs]
490
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
491
- model_inputs = tokenizer(
492
- inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
493
- )
494
-
495
- # set up targets
496
- # Note: labels correspond to our target indices
497
- # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
498
- labels = [eval(indices) for indices in examples['encoding']]
499
- labels = np.asarray(labels)
500
-
501
- # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
502
- model_inputs["labels"] = labels
503
-
504
- # In our case, this prepends the bos token and removes the last one
505
- decoder_input_ids = shift_tokens_right(labels, config.decoder_start_token_id)
506
- model_inputs["decoder_input_ids"] = decoder_input_ids
507
-
508
- return model_inputs
509
-
510
- if training_args.do_train:
511
- if "train" not in dataset:
512
- raise ValueError("--do_train requires a train dataset")
513
- train_dataset = dataset["train"]
514
- if data_args.max_train_samples is not None:
515
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
516
- train_dataset = train_dataset.map(
517
- preprocess_function,
518
- batched=True,
519
- num_proc=data_args.preprocessing_num_workers,
520
- remove_columns=column_names,
521
- load_from_cache_file=not data_args.overwrite_cache,
522
- desc="Running tokenizer on train dataset",
523
- )
524
-
525
- if training_args.do_eval:
526
- max_target_length = data_args.val_max_target_length
527
- if "validation" not in dataset:
528
- raise ValueError("--do_eval requires a validation dataset")
529
- eval_dataset = dataset["validation"]
530
- if data_args.max_eval_samples is not None:
531
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
532
- eval_dataset = eval_dataset.map(
533
- preprocess_function,
534
- batched=True,
535
- num_proc=data_args.preprocessing_num_workers,
536
- remove_columns=column_names,
537
- load_from_cache_file=not data_args.overwrite_cache,
538
- desc="Running tokenizer on validation dataset",
539
- )
540
-
541
- if training_args.do_predict:
542
- max_target_length = data_args.val_max_target_length
543
- if "test" not in dataset:
544
- raise ValueError("--do_predict requires a test dataset")
545
- predict_dataset = dataset["test"]
546
- if data_args.max_predict_samples is not None:
547
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
548
- predict_dataset = predict_dataset.map(
549
- preprocess_function,
550
- batched=True,
551
- num_proc=data_args.preprocessing_num_workers,
552
- remove_columns=column_names,
553
- load_from_cache_file=not data_args.overwrite_cache,
554
- desc="Running tokenizer on prediction dataset",
555
- )
556
-
557
- # Metric
558
- #metric = load_metric("rouge")
559
-
560
- def postprocess_text(preds, labels):
561
- preds = [pred.strip() for pred in preds]
562
- labels = [label.strip() for label in labels]
563
-
564
- # rougeLSum expects newline after each sentence
565
- preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
566
- labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
567
-
568
- return preds, labels
569
-
570
- def compute_metrics(preds, labels):
571
- decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
572
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
573
-
574
- # Some simple post-processing
575
- decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
576
-
577
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
578
- # Extract a few results from ROUGE
579
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
580
-
581
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
582
- result["gen_len"] = np.mean(prediction_lens)
583
- result = {k: round(v, 4) for k, v in result.items()}
584
- return result
585
-
586
- # Initialize our training
587
- rng = jax.random.PRNGKey(training_args.seed)
588
- rng, dropout_rng = jax.random.split(rng)
589
-
590
- # Store some constant
591
- num_epochs = int(training_args.num_train_epochs)
592
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
593
- total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
594
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
595
- steps_per_epoch = len(train_dataset) // train_batch_size
596
- total_steps = steps_per_epoch * num_epochs
597
- total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
598
-
599
- # Create learning rate schedule
600
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
601
- len(train_dataset),
602
- total_batch_size,
603
- training_args.num_train_epochs,
604
- training_args.warmup_steps,
605
- training_args.learning_rate,
606
- data_args.no_decay
607
- )
608
-
609
- # We use Optax's "masking" functionality to not apply weight decay
610
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
611
- # mask boolean with the same structure as the parameters.
612
- # The mask is True for parameters that should be decayed.
613
- # Note that this mask is specifically adapted for FlaxBart.
614
- # For FlaxT5, one should correct the layer norm parameter naming
615
- # accordingly - see `run_t5_mlm_flax.py` e.g.
616
- def decay_mask_fn(params):
617
- flat_params = traverse_util.flatten_dict(params)
618
- layer_norm_params = [
619
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
620
- ]
621
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
622
- return traverse_util.unflatten_dict(flat_mask)
623
-
624
- # create adam optimizer
625
- if training_args.adafactor:
626
- # We use the default parameters here to initialize adafactor,
627
- # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
628
- optimizer = optax.adafactor(
629
- learning_rate=linear_decay_lr_schedule_fn,
630
- )
631
- else:
632
- optimizer = optax.adamw(
633
- learning_rate=linear_decay_lr_schedule_fn,
634
- b1=training_args.adam_beta1,
635
- b2=training_args.adam_beta2,
636
- eps=training_args.adam_epsilon,
637
- weight_decay=training_args.weight_decay,
638
- mask=decay_mask_fn,
639
- )
640
-
641
- # Setup train state
642
- state = TrainState.create(
643
- apply_fn=model.__call__,
644
- params=model.params,
645
- tx=optimizer,
646
- dropout_rng=dropout_rng,
647
- grad_accum=jax.tree_map(jnp.zeros_like, model.params),
648
- optimizer_step=0,
649
- )
650
-
651
- # label smoothed cross entropy
652
- def loss_fn(logits, labels):
653
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
654
- loss = loss.mean()
655
- return loss
656
-
657
- # Define gradient update step fn
658
- def train_step(state, batch):
659
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
660
-
661
- def compute_loss(params):
662
- labels = batch.pop("labels")
663
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
664
- loss = loss_fn(logits, labels)
665
- return loss
666
-
667
- grad_fn = jax.value_and_grad(compute_loss)
668
- loss, grads = grad_fn(state.params)
669
- grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
670
-
671
- def update_fn():
672
- grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
673
- grads = jax.lax.pmean(grads, "batch")
674
- new_state = state.apply_gradients(
675
- grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
676
- )
677
- return new_state
678
-
679
- new_state = jax.lax.cond(
680
- (state.step + 1) % training_args.gradient_accumulation_steps == 0,
681
- lambda _: update_fn(),
682
- lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
683
- None,
684
- )
685
-
686
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
687
- metrics = jax.lax.pmean(metrics, axis_name="batch")
688
-
689
- return new_state.replace(dropout_rng=new_dropout_rng), metrics
690
-
691
- # Define eval fn
692
- def eval_step(params, batch):
693
- labels = batch.pop("labels")
694
- logits = model(**batch, params=params, train=False)[0]
695
- loss = loss_fn(logits, labels)
696
-
697
- # summarize metrics
698
- metrics = {"loss": loss}
699
- metrics = jax.lax.pmean(metrics, axis_name="batch")
700
- return metrics
701
-
702
- # Define generation function
703
- max_length = (
704
- data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
705
- )
706
- num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
707
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
708
-
709
- def generate_step(params, batch):
710
- model.params = params
711
- output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
712
- return output_ids.sequences
713
-
714
- # Create parallel version of the train and eval step
715
- p_train_step = jax.pmap(
716
- train_step, "batch", donate_argnums=(0,)
717
- )
718
- p_eval_step = jax.pmap(eval_step, "batch")
719
- p_generate_step = jax.pmap(generate_step, "batch")
720
-
721
- # Replicate the train state on each device
722
- state = state.replicate()
723
-
724
- logger.info("***** Running training *****")
725
- logger.info(f" Num examples = {len(train_dataset)}")
726
- logger.info(f" Num Epochs = {num_epochs}")
727
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
728
- logger.info(
729
- f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
730
- )
731
- logger.info(f" Total global steps = {total_steps}")
732
- logger.info(f" Total optimization steps = {total_optimization_steps}")
733
-
734
- train_time = 0
735
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
736
- global_step = 0
737
-
738
- def run_evaluation():
739
- # ======================== Evaluating ==============================
740
- eval_metrics = []
741
- if training_args.do_eval:
742
- eval_preds = []
743
- eval_labels = []
744
-
745
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
746
- eval_steps = len(eval_dataset) // eval_batch_size
747
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
748
- # Model forward
749
- batch = next(eval_loader)
750
- labels = batch["labels"]
751
-
752
- metrics = p_eval_step(state.params, batch)
753
- eval_metrics.append(metrics)
754
-
755
- # generation
756
- if data_args.predict_with_generate:
757
- generated_ids = p_generate_step(state.params, batch)
758
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
759
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
760
-
761
- # normalize eval metrics
762
- eval_metrics = get_metrics(eval_metrics)
763
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
764
-
765
- # log metrics
766
- wandb_log(eval_metrics, step=global_step, prefix='eval')
767
-
768
- # compute ROUGE metrics
769
- rouge_desc = ""
770
- # if data_args.predict_with_generate:
771
- # rouge_metrics = compute_metrics(eval_preds, eval_labels)
772
- # eval_metrics.update(rouge_metrics)
773
- # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
774
-
775
- # Print metrics and update progress bar
776
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
777
- epochs.write(desc)
778
- epochs.desc = desc
779
-
780
- return eval_metrics
781
-
782
- def run_save_model(step, epoch, eval_metrics=None):
783
- if jax.process_index() == 0:
784
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
-
786
- # save model locally
787
- model.save_pretrained(
788
- training_args.output_dir,
789
- params=params,
790
- )
791
-
792
- # save to W&B
793
- if data_args.log_model:
794
- metadata = {'step': step, 'epoch': epoch}
795
- if eval_metrics is not None:
796
- metadata['eval/loss'] = eval_metrics['loss']
797
- artifact = wandb.Artifact(
798
- name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
799
- )
800
- artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
801
- artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
802
- wandb.run.log_artifact(artifact)
803
-
804
- # save to the hub
805
- if training_args.push_to_hub:
806
- model.save_pretrained(
807
- training_args.output_dir,
808
- params=params,
809
- push_to_hub=training_args.push_to_hub,
810
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
811
- temp_dir=True # avoid issues with being in a repository
812
- )
813
-
814
- for epoch in epochs:
815
- # ======================== Training ================================
816
- train_start = time.time()
817
-
818
- # Create sampling rng
819
- rng, input_rng = jax.random.split(rng)
820
-
821
- # Generate an epoch by shuffling sampling indices from the train dataset
822
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
823
- steps_per_epoch = len(train_dataset) // train_batch_size
824
- # train
825
- for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
826
- global_step +=1
827
- batch = next(train_loader)
828
- state, train_metric = p_train_step(state, batch)
829
-
830
- if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
831
- # log metrics
832
- wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
833
-
834
- if global_step % training_args.eval_steps == 0:
835
- run_evaluation()
836
-
837
- if global_step % data_args.save_model_steps == 0:
838
- run_save_model(global_step, epoch)
839
-
840
- # log final train metrics
841
- wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
842
-
843
- train_time += time.time() - train_start
844
- train_metric = unreplicate(train_metric)
845
- epochs.write(
846
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
847
- )
848
-
849
- # Final evaluation
850
- eval_metrics = run_evaluation()
851
-
852
- # save checkpoint after each epoch and push checkpoint to the hub
853
- run_save_model(global_step, epoch, eval_metrics)
854
-
855
-
856
- # ======================== Prediction loop ==============================
857
- if training_args.do_predict:
858
- logger.info("*** Predict ***")
859
-
860
- pred_metrics = []
861
- pred_generations = []
862
- pred_labels = []
863
-
864
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
865
- pred_steps = len(predict_dataset) // eval_batch_size
866
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
867
- # Model forward
868
- batch = next(pred_loader)
869
- labels = batch["labels"]
870
-
871
- metrics = p_eval_step(state.params, batch)
872
- pred_metrics.append(metrics)
873
-
874
- # generation
875
- if data_args.predict_with_generate:
876
- generated_ids = p_generate_step(state.params, batch)
877
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
878
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
879
-
880
- # normalize prediction metrics
881
- pred_metrics = get_metrics(pred_metrics)
882
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
883
-
884
- # compute ROUGE metrics
885
- rouge_desc = ""
886
- if data_args.predict_with_generate:
887
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
888
- pred_metrics.update(rouge_metrics)
889
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
890
-
891
- # Print metrics
892
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
893
- logger.info(desc)
894
-
895
-
896
- if __name__ == "__main__":
897
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.cfg ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = dalle_mini
3
+ version = attr: dalle_mini.__version__
4
+ description = DALL·E mini - Generate images from a text prompt
5
+ long_description = file: README.md
6
+ long_description_content_type = text/markdown
7
+ url = https://github.com/borisdayma/dalle-mini
8
+ project_urls =
9
+ Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
10
+
11
+ [options]
12
+ packages = find:
13
+ install_requires =
14
+ transformers
15
+ unidecode
16
+ ftfy
17
+ pillow
18
+ jax
19
+ flax
20
+
21
+ [options.extras_require]
22
+ dev =
23
+ tqdm
24
+ wandb
25
+ optax
26
+ black[jupyter]
27
+ isort
setup.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ if __name__ == "__main__":
4
+ setup()
tools/dataset/encode_dataset.ipynb ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Pre-encoding a dataset for DALLE·mini"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
+ "\n",
18
+ "Adapt it to your own dataset and image encoder.\n",
19
+ "\n",
20
+ "At the end you should have a dataset of pairs:\n",
21
+ "* a caption defined as a string\n",
22
+ "* an encoded image defined as a list of int."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "3b59489e",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "from tqdm.notebook import tqdm\n",
33
+ "\n",
34
+ "import torchvision.transforms as T\n",
35
+ "\n",
36
+ "import webdataset as wds\n",
37
+ "\n",
38
+ "import jax\n",
39
+ "import braceexpand\n",
40
+ "from pathlib import Path"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "c7c4c1e6",
46
+ "metadata": {},
47
+ "source": [
48
+ "## Configuration Parameters"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "1265dbfe",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
59
+ "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
60
+ "\n",
61
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
62
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
63
+ " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
64
+ ")\n",
65
+ "\n",
66
+ "# good defaults for a TPU v3-8\n",
67
+ "batch_size = 128 # Per device\n",
68
+ "num_workers = 8 # For parallel processing\n",
69
+ "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
70
+ "save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "['XXX/shard-0000.tar',\n",
83
+ " 'XXX/shard-0001.tar',\n",
84
+ " 'XXX/shard-0002.tar',\n",
85
+ " 'XXX/shard-0003.tar',\n",
86
+ " 'XXX/shard-0004.tar',\n",
87
+ " 'XXX/shard-0005.tar',\n",
88
+ " 'XXX/shard-0006.tar',\n",
89
+ " 'XXX/shard-0007.tar',\n",
90
+ " 'XXX/shard-0008.tar']"
91
+ ]
92
+ },
93
+ "execution_count": 5,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "shards = list(\n",
100
+ " braceexpand.braceexpand(shards)\n",
101
+ ") # better display for tqdm with known length"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "75dba8e2",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Load data"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "a1e8fb95",
115
+ "metadata": {},
116
+ "source": [
117
+ "We load data using `webdataset`."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "id": "9ef5de9e",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "ds = (\n",
128
+ " wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
129
+ " .decode(\"rgb\", handler=wds.warn_and_continue)\n",
130
+ " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
131
+ " .batched(total_bs) # load in batch per worker (faster)\n",
132
+ ")"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "id": "90981824",
138
+ "metadata": {},
139
+ "source": [
140
+ "Note:\n",
141
+ "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
142
+ "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
143
+ "* you can also filter out some items using `select`."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "129c377d",
149
+ "metadata": {},
150
+ "source": [
151
+ "We can now inspect our data."
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "8cac98cb",
158
+ "metadata": {
159
+ "scrolled": true
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "%%time\n",
164
+ "images, captions = next(iter(ds))"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "cd268fbf",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "images.shape"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "5acfc4d8",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "captions[:10]"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "id": "c24693c0",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "T.ToPILImage()(images[0].permute(2, 0, 1))"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "3059ffb1",
200
+ "metadata": {},
201
+ "source": [
202
+ "Finally we create our dataloader."
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "c227c551",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "dl = (\n",
213
+ " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
214
+ ") # avoid partial batch at the end of each worker"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "a354472b",
220
+ "metadata": {},
221
+ "source": [
222
+ "## Image encoder\n",
223
+ "\n",
224
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "47a8b818",
231
+ "metadata": {
232
+ "scrolled": true
233
+ },
234
+ "outputs": [],
235
+ "source": [
236
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
237
+ "from flax.jax_utils import replicate\n",
238
+ "\n",
239
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
240
+ "vqgan_params = replicate(vqgan.params)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "62ad01c3",
246
+ "metadata": {},
247
+ "source": [
248
+ "## Encoding"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "20357f74",
254
+ "metadata": {},
255
+ "source": [
256
+ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "322a4619",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "from flax.training.common_utils import shard\n",
267
+ "from functools import partial\n",
268
+ "\n",
269
+ "\n",
270
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
271
+ "def p_encode(batch, params):\n",
272
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
273
+ " _, indices = vqgan.encode(batch, params=params)\n",
274
+ " return indices"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "ff6c10d4",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "import pandas as pd\n",
285
+ "\n",
286
+ "\n",
287
+ "def encode_dataset(dataloader, output_dir, save_frequency):\n",
288
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
289
+ " all_captions = []\n",
290
+ " all_encoding = []\n",
291
+ " n_file = 1\n",
292
+ " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
293
+ " images = images.numpy()\n",
294
+ " n = len(images) // 8 * 8\n",
295
+ " if n != len(images):\n",
296
+ " # get the max number of images we can (multiple of 8)\n",
297
+ " print(f\"Different sizes {n} vs {len(images)}\")\n",
298
+ " images = images[:n]\n",
299
+ " captions = captions[:n]\n",
300
+ " if not len(captions):\n",
301
+ " print(f\"No images/captions in batch...\")\n",
302
+ " continue\n",
303
+ " images = shard(images)\n",
304
+ " encoded = p_encode(images, vqgan_params)\n",
305
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
306
+ " all_captions.extend(captions)\n",
307
+ " all_encoding.extend(encoded.tolist())\n",
308
+ "\n",
309
+ " # save files\n",
310
+ " if (idx + 1) % save_frequency == 0:\n",
311
+ " print(f\"Saving file {n_file}\")\n",
312
+ " batch_df = pd.DataFrame.from_dict(\n",
313
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
314
+ " )\n",
315
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
316
+ " all_captions = []\n",
317
+ " all_encoding = []\n",
318
+ " n_file += 1\n",
319
+ "\n",
320
+ " if len(all_captions):\n",
321
+ " print(f\"Saving final file {n_file}\")\n",
322
+ " batch_df = pd.DataFrame.from_dict(\n",
323
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
324
+ " )\n",
325
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "7704863d",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "markdown",
340
+ "id": "8953dd84",
341
+ "metadata": {},
342
+ "source": [
343
+ "----"
344
+ ]
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "interpreter": {
349
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
350
+ },
351
+ "kernelspec": {
352
+ "display_name": "Python 3 (ipykernel)",
353
+ "language": "python",
354
+ "name": "python3"
355
+ },
356
+ "language_info": {
357
+ "codemirror_mode": {
358
+ "name": "ipython",
359
+ "version": 3
360
+ },
361
+ "file_extension": ".py",
362
+ "mimetype": "text/x-python",
363
+ "name": "python",
364
+ "nbconvert_exporter": "python",
365
+ "pygments_lexer": "ipython3",
366
+ "version": "3.9.7"
367
+ }
368
+ },
369
+ "nbformat": 4,
370
+ "nbformat_minor": 5
371
+ }
tools/inference/inference_pipeline.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
tools/inference/log_inference_samples.ipynb ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import tempfile\n",
11
+ "from functools import partial\n",
12
+ "import random\n",
13
+ "import numpy as np\n",
14
+ "from PIL import Image\n",
15
+ "from tqdm.notebook import tqdm\n",
16
+ "import jax\n",
17
+ "import jax.numpy as jnp\n",
18
+ "from flax.training.common_utils import shard, shard_prng_key\n",
19
+ "from flax.jax_utils import replicate\n",
20
+ "import wandb\n",
21
+ "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
22
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
23
+ "from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n",
24
+ "from dalle_mini.text import TextNormalizer"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "run_ids = [\"63otg87g\"]\n",
35
+ "ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
36
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
37
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
38
+ " \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
39
+ ")\n",
40
+ "latest_only = True # log only latest or all versions\n",
41
+ "suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
42
+ "add_clip_32 = False"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "71f27b96-7e6c-4472-a2e4-e99a8fb67a72",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# model.generate parameters - Not used yet\n",
53
+ "gen_top_k = None\n",
54
+ "gen_top_p = None\n",
55
+ "temperature = None"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "batch_size = 8\n",
66
+ "num_images = 128\n",
67
+ "top_k = 8\n",
68
+ "text_normalizer = TextNormalizer()\n",
69
+ "padding_item = \"NONE\"\n",
70
+ "seed = random.randint(0, 2 ** 32 - 1)\n",
71
+ "key = jax.random.PRNGKey(seed)\n",
72
+ "api = wandb.Api()"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
83
+ "vqgan_params = replicate(vqgan.params)\n",
84
+ "\n",
85
+ "clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
86
+ "processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
87
+ "clip16_params = replicate(clip16.params)\n",
88
+ "\n",
89
+ "if add_clip_32:\n",
90
+ " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
91
+ " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
92
+ " clip32_params = replicate(clip32.params)"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
103
+ "def p_decode(indices, params):\n",
104
+ " return vqgan.decode_code(indices, params=params)\n",
105
+ "\n",
106
+ "\n",
107
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
108
+ "def p_clip16(inputs, params):\n",
109
+ " logits = clip16(params=params, **inputs).logits_per_image\n",
110
+ " return logits\n",
111
+ "\n",
112
+ "\n",
113
+ "if add_clip_32:\n",
114
+ "\n",
115
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
116
+ " def p_clip32(inputs, params):\n",
117
+ " logits = clip32(params=params, **inputs).logits_per_image\n",
118
+ " return logits"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
129
+ " samples = [l.strip() for l in f.readlines()]\n",
130
+ " # make list multiple of batch_size by adding elements\n",
131
+ " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
132
+ " samples.extend(samples_to_add)\n",
133
+ " # reshape\n",
134
+ " samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "def get_artifact_versions(run_id, latest_only=False):\n",
145
+ " try:\n",
146
+ " if latest_only:\n",
147
+ " return [\n",
148
+ " api.artifact(\n",
149
+ " type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
150
+ " )\n",
151
+ " ]\n",
152
+ " else:\n",
153
+ " return api.artifact_versions(\n",
154
+ " type_name=\"bart_model\",\n",
155
+ " name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
156
+ " per_page=10000,\n",
157
+ " )\n",
158
+ " except:\n",
159
+ " return []"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "def get_training_config(run_id):\n",
170
+ " training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
171
+ " config = training_run.config\n",
172
+ " return config"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "# retrieve inference run details\n",
183
+ "def get_last_inference_version(run_id):\n",
184
+ " try:\n",
185
+ " inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
186
+ " return inference_run.summary.get(\"version\", None)\n",
187
+ " except:\n",
188
+ " return None"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "# compile functions - needed only once per run\n",
199
+ "def pmap_model_function(model):\n",
200
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
201
+ " def _generate(tokenized_prompt, key, params):\n",
202
+ " return model.generate(\n",
203
+ " **tokenized_prompt,\n",
204
+ " do_sample=True,\n",
205
+ " num_beams=1,\n",
206
+ " prng_key=key,\n",
207
+ " params=params,\n",
208
+ " top_k=gen_top_k,\n",
209
+ " top_p=gen_top_p\n",
210
+ " )\n",
211
+ "\n",
212
+ " return _generate"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "run_id = run_ids[0]\n",
223
+ "# TODO: loop over runs"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
234
+ "last_inference_version = get_last_inference_version(run_id)\n",
235
+ "training_config = get_training_config(run_id)\n",
236
+ "run = None\n",
237
+ "p_generate = None\n",
238
+ "model_files = [\n",
239
+ " \"config.json\",\n",
240
+ " \"flax_model.msgpack\",\n",
241
+ " \"merges.txt\",\n",
242
+ " \"special_tokens_map.json\",\n",
243
+ " \"tokenizer.json\",\n",
244
+ " \"tokenizer_config.json\",\n",
245
+ " \"vocab.json\",\n",
246
+ "]\n",
247
+ "for artifact in artifact_versions:\n",
248
+ " print(f\"Processing artifact: {artifact.name}\")\n",
249
+ " version = int(artifact.version[1:])\n",
250
+ " results16, results32 = [], []\n",
251
+ " columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
252
+ "\n",
253
+ " if latest_only:\n",
254
+ " assert last_inference_version is None or version > last_inference_version\n",
255
+ " else:\n",
256
+ " if last_inference_version is None:\n",
257
+ " # we should start from v0\n",
258
+ " assert version == 0\n",
259
+ " elif version <= last_inference_version:\n",
260
+ " print(\n",
261
+ " f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
262
+ " )\n",
263
+ " else:\n",
264
+ " # check we are logging the correct version\n",
265
+ " assert version == last_inference_version + 1\n",
266
+ "\n",
267
+ " # start/resume corresponding run\n",
268
+ " if run is None:\n",
269
+ " run = wandb.init(\n",
270
+ " job_type=\"inference\",\n",
271
+ " entity=\"dalle-mini\",\n",
272
+ " project=\"dalle-mini\",\n",
273
+ " config=training_config,\n",
274
+ " id=f\"{run_id}-clip16{suffix}\",\n",
275
+ " resume=\"allow\",\n",
276
+ " )\n",
277
+ "\n",
278
+ " # work in temporary directory\n",
279
+ " with tempfile.TemporaryDirectory() as tmp:\n",
280
+ "\n",
281
+ " # download model files\n",
282
+ " artifact = run.use_artifact(artifact)\n",
283
+ " for f in model_files:\n",
284
+ " artifact.get_path(f).download(tmp)\n",
285
+ "\n",
286
+ " # load tokenizer and model\n",
287
+ " tokenizer = BartTokenizer.from_pretrained(tmp)\n",
288
+ " model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n",
289
+ " model_params = replicate(model.params)\n",
290
+ "\n",
291
+ " # pmap model function needs to happen only once per model config\n",
292
+ " if p_generate is None:\n",
293
+ " p_generate = pmap_model_function(model)\n",
294
+ "\n",
295
+ " # process one batch of captions\n",
296
+ " for batch in tqdm(samples):\n",
297
+ " processed_prompts = (\n",
298
+ " [text_normalizer(x) for x in batch]\n",
299
+ " if model.config.normalize_text\n",
300
+ " else list(batch)\n",
301
+ " )\n",
302
+ "\n",
303
+ " # repeat the prompts to distribute over each device and tokenize\n",
304
+ " processed_prompts = processed_prompts * jax.device_count()\n",
305
+ " tokenized_prompt = tokenizer(\n",
306
+ " processed_prompts,\n",
307
+ " return_tensors=\"jax\",\n",
308
+ " padding=\"max_length\",\n",
309
+ " truncation=True,\n",
310
+ " max_length=128,\n",
311
+ " ).data\n",
312
+ " tokenized_prompt = shard(tokenized_prompt)\n",
313
+ "\n",
314
+ " # generate images\n",
315
+ " images = []\n",
316
+ " pbar = tqdm(\n",
317
+ " range(num_images // jax.device_count()),\n",
318
+ " desc=\"Generating Images\",\n",
319
+ " leave=True,\n",
320
+ " )\n",
321
+ " for i in pbar:\n",
322
+ " key, subkey = jax.random.split(key)\n",
323
+ " encoded_images = p_generate(\n",
324
+ " tokenized_prompt, shard_prng_key(subkey), model_params\n",
325
+ " )\n",
326
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
327
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
328
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
329
+ " (-1, 256, 256, 3)\n",
330
+ " )\n",
331
+ " for img in decoded_images:\n",
332
+ " images.append(\n",
333
+ " Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
334
+ " )\n",
335
+ "\n",
336
+ " def add_clip_results(results, processor, p_clip, clip_params):\n",
337
+ " clip_inputs = processor(\n",
338
+ " text=batch,\n",
339
+ " images=images,\n",
340
+ " return_tensors=\"np\",\n",
341
+ " padding=\"max_length\",\n",
342
+ " max_length=77,\n",
343
+ " truncation=True,\n",
344
+ " ).data\n",
345
+ " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
346
+ " images_per_prompt_indices = np.asarray(\n",
347
+ " range(0, len(images), batch_size)\n",
348
+ " )\n",
349
+ " clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
350
+ " list(\n",
351
+ " clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
352
+ " for i in range(batch_size)\n",
353
+ " )\n",
354
+ " )\n",
355
+ " clip_inputs = shard(clip_inputs)\n",
356
+ " logits = p_clip(clip_inputs, clip_params)\n",
357
+ " logits = logits.reshape(-1, num_images)\n",
358
+ " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
359
+ " logits = jax.device_get(logits)\n",
360
+ " # add to results table\n",
361
+ " for i, (idx, scores, sample) in enumerate(\n",
362
+ " zip(top_scores, logits, batch)\n",
363
+ " ):\n",
364
+ " if sample == padding_item:\n",
365
+ " continue\n",
366
+ " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
367
+ " top_images = [\n",
368
+ " wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
369
+ " for x in idx\n",
370
+ " ]\n",
371
+ " results.append([sample] + top_images)\n",
372
+ "\n",
373
+ " # get clip scores\n",
374
+ " pbar.set_description(\"Calculating CLIP 16 scores\")\n",
375
+ " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
376
+ "\n",
377
+ " # get clip 32 scores\n",
378
+ " if add_clip_32:\n",
379
+ " pbar.set_description(\"Calculating CLIP 32 scores\")\n",
380
+ " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
381
+ "\n",
382
+ " pbar.close()\n",
383
+ "\n",
384
+ " # log results\n",
385
+ " table = wandb.Table(columns=columns, data=results16)\n",
386
+ " run.log({\"Samples\": table, \"version\": version})\n",
387
+ " wandb.finish()\n",
388
+ "\n",
389
+ " if add_clip_32:\n",
390
+ " run = wandb.init(\n",
391
+ " job_type=\"inference\",\n",
392
+ " entity=\"dalle-mini\",\n",
393
+ " project=\"dalle-mini\",\n",
394
+ " config=training_config,\n",
395
+ " id=f\"{run_id}-clip32{suffix}\",\n",
396
+ " resume=\"allow\",\n",
397
+ " )\n",
398
+ " table = wandb.Table(columns=columns, data=results32)\n",
399
+ " run.log({\"Samples\": table, \"version\": version})\n",
400
+ " wandb.finish()\n",
401
+ " run = None # ensure we don't log on this run"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "415d3f54-7226-43de-9eea-4283a948dc93",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": []
411
+ }
412
+ ],
413
+ "metadata": {
414
+ "kernelspec": {
415
+ "display_name": "Python 3 (ipykernel)",
416
+ "language": "python",
417
+ "name": "python3"
418
+ },
419
+ "language_info": {
420
+ "codemirror_mode": {
421
+ "name": "ipython",
422
+ "version": 3
423
+ },
424
+ "file_extension": ".py",
425
+ "mimetype": "text/x-python",
426
+ "name": "python",
427
+ "nbconvert_exporter": "python",
428
+ "pygments_lexer": "ipython3",
429
+ "version": "3.9.7"
430
+ }
431
+ },
432
+ "nbformat": 4,
433
+ "nbformat_minor": 5
434
+ }
tools/inference/samples.txt ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ t-shirt, size M
2
+ flower dress, size M
3
+ white snow covered mountain under blue sky during daytime
4
+ aerial view of the beach during daytime
5
+ aerial view of the beach at night
6
+ a beautiful sunset at a beach with a shell on the shore
7
+ a farmhouse surrounded by beautiful flowers
8
+ sunset over green mountains
9
+ a photo of san francisco golden gate bridge
10
+ painting of an oniric forest glade surrounded by tall trees
11
+ a graphite sketch of a gothic cathedral
12
+ a graphite sketch of Elon Musk
13
+ still life in the style of Kandinsky
14
+ still life in the style of Picasso
15
+ a colorful stairway to heaven
16
+ a background consisting of colors blue, green, and red
17
+ Mohammed Ali and Mike Tyson in a match
18
+ Pele and Maradona in a match
19
+ view of Mars from space
20
+ a picture of the Eiffel tower on the moon
21
+ a picture of the Eiffel tower on the moon, Earth is in the background
22
+ watercolor of the Eiffel tower on the moon
23
+ the moon is a skull
24
+ epic sword fight
25
+ underwater cathedral
26
+ a photo of a fantasy version of New York City
27
+ a picture of fantasy kingdoms
28
+ a volcano erupting next to San Francisco golden gate bridge
29
+ Paris in a far future, futuristic Paris
30
+ real painting of an alien from Monet
31
+ the communist statue of liberty
32
+ robots taking control over humans
33
+ illustration of an astronaut in a space suit playing guitar
34
+ a clown wearing a spacesuit floating in space
35
+ a dog playing with a ball
36
+ a cat sits on top of an alligator
37
+ a very cute cat laying by a big bike
38
+ a rat holding a red lightsaber in a white background
39
+ a very cute giraffe making a funny face
40
+ A unicorn is passing by a rainbow in a field of flowers
41
+ an elephant made of carrots
42
+ an elephant on a unicycle during a circus
43
+ photography of a penguin watching television
44
+ a penguin is walking on the Moon, Earth is in the background
45
+ a penguin standing on a tower of books holds onto a rope from a helicopter
46
+ rat wearing a crown
47
+ looking into the sky, 10 airplanes are seen overhead
48
+ shelves filled with books and alchemy potion bottles
49
+ this is a detailed high-resolution scan of a human brain
50
+ a restaurant menu
51
+ a bottle of coca-cola on a table
52
+ a peanut
53
+ a cross-section view of a walnut
54
+ a living room with two white armchairs and a painting of the collosseum. The painting is mounted above a modern fireplace.
55
+ a long line of alternating green and red blocks
56
+ a long line of green blocks on a beach at subset
57
+ a long line of peaches on a beach at sunset
58
+ a picture of a castle from minecraft
59
+ a cute pikachu teapot
60
+ an illustration of pikachu sitting on a bench eating an ice cream
61
+ mario is jumping over a zebra
62
+ famous anime hero
63
+ star wars concept art
64
+ Cartoon of a carrot with big eyes
65
+ a cartoon of a superhero bear
66
+ an illustration of a cute skeleton wearing a blue hoodie
67
+ illustration of a baby shark swimming around corals
68
+ an illustration of an avocado in a beanie riding a motorcycle
69
+ logo of a robot wearing glasses and reading a book
70
+ illustration of a cactus lifting weigths
71
+ logo of a cactus lifting weights
72
+ a photo of a camera from the future
73
+ a skeleton with the shape of a spider
74
+ a collection of glasses is sitting on a table
75
+ a painting of a capybara sitting on a mountain during fall in surrealist style
76
+ a pentagonal green clock
77
+ a small red block sitting on a large green block
78
+ a storefront that has the word 'openai' written on it
79
+ a tatoo of a black broccoli
80
+ a variety of clocks is sitting on a table
81
+ a table has a train model on it with other cars and things
82
+ a pixel art illustration of an eagle sitting in a field in the afternoon
83
+ an emoji of a baby fox wearing a blue hat, green gloves, red shirt, and yellow pants
84
+ an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
85
+ an extreme close-up view of a capybara sitting in a field
86
+ an illustration of a baby cucumber with a mustache playing chess
87
+ an illustration of a baby daikon radish in a tutu walking a dog
88
+ an illustration of a baby hedgehog in a cape staring at its reflection in a mirror
89
+ an illustration of a baby panda with headphones holding an umbrella in the rain
90
+ urinals are lined up in a jungle
91
+ a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
92
+ a human face
93
+ a person is holding a phone and a waterbottle, running a marathon
94
+ a child eating a birthday cake near some balloons
95
+ Young woman riding her bike through the forest
96
+ the best soccer team of the world
97
+ the best football team of the world
98
+ the best basketball team of the world
99
+ happy, happiness
100
+ sad, sadness
101
+ the representation of infinity
102
+ the end of the world
103
+ the last sunrise on earth
104
+ a portrait of a nightmare creature watching at you
105
+ an avocado armchair
106
+ an armchair in the shape of an avocado
107
+ illustration of an avocado armchair
108
+ illustration of an armchair in the shape of an avocado
109
+ logo of an avocado armchair
110
+ an avocado armchair flying into space
111
+ a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
112
+ an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
113
+ illustration of an avocado armchair getting married to a pineapple
114
+ half human half cat
115
+ half human half dog
116
+ half human half pen
117
+ half human half garbage
118
+ half human half avocado
119
+ half human half Eiffel tower
120
+ a propaganda poster for transhumanism
121
+ a propaganda poster for building a space elevator
122
+ a beautiful epic fantasy painting of a space elevator
123
+ a transformer architecture
124
+ a transformer in real life
{seq2seq → tools/train}/sweep.yaml RENAMED
@@ -1,6 +1,6 @@
1
- program: run_seq2seq_flax.py
2
- entity: wandb
3
- project: hf-flax-dalle-mini
4
  method: random
5
  metric:
6
  name: eval/loss
@@ -8,36 +8,47 @@ metric:
8
  parameters:
9
  learning_rate:
10
  distribution: log_uniform
11
- # from exp(min) to exp(max), ie 5e-5 to 5e-3 on log scale
12
- min: -9.9
13
- max: -5.3
14
  gradient_accumulation_steps:
15
  value: 8
16
  warmup_steps:
17
- # in term of optimization steps so multiplied by gradient accumulation
18
- value: 125
19
  command:
20
  - python3
21
  - ${program}
22
- - "--train_file"
23
- - "/data/CC12M/encoded-small-train.tsv"
24
- - "--validation_file"
25
- - "/data/CC12M/encoded-small-valid.tsv"
26
- - "--output_dir"
27
- - "./output_sweep"
28
- - "--overwrite_output_dir"
29
- - "--adafactor"
30
- - "--num_train_epochs"
31
- - 1
32
- - "--max_train_samples"
33
- - 1500000
 
 
34
  - "--per_device_train_batch_size"
35
  - 56
36
  - "--per_device_eval_batch_size"
37
  - 56
38
- - "--preprocessing_num_workers"
39
- - 80
40
- - "--no_decay"
41
  - "--do_train"
42
  - "--do_eval"
 
 
 
 
 
 
 
 
 
 
 
43
  - ${args}
 
1
+ program: train.py
2
+ entity: dalle-mini
3
+ project: dalle-mini
4
  method: random
5
  metric:
6
  name: eval/loss
 
8
  parameters:
9
  learning_rate:
10
  distribution: log_uniform
11
+ # from exp(min) to exp(max)
12
+ min: -6.9
13
+ max: -3.5
14
  gradient_accumulation_steps:
15
  value: 8
16
  warmup_steps:
17
+ value: 4000
18
+ #TODO: outdated command
19
  command:
20
  - python3
21
  - ${program}
22
+ - "--tokenizer_name"
23
+ - "boris/dalle-mini-tokenizer"
24
+ - "--config_name"
25
+ - "facebook/bart-large-cnn"
26
+ - "--dataset_repo_or_path"
27
+ - "boris/gis_vqgan_f16_16384"
28
+ - "--streaming"
29
+ - "--use_auth_token"
30
+ - "--image_vocab_size"
31
+ - 16384
32
+ - "--image_length"
33
+ - 256
34
+ - "--normalize_text"
35
+ - True
36
  - "--per_device_train_batch_size"
37
  - 56
38
  - "--per_device_eval_batch_size"
39
  - 56
40
+ - "--adafactor"
 
 
41
  - "--do_train"
42
  - "--do_eval"
43
+ - "--num_train_epochs"
44
+ - 1
45
+ - "--logging_steps"
46
+ - 40
47
+ - "--eval_steps"
48
+ - 800
49
+ - "--output_dir"
50
+ - "./output"
51
+ - "--overwrite_output_dir"
52
+ - "--max_train_samples"
53
+ - 10000000
54
  - ${args}
tools/train/train.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for seq2seq, text to image.
18
+ Script adapted from run_summarization_flax.py
19
+ """
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import asdict, dataclass, field
27
+ from pathlib import Path
28
+ from typing import Callable, Optional
29
+
30
+ import datasets
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import optax
34
+ import transformers
35
+ import wandb
36
+ from datasets import Dataset
37
+ from flax import jax_utils, traverse_util
38
+ from flax.jax_utils import unreplicate
39
+ from flax.serialization import from_bytes, to_bytes
40
+ from flax.training import train_state
41
+ from flax.training.common_utils import get_metrics, onehot, shard_prng_key
42
+ from tqdm import tqdm
43
+ from transformers import AutoTokenizer, HfArgumentParser
44
+ from transformers.models.bart.modeling_flax_bart import BartConfig
45
+
46
+ from dalle_mini.data import Dataset
47
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ @dataclass
53
+ class ModelArguments:
54
+ """
55
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
56
+ """
57
+
58
+ model_name_or_path: Optional[str] = field(
59
+ default=None,
60
+ metadata={
61
+ "help": "The model checkpoint for weights initialization."
62
+ "Don't set if you want to train a model from scratch."
63
+ },
64
+ )
65
+ config_name: Optional[str] = field(
66
+ default=None,
67
+ metadata={
68
+ "help": "Pretrained config name or path if not the same as model_name"
69
+ },
70
+ )
71
+ image_vocab_size: Optional[int] = field(
72
+ default=None,
73
+ metadata={"help": "Vocab size of image encoder"},
74
+ )
75
+ image_length: Optional[int] = field(
76
+ default=None,
77
+ metadata={"help": "Number of tokens per image"},
78
+ )
79
+ tokenizer_name: Optional[str] = field(
80
+ default=None,
81
+ metadata={
82
+ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
83
+ },
84
+ )
85
+ normalize_text: Optional[bool] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
89
+ },
90
+ )
91
+ dtype: Optional[str] = field(
92
+ default="float32",
93
+ metadata={
94
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
95
+ },
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class DataTrainingArguments:
101
+ """
102
+ Arguments pertaining to what data we are going to input our model for training and eval.
103
+ """
104
+
105
+ text_column: Optional[str] = field(
106
+ default="caption",
107
+ metadata={
108
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
109
+ },
110
+ )
111
+ encoding_column: Optional[str] = field(
112
+ default="encoding",
113
+ metadata={
114
+ "help": "The name of the column in the datasets containing the image encodings."
115
+ },
116
+ )
117
+ dataset_repo_or_path: str = field(
118
+ default=None,
119
+ metadata={"help": "The dataset repository containing encoded files."},
120
+ )
121
+ train_file: Optional[str] = field(
122
+ default=None,
123
+ metadata={"help": "The input training data file (glob acceptable)."},
124
+ )
125
+ validation_file: Optional[str] = field(
126
+ default=None,
127
+ metadata={"help": "An optional input evaluation data file (glob acceptable)."},
128
+ )
129
+ dataset_type: str = field(
130
+ default="datasets",
131
+ metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
132
+ )
133
+ # data loading should not be a bottleneck so we use "streaming" mode by default
134
+ streaming: bool = field(
135
+ default=True,
136
+ metadata={"help": "Whether to stream the dataset."},
137
+ )
138
+ use_auth_token: bool = field(
139
+ default=False,
140
+ metadata={
141
+ "help": "Whether to use the authentication token for private datasets."
142
+ },
143
+ )
144
+ max_source_length: Optional[int] = field(
145
+ default=128,
146
+ metadata={
147
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
148
+ "than this will be truncated, sequences shorter will be padded."
149
+ },
150
+ )
151
+ max_train_samples: Optional[int] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
155
+ "value if set."
156
+ },
157
+ )
158
+ max_eval_samples: Optional[int] = field(
159
+ default=None,
160
+ metadata={
161
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
162
+ "value if set."
163
+ },
164
+ )
165
+ preprocessing_num_workers: Optional[int] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
169
+ },
170
+ )
171
+ overwrite_cache: bool = field(
172
+ default=False,
173
+ metadata={
174
+ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
175
+ },
176
+ )
177
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
178
+ seed_dataset: int = field(
179
+ default=None,
180
+ metadata={
181
+ "help": "Random seed for the dataset that will be set at the beginning of training."
182
+ },
183
+ )
184
+
185
+ def __post_init__(self):
186
+ if self.dataset_repo_or_path is None:
187
+ raise ValueError("Need a dataset repository or path.")
188
+
189
+
190
+ @dataclass
191
+ class TrainingArguments:
192
+ """
193
+ Arguments pertaining to training parameters.
194
+ """
195
+
196
+ output_dir: str = field(
197
+ metadata={
198
+ "help": "The output directory where the model predictions and checkpoints will be written."
199
+ },
200
+ )
201
+ overwrite_output_dir: bool = field(
202
+ default=False,
203
+ metadata={
204
+ "help": (
205
+ "Overwrite the content of the output directory. "
206
+ "Use this to continue training if output_dir points to a checkpoint directory."
207
+ )
208
+ },
209
+ )
210
+
211
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
212
+ do_eval: bool = field(
213
+ default=False, metadata={"help": "Whether to run eval on the dev set."}
214
+ )
215
+
216
+ per_device_train_batch_size: int = field(
217
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
218
+ )
219
+ per_device_eval_batch_size: int = field(
220
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
221
+ )
222
+
223
+ gradient_accumulation_steps: int = field(
224
+ default=1,
225
+ metadata={
226
+ "help": "Number of updates steps to accumulate before performing a backward/update pass."
227
+ },
228
+ )
229
+
230
+ learning_rate: float = field(
231
+ default=5e-5, metadata={"help": "The initial learning rate."}
232
+ )
233
+ adafactor: bool = field(
234
+ default=False,
235
+ metadata={"help": "Whether or not to replace AdamW by Adafactor."},
236
+ )
237
+ weight_decay: float = field(
238
+ default=None, metadata={"help": "Weight decay if we apply some."}
239
+ )
240
+ adam_beta1: float = field(
241
+ default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
242
+ )
243
+ adam_beta2: float = field(
244
+ default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
245
+ )
246
+ adam_epsilon: float = field(
247
+ default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
248
+ )
249
+ max_grad_norm: float = field(
250
+ default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
251
+ )
252
+ use_decay: bool = field(
253
+ default=False,
254
+ metadata={"help": "Whether to use decay in the learning rate scheduler."},
255
+ )
256
+
257
+ num_train_epochs: float = field(
258
+ default=3.0, metadata={"help": "Total number of training epochs to perform."}
259
+ )
260
+ warmup_steps: int = field(
261
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
262
+ )
263
+
264
+ logging_steps: int = field(
265
+ default=40, metadata={"help": "Log every X updates steps."}
266
+ )
267
+ eval_steps: int = field(
268
+ default=400, metadata={"help": "Run an evaluation every X steps."}
269
+ )
270
+ save_steps: int = field(
271
+ default=4000, metadata={"help": "Save checkpoint every X updates steps."}
272
+ )
273
+ log_model: bool = field(
274
+ default=False,
275
+ metadata={"help": "Log model to wandb at `save_steps` frequency."},
276
+ )
277
+
278
+ seed_model: int = field(
279
+ default=42,
280
+ metadata={
281
+ "help": "Random seed for the model that will be set at the beginning of training."
282
+ },
283
+ )
284
+
285
+ push_to_hub: bool = field(
286
+ default=False,
287
+ metadata={
288
+ "help": "Whether or not to upload the trained model to the model hub after training."
289
+ },
290
+ )
291
+
292
+ resume_from_checkpoint: Optional[str] = field(
293
+ default=None,
294
+ metadata={"help": "Reference to a wandb artifact for resuming training."},
295
+ )
296
+
297
+
298
+ class TrainState(train_state.TrainState):
299
+ dropout_rng: jnp.ndarray = None
300
+ epoch: int = 0
301
+ train_time: float = 0.0 # total time the model trained
302
+ train_samples: int = 0 # number of samples seen
303
+
304
+ def replicate(self):
305
+ return jax_utils.replicate(self).replace(
306
+ dropout_rng=shard_prng_key(self.dropout_rng)
307
+ )
308
+
309
+ def restore_state(self, artifact_dir):
310
+ # restore optimizer state
311
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
312
+ new_opt_state = from_bytes(self.opt_state, f.read())
313
+
314
+ # restore other parameters
315
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
316
+ training_state = json.load(f)
317
+
318
+ # replace state
319
+ return self.replace(
320
+ opt_state=new_opt_state,
321
+ step=training_state["step"],
322
+ train_time=training_state["train_time"],
323
+ train_samples=training_state["train_samples"],
324
+ )
325
+
326
+
327
+ def create_learning_rate_fn(
328
+ num_warmup_steps: int,
329
+ learning_rate: float,
330
+ use_decay: bool,
331
+ num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
332
+ ) -> Callable[[int], jnp.array]:
333
+ """Returns a linear warmup, linear_decay learning rate function."""
334
+ if use_decay:
335
+ assert (
336
+ num_train_steps is not None
337
+ ), "Learning rate with decay requires number of training steps"
338
+ warmup_fn = optax.linear_schedule(
339
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
340
+ )
341
+ if not use_decay:
342
+ return warmup_fn
343
+ decay_fn = optax.linear_schedule(
344
+ init_value=learning_rate,
345
+ end_value=0,
346
+ transition_steps=num_train_steps - num_warmup_steps,
347
+ )
348
+ schedule_fn = optax.join_schedules(
349
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
350
+ )
351
+ return schedule_fn
352
+
353
+
354
+ def wandb_log(metrics, step=None, prefix=None):
355
+ if jax.process_index() == 0:
356
+ log_metrics = {
357
+ f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
358
+ }
359
+ if step is not None:
360
+ log_metrics["train/step"] = step
361
+ wandb.log(log_metrics)
362
+
363
+
364
+ def main():
365
+ # See all possible arguments by passing the --help flag to this script.
366
+ parser = HfArgumentParser(
367
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
368
+ )
369
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
370
+ # If we pass only one argument to the script and it's the path to a json file,
371
+ # let's parse it to get our arguments.
372
+ model_args, data_args, training_args = parser.parse_json_file(
373
+ json_file=os.path.abspath(sys.argv[1])
374
+ )
375
+ else:
376
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
377
+
378
+ if (
379
+ os.path.exists(training_args.output_dir)
380
+ and os.listdir(training_args.output_dir)
381
+ and training_args.do_train
382
+ and not training_args.overwrite_output_dir
383
+ ):
384
+ raise ValueError(
385
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
386
+ "Use --overwrite_output_dir to overcome."
387
+ )
388
+
389
+ # Make one log on every process with the configuration for debugging.
390
+ logging.basicConfig(
391
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
392
+ datefmt="%m/%d/%Y %H:%M:%S",
393
+ level=logging.INFO,
394
+ )
395
+ # Setup logging, we only want one process per machine to log things on the screen.
396
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
397
+ if jax.process_index() == 0:
398
+ datasets.utils.logging.set_verbosity_warning()
399
+ transformers.utils.logging.set_verbosity_info()
400
+ else:
401
+ datasets.utils.logging.set_verbosity_error()
402
+ transformers.utils.logging.set_verbosity_error()
403
+
404
+ # Set the verbosity to info of the Transformers logger (on main process only):
405
+ logger.info(f"Training/evaluation parameters {training_args}")
406
+
407
+ # Load dataset
408
+ dataset = Dataset(
409
+ **asdict(data_args),
410
+ do_train=training_args.do_train,
411
+ do_eval=training_args.do_eval,
412
+ )
413
+
414
+ # Set up wandb run
415
+ wandb.init(
416
+ entity="dalle-mini",
417
+ project="dalle-mini",
418
+ job_type="Seq2Seq",
419
+ config=parser.parse_args(),
420
+ )
421
+
422
+ if training_args.resume_from_checkpoint is not None:
423
+ artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
424
+ artifact_dir = artifact.download()
425
+
426
+ # load model
427
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
428
+ # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
429
+ print(model.params)
430
+
431
+ # load tokenizer
432
+ tokenizer = AutoTokenizer.from_pretrained(
433
+ artifact_dir,
434
+ use_fast=True,
435
+ )
436
+
437
+ else:
438
+ # Set up our new model config
439
+ # TODO: simplify with custom config class
440
+ if model_args.config_name:
441
+ config = BartConfig.from_pretrained(model_args.config_name)
442
+ else:
443
+ config = BartConfig.from_pretrained(model_args.model_name_or_path)
444
+ if model_args.image_vocab_size:
445
+ config.image_vocab_size = model_args.image_vocab_size
446
+ assert (
447
+ getattr(config, "image_vocab_size") is not None
448
+ ), "image_vocab_size must be specified when not present in base model/config"
449
+ if model_args.image_length:
450
+ config.image_length = model_args.image_length
451
+ assert (
452
+ getattr(config, "image_length") is not None
453
+ ), "image_length must be specified when not present in base model/config"
454
+ # we append decoder bos to image vocab
455
+ config.decoder_start_token_id = config.image_vocab_size
456
+ # ensure we don't generate bos (in addition to decoder start token)
457
+ config.force_bos_token_to_be_generated = False
458
+ config.forced_bos_token_id = None # we don't need this token
459
+ config.forced_eos_token_id = None # we don't need this token
460
+
461
+ config.tie_word_embeddings = False
462
+ config.min_length = config.image_length + 1
463
+ config.max_length = config.image_length + 1
464
+
465
+ # below tokens need to be set to avoid error during generation (converted to jnp.array)
466
+ # they are not expected to be used and are set to unreachable token id
467
+ config.bos_token_id = config.image_vocab_size + 1
468
+ config.pos_token_id = config.image_vocab_size + 1
469
+ config.eos_token_id = config.image_vocab_size + 1
470
+
471
+ # save whether we normalize the text
472
+ if model_args.normalize_text is not None:
473
+ config.normalize_text = model_args.normalize_text
474
+ else:
475
+ config.normalize_text = getattr(config, "normalize_text", False)
476
+
477
+ # Load or create new model
478
+ if model_args.model_name_or_path:
479
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
480
+ model_args.model_name_or_path,
481
+ config=config,
482
+ seed=training_args.seed_model,
483
+ dtype=getattr(jnp, model_args.dtype),
484
+ )
485
+ # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
486
+ print(model.params)
487
+ else:
488
+ model = CustomFlaxBartForConditionalGeneration(
489
+ config,
490
+ seed=training_args.seed_model,
491
+ dtype=getattr(jnp, model_args.dtype),
492
+ )
493
+
494
+ # Load tokenizer
495
+ if model_args.tokenizer_name is not None:
496
+ tokenizer = AutoTokenizer.from_pretrained(
497
+ model_args.tokenizer_name, use_fast=True
498
+ )
499
+ else:
500
+ tokenizer = AutoTokenizer.from_pretrained(
501
+ model_args.model_name_or_path,
502
+ use_fast=True,
503
+ )
504
+
505
+ logger.info(f"TPUs: {jax.device_count()}")
506
+ assert jax.device_count() == 8, "TPUs in use, please check running processes"
507
+
508
+ # Preprocessing the datasets.
509
+ # We need to normalize and tokenize inputs and targets.
510
+
511
+ dataset.preprocess(
512
+ tokenizer=tokenizer,
513
+ decoder_start_token_id=model.config.decoder_start_token_id,
514
+ normalize_text=model.config.normalize_text,
515
+ )
516
+
517
+ # Initialize our training
518
+ rng = jax.random.PRNGKey(training_args.seed_model)
519
+ rng, dropout_rng = jax.random.split(rng)
520
+
521
+ # Store some constant
522
+ num_epochs = int(training_args.num_train_epochs)
523
+ train_batch_size = (
524
+ int(training_args.per_device_train_batch_size) * jax.device_count()
525
+ )
526
+ batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
527
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
528
+ len_train_dataset, len_eval_dataset = dataset.length
529
+ steps_per_epoch = (
530
+ len_train_dataset // train_batch_size if len_train_dataset is not None else None
531
+ )
532
+ num_train_steps = (
533
+ steps_per_epoch * num_epochs if steps_per_epoch is not None else None
534
+ )
535
+
536
+ # Create learning rate schedule
537
+ learning_rate_fn = create_learning_rate_fn(
538
+ training_args.warmup_steps,
539
+ training_args.learning_rate,
540
+ training_args.use_decay,
541
+ num_train_steps,
542
+ )
543
+
544
+ # We use Optax's "masking" functionality to not apply weight decay
545
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
546
+ # mask boolean with the same structure as the parameters.
547
+ # The mask is True for parameters that should be decayed.
548
+ # Note that this mask is specifically adapted for FlaxBart.
549
+ def decay_mask_fn(params):
550
+ flat_params = traverse_util.flatten_dict(params)
551
+ layer_norm_params = [
552
+ (name, "scale")
553
+ for name in [
554
+ "self_attn_layer_norm",
555
+ "layernorm_embedding",
556
+ "final_layer_norm",
557
+ ]
558
+ ]
559
+ flat_mask = {
560
+ path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
561
+ for path in flat_params
562
+ }
563
+ return traverse_util.unflatten_dict(flat_mask)
564
+
565
+ # create adam optimizer
566
+ if training_args.adafactor:
567
+ # We use the default parameters here to initialize adafactor,
568
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
569
+ optimizer = optax.adafactor(
570
+ learning_rate=learning_rate_fn,
571
+ weight_decay_rate=training_args.weight_decay,
572
+ weight_decay_mask=decay_mask_fn,
573
+ clipping_threshold=training_args.max_grad_norm,
574
+ )
575
+ else:
576
+ optimizer = optax.adamw(
577
+ learning_rate=learning_rate_fn,
578
+ b1=training_args.adam_beta1,
579
+ b2=training_args.adam_beta2,
580
+ eps=training_args.adam_epsilon,
581
+ weight_decay=training_args.weight_decay,
582
+ mask=decay_mask_fn,
583
+ )
584
+
585
+ # add gradient accumulation
586
+ if training_args.gradient_accumulation_steps > 1:
587
+ optimizer = optax.chain(
588
+ optax.apply_every(training_args.gradient_accumulation_steps), optimizer
589
+ )
590
+
591
+ # Setup train state
592
+ state = TrainState.create(
593
+ apply_fn=model.__call__,
594
+ params=model.params,
595
+ tx=optimizer,
596
+ dropout_rng=dropout_rng,
597
+ )
598
+ if training_args.resume_from_checkpoint is not None:
599
+ # restore optimizer state and other parameters
600
+ # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
601
+ state = state.restore_state(artifact_dir)
602
+
603
+ # label smoothed cross entropy
604
+ def loss_fn(logits, labels):
605
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
606
+ loss = loss.mean()
607
+ return loss
608
+
609
+ # Define gradient update step fn
610
+ def train_step(state, batch, delta_time):
611
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
612
+
613
+ def compute_loss(params, batch):
614
+ labels = batch.pop("labels")
615
+ logits = state.apply_fn(
616
+ **batch, params=params, dropout_rng=dropout_rng, train=True
617
+ )[0]
618
+ loss = loss_fn(logits, labels)
619
+ return loss
620
+
621
+ grad_fn = jax.value_and_grad(compute_loss)
622
+ loss, grads = grad_fn(state.params, batch)
623
+ grads = jax.lax.pmean(grads, "batch")
624
+ state = state.apply_gradients(
625
+ grads=grads,
626
+ dropout_rng=new_dropout_rng,
627
+ train_time=state.train_time + delta_time,
628
+ train_samples=state.train_samples + train_batch_size,
629
+ )
630
+
631
+ metrics = {
632
+ "loss": loss,
633
+ "learning_rate": learning_rate_fn(state.step),
634
+ }
635
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
636
+
637
+ return state, metrics
638
+
639
+ # Define eval fn
640
+ def eval_step(params, batch):
641
+ labels = batch.pop("labels")
642
+ logits = model(**batch, params=params, train=False)[0]
643
+ loss = loss_fn(logits, labels)
644
+
645
+ # summarize metrics
646
+ metrics = {"loss": loss}
647
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
648
+ return metrics
649
+
650
+ # Create parallel version of the train and eval step
651
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
652
+ p_eval_step = jax.pmap(eval_step, "batch")
653
+
654
+ logger.info("***** Running training *****")
655
+ logger.info(f" Num examples = {len_train_dataset}")
656
+ logger.info(f" Num Epochs = {num_epochs}")
657
+ logger.info(
658
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
659
+ )
660
+ logger.info(
661
+ f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
662
+ )
663
+ epochs = tqdm(
664
+ range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
665
+ )
666
+
667
+ # set default x-axis as 'train/step'
668
+ wandb_log({}, step=state.step)
669
+ wandb.define_metric("*", step_metric="train/step")
670
+
671
+ # add interesting config parameters
672
+ wandb.config.update(
673
+ {
674
+ "len_train_dataset": len_train_dataset,
675
+ "len_eval_dataset": len_eval_dataset,
676
+ "batch_size_per_update": batch_size_per_update,
677
+ }
678
+ )
679
+
680
+ # replicate state on each device
681
+ state = state.replicate()
682
+
683
+ def run_evaluation():
684
+ # ======================== Evaluating ==============================
685
+ eval_metrics = []
686
+ if training_args.do_eval:
687
+ eval_loader = dataset.dataloader("eval", eval_batch_size)
688
+ eval_steps = (
689
+ len_eval_dataset // eval_batch_size
690
+ if len_eval_dataset is not None
691
+ else None
692
+ )
693
+ for batch in tqdm(
694
+ eval_loader,
695
+ desc="Evaluating...",
696
+ position=2,
697
+ leave=False,
698
+ total=eval_steps,
699
+ ):
700
+ # Model forward
701
+ metrics = p_eval_step(state.params, batch)
702
+ eval_metrics.append(metrics)
703
+
704
+ # normalize eval metrics
705
+ eval_metrics = get_metrics(eval_metrics)
706
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
707
+
708
+ # log metrics
709
+ wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
710
+
711
+ # Print metrics and update progress bar
712
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
713
+ epochs.write(desc)
714
+ epochs.desc = desc
715
+
716
+ return eval_metrics
717
+
718
+ def run_save_model(state, eval_metrics=None):
719
+ if jax.process_index() == 0:
720
+ params = jax.device_get(unreplicate(state.params))
721
+ # save model locally
722
+ model.save_pretrained(
723
+ training_args.output_dir,
724
+ params=params,
725
+ )
726
+
727
+ # save tokenizer
728
+ tokenizer.save_pretrained(training_args.output_dir)
729
+
730
+ # save state
731
+ opt_state = unreplicate(state.opt_state)
732
+ with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
733
+ f.write(to_bytes(opt_state))
734
+ state_dict = {
735
+ k: jax.device_get(unreplicate(getattr(state, k))).item()
736
+ for k in ["step", "epoch", "train_time", "train_samples"]
737
+ }
738
+ with (Path(training_args.output_dir) / "training_state.json").open(
739
+ "w"
740
+ ) as f:
741
+ json.dump(
742
+ state_dict,
743
+ f,
744
+ )
745
+
746
+ # save to W&B
747
+ if training_args.log_model:
748
+ # save some space
749
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
750
+ c.cleanup(wandb.util.from_human_size("10GB"))
751
+
752
+ metadata = dict(state_dict)
753
+ if eval_metrics is not None:
754
+ metadata["eval"] = eval_metrics
755
+ artifact = wandb.Artifact(
756
+ name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
757
+ )
758
+ artifact.add_file(
759
+ str(Path(training_args.output_dir) / "flax_model.msgpack")
760
+ )
761
+ artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
762
+ artifact.add_file(
763
+ str(Path(training_args.output_dir) / "tokenizer.json")
764
+ )
765
+ artifact.add_file(
766
+ str(Path(training_args.output_dir) / "tokenizer_config.json")
767
+ )
768
+ artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
769
+ artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
770
+ artifact.add_file(
771
+ str(Path(training_args.output_dir) / "special_tokens_map.json")
772
+ )
773
+ artifact.add_file(
774
+ str(Path(training_args.output_dir) / "opt_state.msgpack")
775
+ )
776
+ artifact.add_file(
777
+ str(Path(training_args.output_dir) / "training_state.json")
778
+ )
779
+
780
+ wandb.run.log_artifact(artifact)
781
+
782
+ # save to the hub
783
+ if training_args.push_to_hub:
784
+ model.save_pretrained(
785
+ training_args.output_dir,
786
+ params=params,
787
+ push_to_hub=training_args.push_to_hub,
788
+ commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
789
+ temp_dir=True, # avoid issues with being in a repository
790
+ )
791
+
792
+ # init variables
793
+ last_time = time.perf_counter()
794
+ train_metrics = None
795
+
796
+ for epoch in epochs:
797
+ state.replace(epoch=jax_utils.replicate(epoch))
798
+ # ======================== Training ================================
799
+ wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
800
+
801
+ # Generate an epoch by shuffling sampling indices from the train dataset
802
+ train_loader = dataset.dataloader("train", train_batch_size)
803
+ # train
804
+ for batch in tqdm(
805
+ train_loader,
806
+ desc="Training...",
807
+ position=1,
808
+ leave=False,
809
+ total=steps_per_epoch,
810
+ ):
811
+
812
+ # calculate delta time (we have a lag of one step but it's ok)
813
+ new_time = time.perf_counter()
814
+ delta_time = new_time - last_time
815
+ last_time = new_time
816
+
817
+ # train step
818
+ state, train_metrics = p_train_step(
819
+ state, batch, jax_utils.replicate(delta_time)
820
+ )
821
+ step = unreplicate(state.step)
822
+
823
+ if step % training_args.logging_steps == 0 and jax.process_index() == 0:
824
+ # log metrics
825
+ metrics = unreplicate(train_metrics)
826
+ # log state parameters
827
+ state_dict = {
828
+ k.split("_")[-1]: unreplicate(getattr(state, k))
829
+ for k in ["epoch", "train_time", "train_samples"]
830
+ }
831
+ wandb_log({**metrics, **state_dict}, step=step, prefix="train")
832
+
833
+ eval_metrics = None
834
+ if training_args.eval_steps and step % training_args.eval_steps == 0:
835
+ eval_metrics = run_evaluation()
836
+
837
+ if step % training_args.save_steps == 0:
838
+ run_save_model(state, eval_metrics)
839
+
840
+ # log final train metrics
841
+ if train_metrics is not None:
842
+ train_metrics = unreplicate(train_metrics)
843
+ wandb_log(train_metrics, step=step, prefix="train")
844
+
845
+ epochs.write(
846
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
847
+ )
848
+
849
+ # Final evaluation
850
+ eval_metrics = run_evaluation()
851
+
852
+ # save checkpoint after each epoch
853
+ run_save_model(state, eval_metrics)
854
+
855
+
856
+ if __name__ == "__main__":
857
+ main()