|
--- |
|
license: cc-by-nc-4.0 |
|
pipeline_tag: image-classification |
|
--- |
|
# Hiera (Tiny) |
|
|
|
Hiera is a hierarchical transformer that is a much more efficient alternative to previous series of hierarchical transformers (ConvNeXT and Swin). |
|
Vanilla transformer architectures (Dosovitskiy et al. 2020) are very popular yet simple and scalable architectures that enable pretraining strategies such as MAE (He et al., 2022). |
|
However, they use the same spatial resolution and number of channels throughout the network, ViTs make inefficient use of their parameters. This |
|
is in contrast to prior “hierarchical” or “multi-scale” models (e.g., Krizhevsky et al. (2012); He et al. (2016)), which use fewer channels but higher spatial resolution in early stages |
|
with simpler features, and more channels but lower spatial resolution later in the model with more complex features. |
|
These models are way too complex though which add overhead operations to achieve state-of-the-art accuracy in ImageNet-1k, making the model slower. |
|
Hiera attempts to address this issue by teaching the model spatial biases by training MAE. |
|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6141a88b3a0ec78603c9e784/ogkud4qc564bPX3f0bGXO.png) |
|
|
|
## How to Use |
|
|
|
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes: |
|
Clone the repository. |
|
```bash |
|
git lfs install |
|
git clone https://huggingface.co/merve/hiera-tiny-ft-224-in1k |
|
pip install timm |
|
cd hiera-tiny-ft-224-in1k |
|
``` |
|
|
|
``` |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
from PIL import Image |
|
import hiera |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
import requests |
|
import sys |
|
sys.path.append("..") |
|
|
|
model = hiera.hiera_small_224(pretrained=True, checkpoint="mae_in1k_ft_in1k") |
|
input_size = 224 |
|
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
# preprocess the image |
|
transform_list = [ |
|
transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(input_size) |
|
] |
|
transform_vis = transforms.Compose(transform_list) |
|
transform_norm = transforms.Compose(transform_list + [ |
|
transforms.ToTensor(), |
|
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), |
|
]) |
|
img_vis = transform_vis(image) |
|
img_norm = transform_norm(image) |
|
|
|
# Get imagenet class as output |
|
out = model(img_norm[None, ...]) |
|
# tabby cat |
|
out.argmax(dim=-1).item() |
|
``` |
|
|
|
You can try the fine-tuned model [here](https://colab.research.google.com/drive/1WIYWaCWiv5QK-MpNr-bEvqgTS1DIW19Z?usp=sharing). |