boris commited on
Commit
a2dcee4
1 Parent(s): 179282e

fix: support smelu

Browse files
Files changed (2) hide show
  1. README.md +89 -78
  2. src/dalle_mini/model/modeling.py +25 -1
README.md CHANGED
@@ -133,127 +133,138 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
133
  ### Citations
134
 
135
  ```text
136
- @misc{ramesh2021zeroshot,
137
- title={Zero-Shot Text-to-Image Generation},
138
- author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
139
- year={2021},
140
- eprint={2102.12092},
141
- archivePrefix={arXiv},
142
- primaryClass={cs.CV}
143
  }
144
  ```
145
 
146
  ```text
147
- @misc{radford2021learning,
148
- title={Learning Transferable Visual Models From Natural Language Supervision},
149
- author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
150
- year={2021},
151
- eprint={2103.00020},
152
- archivePrefix={arXiv},
153
- primaryClass={cs.CV}
154
  }
155
  ```
156
 
157
  ```text
158
- @misc{esser2021taming,
159
- title={Taming Transformers for High-Resolution Image Synthesis},
160
- author={Patrick Esser and Robin Rombach and Björn Ommer},
161
- year={2021},
162
- eprint={2012.09841},
163
- archivePrefix={arXiv},
164
- primaryClass={cs.CV}
165
  }
166
  ```
167
 
168
  ```text
169
- @misc{lewis2019bart,
170
- title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
171
- author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
172
- year={2019},
173
- eprint={1910.13461},
174
- archivePrefix={arXiv},
175
- primaryClass={cs.CL}
176
  }
177
  ```
178
 
179
  ```text
180
- @misc{anil2021scalable,
181
- title={Scalable Second Order Optimization for Deep Learning},
182
- author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
183
- year={2021},
184
- eprint={2002.09018},
185
- archivePrefix={arXiv},
186
- primaryClass={cs.LG}
187
  }
188
  ```
189
 
190
  ```text
191
- @misc{shazeer2020glu,
192
- title={GLU Variants Improve Transformer},
193
- author={Noam Shazeer},
194
- year={2020},
195
- url={https://arxiv.org/abs/2002.05202}
196
  }
197
  ```
198
 
199
  ```text
200
- @misc{wang_ma_dong_huang_zhang_wei_2022,
201
- title={DeepNet: Scaling transformers to 1,000 layers},
202
- author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
203
- year={2022},
204
- eprint={2203.00555}
205
- archivePrefix={arXiv},
206
- primaryClass={cs.LG}
207
  }
208
  ```
209
 
210
  ```text
211
- @misc{shleifer2021normformer,
212
- title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
213
- author={Sam Shleifer and Jason Weston and Myle Ott},
214
- year={2021},
215
- eprint={2110.09456},
216
- archivePrefix={arXiv},
217
- primaryClass={cs.CL}
218
  }
219
  ```
220
 
221
  ```text
222
- @inproceedings{liu2021swinv2,
223
- title={Swin Transformer V2: Scaling Up Capacity and Resolution},
224
- author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
225
- booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
226
- year={2022}
227
  }
228
  ```
229
 
230
  ```text
231
- @misc{ding2021cogview,
232
- title = {CogView: Mastering Text-to-Image Generation via Transformers},
233
- author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
234
- year = {2021},
235
- eprint = {2105.13290},
236
- archivePrefix = {arXiv},
237
- primaryClass = {cs.CV}
238
  }
239
  ```
240
 
241
  ```text
242
- @misc{zhang2019root,
243
- title = {Root Mean Square Layer Normalization},
244
- author = {Biao Zhang and Rico Sennrich},
245
- year = {2019},
246
- eprint = {1910.07467},
247
- archivePrefix = {arXiv},
248
- primaryClass = {cs.LG}
249
  }
250
  ```
251
 
252
  ```text
253
- @misc{title = {Sinkformers: Transformers with Doubly Stochastic Attention},
254
- url = {https://arxiv.org/abs/2110.11773},
255
- author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
256
- publisher = {arXiv},
257
- year = {2021},
 
 
 
 
 
 
 
 
 
 
 
258
  }
259
  ```
 
133
  ### Citations
134
 
135
  ```text
136
+ @misc{
137
+ title={Zero-Shot Text-to-Image Generation},
138
+ author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
139
+ year={2021},
140
+ eprint={2102.12092},
141
+ archivePrefix={arXiv},
142
+ primaryClass={cs.CV}
143
  }
144
  ```
145
 
146
  ```text
147
+ @misc{
148
+ title={Learning Transferable Visual Models From Natural Language Supervision},
149
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
150
+ year={2021},
151
+ eprint={2103.00020},
152
+ archivePrefix={arXiv},
153
+ primaryClass={cs.CV}
154
  }
155
  ```
156
 
157
  ```text
158
+ @misc{
159
+ title={Taming Transformers for High-Resolution Image Synthesis},
160
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
161
+ year={2021},
162
+ eprint={2012.09841},
163
+ archivePrefix={arXiv},
164
+ primaryClass={cs.CV}
165
  }
166
  ```
167
 
168
  ```text
169
+ @misc{
170
+ title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
171
+ author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
172
+ year={2019},
173
+ eprint={1910.13461},
174
+ archivePrefix={arXiv},
175
+ primaryClass={cs.CL}
176
  }
177
  ```
178
 
179
  ```text
180
+ @misc{
181
+ title={Scalable Second Order Optimization for Deep Learning},
182
+ author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
183
+ year={2021},
184
+ eprint={2002.09018},
185
+ archivePrefix={arXiv},
186
+ primaryClass={cs.LG}
187
  }
188
  ```
189
 
190
  ```text
191
+ @misc{
192
+ title={GLU Variants Improve Transformer},
193
+ author={Noam Shazeer},
194
+ year={2020},
195
+ url={https://arxiv.org/abs/2002.05202}
196
  }
197
  ```
198
 
199
  ```text
200
+ @misc{
201
+ title={DeepNet: Scaling transformers to 1,000 layers},
202
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
203
+ year={2022},
204
+ eprint={2203.00555}
205
+ archivePrefix={arXiv},
206
+ primaryClass={cs.LG}
207
  }
208
  ```
209
 
210
  ```text
211
+ @misc{
212
+ title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
213
+ author={Sam Shleifer and Jason Weston and Myle Ott},
214
+ year={2021},
215
+ eprint={2110.09456},
216
+ archivePrefix={arXiv},
217
+ primaryClass={cs.CL}
218
  }
219
  ```
220
 
221
  ```text
222
+ @inproceedings{
223
+ title={Swin Transformer V2: Scaling Up Capacity and Resolution},
224
+ author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
225
+ booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
226
+ year={2022}
227
  }
228
  ```
229
 
230
  ```text
231
+ @misc{
232
+ title = {CogView: Mastering Text-to-Image Generation via Transformers},
233
+ author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
234
+ year = {2021},
235
+ eprint = {2105.13290},
236
+ archivePrefix = {arXiv},
237
+ primaryClass = {cs.CV}
238
  }
239
  ```
240
 
241
  ```text
242
+ @misc{
243
+ title = {Root Mean Square Layer Normalization},
244
+ author = {Biao Zhang and Rico Sennrich},
245
+ year = {2019},
246
+ eprint = {1910.07467},
247
+ archivePrefix = {arXiv},
248
+ primaryClass = {cs.LG}
249
  }
250
  ```
251
 
252
  ```text
253
+ @misc{
254
+ title = {Sinkformers: Transformers with Doubly Stochastic Attention},
255
+ url = {https://arxiv.org/abs/2110.11773},
256
+ author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
257
+ publisher = {arXiv},
258
+ year = {2021},
259
+ }
260
+ ```
261
+
262
+ ```text
263
+ @misc{
264
+ title = {Smooth activations and reproducibility in deep networks},
265
+ url = {https://arxiv.org/abs/2010.09931},
266
+ author = {Shamir, Gil I. and Lin, Dong and Coviello, Lorenzo},
267
+ publisher = {arXiv},
268
+ year = {2020},
269
  }
270
  ```
src/dalle_mini/model/modeling.py CHANGED
@@ -32,7 +32,7 @@ from flax.linen import partitioning as nn_partitioning
32
  from flax.linen.linear import PrecisionLike
33
  from flax.serialization import from_bytes
34
  from flax.traverse_util import flatten_dict, unflatten_dict
35
- from jax import lax
36
  from jax.random import PRNGKey
37
  from transformers.configuration_utils import PretrainedConfig
38
  from transformers.file_utils import (
@@ -68,6 +68,30 @@ logger = logging.get_logger(__name__)
68
  remat = nn_partitioning.remat
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # deepnet initialization
72
  def deepnet_init(gain=1):
73
  init = jax.nn.initializers.glorot_normal()
 
32
  from flax.linen.linear import PrecisionLike
33
  from flax.serialization import from_bytes
34
  from flax.traverse_util import flatten_dict, unflatten_dict
35
+ from jax import custom_jvp, lax
36
  from jax.random import PRNGKey
37
  from transformers.configuration_utils import PretrainedConfig
38
  from transformers.file_utils import (
 
68
  remat = nn_partitioning.remat
69
 
70
 
71
+ def smelu(beta: Any = 1.0):
72
+ """
73
+ Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
74
+ https://arxiv.org/abs/2202.06499
75
+ """
76
+
77
+ @custom_jvp
78
+ @jax.jit
79
+ def _smelu(x: Any) -> Any:
80
+ x = jnp.where(x <= -beta, 0.0, x)
81
+ return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
82
+
83
+ _smelu.defjvps(
84
+ lambda g, ans, x: lax.select(
85
+ x == -beta,
86
+ lax.full_like(g, 0),
87
+ lax.select(x == beta, lax.full_like(g, 1), g),
88
+ )
89
+ )
90
+ return _smelu
91
+
92
+
93
+ ACT2FN.update({"smelu": smelu})
94
+
95
  # deepnet initialization
96
  def deepnet_init(gain=1):
97
  init = jax.nn.initializers.glorot_normal()