|
|
|
|
|
from functools import partial |
|
from typing import Tuple |
|
import math |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import flax.linen as nn |
|
from flax.core.frozen_dict import FrozenDict |
|
|
|
from transformers.modeling_flax_utils import FlaxPreTrainedModel |
|
|
|
from .configuration_vqgan import VQGANConfig |
|
|
|
|
|
class Upsample(nn.Module): |
|
in_channels: int |
|
with_conv: bool |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
if self.with_conv: |
|
self.conv = nn.Conv( |
|
self.in_channels, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
batch, height, width, channels = hidden_states.shape |
|
hidden_states = jax.image.resize( |
|
hidden_states, |
|
shape=(batch, height * 2, width * 2, channels), |
|
method="nearest", |
|
) |
|
if self.with_conv: |
|
hidden_states = self.conv(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class Downsample(nn.Module): |
|
in_channels: int |
|
with_conv: bool |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
if self.with_conv: |
|
self.conv = nn.Conv( |
|
self.in_channels, |
|
kernel_size=(3, 3), |
|
strides=(2, 2), |
|
padding="VALID", |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states): |
|
if self.with_conv: |
|
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) |
|
hidden_states = jnp.pad(hidden_states, pad_width=pad) |
|
hidden_states = self.conv(hidden_states) |
|
else: |
|
hidden_states = nn.avg_pool(hidden_states, |
|
window_shape=(2, 2), |
|
strides=(2, 2), |
|
padding="VALID") |
|
return hidden_states |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
in_channels: int |
|
out_channels: int = None |
|
use_conv_shortcut: bool = False |
|
temb_channels: int = 512 |
|
dropout_prob: float = 0.0 |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels |
|
|
|
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) |
|
self.conv1 = nn.Conv( |
|
self.out_channels_, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
if self.temb_channels: |
|
self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype) |
|
|
|
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) |
|
self.dropout = nn.Dropout(self.dropout_prob) |
|
self.conv2 = nn.Conv( |
|
self.out_channels_, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
if self.in_channels != self.out_channels_: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = nn.Conv( |
|
self.out_channels_, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
else: |
|
self.nin_shortcut = nn.Conv( |
|
self.out_channels_, |
|
kernel_size=(1, 1), |
|
strides=(1, 1), |
|
padding="VALID", |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states, temb=None, deterministic: bool = True): |
|
residual = hidden_states |
|
hidden_states = self.norm1(hidden_states) |
|
hidden_states = nn.swish(hidden_states) |
|
hidden_states = self.conv1(hidden_states) |
|
|
|
if temb is not None: |
|
hidden_states = hidden_states + self.temb_proj( |
|
nn.swish(temb))[:, :, None, None] |
|
|
|
hidden_states = self.norm2(hidden_states) |
|
hidden_states = nn.swish(hidden_states) |
|
hidden_states = self.dropout(hidden_states, deterministic) |
|
hidden_states = self.conv2(hidden_states) |
|
|
|
if self.in_channels != self.out_channels_: |
|
if self.use_conv_shortcut: |
|
residual = self.conv_shortcut(residual) |
|
else: |
|
residual = self.nin_shortcut(residual) |
|
|
|
return hidden_states + residual |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
in_channels: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
conv = partial(nn.Conv, |
|
self.in_channels, |
|
kernel_size=(1, 1), |
|
strides=(1, 1), |
|
padding="VALID", |
|
dtype=self.dtype) |
|
|
|
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) |
|
self.q, self.k, self.v = conv(), conv(), conv() |
|
self.proj_out = conv() |
|
|
|
def __call__(self, hidden_states): |
|
residual = hidden_states |
|
hidden_states = self.norm(hidden_states) |
|
|
|
query = self.q(hidden_states) |
|
key = self.k(hidden_states) |
|
value = self.v(hidden_states) |
|
|
|
|
|
batch, height, width, channels = query.shape |
|
query = query.reshape((batch, height * width, channels)) |
|
key = key.reshape((batch, height * width, channels)) |
|
attn_weights = jnp.einsum("...qc,...kc->...qk", query, key) |
|
attn_weights = attn_weights * (int(channels)**-0.5) |
|
attn_weights = nn.softmax(attn_weights, axis=2) |
|
|
|
|
|
value = value.reshape((batch, height * width, channels)) |
|
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) |
|
hidden_states = hidden_states.reshape((batch, height, width, channels)) |
|
|
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states + residual |
|
return hidden_states |
|
|
|
|
|
class UpsamplingBlock(nn.Module): |
|
config: VQGANConfig |
|
curr_res: int |
|
block_idx: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
if self.block_idx == self.config.num_resolutions - 1: |
|
block_in = self.config.ch * self.config.ch_mult[-1] |
|
else: |
|
block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1] |
|
|
|
block_out = self.config.ch * self.config.ch_mult[self.block_idx] |
|
self.temb_ch = 0 |
|
|
|
res_blocks = [] |
|
attn_blocks = [] |
|
for _ in range(self.config.num_res_blocks + 1): |
|
res_blocks.append( |
|
ResnetBlock(block_in, |
|
block_out, |
|
temb_channels=self.temb_ch, |
|
dropout_prob=self.config.dropout, |
|
dtype=self.dtype)) |
|
block_in = block_out |
|
if self.curr_res in self.config.attn_resolutions: |
|
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) |
|
|
|
self.block = res_blocks |
|
self.attn = attn_blocks |
|
|
|
self.upsample = None |
|
if self.block_idx != 0: |
|
self.upsample = Upsample(block_in, |
|
self.config.resamp_with_conv, |
|
dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, temb=None, deterministic: bool = True): |
|
for res_block in self.block: |
|
hidden_states = res_block(hidden_states, |
|
temb, |
|
deterministic=deterministic) |
|
for attn_block in self.attn: |
|
hidden_states = attn_block(hidden_states) |
|
|
|
if self.upsample is not None: |
|
hidden_states = self.upsample(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class DownsamplingBlock(nn.Module): |
|
config: VQGANConfig |
|
curr_res: int |
|
block_idx: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
in_ch_mult = (1, ) + tuple(self.config.ch_mult) |
|
block_in = self.config.ch * in_ch_mult[self.block_idx] |
|
block_out = self.config.ch * self.config.ch_mult[self.block_idx] |
|
self.temb_ch = 0 |
|
|
|
res_blocks = [] |
|
attn_blocks = [] |
|
for _ in range(self.config.num_res_blocks): |
|
res_blocks.append( |
|
ResnetBlock(block_in, |
|
block_out, |
|
temb_channels=self.temb_ch, |
|
dropout_prob=self.config.dropout, |
|
dtype=self.dtype)) |
|
block_in = block_out |
|
if self.curr_res in self.config.attn_resolutions: |
|
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) |
|
|
|
self.block = res_blocks |
|
self.attn = attn_blocks |
|
|
|
self.downsample = None |
|
if self.block_idx != self.config.num_resolutions - 1: |
|
self.downsample = Downsample(block_in, |
|
self.config.resamp_with_conv, |
|
dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, temb=None, deterministic: bool = True): |
|
for res_block in self.block: |
|
hidden_states = res_block(hidden_states, |
|
temb, |
|
deterministic=deterministic) |
|
for attn_block in self.attn: |
|
hidden_states = attn_block(hidden_states) |
|
|
|
if self.downsample is not None: |
|
hidden_states = self.downsample(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class MidBlock(nn.Module): |
|
in_channels: int |
|
temb_channels: int |
|
dropout: float |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.block_1 = ResnetBlock( |
|
self.in_channels, |
|
self.in_channels, |
|
temb_channels=self.temb_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype) |
|
self.block_2 = ResnetBlock( |
|
self.in_channels, |
|
self.in_channels, |
|
temb_channels=self.temb_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states, temb=None, deterministic: bool = True): |
|
hidden_states = self.block_1(hidden_states, |
|
temb, |
|
deterministic=deterministic) |
|
hidden_states = self.attn_1(hidden_states) |
|
hidden_states = self.block_2(hidden_states, |
|
temb, |
|
deterministic=deterministic) |
|
return hidden_states |
|
|
|
|
|
class Encoder(nn.Module): |
|
config: VQGANConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.temb_ch = 0 |
|
|
|
|
|
self.conv_in = nn.Conv( |
|
self.config.ch, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
curr_res = self.config.resolution |
|
downsample_blocks = [] |
|
for i_level in range(self.config.num_resolutions): |
|
downsample_blocks.append( |
|
DownsamplingBlock(self.config, |
|
curr_res, |
|
block_idx=i_level, |
|
dtype=self.dtype)) |
|
|
|
if i_level != self.config.num_resolutions - 1: |
|
curr_res = curr_res // 2 |
|
self.down = downsample_blocks |
|
|
|
|
|
mid_channels = self.config.ch * self.config.ch_mult[-1] |
|
self.mid = MidBlock(mid_channels, |
|
self.temb_ch, |
|
self.config.dropout, |
|
dtype=self.dtype) |
|
|
|
|
|
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) |
|
self.conv_out = nn.Conv( |
|
2 * self.config.z_channels |
|
if self.config.double_z else self.config.z_channels, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, pixel_values, deterministic: bool = True): |
|
|
|
temb = None |
|
|
|
|
|
hidden_states = self.conv_in(pixel_values) |
|
for block in self.down: |
|
hidden_states = block(hidden_states, temb, deterministic=deterministic) |
|
|
|
|
|
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = nn.swish(hidden_states) |
|
hidden_states = self.conv_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class Decoder(nn.Module): |
|
config: VQGANConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.temb_ch = 0 |
|
|
|
|
|
block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions |
|
- 1] |
|
curr_res = self.config.resolution // 2**(self.config.num_resolutions - 1) |
|
self.z_shape = (1, self.config.z_channels, curr_res, curr_res) |
|
print("Working with z of shape {} = {} dimensions.".format( |
|
self.z_shape, np.prod(self.z_shape))) |
|
|
|
|
|
self.conv_in = nn.Conv( |
|
block_in, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
|
|
self.mid = MidBlock(block_in, |
|
self.temb_ch, |
|
self.config.dropout, |
|
dtype=self.dtype) |
|
|
|
|
|
upsample_blocks = [] |
|
for i_level in reversed(range(self.config.num_resolutions)): |
|
upsample_blocks.append( |
|
UpsamplingBlock(self.config, |
|
curr_res, |
|
block_idx=i_level, |
|
dtype=self.dtype)) |
|
if i_level != 0: |
|
curr_res = curr_res * 2 |
|
self.up = list( |
|
reversed(upsample_blocks)) |
|
|
|
|
|
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) |
|
self.conv_out = nn.Conv( |
|
self.config.out_ch, |
|
kernel_size=(3, 3), |
|
strides=(1, 1), |
|
padding=((1, 1), (1, 1)), |
|
dtype=self.dtype, |
|
) |
|
|
|
def __call__(self, hidden_states, deterministic: bool = True): |
|
|
|
temb = None |
|
|
|
|
|
hidden_states = self.conv_in(hidden_states) |
|
|
|
|
|
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) |
|
|
|
|
|
for block in reversed(self.up): |
|
hidden_states = block(hidden_states, temb, deterministic=deterministic) |
|
|
|
|
|
if self.config.give_pre_end: |
|
return hidden_states |
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = nn.swish(hidden_states) |
|
hidden_states = self.conv_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class VectorQuantizer(nn.Module): |
|
""" |
|
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py |
|
____________________________________________ |
|
Discretization bottleneck part of the VQ-VAE. |
|
Inputs: |
|
- n_e : number of embeddings |
|
- e_dim : dimension of embedding |
|
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 |
|
_____________________________________________ |
|
""" |
|
|
|
config: VQGANConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.embedding = nn.Embed(self.config.n_embed, |
|
self.config.embed_dim, |
|
dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states): |
|
""" |
|
Inputs the output of the encoder network z and maps it to a discrete |
|
one-hot vector that is the index of the closest embedding vector e_j |
|
z (continuous) -> z_q (discrete) |
|
z.shape = (batch, channel, height, width) |
|
quantization pipeline: |
|
1. get encoder input (B,C,H,W) |
|
2. flatten input to (B*H*W,C) |
|
""" |
|
|
|
hidden_states_flattended = hidden_states.reshape( |
|
(-1, self.config.embed_dim)) |
|
|
|
|
|
self.embedding(jnp.ones((1, 1), dtype="i4")) |
|
|
|
|
|
emb_weights = self.variables["params"]["embedding"]["embedding"] |
|
distance = (jnp.sum(hidden_states_flattended**2, axis=1, keepdims=True) + |
|
jnp.sum(emb_weights**2, axis=1) - |
|
2 * jnp.dot(hidden_states_flattended, emb_weights.T)) |
|
|
|
|
|
min_encoding_indices = jnp.argmin(distance, axis=1) |
|
z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape) |
|
|
|
|
|
min_encoding_indices = min_encoding_indices.reshape( |
|
hidden_states.shape[0], -1) |
|
|
|
|
|
|
|
return z_q, min_encoding_indices |
|
|
|
def get_codebook_entry(self, indices, shape=None): |
|
|
|
|
|
batch, num_tokens = indices.shape |
|
z_q = self.embedding(indices) |
|
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), |
|
int(math.sqrt(num_tokens)), -1) |
|
return z_q |
|
|
|
|
|
class VQModule(nn.Module): |
|
config: VQGANConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.encoder = Encoder(self.config, dtype=self.dtype) |
|
self.decoder = Decoder(self.config, dtype=self.dtype) |
|
self.quantize = VectorQuantizer(self.config, dtype=self.dtype) |
|
self.quant_conv = nn.Conv( |
|
self.config.embed_dim, |
|
kernel_size=(1, 1), |
|
strides=(1, 1), |
|
padding="VALID", |
|
dtype=self.dtype, |
|
) |
|
self.post_quant_conv = nn.Conv( |
|
self.config.z_channels, |
|
kernel_size=(1, 1), |
|
strides=(1, 1), |
|
padding="VALID", |
|
dtype=self.dtype, |
|
) |
|
|
|
def encode(self, pixel_values, deterministic: bool = True): |
|
hidden_states = self.encoder(pixel_values, deterministic=deterministic) |
|
hidden_states = self.quant_conv(hidden_states) |
|
quant_states, indices = self.quantize(hidden_states) |
|
return quant_states, indices |
|
|
|
def decode(self, hidden_states, deterministic: bool = True): |
|
hidden_states = self.post_quant_conv(hidden_states) |
|
hidden_states = self.decoder(hidden_states, deterministic=deterministic) |
|
return hidden_states |
|
|
|
def decode_code(self, code_b): |
|
hidden_states = self.quantize.get_codebook_entry(code_b) |
|
hidden_states = self.decode(hidden_states) |
|
return hidden_states |
|
|
|
def __call__(self, pixel_values, deterministic: bool = True): |
|
quant_states, indices = self.encode(pixel_values, deterministic) |
|
hidden_states = self.decode(quant_states, deterministic) |
|
return hidden_states, indices |
|
|
|
|
|
class VQGANPreTrainedModel(FlaxPreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface |
|
for downloading and loading pretrained models. |
|
""" |
|
|
|
config_class = VQGANConfig |
|
base_model_prefix = "model" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: VQGANConfig, |
|
input_shape: Tuple = (1, 256, 256, 3), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
**kwargs, |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, |
|
module, |
|
input_shape=input_shape, |
|
seed=seed, |
|
dtype=dtype) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, |
|
input_shape: Tuple) -> FrozenDict: |
|
|
|
pixel_values = jnp.zeros(input_shape, dtype=jnp.float32) |
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
return self.module.init(rngs, pixel_values)["params"] |
|
|
|
def encode(self, |
|
pixel_values, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False): |
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply({"params": params or self.params}, |
|
jnp.array(pixel_values), |
|
not train, |
|
rngs=rngs, |
|
method=self.module.encode) |
|
|
|
def decode(self, |
|
hidden_states, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False): |
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(hidden_states), |
|
not train, |
|
rngs=rngs, |
|
method=self.module.decode, |
|
) |
|
|
|
def decode_code(self, indices, params: dict = None): |
|
return self.module.apply({"params": params or self.params}, |
|
jnp.array(indices, dtype="i4"), |
|
method=self.module.decode_code) |
|
|
|
def __call__( |
|
self, |
|
pixel_values, |
|
params: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False, |
|
): |
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
jnp.array(pixel_values), |
|
not train, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
class VQModel(VQGANPreTrainedModel): |
|
module_class = VQModule |