File size: 5,750 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MLP-Mixer model."""
from typing import Optional, Tuple
from absl import logging
from big_vision import utils
from big_vision.models import common
import einops
import flax.linen as nn
import flax.training.checkpoints
import jax
import jax.numpy as jnp
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
drop_p: float
@nn.compact
def __call__(self, x, *, train=False):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
y = nn.LayerNorm()(x)
y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y)
return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
class MlpMixer(nn.Module):
"""Mixer architecture."""
patch_size: Tuple[int, int]
num_classes: Optional[int]
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
model_name: Optional[str] = None
stoch_depth: float = 0.0
@nn.compact
def __call__(self, image, *, train=False):
out = {}
x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size,
strides=self.patch_size, name="stem")(image)
x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c")
for i in range(self.num_blocks):
drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth
x = out[f"block_{i}"] = MixerBlock(
self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train)
x = nn.LayerNorm(name="pre_head_layer_norm")(x)
x = out["pre_logits"] = jnp.mean(x, axis=1)
if self.num_classes:
x = out["logits"] = nn.Dense(
self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x)
return x, out
def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name
"""Factory function to easily create a Model variant like "L/16"."""
if variant is not None:
model_size, patch = variant.split("/")
kw.setdefault("patch_size", (int(patch), int(patch)))
config = {
"S": {
"hidden_dim": 512,
"num_blocks": 8,
"channels_mlp_dim": 2048,
"tokens_mlp_dim": 256
},
"B": {
"hidden_dim": 768,
"num_blocks": 12,
"channels_mlp_dim": 3072,
"tokens_mlp_dim": 384
},
"L": {
"hidden_dim": 1024,
"num_blocks": 24,
"channels_mlp_dim": 4096,
"tokens_mlp_dim": 512
},
"H": {
"hidden_dim": 1280,
"num_blocks": 32,
"channels_mlp_dim": 5120,
"tokens_mlp_dim": 640
},
}[model_size]
for k, v in config.items():
kw.setdefault(k, v)
logging.info("Mixer config: %s", kw)
return MlpMixer(num_classes=num_classes, **kw)
def load(init_params, init_file, model_cfg, dont_load=()):
"""Load checkpoint."""
del model_cfg
# Shortcut names for some canonical paper checkpoints:
init_file = {
# pylint: disable=line-too-long
# Pretrained models from the MLP-Mixer paper: https://arxiv.org/abs/2105.01601.
"B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz",
"L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz",
"B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz",
"L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz",
# pylint: enable=line-too-long
}.get(init_file, init_file)
restored_params = utils.load_params(init_file)
restored_params = flax.training.checkpoints.convert_pre_linen(restored_params)
if "Mixer" in restored_params:
restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop(
"encoder_norm"
)
restored_params["stem"] = restored_params.pop("embedding")
def unflatten_dense(d):
return {
"Dense_0": {
"bias": d["bias1"].squeeze(),
"kernel": d["kernel1"].squeeze(),
},
"Dense_1": {
"bias": d["bias2"].squeeze(),
"kernel": d["kernel2"].squeeze(),
},
}
for k, v in restored_params["Mixer"].items():
assert k.startswith("encoderblock_"), k
v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0"))
v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0"))
restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v
del restored_params["Mixer"]
# possibly use the random init for some of the params (such as, the head).
restored_params = common.merge_params(restored_params, init_params, dont_load)
return restored_params
def _stoch_depth_mask(x, drop_p, deterministic, make_rng):
if not deterministic and drop_p:
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape)
return 1.0
|