Edit model card

outputs

This model is a fine-tuned version of openai/clip-vit-base-patch32 on an CIFAR10 dataset. It achieves the following results on the evaluation set:

  • Loss: 0.8115
  • Accuracy: 0.8255

Model description

Training and evaluation data

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0002
  • train_batch_size: 10
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 4

Training results

Training Loss Epoch Step Validation Loss Accuracy
1.7258 0.02 100 1.6999 0.8048
1.669 0.04 200 1.6798 0.8055
1.6704 0.06 300 1.6599 0.8053
1.6655 0.08 400 1.6407 0.8047
1.5754 0.1 500 1.6223 0.809
1.6159 0.12 600 1.6040 0.8068
1.5663 0.15 700 1.5858 0.8073
1.5426 0.17 800 1.5677 0.8095
1.5794 0.19 900 1.5506 0.808
1.5504 0.21 1000 1.5342 0.8035
1.554 0.23 1100 1.5179 0.802
1.4831 0.25 1200 1.5022 0.7972
1.4718 0.27 1300 1.4867 0.7955
1.5206 0.29 1400 1.4716 0.796
1.4534 0.31 1500 1.4567 0.7963
1.3932 0.33 1600 1.4427 0.7875
1.4635 0.35 1700 1.4289 0.789
1.4339 0.38 1800 1.4151 0.793
1.4492 0.4 1900 1.4016 0.7973
1.4369 0.42 2000 1.3881 0.8018
1.4007 0.44 2100 1.3754 0.801
1.3697 0.46 2200 1.3627 0.8025
1.3298 0.48 2300 1.3505 0.8048
1.2809 0.5 2400 1.3386 0.8068
1.2989 0.52 2500 1.3272 0.8067
1.2958 0.54 2600 1.3159 0.81
1.3072 0.56 2700 1.3048 0.8097
1.2545 0.58 2800 1.2943 0.809
1.2722 0.6 2900 1.2834 0.8112
1.2628 0.62 3000 1.2732 0.8102
1.2357 0.65 3100 1.2632 0.8105
1.3189 0.67 3200 1.2532 0.8093
1.2465 0.69 3300 1.2436 0.8097
1.2579 0.71 3400 1.2342 0.8087
1.1963 0.73 3500 1.2249 0.8085
1.1701 0.75 3600 1.2159 0.8092
1.2117 0.77 3700 1.2069 0.8113
1.1907 0.79 3800 1.1984 0.8112
1.1903 0.81 3900 1.1902 0.8115
1.2357 0.83 4000 1.1821 0.8115
1.1924 0.85 4100 1.1738 0.8117
1.1914 0.88 4200 1.1657 0.8133
1.1536 0.9 4300 1.1580 0.8148
1.1893 0.92 4400 1.1505 0.8158
1.1811 0.94 4500 1.1433 0.8158
1.0182 0.96 4600 1.1358 0.8165
1.0396 0.98 4700 1.1287 0.8158
1.1502 1.0 4800 1.1217 0.816
1.1764 1.02 4900 1.1147 0.8158
1.1508 1.04 5000 1.1080 0.8152
1.0518 1.06 5100 1.1015 0.8155
1.0648 1.08 5200 1.0952 0.816
1.1631 1.1 5300 1.0889 0.8153
1.0629 1.12 5400 1.0826 0.8152
1.1151 1.15 5500 1.0771 0.815
1.1377 1.17 5600 1.0711 0.8145
1.0353 1.19 5700 1.0652 0.8158
1.068 1.21 5800 1.0594 0.815
1.0834 1.23 5900 1.0538 0.8162
1.0002 1.25 6000 1.0483 0.8165
1.0024 1.27 6100 1.0428 0.817
1.0609 1.29 6200 1.0376 0.817
1.0901 1.31 6300 1.0324 0.816
1.0772 1.33 6400 1.0275 0.8173
0.9434 1.35 6500 1.0226 0.817
0.9692 1.38 6600 1.0178 0.8157
1.0461 1.4 6700 1.0131 0.8155
1.0583 1.42 6800 1.0086 0.8143
0.9369 1.44 6900 1.0042 0.8157
1.0685 1.46 7000 0.9998 0.8152
1.062 1.48 7100 0.9955 0.8153
1.0394 1.5 7200 0.9912 0.8142
1.031 1.52 7300 0.9870 0.8157
0.9556 1.54 7400 0.9829 0.8155
0.9846 1.56 7500 0.9789 0.8152
0.9995 1.58 7600 0.9750 0.8158
1.0273 1.6 7700 0.9711 0.8163
0.9383 1.62 7800 0.9674 0.817
0.951 1.65 7900 0.9634 0.8163
0.9457 1.67 8000 0.9598 0.8167
1.012 1.69 8100 0.9563 0.816
0.9683 1.71 8200 0.9529 0.8158
0.9582 1.73 8300 0.9495 0.8157
0.9005 1.75 8400 0.9461 0.8162
0.888 1.77 8500 0.9428 0.8175
0.9267 1.79 8600 0.9396 0.8168
0.9298 1.81 8700 0.9364 0.8168
1.0072 1.83 8800 0.9334 0.8167
0.9425 1.85 8900 0.9303 0.8158
0.9729 1.88 9000 0.9273 0.8168
0.9104 1.9 9100 0.9244 0.8175
0.9153 1.92 9200 0.9216 0.817
0.9115 1.94 9300 0.9188 0.8165
0.9079 1.96 9400 0.9161 0.8168
0.8453 1.98 9500 0.9133 0.8175
0.8323 2.0 9600 0.9107 0.817
0.9071 2.02 9700 0.9080 0.8183
0.9331 2.04 9800 0.9054 0.8185
0.886 2.06 9900 0.9029 0.8193
0.8562 2.08 10000 0.9006 0.8183
0.8904 2.1 10100 0.8980 0.8193
0.8247 2.12 10200 0.8956 0.8188
0.8114 2.15 10300 0.8934 0.8202
0.96 2.17 10400 0.8912 0.8198
0.9326 2.19 10500 0.8889 0.8198
0.8057 2.21 10600 0.8867 0.8195
0.8266 2.23 10700 0.8846 0.8188
0.7909 2.25 10800 0.8823 0.82
0.886 2.27 10900 0.8803 0.8192
0.8691 2.29 11000 0.8783 0.8193
0.8676 2.31 11100 0.8763 0.8187
0.8147 2.33 11200 0.8744 0.819
0.7723 2.35 11300 0.8725 0.8195
0.9222 2.38 11400 0.8705 0.8188
0.9692 2.4 11500 0.8687 0.8195
0.8792 2.42 11600 0.8669 0.8188
0.939 2.44 11700 0.8650 0.8193
0.9093 2.46 11800 0.8633 0.8188
0.7794 2.48 11900 0.8616 0.8182
0.8572 2.5 12000 0.8599 0.8182
0.9035 2.52 12100 0.8582 0.8185
0.8063 2.54 12200 0.8566 0.8193
0.8935 2.56 12300 0.8550 0.8195
0.7991 2.58 12400 0.8535 0.8192
0.856 2.6 12500 0.8520 0.8195
0.8374 2.62 12600 0.8505 0.8197
0.8418 2.65 12700 0.8490 0.8203
0.9232 2.67 12800 0.8475 0.8208
0.8335 2.69 12900 0.8462 0.8207
0.8659 2.71 13000 0.8449 0.8205
0.9798 2.73 13100 0.8435 0.8205
0.7288 2.75 13200 0.8423 0.8205
0.9086 2.77 13300 0.8411 0.821
0.7912 2.79 13400 0.8398 0.8205
0.8675 2.81 13500 0.8386 0.8202
0.8045 2.83 13600 0.8374 0.8198
0.8421 2.85 13700 0.8362 0.8202
0.7453 2.88 13800 0.8350 0.8202
0.7348 2.9 13900 0.8339 0.8203
0.8977 2.92 14000 0.8328 0.8205
0.859 2.94 14100 0.8318 0.821
0.8571 2.96 14200 0.8307 0.8212
0.8158 2.98 14300 0.8297 0.8215
0.8635 3.0 14400 0.8287 0.8215
0.9095 3.02 14500 0.8277 0.8215
0.8491 3.04 14600 0.8268 0.8217
0.9136 3.06 14700 0.8259 0.8223
0.8652 3.08 14800 0.8250 0.8218
0.9299 3.1 14900 0.8242 0.8215
0.8259 3.12 15000 0.8233 0.8215
0.775 3.15 15100 0.8225 0.8222
0.801 3.17 15200 0.8217 0.8217
0.8535 3.19 15300 0.8209 0.8215
0.7973 3.21 15400 0.8202 0.8217
0.8937 3.23 15500 0.8195 0.8213
0.7632 3.25 15600 0.8188 0.821
0.8117 3.27 15700 0.8181 0.8212
0.8941 3.29 15800 0.8174 0.8217
0.802 3.31 15900 0.8168 0.8225
0.8303 3.33 16000 0.8161 0.8217
0.8264 3.35 16100 0.8155 0.8218
0.8411 3.38 16200 0.8149 0.8213
0.9378 3.4 16300 0.8143 0.8218
0.8514 3.42 16400 0.8138 0.8217
0.7313 3.44 16500 0.8133 0.8222
0.8238 3.46 16600 0.8128 0.8218
0.7876 3.48 16700 0.8123 0.8222
0.8364 3.5 16800 0.8118 0.8222
0.7049 3.52 16900 0.8114 0.8222
0.9101 3.54 17000 0.8109 0.8218
0.7984 3.56 17100 0.8105 0.822
0.85 3.58 17200 0.8101 0.8218
0.8677 3.6 17300 0.8098 0.822
0.8797 3.62 17400 0.8094 0.8218
0.7847 3.65 17500 0.8091 0.8222
0.8415 3.67 17600 0.8088 0.8218
0.8702 3.69 17700 0.8085 0.8222
0.8979 3.71 17800 0.8082 0.8222
0.8387 3.73 17900 0.8080 0.8222
0.8467 3.75 18000 0.8077 0.822
0.8729 3.77 18100 0.8075 0.822
0.8291 3.79 18200 0.8073 0.8222
0.7897 3.81 18300 0.8072 0.8222
0.8039 3.83 18400 0.8070 0.822
0.771 3.85 18500 0.8069 0.8223
0.7704 3.88 18600 0.8067 0.8223
0.7695 3.9 18700 0.8066 0.8223
0.8958 3.92 18800 0.8066 0.8223
0.8342 3.94 18900 0.8065 0.8223
0.8725 3.96 19000 0.8064 0.8225
0.8657 3.98 19100 0.8064 0.8225
0.779 4.0 19200 0.8064 0.8225

Framework versions

  • Transformers 4.35.2
  • Pytorch 2.1.0+cu118
  • Datasets 2.15.0
  • Tokenizers 0.15.0

Example of usage

Simple demo for Google Colab

!pip install datasets transformers[torch] accelerate -U
!git clone https://github.com/Andron00e/CLIPForImageClassification
%cd CLIPForImageClassification/clip_for_classification

import torch
from transformers import TrainingArguments
from datasets import load_dataset, load_metric
from transformers import CLIPProcessor, AutoModelForImageClassification
from modeling_clipforimageclassification import CLIPForImageClassification

processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1")
model = CLIPForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1", 10)

dataset = load_dataset("Andron00e/CIFAR10-custom")
dataset = dataset["train"].train_test_split(test_size=0.2)
from datasets import DatasetDict

val_test = dataset["test"].train_test_split(test_size=0.5)
dataset = DatasetDict({
    "train": dataset["train"],
    "validation": val_test["train"],
    "test": val_test["test"],
})

classes = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}

def transform(example_batch):
    inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load_metric("accuracy")

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

training_args = TrainingArguments(
  output_dir="./outputs",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=False,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=dataset.with_transform(transform)["train"],
    eval_dataset=dataset.with_transform(transform)["validation"],
    tokenizer=model.processor,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(processed_dataset['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

%cd ..
%cd ..
Downloads last month
1
Safetensors
Model size
151M params
Tensor type
F32
·

Finetuned from

Dataset used to train Andron00e/CLIPForImageClassification-v1