boris commited on
Commit
cfb03b3
2 Parent(s): d8a2192 127e230

Merge pull request #60 from borisdayma/refactor/vqgan-jax

Browse files
.gitignore CHANGED
@@ -1 +1,3 @@
1
  __pycache__
 
 
 
1
  __pycache__
2
+ .ipynb_checkpoints
3
+ .streamlit
app/app_gradio.py CHANGED
@@ -19,7 +19,7 @@ import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
 
22
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
23
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
25
  import gradio as gr
 
19
  import matplotlib.pyplot as plt
20
 
21
 
22
+ from vqgan_jax.modeling_flax_vqgan import VQModel
23
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
25
  import gradio as gr
dalle_mini/vqgan_jax/README.md DELETED
@@ -1,5 +0,0 @@
1
- ## vqgan-jax
2
-
3
- Files copied from [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax/tree/main/vqgan_jax)
4
-
5
- Required for VQGAN Jax model.
 
 
 
 
 
 
dalle_mini/vqgan_jax/__init__.py DELETED
File without changes
dalle_mini/vqgan_jax/configuration_vqgan.py DELETED
@@ -1,40 +0,0 @@
1
- from typing import Tuple
2
-
3
- from transformers import PretrainedConfig
4
-
5
-
6
- class VQGANConfig(PretrainedConfig):
7
- def __init__(
8
- self,
9
- ch: int = 128,
10
- out_ch: int = 3,
11
- in_channels: int = 3,
12
- num_res_blocks: int = 2,
13
- resolution: int = 256,
14
- z_channels: int = 256,
15
- ch_mult: Tuple = (1, 1, 2, 2, 4),
16
- attn_resolutions: int = (16,),
17
- n_embed: int = 1024,
18
- embed_dim: int = 256,
19
- dropout: float = 0.0,
20
- double_z: bool = False,
21
- resamp_with_conv: bool = True,
22
- give_pre_end: bool = False,
23
- **kwargs,
24
- ):
25
- super().__init__(**kwargs)
26
- self.ch = ch
27
- self.out_ch = out_ch
28
- self.in_channels = in_channels
29
- self.num_res_blocks = num_res_blocks
30
- self.resolution = resolution
31
- self.z_channels = z_channels
32
- self.ch_mult = list(ch_mult)
33
- self.attn_resolutions = list(attn_resolutions)
34
- self.n_embed = n_embed
35
- self.embed_dim = embed_dim
36
- self.dropout = dropout
37
- self.double_z = double_z
38
- self.resamp_with_conv = resamp_with_conv
39
- self.give_pre_end = give_pre_end
40
- self.num_resolutions = len(ch_mult)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/vqgan_jax/modeling_flax_vqgan.py DELETED
@@ -1,609 +0,0 @@
1
- # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
-
3
- from functools import partial
4
- from typing import Tuple
5
- import math
6
-
7
- import jax
8
- import jax.numpy as jnp
9
- import numpy as np
10
- import flax.linen as nn
11
- from flax.core.frozen_dict import FrozenDict
12
-
13
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
-
15
- from .configuration_vqgan import VQGANConfig
16
-
17
-
18
- class Upsample(nn.Module):
19
- in_channels: int
20
- with_conv: bool
21
- dtype: jnp.dtype = jnp.float32
22
-
23
- def setup(self):
24
- if self.with_conv:
25
- self.conv = nn.Conv(
26
- self.in_channels,
27
- kernel_size=(3, 3),
28
- strides=(1, 1),
29
- padding=((1, 1), (1, 1)),
30
- dtype=self.dtype,
31
- )
32
-
33
- def __call__(self, hidden_states):
34
- batch, height, width, channels = hidden_states.shape
35
- hidden_states = jax.image.resize(
36
- hidden_states,
37
- shape=(batch, height * 2, width * 2, channels),
38
- method="nearest",
39
- )
40
- if self.with_conv:
41
- hidden_states = self.conv(hidden_states)
42
- return hidden_states
43
-
44
-
45
- class Downsample(nn.Module):
46
- in_channels: int
47
- with_conv: bool
48
- dtype: jnp.dtype = jnp.float32
49
-
50
- def setup(self):
51
- if self.with_conv:
52
- self.conv = nn.Conv(
53
- self.in_channels,
54
- kernel_size=(3, 3),
55
- strides=(2, 2),
56
- padding="VALID",
57
- dtype=self.dtype,
58
- )
59
-
60
- def __call__(self, hidden_states):
61
- if self.with_conv:
62
- pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
- hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
- hidden_states = self.conv(hidden_states)
65
- else:
66
- hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
- return hidden_states
68
-
69
-
70
- class ResnetBlock(nn.Module):
71
- in_channels: int
72
- out_channels: int = None
73
- use_conv_shortcut: bool = False
74
- temb_channels: int = 512
75
- dropout_prob: float = 0.0
76
- dtype: jnp.dtype = jnp.float32
77
-
78
- def setup(self):
79
- self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
-
81
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
- self.conv1 = nn.Conv(
83
- self.out_channels_,
84
- kernel_size=(3, 3),
85
- strides=(1, 1),
86
- padding=((1, 1), (1, 1)),
87
- dtype=self.dtype,
88
- )
89
-
90
- if self.temb_channels:
91
- self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
-
93
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
- self.dropout = nn.Dropout(self.dropout_prob)
95
- self.conv2 = nn.Conv(
96
- self.out_channels_,
97
- kernel_size=(3, 3),
98
- strides=(1, 1),
99
- padding=((1, 1), (1, 1)),
100
- dtype=self.dtype,
101
- )
102
-
103
- if self.in_channels != self.out_channels_:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = nn.Conv(
106
- self.out_channels_,
107
- kernel_size=(3, 3),
108
- strides=(1, 1),
109
- padding=((1, 1), (1, 1)),
110
- dtype=self.dtype,
111
- )
112
- else:
113
- self.nin_shortcut = nn.Conv(
114
- self.out_channels_,
115
- kernel_size=(1, 1),
116
- strides=(1, 1),
117
- padding="VALID",
118
- dtype=self.dtype,
119
- )
120
-
121
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
- residual = hidden_states
123
- hidden_states = self.norm1(hidden_states)
124
- hidden_states = nn.swish(hidden_states)
125
- hidden_states = self.conv1(hidden_states)
126
-
127
- if temb is not None:
128
- hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
-
130
- hidden_states = self.norm2(hidden_states)
131
- hidden_states = nn.swish(hidden_states)
132
- hidden_states = self.dropout(hidden_states, deterministic)
133
- hidden_states = self.conv2(hidden_states)
134
-
135
- if self.in_channels != self.out_channels_:
136
- if self.use_conv_shortcut:
137
- residual = self.conv_shortcut(residual)
138
- else:
139
- residual = self.nin_shortcut(residual)
140
-
141
- return hidden_states + residual
142
-
143
-
144
- class AttnBlock(nn.Module):
145
- in_channels: int
146
- dtype: jnp.dtype = jnp.float32
147
-
148
- def setup(self):
149
- conv = partial(
150
- nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
- )
152
-
153
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
- self.q, self.k, self.v = conv(), conv(), conv()
155
- self.proj_out = conv()
156
-
157
- def __call__(self, hidden_states):
158
- residual = hidden_states
159
- hidden_states = self.norm(hidden_states)
160
-
161
- query = self.q(hidden_states)
162
- key = self.k(hidden_states)
163
- value = self.v(hidden_states)
164
-
165
- # compute attentions
166
- batch, height, width, channels = query.shape
167
- query = query.reshape((batch, height * width, channels))
168
- key = key.reshape((batch, height * width, channels))
169
- attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
- attn_weights = attn_weights * (int(channels) ** -0.5)
171
- attn_weights = nn.softmax(attn_weights, axis=2)
172
-
173
- ## attend to values
174
- value = value.reshape((batch, height * width, channels))
175
- hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
- hidden_states = hidden_states.reshape((batch, height, width, channels))
177
-
178
- hidden_states = self.proj_out(hidden_states)
179
- hidden_states = hidden_states + residual
180
- return hidden_states
181
-
182
-
183
- class UpsamplingBlock(nn.Module):
184
- config: VQGANConfig
185
- curr_res: int
186
- block_idx: int
187
- dtype: jnp.dtype = jnp.float32
188
-
189
- def setup(self):
190
- if self.block_idx == self.config.num_resolutions - 1:
191
- block_in = self.config.ch * self.config.ch_mult[-1]
192
- else:
193
- block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
-
195
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
- self.temb_ch = 0
197
-
198
- res_blocks = []
199
- attn_blocks = []
200
- for _ in range(self.config.num_res_blocks + 1):
201
- res_blocks.append(
202
- ResnetBlock(
203
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
- )
205
- )
206
- block_in = block_out
207
- if self.curr_res in self.config.attn_resolutions:
208
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
-
210
- self.block = res_blocks
211
- self.attn = attn_blocks
212
-
213
- self.upsample = None
214
- if self.block_idx != 0:
215
- self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
-
217
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
- for res_block in self.block:
219
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
- for attn_block in self.attn:
221
- hidden_states = attn_block(hidden_states)
222
-
223
- if self.upsample is not None:
224
- hidden_states = self.upsample(hidden_states)
225
-
226
- return hidden_states
227
-
228
-
229
- class DownsamplingBlock(nn.Module):
230
- config: VQGANConfig
231
- curr_res: int
232
- block_idx: int
233
- dtype: jnp.dtype = jnp.float32
234
-
235
- def setup(self):
236
- in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
- block_in = self.config.ch * in_ch_mult[self.block_idx]
238
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
- self.temb_ch = 0
240
-
241
- res_blocks = []
242
- attn_blocks = []
243
- for _ in range(self.config.num_res_blocks):
244
- res_blocks.append(
245
- ResnetBlock(
246
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
- )
248
- )
249
- block_in = block_out
250
- if self.curr_res in self.config.attn_resolutions:
251
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
-
253
- self.block = res_blocks
254
- self.attn = attn_blocks
255
-
256
- self.downsample = None
257
- if self.block_idx != self.config.num_resolutions - 1:
258
- self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
-
260
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
- for res_block in self.block:
262
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
- for attn_block in self.attn:
264
- hidden_states = attn_block(hidden_states)
265
-
266
- if self.downsample is not None:
267
- hidden_states = self.downsample(hidden_states)
268
-
269
- return hidden_states
270
-
271
-
272
- class MidBlock(nn.Module):
273
- in_channels: int
274
- temb_channels: int
275
- dropout: float
276
- dtype: jnp.dtype = jnp.float32
277
-
278
- def setup(self):
279
- self.block_1 = ResnetBlock(
280
- self.in_channels,
281
- self.in_channels,
282
- temb_channels=self.temb_channels,
283
- dropout_prob=self.dropout,
284
- dtype=self.dtype,
285
- )
286
- self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
- self.block_2 = ResnetBlock(
288
- self.in_channels,
289
- self.in_channels,
290
- temb_channels=self.temb_channels,
291
- dropout_prob=self.dropout,
292
- dtype=self.dtype,
293
- )
294
-
295
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
- hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
- hidden_states = self.attn_1(hidden_states)
298
- hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
- return hidden_states
300
-
301
-
302
- class Encoder(nn.Module):
303
- config: VQGANConfig
304
- dtype: jnp.dtype = jnp.float32
305
-
306
- def setup(self):
307
- self.temb_ch = 0
308
-
309
- # downsampling
310
- self.conv_in = nn.Conv(
311
- self.config.ch,
312
- kernel_size=(3, 3),
313
- strides=(1, 1),
314
- padding=((1, 1), (1, 1)),
315
- dtype=self.dtype,
316
- )
317
-
318
- curr_res = self.config.resolution
319
- downsample_blocks = []
320
- for i_level in range(self.config.num_resolutions):
321
- downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
-
323
- if i_level != self.config.num_resolutions - 1:
324
- curr_res = curr_res // 2
325
- self.down = downsample_blocks
326
-
327
- # middle
328
- mid_channels = self.config.ch * self.config.ch_mult[-1]
329
- self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
-
331
- # end
332
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
- self.conv_out = nn.Conv(
334
- 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
- kernel_size=(3, 3),
336
- strides=(1, 1),
337
- padding=((1, 1), (1, 1)),
338
- dtype=self.dtype,
339
- )
340
-
341
- def __call__(self, pixel_values, deterministic: bool = True):
342
- # timestep embedding
343
- temb = None
344
-
345
- # downsampling
346
- hidden_states = self.conv_in(pixel_values)
347
- for block in self.down:
348
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
-
350
- # middle
351
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
-
353
- # end
354
- hidden_states = self.norm_out(hidden_states)
355
- hidden_states = nn.swish(hidden_states)
356
- hidden_states = self.conv_out(hidden_states)
357
-
358
- return hidden_states
359
-
360
-
361
- class Decoder(nn.Module):
362
- config: VQGANConfig
363
- dtype: jnp.dtype = jnp.float32
364
-
365
- def setup(self):
366
- self.temb_ch = 0
367
-
368
- # compute in_ch_mult, block_in and curr_res at lowest res
369
- block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
- curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
- self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
-
374
- # z to block_in
375
- self.conv_in = nn.Conv(
376
- block_in,
377
- kernel_size=(3, 3),
378
- strides=(1, 1),
379
- padding=((1, 1), (1, 1)),
380
- dtype=self.dtype,
381
- )
382
-
383
- # middle
384
- self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
-
386
- # upsampling
387
- upsample_blocks = []
388
- for i_level in reversed(range(self.config.num_resolutions)):
389
- upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
- if i_level != 0:
391
- curr_res = curr_res * 2
392
- self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
-
394
- # end
395
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
- self.conv_out = nn.Conv(
397
- self.config.out_ch,
398
- kernel_size=(3, 3),
399
- strides=(1, 1),
400
- padding=((1, 1), (1, 1)),
401
- dtype=self.dtype,
402
- )
403
-
404
- def __call__(self, hidden_states, deterministic: bool = True):
405
- # timestep embedding
406
- temb = None
407
-
408
- # z to block_in
409
- hidden_states = self.conv_in(hidden_states)
410
-
411
- # middle
412
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
-
414
- # upsampling
415
- for block in reversed(self.up):
416
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
-
418
- # end
419
- if self.config.give_pre_end:
420
- return hidden_states
421
-
422
- hidden_states = self.norm_out(hidden_states)
423
- hidden_states = nn.swish(hidden_states)
424
- hidden_states = self.conv_out(hidden_states)
425
-
426
- return hidden_states
427
-
428
-
429
- class VectorQuantizer(nn.Module):
430
- """
431
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
- ____________________________________________
433
- Discretization bottleneck part of the VQ-VAE.
434
- Inputs:
435
- - n_e : number of embeddings
436
- - e_dim : dimension of embedding
437
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
- _____________________________________________
439
- """
440
-
441
- config: VQGANConfig
442
- dtype: jnp.dtype = jnp.float32
443
-
444
- def setup(self):
445
- self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
-
447
- def __call__(self, hidden_states):
448
- """
449
- Inputs the output of the encoder network z and maps it to a discrete
450
- one-hot vector that is the index of the closest embedding vector e_j
451
- z (continuous) -> z_q (discrete)
452
- z.shape = (batch, channel, height, width)
453
- quantization pipeline:
454
- 1. get encoder input (B,C,H,W)
455
- 2. flatten input to (B*H*W,C)
456
- """
457
- # flatten
458
- hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
-
460
- # dummy op to init the weights, so we can access them below
461
- self.embedding(jnp.ones((1, 1), dtype="i4"))
462
-
463
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
- emb_weights = self.variables["params"]["embedding"]["embedding"]
465
- distance = (
466
- jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
- + jnp.sum(emb_weights ** 2, axis=1)
468
- - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
- )
470
-
471
- # get quantized latent vectors
472
- min_encoding_indices = jnp.argmin(distance, axis=1)
473
- z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
-
475
- # reshape to (batch, num_tokens)
476
- min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
-
478
- # compute the codebook_loss (q_loss) outside the model
479
- # here we return the embeddings and indices
480
- return z_q, min_encoding_indices
481
-
482
- def get_codebook_entry(self, indices, shape=None):
483
- # indices are expected to be of shape (batch, num_tokens)
484
- # get quantized latent vectors
485
- batch, num_tokens = indices.shape
486
- z_q = self.embedding(indices)
487
- z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
- return z_q
489
-
490
-
491
- class VQModule(nn.Module):
492
- config: VQGANConfig
493
- dtype: jnp.dtype = jnp.float32
494
-
495
- def setup(self):
496
- self.encoder = Encoder(self.config, dtype=self.dtype)
497
- self.decoder = Decoder(self.config, dtype=self.dtype)
498
- self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
- self.quant_conv = nn.Conv(
500
- self.config.embed_dim,
501
- kernel_size=(1, 1),
502
- strides=(1, 1),
503
- padding="VALID",
504
- dtype=self.dtype,
505
- )
506
- self.post_quant_conv = nn.Conv(
507
- self.config.z_channels,
508
- kernel_size=(1, 1),
509
- strides=(1, 1),
510
- padding="VALID",
511
- dtype=self.dtype,
512
- )
513
-
514
- def encode(self, pixel_values, deterministic: bool = True):
515
- hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
- hidden_states = self.quant_conv(hidden_states)
517
- quant_states, indices = self.quantize(hidden_states)
518
- return quant_states, indices
519
-
520
- def decode(self, hidden_states, deterministic: bool = True):
521
- hidden_states = self.post_quant_conv(hidden_states)
522
- hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
- return hidden_states
524
-
525
- def decode_code(self, code_b):
526
- hidden_states = self.quantize.get_codebook_entry(code_b)
527
- hidden_states = self.decode(hidden_states)
528
- return hidden_states
529
-
530
- def __call__(self, pixel_values, deterministic: bool = True):
531
- quant_states, indices = self.encode(pixel_values, deterministic)
532
- hidden_states = self.decode(quant_states, deterministic)
533
- return hidden_states, indices
534
-
535
-
536
- class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
- """
538
- An abstract class to handle weights initialization and a simple interface
539
- for downloading and loading pretrained models.
540
- """
541
-
542
- config_class = VQGANConfig
543
- base_model_prefix = "model"
544
- module_class: nn.Module = None
545
-
546
- def __init__(
547
- self,
548
- config: VQGANConfig,
549
- input_shape: Tuple = (1, 256, 256, 3),
550
- seed: int = 0,
551
- dtype: jnp.dtype = jnp.float32,
552
- **kwargs,
553
- ):
554
- module = self.module_class(config=config, dtype=dtype, **kwargs)
555
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
-
557
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
- # init input tensors
559
- pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
- params_rng, dropout_rng = jax.random.split(rng)
561
- rngs = {"params": params_rng, "dropout": dropout_rng}
562
-
563
- return self.module.init(rngs, pixel_values)["params"]
564
-
565
- def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
- # Handle any PRNG if needed
567
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
-
569
- return self.module.apply(
570
- {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
- )
572
-
573
- def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
- # Handle any PRNG if needed
575
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
-
577
- return self.module.apply(
578
- {"params": params or self.params},
579
- jnp.array(hidden_states),
580
- not train,
581
- rngs=rngs,
582
- method=self.module.decode,
583
- )
584
-
585
- def decode_code(self, indices, params: dict = None):
586
- return self.module.apply(
587
- {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
- )
589
-
590
- def __call__(
591
- self,
592
- pixel_values,
593
- params: dict = None,
594
- dropout_rng: jax.random.PRNGKey = None,
595
- train: bool = False,
596
- ):
597
- # Handle any PRNG if needed
598
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
-
600
- return self.module.apply(
601
- {"params": params or self.params},
602
- jnp.array(pixel_values),
603
- not train,
604
- rngs=rngs,
605
- )
606
-
607
-
608
- class VQModel(VQGANPreTrainedModel):
609
- module_class = VQModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/notebooks/demo/model-sweep.py CHANGED
@@ -11,19 +11,15 @@ from flax.jax_utils import replicate, unreplicate
11
  from transformers.models.bart.modeling_flax_bart import *
12
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
 
14
- import io
15
-
16
- import requests
17
  from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
- import torch
22
  import torchvision.transforms as T
23
  import torchvision.transforms.functional as TF
24
  from torchvision.transforms import InterpolationMode
25
 
26
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
 
28
  # TODO: set those args in a config file
29
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
 
11
  from transformers.models.bart.modeling_flax_bart import *
12
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
13
 
 
 
 
14
  from PIL import Image
15
  import numpy as np
16
  import matplotlib.pyplot as plt
17
 
 
18
  import torchvision.transforms as T
19
  import torchvision.transforms.functional as TF
20
  from torchvision.transforms import InterpolationMode
21
 
22
+ from vqgan_jax.modeling_flax_vqgan import VQModel
23
 
24
  # TODO: set those args in a config file
25
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
dev/notebooks/demo/tpu-demo.ipynb CHANGED
@@ -51,14 +51,6 @@
51
  "jax.devices()"
52
  ]
53
  },
54
- {
55
- "cell_type": "markdown",
56
- "id": "d408065c",
57
- "metadata": {},
58
- "source": [
59
- "`dalle_mini` is a local package that contains the VQGAN-JAX model by Suraj, and other utilities. You can also `cd` to the directory that contains your checkout of [`vqgan-jax`](https://github.com/patil-suraj/vqgan-jax.git)"
60
- ]
61
- },
62
  {
63
  "cell_type": "code",
64
  "execution_count": null,
@@ -66,8 +58,7 @@
66
  "metadata": {},
67
  "outputs": [],
68
  "source": [
69
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel\n",
70
- "#%cd /content/vqgan-jax"
71
  ]
72
  },
73
  {
@@ -447,7 +438,7 @@
447
  "name": "python",
448
  "nbconvert_exporter": "python",
449
  "pygments_lexer": "ipython3",
450
- "version": "3.8.5"
451
  }
452
  },
453
  "nbformat": 4,
 
51
  "jax.devices()"
52
  ]
53
  },
 
 
 
 
 
 
 
 
54
  {
55
  "cell_type": "code",
56
  "execution_count": null,
 
58
  "metadata": {},
59
  "outputs": [],
60
  "source": [
61
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
 
62
  ]
63
  },
64
  {
 
438
  "name": "python",
439
  "nbconvert_exporter": "python",
440
  "pygments_lexer": "ipython3",
441
+ "version": "3.8.10"
442
  }
443
  },
444
  "nbformat": 4,
dev/notebooks/encoding/vqgan-jax-encoding-with-captions.ipynb CHANGED
@@ -50,14 +50,6 @@
50
  "## VQGAN-JAX model"
51
  ]
52
  },
53
- {
54
- "cell_type": "markdown",
55
- "id": "bb408f6c",
56
- "metadata": {},
57
- "source": [
58
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
59
- ]
60
- },
61
  {
62
  "cell_type": "code",
63
  "execution_count": 2,
@@ -65,7 +57,7 @@
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
68
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
69
  ]
70
  },
71
  {
 
50
  "## VQGAN-JAX model"
51
  ]
52
  },
 
 
 
 
 
 
 
 
53
  {
54
  "cell_type": "code",
55
  "execution_count": 2,
 
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
60
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
61
  ]
62
  },
63
  {
dev/notebooks/encoding/vqgan-jax-encoding-yfcc100m.ipynb CHANGED
@@ -52,14 +52,6 @@
52
  "## VQGAN-JAX model"
53
  ]
54
  },
55
- {
56
- "cell_type": "markdown",
57
- "id": "bb408f6c",
58
- "metadata": {},
59
- "source": [
60
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
61
- ]
62
- },
63
  {
64
  "cell_type": "code",
65
  "execution_count": 93,
@@ -67,7 +59,7 @@
67
  "metadata": {},
68
  "outputs": [],
69
  "source": [
70
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
71
  ]
72
  },
73
  {
@@ -1111,9 +1103,13 @@
1111
  }
1112
  ],
1113
  "metadata": {
 
 
 
1114
  "kernelspec": {
1115
- "name": "python3",
1116
- "display_name": "Python 3.9.0 64-bit ('Python39')"
 
1117
  },
1118
  "language_info": {
1119
  "codemirror_mode": {
@@ -1125,12 +1121,9 @@
1125
  "name": "python",
1126
  "nbconvert_exporter": "python",
1127
  "pygments_lexer": "ipython3",
1128
- "version": "3.9.0"
1129
- },
1130
- "interpreter": {
1131
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1132
  }
1133
  },
1134
  "nbformat": 4,
1135
  "nbformat_minor": 5
1136
- }
 
52
  "## VQGAN-JAX model"
53
  ]
54
  },
 
 
 
 
 
 
 
 
55
  {
56
  "cell_type": "code",
57
  "execution_count": 93,
 
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
63
  ]
64
  },
65
  {
 
1103
  }
1104
  ],
1105
  "metadata": {
1106
+ "interpreter": {
1107
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1108
+ },
1109
  "kernelspec": {
1110
+ "display_name": "Python 3 (ipykernel)",
1111
+ "language": "python",
1112
+ "name": "python3"
1113
  },
1114
  "language_info": {
1115
  "codemirror_mode": {
 
1121
  "name": "python",
1122
  "nbconvert_exporter": "python",
1123
  "pygments_lexer": "ipython3",
1124
+ "version": "3.8.10"
 
 
 
1125
  }
1126
  },
1127
  "nbformat": 4,
1128
  "nbformat_minor": 5
1129
+ }
dev/predictions/wandb-examples.py CHANGED
@@ -23,7 +23,7 @@ import torchvision.transforms as T
23
  import torchvision.transforms.functional as TF
24
  from torchvision.transforms import InterpolationMode
25
 
26
- from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
27
 
28
  # TODO: set those args in a config file
29
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
 
23
  import torchvision.transforms.functional as TF
24
  from torchvision.transforms import InterpolationMode
25
 
26
+ from vqgan_jax.modeling_flax_vqgan import VQModel
27
 
28
  # TODO: set those args in a config file
29
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
dev/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note: install with the following command:
2
+ # pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
+ # Otherwise it won't find the appropriate libtpu_nightly
4
+ requests
5
+ jax[tpu]>=0.2.16
6
+ -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
+ -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
+ flax
9
+ jupyter
10
+ wandb
11
+ nltk
12
+ optax
13
+ git+https://github.com/patil-suraj/vqgan-jax.git@610d842dd33c739325a944102ed33acc07692dd5
14
+
15
+ # Inference
16
+ ftfy
examples/JAX_VQGAN_f16_16384_Reconstruction.ipynb ADDED
The diff for this file is too large to render. See raw diff