--- datasets: - Matthijs/snacks model-index: - name: matteopilotto/vit-base-patch16-224-in21k-snacks results: - task: type: image-classification name: Image Classification dataset: name: Matthijs/snacks type: Matthijs/snacks config: default split: test metrics: - name: Accuracy type: accuracy value: 0.8928571428571429 verified: true - name: Precision Macro type: precision value: 0.8990033704680036 verified: true - name: Precision Micro type: precision value: 0.8928571428571429 verified: true - name: Precision Weighted type: precision value: 0.8972398709051788 verified: true - name: Recall Macro type: recall value: 0.8914608843537415 verified: true - name: Recall Micro type: recall value: 0.8928571428571429 verified: true - name: Recall Weighted type: recall value: 0.8928571428571429 verified: true - name: F1 Macro type: f1 value: 0.892544821273258 verified: true - name: F1 Micro type: f1 value: 0.8928571428571429 verified: true - name: F1 Weighted type: f1 value: 0.8924168605019522 verified: true - name: loss type: loss value: 0.479541540145874 verified: true --- # Vision Transformer fine-tuned on `Matthijs/snacks` dataset Vision Transformer (ViT) model pre-trained on ImageNet-21k and fine-tuned on [**Matthijs/snacks**](https://huggingface.co/datasets/Matthijs/snacks) for 5 epochs using various data augmentation transformations from `torchvision`. The model achieves a **94.97%** and **94.43%** accuracy on the validation and test set, respectively. ## Data augmentation pipeline The code block below shows the various transformations applied during pre-processing to augment the original dataset. The augmented images where generated on-the-fly with the `set_transform` method. ```python from transformers import ViTFeatureExtractor from torchvision.transforms import ( Compose, Normalize, Resize, RandomResizedCrop, RandomHorizontalFlip, RandomAdjustSharpness, ToTensor ) checkpoint = 'google/vit-base-patch16-224-in21k' feature_extractor = ViTFeatureExtractor.from_pretrained(checkpoint) # transformations on the training set train_aug_transforms = Compose([ RandomResizedCrop(size=feature_extractor.size), RandomHorizontalFlip(p=0.5), RandomAdjustSharpness(sharpness_factor=5, p=0.5), ToTensor(), Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std), ]) # transformations on the validation/test set valid_aug_transforms = Compose([ Resize(size=(feature_extractor.size, feature_extractor.size)), ToTensor(), Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std), ]) ```