Edit model card


A Swin Transformer image classification model. Weights are learned on ImageNet-1k data.

Disclaimer: This is a porting of the torchvision model weights to Apple MLX Framework.

How to use

pip install mlx-image

Here is how to use this model for image classification:

from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=256)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

model = create_model("swin_v2_small_patch4_window8_256")

logits = model(x)

You can also use the embeds from layer before head:

from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=256)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

# first option
model = create_model("swin_v2_small_patch4_window8_256", num_classes=0)

embeds = model(x)

# second option
model = create_model("swin_v2_small_patch4_window8_256")

embeds = model.get_features(x)

Model Comparison

Explore the metrics of this model in mlx-image model results.

Downloads last month
Model size
49.8M params
Tensor type
Inference API
Drag image file here or click to browse from your device
Inference API (serverless) does not yet support mlx-image models for this pipeline type.

Dataset used to train mlx-vision/swin_v2_small_patch4_window8_256-mlxim