mattdangerw's picture
Update README.md with new model card content
a30a95f verified
metadata
library_name: keras-hub
tags:
  - image-segmentation
  - keras

Model Overview

A Keras model implementing the MixTransformer architecture to be used as a backbone for the SegFormer architecture. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub.

References:

Links

Installation

Keras and KerasHub can be installed with:

pip install -U -q keras-Hub
pip install -U -q keras>=3

Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the Keras Getting Started page.

Presets

The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below. Here's the table formatted similarly to the given pattern:

Here's the updated table with the input resolutions included in the descriptions:

Preset name Parameters Description
mit_b0_ade20k_512 3.32M MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels.
mit_b1_ade20k_512 13.16M MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels.
mit_b2_ade20k_512 24.20M MiT (MixTransformer) model with 16 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels.
mit_b3_ade20k_512 44.08M MiT (MixTransformer) model with 28 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels.
mit_b4_ade20k_512 60.85M MiT (MixTransformer) model with 41 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels.
mit_b5_ade20k_640 81.45M MiT (MixTransformer) model with 52 transformer blocks, trained on the ADE20K dataset with an input resolution of 640x640 pixels.
mit_b0_cityscapes_1024 3.32M MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels.
mit_b1_cityscapes_1024 13.16M MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels.
mit_b2_cityscapes_1024 24.20M MiT (MixTransformer) model with 16 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels.
mit_b3_cityscapes_1024 44.08M MiT (MixTransformer) model with 28 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels.
mit_b4_cityscapes_1024 60.85M MiT (MixTransformer) model with 41 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels.
mit_b5_cityscapes_1024 81.45M MiT (MixTransformer) model with 52 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels.

Example Usage

Using the class with a backbone:

import tensorflow as tf
import keras_cv
import numpy as np

images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b2_cityscapes_1024")

# Evaluate model
model(images)

# Train model
model.compile(
     optimizer="adam",
     loss=keras.losses.BinaryCrossentropy(from_logits=False),
     metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)

Example Usage with Hugging Face URI

Using the class with a backbone:

import tensorflow as tf
import keras_cv
import numpy as np

images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("hf://keras/mit_b2_cityscapes_1024")

# Evaluate model
model(images)

# Train model
model.compile(
     optimizer="adam",
     loss=keras.losses.BinaryCrossentropy(from_logits=False),
     metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)