|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""MLP-Mixer model.""" |
|
|
|
from typing import Optional, Tuple |
|
from absl import logging |
|
|
|
from big_vision import utils |
|
from big_vision.models import common |
|
|
|
import einops |
|
import flax.linen as nn |
|
import flax.training.checkpoints |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
class MlpBlock(nn.Module): |
|
mlp_dim: int |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
y = nn.Dense(self.mlp_dim)(x) |
|
y = nn.gelu(y) |
|
return nn.Dense(x.shape[-1])(y) |
|
|
|
|
|
class MixerBlock(nn.Module): |
|
"""Mixer block layer.""" |
|
tokens_mlp_dim: int |
|
channels_mlp_dim: int |
|
drop_p: float |
|
|
|
@nn.compact |
|
def __call__(self, x, *, train=False): |
|
y = nn.LayerNorm()(x) |
|
y = jnp.swapaxes(y, 1, 2) |
|
y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) |
|
y = jnp.swapaxes(y, 1, 2) |
|
x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng) |
|
y = nn.LayerNorm()(x) |
|
y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) |
|
return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng) |
|
|
|
|
|
class MlpMixer(nn.Module): |
|
"""Mixer architecture.""" |
|
patch_size: Tuple[int, int] |
|
num_classes: Optional[int] |
|
num_blocks: int |
|
hidden_dim: int |
|
tokens_mlp_dim: int |
|
channels_mlp_dim: int |
|
model_name: Optional[str] = None |
|
stoch_depth: float = 0.0 |
|
|
|
@nn.compact |
|
def __call__(self, image, *, train=False): |
|
out = {} |
|
x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size, |
|
strides=self.patch_size, name="stem")(image) |
|
x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c") |
|
for i in range(self.num_blocks): |
|
drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth |
|
x = out[f"block_{i}"] = MixerBlock( |
|
self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train) |
|
x = nn.LayerNorm(name="pre_head_layer_norm")(x) |
|
x = out["pre_logits"] = jnp.mean(x, axis=1) |
|
if self.num_classes: |
|
x = out["logits"] = nn.Dense( |
|
self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x) |
|
return x, out |
|
|
|
|
|
def Model(num_classes=None, *, variant=None, **kw): |
|
"""Factory function to easily create a Model variant like "L/16".""" |
|
|
|
if variant is not None: |
|
model_size, patch = variant.split("/") |
|
kw.setdefault("patch_size", (int(patch), int(patch))) |
|
config = { |
|
"S": { |
|
"hidden_dim": 512, |
|
"num_blocks": 8, |
|
"channels_mlp_dim": 2048, |
|
"tokens_mlp_dim": 256 |
|
}, |
|
"B": { |
|
"hidden_dim": 768, |
|
"num_blocks": 12, |
|
"channels_mlp_dim": 3072, |
|
"tokens_mlp_dim": 384 |
|
}, |
|
"L": { |
|
"hidden_dim": 1024, |
|
"num_blocks": 24, |
|
"channels_mlp_dim": 4096, |
|
"tokens_mlp_dim": 512 |
|
}, |
|
"H": { |
|
"hidden_dim": 1280, |
|
"num_blocks": 32, |
|
"channels_mlp_dim": 5120, |
|
"tokens_mlp_dim": 640 |
|
}, |
|
}[model_size] |
|
|
|
for k, v in config.items(): |
|
kw.setdefault(k, v) |
|
|
|
logging.info("Mixer config: %s", kw) |
|
return MlpMixer(num_classes=num_classes, **kw) |
|
|
|
|
|
def load(init_params, init_file, model_cfg, dont_load=()): |
|
"""Load checkpoint.""" |
|
|
|
del model_cfg |
|
|
|
init_file = { |
|
|
|
|
|
"B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz", |
|
"L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz", |
|
"B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz", |
|
"L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz", |
|
|
|
}.get(init_file, init_file) |
|
restored_params = utils.load_params(init_file) |
|
restored_params = flax.training.checkpoints.convert_pre_linen(restored_params) |
|
|
|
if "Mixer" in restored_params: |
|
restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop( |
|
"encoder_norm" |
|
) |
|
restored_params["stem"] = restored_params.pop("embedding") |
|
def unflatten_dense(d): |
|
return { |
|
"Dense_0": { |
|
"bias": d["bias1"].squeeze(), |
|
"kernel": d["kernel1"].squeeze(), |
|
}, |
|
"Dense_1": { |
|
"bias": d["bias2"].squeeze(), |
|
"kernel": d["kernel2"].squeeze(), |
|
}, |
|
} |
|
for k, v in restored_params["Mixer"].items(): |
|
assert k.startswith("encoderblock_"), k |
|
v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0")) |
|
v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0")) |
|
restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v |
|
del restored_params["Mixer"] |
|
|
|
|
|
restored_params = common.merge_params(restored_params, init_params, dont_load) |
|
|
|
return restored_params |
|
|
|
|
|
def _stoch_depth_mask(x, drop_p, deterministic, make_rng): |
|
if not deterministic and drop_p: |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape) |
|
return 1.0 |
|
|