Edit model card

Model Card for Model ViT fine tuning on CiFAR10

It's a toy experiemnt of fine tuning ViT by using huggingface transformers.

Model Details

It's fine tuned on CiFAR10 for 1000 steps, and achieved accuracy of 98.7% on test split.

Model Description

  • Developed by: verypro
  • Model type: Vision Transformer
  • License: MIT
  • Finetuned from model [optional]: google/vit-base-patch16-224

Uses

from transformers import ViTImageProcessor, ViTForImageClassification
from torchvision import datasets

# # 初始化模型和特征提取器
image_processor = ViTImageProcessor.from_pretrained('verypro/vit-base-patch16-224-cifar10')
model = ViTForImageClassification.from_pretrained('verypro/vit-base-patch16-224-cifar10')


# 加载 CIFAR10 数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)

sample = test_dataset[0]
image = sample[0]
gt_label = sample[1]

# 保存原始图像,并打印其标签
image.save("original.png")
print(f"Ground truth class: '{test_dataset.classes[gt_label]}'")

inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)

logits = outputs.logits
print(logits)

predicted_class_idx = logits.argmax(-1).item()
predicted_class_label = test_dataset.classes[predicted_class_idx]
print(f"Predicted class: '{predicted_class_label}', confidence: {logits[0, predicted_class_idx]:.2f}")

The output of above code snippets should be like:

Ground truth class: 'cat'
tensor([[-1.1497, -0.1080, -0.7349,  9.2517, -1.3094,  0.5403, -0.9521, -1.0223,
         -1.4102, -1.5389]], grad_fn=<AddmmBackward0>)
Predicted class: 'cat', confidence: 9.25
Downloads last month
1
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.