Z-Image-Turbo-MLX / autoencoder.py
illusion615's picture
Upload folder using huggingface_hub
64566e4 verified
"""AutoencoderKL Decoder β€” pure MLX implementation.
Decodes latent representations to RGB images without PyTorch/diffusers
dependency. Architecture matches diffusers AutoencoderKL with the
Z-Image-Turbo VAE config:
latent_channels = 16
block_out_channels = [128, 256, 512, 512]
layers_per_block = 2 (decoder uses layers_per_block + 1 = 3)
norm_num_groups = 32
mid_block_add_attention = true
force_upcast = true (all ops in float32)
scaling_factor = 0.3611
shift_factor = 0.1159
Data format: NHWC throughout (MLX convention).
"""
from __future__ import annotations
import math
import mlx.core as mx
import mlx.nn as nn
# Match diffusers VAE GroupNorm: eps=1e-6, pytorch_compatible=True
_GN_EPS = 1e-6
def _gn(groups: int, channels: int) -> nn.GroupNorm:
return nn.GroupNorm(groups, channels, eps=_GN_EPS, pytorch_compatible=True)
# ── Building blocks ──────────────────────────────────────────────
class ResnetBlock2D(nn.Module):
"""Residual block: GroupNorm β†’ SiLU β†’ Conv β†’ GroupNorm β†’ SiLU β†’ Conv + skip."""
def __init__(self, in_channels: int, out_channels: int, groups: int = 32):
super().__init__()
self.norm1 = _gn(groups, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.norm2 = _gn(groups, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def __call__(self, x: mx.array) -> mx.array:
residual = x
x = nn.silu(self.norm1(x))
x = self.conv1(x)
x = nn.silu(self.norm2(x))
x = self.conv2(x)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return x + residual
class AttentionBlock(nn.Module):
"""Single-head self-attention over spatial positions (NHWC)."""
def __init__(self, channels: int, groups: int = 32):
super().__init__()
self.group_norm = _gn(groups, channels)
self.to_q = nn.Linear(channels, channels)
self.to_k = nn.Linear(channels, channels)
self.to_v = nn.Linear(channels, channels)
# diffusers wraps out-proj in a list (Sequential): to_out.0
self.to_out = [nn.Linear(channels, channels)]
def __call__(self, x: mx.array) -> mx.array:
residual = x
B, H, W, C = x.shape
x = self.group_norm(x)
x = x.reshape(B, H * W, C)
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
scale = 1.0 / math.sqrt(C)
attn = (q @ k.transpose(0, 2, 1)) * scale
attn = mx.softmax(attn, axis=-1)
x = attn @ v
x = self.to_out[0](x)
x = x.reshape(B, H, W, C)
return x + residual
class Upsample2D(nn.Module):
"""2Γ— nearest-neighbour upsample followed by a 3Γ—3 conv."""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def __call__(self, x: mx.array) -> mx.array:
# Nearest-neighbour 2Γ— in NHWC
B, H, W, C = x.shape
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
x = self.conv(x)
return x
class UpDecoderBlock2D(nn.Module):
"""Decoder up-block: N resnet blocks + optional 2Γ— upsample."""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 3,
add_upsample: bool = True,
groups: int = 32,
):
super().__init__()
self.resnets = []
for i in range(num_layers):
res_in = in_channels if i == 0 else out_channels
self.resnets.append(ResnetBlock2D(res_in, out_channels, groups))
self.upsamplers = []
if add_upsample:
self.upsamplers.append(Upsample2D(out_channels))
def __call__(self, x: mx.array) -> mx.array:
for resnet in self.resnets:
x = resnet(x)
for up in self.upsamplers:
x = up(x)
return x
class MidBlock2D(nn.Module):
"""Mid block: resnet β†’ self-attention β†’ resnet."""
def __init__(self, channels: int, groups: int = 32):
super().__init__()
self.resnets = [
ResnetBlock2D(channels, channels, groups),
ResnetBlock2D(channels, channels, groups),
]
self.attentions = [AttentionBlock(channels, groups)]
def __call__(self, x: mx.array) -> mx.array:
x = self.resnets[0](x)
x = self.attentions[0](x)
x = self.resnets[1](x)
return x
# ── Decoder ──────────────────────────────────────────────────────
class Decoder(nn.Module):
"""AutoencoderKL Decoder (NHWC, pure MLX).
Module hierarchy matches diffusers weight-key paths after stripping
the ``decoder.`` prefix, so weights can be loaded directly.
"""
def __init__(
self,
latent_channels: int = 16,
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
):
super().__init__()
reversed_ch = list(reversed(block_out_channels)) # [512, 512, 256, 128]
# Input projection
self.conv_in = nn.Conv2d(latent_channels, reversed_ch[0], kernel_size=3, padding=1)
# Mid block
self.mid_block = MidBlock2D(reversed_ch[0], norm_num_groups)
# Up blocks (3 upsamples β†’ total 8Γ— spatial increase)
self.up_blocks = []
for i, out_ch in enumerate(reversed_ch):
in_ch = reversed_ch[i - 1] if i > 0 else reversed_ch[0]
add_upsample = i < len(reversed_ch) - 1
self.up_blocks.append(
UpDecoderBlock2D(
in_channels=in_ch,
out_channels=out_ch,
num_layers=layers_per_block + 1,
add_upsample=add_upsample,
groups=norm_num_groups,
)
)
# Output
self.conv_norm_out = _gn(norm_num_groups, reversed_ch[-1])
self.conv_out = nn.Conv2d(reversed_ch[-1], 3, kernel_size=3, padding=1)
def __call__(self, z: mx.array) -> mx.array:
"""Decode latents β†’ image.
Args:
z: (B, H, W, C) latent tensor in NHWC, **already scaled**.
Returns:
(B, 8H, 8W, 3) decoded image.
"""
x = self.conv_in(z)
x = self.mid_block(x)
for block in self.up_blocks:
x = block(x)
x = nn.silu(self.conv_norm_out(x))
x = self.conv_out(x)
return x