osanseviero
commited on
Commit
•
7cc474b
1
Parent(s):
149555d
Change requirements and add vqgan_jax clone
Browse files- requirements.txt +0 -1
- vqgan_jax/configuration_vqgan.py +40 -0
- vqgan_jax/modeling_flax_vqgan.py +666 -0
requirements.txt
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
transformers
|
2 |
flax
|
3 |
-
git+https://github.com/patil-suraj/vqgan-jax.git
|
|
|
1 |
transformers
|
2 |
flax
|
|
vqgan_jax/configuration_vqgan.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
from transformers import PretrainedConfig
|
4 |
+
|
5 |
+
|
6 |
+
class VQGANConfig(PretrainedConfig):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
ch: int = 128,
|
10 |
+
out_ch: int = 3,
|
11 |
+
in_channels: int = 3,
|
12 |
+
num_res_blocks: int = 2,
|
13 |
+
resolution: int = 256,
|
14 |
+
z_channels: int = 256,
|
15 |
+
ch_mult: Tuple = (1, 1, 2, 2, 4),
|
16 |
+
attn_resolutions: int = (16,),
|
17 |
+
n_embed: int = 1024,
|
18 |
+
embed_dim: int = 256,
|
19 |
+
dropout: float = 0.0,
|
20 |
+
double_z: bool = False,
|
21 |
+
resamp_with_conv: bool = True,
|
22 |
+
give_pre_end: bool = False,
|
23 |
+
**kwargs,
|
24 |
+
):
|
25 |
+
super().__init__(**kwargs)
|
26 |
+
self.ch = ch
|
27 |
+
self.out_ch = out_ch
|
28 |
+
self.in_channels = in_channels
|
29 |
+
self.num_res_blocks = num_res_blocks
|
30 |
+
self.resolution = resolution
|
31 |
+
self.z_channels = z_channels
|
32 |
+
self.ch_mult = list(ch_mult)
|
33 |
+
self.attn_resolutions = list(attn_resolutions)
|
34 |
+
self.n_embed = n_embed
|
35 |
+
self.embed_dim = embed_dim
|
36 |
+
self.dropout = dropout
|
37 |
+
self.double_z = double_z
|
38 |
+
self.resamp_with_conv = resamp_with_conv
|
39 |
+
self.give_pre_end = give_pre_end
|
40 |
+
self.num_resolutions = len(ch_mult)
|
vqgan_jax/modeling_flax_vqgan.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
from typing import Tuple
|
5 |
+
import math
|
6 |
+
|
7 |
+
import jax
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import numpy as np
|
10 |
+
import flax.linen as nn
|
11 |
+
from flax.core.frozen_dict import FrozenDict
|
12 |
+
|
13 |
+
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
14 |
+
|
15 |
+
from .configuration_vqgan import VQGANConfig
|
16 |
+
|
17 |
+
|
18 |
+
class Upsample(nn.Module):
|
19 |
+
in_channels: int
|
20 |
+
with_conv: bool
|
21 |
+
dtype: jnp.dtype = jnp.float32
|
22 |
+
|
23 |
+
def setup(self):
|
24 |
+
if self.with_conv:
|
25 |
+
self.conv = nn.Conv(
|
26 |
+
self.in_channels,
|
27 |
+
kernel_size=(3, 3),
|
28 |
+
strides=(1, 1),
|
29 |
+
padding=((1, 1), (1, 1)),
|
30 |
+
dtype=self.dtype,
|
31 |
+
)
|
32 |
+
|
33 |
+
def __call__(self, hidden_states):
|
34 |
+
batch, height, width, channels = hidden_states.shape
|
35 |
+
hidden_states = jax.image.resize(
|
36 |
+
hidden_states,
|
37 |
+
shape=(batch, height * 2, width * 2, channels),
|
38 |
+
method="nearest",
|
39 |
+
)
|
40 |
+
if self.with_conv:
|
41 |
+
hidden_states = self.conv(hidden_states)
|
42 |
+
return hidden_states
|
43 |
+
|
44 |
+
|
45 |
+
class Downsample(nn.Module):
|
46 |
+
in_channels: int
|
47 |
+
with_conv: bool
|
48 |
+
dtype: jnp.dtype = jnp.float32
|
49 |
+
|
50 |
+
def setup(self):
|
51 |
+
if self.with_conv:
|
52 |
+
self.conv = nn.Conv(
|
53 |
+
self.in_channels,
|
54 |
+
kernel_size=(3, 3),
|
55 |
+
strides=(2, 2),
|
56 |
+
padding="VALID",
|
57 |
+
dtype=self.dtype,
|
58 |
+
)
|
59 |
+
|
60 |
+
def __call__(self, hidden_states):
|
61 |
+
if self.with_conv:
|
62 |
+
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
|
63 |
+
hidden_states = jnp.pad(hidden_states, pad_width=pad)
|
64 |
+
hidden_states = self.conv(hidden_states)
|
65 |
+
else:
|
66 |
+
hidden_states = nn.avg_pool(hidden_states,
|
67 |
+
window_shape=(2, 2),
|
68 |
+
strides=(2, 2),
|
69 |
+
padding="VALID")
|
70 |
+
return hidden_states
|
71 |
+
|
72 |
+
|
73 |
+
class ResnetBlock(nn.Module):
|
74 |
+
in_channels: int
|
75 |
+
out_channels: int = None
|
76 |
+
use_conv_shortcut: bool = False
|
77 |
+
temb_channels: int = 512
|
78 |
+
dropout_prob: float = 0.0
|
79 |
+
dtype: jnp.dtype = jnp.float32
|
80 |
+
|
81 |
+
def setup(self):
|
82 |
+
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
|
83 |
+
|
84 |
+
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
85 |
+
self.conv1 = nn.Conv(
|
86 |
+
self.out_channels_,
|
87 |
+
kernel_size=(3, 3),
|
88 |
+
strides=(1, 1),
|
89 |
+
padding=((1, 1), (1, 1)),
|
90 |
+
dtype=self.dtype,
|
91 |
+
)
|
92 |
+
|
93 |
+
if self.temb_channels:
|
94 |
+
self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
|
95 |
+
|
96 |
+
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
97 |
+
self.dropout = nn.Dropout(self.dropout_prob)
|
98 |
+
self.conv2 = nn.Conv(
|
99 |
+
self.out_channels_,
|
100 |
+
kernel_size=(3, 3),
|
101 |
+
strides=(1, 1),
|
102 |
+
padding=((1, 1), (1, 1)),
|
103 |
+
dtype=self.dtype,
|
104 |
+
)
|
105 |
+
|
106 |
+
if self.in_channels != self.out_channels_:
|
107 |
+
if self.use_conv_shortcut:
|
108 |
+
self.conv_shortcut = nn.Conv(
|
109 |
+
self.out_channels_,
|
110 |
+
kernel_size=(3, 3),
|
111 |
+
strides=(1, 1),
|
112 |
+
padding=((1, 1), (1, 1)),
|
113 |
+
dtype=self.dtype,
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
self.nin_shortcut = nn.Conv(
|
117 |
+
self.out_channels_,
|
118 |
+
kernel_size=(1, 1),
|
119 |
+
strides=(1, 1),
|
120 |
+
padding="VALID",
|
121 |
+
dtype=self.dtype,
|
122 |
+
)
|
123 |
+
|
124 |
+
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
125 |
+
residual = hidden_states
|
126 |
+
hidden_states = self.norm1(hidden_states)
|
127 |
+
hidden_states = nn.swish(hidden_states)
|
128 |
+
hidden_states = self.conv1(hidden_states)
|
129 |
+
|
130 |
+
if temb is not None:
|
131 |
+
hidden_states = hidden_states + self.temb_proj(
|
132 |
+
nn.swish(temb))[:, :, None, None] # TODO: check shapes
|
133 |
+
|
134 |
+
hidden_states = self.norm2(hidden_states)
|
135 |
+
hidden_states = nn.swish(hidden_states)
|
136 |
+
hidden_states = self.dropout(hidden_states, deterministic)
|
137 |
+
hidden_states = self.conv2(hidden_states)
|
138 |
+
|
139 |
+
if self.in_channels != self.out_channels_:
|
140 |
+
if self.use_conv_shortcut:
|
141 |
+
residual = self.conv_shortcut(residual)
|
142 |
+
else:
|
143 |
+
residual = self.nin_shortcut(residual)
|
144 |
+
|
145 |
+
return hidden_states + residual
|
146 |
+
|
147 |
+
|
148 |
+
class AttnBlock(nn.Module):
|
149 |
+
in_channels: int
|
150 |
+
dtype: jnp.dtype = jnp.float32
|
151 |
+
|
152 |
+
def setup(self):
|
153 |
+
conv = partial(nn.Conv,
|
154 |
+
self.in_channels,
|
155 |
+
kernel_size=(1, 1),
|
156 |
+
strides=(1, 1),
|
157 |
+
padding="VALID",
|
158 |
+
dtype=self.dtype)
|
159 |
+
|
160 |
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
161 |
+
self.q, self.k, self.v = conv(), conv(), conv()
|
162 |
+
self.proj_out = conv()
|
163 |
+
|
164 |
+
def __call__(self, hidden_states):
|
165 |
+
residual = hidden_states
|
166 |
+
hidden_states = self.norm(hidden_states)
|
167 |
+
|
168 |
+
query = self.q(hidden_states)
|
169 |
+
key = self.k(hidden_states)
|
170 |
+
value = self.v(hidden_states)
|
171 |
+
|
172 |
+
# compute attentions
|
173 |
+
batch, height, width, channels = query.shape
|
174 |
+
query = query.reshape((batch, height * width, channels))
|
175 |
+
key = key.reshape((batch, height * width, channels))
|
176 |
+
attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
|
177 |
+
attn_weights = attn_weights * (int(channels)**-0.5)
|
178 |
+
attn_weights = nn.softmax(attn_weights, axis=2)
|
179 |
+
|
180 |
+
## attend to values
|
181 |
+
value = value.reshape((batch, height * width, channels))
|
182 |
+
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
|
183 |
+
hidden_states = hidden_states.reshape((batch, height, width, channels))
|
184 |
+
|
185 |
+
hidden_states = self.proj_out(hidden_states)
|
186 |
+
hidden_states = hidden_states + residual
|
187 |
+
return hidden_states
|
188 |
+
|
189 |
+
|
190 |
+
class UpsamplingBlock(nn.Module):
|
191 |
+
config: VQGANConfig
|
192 |
+
curr_res: int
|
193 |
+
block_idx: int
|
194 |
+
dtype: jnp.dtype = jnp.float32
|
195 |
+
|
196 |
+
def setup(self):
|
197 |
+
if self.block_idx == self.config.num_resolutions - 1:
|
198 |
+
block_in = self.config.ch * self.config.ch_mult[-1]
|
199 |
+
else:
|
200 |
+
block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
|
201 |
+
|
202 |
+
block_out = self.config.ch * self.config.ch_mult[self.block_idx]
|
203 |
+
self.temb_ch = 0
|
204 |
+
|
205 |
+
res_blocks = []
|
206 |
+
attn_blocks = []
|
207 |
+
for _ in range(self.config.num_res_blocks + 1):
|
208 |
+
res_blocks.append(
|
209 |
+
ResnetBlock(block_in,
|
210 |
+
block_out,
|
211 |
+
temb_channels=self.temb_ch,
|
212 |
+
dropout_prob=self.config.dropout,
|
213 |
+
dtype=self.dtype))
|
214 |
+
block_in = block_out
|
215 |
+
if self.curr_res in self.config.attn_resolutions:
|
216 |
+
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
|
217 |
+
|
218 |
+
self.block = res_blocks
|
219 |
+
self.attn = attn_blocks
|
220 |
+
|
221 |
+
self.upsample = None
|
222 |
+
if self.block_idx != 0:
|
223 |
+
self.upsample = Upsample(block_in,
|
224 |
+
self.config.resamp_with_conv,
|
225 |
+
dtype=self.dtype)
|
226 |
+
|
227 |
+
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
228 |
+
for res_block in self.block:
|
229 |
+
hidden_states = res_block(hidden_states,
|
230 |
+
temb,
|
231 |
+
deterministic=deterministic)
|
232 |
+
for attn_block in self.attn:
|
233 |
+
hidden_states = attn_block(hidden_states)
|
234 |
+
|
235 |
+
if self.upsample is not None:
|
236 |
+
hidden_states = self.upsample(hidden_states)
|
237 |
+
|
238 |
+
return hidden_states
|
239 |
+
|
240 |
+
|
241 |
+
class DownsamplingBlock(nn.Module):
|
242 |
+
config: VQGANConfig
|
243 |
+
curr_res: int
|
244 |
+
block_idx: int
|
245 |
+
dtype: jnp.dtype = jnp.float32
|
246 |
+
|
247 |
+
def setup(self):
|
248 |
+
in_ch_mult = (1, ) + tuple(self.config.ch_mult)
|
249 |
+
block_in = self.config.ch * in_ch_mult[self.block_idx]
|
250 |
+
block_out = self.config.ch * self.config.ch_mult[self.block_idx]
|
251 |
+
self.temb_ch = 0
|
252 |
+
|
253 |
+
res_blocks = []
|
254 |
+
attn_blocks = []
|
255 |
+
for _ in range(self.config.num_res_blocks):
|
256 |
+
res_blocks.append(
|
257 |
+
ResnetBlock(block_in,
|
258 |
+
block_out,
|
259 |
+
temb_channels=self.temb_ch,
|
260 |
+
dropout_prob=self.config.dropout,
|
261 |
+
dtype=self.dtype))
|
262 |
+
block_in = block_out
|
263 |
+
if self.curr_res in self.config.attn_resolutions:
|
264 |
+
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
|
265 |
+
|
266 |
+
self.block = res_blocks
|
267 |
+
self.attn = attn_blocks
|
268 |
+
|
269 |
+
self.downsample = None
|
270 |
+
if self.block_idx != self.config.num_resolutions - 1:
|
271 |
+
self.downsample = Downsample(block_in,
|
272 |
+
self.config.resamp_with_conv,
|
273 |
+
dtype=self.dtype)
|
274 |
+
|
275 |
+
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
276 |
+
for res_block in self.block:
|
277 |
+
hidden_states = res_block(hidden_states,
|
278 |
+
temb,
|
279 |
+
deterministic=deterministic)
|
280 |
+
for attn_block in self.attn:
|
281 |
+
hidden_states = attn_block(hidden_states)
|
282 |
+
|
283 |
+
if self.downsample is not None:
|
284 |
+
hidden_states = self.downsample(hidden_states)
|
285 |
+
|
286 |
+
return hidden_states
|
287 |
+
|
288 |
+
|
289 |
+
class MidBlock(nn.Module):
|
290 |
+
in_channels: int
|
291 |
+
temb_channels: int
|
292 |
+
dropout: float
|
293 |
+
dtype: jnp.dtype = jnp.float32
|
294 |
+
|
295 |
+
def setup(self):
|
296 |
+
self.block_1 = ResnetBlock(
|
297 |
+
self.in_channels,
|
298 |
+
self.in_channels,
|
299 |
+
temb_channels=self.temb_channels,
|
300 |
+
dropout_prob=self.dropout,
|
301 |
+
dtype=self.dtype,
|
302 |
+
)
|
303 |
+
self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
|
304 |
+
self.block_2 = ResnetBlock(
|
305 |
+
self.in_channels,
|
306 |
+
self.in_channels,
|
307 |
+
temb_channels=self.temb_channels,
|
308 |
+
dropout_prob=self.dropout,
|
309 |
+
dtype=self.dtype,
|
310 |
+
)
|
311 |
+
|
312 |
+
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
|
313 |
+
hidden_states = self.block_1(hidden_states,
|
314 |
+
temb,
|
315 |
+
deterministic=deterministic)
|
316 |
+
hidden_states = self.attn_1(hidden_states)
|
317 |
+
hidden_states = self.block_2(hidden_states,
|
318 |
+
temb,
|
319 |
+
deterministic=deterministic)
|
320 |
+
return hidden_states
|
321 |
+
|
322 |
+
|
323 |
+
class Encoder(nn.Module):
|
324 |
+
config: VQGANConfig
|
325 |
+
dtype: jnp.dtype = jnp.float32
|
326 |
+
|
327 |
+
def setup(self):
|
328 |
+
self.temb_ch = 0
|
329 |
+
|
330 |
+
# downsampling
|
331 |
+
self.conv_in = nn.Conv(
|
332 |
+
self.config.ch,
|
333 |
+
kernel_size=(3, 3),
|
334 |
+
strides=(1, 1),
|
335 |
+
padding=((1, 1), (1, 1)),
|
336 |
+
dtype=self.dtype,
|
337 |
+
)
|
338 |
+
|
339 |
+
curr_res = self.config.resolution
|
340 |
+
downsample_blocks = []
|
341 |
+
for i_level in range(self.config.num_resolutions):
|
342 |
+
downsample_blocks.append(
|
343 |
+
DownsamplingBlock(self.config,
|
344 |
+
curr_res,
|
345 |
+
block_idx=i_level,
|
346 |
+
dtype=self.dtype))
|
347 |
+
|
348 |
+
if i_level != self.config.num_resolutions - 1:
|
349 |
+
curr_res = curr_res // 2
|
350 |
+
self.down = downsample_blocks
|
351 |
+
|
352 |
+
# middle
|
353 |
+
mid_channels = self.config.ch * self.config.ch_mult[-1]
|
354 |
+
self.mid = MidBlock(mid_channels,
|
355 |
+
self.temb_ch,
|
356 |
+
self.config.dropout,
|
357 |
+
dtype=self.dtype)
|
358 |
+
|
359 |
+
# end
|
360 |
+
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
361 |
+
self.conv_out = nn.Conv(
|
362 |
+
2 * self.config.z_channels
|
363 |
+
if self.config.double_z else self.config.z_channels,
|
364 |
+
kernel_size=(3, 3),
|
365 |
+
strides=(1, 1),
|
366 |
+
padding=((1, 1), (1, 1)),
|
367 |
+
dtype=self.dtype,
|
368 |
+
)
|
369 |
+
|
370 |
+
def __call__(self, pixel_values, deterministic: bool = True):
|
371 |
+
# timestep embedding
|
372 |
+
temb = None
|
373 |
+
|
374 |
+
# downsampling
|
375 |
+
hidden_states = self.conv_in(pixel_values)
|
376 |
+
for block in self.down:
|
377 |
+
hidden_states = block(hidden_states, temb, deterministic=deterministic)
|
378 |
+
|
379 |
+
# middle
|
380 |
+
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
|
381 |
+
|
382 |
+
# end
|
383 |
+
hidden_states = self.norm_out(hidden_states)
|
384 |
+
hidden_states = nn.swish(hidden_states)
|
385 |
+
hidden_states = self.conv_out(hidden_states)
|
386 |
+
|
387 |
+
return hidden_states
|
388 |
+
|
389 |
+
|
390 |
+
class Decoder(nn.Module):
|
391 |
+
config: VQGANConfig
|
392 |
+
dtype: jnp.dtype = jnp.float32
|
393 |
+
|
394 |
+
def setup(self):
|
395 |
+
self.temb_ch = 0
|
396 |
+
|
397 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
398 |
+
block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions
|
399 |
+
- 1]
|
400 |
+
curr_res = self.config.resolution // 2**(self.config.num_resolutions - 1)
|
401 |
+
self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
|
402 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
403 |
+
self.z_shape, np.prod(self.z_shape)))
|
404 |
+
|
405 |
+
# z to block_in
|
406 |
+
self.conv_in = nn.Conv(
|
407 |
+
block_in,
|
408 |
+
kernel_size=(3, 3),
|
409 |
+
strides=(1, 1),
|
410 |
+
padding=((1, 1), (1, 1)),
|
411 |
+
dtype=self.dtype,
|
412 |
+
)
|
413 |
+
|
414 |
+
# middle
|
415 |
+
self.mid = MidBlock(block_in,
|
416 |
+
self.temb_ch,
|
417 |
+
self.config.dropout,
|
418 |
+
dtype=self.dtype)
|
419 |
+
|
420 |
+
# upsampling
|
421 |
+
upsample_blocks = []
|
422 |
+
for i_level in reversed(range(self.config.num_resolutions)):
|
423 |
+
upsample_blocks.append(
|
424 |
+
UpsamplingBlock(self.config,
|
425 |
+
curr_res,
|
426 |
+
block_idx=i_level,
|
427 |
+
dtype=self.dtype))
|
428 |
+
if i_level != 0:
|
429 |
+
curr_res = curr_res * 2
|
430 |
+
self.up = list(
|
431 |
+
reversed(upsample_blocks)) # reverse to get consistent order
|
432 |
+
|
433 |
+
# end
|
434 |
+
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
435 |
+
self.conv_out = nn.Conv(
|
436 |
+
self.config.out_ch,
|
437 |
+
kernel_size=(3, 3),
|
438 |
+
strides=(1, 1),
|
439 |
+
padding=((1, 1), (1, 1)),
|
440 |
+
dtype=self.dtype,
|
441 |
+
)
|
442 |
+
|
443 |
+
def __call__(self, hidden_states, deterministic: bool = True):
|
444 |
+
# timestep embedding
|
445 |
+
temb = None
|
446 |
+
|
447 |
+
# z to block_in
|
448 |
+
hidden_states = self.conv_in(hidden_states)
|
449 |
+
|
450 |
+
# middle
|
451 |
+
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
|
452 |
+
|
453 |
+
# upsampling
|
454 |
+
for block in reversed(self.up):
|
455 |
+
hidden_states = block(hidden_states, temb, deterministic=deterministic)
|
456 |
+
|
457 |
+
# end
|
458 |
+
if self.config.give_pre_end:
|
459 |
+
return hidden_states
|
460 |
+
|
461 |
+
hidden_states = self.norm_out(hidden_states)
|
462 |
+
hidden_states = nn.swish(hidden_states)
|
463 |
+
hidden_states = self.conv_out(hidden_states)
|
464 |
+
|
465 |
+
return hidden_states
|
466 |
+
|
467 |
+
|
468 |
+
class VectorQuantizer(nn.Module):
|
469 |
+
"""
|
470 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
471 |
+
____________________________________________
|
472 |
+
Discretization bottleneck part of the VQ-VAE.
|
473 |
+
Inputs:
|
474 |
+
- n_e : number of embeddings
|
475 |
+
- e_dim : dimension of embedding
|
476 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
477 |
+
_____________________________________________
|
478 |
+
"""
|
479 |
+
|
480 |
+
config: VQGANConfig
|
481 |
+
dtype: jnp.dtype = jnp.float32
|
482 |
+
|
483 |
+
def setup(self):
|
484 |
+
self.embedding = nn.Embed(self.config.n_embed,
|
485 |
+
self.config.embed_dim,
|
486 |
+
dtype=self.dtype) # TODO: init
|
487 |
+
|
488 |
+
def __call__(self, hidden_states):
|
489 |
+
"""
|
490 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
491 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
492 |
+
z (continuous) -> z_q (discrete)
|
493 |
+
z.shape = (batch, channel, height, width)
|
494 |
+
quantization pipeline:
|
495 |
+
1. get encoder input (B,C,H,W)
|
496 |
+
2. flatten input to (B*H*W,C)
|
497 |
+
"""
|
498 |
+
# flatten
|
499 |
+
hidden_states_flattended = hidden_states.reshape(
|
500 |
+
(-1, self.config.embed_dim))
|
501 |
+
|
502 |
+
# dummy op to init the weights, so we can access them below
|
503 |
+
self.embedding(jnp.ones((1, 1), dtype="i4"))
|
504 |
+
|
505 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
506 |
+
emb_weights = self.variables["params"]["embedding"]["embedding"]
|
507 |
+
distance = (jnp.sum(hidden_states_flattended**2, axis=1, keepdims=True) +
|
508 |
+
jnp.sum(emb_weights**2, axis=1) -
|
509 |
+
2 * jnp.dot(hidden_states_flattended, emb_weights.T))
|
510 |
+
|
511 |
+
# get quantized latent vectors
|
512 |
+
min_encoding_indices = jnp.argmin(distance, axis=1)
|
513 |
+
z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
|
514 |
+
|
515 |
+
# reshape to (batch, num_tokens)
|
516 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
517 |
+
hidden_states.shape[0], -1)
|
518 |
+
|
519 |
+
# compute the codebook_loss (q_loss) outside the model
|
520 |
+
# here we return the embeddings and indices
|
521 |
+
return z_q, min_encoding_indices
|
522 |
+
|
523 |
+
def get_codebook_entry(self, indices, shape=None):
|
524 |
+
# indices are expected to be of shape (batch, num_tokens)
|
525 |
+
# get quantized latent vectors
|
526 |
+
batch, num_tokens = indices.shape
|
527 |
+
z_q = self.embedding(indices)
|
528 |
+
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)),
|
529 |
+
int(math.sqrt(num_tokens)), -1)
|
530 |
+
return z_q
|
531 |
+
|
532 |
+
|
533 |
+
class VQModule(nn.Module):
|
534 |
+
config: VQGANConfig
|
535 |
+
dtype: jnp.dtype = jnp.float32
|
536 |
+
|
537 |
+
def setup(self):
|
538 |
+
self.encoder = Encoder(self.config, dtype=self.dtype)
|
539 |
+
self.decoder = Decoder(self.config, dtype=self.dtype)
|
540 |
+
self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
|
541 |
+
self.quant_conv = nn.Conv(
|
542 |
+
self.config.embed_dim,
|
543 |
+
kernel_size=(1, 1),
|
544 |
+
strides=(1, 1),
|
545 |
+
padding="VALID",
|
546 |
+
dtype=self.dtype,
|
547 |
+
)
|
548 |
+
self.post_quant_conv = nn.Conv(
|
549 |
+
self.config.z_channels,
|
550 |
+
kernel_size=(1, 1),
|
551 |
+
strides=(1, 1),
|
552 |
+
padding="VALID",
|
553 |
+
dtype=self.dtype,
|
554 |
+
)
|
555 |
+
|
556 |
+
def encode(self, pixel_values, deterministic: bool = True):
|
557 |
+
hidden_states = self.encoder(pixel_values, deterministic=deterministic)
|
558 |
+
hidden_states = self.quant_conv(hidden_states)
|
559 |
+
quant_states, indices = self.quantize(hidden_states)
|
560 |
+
return quant_states, indices
|
561 |
+
|
562 |
+
def decode(self, hidden_states, deterministic: bool = True):
|
563 |
+
hidden_states = self.post_quant_conv(hidden_states)
|
564 |
+
hidden_states = self.decoder(hidden_states, deterministic=deterministic)
|
565 |
+
return hidden_states
|
566 |
+
|
567 |
+
def decode_code(self, code_b):
|
568 |
+
hidden_states = self.quantize.get_codebook_entry(code_b)
|
569 |
+
hidden_states = self.decode(hidden_states)
|
570 |
+
return hidden_states
|
571 |
+
|
572 |
+
def __call__(self, pixel_values, deterministic: bool = True):
|
573 |
+
quant_states, indices = self.encode(pixel_values, deterministic)
|
574 |
+
hidden_states = self.decode(quant_states, deterministic)
|
575 |
+
return hidden_states, indices
|
576 |
+
|
577 |
+
|
578 |
+
class VQGANPreTrainedModel(FlaxPreTrainedModel):
|
579 |
+
"""
|
580 |
+
An abstract class to handle weights initialization and a simple interface
|
581 |
+
for downloading and loading pretrained models.
|
582 |
+
"""
|
583 |
+
|
584 |
+
config_class = VQGANConfig
|
585 |
+
base_model_prefix = "model"
|
586 |
+
module_class: nn.Module = None
|
587 |
+
|
588 |
+
def __init__(
|
589 |
+
self,
|
590 |
+
config: VQGANConfig,
|
591 |
+
input_shape: Tuple = (1, 256, 256, 3),
|
592 |
+
seed: int = 0,
|
593 |
+
dtype: jnp.dtype = jnp.float32,
|
594 |
+
**kwargs,
|
595 |
+
):
|
596 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
597 |
+
super().__init__(config,
|
598 |
+
module,
|
599 |
+
input_shape=input_shape,
|
600 |
+
seed=seed,
|
601 |
+
dtype=dtype)
|
602 |
+
|
603 |
+
def init_weights(self, rng: jax.random.PRNGKey,
|
604 |
+
input_shape: Tuple) -> FrozenDict:
|
605 |
+
# init input tensors
|
606 |
+
pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
|
607 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
608 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
609 |
+
|
610 |
+
return self.module.init(rngs, pixel_values)["params"]
|
611 |
+
|
612 |
+
def encode(self,
|
613 |
+
pixel_values,
|
614 |
+
params: dict = None,
|
615 |
+
dropout_rng: jax.random.PRNGKey = None,
|
616 |
+
train: bool = False):
|
617 |
+
# Handle any PRNG if needed
|
618 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
619 |
+
|
620 |
+
return self.module.apply({"params": params or self.params},
|
621 |
+
jnp.array(pixel_values),
|
622 |
+
not train,
|
623 |
+
rngs=rngs,
|
624 |
+
method=self.module.encode)
|
625 |
+
|
626 |
+
def decode(self,
|
627 |
+
hidden_states,
|
628 |
+
params: dict = None,
|
629 |
+
dropout_rng: jax.random.PRNGKey = None,
|
630 |
+
train: bool = False):
|
631 |
+
# Handle any PRNG if needed
|
632 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
633 |
+
|
634 |
+
return self.module.apply(
|
635 |
+
{"params": params or self.params},
|
636 |
+
jnp.array(hidden_states),
|
637 |
+
not train,
|
638 |
+
rngs=rngs,
|
639 |
+
method=self.module.decode,
|
640 |
+
)
|
641 |
+
|
642 |
+
def decode_code(self, indices, params: dict = None):
|
643 |
+
return self.module.apply({"params": params or self.params},
|
644 |
+
jnp.array(indices, dtype="i4"),
|
645 |
+
method=self.module.decode_code)
|
646 |
+
|
647 |
+
def __call__(
|
648 |
+
self,
|
649 |
+
pixel_values,
|
650 |
+
params: dict = None,
|
651 |
+
dropout_rng: jax.random.PRNGKey = None,
|
652 |
+
train: bool = False,
|
653 |
+
):
|
654 |
+
# Handle any PRNG if needed
|
655 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
656 |
+
|
657 |
+
return self.module.apply(
|
658 |
+
{"params": params or self.params},
|
659 |
+
jnp.array(pixel_values),
|
660 |
+
not train,
|
661 |
+
rngs=rngs,
|
662 |
+
)
|
663 |
+
|
664 |
+
|
665 |
+
class VQModel(VQGANPreTrainedModel):
|
666 |
+
module_class = VQModule
|